ScalarReplAggregatesHLSL.cpp 251 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822
  1. //===- ScalarReplAggregatesHLSL.cpp - Scalar Replacement of Aggregates ----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Based on ScalarReplAggregates.cpp. The difference is HLSL version will keep
  12. // array so it can break up all structure.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/ADT/SetVector.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/ADT/Statistic.h"
  18. #include "llvm/Analysis/AssumptionCache.h"
  19. #include "llvm/Analysis/Loads.h"
  20. #include "llvm/Analysis/ValueTracking.h"
  21. #include "llvm/IR/CallSite.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DIBuilder.h"
  24. #include "llvm/IR/DataLayout.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/DerivedTypes.h"
  27. #include "llvm/IR/Dominators.h"
  28. #include "llvm/IR/Function.h"
  29. #include "llvm/IR/GetElementPtrTypeIterator.h"
  30. #include "llvm/IR/GlobalVariable.h"
  31. #include "llvm/IR/IRBuilder.h"
  32. #include "llvm/IR/Instructions.h"
  33. #include "llvm/IR/IntrinsicInst.h"
  34. #include "llvm/IR/LLVMContext.h"
  35. #include "llvm/IR/Module.h"
  36. #include "llvm/IR/Operator.h"
  37. #include "llvm/Pass.h"
  38. #include "llvm/Support/Debug.h"
  39. #include "llvm/Support/ErrorHandling.h"
  40. #include "llvm/Support/MathExtras.h"
  41. #include "llvm/Support/raw_ostream.h"
  42. #include "llvm/Transforms/Scalar.h"
  43. #include "llvm/Transforms/Utils/Local.h"
  44. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  45. #include "llvm/Transforms/Utils/SSAUpdater.h"
  46. #include "llvm/Transforms/Utils/Local.h"
  47. #include "dxc/HLSL/HLOperations.h"
  48. #include "dxc/HLSL/DxilConstants.h"
  49. #include "dxc/HLSL/HLModule.h"
  50. #include "dxc/HLSL/DxilModule.h"
  51. #include "dxc/HlslIntrinsicOp.h"
  52. #include "dxc/HLSL/DxilTypeSystem.h"
  53. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  54. #include "dxc/HLSL/DxilOperations.h"
  55. #include <deque>
  56. #include <unordered_map>
  57. #include <unordered_set>
  58. using namespace llvm;
  59. using namespace hlsl;
  60. #define DEBUG_TYPE "scalarreplhlsl"
  61. STATISTIC(NumReplaced, "Number of allocas broken up");
  62. STATISTIC(NumPromoted, "Number of allocas promoted");
  63. STATISTIC(NumAdjusted, "Number of scalar allocas adjusted to allow promotion");
  64. STATISTIC(NumConverted, "Number of aggregates converted to scalar");
  65. namespace {
  66. class SROA_Helper {
  67. public:
  68. // Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  69. // Then do SROA on V.
  70. static bool DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  71. IRBuilder<> &Builder, bool bFlatVector,
  72. bool hasPrecise, DxilTypeSystem &typeSys,
  73. SmallVector<Value *, 32> &DeadInsts);
  74. static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
  75. IRBuilder<> &Builder, bool bFlatVector,
  76. bool hasPrecise, DxilTypeSystem &typeSys,
  77. SmallVector<Value *, 32> &DeadInsts);
  78. // Lower memcpy related to V.
  79. static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  80. DxilTypeSystem &typeSys, const DataLayout &DL,
  81. bool bAllowReplace);
  82. static void MarkEmptyStructUsers(Value *V,
  83. SmallVector<Value *, 32> &DeadInsts);
  84. static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
  85. private:
  86. SROA_Helper(Value *V, ArrayRef<Value *> Elts,
  87. SmallVector<Value *, 32> &DeadInsts)
  88. : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts) {}
  89. void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
  90. private:
  91. // Must be a pointer type val.
  92. Value * OldVal;
  93. // Flattened elements for OldVal.
  94. ArrayRef<Value*> NewElts;
  95. SmallVector<Value *, 32> &DeadInsts;
  96. void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
  97. void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
  98. void RewriteForLoad(LoadInst *loadInst);
  99. void RewriteForStore(StoreInst *storeInst);
  100. void RewriteMemIntrin(MemIntrinsic *MI, Instruction *Inst);
  101. void RewriteCall(CallInst *CI);
  102. void RewriteBitCast(BitCastInst *BCI);
  103. };
  104. struct SROA_HLSL : public FunctionPass {
  105. SROA_HLSL(bool Promote, int T, bool hasDT, char &ID, int ST, int AT, int SLT)
  106. : FunctionPass(ID), HasDomTree(hasDT), RunPromotion(Promote) {
  107. if (AT == -1)
  108. ArrayElementThreshold = 8;
  109. else
  110. ArrayElementThreshold = AT;
  111. if (SLT == -1)
  112. // Do not limit the scalar integer load size if no threshold is given.
  113. ScalarLoadThreshold = -1;
  114. else
  115. ScalarLoadThreshold = SLT;
  116. }
  117. bool runOnFunction(Function &F) override;
  118. bool performScalarRepl(Function &F, DxilTypeSystem &typeSys);
  119. bool performPromotion(Function &F);
  120. bool markPrecise(Function &F);
  121. private:
  122. bool HasDomTree;
  123. bool RunPromotion;
  124. /// DeadInsts - Keep track of instructions we have made dead, so that
  125. /// we can remove them after we are done working.
  126. SmallVector<Value *, 32> DeadInsts;
  127. /// AllocaInfo - When analyzing uses of an alloca instruction, this captures
  128. /// information about the uses. All these fields are initialized to false
  129. /// and set to true when something is learned.
  130. struct AllocaInfo {
  131. /// The alloca to promote.
  132. AllocaInst *AI;
  133. /// CheckedPHIs - This is a set of verified PHI nodes, to prevent infinite
  134. /// looping and avoid redundant work.
  135. SmallPtrSet<PHINode *, 8> CheckedPHIs;
  136. /// isUnsafe - This is set to true if the alloca cannot be SROA'd.
  137. bool isUnsafe : 1;
  138. /// isMemCpySrc - This is true if this aggregate is memcpy'd from.
  139. bool isMemCpySrc : 1;
  140. /// isMemCpyDst - This is true if this aggregate is memcpy'd into.
  141. bool isMemCpyDst : 1;
  142. /// hasSubelementAccess - This is true if a subelement of the alloca is
  143. /// ever accessed, or false if the alloca is only accessed with mem
  144. /// intrinsics or load/store that only access the entire alloca at once.
  145. bool hasSubelementAccess : 1;
  146. /// hasALoadOrStore - This is true if there are any loads or stores to it.
  147. /// The alloca may just be accessed with memcpy, for example, which would
  148. /// not set this.
  149. bool hasALoadOrStore : 1;
  150. /// hasArrayIndexing - This is true if there are any dynamic array
  151. /// indexing to it.
  152. bool hasArrayIndexing : 1;
  153. /// hasVectorIndexing - This is true if there are any dynamic vector
  154. /// indexing to it.
  155. bool hasVectorIndexing : 1;
  156. explicit AllocaInfo(AllocaInst *ai)
  157. : AI(ai), isUnsafe(false), isMemCpySrc(false), isMemCpyDst(false),
  158. hasSubelementAccess(false), hasALoadOrStore(false),
  159. hasArrayIndexing(false), hasVectorIndexing(false) {}
  160. };
  161. /// ArrayElementThreshold - The maximum number of elements an array can
  162. /// have to be considered for SROA.
  163. unsigned ArrayElementThreshold;
  164. /// ScalarLoadThreshold - The maximum size in bits of scalars to load when
  165. /// converting to scalar
  166. unsigned ScalarLoadThreshold;
  167. void MarkUnsafe(AllocaInfo &I, Instruction *User) {
  168. I.isUnsafe = true;
  169. DEBUG(dbgs() << " Transformation preventing inst: " << *User << '\n');
  170. }
  171. bool isSafeAllocaToScalarRepl(AllocaInst *AI);
  172. void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info);
  173. void isSafePHISelectUseForScalarRepl(Instruction *User, uint64_t Offset,
  174. AllocaInfo &Info);
  175. void isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset, AllocaInfo &Info);
  176. void isSafeMemAccess(uint64_t Offset, uint64_t MemSize, Type *MemOpType,
  177. bool isStore, AllocaInfo &Info, Instruction *TheAccess,
  178. bool AllowWholeAccess);
  179. bool TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
  180. const DataLayout &DL);
  181. void DeleteDeadInstructions();
  182. bool ShouldAttemptScalarRepl(AllocaInst *AI);
  183. };
  184. // SROA_DT - SROA that uses DominatorTree.
  185. struct SROA_DT_HLSL : public SROA_HLSL {
  186. static char ID;
  187. public:
  188. SROA_DT_HLSL(bool Promote = false, int T = -1, int ST = -1, int AT = -1, int SLT = -1)
  189. : SROA_HLSL(Promote, T, true, ID, ST, AT, SLT) {
  190. initializeSROA_DTPass(*PassRegistry::getPassRegistry());
  191. }
  192. // getAnalysisUsage - This pass does not require any passes, but we know it
  193. // will not alter the CFG, so say so.
  194. void getAnalysisUsage(AnalysisUsage &AU) const override {
  195. AU.addRequired<AssumptionCacheTracker>();
  196. AU.addRequired<DominatorTreeWrapperPass>();
  197. AU.setPreservesCFG();
  198. }
  199. };
  200. // SROA_SSAUp - SROA that uses SSAUpdater.
  201. struct SROA_SSAUp_HLSL : public SROA_HLSL {
  202. static char ID;
  203. public:
  204. SROA_SSAUp_HLSL(bool Promote = false, int T = -1, int ST = -1, int AT = -1, int SLT = -1)
  205. : SROA_HLSL(Promote, T, false, ID, ST, AT, SLT) {
  206. initializeSROA_SSAUpPass(*PassRegistry::getPassRegistry());
  207. }
  208. // getAnalysisUsage - This pass does not require any passes, but we know it
  209. // will not alter the CFG, so say so.
  210. void getAnalysisUsage(AnalysisUsage &AU) const override {
  211. AU.addRequired<AssumptionCacheTracker>();
  212. AU.setPreservesCFG();
  213. }
  214. };
  215. // Simple struct to split memcpy into ld/st
  216. struct MemcpySplitter {
  217. llvm::LLVMContext &m_context;
  218. DxilTypeSystem &m_typeSys;
  219. public:
  220. MemcpySplitter(llvm::LLVMContext &context, DxilTypeSystem &typeSys)
  221. : m_context(context), m_typeSys(typeSys) {}
  222. void Split(llvm::Function &F);
  223. static void PatchMemCpyWithZeroIdxGEP(Module &M);
  224. static void PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI, const DataLayout &DL);
  225. static void SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  226. DxilFieldAnnotation *fieldAnnotation,
  227. DxilTypeSystem &typeSys);
  228. };
  229. }
  230. char SROA_DT_HLSL::ID = 0;
  231. char SROA_SSAUp_HLSL::ID = 0;
  232. INITIALIZE_PASS_BEGIN(SROA_DT_HLSL, "scalarreplhlsl",
  233. "Scalar Replacement of Aggregates HLSL (DT)", false,
  234. false)
  235. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  236. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  237. INITIALIZE_PASS_END(SROA_DT_HLSL, "scalarreplhlsl",
  238. "Scalar Replacement of Aggregates HLSL (DT)", false, false)
  239. INITIALIZE_PASS_BEGIN(SROA_SSAUp_HLSL, "scalarreplhlsl-ssa",
  240. "Scalar Replacement of Aggregates HLSL (SSAUp)", false,
  241. false)
  242. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  243. INITIALIZE_PASS_END(SROA_SSAUp_HLSL, "scalarreplhlsl-ssa",
  244. "Scalar Replacement of Aggregates HLSL (SSAUp)", false,
  245. false)
  246. // Public interface to the ScalarReplAggregates pass
  247. FunctionPass *llvm::createScalarReplAggregatesHLSLPass(bool UseDomTree, bool Promote) {
  248. if (UseDomTree)
  249. return new SROA_DT_HLSL(Promote);
  250. return new SROA_SSAUp_HLSL(Promote);
  251. }
  252. //===----------------------------------------------------------------------===//
  253. // Convert To Scalar Optimization.
  254. //===----------------------------------------------------------------------===//
  255. namespace {
  256. /// ConvertToScalarInfo - This class implements the "Convert To Scalar"
  257. /// optimization, which scans the uses of an alloca and determines if it can
  258. /// rewrite it in terms of a single new alloca that can be mem2reg'd.
  259. class ConvertToScalarInfo {
  260. /// AllocaSize - The size of the alloca being considered in bytes.
  261. unsigned AllocaSize;
  262. const DataLayout &DL;
  263. unsigned ScalarLoadThreshold;
  264. /// IsNotTrivial - This is set to true if there is some access to the object
  265. /// which means that mem2reg can't promote it.
  266. bool IsNotTrivial;
  267. /// ScalarKind - Tracks the kind of alloca being considered for promotion,
  268. /// computed based on the uses of the alloca rather than the LLVM type system.
  269. enum {
  270. Unknown,
  271. // Accesses via GEPs that are consistent with element access of a vector
  272. // type. This will not be converted into a vector unless there is a later
  273. // access using an actual vector type.
  274. ImplicitVector,
  275. // Accesses via vector operations and GEPs that are consistent with the
  276. // layout of a vector type.
  277. Vector,
  278. // An integer bag-of-bits with bitwise operations for insertion and
  279. // extraction. Any combination of types can be converted into this kind
  280. // of scalar.
  281. Integer
  282. } ScalarKind;
  283. /// VectorTy - This tracks the type that we should promote the vector to if
  284. /// it is possible to turn it into a vector. This starts out null, and if it
  285. /// isn't possible to turn into a vector type, it gets set to VoidTy.
  286. VectorType *VectorTy;
  287. /// HadNonMemTransferAccess - True if there is at least one access to the
  288. /// alloca that is not a MemTransferInst. We don't want to turn structs into
  289. /// large integers unless there is some potential for optimization.
  290. bool HadNonMemTransferAccess;
  291. /// HadDynamicAccess - True if some element of this alloca was dynamic.
  292. /// We don't yet have support for turning a dynamic access into a large
  293. /// integer.
  294. bool HadDynamicAccess;
  295. public:
  296. explicit ConvertToScalarInfo(unsigned Size, const DataLayout &DL,
  297. unsigned SLT)
  298. : AllocaSize(Size), DL(DL), ScalarLoadThreshold(SLT), IsNotTrivial(false),
  299. ScalarKind(Unknown), VectorTy(nullptr), HadNonMemTransferAccess(false),
  300. HadDynamicAccess(false) {}
  301. AllocaInst *TryConvert(AllocaInst *AI);
  302. private:
  303. bool CanConvertToScalar(Value *V, uint64_t Offset, Value *NonConstantIdx);
  304. void MergeInTypeForLoadOrStore(Type *In, uint64_t Offset);
  305. bool MergeInVectorType(VectorType *VInTy, uint64_t Offset);
  306. void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset,
  307. Value *NonConstantIdx);
  308. Value *ConvertScalar_ExtractValue(Value *NV, Type *ToType, uint64_t Offset,
  309. Value *NonConstantIdx,
  310. IRBuilder<> &Builder);
  311. Value *ConvertScalar_InsertValue(Value *StoredVal, Value *ExistingVal,
  312. uint64_t Offset, Value *NonConstantIdx,
  313. IRBuilder<> &Builder);
  314. };
  315. } // end anonymous namespace.
  316. /// TryConvert - Analyze the specified alloca, and if it is safe to do so,
  317. /// rewrite it to be a new alloca which is mem2reg'able. This returns the new
  318. /// alloca if possible or null if not.
  319. AllocaInst *ConvertToScalarInfo::TryConvert(AllocaInst *AI) {
  320. // If we can't convert this scalar, or if mem2reg can trivially do it, bail
  321. // out.
  322. if (!CanConvertToScalar(AI, 0, nullptr) || !IsNotTrivial)
  323. return nullptr;
  324. // If an alloca has only memset / memcpy uses, it may still have an Unknown
  325. // ScalarKind. Treat it as an Integer below.
  326. if (ScalarKind == Unknown)
  327. ScalarKind = Integer;
  328. if (ScalarKind == Vector && VectorTy->getBitWidth() != AllocaSize * 8)
  329. ScalarKind = Integer;
  330. // If we were able to find a vector type that can handle this with
  331. // insert/extract elements, and if there was at least one use that had
  332. // a vector type, promote this to a vector. We don't want to promote
  333. // random stuff that doesn't use vectors (e.g. <9 x double>) because then
  334. // we just get a lot of insert/extracts. If at least one vector is
  335. // involved, then we probably really do have a union of vector/array.
  336. Type *NewTy;
  337. if (ScalarKind == Vector) {
  338. assert(VectorTy && "Missing type for vector scalar.");
  339. DEBUG(dbgs() << "CONVERT TO VECTOR: " << *AI << "\n TYPE = " << *VectorTy
  340. << '\n');
  341. NewTy = VectorTy; // Use the vector type.
  342. } else {
  343. unsigned BitWidth = AllocaSize * 8;
  344. // Do not convert to scalar integer if the alloca size exceeds the
  345. // scalar load threshold.
  346. if (BitWidth > ScalarLoadThreshold)
  347. return nullptr;
  348. if ((ScalarKind == ImplicitVector || ScalarKind == Integer) &&
  349. !HadNonMemTransferAccess && !DL.fitsInLegalInteger(BitWidth))
  350. return nullptr;
  351. // Dynamic accesses on integers aren't yet supported. They need us to shift
  352. // by a dynamic amount which could be difficult to work out as we might not
  353. // know whether to use a left or right shift.
  354. if (ScalarKind == Integer && HadDynamicAccess)
  355. return nullptr;
  356. DEBUG(dbgs() << "CONVERT TO SCALAR INTEGER: " << *AI << "\n");
  357. // Create and insert the integer alloca.
  358. NewTy = IntegerType::get(AI->getContext(), BitWidth);
  359. }
  360. AllocaInst *NewAI =
  361. new AllocaInst(NewTy, nullptr, "", AI->getParent()->begin());
  362. ConvertUsesToScalar(AI, NewAI, 0, nullptr);
  363. return NewAI;
  364. }
  365. /// MergeInTypeForLoadOrStore - Add the 'In' type to the accumulated vector type
  366. /// (VectorTy) so far at the offset specified by Offset (which is specified in
  367. /// bytes).
  368. ///
  369. /// There are two cases we handle here:
  370. /// 1) A union of vector types of the same size and potentially its elements.
  371. /// Here we turn element accesses into insert/extract element operations.
  372. /// This promotes a <4 x float> with a store of float to the third element
  373. /// into a <4 x float> that uses insert element.
  374. /// 2) A fully general blob of memory, which we turn into some (potentially
  375. /// large) integer type with extract and insert operations where the loads
  376. /// and stores would mutate the memory. We mark this by setting VectorTy
  377. /// to VoidTy.
  378. void ConvertToScalarInfo::MergeInTypeForLoadOrStore(Type *In, uint64_t Offset) {
  379. // If we already decided to turn this into a blob of integer memory, there is
  380. // nothing to be done.
  381. if (ScalarKind == Integer)
  382. return;
  383. // If this could be contributing to a vector, analyze it.
  384. // If the In type is a vector that is the same size as the alloca, see if it
  385. // matches the existing VecTy.
  386. if (VectorType *VInTy = dyn_cast<VectorType>(In)) {
  387. if (MergeInVectorType(VInTy, Offset))
  388. return;
  389. } else if (In->isFloatTy() || In->isDoubleTy() ||
  390. (In->isIntegerTy() && In->getPrimitiveSizeInBits() >= 8 &&
  391. isPowerOf2_32(In->getPrimitiveSizeInBits()))) {
  392. // Full width accesses can be ignored, because they can always be turned
  393. // into bitcasts.
  394. unsigned EltSize = In->getPrimitiveSizeInBits() / 8;
  395. if (EltSize == AllocaSize)
  396. return;
  397. // If we're accessing something that could be an element of a vector, see
  398. // if the implied vector agrees with what we already have and if Offset is
  399. // compatible with it.
  400. if (Offset % EltSize == 0 && AllocaSize % EltSize == 0 &&
  401. (!VectorTy ||
  402. EltSize == VectorTy->getElementType()->getPrimitiveSizeInBits() / 8)) {
  403. if (!VectorTy) {
  404. ScalarKind = ImplicitVector;
  405. VectorTy = VectorType::get(In, AllocaSize / EltSize);
  406. }
  407. return;
  408. }
  409. }
  410. // Otherwise, we have a case that we can't handle with an optimized vector
  411. // form. We can still turn this into a large integer.
  412. ScalarKind = Integer;
  413. }
  414. /// MergeInVectorType - Handles the vector case of MergeInTypeForLoadOrStore,
  415. /// returning true if the type was successfully merged and false otherwise.
  416. bool ConvertToScalarInfo::MergeInVectorType(VectorType *VInTy,
  417. uint64_t Offset) {
  418. if (VInTy->getBitWidth() / 8 == AllocaSize && Offset == 0) {
  419. // If we're storing/loading a vector of the right size, allow it as a
  420. // vector. If this the first vector we see, remember the type so that
  421. // we know the element size. If this is a subsequent access, ignore it
  422. // even if it is a differing type but the same size. Worst case we can
  423. // bitcast the resultant vectors.
  424. if (!VectorTy)
  425. VectorTy = VInTy;
  426. ScalarKind = Vector;
  427. return true;
  428. }
  429. return false;
  430. }
  431. /// CanConvertToScalar - V is a pointer. If we can convert the pointee and all
  432. /// its accesses to a single vector type, return true and set VecTy to
  433. /// the new type. If we could convert the alloca into a single promotable
  434. /// integer, return true but set VecTy to VoidTy. Further, if the use is not a
  435. /// completely trivial use that mem2reg could promote, set IsNotTrivial. Offset
  436. /// is the current offset from the base of the alloca being analyzed.
  437. ///
  438. /// If we see at least one access to the value that is as a vector type, set the
  439. /// SawVec flag.
  440. bool ConvertToScalarInfo::CanConvertToScalar(Value *V, uint64_t Offset,
  441. Value *NonConstantIdx) {
  442. for (User *U : V->users()) {
  443. Instruction *UI = cast<Instruction>(U);
  444. if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  445. // Don't break volatile loads.
  446. if (!LI->isSimple())
  447. return false;
  448. HadNonMemTransferAccess = true;
  449. MergeInTypeForLoadOrStore(LI->getType(), Offset);
  450. continue;
  451. }
  452. if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
  453. // Storing the pointer, not into the value?
  454. if (SI->getOperand(0) == V || !SI->isSimple())
  455. return false;
  456. HadNonMemTransferAccess = true;
  457. MergeInTypeForLoadOrStore(SI->getOperand(0)->getType(), Offset);
  458. continue;
  459. }
  460. if (BitCastInst *BCI = dyn_cast<BitCastInst>(UI)) {
  461. if (!onlyUsedByLifetimeMarkers(BCI))
  462. IsNotTrivial = true; // Can't be mem2reg'd.
  463. if (!CanConvertToScalar(BCI, Offset, NonConstantIdx))
  464. return false;
  465. continue;
  466. }
  467. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UI)) {
  468. // If this is a GEP with a variable indices, we can't handle it.
  469. PointerType *PtrTy = dyn_cast<PointerType>(GEP->getPointerOperandType());
  470. if (!PtrTy)
  471. return false;
  472. // Compute the offset that this GEP adds to the pointer.
  473. SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
  474. Value *GEPNonConstantIdx = nullptr;
  475. if (!GEP->hasAllConstantIndices()) {
  476. if (!isa<VectorType>(PtrTy->getElementType()))
  477. return false;
  478. if (NonConstantIdx)
  479. return false;
  480. GEPNonConstantIdx = Indices.pop_back_val();
  481. if (!GEPNonConstantIdx->getType()->isIntegerTy(32))
  482. return false;
  483. HadDynamicAccess = true;
  484. } else
  485. GEPNonConstantIdx = NonConstantIdx;
  486. uint64_t GEPOffset = DL.getIndexedOffset(PtrTy, Indices);
  487. // See if all uses can be converted.
  488. if (!CanConvertToScalar(GEP, Offset + GEPOffset, GEPNonConstantIdx))
  489. return false;
  490. IsNotTrivial = true; // Can't be mem2reg'd.
  491. HadNonMemTransferAccess = true;
  492. continue;
  493. }
  494. // If this is a constant sized memset of a constant value (e.g. 0) we can
  495. // handle it.
  496. if (MemSetInst *MSI = dyn_cast<MemSetInst>(UI)) {
  497. // Store to dynamic index.
  498. if (NonConstantIdx)
  499. return false;
  500. // Store of constant value.
  501. if (!isa<ConstantInt>(MSI->getValue()))
  502. return false;
  503. // Store of constant size.
  504. ConstantInt *Len = dyn_cast<ConstantInt>(MSI->getLength());
  505. if (!Len)
  506. return false;
  507. // If the size differs from the alloca, we can only convert the alloca to
  508. // an integer bag-of-bits.
  509. // FIXME: This should handle all of the cases that are currently accepted
  510. // as vector element insertions.
  511. if (Len->getZExtValue() != AllocaSize || Offset != 0)
  512. ScalarKind = Integer;
  513. IsNotTrivial = true; // Can't be mem2reg'd.
  514. HadNonMemTransferAccess = true;
  515. continue;
  516. }
  517. // If this is a memcpy or memmove into or out of the whole allocation, we
  518. // can handle it like a load or store of the scalar type.
  519. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(UI)) {
  520. // Store to dynamic index.
  521. if (NonConstantIdx)
  522. return false;
  523. ConstantInt *Len = dyn_cast<ConstantInt>(MTI->getLength());
  524. if (!Len || Len->getZExtValue() != AllocaSize || Offset != 0)
  525. return false;
  526. IsNotTrivial = true; // Can't be mem2reg'd.
  527. continue;
  528. }
  529. // If this is a lifetime intrinsic, we can handle it.
  530. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(UI)) {
  531. if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
  532. II->getIntrinsicID() == Intrinsic::lifetime_end) {
  533. continue;
  534. }
  535. }
  536. // Otherwise, we cannot handle this!
  537. return false;
  538. }
  539. return true;
  540. }
  541. /// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca
  542. /// directly. This happens when we are converting an "integer union" to a
  543. /// single integer scalar, or when we are converting a "vector union" to a
  544. /// vector with insert/extractelement instructions.
  545. ///
  546. /// Offset is an offset from the original alloca, in bits that need to be
  547. /// shifted to the right. By the end of this, there should be no uses of Ptr.
  548. void ConvertToScalarInfo::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI,
  549. uint64_t Offset,
  550. Value *NonConstantIdx) {
  551. while (!Ptr->use_empty()) {
  552. Instruction *User = cast<Instruction>(Ptr->user_back());
  553. if (BitCastInst *CI = dyn_cast<BitCastInst>(User)) {
  554. ConvertUsesToScalar(CI, NewAI, Offset, NonConstantIdx);
  555. CI->eraseFromParent();
  556. continue;
  557. }
  558. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  559. // Compute the offset that this GEP adds to the pointer.
  560. SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
  561. Value *GEPNonConstantIdx = nullptr;
  562. if (!GEP->hasAllConstantIndices()) {
  563. assert(!NonConstantIdx &&
  564. "Dynamic GEP reading from dynamic GEP unsupported");
  565. GEPNonConstantIdx = Indices.pop_back_val();
  566. } else
  567. GEPNonConstantIdx = NonConstantIdx;
  568. uint64_t GEPOffset =
  569. DL.getIndexedOffset(GEP->getPointerOperandType(), Indices);
  570. ConvertUsesToScalar(GEP, NewAI, Offset + GEPOffset * 8,
  571. GEPNonConstantIdx);
  572. GEP->eraseFromParent();
  573. continue;
  574. }
  575. IRBuilder<> Builder(User);
  576. if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
  577. // The load is a bit extract from NewAI shifted right by Offset bits.
  578. Value *LoadedVal = Builder.CreateLoad(NewAI);
  579. Value *NewLoadVal = ConvertScalar_ExtractValue(
  580. LoadedVal, LI->getType(), Offset, NonConstantIdx, Builder);
  581. LI->replaceAllUsesWith(NewLoadVal);
  582. LI->eraseFromParent();
  583. continue;
  584. }
  585. if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
  586. assert(SI->getOperand(0) != Ptr && "Consistency error!");
  587. Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName() + ".in");
  588. Value *New = ConvertScalar_InsertValue(SI->getOperand(0), Old, Offset,
  589. NonConstantIdx, Builder);
  590. Builder.CreateStore(New, NewAI);
  591. SI->eraseFromParent();
  592. // If the load we just inserted is now dead, then the inserted store
  593. // overwrote the entire thing.
  594. if (Old->use_empty())
  595. Old->eraseFromParent();
  596. continue;
  597. }
  598. // If this is a constant sized memset of a constant value (e.g. 0) we can
  599. // transform it into a store of the expanded constant value.
  600. if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) {
  601. assert(MSI->getRawDest() == Ptr && "Consistency error!");
  602. assert(!NonConstantIdx && "Cannot replace dynamic memset with insert");
  603. int64_t SNumBytes = cast<ConstantInt>(MSI->getLength())->getSExtValue();
  604. if (SNumBytes > 0 && (SNumBytes >> 32) == 0) {
  605. unsigned NumBytes = static_cast<unsigned>(SNumBytes);
  606. unsigned Val = cast<ConstantInt>(MSI->getValue())->getZExtValue();
  607. // Compute the value replicated the right number of times.
  608. APInt APVal(NumBytes * 8, Val);
  609. // Splat the value if non-zero.
  610. if (Val)
  611. for (unsigned i = 1; i != NumBytes; ++i)
  612. APVal |= APVal << 8;
  613. Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName() + ".in");
  614. Value *New = ConvertScalar_InsertValue(
  615. ConstantInt::get(User->getContext(), APVal), Old, Offset, nullptr,
  616. Builder);
  617. Builder.CreateStore(New, NewAI);
  618. // If the load we just inserted is now dead, then the memset overwrote
  619. // the entire thing.
  620. if (Old->use_empty())
  621. Old->eraseFromParent();
  622. }
  623. MSI->eraseFromParent();
  624. continue;
  625. }
  626. // If this is a memcpy or memmove into or out of the whole allocation, we
  627. // can handle it like a load or store of the scalar type.
  628. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(User)) {
  629. assert(Offset == 0 && "must be store to start of alloca");
  630. assert(!NonConstantIdx && "Cannot replace dynamic transfer with insert");
  631. // If the source and destination are both to the same alloca, then this is
  632. // a noop copy-to-self, just delete it. Otherwise, emit a load and store
  633. // as appropriate.
  634. AllocaInst *OrigAI = cast<AllocaInst>(GetUnderlyingObject(Ptr, DL, 0));
  635. if (GetUnderlyingObject(MTI->getSource(), DL, 0) != OrigAI) {
  636. // Dest must be OrigAI, change this to be a load from the original
  637. // pointer (bitcasted), then a store to our new alloca.
  638. assert(MTI->getRawDest() == Ptr && "Neither use is of pointer?");
  639. Value *SrcPtr = MTI->getSource();
  640. PointerType *SPTy = cast<PointerType>(SrcPtr->getType());
  641. PointerType *AIPTy = cast<PointerType>(NewAI->getType());
  642. if (SPTy->getAddressSpace() != AIPTy->getAddressSpace()) {
  643. AIPTy = PointerType::get(AIPTy->getElementType(),
  644. SPTy->getAddressSpace());
  645. }
  646. SrcPtr = Builder.CreateBitCast(SrcPtr, AIPTy);
  647. LoadInst *SrcVal = Builder.CreateLoad(SrcPtr, "srcval");
  648. SrcVal->setAlignment(MTI->getAlignment());
  649. Builder.CreateStore(SrcVal, NewAI);
  650. } else if (GetUnderlyingObject(MTI->getDest(), DL, 0) != OrigAI) {
  651. // Src must be OrigAI, change this to be a load from NewAI then a store
  652. // through the original dest pointer (bitcasted).
  653. assert(MTI->getRawSource() == Ptr && "Neither use is of pointer?");
  654. LoadInst *SrcVal = Builder.CreateLoad(NewAI, "srcval");
  655. PointerType *DPTy = cast<PointerType>(MTI->getDest()->getType());
  656. PointerType *AIPTy = cast<PointerType>(NewAI->getType());
  657. if (DPTy->getAddressSpace() != AIPTy->getAddressSpace()) {
  658. AIPTy = PointerType::get(AIPTy->getElementType(),
  659. DPTy->getAddressSpace());
  660. }
  661. Value *DstPtr = Builder.CreateBitCast(MTI->getDest(), AIPTy);
  662. StoreInst *NewStore = Builder.CreateStore(SrcVal, DstPtr);
  663. NewStore->setAlignment(MTI->getAlignment());
  664. } else {
  665. // Noop transfer. Src == Dst
  666. }
  667. MTI->eraseFromParent();
  668. continue;
  669. }
  670. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) {
  671. if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
  672. II->getIntrinsicID() == Intrinsic::lifetime_end) {
  673. // There's no need to preserve these, as the resulting alloca will be
  674. // converted to a register anyways.
  675. II->eraseFromParent();
  676. continue;
  677. }
  678. }
  679. llvm_unreachable("Unsupported operation!");
  680. }
  681. }
  682. /// ConvertScalar_ExtractValue - Extract a value of type ToType from an integer
  683. /// or vector value FromVal, extracting the bits from the offset specified by
  684. /// Offset. This returns the value, which is of type ToType.
  685. ///
  686. /// This happens when we are converting an "integer union" to a single
  687. /// integer scalar, or when we are converting a "vector union" to a vector with
  688. /// insert/extractelement instructions.
  689. ///
  690. /// Offset is an offset from the original alloca, in bits that need to be
  691. /// shifted to the right.
  692. Value *ConvertToScalarInfo::ConvertScalar_ExtractValue(Value *FromVal,
  693. Type *ToType,
  694. uint64_t Offset,
  695. Value *NonConstantIdx,
  696. IRBuilder<> &Builder) {
  697. // If the load is of the whole new alloca, no conversion is needed.
  698. Type *FromType = FromVal->getType();
  699. if (FromType == ToType && Offset == 0)
  700. return FromVal;
  701. // If the result alloca is a vector type, this is either an element
  702. // access or a bitcast to another vector type of the same size.
  703. if (VectorType *VTy = dyn_cast<VectorType>(FromType)) {
  704. unsigned FromTypeSize = DL.getTypeAllocSize(FromType);
  705. unsigned ToTypeSize = DL.getTypeAllocSize(ToType);
  706. if (FromTypeSize == ToTypeSize)
  707. return Builder.CreateBitCast(FromVal, ToType);
  708. // Otherwise it must be an element access.
  709. unsigned Elt = 0;
  710. if (Offset) {
  711. unsigned EltSize = DL.getTypeAllocSizeInBits(VTy->getElementType());
  712. Elt = Offset / EltSize;
  713. assert(EltSize * Elt == Offset && "Invalid modulus in validity checking");
  714. }
  715. // Return the element extracted out of it.
  716. Value *Idx;
  717. if (NonConstantIdx) {
  718. if (Elt)
  719. Idx = Builder.CreateAdd(NonConstantIdx, Builder.getInt32(Elt),
  720. "dyn.offset");
  721. else
  722. Idx = NonConstantIdx;
  723. } else
  724. Idx = Builder.getInt32(Elt);
  725. Value *V = Builder.CreateExtractElement(FromVal, Idx);
  726. if (V->getType() != ToType)
  727. V = Builder.CreateBitCast(V, ToType);
  728. return V;
  729. }
  730. // If ToType is a first class aggregate, extract out each of the pieces and
  731. // use insertvalue's to form the FCA.
  732. if (StructType *ST = dyn_cast<StructType>(ToType)) {
  733. assert(!NonConstantIdx &&
  734. "Dynamic indexing into struct types not supported");
  735. const StructLayout &Layout = *DL.getStructLayout(ST);
  736. Value *Res = UndefValue::get(ST);
  737. for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
  738. Value *Elt = ConvertScalar_ExtractValue(
  739. FromVal, ST->getElementType(i),
  740. Offset + Layout.getElementOffsetInBits(i), nullptr, Builder);
  741. Res = Builder.CreateInsertValue(Res, Elt, i);
  742. }
  743. return Res;
  744. }
  745. if (ArrayType *AT = dyn_cast<ArrayType>(ToType)) {
  746. assert(!NonConstantIdx &&
  747. "Dynamic indexing into array types not supported");
  748. uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType());
  749. Value *Res = UndefValue::get(AT);
  750. for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) {
  751. Value *Elt =
  752. ConvertScalar_ExtractValue(FromVal, AT->getElementType(),
  753. Offset + i * EltSize, nullptr, Builder);
  754. Res = Builder.CreateInsertValue(Res, Elt, i);
  755. }
  756. return Res;
  757. }
  758. // Otherwise, this must be a union that was converted to an integer value.
  759. IntegerType *NTy = cast<IntegerType>(FromVal->getType());
  760. // If this is a big-endian system and the load is narrower than the
  761. // full alloca type, we need to do a shift to get the right bits.
  762. int ShAmt = 0;
  763. if (DL.isBigEndian()) {
  764. // On big-endian machines, the lowest bit is stored at the bit offset
  765. // from the pointer given by getTypeStoreSizeInBits. This matters for
  766. // integers with a bitwidth that is not a multiple of 8.
  767. ShAmt = DL.getTypeStoreSizeInBits(NTy) - DL.getTypeStoreSizeInBits(ToType) -
  768. Offset;
  769. } else {
  770. ShAmt = Offset;
  771. }
  772. // Note: we support negative bitwidths (with shl) which are not defined.
  773. // We do this to support (f.e.) loads off the end of a structure where
  774. // only some bits are used.
  775. if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth())
  776. FromVal = Builder.CreateLShr(FromVal,
  777. ConstantInt::get(FromVal->getType(), ShAmt));
  778. else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth())
  779. FromVal = Builder.CreateShl(FromVal,
  780. ConstantInt::get(FromVal->getType(), -ShAmt));
  781. // Finally, unconditionally truncate the integer to the right width.
  782. unsigned LIBitWidth = DL.getTypeSizeInBits(ToType);
  783. if (LIBitWidth < NTy->getBitWidth())
  784. FromVal = Builder.CreateTrunc(
  785. FromVal, IntegerType::get(FromVal->getContext(), LIBitWidth));
  786. else if (LIBitWidth > NTy->getBitWidth())
  787. FromVal = Builder.CreateZExt(
  788. FromVal, IntegerType::get(FromVal->getContext(), LIBitWidth));
  789. // If the result is an integer, this is a trunc or bitcast.
  790. if (ToType->isIntegerTy()) {
  791. // Should be done.
  792. } else if (ToType->isFloatingPointTy() || ToType->isVectorTy()) {
  793. // Just do a bitcast, we know the sizes match up.
  794. FromVal = Builder.CreateBitCast(FromVal, ToType);
  795. } else {
  796. // Otherwise must be a pointer.
  797. FromVal = Builder.CreateIntToPtr(FromVal, ToType);
  798. }
  799. assert(FromVal->getType() == ToType && "Didn't convert right?");
  800. return FromVal;
  801. }
  802. /// ConvertScalar_InsertValue - Insert the value "SV" into the existing integer
  803. /// or vector value "Old" at the offset specified by Offset.
  804. ///
  805. /// This happens when we are converting an "integer union" to a
  806. /// single integer scalar, or when we are converting a "vector union" to a
  807. /// vector with insert/extractelement instructions.
  808. ///
  809. /// Offset is an offset from the original alloca, in bits that need to be
  810. /// shifted to the right.
  811. ///
  812. /// NonConstantIdx is an index value if there was a GEP with a non-constant
  813. /// index value. If this is 0 then all GEPs used to find this insert address
  814. /// are constant.
  815. Value *ConvertToScalarInfo::ConvertScalar_InsertValue(Value *SV, Value *Old,
  816. uint64_t Offset,
  817. Value *NonConstantIdx,
  818. IRBuilder<> &Builder) {
  819. // Convert the stored type to the actual type, shift it left to insert
  820. // then 'or' into place.
  821. Type *AllocaType = Old->getType();
  822. LLVMContext &Context = Old->getContext();
  823. if (VectorType *VTy = dyn_cast<VectorType>(AllocaType)) {
  824. uint64_t VecSize = DL.getTypeAllocSizeInBits(VTy);
  825. uint64_t ValSize = DL.getTypeAllocSizeInBits(SV->getType());
  826. // Changing the whole vector with memset or with an access of a different
  827. // vector type?
  828. if (ValSize == VecSize)
  829. return Builder.CreateBitCast(SV, AllocaType);
  830. // Must be an element insertion.
  831. Type *EltTy = VTy->getElementType();
  832. if (SV->getType() != EltTy)
  833. SV = Builder.CreateBitCast(SV, EltTy);
  834. uint64_t EltSize = DL.getTypeAllocSizeInBits(EltTy);
  835. unsigned Elt = Offset / EltSize;
  836. Value *Idx;
  837. if (NonConstantIdx) {
  838. if (Elt)
  839. Idx = Builder.CreateAdd(NonConstantIdx, Builder.getInt32(Elt),
  840. "dyn.offset");
  841. else
  842. Idx = NonConstantIdx;
  843. } else
  844. Idx = Builder.getInt32(Elt);
  845. return Builder.CreateInsertElement(Old, SV, Idx);
  846. }
  847. // If SV is a first-class aggregate value, insert each value recursively.
  848. if (StructType *ST = dyn_cast<StructType>(SV->getType())) {
  849. assert(!NonConstantIdx &&
  850. "Dynamic indexing into struct types not supported");
  851. const StructLayout &Layout = *DL.getStructLayout(ST);
  852. for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
  853. Value *Elt = Builder.CreateExtractValue(SV, i);
  854. Old = ConvertScalar_InsertValue(Elt, Old,
  855. Offset + Layout.getElementOffsetInBits(i),
  856. nullptr, Builder);
  857. }
  858. return Old;
  859. }
  860. if (ArrayType *AT = dyn_cast<ArrayType>(SV->getType())) {
  861. assert(!NonConstantIdx &&
  862. "Dynamic indexing into array types not supported");
  863. uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType());
  864. for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) {
  865. Value *Elt = Builder.CreateExtractValue(SV, i);
  866. Old = ConvertScalar_InsertValue(Elt, Old, Offset + i * EltSize, nullptr,
  867. Builder);
  868. }
  869. return Old;
  870. }
  871. // If SV is a float, convert it to the appropriate integer type.
  872. // If it is a pointer, do the same.
  873. unsigned SrcWidth = DL.getTypeSizeInBits(SV->getType());
  874. unsigned DestWidth = DL.getTypeSizeInBits(AllocaType);
  875. unsigned SrcStoreWidth = DL.getTypeStoreSizeInBits(SV->getType());
  876. unsigned DestStoreWidth = DL.getTypeStoreSizeInBits(AllocaType);
  877. if (SV->getType()->isFloatingPointTy() || SV->getType()->isVectorTy())
  878. SV =
  879. Builder.CreateBitCast(SV, IntegerType::get(SV->getContext(), SrcWidth));
  880. else if (SV->getType()->isPointerTy())
  881. SV = Builder.CreatePtrToInt(SV, DL.getIntPtrType(SV->getType()));
  882. // Zero extend or truncate the value if needed.
  883. if (SV->getType() != AllocaType) {
  884. if (SV->getType()->getPrimitiveSizeInBits() <
  885. AllocaType->getPrimitiveSizeInBits())
  886. SV = Builder.CreateZExt(SV, AllocaType);
  887. else {
  888. // Truncation may be needed if storing more than the alloca can hold
  889. // (undefined behavior).
  890. SV = Builder.CreateTrunc(SV, AllocaType);
  891. SrcWidth = DestWidth;
  892. SrcStoreWidth = DestStoreWidth;
  893. }
  894. }
  895. // If this is a big-endian system and the store is narrower than the
  896. // full alloca type, we need to do a shift to get the right bits.
  897. int ShAmt = 0;
  898. if (DL.isBigEndian()) {
  899. // On big-endian machines, the lowest bit is stored at the bit offset
  900. // from the pointer given by getTypeStoreSizeInBits. This matters for
  901. // integers with a bitwidth that is not a multiple of 8.
  902. ShAmt = DestStoreWidth - SrcStoreWidth - Offset;
  903. } else {
  904. ShAmt = Offset;
  905. }
  906. // Note: we support negative bitwidths (with shr) which are not defined.
  907. // We do this to support (f.e.) stores off the end of a structure where
  908. // only some bits in the structure are set.
  909. APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth));
  910. if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) {
  911. SV = Builder.CreateShl(SV, ConstantInt::get(SV->getType(), ShAmt));
  912. Mask <<= ShAmt;
  913. } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) {
  914. SV = Builder.CreateLShr(SV, ConstantInt::get(SV->getType(), -ShAmt));
  915. Mask = Mask.lshr(-ShAmt);
  916. }
  917. // Mask out the bits we are about to insert from the old value, and or
  918. // in the new bits.
  919. if (SrcWidth != DestWidth) {
  920. assert(DestWidth > SrcWidth);
  921. Old = Builder.CreateAnd(Old, ConstantInt::get(Context, ~Mask), "mask");
  922. SV = Builder.CreateOr(Old, SV, "ins");
  923. }
  924. return SV;
  925. }
  926. //===----------------------------------------------------------------------===//
  927. // SRoA Driver
  928. //===----------------------------------------------------------------------===//
  929. bool SROA_HLSL::runOnFunction(Function &F) {
  930. Module *M = F.getParent();
  931. HLModule &HLM = M->GetOrCreateHLModule();
  932. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  933. bool Changed = performScalarRepl(F, typeSys);
  934. // change rest memcpy into ld/st.
  935. MemcpySplitter splitter(F.getContext(), typeSys);
  936. splitter.Split(F);
  937. Changed |= markPrecise(F);
  938. return Changed;
  939. }
  940. namespace {
  941. class AllocaPromoter : public LoadAndStorePromoter {
  942. AllocaInst *AI;
  943. DIBuilder *DIB;
  944. SmallVector<DbgDeclareInst *, 4> DDIs;
  945. SmallVector<DbgValueInst *, 4> DVIs;
  946. public:
  947. AllocaPromoter(ArrayRef<Instruction *> Insts, SSAUpdater &S, DIBuilder *DB)
  948. : LoadAndStorePromoter(Insts, S), AI(nullptr), DIB(DB) {}
  949. void run(AllocaInst *AI, const SmallVectorImpl<Instruction *> &Insts) {
  950. // Remember which alloca we're promoting (for isInstInList).
  951. this->AI = AI;
  952. if (auto *L = LocalAsMetadata::getIfExists(AI)) {
  953. if (auto *DINode = MetadataAsValue::getIfExists(AI->getContext(), L)) {
  954. for (User *U : DINode->users())
  955. if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U))
  956. DDIs.push_back(DDI);
  957. else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U))
  958. DVIs.push_back(DVI);
  959. }
  960. }
  961. LoadAndStorePromoter::run(Insts);
  962. AI->eraseFromParent();
  963. for (SmallVectorImpl<DbgDeclareInst *>::iterator I = DDIs.begin(),
  964. E = DDIs.end();
  965. I != E; ++I) {
  966. DbgDeclareInst *DDI = *I;
  967. DDI->eraseFromParent();
  968. }
  969. for (SmallVectorImpl<DbgValueInst *>::iterator I = DVIs.begin(),
  970. E = DVIs.end();
  971. I != E; ++I) {
  972. DbgValueInst *DVI = *I;
  973. DVI->eraseFromParent();
  974. }
  975. }
  976. bool
  977. isInstInList(Instruction *I,
  978. const SmallVectorImpl<Instruction *> &Insts) const override {
  979. if (LoadInst *LI = dyn_cast<LoadInst>(I))
  980. return LI->getOperand(0) == AI;
  981. return cast<StoreInst>(I)->getPointerOperand() == AI;
  982. }
  983. void updateDebugInfo(Instruction *Inst) const override {
  984. for (SmallVectorImpl<DbgDeclareInst *>::const_iterator I = DDIs.begin(),
  985. E = DDIs.end();
  986. I != E; ++I) {
  987. DbgDeclareInst *DDI = *I;
  988. if (StoreInst *SI = dyn_cast<StoreInst>(Inst))
  989. ConvertDebugDeclareToDebugValue(DDI, SI, *DIB);
  990. else if (LoadInst *LI = dyn_cast<LoadInst>(Inst))
  991. ConvertDebugDeclareToDebugValue(DDI, LI, *DIB);
  992. }
  993. for (SmallVectorImpl<DbgValueInst *>::const_iterator I = DVIs.begin(),
  994. E = DVIs.end();
  995. I != E; ++I) {
  996. DbgValueInst *DVI = *I;
  997. Value *Arg = nullptr;
  998. if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
  999. // If an argument is zero extended then use argument directly. The ZExt
  1000. // may be zapped by an optimization pass in future.
  1001. if (ZExtInst *ZExt = dyn_cast<ZExtInst>(SI->getOperand(0)))
  1002. Arg = dyn_cast<Argument>(ZExt->getOperand(0));
  1003. if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0)))
  1004. Arg = dyn_cast<Argument>(SExt->getOperand(0));
  1005. if (!Arg)
  1006. Arg = SI->getOperand(0);
  1007. } else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
  1008. Arg = LI->getOperand(0);
  1009. } else {
  1010. continue;
  1011. }
  1012. DIB->insertDbgValueIntrinsic(Arg, 0, DVI->getVariable(),
  1013. DVI->getExpression(), DVI->getDebugLoc(),
  1014. Inst);
  1015. }
  1016. }
  1017. };
  1018. } // end anon namespace
  1019. /// isSafeSelectToSpeculate - Select instructions that use an alloca and are
  1020. /// subsequently loaded can be rewritten to load both input pointers and then
  1021. /// select between the result, allowing the load of the alloca to be promoted.
  1022. /// From this:
  1023. /// %P2 = select i1 %cond, i32* %Alloca, i32* %Other
  1024. /// %V = load i32* %P2
  1025. /// to:
  1026. /// %V1 = load i32* %Alloca -> will be mem2reg'd
  1027. /// %V2 = load i32* %Other
  1028. /// %V = select i1 %cond, i32 %V1, i32 %V2
  1029. ///
  1030. /// We can do this to a select if its only uses are loads and if the operand to
  1031. /// the select can be loaded unconditionally.
  1032. static bool isSafeSelectToSpeculate(SelectInst *SI) {
  1033. const DataLayout &DL = SI->getModule()->getDataLayout();
  1034. bool TDerefable = isDereferenceablePointer(SI->getTrueValue(), DL);
  1035. bool FDerefable = isDereferenceablePointer(SI->getFalseValue(), DL);
  1036. for (User *U : SI->users()) {
  1037. LoadInst *LI = dyn_cast<LoadInst>(U);
  1038. if (!LI || !LI->isSimple())
  1039. return false;
  1040. // Both operands to the select need to be dereferencable, either absolutely
  1041. // (e.g. allocas) or at this point because we can see other accesses to it.
  1042. if (!TDerefable &&
  1043. !isSafeToLoadUnconditionally(SI->getTrueValue(), LI,
  1044. LI->getAlignment()))
  1045. return false;
  1046. if (!FDerefable &&
  1047. !isSafeToLoadUnconditionally(SI->getFalseValue(), LI,
  1048. LI->getAlignment()))
  1049. return false;
  1050. }
  1051. return true;
  1052. }
  1053. /// isSafePHIToSpeculate - PHI instructions that use an alloca and are
  1054. /// subsequently loaded can be rewritten to load both input pointers in the pred
  1055. /// blocks and then PHI the results, allowing the load of the alloca to be
  1056. /// promoted.
  1057. /// From this:
  1058. /// %P2 = phi [i32* %Alloca, i32* %Other]
  1059. /// %V = load i32* %P2
  1060. /// to:
  1061. /// %V1 = load i32* %Alloca -> will be mem2reg'd
  1062. /// ...
  1063. /// %V2 = load i32* %Other
  1064. /// ...
  1065. /// %V = phi [i32 %V1, i32 %V2]
  1066. ///
  1067. /// We can do this to a select if its only uses are loads and if the operand to
  1068. /// the select can be loaded unconditionally.
  1069. static bool isSafePHIToSpeculate(PHINode *PN) {
  1070. // For now, we can only do this promotion if the load is in the same block as
  1071. // the PHI, and if there are no stores between the phi and load.
  1072. // TODO: Allow recursive phi users.
  1073. // TODO: Allow stores.
  1074. BasicBlock *BB = PN->getParent();
  1075. unsigned MaxAlign = 0;
  1076. for (User *U : PN->users()) {
  1077. LoadInst *LI = dyn_cast<LoadInst>(U);
  1078. if (!LI || !LI->isSimple())
  1079. return false;
  1080. // For now we only allow loads in the same block as the PHI. This is a
  1081. // common case that happens when instcombine merges two loads through a PHI.
  1082. if (LI->getParent() != BB)
  1083. return false;
  1084. // Ensure that there are no instructions between the PHI and the load that
  1085. // could store.
  1086. for (BasicBlock::iterator BBI = PN; &*BBI != LI; ++BBI)
  1087. if (BBI->mayWriteToMemory())
  1088. return false;
  1089. MaxAlign = std::max(MaxAlign, LI->getAlignment());
  1090. }
  1091. const DataLayout &DL = PN->getModule()->getDataLayout();
  1092. // Okay, we know that we have one or more loads in the same block as the PHI.
  1093. // We can transform this if it is safe to push the loads into the predecessor
  1094. // blocks. The only thing to watch out for is that we can't put a possibly
  1095. // trapping load in the predecessor if it is a critical edge.
  1096. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  1097. BasicBlock *Pred = PN->getIncomingBlock(i);
  1098. Value *InVal = PN->getIncomingValue(i);
  1099. // If the terminator of the predecessor has side-effects (an invoke),
  1100. // there is no safe place to put a load in the predecessor.
  1101. if (Pred->getTerminator()->mayHaveSideEffects())
  1102. return false;
  1103. // If the value is produced by the terminator of the predecessor
  1104. // (an invoke), there is no valid place to put a load in the predecessor.
  1105. if (Pred->getTerminator() == InVal)
  1106. return false;
  1107. // If the predecessor has a single successor, then the edge isn't critical.
  1108. if (Pred->getTerminator()->getNumSuccessors() == 1)
  1109. continue;
  1110. // If this pointer is always safe to load, or if we can prove that there is
  1111. // already a load in the block, then we can move the load to the pred block.
  1112. if (isDereferenceablePointer(InVal, DL) ||
  1113. isSafeToLoadUnconditionally(InVal, Pred->getTerminator(), MaxAlign))
  1114. continue;
  1115. return false;
  1116. }
  1117. return true;
  1118. }
  1119. /// tryToMakeAllocaBePromotable - This returns true if the alloca only has
  1120. /// direct (non-volatile) loads and stores to it. If the alloca is close but
  1121. /// not quite there, this will transform the code to allow promotion. As such,
  1122. /// it is a non-pure predicate.
  1123. static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout &DL) {
  1124. SetVector<Instruction *, SmallVector<Instruction *, 4>,
  1125. SmallPtrSet<Instruction *, 4>>
  1126. InstsToRewrite;
  1127. for (User *U : AI->users()) {
  1128. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  1129. if (!LI->isSimple())
  1130. return false;
  1131. continue;
  1132. }
  1133. if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  1134. if (SI->getOperand(0) == AI || !SI->isSimple())
  1135. return false; // Don't allow a store OF the AI, only INTO the AI.
  1136. continue;
  1137. }
  1138. if (SelectInst *SI = dyn_cast<SelectInst>(U)) {
  1139. // If the condition being selected on is a constant, fold the select, yes
  1140. // this does (rarely) happen early on.
  1141. if (ConstantInt *CI = dyn_cast<ConstantInt>(SI->getCondition())) {
  1142. Value *Result = SI->getOperand(1 + CI->isZero());
  1143. SI->replaceAllUsesWith(Result);
  1144. SI->eraseFromParent();
  1145. // This is very rare and we just scrambled the use list of AI, start
  1146. // over completely.
  1147. return tryToMakeAllocaBePromotable(AI, DL);
  1148. }
  1149. // If it is safe to turn "load (select c, AI, ptr)" into a select of two
  1150. // loads, then we can transform this by rewriting the select.
  1151. if (!isSafeSelectToSpeculate(SI))
  1152. return false;
  1153. InstsToRewrite.insert(SI);
  1154. continue;
  1155. }
  1156. if (PHINode *PN = dyn_cast<PHINode>(U)) {
  1157. if (PN->use_empty()) { // Dead PHIs can be stripped.
  1158. InstsToRewrite.insert(PN);
  1159. continue;
  1160. }
  1161. // If it is safe to turn "load (phi [AI, ptr, ...])" into a PHI of loads
  1162. // in the pred blocks, then we can transform this by rewriting the PHI.
  1163. if (!isSafePHIToSpeculate(PN))
  1164. return false;
  1165. InstsToRewrite.insert(PN);
  1166. continue;
  1167. }
  1168. if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  1169. if (onlyUsedByLifetimeMarkers(BCI)) {
  1170. InstsToRewrite.insert(BCI);
  1171. continue;
  1172. }
  1173. }
  1174. return false;
  1175. }
  1176. // If there are no instructions to rewrite, then all uses are load/stores and
  1177. // we're done!
  1178. if (InstsToRewrite.empty())
  1179. return true;
  1180. // If we have instructions that need to be rewritten for this to be promotable
  1181. // take care of it now.
  1182. for (unsigned i = 0, e = InstsToRewrite.size(); i != e; ++i) {
  1183. if (BitCastInst *BCI = dyn_cast<BitCastInst>(InstsToRewrite[i])) {
  1184. // This could only be a bitcast used by nothing but lifetime intrinsics.
  1185. for (BitCastInst::user_iterator I = BCI->user_begin(),
  1186. E = BCI->user_end();
  1187. I != E;)
  1188. cast<Instruction>(*I++)->eraseFromParent();
  1189. BCI->eraseFromParent();
  1190. continue;
  1191. }
  1192. if (SelectInst *SI = dyn_cast<SelectInst>(InstsToRewrite[i])) {
  1193. // Selects in InstsToRewrite only have load uses. Rewrite each as two
  1194. // loads with a new select.
  1195. while (!SI->use_empty()) {
  1196. LoadInst *LI = cast<LoadInst>(SI->user_back());
  1197. IRBuilder<> Builder(LI);
  1198. LoadInst *TrueLoad =
  1199. Builder.CreateLoad(SI->getTrueValue(), LI->getName() + ".t");
  1200. LoadInst *FalseLoad =
  1201. Builder.CreateLoad(SI->getFalseValue(), LI->getName() + ".f");
  1202. // Transfer alignment and AA info if present.
  1203. TrueLoad->setAlignment(LI->getAlignment());
  1204. FalseLoad->setAlignment(LI->getAlignment());
  1205. AAMDNodes Tags;
  1206. LI->getAAMetadata(Tags);
  1207. if (Tags) {
  1208. TrueLoad->setAAMetadata(Tags);
  1209. FalseLoad->setAAMetadata(Tags);
  1210. }
  1211. Value *V =
  1212. Builder.CreateSelect(SI->getCondition(), TrueLoad, FalseLoad);
  1213. V->takeName(LI);
  1214. LI->replaceAllUsesWith(V);
  1215. LI->eraseFromParent();
  1216. }
  1217. // Now that all the loads are gone, the select is gone too.
  1218. SI->eraseFromParent();
  1219. continue;
  1220. }
  1221. // Otherwise, we have a PHI node which allows us to push the loads into the
  1222. // predecessors.
  1223. PHINode *PN = cast<PHINode>(InstsToRewrite[i]);
  1224. if (PN->use_empty()) {
  1225. PN->eraseFromParent();
  1226. continue;
  1227. }
  1228. Type *LoadTy = cast<PointerType>(PN->getType())->getElementType();
  1229. PHINode *NewPN = PHINode::Create(LoadTy, PN->getNumIncomingValues(),
  1230. PN->getName() + ".ld", PN);
  1231. // Get the AA tags and alignment to use from one of the loads. It doesn't
  1232. // matter which one we get and if any differ, it doesn't matter.
  1233. LoadInst *SomeLoad = cast<LoadInst>(PN->user_back());
  1234. AAMDNodes AATags;
  1235. SomeLoad->getAAMetadata(AATags);
  1236. unsigned Align = SomeLoad->getAlignment();
  1237. // Rewrite all loads of the PN to use the new PHI.
  1238. while (!PN->use_empty()) {
  1239. LoadInst *LI = cast<LoadInst>(PN->user_back());
  1240. LI->replaceAllUsesWith(NewPN);
  1241. LI->eraseFromParent();
  1242. }
  1243. // Inject loads into all of the pred blocks. Keep track of which blocks we
  1244. // insert them into in case we have multiple edges from the same block.
  1245. DenseMap<BasicBlock *, LoadInst *> InsertedLoads;
  1246. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  1247. BasicBlock *Pred = PN->getIncomingBlock(i);
  1248. LoadInst *&Load = InsertedLoads[Pred];
  1249. if (!Load) {
  1250. Load = new LoadInst(PN->getIncomingValue(i),
  1251. PN->getName() + "." + Pred->getName(),
  1252. Pred->getTerminator());
  1253. Load->setAlignment(Align);
  1254. if (AATags)
  1255. Load->setAAMetadata(AATags);
  1256. }
  1257. NewPN->addIncoming(Load, Pred);
  1258. }
  1259. PN->eraseFromParent();
  1260. }
  1261. ++NumAdjusted;
  1262. return true;
  1263. }
  1264. bool SROA_HLSL::performPromotion(Function &F) {
  1265. std::vector<AllocaInst *> Allocas;
  1266. const DataLayout &DL = F.getParent()->getDataLayout();
  1267. DominatorTree *DT = nullptr;
  1268. if (HasDomTree)
  1269. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  1270. AssumptionCache &AC =
  1271. getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  1272. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  1273. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  1274. bool Changed = false;
  1275. SmallVector<Instruction *, 64> Insts;
  1276. while (1) {
  1277. Allocas.clear();
  1278. // Find allocas that are safe to promote, by looking at all instructions in
  1279. // the entry node
  1280. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  1281. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
  1282. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  1283. // Skip alloca has debug info when not promote.
  1284. if (DDI && !RunPromotion) {
  1285. continue;
  1286. }
  1287. if (tryToMakeAllocaBePromotable(AI, DL))
  1288. Allocas.push_back(AI);
  1289. }
  1290. if (Allocas.empty())
  1291. break;
  1292. if (HasDomTree)
  1293. PromoteMemToReg(Allocas, *DT, nullptr, &AC);
  1294. else {
  1295. SSAUpdater SSA;
  1296. for (unsigned i = 0, e = Allocas.size(); i != e; ++i) {
  1297. AllocaInst *AI = Allocas[i];
  1298. // Build list of instructions to promote.
  1299. for (User *U : AI->users())
  1300. Insts.push_back(cast<Instruction>(U));
  1301. AllocaPromoter(Insts, SSA, &DIB).run(AI, Insts);
  1302. Insts.clear();
  1303. }
  1304. }
  1305. NumPromoted += Allocas.size();
  1306. Changed = true;
  1307. }
  1308. return Changed;
  1309. }
  1310. /// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
  1311. /// SROA. It must be a struct or array type with a small number of elements.
  1312. bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
  1313. Type *T = AI->getAllocatedType();
  1314. // promote every struct.
  1315. if (StructType *ST = dyn_cast<StructType>(T))
  1316. return true;
  1317. // promote every array.
  1318. if (ArrayType *AT = dyn_cast<ArrayType>(T))
  1319. return true;
  1320. return false;
  1321. }
  1322. // performScalarRepl - This algorithm is a simple worklist driven algorithm,
  1323. // which runs on all of the alloca instructions in the entry block, removing
  1324. // them if they are only used by getelementptr instructions.
  1325. //
  1326. bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
  1327. std::vector<AllocaInst *> AllocaList;
  1328. const DataLayout &DL = F.getParent()->getDataLayout();
  1329. // Scan the entry basic block, adding allocas to the worklist.
  1330. BasicBlock &BB = F.getEntryBlock();
  1331. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1332. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1333. if (A->hasNUsesOrMore(1))
  1334. AllocaList.emplace_back(A);
  1335. }
  1336. // merge GEP use for the allocs
  1337. for (auto A : AllocaList)
  1338. HLModule::MergeGepUse(A);
  1339. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  1340. // Process the worklist
  1341. bool Changed = false;
  1342. for (AllocaInst *Alloc : AllocaList) {
  1343. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(Alloc);
  1344. unsigned debugOffset = 0;
  1345. std::deque<AllocaInst *> WorkList;
  1346. WorkList.emplace_back(Alloc);
  1347. while (!WorkList.empty()) {
  1348. AllocaInst *AI = WorkList.front();
  1349. WorkList.pop_front();
  1350. // Handle dead allocas trivially. These can be formed by SROA'ing arrays
  1351. // with unused elements.
  1352. if (AI->use_empty()) {
  1353. AI->eraseFromParent();
  1354. Changed = true;
  1355. continue;
  1356. }
  1357. const bool bAllowReplace = true;
  1358. if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
  1359. bAllowReplace)) {
  1360. Changed = true;
  1361. continue;
  1362. }
  1363. // If this alloca is impossible for us to promote, reject it early.
  1364. if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
  1365. continue;
  1366. // Check to see if we can perform the core SROA transformation. We cannot
  1367. // transform the allocation instruction if it is an array allocation
  1368. // (allocations OF arrays are ok though), and an allocation of a scalar
  1369. // value cannot be decomposed at all.
  1370. uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType());
  1371. // Do not promote [0 x %struct].
  1372. if (AllocaSize == 0)
  1373. continue;
  1374. Type *Ty = AI->getAllocatedType();
  1375. // Skip empty struct type.
  1376. if (SROA_Helper::IsEmptyStructType(Ty, typeSys)) {
  1377. SROA_Helper::MarkEmptyStructUsers(AI, DeadInsts);
  1378. DeleteDeadInstructions();
  1379. continue;
  1380. }
  1381. // If the alloca looks like a good candidate for scalar replacement, and
  1382. // if
  1383. // all its users can be transformed, then split up the aggregate into its
  1384. // separate elements.
  1385. if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
  1386. std::vector<Value *> Elts;
  1387. IRBuilder<> Builder(AI);
  1388. bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
  1389. bool SROAed = SROA_Helper::DoScalarReplacement(
  1390. AI, Elts, Builder, /*bFlatVector*/ true, hasPrecise, typeSys,
  1391. DeadInsts);
  1392. if (SROAed) {
  1393. Type *Ty = AI->getAllocatedType();
  1394. // Skip empty struct parameters.
  1395. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  1396. if (!HLMatrixLower::IsMatrixType(Ty)) {
  1397. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  1398. if (SA && SA->IsEmptyStruct()) {
  1399. for (User *U : AI->users()) {
  1400. if (StoreInst *SI = dyn_cast<StoreInst>(U))
  1401. DeadInsts.emplace_back(SI);
  1402. }
  1403. DeleteDeadInstructions();
  1404. AI->replaceAllUsesWith(UndefValue::get(AI->getType()));
  1405. AI->eraseFromParent();
  1406. continue;
  1407. }
  1408. }
  1409. }
  1410. // Push Elts into workList.
  1411. for (auto iter = Elts.begin(); iter != Elts.end(); iter++)
  1412. WorkList.emplace_back(cast<AllocaInst>(*iter));
  1413. // Now erase any instructions that were made dead while rewriting the
  1414. // alloca.
  1415. DeleteDeadInstructions();
  1416. ++NumReplaced;
  1417. AI->eraseFromParent();
  1418. Changed = true;
  1419. continue;
  1420. }
  1421. }
  1422. // Add debug info.
  1423. if (DDI != nullptr && AI != Alloc) {
  1424. Type *Ty = AI->getAllocatedType();
  1425. unsigned size = DL.getTypeAllocSize(Ty);
  1426. DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
  1427. debugOffset += size;
  1428. DIB.insertDeclare(AI, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  1429. DDI);
  1430. }
  1431. }
  1432. }
  1433. return Changed;
  1434. }
  1435. // markPrecise - To save the precise attribute on alloca inst which might be removed by promote,
  1436. // mark precise attribute with function call on alloca inst stores.
  1437. bool SROA_HLSL::markPrecise(Function &F) {
  1438. bool Changed = false;
  1439. BasicBlock &BB = F.getEntryBlock();
  1440. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1441. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1442. // TODO: Only do this on basic types.
  1443. if (HLModule::HasPreciseAttributeWithMetadata(A)) {
  1444. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(A,
  1445. *(F.getParent()));
  1446. Changed = true;
  1447. }
  1448. }
  1449. return Changed;
  1450. }
  1451. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  1452. /// recursively including all their operands that become trivially dead.
  1453. void SROA_HLSL::DeleteDeadInstructions() {
  1454. while (!DeadInsts.empty()) {
  1455. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  1456. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  1457. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  1458. // Zero out the operand and see if it becomes trivially dead.
  1459. // (But, don't add allocas to the dead instruction list -- they are
  1460. // already on the worklist and will be deleted separately.)
  1461. *OI = nullptr;
  1462. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  1463. DeadInsts.push_back(U);
  1464. }
  1465. I->eraseFromParent();
  1466. }
  1467. }
  1468. /// isSafeForScalarRepl - Check if instruction I is a safe use with regard to
  1469. /// performing scalar replacement of alloca AI. The results are flagged in
  1470. /// the Info parameter. Offset indicates the position within AI that is
  1471. /// referenced by this instruction.
  1472. void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset,
  1473. AllocaInfo &Info) {
  1474. if (I->getType()->isPointerTy()) {
  1475. // Don't check object pointers.
  1476. if (HLModule::IsHLSLObjectType(I->getType()->getPointerElementType()))
  1477. return;
  1478. }
  1479. const DataLayout &DL = I->getModule()->getDataLayout();
  1480. for (Use &U : I->uses()) {
  1481. Instruction *User = cast<Instruction>(U.getUser());
  1482. if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) {
  1483. isSafeForScalarRepl(BC, Offset, Info);
  1484. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User)) {
  1485. uint64_t GEPOffset = Offset;
  1486. isSafeGEP(GEPI, GEPOffset, Info);
  1487. if (!Info.isUnsafe)
  1488. isSafeForScalarRepl(GEPI, GEPOffset, Info);
  1489. } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
  1490. ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength());
  1491. if (!Length || Length->isNegative())
  1492. return MarkUnsafe(Info, User);
  1493. isSafeMemAccess(Offset, Length->getZExtValue(), nullptr,
  1494. U.getOperandNo() == 0, Info, MI,
  1495. true /*AllowWholeAccess*/);
  1496. } else if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
  1497. if (!LI->isSimple())
  1498. return MarkUnsafe(Info, User);
  1499. Type *LIType = LI->getType();
  1500. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1501. LI, true /*AllowWholeAccess*/);
  1502. Info.hasALoadOrStore = true;
  1503. } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
  1504. // Store is ok if storing INTO the pointer, not storing the pointer
  1505. if (!SI->isSimple() || SI->getOperand(0) == I)
  1506. return MarkUnsafe(Info, User);
  1507. Type *SIType = SI->getOperand(0)->getType();
  1508. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1509. SI, true /*AllowWholeAccess*/);
  1510. Info.hasALoadOrStore = true;
  1511. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) {
  1512. if (II->getIntrinsicID() != Intrinsic::lifetime_start &&
  1513. II->getIntrinsicID() != Intrinsic::lifetime_end)
  1514. return MarkUnsafe(Info, User);
  1515. } else if (isa<PHINode>(User) || isa<SelectInst>(User)) {
  1516. isSafePHISelectUseForScalarRepl(User, Offset, Info);
  1517. } else if (CallInst *CI = dyn_cast<CallInst>(User)) {
  1518. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1519. // HL functions are safe for scalar repl.
  1520. if (group == HLOpcodeGroup::NotHL)
  1521. return MarkUnsafe(Info, User);
  1522. } else {
  1523. return MarkUnsafe(Info, User);
  1524. }
  1525. if (Info.isUnsafe)
  1526. return;
  1527. }
  1528. }
  1529. /// isSafePHIUseForScalarRepl - If we see a PHI node or select using a pointer
  1530. /// derived from the alloca, we can often still split the alloca into elements.
  1531. /// This is useful if we have a large alloca where one element is phi'd
  1532. /// together somewhere: we can SRoA and promote all the other elements even if
  1533. /// we end up not being able to promote this one.
  1534. ///
  1535. /// All we require is that the uses of the PHI do not index into other parts of
  1536. /// the alloca. The most important use case for this is single load and stores
  1537. /// that are PHI'd together, which can happen due to code sinking.
  1538. void SROA_HLSL::isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset,
  1539. AllocaInfo &Info) {
  1540. // If we've already checked this PHI, don't do it again.
  1541. if (PHINode *PN = dyn_cast<PHINode>(I))
  1542. if (!Info.CheckedPHIs.insert(PN).second)
  1543. return;
  1544. const DataLayout &DL = I->getModule()->getDataLayout();
  1545. for (User *U : I->users()) {
  1546. Instruction *UI = cast<Instruction>(U);
  1547. if (BitCastInst *BC = dyn_cast<BitCastInst>(UI)) {
  1548. isSafePHISelectUseForScalarRepl(BC, Offset, Info);
  1549. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(UI)) {
  1550. // Only allow "bitcast" GEPs for simplicity. We could generalize this,
  1551. // but would have to prove that we're staying inside of an element being
  1552. // promoted.
  1553. if (!GEPI->hasAllZeroIndices())
  1554. return MarkUnsafe(Info, UI);
  1555. isSafePHISelectUseForScalarRepl(GEPI, Offset, Info);
  1556. } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  1557. if (!LI->isSimple())
  1558. return MarkUnsafe(Info, UI);
  1559. Type *LIType = LI->getType();
  1560. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1561. LI, false /*AllowWholeAccess*/);
  1562. Info.hasALoadOrStore = true;
  1563. } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
  1564. // Store is ok if storing INTO the pointer, not storing the pointer
  1565. if (!SI->isSimple() || SI->getOperand(0) == I)
  1566. return MarkUnsafe(Info, UI);
  1567. Type *SIType = SI->getOperand(0)->getType();
  1568. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1569. SI, false /*AllowWholeAccess*/);
  1570. Info.hasALoadOrStore = true;
  1571. } else if (isa<PHINode>(UI) || isa<SelectInst>(UI)) {
  1572. isSafePHISelectUseForScalarRepl(UI, Offset, Info);
  1573. } else {
  1574. return MarkUnsafe(Info, UI);
  1575. }
  1576. if (Info.isUnsafe)
  1577. return;
  1578. }
  1579. }
  1580. /// isSafeGEP - Check if a GEP instruction can be handled for scalar
  1581. /// replacement. It is safe when all the indices are constant, in-bounds
  1582. /// references, and when the resulting offset corresponds to an element within
  1583. /// the alloca type. The results are flagged in the Info parameter. Upon
  1584. /// return, Offset is adjusted as specified by the GEP indices.
  1585. void SROA_HLSL::isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset,
  1586. AllocaInfo &Info) {
  1587. gep_type_iterator GEPIt = gep_type_begin(GEPI), E = gep_type_end(GEPI);
  1588. if (GEPIt == E)
  1589. return;
  1590. bool NonConstant = false;
  1591. unsigned NonConstantIdxSize = 0;
  1592. // Compute the offset due to this GEP and check if the alloca has a
  1593. // component element at that offset.
  1594. SmallVector<Value *, 8> Indices(GEPI->op_begin() + 1, GEPI->op_end());
  1595. auto indicesIt = Indices.begin();
  1596. // Walk through the GEP type indices, checking the types that this indexes
  1597. // into.
  1598. uint32_t arraySize = 0;
  1599. bool isArrayIndexing = false;
  1600. for (;GEPIt != E; ++GEPIt) {
  1601. Type *Ty = *GEPIt;
  1602. if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
  1603. // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
  1604. // The following level won't affect scalar repl on the struct.
  1605. break;
  1606. }
  1607. if (GEPIt->isArrayTy()) {
  1608. arraySize = GEPIt->getArrayNumElements();
  1609. isArrayIndexing = true;
  1610. }
  1611. if (GEPIt->isVectorTy()) {
  1612. arraySize = GEPIt->getVectorNumElements();
  1613. isArrayIndexing = false;
  1614. }
  1615. // Allow dynamic indexing
  1616. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1617. if (!IdxVal) {
  1618. // for dynamic index, use array size - 1 to check the offset
  1619. *indicesIt = Constant::getIntegerValue(
  1620. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1621. if (isArrayIndexing)
  1622. Info.hasArrayIndexing = true;
  1623. else
  1624. Info.hasVectorIndexing = true;
  1625. NonConstant = true;
  1626. }
  1627. indicesIt++;
  1628. }
  1629. // Continue iterate only for the NonConstant.
  1630. for (;GEPIt != E; ++GEPIt) {
  1631. Type *Ty = *GEPIt;
  1632. if (Ty->isArrayTy()) {
  1633. arraySize = GEPIt->getArrayNumElements();
  1634. }
  1635. if (Ty->isVectorTy()) {
  1636. arraySize = GEPIt->getVectorNumElements();
  1637. }
  1638. // Allow dynamic indexing
  1639. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1640. if (!IdxVal) {
  1641. // for dynamic index, use array size - 1 to check the offset
  1642. *indicesIt = Constant::getIntegerValue(
  1643. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1644. NonConstant = true;
  1645. }
  1646. indicesIt++;
  1647. }
  1648. // If this GEP is non-constant then the last operand must have been a
  1649. // dynamic index into a vector. Pop this now as it has no impact on the
  1650. // constant part of the offset.
  1651. if (NonConstant)
  1652. Indices.pop_back();
  1653. const DataLayout &DL = GEPI->getModule()->getDataLayout();
  1654. Offset += DL.getIndexedOffset(GEPI->getPointerOperandType(), Indices);
  1655. if (!TypeHasComponent(Info.AI->getAllocatedType(), Offset, NonConstantIdxSize,
  1656. DL))
  1657. MarkUnsafe(Info, GEPI);
  1658. }
  1659. /// isHomogeneousAggregate - Check if type T is a struct or array containing
  1660. /// elements of the same type (which is always true for arrays). If so,
  1661. /// return true with NumElts and EltTy set to the number of elements and the
  1662. /// element type, respectively.
  1663. static bool isHomogeneousAggregate(Type *T, unsigned &NumElts, Type *&EltTy) {
  1664. if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  1665. NumElts = AT->getNumElements();
  1666. EltTy = (NumElts == 0 ? nullptr : AT->getElementType());
  1667. return true;
  1668. }
  1669. if (StructType *ST = dyn_cast<StructType>(T)) {
  1670. NumElts = ST->getNumContainedTypes();
  1671. EltTy = (NumElts == 0 ? nullptr : ST->getContainedType(0));
  1672. for (unsigned n = 1; n < NumElts; ++n) {
  1673. if (ST->getContainedType(n) != EltTy)
  1674. return false;
  1675. }
  1676. return true;
  1677. }
  1678. return false;
  1679. }
  1680. /// isCompatibleAggregate - Check if T1 and T2 are either the same type or are
  1681. /// "homogeneous" aggregates with the same element type and number of elements.
  1682. static bool isCompatibleAggregate(Type *T1, Type *T2) {
  1683. if (T1 == T2)
  1684. return true;
  1685. unsigned NumElts1, NumElts2;
  1686. Type *EltTy1, *EltTy2;
  1687. if (isHomogeneousAggregate(T1, NumElts1, EltTy1) &&
  1688. isHomogeneousAggregate(T2, NumElts2, EltTy2) && NumElts1 == NumElts2 &&
  1689. EltTy1 == EltTy2)
  1690. return true;
  1691. return false;
  1692. }
  1693. /// isSafeMemAccess - Check if a load/store/memcpy operates on the entire AI
  1694. /// alloca or has an offset and size that corresponds to a component element
  1695. /// within it. The offset checked here may have been formed from a GEP with a
  1696. /// pointer bitcasted to a different type.
  1697. ///
  1698. /// If AllowWholeAccess is true, then this allows uses of the entire alloca as a
  1699. /// unit. If false, it only allows accesses known to be in a single element.
  1700. void SROA_HLSL::isSafeMemAccess(uint64_t Offset, uint64_t MemSize,
  1701. Type *MemOpType, bool isStore, AllocaInfo &Info,
  1702. Instruction *TheAccess, bool AllowWholeAccess) {
  1703. // What hlsl cares is Info.hasVectorIndexing.
  1704. // Do nothing here.
  1705. }
  1706. /// TypeHasComponent - Return true if T has a component type with the
  1707. /// specified offset and size. If Size is zero, do not check the size.
  1708. bool SROA_HLSL::TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
  1709. const DataLayout &DL) {
  1710. Type *EltTy;
  1711. uint64_t EltSize;
  1712. if (StructType *ST = dyn_cast<StructType>(T)) {
  1713. const StructLayout *Layout = DL.getStructLayout(ST);
  1714. unsigned EltIdx = Layout->getElementContainingOffset(Offset);
  1715. EltTy = ST->getContainedType(EltIdx);
  1716. EltSize = DL.getTypeAllocSize(EltTy);
  1717. Offset -= Layout->getElementOffset(EltIdx);
  1718. } else if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  1719. EltTy = AT->getElementType();
  1720. EltSize = DL.getTypeAllocSize(EltTy);
  1721. if (Offset >= AT->getNumElements() * EltSize)
  1722. return false;
  1723. Offset %= EltSize;
  1724. } else if (VectorType *VT = dyn_cast<VectorType>(T)) {
  1725. EltTy = VT->getElementType();
  1726. EltSize = DL.getTypeAllocSize(EltTy);
  1727. if (Offset >= VT->getNumElements() * EltSize)
  1728. return false;
  1729. Offset %= EltSize;
  1730. } else {
  1731. return false;
  1732. }
  1733. if (Offset == 0 && (Size == 0 || EltSize == Size))
  1734. return true;
  1735. // Check if the component spans multiple elements.
  1736. if (Offset + Size > EltSize)
  1737. return false;
  1738. return TypeHasComponent(EltTy, Offset, Size, DL);
  1739. }
  1740. /// LoadVectorArray - Load vector array like [2 x <4 x float>] from
  1741. /// arrays like 4 [2 x float] or struct array like
  1742. /// [2 x { <4 x float>, < 4 x uint> }]
  1743. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  1744. static Value *LoadVectorOrStructArray(ArrayType *AT, ArrayRef<Value *> NewElts,
  1745. SmallVector<Value *, 8> &idxList,
  1746. IRBuilder<> &Builder) {
  1747. Type *EltTy = AT->getElementType();
  1748. Value *retVal = llvm::UndefValue::get(AT);
  1749. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  1750. uint32_t arraySize = AT->getNumElements();
  1751. for (uint32_t i = 0; i < arraySize; i++) {
  1752. Constant *idx = ConstantInt::get(i32Ty, i);
  1753. idxList.emplace_back(idx);
  1754. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  1755. Value *EltVal = LoadVectorOrStructArray(EltAT, NewElts, idxList, Builder);
  1756. retVal = Builder.CreateInsertValue(retVal, EltVal, i);
  1757. } else {
  1758. assert(EltTy->isVectorTy() ||
  1759. EltTy->isStructTy() && "must be a vector or struct type");
  1760. bool isVectorTy = EltTy->isVectorTy();
  1761. Value *retVec = llvm::UndefValue::get(EltTy);
  1762. if (isVectorTy) {
  1763. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  1764. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1765. Value *elt = Builder.CreateLoad(GEP);
  1766. retVec = Builder.CreateInsertElement(retVec, elt, c);
  1767. }
  1768. } else {
  1769. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  1770. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1771. Value *elt = Builder.CreateLoad(GEP);
  1772. retVec = Builder.CreateInsertValue(retVec, elt, c);
  1773. }
  1774. }
  1775. retVal = Builder.CreateInsertValue(retVal, retVec, i);
  1776. }
  1777. idxList.pop_back();
  1778. }
  1779. return retVal;
  1780. }
  1781. /// LoadVectorArray - Store vector array like [2 x <4 x float>] to
  1782. /// arrays like 4 [2 x float] or struct array like
  1783. /// [2 x { <4 x float>, < 4 x uint> }]
  1784. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  1785. static void StoreVectorOrStructArray(ArrayType *AT, Value *val,
  1786. ArrayRef<Value *> NewElts,
  1787. SmallVector<Value *, 8> &idxList,
  1788. IRBuilder<> &Builder) {
  1789. Type *EltTy = AT->getElementType();
  1790. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  1791. uint32_t arraySize = AT->getNumElements();
  1792. for (uint32_t i = 0; i < arraySize; i++) {
  1793. Value *elt = Builder.CreateExtractValue(val, i);
  1794. Constant *idx = ConstantInt::get(i32Ty, i);
  1795. idxList.emplace_back(idx);
  1796. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  1797. StoreVectorOrStructArray(EltAT, elt, NewElts, idxList, Builder);
  1798. } else {
  1799. assert(EltTy->isVectorTy() ||
  1800. EltTy->isStructTy() && "must be a vector or struct type");
  1801. bool isVectorTy = EltTy->isVectorTy();
  1802. if (isVectorTy) {
  1803. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  1804. Value *component = Builder.CreateExtractElement(elt, c);
  1805. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1806. Builder.CreateStore(component, GEP);
  1807. }
  1808. } else {
  1809. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  1810. Value *field = Builder.CreateExtractValue(elt, c);
  1811. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1812. Builder.CreateStore(field, GEP);
  1813. }
  1814. }
  1815. }
  1816. idxList.pop_back();
  1817. }
  1818. }
  1819. /// HasPadding - Return true if the specified type has any structure or
  1820. /// alignment padding in between the elements that would be split apart
  1821. /// by SROA; return false otherwise.
  1822. static bool HasPadding(Type *Ty, const DataLayout &DL) {
  1823. if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
  1824. Ty = ATy->getElementType();
  1825. return DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty);
  1826. }
  1827. // SROA currently handles only Arrays and Structs.
  1828. StructType *STy = cast<StructType>(Ty);
  1829. const StructLayout *SL = DL.getStructLayout(STy);
  1830. unsigned PrevFieldBitOffset = 0;
  1831. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
  1832. unsigned FieldBitOffset = SL->getElementOffsetInBits(i);
  1833. // Check to see if there is any padding between this element and the
  1834. // previous one.
  1835. if (i) {
  1836. unsigned PrevFieldEnd =
  1837. PrevFieldBitOffset + DL.getTypeSizeInBits(STy->getElementType(i - 1));
  1838. if (PrevFieldEnd < FieldBitOffset)
  1839. return true;
  1840. }
  1841. PrevFieldBitOffset = FieldBitOffset;
  1842. }
  1843. // Check for tail padding.
  1844. if (unsigned EltCount = STy->getNumElements()) {
  1845. unsigned PrevFieldEnd =
  1846. PrevFieldBitOffset +
  1847. DL.getTypeSizeInBits(STy->getElementType(EltCount - 1));
  1848. if (PrevFieldEnd < SL->getSizeInBits())
  1849. return true;
  1850. }
  1851. return false;
  1852. }
  1853. /// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of
  1854. /// an aggregate can be broken down into elements. Return 0 if not, 3 if safe,
  1855. /// or 1 if safe after canonicalization has been performed.
  1856. bool SROA_HLSL::isSafeAllocaToScalarRepl(AllocaInst *AI) {
  1857. // Loop over the use list of the alloca. We can only transform it if all of
  1858. // the users are safe to transform.
  1859. AllocaInfo Info(AI);
  1860. isSafeForScalarRepl(AI, 0, Info);
  1861. if (Info.isUnsafe) {
  1862. DEBUG(dbgs() << "Cannot transform: " << *AI << '\n');
  1863. return false;
  1864. }
  1865. // vector indexing need translate vector into array
  1866. if (Info.hasVectorIndexing)
  1867. return false;
  1868. const DataLayout &DL = AI->getModule()->getDataLayout();
  1869. // Okay, we know all the users are promotable. If the aggregate is a memcpy
  1870. // source and destination, we have to be careful. In particular, the memcpy
  1871. // could be moving around elements that live in structure padding of the LLVM
  1872. // types, but may actually be used. In these cases, we refuse to promote the
  1873. // struct.
  1874. if (Info.isMemCpySrc && Info.isMemCpyDst &&
  1875. HasPadding(AI->getAllocatedType(), DL))
  1876. return false;
  1877. return true;
  1878. }
  1879. // Copy data from srcPtr to destPtr.
  1880. static void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
  1881. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1882. IRBuilder<> &Builder) {
  1883. if (idxList.size() > 1) {
  1884. DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
  1885. SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
  1886. }
  1887. llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
  1888. Builder.CreateStore(ld, DestPtr);
  1889. }
  1890. // Copy srcVal to destPtr.
  1891. static void SimpleValCopy(Value *DestPtr, Value *SrcVal,
  1892. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1893. IRBuilder<> &Builder) {
  1894. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  1895. Value *Val = SrcVal;
  1896. // Skip beginning pointer type.
  1897. for (unsigned i = 1; i < idxList.size(); i++) {
  1898. ConstantInt *idx = cast<ConstantInt>(idxList[i]);
  1899. Type *Ty = Val->getType();
  1900. if (Ty->isAggregateType()) {
  1901. Val = Builder.CreateExtractValue(Val, idx->getLimitedValue());
  1902. }
  1903. }
  1904. Builder.CreateStore(Val, DestGEP);
  1905. }
  1906. static void SimpleCopy(Value *Dest, Value *Src,
  1907. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1908. IRBuilder<> &Builder) {
  1909. if (Src->getType()->isPointerTy())
  1910. SimplePtrCopy(Dest, Src, idxList, Builder);
  1911. else
  1912. SimpleValCopy(Dest, Src, idxList, Builder);
  1913. }
  1914. // Split copy into ld/st.
  1915. static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
  1916. SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
  1917. DxilTypeSystem &typeSys,
  1918. DxilFieldAnnotation *fieldAnnotation) {
  1919. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  1920. Constant *idx = Constant::getIntegerValue(
  1921. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  1922. idxList.emplace_back(idx);
  1923. SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, typeSys,
  1924. fieldAnnotation);
  1925. idxList.pop_back();
  1926. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  1927. // If no fieldAnnotation, use row major as default.
  1928. // Only load then store immediately should be fine.
  1929. bool bRowMajor = true;
  1930. if (fieldAnnotation) {
  1931. DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
  1932. "must has matrix annotation");
  1933. bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
  1934. MatrixOrientation::RowMajor;
  1935. }
  1936. Module *M = Builder.GetInsertPoint()->getModule();
  1937. Value *DestGEP = Builder.CreateInBoundsGEP(Dest, idxList);
  1938. Value *SrcGEP = Builder.CreateInBoundsGEP(Src, idxList);
  1939. if (bRowMajor) {
  1940. Value *Load = HLModule::EmitHLOperationCall(
  1941. Builder, HLOpcodeGroup::HLMatLoadStore,
  1942. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty, {SrcGEP},
  1943. *M);
  1944. // Generate Matrix Store.
  1945. HLModule::EmitHLOperationCall(
  1946. Builder, HLOpcodeGroup::HLMatLoadStore,
  1947. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
  1948. {DestGEP, Load}, *M);
  1949. } else {
  1950. Value *Load = HLModule::EmitHLOperationCall(
  1951. Builder, HLOpcodeGroup::HLMatLoadStore,
  1952. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty, {SrcGEP},
  1953. *M);
  1954. // Generate Matrix Store.
  1955. HLModule::EmitHLOperationCall(
  1956. Builder, HLOpcodeGroup::HLMatLoadStore,
  1957. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
  1958. {DestGEP, Load}, *M);
  1959. }
  1960. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  1961. if (HLModule::IsHLSLObjectType(ST)) {
  1962. // Avoid split HLSL object.
  1963. SimpleCopy(Dest, Src, idxList, Builder);
  1964. return;
  1965. }
  1966. DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
  1967. DXASSERT(STA, "require annotation here");
  1968. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  1969. llvm::Type *ET = ST->getElementType(i);
  1970. Constant *idx = llvm::Constant::getIntegerValue(
  1971. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  1972. idxList.emplace_back(idx);
  1973. DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
  1974. SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, &EltAnnotation);
  1975. idxList.pop_back();
  1976. }
  1977. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  1978. Type *ET = AT->getElementType();
  1979. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  1980. Constant *idx = Constant::getIntegerValue(
  1981. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  1982. idxList.emplace_back(idx);
  1983. SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, fieldAnnotation);
  1984. idxList.pop_back();
  1985. }
  1986. } else {
  1987. SimpleCopy(Dest, Src, idxList, Builder);
  1988. }
  1989. }
  1990. static void SplitPtr(Type *Ty, Value *Ptr, SmallVector<Value *, 16> &idxList,
  1991. SmallVector<Value *, 16> &EltPtrList,
  1992. IRBuilder<> &Builder) {
  1993. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  1994. Constant *idx = Constant::getIntegerValue(
  1995. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  1996. idxList.emplace_back(idx);
  1997. SplitPtr(PT->getElementType(), Ptr, idxList, EltPtrList, Builder);
  1998. idxList.pop_back();
  1999. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  2000. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2001. EltPtrList.emplace_back(GEP);
  2002. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2003. if (HLModule::IsHLSLObjectType(ST)) {
  2004. // Avoid split HLSL object.
  2005. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2006. EltPtrList.emplace_back(GEP);
  2007. return;
  2008. }
  2009. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  2010. llvm::Type *ET = ST->getElementType(i);
  2011. Constant *idx = llvm::Constant::getIntegerValue(
  2012. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  2013. idxList.emplace_back(idx);
  2014. SplitPtr(ET, Ptr, idxList, EltPtrList, Builder);
  2015. idxList.pop_back();
  2016. }
  2017. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  2018. if (AT->getNumContainedTypes() == 0) {
  2019. // Skip case like [0 x %struct].
  2020. return;
  2021. }
  2022. Type *ElTy = AT->getElementType();
  2023. SmallVector<ArrayType *, 4> nestArrayTys;
  2024. nestArrayTys.emplace_back(AT);
  2025. // support multi level of array
  2026. while (ElTy->isArrayTy()) {
  2027. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2028. nestArrayTys.emplace_back(ElAT);
  2029. ElTy = ElAT->getElementType();
  2030. }
  2031. if (!ElTy->isStructTy() ||
  2032. HLMatrixLower::IsMatrixType(ElTy)) {
  2033. // Not split array of basic type.
  2034. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2035. EltPtrList.emplace_back(GEP);
  2036. }
  2037. else {
  2038. DXASSERT(0, "Not support array of struct when split pointers.");
  2039. }
  2040. } else {
  2041. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2042. EltPtrList.emplace_back(GEP);
  2043. }
  2044. }
  2045. // Support case when bitcast (gep ptr, 0,0) is transformed into bitcast ptr.
  2046. static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsigned size, unsigned level) {
  2047. unsigned ptrSize = DL.getTypeAllocSize(Ty);
  2048. // Size match, return current level.
  2049. if (ptrSize == size) {
  2050. // For struct, go deeper if size not change.
  2051. // This will leave memcpy to deeper level when flatten.
  2052. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2053. if (ST->getNumElements() == 1) {
  2054. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size, level+1);
  2055. }
  2056. }
  2057. // Don't do this for array.
  2058. // Array will be flattened as struct of array.
  2059. return level;
  2060. }
  2061. // Add ZeroIdx cannot make ptrSize bigger.
  2062. if (ptrSize < size)
  2063. return 0;
  2064. // ptrSize > size.
  2065. // Try to use element type to make size match.
  2066. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2067. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size, level+1);
  2068. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  2069. return MatchSizeByCheckElementType(AT->getElementType(), DL, size, level+1);
  2070. } else {
  2071. return 0;
  2072. }
  2073. }
  2074. static void PatchZeroIdxGEP(Value *Ptr, Value *RawPtr, MemCpyInst *MI,
  2075. unsigned level, IRBuilder<> &Builder) {
  2076. Value *zeroIdx = Builder.getInt32(0);
  2077. SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
  2078. Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  2079. // Use BitCastInst::Create to prevent idxList from being optimized.
  2080. CastInst *Cast =
  2081. BitCastInst::Create(Instruction::BitCast, GEP, RawPtr->getType());
  2082. Builder.Insert(Cast);
  2083. MI->replaceUsesOfWith(RawPtr, Cast);
  2084. // Remove RawPtr if possible.
  2085. if (RawPtr->user_empty()) {
  2086. if (Instruction *I = dyn_cast<Instruction>(RawPtr)) {
  2087. I->eraseFromParent();
  2088. }
  2089. }
  2090. }
  2091. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI,
  2092. const DataLayout &DL) {
  2093. Value *Dest = MI->getRawDest();
  2094. Value *Src = MI->getRawSource();
  2095. // Only remove one level bitcast generated from inline.
  2096. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  2097. Dest = BC->getOperand(0);
  2098. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  2099. Src = BC->getOperand(0);
  2100. IRBuilder<> Builder(MI);
  2101. ConstantInt *zero = Builder.getInt32(0);
  2102. Type *DestTy = Dest->getType()->getPointerElementType();
  2103. Type *SrcTy = Src->getType()->getPointerElementType();
  2104. // Support case when bitcast (gep ptr, 0,0) is transformed into
  2105. // bitcast ptr.
  2106. // Also replace (gep ptr, 0) with ptr.
  2107. ConstantInt *Length = cast<ConstantInt>(MI->getLength());
  2108. unsigned size = Length->getLimitedValue();
  2109. if (unsigned level = MatchSizeByCheckElementType(DestTy, DL, size, 0)) {
  2110. PatchZeroIdxGEP(Dest, MI->getRawDest(), MI, level, Builder);
  2111. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Dest)) {
  2112. if (GEP->getNumIndices() == 1) {
  2113. Value *idx = *GEP->idx_begin();
  2114. if (idx == zero) {
  2115. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  2116. }
  2117. }
  2118. }
  2119. if (unsigned level = MatchSizeByCheckElementType(SrcTy, DL, size, 0)) {
  2120. PatchZeroIdxGEP(Src, MI->getRawSource(), MI, level, Builder);
  2121. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  2122. if (GEP->getNumIndices() == 1) {
  2123. Value *idx = *GEP->idx_begin();
  2124. if (idx == zero) {
  2125. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  2126. }
  2127. }
  2128. }
  2129. }
  2130. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(Module &M) {
  2131. const DataLayout &DL = M.getDataLayout();
  2132. for (Function &F : M.functions()) {
  2133. for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
  2134. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
  2135. // Avoid invalidating the iterator.
  2136. Instruction *I = BI++;
  2137. if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
  2138. PatchMemCpyWithZeroIdxGEP(MI, DL);
  2139. }
  2140. }
  2141. }
  2142. }
  2143. }
  2144. static void DeleteMemcpy(MemCpyInst *MI) {
  2145. Value *Op0 = MI->getOperand(0);
  2146. Value *Op1 = MI->getOperand(1);
  2147. // delete memcpy
  2148. MI->eraseFromParent();
  2149. if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
  2150. if (op0->user_empty())
  2151. op0->eraseFromParent();
  2152. }
  2153. if (Instruction *op1 = dyn_cast<Instruction>(Op1)) {
  2154. if (op1->user_empty())
  2155. op1->eraseFromParent();
  2156. }
  2157. }
  2158. void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  2159. DxilFieldAnnotation *fieldAnnotation,
  2160. DxilTypeSystem &typeSys) {
  2161. Value *Dest = MI->getRawDest();
  2162. Value *Src = MI->getRawSource();
  2163. // Only remove one level bitcast generated from inline.
  2164. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  2165. Dest = BC->getOperand(0);
  2166. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  2167. Src = BC->getOperand(0);
  2168. if (Dest == Src) {
  2169. // delete self copy.
  2170. DeleteMemcpy(MI);
  2171. return;
  2172. }
  2173. IRBuilder<> Builder(MI);
  2174. Type *DestTy = Dest->getType()->getPointerElementType();
  2175. Type *SrcTy = Src->getType()->getPointerElementType();
  2176. // Allow copy between different address space.
  2177. if (DestTy != SrcTy) {
  2178. return;
  2179. }
  2180. llvm::SmallVector<llvm::Value *, 16> idxList;
  2181. // split
  2182. // Matrix is treated as scalar type, will not use memcpy.
  2183. // So use nullptr for fieldAnnotation should be safe here.
  2184. SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, typeSys,
  2185. fieldAnnotation);
  2186. // delete memcpy
  2187. DeleteMemcpy(MI);
  2188. }
  2189. void MemcpySplitter::Split(llvm::Function &F) {
  2190. const DataLayout &DL = F.getParent()->getDataLayout();
  2191. // Walk all instruction in the function.
  2192. for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
  2193. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
  2194. // Avoid invalidating the iterator.
  2195. Instruction *I = BI++;
  2196. if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
  2197. // Matrix is treated as scalar type, will not use memcpy.
  2198. // So use nullptr for fieldAnnotation should be safe here.
  2199. SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys);
  2200. }
  2201. }
  2202. }
  2203. }
  2204. //===----------------------------------------------------------------------===//
  2205. // SRoA Helper
  2206. //===----------------------------------------------------------------------===//
  2207. /// RewriteGEP - Rewrite the GEP to be relative to new element when can find a
  2208. /// new element which is struct field. If cannot find, create new element GEPs
  2209. /// and try to rewrite GEP with new GEPS.
  2210. void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
  2211. assert(OldVal == GEP->getPointerOperand() && "");
  2212. Value *NewPointer = nullptr;
  2213. SmallVector<Value *, 8> NewArgs;
  2214. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  2215. for (; GEPIt != E; ++GEPIt) {
  2216. if (GEPIt->isStructTy()) {
  2217. // must be const
  2218. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  2219. assert(IdxVal->getLimitedValue() < NewElts.size() && "");
  2220. NewPointer = NewElts[IdxVal->getLimitedValue()];
  2221. // The idx is used for NewPointer, not part of newGEP idx,
  2222. GEPIt++;
  2223. break;
  2224. } else if (GEPIt->isArrayTy()) {
  2225. // Add array idx.
  2226. NewArgs.push_back(GEPIt.getOperand());
  2227. } else if (GEPIt->isPointerTy()) {
  2228. // Add pointer idx.
  2229. NewArgs.push_back(GEPIt.getOperand());
  2230. } else if (GEPIt->isVectorTy()) {
  2231. // Add vector idx.
  2232. NewArgs.push_back(GEPIt.getOperand());
  2233. } else {
  2234. llvm_unreachable("should break from structTy");
  2235. }
  2236. }
  2237. if (NewPointer) {
  2238. // Struct split.
  2239. // Add rest of idx.
  2240. for (; GEPIt != E; ++GEPIt) {
  2241. NewArgs.push_back(GEPIt.getOperand());
  2242. }
  2243. // If only 1 level struct, just use the new pointer.
  2244. Value *NewGEP = NewPointer;
  2245. if (NewArgs.size() > 1) {
  2246. NewGEP = Builder.CreateInBoundsGEP(NewPointer, NewArgs);
  2247. NewGEP->takeName(GEP);
  2248. }
  2249. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  2250. GEP->replaceAllUsesWith(NewGEP);
  2251. if (isa<Instruction>(GEP))
  2252. DeadInsts.push_back(GEP);
  2253. } else {
  2254. // End at array of basic type.
  2255. Type *Ty = GEP->getType()->getPointerElementType();
  2256. if (Ty->isVectorTy() || Ty->isStructTy() || Ty->isArrayTy()) {
  2257. SmallVector<Value *, 8> NewArgs;
  2258. NewArgs.append(GEP->idx_begin(), GEP->idx_end());
  2259. SmallVector<Value *, 8> NewGEPs;
  2260. // create new geps
  2261. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2262. Value *NewGEP = Builder.CreateGEP(nullptr, NewElts[i], NewArgs);
  2263. NewGEPs.emplace_back(NewGEP);
  2264. }
  2265. SROA_Helper helper(GEP, NewGEPs, DeadInsts);
  2266. helper.RewriteForScalarRepl(GEP, Builder);
  2267. for (Value *NewGEP : NewGEPs) {
  2268. if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
  2269. // Delete unused newGEP.
  2270. cast<Instruction>(NewGEP)->eraseFromParent();
  2271. }
  2272. }
  2273. if (GEP->user_empty() && isa<Instruction>(GEP))
  2274. DeadInsts.push_back(GEP);
  2275. } else {
  2276. Value *vecIdx = NewArgs.back();
  2277. if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
  2278. // Replace vecArray[arrayIdx][immVecIdx]
  2279. // with scalarArray_immVecIdx[arrayIdx]
  2280. // Pop the vecIdx.
  2281. NewArgs.pop_back();
  2282. Value *NewGEP = NewElts[immVecIdx->getLimitedValue()];
  2283. if (NewArgs.size() > 1) {
  2284. NewGEP = Builder.CreateInBoundsGEP(NewGEP, NewArgs);
  2285. NewGEP->takeName(GEP);
  2286. }
  2287. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  2288. GEP->replaceAllUsesWith(NewGEP);
  2289. if (isa<Instruction>(GEP))
  2290. DeadInsts.push_back(GEP);
  2291. } else {
  2292. // dynamic vector indexing.
  2293. assert(0 && "should not reach here");
  2294. }
  2295. }
  2296. }
  2297. }
  2298. /// isVectorOrStructArray - Check if T is array of vector or struct.
  2299. static bool isVectorOrStructArray(Type *T) {
  2300. if (!T->isArrayTy())
  2301. return false;
  2302. T = HLModule::GetArrayEltTy(T);
  2303. return T->isStructTy() || T->isVectorTy();
  2304. }
  2305. static void SimplifyStructValUsage(Value *StructVal, std::vector<Value *> Elts,
  2306. SmallVectorImpl<Value *> &DeadInsts) {
  2307. for (User *user : StructVal->users()) {
  2308. if (ExtractValueInst *Extract = dyn_cast<ExtractValueInst>(user)) {
  2309. DXASSERT(Extract->getNumIndices() == 1, "only support 1 index case");
  2310. unsigned index = Extract->getIndices()[0];
  2311. Value *Elt = Elts[index];
  2312. Extract->replaceAllUsesWith(Elt);
  2313. DeadInsts.emplace_back(Extract);
  2314. } else if (InsertValueInst *Insert = dyn_cast<InsertValueInst>(user)) {
  2315. DXASSERT(Insert->getNumIndices() == 1, "only support 1 index case");
  2316. unsigned index = Insert->getIndices()[0];
  2317. if (Insert->getAggregateOperand() == StructVal) {
  2318. // Update field.
  2319. std::vector<Value *> NewElts = Elts;
  2320. NewElts[index] = Insert->getInsertedValueOperand();
  2321. SimplifyStructValUsage(Insert, NewElts, DeadInsts);
  2322. } else {
  2323. // Insert to another bigger struct.
  2324. IRBuilder<> Builder(Insert);
  2325. Value *TmpStructVal = UndefValue::get(StructVal->getType());
  2326. for (unsigned i = 0; i < Elts.size(); i++) {
  2327. TmpStructVal =
  2328. Builder.CreateInsertValue(TmpStructVal, Elts[i], {i});
  2329. }
  2330. Insert->replaceUsesOfWith(StructVal, TmpStructVal);
  2331. }
  2332. }
  2333. }
  2334. }
  2335. /// RewriteForLoad - Replace OldVal with flattened NewElts in LoadInst.
  2336. void SROA_Helper::RewriteForLoad(LoadInst *LI) {
  2337. Type *LIType = LI->getType();
  2338. Type *ValTy = OldVal->getType()->getPointerElementType();
  2339. IRBuilder<> Builder(LI);
  2340. if (LIType->isVectorTy()) {
  2341. // Replace:
  2342. // %res = load { 2 x i32 }* %alloc
  2343. // with:
  2344. // %load.0 = load i32* %alloc.0
  2345. // %insert.0 insertvalue { 2 x i32 } zeroinitializer, i32 %load.0, 0
  2346. // %load.1 = load i32* %alloc.1
  2347. // %insert = insertvalue { 2 x i32 } %insert.0, i32 %load.1, 1
  2348. Value *Insert = UndefValue::get(LIType);
  2349. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2350. Value *Load = Builder.CreateLoad(NewElts[i], "load");
  2351. Insert = Builder.CreateInsertElement(Insert, Load, i, "insert");
  2352. }
  2353. LI->replaceAllUsesWith(Insert);
  2354. DeadInsts.push_back(LI);
  2355. } else if (isCompatibleAggregate(LIType, ValTy)) {
  2356. if (isVectorOrStructArray(LIType)) {
  2357. // Replace:
  2358. // %res = load [2 x <2 x float>] * %alloc
  2359. // with:
  2360. // %load.0 = load [4 x float]* %alloc.0
  2361. // %insert.0 insertvalue [4 x float] zeroinitializer,i32 %load.0,0
  2362. // %load.1 = load [4 x float]* %alloc.1
  2363. // %insert = insertvalue [4 x float] %insert.0, i32 %load.1, 1
  2364. // ...
  2365. Type *i32Ty = Type::getInt32Ty(LIType->getContext());
  2366. Value *zero = ConstantInt::get(i32Ty, 0);
  2367. SmallVector<Value *, 8> idxList;
  2368. idxList.emplace_back(zero);
  2369. Value *newLd =
  2370. LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
  2371. LI->replaceAllUsesWith(newLd);
  2372. DeadInsts.push_back(LI);
  2373. } else {
  2374. // Replace:
  2375. // %res = load { i32, i32 }* %alloc
  2376. // with:
  2377. // %load.0 = load i32* %alloc.0
  2378. // %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0,
  2379. // 0
  2380. // %load.1 = load i32* %alloc.1
  2381. // %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1
  2382. // (Also works for arrays instead of structs)
  2383. Module *M = LI->getModule();
  2384. Value *Insert = UndefValue::get(LIType);
  2385. std::vector<Value *> LdElts(NewElts.size());
  2386. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2387. Value *Ptr = NewElts[i];
  2388. Type *Ty = Ptr->getType()->getPointerElementType();
  2389. Value *Load = nullptr;
  2390. if (!HLMatrixLower::IsMatrixType(Ty))
  2391. Load = Builder.CreateLoad(Ptr, "load");
  2392. else {
  2393. // Generate Matrix Load.
  2394. Load = HLModule::EmitHLOperationCall(
  2395. Builder, HLOpcodeGroup::HLMatLoadStore,
  2396. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty,
  2397. {Ptr}, *M);
  2398. }
  2399. LdElts[i] = Load;
  2400. Insert = Builder.CreateInsertValue(Insert, Load, i, "insert");
  2401. }
  2402. LI->replaceAllUsesWith(Insert);
  2403. if (LIType->isStructTy()) {
  2404. SimplifyStructValUsage(Insert, LdElts, DeadInsts);
  2405. }
  2406. DeadInsts.push_back(LI);
  2407. }
  2408. } else {
  2409. llvm_unreachable("other type don't need rewrite");
  2410. }
  2411. }
  2412. /// RewriteForStore - Replace OldVal with flattened NewElts in StoreInst.
  2413. void SROA_Helper::RewriteForStore(StoreInst *SI) {
  2414. Value *Val = SI->getOperand(0);
  2415. Type *SIType = Val->getType();
  2416. IRBuilder<> Builder(SI);
  2417. Type *ValTy = OldVal->getType()->getPointerElementType();
  2418. if (SIType->isVectorTy()) {
  2419. // Replace:
  2420. // store <2 x float> %val, <2 x float>* %alloc
  2421. // with:
  2422. // %val.0 = extractelement { 2 x float } %val, 0
  2423. // store i32 %val.0, i32* %alloc.0
  2424. // %val.1 = extractelement { 2 x float } %val, 1
  2425. // store i32 %val.1, i32* %alloc.1
  2426. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2427. Value *Extract = Builder.CreateExtractElement(Val, i, Val->getName());
  2428. Builder.CreateStore(Extract, NewElts[i]);
  2429. }
  2430. DeadInsts.push_back(SI);
  2431. } else if (isCompatibleAggregate(SIType, ValTy)) {
  2432. if (isVectorOrStructArray(SIType)) {
  2433. // Replace:
  2434. // store [2 x <2 x i32>] %val, [2 x <2 x i32>]* %alloc, align 16
  2435. // with:
  2436. // %val.0 = extractvalue [2 x <2 x i32>] %val, 0
  2437. // %all0c.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2438. // i32 0, i32 0
  2439. // %val.0.0 = extractelement <2 x i32> %243, i64 0
  2440. // store i32 %val.0.0, i32* %all0c.0.0
  2441. // %alloc.1.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2442. // i32 0, i32 0
  2443. // %val.0.1 = extractelement <2 x i32> %243, i64 1
  2444. // store i32 %val.0.1, i32* %alloc.1.0
  2445. // %val.1 = extractvalue [2 x <2 x i32>] %val, 1
  2446. // %alloc.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2447. // i32 0, i32 1
  2448. // %val.1.0 = extractelement <2 x i32> %248, i64 0
  2449. // store i32 %val.1.0, i32* %alloc.0.0
  2450. // %all0c.1.1 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2451. // i32 0, i32 1
  2452. // %val.1.1 = extractelement <2 x i32> %248, i64 1
  2453. // store i32 %val.1.1, i32* %all0c.1.1
  2454. ArrayType *AT = cast<ArrayType>(SIType);
  2455. Type *i32Ty = Type::getInt32Ty(SIType->getContext());
  2456. Value *zero = ConstantInt::get(i32Ty, 0);
  2457. SmallVector<Value *, 8> idxList;
  2458. idxList.emplace_back(zero);
  2459. StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
  2460. DeadInsts.push_back(SI);
  2461. } else {
  2462. // Replace:
  2463. // store { i32, i32 } %val, { i32, i32 }* %alloc
  2464. // with:
  2465. // %val.0 = extractvalue { i32, i32 } %val, 0
  2466. // store i32 %val.0, i32* %alloc.0
  2467. // %val.1 = extractvalue { i32, i32 } %val, 1
  2468. // store i32 %val.1, i32* %alloc.1
  2469. // (Also works for arrays instead of structs)
  2470. Module *M = SI->getModule();
  2471. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2472. Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
  2473. if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
  2474. Builder.CreateStore(Extract, NewElts[i]);
  2475. } else {
  2476. // Generate Matrix Store.
  2477. HLModule::EmitHLOperationCall(
  2478. Builder, HLOpcodeGroup::HLMatLoadStore,
  2479. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore),
  2480. Extract->getType(), {NewElts[i], Extract}, *M);
  2481. }
  2482. }
  2483. DeadInsts.push_back(SI);
  2484. }
  2485. } else {
  2486. llvm_unreachable("other type don't need rewrite");
  2487. }
  2488. }
  2489. /// RewriteMemIntrin - MI is a memcpy/memset/memmove from or to AI.
  2490. /// Rewrite it to copy or set the elements of the scalarized memory.
  2491. void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Instruction *Inst) {
  2492. // If this is a memcpy/memmove, construct the other pointer as the
  2493. // appropriate type. The "Other" pointer is the pointer that goes to memory
  2494. // that doesn't have anything to do with the alloca that we are promoting. For
  2495. // memset, this Value* stays null.
  2496. Value *OtherPtr = nullptr;
  2497. unsigned MemAlignment = MI->getAlignment();
  2498. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy
  2499. if (Inst == MTI->getRawDest())
  2500. OtherPtr = MTI->getRawSource();
  2501. else {
  2502. assert(Inst == MTI->getRawSource());
  2503. OtherPtr = MTI->getRawDest();
  2504. }
  2505. }
  2506. // If there is an other pointer, we want to convert it to the same pointer
  2507. // type as AI has, so we can GEP through it safely.
  2508. if (OtherPtr) {
  2509. unsigned AddrSpace =
  2510. cast<PointerType>(OtherPtr->getType())->getAddressSpace();
  2511. // Remove bitcasts and all-zero GEPs from OtherPtr. This is an
  2512. // optimization, but it's also required to detect the corner case where
  2513. // both pointer operands are referencing the same memory, and where
  2514. // OtherPtr may be a bitcast or GEP that currently being rewritten. (This
  2515. // function is only called for mem intrinsics that access the whole
  2516. // aggregate, so non-zero GEPs are not an issue here.)
  2517. OtherPtr = OtherPtr->stripPointerCasts();
  2518. // Copying the alloca to itself is a no-op: just delete it.
  2519. if (OtherPtr == OldVal || OtherPtr == NewElts[0]) {
  2520. // This code will run twice for a no-op memcpy -- once for each operand.
  2521. // Put only one reference to MI on the DeadInsts list.
  2522. for (SmallVectorImpl<Value *>::const_iterator I = DeadInsts.begin(),
  2523. E = DeadInsts.end();
  2524. I != E; ++I)
  2525. if (*I == MI)
  2526. return;
  2527. DeadInsts.push_back(MI);
  2528. return;
  2529. }
  2530. // If the pointer is not the right type, insert a bitcast to the right
  2531. // type.
  2532. Type *NewTy =
  2533. PointerType::get(OldVal->getType()->getPointerElementType(), AddrSpace);
  2534. if (OtherPtr->getType() != NewTy)
  2535. OtherPtr = new BitCastInst(OtherPtr, NewTy, OtherPtr->getName(), MI);
  2536. }
  2537. // Process each element of the aggregate.
  2538. bool SROADest = MI->getRawDest() == Inst;
  2539. Constant *Zero = Constant::getNullValue(Type::getInt32Ty(MI->getContext()));
  2540. const DataLayout &DL = MI->getModule()->getDataLayout();
  2541. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2542. // If this is a memcpy/memmove, emit a GEP of the other element address.
  2543. Value *OtherElt = nullptr;
  2544. unsigned OtherEltAlign = MemAlignment;
  2545. if (OtherPtr) {
  2546. Value *Idx[2] = {Zero,
  2547. ConstantInt::get(Type::getInt32Ty(MI->getContext()), i)};
  2548. OtherElt = GetElementPtrInst::CreateInBounds(
  2549. OtherPtr, Idx, OtherPtr->getName() + "." + Twine(i), MI);
  2550. uint64_t EltOffset;
  2551. PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType());
  2552. Type *OtherTy = OtherPtrTy->getElementType();
  2553. if (StructType *ST = dyn_cast<StructType>(OtherTy)) {
  2554. EltOffset = DL.getStructLayout(ST)->getElementOffset(i);
  2555. } else {
  2556. Type *EltTy = cast<SequentialType>(OtherTy)->getElementType();
  2557. EltOffset = DL.getTypeAllocSize(EltTy) * i;
  2558. }
  2559. // The alignment of the other pointer is the guaranteed alignment of the
  2560. // element, which is affected by both the known alignment of the whole
  2561. // mem intrinsic and the alignment of the element. If the alignment of
  2562. // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the
  2563. // known alignment is just 4 bytes.
  2564. OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset);
  2565. }
  2566. Value *EltPtr = NewElts[i];
  2567. Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType();
  2568. // If we got down to a scalar, insert a load or store as appropriate.
  2569. if (EltTy->isSingleValueType()) {
  2570. if (isa<MemTransferInst>(MI)) {
  2571. if (SROADest) {
  2572. // From Other to Alloca.
  2573. Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI);
  2574. new StoreInst(Elt, EltPtr, MI);
  2575. } else {
  2576. // From Alloca to Other.
  2577. Value *Elt = new LoadInst(EltPtr, "tmp", MI);
  2578. new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI);
  2579. }
  2580. continue;
  2581. }
  2582. assert(isa<MemSetInst>(MI));
  2583. // If the stored element is zero (common case), just store a null
  2584. // constant.
  2585. Constant *StoreVal;
  2586. if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getArgOperand(1))) {
  2587. if (CI->isZero()) {
  2588. StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0>
  2589. } else {
  2590. // If EltTy is a vector type, get the element type.
  2591. Type *ValTy = EltTy->getScalarType();
  2592. // Construct an integer with the right value.
  2593. unsigned EltSize = DL.getTypeSizeInBits(ValTy);
  2594. APInt OneVal(EltSize, CI->getZExtValue());
  2595. APInt TotalVal(OneVal);
  2596. // Set each byte.
  2597. for (unsigned i = 0; 8 * i < EltSize; ++i) {
  2598. TotalVal = TotalVal.shl(8);
  2599. TotalVal |= OneVal;
  2600. }
  2601. // Convert the integer value to the appropriate type.
  2602. StoreVal = ConstantInt::get(CI->getContext(), TotalVal);
  2603. if (ValTy->isPointerTy())
  2604. StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
  2605. else if (ValTy->isFloatingPointTy())
  2606. StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
  2607. assert(StoreVal->getType() == ValTy && "Type mismatch!");
  2608. // If the requested value was a vector constant, create it.
  2609. if (EltTy->isVectorTy()) {
  2610. unsigned NumElts = cast<VectorType>(EltTy)->getNumElements();
  2611. StoreVal = ConstantVector::getSplat(NumElts, StoreVal);
  2612. }
  2613. }
  2614. new StoreInst(StoreVal, EltPtr, MI);
  2615. continue;
  2616. }
  2617. // Otherwise, if we're storing a byte variable, use a memset call for
  2618. // this element.
  2619. }
  2620. unsigned EltSize = DL.getTypeAllocSize(EltTy);
  2621. if (!EltSize)
  2622. continue;
  2623. IRBuilder<> Builder(MI);
  2624. // Finally, insert the meminst for this element.
  2625. if (isa<MemSetInst>(MI)) {
  2626. Builder.CreateMemSet(EltPtr, MI->getArgOperand(1), EltSize,
  2627. MI->isVolatile());
  2628. } else {
  2629. assert(isa<MemTransferInst>(MI));
  2630. Value *Dst = SROADest ? EltPtr : OtherElt; // Dest ptr
  2631. Value *Src = SROADest ? OtherElt : EltPtr; // Src ptr
  2632. if (isa<MemCpyInst>(MI))
  2633. Builder.CreateMemCpy(Dst, Src, EltSize, OtherEltAlign,
  2634. MI->isVolatile());
  2635. else
  2636. Builder.CreateMemMove(Dst, Src, EltSize, OtherEltAlign,
  2637. MI->isVolatile());
  2638. }
  2639. }
  2640. DeadInsts.push_back(MI);
  2641. }
  2642. void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
  2643. Type *DstTy = BCI->getType();
  2644. Value *Val = BCI->getOperand(0);
  2645. Type *SrcTy = Val->getType();
  2646. if (!DstTy->isPointerTy()) {
  2647. assert(0 && "Type mismatch.");
  2648. return;
  2649. }
  2650. if (!SrcTy->isPointerTy()) {
  2651. assert(0 && "Type mismatch.");
  2652. return;
  2653. }
  2654. DstTy = DstTy->getPointerElementType();
  2655. SrcTy = SrcTy->getPointerElementType();
  2656. if (!DstTy->isStructTy()) {
  2657. assert(0 && "Type mismatch.");
  2658. return;
  2659. }
  2660. if (!SrcTy->isStructTy()) {
  2661. assert(0 && "Type mismatch.");
  2662. return;
  2663. }
  2664. // Only support bitcast to parent struct type.
  2665. StructType *DstST = cast<StructType>(DstTy);
  2666. StructType *SrcST = cast<StructType>(SrcTy);
  2667. bool bTypeMatch = false;
  2668. unsigned level = 0;
  2669. while (SrcST) {
  2670. level++;
  2671. Type *EltTy = SrcST->getElementType(0);
  2672. if (EltTy == DstST) {
  2673. bTypeMatch = true;
  2674. break;
  2675. }
  2676. SrcST = dyn_cast<StructType>(EltTy);
  2677. }
  2678. if (!bTypeMatch) {
  2679. assert(0 && "Type mismatch.");
  2680. return;
  2681. }
  2682. std::vector<Value*> idxList(level+1);
  2683. ConstantInt *zeroIdx = ConstantInt::get(Type::getInt32Ty(Val->getContext()), 0);
  2684. for (unsigned i=0;i<(level+1);i++)
  2685. idxList[i] = zeroIdx;
  2686. IRBuilder<> Builder(BCI);
  2687. Instruction *GEP = cast<Instruction>(Builder.CreateInBoundsGEP(Val, idxList));
  2688. BCI->replaceAllUsesWith(GEP);
  2689. BCI->eraseFromParent();
  2690. IRBuilder<> GEPBuilder(GEP);
  2691. RewriteForGEP(cast<GEPOperator>(GEP), GEPBuilder);
  2692. }
  2693. /// RewriteCall - Replace OldVal with flattened NewElts in CallInst.
  2694. void SROA_Helper::RewriteCall(CallInst *CI) {
  2695. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2696. Function *F = CI->getCalledFunction();
  2697. if (group != HLOpcodeGroup::NotHL) {
  2698. unsigned opcode = GetHLOpcode(CI);
  2699. if (group == HLOpcodeGroup::HLIntrinsic) {
  2700. IntrinsicOp IOP = static_cast<IntrinsicOp>(opcode);
  2701. switch (IOP) {
  2702. case IntrinsicOp::MOP_Append: {
  2703. // Buffer Append already expand in code gen.
  2704. // Must be OutputStream Append here.
  2705. SmallVector<Value *, 4> flatArgs;
  2706. for (Value *arg : CI->arg_operands()) {
  2707. if (arg == OldVal) {
  2708. // Flatten to arg.
  2709. // Every Elt has a pointer type.
  2710. // For Append, it's not a problem.
  2711. for (Value *Elt : NewElts)
  2712. flatArgs.emplace_back(Elt);
  2713. } else
  2714. flatArgs.emplace_back(arg);
  2715. }
  2716. SmallVector<Type *, 4> flatParamTys;
  2717. for (Value *arg : flatArgs)
  2718. flatParamTys.emplace_back(arg->getType());
  2719. // Don't need flat return type for Append.
  2720. FunctionType *flatFuncTy =
  2721. FunctionType::get(CI->getType(), flatParamTys, false);
  2722. Function *flatF =
  2723. GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode);
  2724. IRBuilder<> Builder(CI);
  2725. // Append return void, don't need to replace CI with flatCI.
  2726. Builder.CreateCall(flatF, flatArgs);
  2727. DeadInsts.push_back(CI);
  2728. } break;
  2729. default:
  2730. DXASSERT(0, "cannot flatten hlsl intrinsic.");
  2731. }
  2732. }
  2733. // TODO: check other high level dx operations if need to.
  2734. } else {
  2735. DXASSERT(0, "should done at inline");
  2736. }
  2737. }
  2738. /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
  2739. void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
  2740. if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
  2741. if (OldVal == GEP->getPointerOperand()) {
  2742. // Flatten GEP.
  2743. RewriteForGEP(GEP, Builder);
  2744. return;
  2745. }
  2746. }
  2747. // Skip unused CE.
  2748. if (CE->use_empty())
  2749. return;
  2750. Instruction *constInst = CE->getAsInstruction();
  2751. Builder.Insert(constInst);
  2752. // Replace CE with constInst.
  2753. for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
  2754. Use &TheUse = *UI++;
  2755. if (isa<Instruction>(TheUse.getUser()))
  2756. TheUse.set(constInst);
  2757. else {
  2758. RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
  2759. }
  2760. }
  2761. }
  2762. /// RewriteForScalarRepl - OldVal is being split into NewElts, so rewrite
  2763. /// users of V, which references it, to use the separate elements.
  2764. void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
  2765. for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;) {
  2766. Use &TheUse = *UI++;
  2767. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(TheUse.getUser())) {
  2768. RewriteForConstExpr(CE, Builder);
  2769. continue;
  2770. }
  2771. Instruction *User = cast<Instruction>(TheUse.getUser());
  2772. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  2773. IRBuilder<> Builder(GEP);
  2774. RewriteForGEP(cast<GEPOperator>(GEP), Builder);
  2775. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
  2776. RewriteForLoad(ldInst);
  2777. else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
  2778. RewriteForStore(stInst);
  2779. else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
  2780. RewriteMemIntrin(MI, cast<Instruction>(V));
  2781. else if (CallInst *CI = dyn_cast<CallInst>(User))
  2782. RewriteCall(CI);
  2783. else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
  2784. RewriteBitCast(BCI);
  2785. else {
  2786. assert(0 && "not support.");
  2787. }
  2788. }
  2789. }
  2790. static ArrayType *CreateNestArrayTy(Type *FinalEltTy,
  2791. ArrayRef<ArrayType *> nestArrayTys) {
  2792. Type *newAT = FinalEltTy;
  2793. for (auto ArrayTy = nestArrayTys.rbegin(), E=nestArrayTys.rend(); ArrayTy != E;
  2794. ++ArrayTy)
  2795. newAT = ArrayType::get(newAT, (*ArrayTy)->getNumElements());
  2796. return cast<ArrayType>(newAT);
  2797. }
  2798. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2799. /// Then do SROA on V.
  2800. bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  2801. IRBuilder<> &Builder, bool bFlatVector,
  2802. bool hasPrecise, DxilTypeSystem &typeSys,
  2803. SmallVector<Value *, 32> &DeadInsts) {
  2804. DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
  2805. Type *Ty = V->getType();
  2806. // Skip none pointer types.
  2807. if (!Ty->isPointerTy())
  2808. return false;
  2809. Ty = Ty->getPointerElementType();
  2810. // Skip none aggregate types.
  2811. if (!Ty->isAggregateType())
  2812. return false;
  2813. // Skip matrix types.
  2814. if (HLMatrixLower::IsMatrixType(Ty))
  2815. return false;
  2816. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2817. // Skip HLSL object types.
  2818. if (HLModule::IsHLSLObjectType(ST)) {
  2819. return false;
  2820. }
  2821. unsigned numTypes = ST->getNumContainedTypes();
  2822. Elts.reserve(numTypes);
  2823. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2824. // Skip empty struct.
  2825. if (SA && SA->IsEmptyStruct())
  2826. return true;
  2827. for (int i = 0, e = numTypes; i != e; ++i) {
  2828. AllocaInst *NA = Builder.CreateAlloca(ST->getContainedType(i), nullptr, V->getName() + "." + Twine(i));
  2829. bool markPrecise = hasPrecise;
  2830. if (SA) {
  2831. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2832. markPrecise |= FA.IsPrecise();
  2833. }
  2834. if (markPrecise)
  2835. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2836. Elts.push_back(NA);
  2837. }
  2838. } else {
  2839. ArrayType *AT = cast<ArrayType>(Ty);
  2840. if (AT->getNumContainedTypes() == 0) {
  2841. // Skip case like [0 x %struct].
  2842. return false;
  2843. }
  2844. Type *ElTy = AT->getElementType();
  2845. SmallVector<ArrayType *, 4> nestArrayTys;
  2846. nestArrayTys.emplace_back(AT);
  2847. // support multi level of array
  2848. while (ElTy->isArrayTy()) {
  2849. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2850. nestArrayTys.emplace_back(ElAT);
  2851. ElTy = ElAT->getElementType();
  2852. }
  2853. if (ElTy->isStructTy() &&
  2854. // Skip Matrix type.
  2855. !HLMatrixLower::IsMatrixType(ElTy)) {
  2856. // Skip HLSL object types.
  2857. if (HLModule::IsHLSLObjectType(ElTy)) {
  2858. return false;
  2859. }
  2860. // for array of struct
  2861. // split into arrays of struct elements
  2862. StructType *ElST = cast<StructType>(ElTy);
  2863. unsigned numTypes = ElST->getNumContainedTypes();
  2864. Elts.reserve(numTypes);
  2865. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  2866. // Skip empty struct.
  2867. if (SA && SA->IsEmptyStruct())
  2868. return true;
  2869. for (int i = 0, e = numTypes; i != e; ++i) {
  2870. AllocaInst *NA = Builder.CreateAlloca(
  2871. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys), nullptr,
  2872. V->getName() + "." + Twine(i));
  2873. bool markPrecise = hasPrecise;
  2874. if (SA) {
  2875. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2876. markPrecise |= FA.IsPrecise();
  2877. }
  2878. if (markPrecise)
  2879. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2880. Elts.push_back(NA);
  2881. }
  2882. } else if (ElTy->isVectorTy()) {
  2883. // Skip vector if required.
  2884. if (!bFlatVector)
  2885. return false;
  2886. // for array of vector
  2887. // split into arrays of scalar
  2888. VectorType *ElVT = cast<VectorType>(ElTy);
  2889. Elts.reserve(ElVT->getNumElements());
  2890. ArrayType *scalarArrayTy = CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  2891. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  2892. AllocaInst *NA = Builder.CreateAlloca(scalarArrayTy, nullptr,
  2893. V->getName() + "." + Twine(i));
  2894. if (hasPrecise)
  2895. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2896. Elts.push_back(NA);
  2897. }
  2898. } else
  2899. // Skip array of basic types.
  2900. return false;
  2901. }
  2902. // Now that we have created the new alloca instructions, rewrite all the
  2903. // uses of the old alloca.
  2904. SROA_Helper helper(V, Elts, DeadInsts);
  2905. helper.RewriteForScalarRepl(V, Builder);
  2906. return true;
  2907. }
  2908. static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
  2909. Type *EltTy) {
  2910. if (isa<UndefValue>(Init))
  2911. return UndefValue::get(EltTy);
  2912. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2913. return Init->getAggregateElement(idx);
  2914. } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  2915. return Init->getAggregateElement(idx);
  2916. } else {
  2917. ArrayType *AT = cast<ArrayType>(Ty);
  2918. ArrayType *EltArrayTy = cast<ArrayType>(EltTy);
  2919. std::vector<Constant *> Elts;
  2920. if (!AT->getElementType()->isArrayTy()) {
  2921. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2922. // Get Array[i]
  2923. Constant *InitArrayElt = Init->getAggregateElement(i);
  2924. // Get Array[i].idx
  2925. InitArrayElt = InitArrayElt->getAggregateElement(idx);
  2926. Elts.emplace_back(InitArrayElt);
  2927. }
  2928. return ConstantArray::get(EltArrayTy, Elts);
  2929. } else {
  2930. Type *EltTy = AT->getElementType();
  2931. ArrayType *NestEltArrayTy = cast<ArrayType>(EltArrayTy->getElementType());
  2932. // Nested array.
  2933. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2934. // Get Array[i]
  2935. Constant *InitArrayElt = Init->getAggregateElement(i);
  2936. // Get Array[i].idx
  2937. InitArrayElt = GetEltInit(EltTy, InitArrayElt, idx, NestEltArrayTy);
  2938. Elts.emplace_back(InitArrayElt);
  2939. }
  2940. return ConstantArray::get(EltArrayTy, Elts);
  2941. }
  2942. }
  2943. }
  2944. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2945. /// Then do SROA on V.
  2946. bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
  2947. IRBuilder<> &Builder, bool bFlatVector,
  2948. bool hasPrecise, DxilTypeSystem &typeSys,
  2949. SmallVector<Value *, 32> &DeadInsts) {
  2950. DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
  2951. Type *Ty = GV->getType();
  2952. // Skip none pointer types.
  2953. if (!Ty->isPointerTy())
  2954. return false;
  2955. Ty = Ty->getPointerElementType();
  2956. // Skip none aggregate types.
  2957. if (!Ty->isAggregateType() && !bFlatVector)
  2958. return false;
  2959. // Skip basic types.
  2960. if (Ty->isSingleValueType() && !Ty->isVectorTy())
  2961. return false;
  2962. // Skip matrix types.
  2963. if (HLMatrixLower::IsMatrixType(Ty))
  2964. return false;
  2965. Module *M = GV->getParent();
  2966. Constant *Init = GV->getInitializer();
  2967. if (!Init)
  2968. Init = UndefValue::get(Ty);
  2969. bool isConst = GV->isConstant();
  2970. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2971. unsigned AddressSpace = GV->getType()->getAddressSpace();
  2972. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  2973. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2974. // Skip HLSL object types.
  2975. if (HLModule::IsHLSLObjectType(ST))
  2976. return false;
  2977. unsigned numTypes = ST->getNumContainedTypes();
  2978. Elts.reserve(numTypes);
  2979. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2980. for (int i = 0, e = numTypes; i != e; ++i) {
  2981. Constant *EltInit = GetEltInit(Ty, Init, i, ST->getElementType(i));
  2982. GlobalVariable *EltGV = new llvm::GlobalVariable(
  2983. *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
  2984. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  2985. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  2986. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2987. // TODO: set precise.
  2988. // if (hasPrecise || FA.IsPrecise())
  2989. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  2990. Elts.push_back(EltGV);
  2991. }
  2992. } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  2993. // TODO: support dynamic indexing on vector by change it to array.
  2994. unsigned numElts = VT->getNumElements();
  2995. Elts.reserve(numElts);
  2996. Type *EltTy = VT->getElementType();
  2997. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2998. for (int i = 0, e = numElts; i != e; ++i) {
  2999. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  3000. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3001. *M, EltTy, /*IsConstant*/ isConst, linkage,
  3002. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3003. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3004. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  3005. // TODO: set precise.
  3006. // if (hasPrecise || FA.IsPrecise())
  3007. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3008. Elts.push_back(EltGV);
  3009. }
  3010. } else {
  3011. ArrayType *AT = cast<ArrayType>(Ty);
  3012. if (AT->getNumContainedTypes() == 0) {
  3013. // Skip case like [0 x %struct].
  3014. return false;
  3015. }
  3016. Type *ElTy = AT->getElementType();
  3017. SmallVector<ArrayType *, 4> nestArrayTys;
  3018. nestArrayTys.emplace_back(AT);
  3019. // support multi level of array
  3020. while (ElTy->isArrayTy()) {
  3021. ArrayType *ElAT = cast<ArrayType>(ElTy);
  3022. nestArrayTys.emplace_back(ElAT);
  3023. ElTy = ElAT->getElementType();
  3024. }
  3025. if (ElTy->isStructTy() &&
  3026. // Skip Matrix type.
  3027. !HLMatrixLower::IsMatrixType(ElTy)) {
  3028. // for array of struct
  3029. // split into arrays of struct elements
  3030. StructType *ElST = cast<StructType>(ElTy);
  3031. unsigned numTypes = ElST->getNumContainedTypes();
  3032. Elts.reserve(numTypes);
  3033. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  3034. for (int i = 0, e = numTypes; i != e; ++i) {
  3035. Type *EltTy =
  3036. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys);
  3037. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  3038. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3039. *M, EltTy, /*IsConstant*/ isConst, linkage,
  3040. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3041. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3042. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  3043. // TODO: set precise.
  3044. // if (hasPrecise || FA.IsPrecise())
  3045. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3046. Elts.push_back(EltGV);
  3047. }
  3048. } else if (ElTy->isVectorTy()) {
  3049. // Skip vector if required.
  3050. if (!bFlatVector)
  3051. return false;
  3052. // for array of vector
  3053. // split into arrays of scalar
  3054. VectorType *ElVT = cast<VectorType>(ElTy);
  3055. Elts.reserve(ElVT->getNumElements());
  3056. ArrayType *scalarArrayTy =
  3057. CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  3058. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  3059. Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
  3060. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3061. *M, scalarArrayTy, /*IsConstant*/ isConst, linkage,
  3062. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3063. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3064. // TODO: set precise.
  3065. // if (hasPrecise)
  3066. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3067. Elts.push_back(EltGV);
  3068. }
  3069. } else
  3070. // Skip array of basic types.
  3071. return false;
  3072. }
  3073. // Now that we have created the new alloca instructions, rewrite all the
  3074. // uses of the old alloca.
  3075. SROA_Helper helper(GV, Elts, DeadInsts);
  3076. helper.RewriteForScalarRepl(GV, Builder);
  3077. return true;
  3078. }
  3079. struct PointerStatus {
  3080. /// Keep track of what stores to the pointer look like.
  3081. enum StoredType {
  3082. /// There is no store to this pointer. It can thus be marked constant.
  3083. NotStored,
  3084. /// This ptr is a global, and is stored to, but the only thing stored is the
  3085. /// constant it
  3086. /// was initialized with. This is only tracked for scalar globals.
  3087. InitializerStored,
  3088. /// This ptr is stored to, but only its initializer and one other value
  3089. /// is ever stored to it. If this global isStoredOnce, we track the value
  3090. /// stored to it in StoredOnceValue below. This is only tracked for scalar
  3091. /// globals.
  3092. StoredOnce,
  3093. /// This ptr is only assigned by a memcpy.
  3094. MemcopyDestOnce,
  3095. /// This ptr is stored to by multiple values or something else that we
  3096. /// cannot track.
  3097. Stored
  3098. } StoredType;
  3099. /// Keep track of what loaded from the pointer look like.
  3100. enum LoadedType {
  3101. /// There is no load to this pointer. It can thus be marked constant.
  3102. NotLoaded,
  3103. /// This ptr is only used by a memcpy.
  3104. MemcopySrcOnce,
  3105. /// This ptr is loaded to by multiple instructions or something else that we
  3106. /// cannot track.
  3107. Loaded
  3108. } LoadedType;
  3109. /// If only one value (besides the initializer constant) is ever stored to
  3110. /// this global, keep track of what value it is.
  3111. Value *StoredOnceValue;
  3112. /// Memcpy which this ptr is used.
  3113. std::unordered_set<MemCpyInst *> memcpySet;
  3114. /// Memcpy which use this ptr as dest.
  3115. MemCpyInst *StoringMemcpy;
  3116. /// Memcpy which use this ptr as src.
  3117. MemCpyInst *LoadingMemcpy;
  3118. /// These start out null/false. When the first accessing function is noticed,
  3119. /// it is recorded. When a second different accessing function is noticed,
  3120. /// HasMultipleAccessingFunctions is set to true.
  3121. const Function *AccessingFunction;
  3122. bool HasMultipleAccessingFunctions;
  3123. /// Size of the ptr.
  3124. unsigned Size;
  3125. /// Look at all uses of the global and fill in the GlobalStatus structure. If
  3126. /// the global has its address taken, return true to indicate we can't do
  3127. /// anything with it.
  3128. static void analyzePointer(const Value *V, PointerStatus &PS,
  3129. DxilTypeSystem &typeSys, bool bStructElt);
  3130. PointerStatus(unsigned size)
  3131. : StoredType(NotStored), LoadedType(NotLoaded), StoredOnceValue(nullptr),
  3132. StoringMemcpy(nullptr), LoadingMemcpy(nullptr),
  3133. AccessingFunction(nullptr), HasMultipleAccessingFunctions(false),
  3134. Size(size) {}
  3135. void MarkAsStored() {
  3136. StoredType = PointerStatus::StoredType::Stored;
  3137. StoredOnceValue = nullptr;
  3138. }
  3139. void MarkAsLoaded() { LoadedType = PointerStatus::LoadedType::Loaded; }
  3140. };
  3141. void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
  3142. DxilTypeSystem &typeSys, bool bStructElt) {
  3143. if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
  3144. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  3145. PS.StoredType = PointerStatus::StoredType::InitializerStored;
  3146. }
  3147. }
  3148. for (const User *U : V->users()) {
  3149. if (const Instruction *I = dyn_cast<Instruction>(U)) {
  3150. const Function *F = I->getParent()->getParent();
  3151. if (!PS.AccessingFunction) {
  3152. PS.AccessingFunction = F;
  3153. } else {
  3154. if (F != PS.AccessingFunction)
  3155. PS.HasMultipleAccessingFunctions = true;
  3156. }
  3157. }
  3158. if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(U)) {
  3159. analyzePointer(BC, PS, typeSys, bStructElt);
  3160. } else if (const MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
  3161. // Do not collect memcpy on struct GEP use.
  3162. // These memcpy will be flattened in next level.
  3163. if (!bStructElt) {
  3164. MemCpyInst *MI = const_cast<MemCpyInst *>(MC);
  3165. PS.memcpySet.insert(MI);
  3166. bool bFullCopy = false;
  3167. if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
  3168. bFullCopy = PS.Size == Length->getLimitedValue();
  3169. }
  3170. if (MC->getRawDest() == V) {
  3171. if (bFullCopy &&
  3172. PS.StoredType == PointerStatus::StoredType::NotStored) {
  3173. PS.StoredType = PointerStatus::StoredType::MemcopyDestOnce;
  3174. PS.StoringMemcpy = MI;
  3175. } else {
  3176. PS.MarkAsStored();
  3177. PS.StoringMemcpy = nullptr;
  3178. }
  3179. } else if (MC->getRawSource() == V) {
  3180. if (bFullCopy &&
  3181. PS.LoadedType == PointerStatus::LoadedType::NotLoaded) {
  3182. PS.LoadedType = PointerStatus::LoadedType::MemcopySrcOnce;
  3183. PS.LoadingMemcpy = MI;
  3184. } else {
  3185. PS.MarkAsLoaded();
  3186. PS.LoadingMemcpy = nullptr;
  3187. }
  3188. }
  3189. } else {
  3190. if (MC->getRawDest() == V) {
  3191. PS.MarkAsStored();
  3192. } else {
  3193. DXASSERT(MC->getRawSource() == V, "must be source here");
  3194. PS.MarkAsLoaded();
  3195. }
  3196. }
  3197. } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3198. gep_type_iterator GEPIt = gep_type_begin(GEP);
  3199. gep_type_iterator GEPEnd = gep_type_end(GEP);
  3200. // Skip pointer idx.
  3201. GEPIt++;
  3202. // Struct elt will be flattened in next level.
  3203. bool bStructElt = (GEPIt != GEPEnd) && GEPIt->isStructTy();
  3204. analyzePointer(GEP, PS, typeSys, bStructElt);
  3205. } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) {
  3206. Value *V = SI->getOperand(0);
  3207. if (PS.StoredType == PointerStatus::StoredType::NotStored) {
  3208. PS.StoredType = PointerStatus::StoredType::StoredOnce;
  3209. PS.StoredOnceValue = V;
  3210. } else {
  3211. PS.MarkAsStored();
  3212. }
  3213. } else if (const LoadInst *LI = dyn_cast<LoadInst>(U)) {
  3214. PS.MarkAsLoaded();
  3215. } else if (const CallInst *CI = dyn_cast<CallInst>(U)) {
  3216. Function *F = CI->getCalledFunction();
  3217. DxilFunctionAnnotation *annotation = typeSys.GetFunctionAnnotation(F);
  3218. if (!annotation) {
  3219. HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
  3220. switch (group) {
  3221. case HLOpcodeGroup::HLMatLoadStore: {
  3222. HLMatLoadStoreOpcode opcode =
  3223. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  3224. switch (opcode) {
  3225. case HLMatLoadStoreOpcode::ColMatLoad:
  3226. case HLMatLoadStoreOpcode::RowMatLoad:
  3227. PS.MarkAsLoaded();
  3228. break;
  3229. case HLMatLoadStoreOpcode::ColMatStore:
  3230. case HLMatLoadStoreOpcode::RowMatStore:
  3231. PS.MarkAsStored();
  3232. break;
  3233. default:
  3234. DXASSERT(0, "invalid opcode");
  3235. PS.MarkAsStored();
  3236. PS.MarkAsLoaded();
  3237. }
  3238. } break;
  3239. case HLOpcodeGroup::HLSubscript: {
  3240. HLSubscriptOpcode opcode =
  3241. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
  3242. switch (opcode) {
  3243. case HLSubscriptOpcode::VectorSubscript:
  3244. case HLSubscriptOpcode::ColMatElement:
  3245. case HLSubscriptOpcode::ColMatSubscript:
  3246. case HLSubscriptOpcode::RowMatElement:
  3247. case HLSubscriptOpcode::RowMatSubscript:
  3248. analyzePointer(CI, PS, typeSys, bStructElt);
  3249. break;
  3250. default:
  3251. // Rest are resource ptr like buf[i].
  3252. // Only read of resource handle.
  3253. PS.MarkAsLoaded();
  3254. break;
  3255. }
  3256. } break;
  3257. default: {
  3258. // If not sure its out param or not. Take as out param.
  3259. PS.MarkAsStored();
  3260. PS.MarkAsLoaded();
  3261. }
  3262. }
  3263. continue;
  3264. }
  3265. unsigned argSize = F->arg_size();
  3266. for (unsigned i = 0; i < argSize; i++) {
  3267. Value *arg = CI->getArgOperand(i);
  3268. if (V == arg) {
  3269. DxilParamInputQual inputQual =
  3270. annotation->GetParameterAnnotation(i).GetParamInputQual();
  3271. if (inputQual != DxilParamInputQual::In) {
  3272. PS.MarkAsStored();
  3273. if (inputQual == DxilParamInputQual::Inout)
  3274. PS.MarkAsLoaded();
  3275. break;
  3276. } else {
  3277. PS.MarkAsLoaded();
  3278. break;
  3279. }
  3280. }
  3281. }
  3282. }
  3283. }
  3284. }
  3285. static void ReplaceConstantWithInst(Constant *C, Value *V, IRBuilder<> &Builder) {
  3286. for (auto it = C->user_begin(); it != C->user_end(); ) {
  3287. User *U = *(it++);
  3288. if (Instruction *I = dyn_cast<Instruction>(U)) {
  3289. I->replaceUsesOfWith(C, V);
  3290. } else {
  3291. ConstantExpr *CE = cast<ConstantExpr>(U);
  3292. Instruction *Inst = CE->getAsInstruction();
  3293. Builder.Insert(Inst);
  3294. Inst->replaceUsesOfWith(C, V);
  3295. ReplaceConstantWithInst(CE, Inst, Builder);
  3296. }
  3297. }
  3298. }
  3299. static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
  3300. if (Constant *C = dyn_cast<Constant>(V)) {
  3301. if (isa<Constant>(Src)) {
  3302. V->replaceAllUsesWith(Src);
  3303. } else {
  3304. // Replace Constant with a non-Constant.
  3305. IRBuilder<> Builder(MC);
  3306. ReplaceConstantWithInst(C, Src, Builder);
  3307. }
  3308. } else {
  3309. V->replaceAllUsesWith(Src);
  3310. }
  3311. Value *RawDest = MC->getOperand(0);
  3312. Value *RawSrc = MC->getOperand(1);
  3313. MC->eraseFromParent();
  3314. if (Instruction *I = dyn_cast<Instruction>(RawDest)) {
  3315. if (I->user_empty())
  3316. I->eraseFromParent();
  3317. }
  3318. if (Instruction *I = dyn_cast<Instruction>(RawSrc)) {
  3319. if (I->user_empty())
  3320. I->eraseFromParent();
  3321. }
  3322. }
  3323. bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  3324. DxilTypeSystem &typeSys, const DataLayout &DL,
  3325. bool bAllowReplace) {
  3326. Type *Ty = V->getType();
  3327. if (!Ty->isPointerTy()) {
  3328. return false;
  3329. }
  3330. // Get access status and collect memcpy uses.
  3331. // if MemcpyOnce, replace with dest with src if dest is not out param.
  3332. // else flat memcpy.
  3333. unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
  3334. PointerStatus PS(size);
  3335. const bool bStructElt = false;
  3336. PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
  3337. if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
  3338. if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce) {
  3339. // Replace with src of memcpy.
  3340. MemCpyInst *MC = PS.StoringMemcpy;
  3341. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3342. Value *Src = MC->getOperand(1);
  3343. // Only remove one level bitcast generated from inline.
  3344. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  3345. Src = BC->getOperand(0);
  3346. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  3347. // For GEP, the ptr could have other GEP read/write.
  3348. // Only scan one GEP is not enough.
  3349. Value *Ptr = GEP->getPointerOperand();
  3350. if (CallInst *PtrCI = dyn_cast<CallInst>(Ptr)) {
  3351. hlsl::HLOpcodeGroup group =
  3352. hlsl::GetHLOpcodeGroup(PtrCI->getCalledFunction());
  3353. if (group == HLOpcodeGroup::HLSubscript) {
  3354. HLSubscriptOpcode opcode =
  3355. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
  3356. if (opcode == HLSubscriptOpcode::CBufferSubscript) {
  3357. // Ptr from CBuffer is safe.
  3358. ReplaceMemcpy(V, Src, MC);
  3359. return true;
  3360. }
  3361. }
  3362. }
  3363. } else if (!isa<CallInst>(Src)) {
  3364. // Resource ptr should not be replaced.
  3365. // Need to make sure src not updated after current memcpy.
  3366. // Check Src only have 1 store now.
  3367. PointerStatus SrcPS(size);
  3368. PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
  3369. if (SrcPS.StoredType != PointerStatus::StoredType::Stored) {
  3370. ReplaceMemcpy(V, Src, MC);
  3371. return true;
  3372. }
  3373. }
  3374. }
  3375. } else if (PS.LoadedType == PointerStatus::LoadedType::MemcopySrcOnce) {
  3376. // Replace dst of memcpy.
  3377. MemCpyInst *MC = PS.LoadingMemcpy;
  3378. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3379. Value *Dest = MC->getOperand(0);
  3380. // Only remove one level bitcast generated from inline.
  3381. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  3382. Dest = BC->getOperand(0);
  3383. // For GEP, the ptr could have other GEP read/write.
  3384. // Only scan one GEP is not enough.
  3385. // And resource ptr should not be replaced.
  3386. if (!isa<GEPOperator>(Dest) && !isa<CallInst>(Dest) &&
  3387. !isa<BitCastOperator>(Dest)) {
  3388. // Need to make sure Dest not updated after current memcpy.
  3389. // Check Dest only have 1 store now.
  3390. PointerStatus DestPS(size);
  3391. PointerStatus::analyzePointer(Dest, DestPS, typeSys, bStructElt);
  3392. if (DestPS.StoredType != PointerStatus::StoredType::Stored) {
  3393. ReplaceMemcpy(Dest, V, MC);
  3394. // V still need to be flatten.
  3395. // Lower memcpy come from Dest.
  3396. return LowerMemcpy(V, annotation, typeSys, DL, bAllowReplace);
  3397. }
  3398. }
  3399. }
  3400. }
  3401. }
  3402. for (MemCpyInst *MC : PS.memcpySet) {
  3403. MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
  3404. }
  3405. return false;
  3406. }
  3407. /// MarkEmptyStructUsers - Add instruction related to Empty struct to DeadInsts.
  3408. void SROA_Helper::MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts) {
  3409. for (User *U : V->users()) {
  3410. MarkEmptyStructUsers(U, DeadInsts);
  3411. }
  3412. if (Instruction *I = dyn_cast<Instruction>(V)) {
  3413. // Only need to add no use inst here.
  3414. // DeleteDeadInst will delete everything.
  3415. if (I->user_empty())
  3416. DeadInsts.emplace_back(I);
  3417. }
  3418. }
  3419. bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
  3420. if (isa<ArrayType>(Ty))
  3421. Ty = Ty->getArrayElementType();
  3422. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3423. if (!HLMatrixLower::IsMatrixType(Ty)) {
  3424. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  3425. if (SA && SA->IsEmptyStruct())
  3426. return true;
  3427. }
  3428. }
  3429. return false;
  3430. }
  3431. //===----------------------------------------------------------------------===//
  3432. // SROA on function parameters.
  3433. //===----------------------------------------------------------------------===//
  3434. namespace {
  3435. class SROA_Parameter_HLSL : public ModulePass {
  3436. HLModule *m_pHLModule;
  3437. public:
  3438. static char ID; // Pass identification, replacement for typeid
  3439. explicit SROA_Parameter_HLSL() : ModulePass(ID) {}
  3440. const char *getPassName() const override { return "SROA Parameter HLSL"; }
  3441. bool runOnModule(Module &M) override {
  3442. // Patch memcpy to cover case bitcast (gep ptr, 0,0) is transformed into
  3443. // bitcast ptr.
  3444. MemcpySplitter::PatchMemCpyWithZeroIdxGEP(M);
  3445. m_pHLModule = &M.GetOrCreateHLModule();
  3446. // Load up debug information, to cross-reference values and the instructions
  3447. // used to load them.
  3448. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  3449. std::deque<Function *> WorkList;
  3450. for (Function &F : M.functions()) {
  3451. HLOpcodeGroup group = GetHLOpcodeGroup(&F);
  3452. // Skip HL operations.
  3453. if (group != HLOpcodeGroup::NotHL || group == HLOpcodeGroup::HLExtIntrinsic) {
  3454. continue;
  3455. }
  3456. if (F.isDeclaration()) {
  3457. // Skip llvm intrinsic.
  3458. if (F.isIntrinsic())
  3459. continue;
  3460. // Skip unused external function.
  3461. if (F.user_empty())
  3462. continue;
  3463. }
  3464. WorkList.emplace_back(&F);
  3465. }
  3466. // Process the worklist
  3467. while (!WorkList.empty()) {
  3468. Function *F = WorkList.front();
  3469. WorkList.pop_front();
  3470. createFlattenedFunction(F);
  3471. }
  3472. // Replace functions with flattened version when we flat all the functions.
  3473. for (auto Iter : funcMap)
  3474. replaceCall(Iter.first, Iter.second);
  3475. // Remove flattened functions.
  3476. for (auto Iter : funcMap) {
  3477. Function *F = Iter.first;
  3478. F->eraseFromParent();
  3479. }
  3480. // Flatten internal global.
  3481. std::vector<GlobalVariable *> staticGVs;
  3482. for (GlobalVariable &GV : M.globals()) {
  3483. if (HLModule::IsStaticGlobal(&GV) ||
  3484. HLModule::IsSharedMemoryGlobal(&GV)) {
  3485. staticGVs.emplace_back(&GV);
  3486. } else {
  3487. // merge GEP use for global.
  3488. HLModule::MergeGepUse(&GV);
  3489. }
  3490. }
  3491. for (GlobalVariable *GV : staticGVs)
  3492. flattenGlobal(GV);
  3493. // Remove unused internal global.
  3494. staticGVs.clear();
  3495. for (GlobalVariable &GV : M.globals()) {
  3496. if (HLModule::IsStaticGlobal(&GV) ||
  3497. HLModule::IsSharedMemoryGlobal(&GV)) {
  3498. staticGVs.emplace_back(&GV);
  3499. }
  3500. }
  3501. for (GlobalVariable *GV : staticGVs) {
  3502. bool onlyStoreUse = true;
  3503. for (User *user : GV->users()) {
  3504. if (isa<StoreInst>(user))
  3505. continue;
  3506. if (isa<ConstantExpr>(user) && user->user_empty())
  3507. continue;
  3508. // Check matrix store.
  3509. if (HLMatrixLower::IsMatrixType(
  3510. GV->getType()->getPointerElementType())) {
  3511. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  3512. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  3513. HLOpcodeGroup::HLMatLoadStore) {
  3514. HLMatLoadStoreOpcode opcode =
  3515. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  3516. if (opcode == HLMatLoadStoreOpcode::ColMatStore ||
  3517. opcode == HLMatLoadStoreOpcode::RowMatStore)
  3518. continue;
  3519. }
  3520. }
  3521. }
  3522. onlyStoreUse = false;
  3523. break;
  3524. }
  3525. if (onlyStoreUse) {
  3526. for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
  3527. Value *User = *(UserIt++);
  3528. if (Instruction *I = dyn_cast<Instruction>(User)) {
  3529. I->eraseFromParent();
  3530. }
  3531. else {
  3532. ConstantExpr *CE = cast<ConstantExpr>(User);
  3533. CE->dropAllReferences();
  3534. }
  3535. }
  3536. GV->eraseFromParent();
  3537. }
  3538. }
  3539. return true;
  3540. }
  3541. private:
  3542. void DeleteDeadInstructions();
  3543. void moveFunctionBody(Function *F, Function *flatF);
  3544. void replaceCall(Function *F, Function *flatF);
  3545. void createFlattenedFunction(Function *F);
  3546. void createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI);
  3547. void
  3548. flattenArgument(Function *F, Value *Arg, bool bForParam,
  3549. DxilParameterAnnotation &paramAnnotation,
  3550. std::vector<Value *> &FlatParamList,
  3551. std::vector<DxilParameterAnnotation> &FlatRetAnnotationList,
  3552. IRBuilder<> &Builder, DbgDeclareInst *DDI);
  3553. Value *castArgumentIfRequired(Value *V, Type *Ty, bool bOut,
  3554. bool hasShaderInputOutput,
  3555. DxilParamInputQual inputQual,
  3556. DxilFieldAnnotation &annotation,
  3557. std::deque<Value *> &WorkList,
  3558. IRBuilder<> &Builder);
  3559. // Replace argument which changed type when flatten.
  3560. void replaceCastArgument(Value *&NewArg, Value *OldArg,
  3561. DxilParamInputQual inputQual,
  3562. IRBuilder<> &CallBuilder, IRBuilder<> &RetBuilder);
  3563. // Replace use of parameter which changed type when flatten.
  3564. // Also add information to Arg if required.
  3565. void replaceCastParameter(Value *NewParam, Value *OldParam, Function &F,
  3566. Argument *Arg, const DxilParamInputQual inputQual,
  3567. IRBuilder<> &Builder);
  3568. void allocateSemanticIndex(
  3569. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3570. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap);
  3571. bool hasDynamicVectorIndexing(Value *V);
  3572. void flattenGlobal(GlobalVariable *GV);
  3573. /// DeadInsts - Keep track of instructions we have made dead, so that
  3574. /// we can remove them after we are done working.
  3575. SmallVector<Value *, 32> DeadInsts;
  3576. // Map from orginal function to the flatten version.
  3577. std::unordered_map<Function *, Function *> funcMap;
  3578. // Map from original arg/param to flatten cast version.
  3579. std::unordered_map<Value *, std::pair<Value*, DxilParamInputQual>> castParamMap;
  3580. // Map form first element of a vector the list of all elements of the vector.
  3581. std::unordered_map<Value *, SmallVector<Value*, 4> > vectorEltsMap;
  3582. // Set for row major matrix parameter.
  3583. std::unordered_set<Value *> castRowMajorParamMap;
  3584. bool m_HasDbgInfo;
  3585. };
  3586. }
  3587. char SROA_Parameter_HLSL::ID = 0;
  3588. INITIALIZE_PASS(SROA_Parameter_HLSL, "scalarrepl-param-hlsl",
  3589. "Scalar Replacement of Aggregates HLSL (parameters)", false,
  3590. false)
  3591. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  3592. /// recursively including all their operands that become trivially dead.
  3593. void SROA_Parameter_HLSL::DeleteDeadInstructions() {
  3594. while (!DeadInsts.empty()) {
  3595. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  3596. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  3597. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  3598. // Zero out the operand and see if it becomes trivially dead.
  3599. // (But, don't add allocas to the dead instruction list -- they are
  3600. // already on the worklist and will be deleted separately.)
  3601. *OI = nullptr;
  3602. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  3603. DeadInsts.push_back(U);
  3604. }
  3605. I->eraseFromParent();
  3606. }
  3607. }
  3608. bool SROA_Parameter_HLSL::hasDynamicVectorIndexing(Value *V) {
  3609. for (User *U : V->users()) {
  3610. if (!U->getType()->isPointerTy())
  3611. continue;
  3612. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3613. gep_type_iterator GEPIt = gep_type_begin(U), E = gep_type_end(U);
  3614. for (; GEPIt != E; ++GEPIt) {
  3615. if (isa<VectorType>(*GEPIt)) {
  3616. Value *VecIdx = GEPIt.getOperand();
  3617. if (!isa<ConstantInt>(VecIdx))
  3618. return true;
  3619. }
  3620. }
  3621. }
  3622. }
  3623. return false;
  3624. }
  3625. void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
  3626. Type *Ty = GV->getType()->getPointerElementType();
  3627. // Skip basic types.
  3628. if (!Ty->isAggregateType() && !Ty->isVectorTy())
  3629. return;
  3630. std::deque<Value *> WorkList;
  3631. WorkList.push_back(GV);
  3632. // merge GEP use for global.
  3633. HLModule::MergeGepUse(GV);
  3634. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  3635. // Only used to create ConstantExpr.
  3636. IRBuilder<> Builder(m_pHLModule->GetCtx());
  3637. std::vector<Instruction*> deadAllocas;
  3638. const DataLayout &DL = GV->getParent()->getDataLayout();
  3639. unsigned debugOffset = 0;
  3640. std::unordered_map<Value*, StringRef> EltNameMap;
  3641. // Process the worklist
  3642. while (!WorkList.empty()) {
  3643. GlobalVariable *EltGV = cast<GlobalVariable>(WorkList.front());
  3644. WorkList.pop_front();
  3645. const bool bAllowReplace = true;
  3646. if (SROA_Helper::LowerMemcpy(EltGV, /*annoation*/ nullptr, dxilTypeSys, DL,
  3647. bAllowReplace)) {
  3648. continue;
  3649. }
  3650. // Flat Global vector if no dynamic vector indexing.
  3651. bool bFlatVector = !hasDynamicVectorIndexing(EltGV);
  3652. std::vector<Value *> Elts;
  3653. bool SROAed = SROA_Helper::DoScalarReplacement(
  3654. EltGV, Elts, Builder, bFlatVector,
  3655. // TODO: set precise.
  3656. /*hasPrecise*/ false,
  3657. dxilTypeSys, DeadInsts);
  3658. if (SROAed) {
  3659. // Push Elts into workList.
  3660. // Use rbegin to make sure the order not change.
  3661. for (auto iter = Elts.rbegin(); iter != Elts.rend(); iter++) {
  3662. WorkList.push_front(*iter);
  3663. if (m_HasDbgInfo) {
  3664. StringRef EltName = (*iter)->getName().ltrim(GV->getName());
  3665. EltNameMap[*iter] = EltName;
  3666. }
  3667. }
  3668. EltGV->removeDeadConstantUsers();
  3669. // Now erase any instructions that were made dead while rewriting the
  3670. // alloca.
  3671. DeleteDeadInstructions();
  3672. ++NumReplaced;
  3673. } else {
  3674. // Add debug info for flattened globals.
  3675. if (m_HasDbgInfo && GV != EltGV) {
  3676. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  3677. Type *Ty = EltGV->getType()->getElementType();
  3678. unsigned size = DL.getTypeAllocSizeInBits(Ty);
  3679. unsigned align = DL.getPrefTypeAlignment(Ty);
  3680. HLModule::CreateElementGlobalVariableDebugInfo(
  3681. GV, Finder, EltGV, size, align, debugOffset,
  3682. EltNameMap[EltGV]);
  3683. debugOffset += size;
  3684. }
  3685. }
  3686. }
  3687. DeleteDeadInstructions();
  3688. if (GV->user_empty()) {
  3689. GV->removeDeadConstantUsers();
  3690. GV->eraseFromParent();
  3691. }
  3692. }
  3693. static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAnnotation &annotation, DxilTypeSystem &dxilTypeSys) {
  3694. while (Ty->isArrayTy())
  3695. Ty = Ty->getArrayElementType();
  3696. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3697. if (HLMatrixLower::IsMatrixType(Ty))
  3698. return annotation;
  3699. DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
  3700. if (SA) {
  3701. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(idx);
  3702. return FA;
  3703. }
  3704. }
  3705. return annotation;
  3706. }
  3707. // Note: Semantic index allocation.
  3708. // Semantic index is allocated base on linear layout.
  3709. // For following code
  3710. /*
  3711. struct S {
  3712. float4 m;
  3713. float4 m2;
  3714. };
  3715. S s[2] : semantic;
  3716. struct S2 {
  3717. float4 m[2];
  3718. float4 m2[2];
  3719. };
  3720. S2 s2 : semantic;
  3721. */
  3722. // The semantic index is like this:
  3723. // s[0].m : semantic0
  3724. // s[0].m2 : semantic1
  3725. // s[1].m : semantic2
  3726. // s[1].m2 : semantic3
  3727. // s2.m[0] : semantic0
  3728. // s2.m[1] : semantic1
  3729. // s2.m2[0] : semantic2
  3730. // s2.m2[1] : semantic3
  3731. // But when flatten argument, the result is like this:
  3732. // float4 s_m[2], float4 s_m2[2].
  3733. // float4 s2_m[2], float4 s2_m2[2].
  3734. // To do the allocation, need to map from each element to its flattened argument.
  3735. // Say arg index of float4 s_m[2] is 0, float4 s_m2[2] is 1.
  3736. // Need to get 0 from s[0].m and s[1].m, get 1 from s[0].m2 and s[1].m2.
  3737. // Allocate the argments with same semantic string from type where the
  3738. // semantic starts( S2 for s2.m[2] and s2.m2[2]).
  3739. // Iterate each elements of the type, save the semantic index and update it.
  3740. // The map from element to the arg ( s[0].m2 -> s.m2[2]) is done by argIdx.
  3741. // ArgIdx only inc by 1 when finish a struct field.
  3742. static unsigned AllocateSemanticIndex(
  3743. Type *Ty, unsigned &semIndex, unsigned argIdx, unsigned endArgIdx,
  3744. std::vector<DxilParameterAnnotation> &FlatAnnotationList) {
  3745. if (Ty->isPointerTy()) {
  3746. return AllocateSemanticIndex(Ty->getPointerElementType(), semIndex, argIdx,
  3747. endArgIdx, FlatAnnotationList);
  3748. } else if (Ty->isArrayTy()) {
  3749. unsigned arraySize = Ty->getArrayNumElements();
  3750. unsigned updatedArgIdx = argIdx;
  3751. Type *EltTy = Ty->getArrayElementType();
  3752. for (unsigned i = 0; i < arraySize; i++) {
  3753. updatedArgIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3754. FlatAnnotationList);
  3755. }
  3756. return updatedArgIdx;
  3757. } else if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
  3758. unsigned fieldsCount = Ty->getStructNumElements();
  3759. for (unsigned i = 0; i < fieldsCount; i++) {
  3760. Type *EltTy = Ty->getStructElementType(i);
  3761. argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3762. FlatAnnotationList);
  3763. if (!(EltTy->isStructTy() && !HLMatrixLower::IsMatrixType(EltTy))) {
  3764. // Update argIdx only when it is a leaf node.
  3765. argIdx++;
  3766. }
  3767. }
  3768. return argIdx;
  3769. } else {
  3770. DXASSERT(argIdx < endArgIdx, "arg index out of bound");
  3771. DxilParameterAnnotation &paramAnnotation = FlatAnnotationList[argIdx];
  3772. // Get element size.
  3773. unsigned rows = 1;
  3774. if (paramAnnotation.HasMatrixAnnotation()) {
  3775. const DxilMatrixAnnotation &matrix =
  3776. paramAnnotation.GetMatrixAnnotation();
  3777. if (matrix.Orientation == MatrixOrientation::RowMajor) {
  3778. rows = matrix.Rows;
  3779. } else {
  3780. DXASSERT(matrix.Orientation == MatrixOrientation::ColumnMajor, "");
  3781. rows = matrix.Cols;
  3782. }
  3783. }
  3784. // Save semIndex.
  3785. for (unsigned i = 0; i < rows; i++)
  3786. paramAnnotation.AppendSemanticIndex(semIndex + i);
  3787. // Update semIndex.
  3788. semIndex += rows;
  3789. return argIdx;
  3790. }
  3791. }
  3792. void SROA_Parameter_HLSL::allocateSemanticIndex(
  3793. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3794. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap) {
  3795. unsigned endArgIndex = FlatAnnotationList.size();
  3796. // Allocate semantic index.
  3797. for (unsigned i = startArgIndex; i < endArgIndex; ++i) {
  3798. // Group by semantic names.
  3799. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList[i];
  3800. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  3801. // If semantic is undefined, an error will be emitted elsewhere. For now,
  3802. // we should avoid asserting.
  3803. if (semantic.empty())
  3804. continue;
  3805. unsigned semGroupEnd = i + 1;
  3806. while (semGroupEnd < endArgIndex &&
  3807. FlatAnnotationList[semGroupEnd].GetSemanticString() == semantic) {
  3808. ++semGroupEnd;
  3809. }
  3810. StringRef baseSemName; // The 'FOO' in 'FOO1'.
  3811. uint32_t semIndex; // The '1' in 'FOO1'
  3812. // Split semName and index.
  3813. Semantic::DecomposeNameAndIndex(semantic, &baseSemName, &semIndex);
  3814. DXASSERT(semanticTypeMap.count(semantic) > 0, "Must has semantic type");
  3815. Type *semanticTy = semanticTypeMap[semantic];
  3816. AllocateSemanticIndex(semanticTy, semIndex, /*argIdx*/ i,
  3817. /*endArgIdx*/ semGroupEnd, FlatAnnotationList);
  3818. // Update i.
  3819. i = semGroupEnd - 1;
  3820. }
  3821. }
  3822. //
  3823. // Cast parameters.
  3824. //
  3825. static void CopyHandleToResourcePtr(Value *Handle, Value *ResPtr, HLModule &HLM,
  3826. IRBuilder<> &Builder) {
  3827. // Cast it to resource.
  3828. Type *ResTy = ResPtr->getType()->getPointerElementType();
  3829. Value *Res = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3830. (unsigned)HLCastOpcode::HandleToResCast,
  3831. ResTy, {Handle}, *HLM.GetModule());
  3832. // Store casted resource to OldArg.
  3833. Builder.CreateStore(Res, ResPtr);
  3834. }
  3835. static void CopyHandlePtrToResourcePtr(Value *HandlePtr, Value *ResPtr,
  3836. HLModule &HLM, IRBuilder<> &Builder) {
  3837. // Load the handle.
  3838. Value *Handle = Builder.CreateLoad(HandlePtr);
  3839. CopyHandleToResourcePtr(Handle, ResPtr, HLM, Builder);
  3840. }
  3841. static Value *CastResourcePtrToHandle(Value *Res, Type *HandleTy, HLModule &HLM,
  3842. IRBuilder<> &Builder) {
  3843. // Load OldArg.
  3844. Value *LdRes = Builder.CreateLoad(Res);
  3845. Value *Handle = HLM.EmitHLOperationCall(
  3846. Builder, HLOpcodeGroup::HLCreateHandle,
  3847. /*opcode*/ 0, HandleTy, {LdRes}, *HLM.GetModule());
  3848. return Handle;
  3849. }
  3850. static void CopyResourcePtrToHandlePtr(Value *Res, Value *HandlePtr,
  3851. HLModule &HLM, IRBuilder<> &Builder) {
  3852. Type *HandleTy = HandlePtr->getType()->getPointerElementType();
  3853. Value *Handle = CastResourcePtrToHandle(Res, HandleTy, HLM, Builder);
  3854. Builder.CreateStore(Handle, HandlePtr);
  3855. }
  3856. static void CopyVectorPtrToEltsPtr(Value *VecPtr, ArrayRef<Value *> elts,
  3857. unsigned vecSize, IRBuilder<> &Builder) {
  3858. Value *Vec = Builder.CreateLoad(VecPtr);
  3859. for (unsigned i = 0; i < vecSize; i++) {
  3860. Value *Elt = Builder.CreateExtractElement(Vec, i);
  3861. Builder.CreateStore(Elt, elts[i]);
  3862. }
  3863. }
  3864. static void CopyEltsPtrToVectorPtr(ArrayRef<Value *> elts, Value *VecPtr,
  3865. Type *VecTy, unsigned vecSize,
  3866. IRBuilder<> &Builder) {
  3867. Value *Vec = UndefValue::get(VecTy);
  3868. for (unsigned i = 0; i < vecSize; i++) {
  3869. Value *Elt = Builder.CreateLoad(elts[i]);
  3870. Vec = Builder.CreateInsertElement(Vec, Elt, i);
  3871. }
  3872. Builder.CreateStore(Vec, VecPtr);
  3873. }
  3874. static void CopyMatToArrayPtr(Value *Mat, Value *ArrayPtr,
  3875. unsigned arrayBaseIdx, HLModule &HLM,
  3876. IRBuilder<> &Builder, bool bRowMajor) {
  3877. Type *Ty = Mat->getType();
  3878. // Mat val is row major.
  3879. unsigned col, row;
  3880. HLMatrixLower::GetMatrixInfo(Mat->getType(), col, row);
  3881. Type *VecTy = HLMatrixLower::LowerMatrixType(Ty);
  3882. Value *Vec =
  3883. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3884. (unsigned)HLCastOpcode::RowMatrixToVecCast, VecTy,
  3885. {Mat}, *HLM.GetModule());
  3886. Value *zero = Builder.getInt32(0);
  3887. for (unsigned r = 0; r < row; r++) {
  3888. for (unsigned c = 0; c < col; c++) {
  3889. unsigned rowMatIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  3890. Value *Elt = Builder.CreateExtractElement(Vec, rowMatIdx);
  3891. unsigned matIdx =
  3892. bRowMajor ? rowMatIdx : HLMatrixLower::GetColMajorIdx(r, c, row);
  3893. Value *Ptr = Builder.CreateInBoundsGEP(
  3894. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3895. Builder.CreateStore(Elt, Ptr);
  3896. }
  3897. }
  3898. }
  3899. static void CopyMatPtrToArrayPtr(Value *MatPtr, Value *ArrayPtr,
  3900. unsigned arrayBaseIdx, HLModule &HLM,
  3901. IRBuilder<> &Builder, bool bRowMajor) {
  3902. Type *Ty = MatPtr->getType()->getPointerElementType();
  3903. Value *Mat = nullptr;
  3904. if (bRowMajor) {
  3905. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3906. (unsigned)HLMatLoadStoreOpcode::RowMatLoad,
  3907. Ty, {MatPtr}, *HLM.GetModule());
  3908. } else {
  3909. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3910. (unsigned)HLMatLoadStoreOpcode::ColMatLoad,
  3911. Ty, {MatPtr}, *HLM.GetModule());
  3912. // Matrix value should be row major.
  3913. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3914. (unsigned)HLCastOpcode::ColMatrixToRowMatrix,
  3915. Ty, {Mat}, *HLM.GetModule());
  3916. }
  3917. CopyMatToArrayPtr(Mat, ArrayPtr, arrayBaseIdx, HLM, Builder, bRowMajor);
  3918. }
  3919. static Value *LoadArrayPtrToMat(Value *ArrayPtr, unsigned arrayBaseIdx,
  3920. Type *Ty, HLModule &HLM, IRBuilder<> &Builder,
  3921. bool bRowMajor) {
  3922. unsigned col, row;
  3923. HLMatrixLower::GetMatrixInfo(Ty, col, row);
  3924. // HLInit operands are in row major.
  3925. SmallVector<Value *, 16> Elts;
  3926. Value *zero = Builder.getInt32(0);
  3927. for (unsigned r = 0; r < row; r++) {
  3928. for (unsigned c = 0; c < col; c++) {
  3929. unsigned matIdx = bRowMajor ? HLMatrixLower::GetRowMajorIdx(r, c, col)
  3930. : HLMatrixLower::GetColMajorIdx(r, c, row);
  3931. Value *Ptr = Builder.CreateInBoundsGEP(
  3932. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3933. Value *Elt = Builder.CreateLoad(Ptr);
  3934. Elts.emplace_back(Elt);
  3935. }
  3936. }
  3937. return HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLInit,
  3938. /*opcode*/ 0, Ty, {Elts}, *HLM.GetModule());
  3939. }
  3940. static void CopyArrayPtrToMatPtr(Value *ArrayPtr, unsigned arrayBaseIdx,
  3941. Value *MatPtr, HLModule &HLM,
  3942. IRBuilder<> &Builder, bool bRowMajor) {
  3943. Type *Ty = MatPtr->getType()->getPointerElementType();
  3944. Value *Mat =
  3945. LoadArrayPtrToMat(ArrayPtr, arrayBaseIdx, Ty, HLM, Builder, bRowMajor);
  3946. if (bRowMajor) {
  3947. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3948. (unsigned)HLMatLoadStoreOpcode::RowMatStore, Ty,
  3949. {MatPtr, Mat}, *HLM.GetModule());
  3950. } else {
  3951. // Mat is row major.
  3952. // Cast it to col major before store.
  3953. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3954. (unsigned)HLCastOpcode::RowMatrixToColMatrix,
  3955. Ty, {Mat}, *HLM.GetModule());
  3956. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3957. (unsigned)HLMatLoadStoreOpcode::ColMatStore, Ty,
  3958. {MatPtr, Mat}, *HLM.GetModule());
  3959. }
  3960. }
  3961. using CopyFunctionTy = void(Value *FromPtr, Value *ToPtr, HLModule &HLM,
  3962. Type *HandleTy, IRBuilder<> &Builder,
  3963. bool bRowMajor);
  3964. static void
  3965. CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
  3966. std::vector<Value *> &idxList, unsigned calcIdx,
  3967. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  3968. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  3969. if (CurFromTy->isVectorTy()) {
  3970. // Copy vector to array.
  3971. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  3972. Value *V = Builder.CreateLoad(FromPtr);
  3973. unsigned vecSize = CurFromTy->getVectorNumElements();
  3974. Value *zeroIdx = Builder.getInt32(0);
  3975. for (unsigned i = 0; i < vecSize; i++) {
  3976. Value *ToPtr = Builder.CreateInBoundsGEP(
  3977. ToArray, {zeroIdx, Builder.getInt32(calcIdx++)});
  3978. Value *Elt = Builder.CreateExtractElement(V, i);
  3979. Builder.CreateStore(Elt, ToPtr);
  3980. }
  3981. } else if (HLMatrixLower::IsMatrixType(CurFromTy)) {
  3982. // Copy matrix to array.
  3983. unsigned col, row;
  3984. HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
  3985. // Calculate the offset.
  3986. unsigned offset = calcIdx * col * row;
  3987. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  3988. CopyMatPtrToArrayPtr(FromPtr, ToArray, offset, HLM, Builder, bRowMajor);
  3989. } else if (!CurFromTy->isArrayTy()) {
  3990. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  3991. Value *ToPtr = Builder.CreateInBoundsGEP(
  3992. ToArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  3993. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  3994. } else {
  3995. unsigned size = CurFromTy->getArrayNumElements();
  3996. Type *FromEltTy = CurFromTy->getArrayElementType();
  3997. for (unsigned i = 0; i < size; i++) {
  3998. idxList.push_back(Builder.getInt32(i));
  3999. unsigned idx = calcIdx * size + i;
  4000. CastCopyArrayMultiDimTo1Dim(FromArray, ToArray, FromEltTy, idxList, idx,
  4001. HandleTy, HLM, Builder, CastCopyFn,
  4002. bRowMajor);
  4003. idxList.pop_back();
  4004. }
  4005. }
  4006. }
  4007. static void
  4008. CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
  4009. std::vector<Value *> &idxList, unsigned calcIdx,
  4010. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  4011. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  4012. if (CurToTy->isVectorTy()) {
  4013. // Copy array to vector.
  4014. Value *V = UndefValue::get(CurToTy);
  4015. unsigned vecSize = CurToTy->getVectorNumElements();
  4016. // Calculate the offset.
  4017. unsigned offset = calcIdx * vecSize;
  4018. Value *zeroIdx = Builder.getInt32(0);
  4019. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4020. for (unsigned i = 0; i < vecSize; i++) {
  4021. Value *FromPtr = Builder.CreateInBoundsGEP(
  4022. FromArray, {zeroIdx, Builder.getInt32(offset++)});
  4023. Value *Elt = Builder.CreateLoad(FromPtr);
  4024. V = Builder.CreateInsertElement(V, Elt, i);
  4025. }
  4026. Builder.CreateStore(V, ToPtr);
  4027. } else if (HLMatrixLower::IsMatrixType(CurToTy)) {
  4028. // Copy array to matrix.
  4029. unsigned col, row;
  4030. HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
  4031. // Calculate the offset.
  4032. unsigned offset = calcIdx * col * row;
  4033. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4034. CopyArrayPtrToMatPtr(FromArray, offset, ToPtr, HLM, Builder, bRowMajor);
  4035. } else if (!CurToTy->isArrayTy()) {
  4036. Value *FromPtr = Builder.CreateInBoundsGEP(
  4037. FromArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  4038. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4039. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  4040. } else {
  4041. unsigned size = CurToTy->getArrayNumElements();
  4042. Type *ToEltTy = CurToTy->getArrayElementType();
  4043. for (unsigned i = 0; i < size; i++) {
  4044. idxList.push_back(Builder.getInt32(i));
  4045. unsigned idx = calcIdx * size + i;
  4046. CastCopyArray1DimToMultiDim(FromArray, ToArray, ToEltTy, idxList, idx,
  4047. HandleTy, HLM, Builder, CastCopyFn,
  4048. bRowMajor);
  4049. idxList.pop_back();
  4050. }
  4051. }
  4052. }
  4053. static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
  4054. Type *HandleTy, IRBuilder<> &Builder,
  4055. bool bRowMajor) {
  4056. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4057. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4058. if (NewTy == HandleTy) {
  4059. CopyResourcePtrToHandlePtr(OldPtr, NewPtr, HLM, Builder);
  4060. } else if (OldTy->isVectorTy()) {
  4061. // Copy vector to array.
  4062. Value *V = Builder.CreateLoad(OldPtr);
  4063. unsigned vecSize = OldTy->getVectorNumElements();
  4064. Value *zeroIdx = Builder.getInt32(0);
  4065. for (unsigned i = 0; i < vecSize; i++) {
  4066. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4067. Value *Elt = Builder.CreateExtractElement(V, i);
  4068. Builder.CreateStore(Elt, EltPtr);
  4069. }
  4070. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4071. CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
  4072. bRowMajor);
  4073. } else if (OldTy->isArrayTy()) {
  4074. std::vector<Value *> idxList;
  4075. idxList.emplace_back(Builder.getInt32(0));
  4076. CastCopyArrayMultiDimTo1Dim(OldPtr, NewPtr, OldTy, idxList, /*calcIdx*/ 0,
  4077. HandleTy, HLM, Builder, CastCopyOldPtrToNewPtr,
  4078. bRowMajor);
  4079. }
  4080. }
  4081. static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
  4082. Type *HandleTy, IRBuilder<> &Builder,
  4083. bool bRowMajor) {
  4084. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4085. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4086. if (NewTy == HandleTy) {
  4087. CopyHandlePtrToResourcePtr(NewPtr, OldPtr, HLM, Builder);
  4088. } else if (OldTy->isVectorTy()) {
  4089. // Copy array to vector.
  4090. Value *V = UndefValue::get(OldTy);
  4091. unsigned vecSize = OldTy->getVectorNumElements();
  4092. Value *zeroIdx = Builder.getInt32(0);
  4093. for (unsigned i = 0; i < vecSize; i++) {
  4094. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4095. Value *Elt = Builder.CreateLoad(EltPtr);
  4096. V = Builder.CreateInsertElement(V, Elt, i);
  4097. }
  4098. Builder.CreateStore(V, OldPtr);
  4099. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4100. CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
  4101. bRowMajor);
  4102. } else if (OldTy->isArrayTy()) {
  4103. std::vector<Value *> idxList;
  4104. idxList.emplace_back(Builder.getInt32(0));
  4105. CastCopyArray1DimToMultiDim(NewPtr, OldPtr, OldTy, idxList, /*calcIdx*/ 0,
  4106. HandleTy, HLM, Builder, CastCopyNewPtrToOldPtr,
  4107. bRowMajor);
  4108. }
  4109. }
  4110. void SROA_Parameter_HLSL::replaceCastArgument(Value *&NewArg, Value *OldArg,
  4111. DxilParamInputQual inputQual,
  4112. IRBuilder<> &CallBuilder,
  4113. IRBuilder<> &RetBuilder) {
  4114. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4115. Type *NewTy = NewArg->getType();
  4116. Type *OldTy = OldArg->getType();
  4117. bool bIn = inputQual == DxilParamInputQual::Inout ||
  4118. inputQual == DxilParamInputQual::In;
  4119. bool bOut = inputQual == DxilParamInputQual::Inout ||
  4120. inputQual == DxilParamInputQual::Out;
  4121. if (NewArg->getType() == HandleTy) {
  4122. Value *Handle =
  4123. CastResourcePtrToHandle(OldArg, HandleTy, *m_pHLModule, CallBuilder);
  4124. // Use Handle as NewArg.
  4125. NewArg = Handle;
  4126. } else if (vectorEltsMap.count(NewArg)) {
  4127. Type *VecTy = OldTy;
  4128. if (VecTy->isPointerTy())
  4129. VecTy = VecTy->getPointerElementType();
  4130. // Flattened vector.
  4131. SmallVector<Value *, 4> &elts = vectorEltsMap[NewArg];
  4132. unsigned vecSize = elts.size();
  4133. if (NewTy->isPointerTy()) {
  4134. if (bIn) {
  4135. // Copy OldArg to NewArg before Call.
  4136. CopyVectorPtrToEltsPtr(OldArg, elts, vecSize, CallBuilder);
  4137. }
  4138. // bOut must be true here.
  4139. // Store NewArg to OldArg after Call.
  4140. CopyEltsPtrToVectorPtr(elts, OldArg, VecTy, vecSize, RetBuilder);
  4141. } else {
  4142. // Must be in parameter.
  4143. // Copy OldArg to NewArg before Call.
  4144. Value *Vec = OldArg;
  4145. if (OldTy->isPointerTy()) {
  4146. Vec = CallBuilder.CreateLoad(OldArg);
  4147. }
  4148. for (unsigned i = 0; i < vecSize; i++) {
  4149. Value *Elt = CallBuilder.CreateExtractElement(Vec, i);
  4150. // Save elt to update arg in createFlattenedFunctionCall.
  4151. elts[i] = Elt;
  4152. }
  4153. }
  4154. } else if (!NewTy->isPointerTy()) {
  4155. // Ptr param is cast to non-ptr param.
  4156. // Must be in param.
  4157. // Load OldArg as NewArg before call.
  4158. NewArg = CallBuilder.CreateLoad(OldArg);
  4159. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4160. bool bRowMajor = castRowMajorParamMap.count(NewArg);
  4161. CopyMatToArrayPtr(OldArg, NewArg, /*arrayBaseIdx*/ 0, *m_pHLModule,
  4162. CallBuilder, bRowMajor);
  4163. } else {
  4164. bool bRowMajor = castRowMajorParamMap.count(NewArg);
  4165. // NewTy is pointer type.
  4166. // Copy OldArg to NewArg before Call.
  4167. if (bIn) {
  4168. CastCopyOldPtrToNewPtr(OldArg, NewArg, *m_pHLModule, HandleTy,
  4169. CallBuilder, bRowMajor);
  4170. }
  4171. if (bOut) {
  4172. // Store NewArg to OldArg after Call.
  4173. CastCopyNewPtrToOldPtr(NewArg, OldArg, *m_pHLModule, HandleTy, RetBuilder,
  4174. bRowMajor);
  4175. }
  4176. }
  4177. }
  4178. void SROA_Parameter_HLSL::replaceCastParameter(
  4179. Value *NewParam, Value *OldParam, Function &F, Argument *Arg,
  4180. const DxilParamInputQual inputQual, IRBuilder<> &Builder) {
  4181. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4182. Type *HandlePtrTy = PointerType::get(HandleTy, 0);
  4183. Module &M = *m_pHLModule->GetModule();
  4184. Type *NewTy = NewParam->getType();
  4185. Type *OldTy = OldParam->getType();
  4186. bool bIn = inputQual == DxilParamInputQual::Inout ||
  4187. inputQual == DxilParamInputQual::In;
  4188. bool bOut = inputQual == DxilParamInputQual::Inout ||
  4189. inputQual == DxilParamInputQual::Out;
  4190. // Make sure InsertPoint after OldParam inst.
  4191. if (Instruction *I = dyn_cast<Instruction>(OldParam)) {
  4192. Builder.SetInsertPoint(I->getNextNode());
  4193. }
  4194. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(OldParam)) {
  4195. // Add debug info to new param.
  4196. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  4197. DIExpression *DDIExp = DDI->getExpression();
  4198. DIB.insertDeclare(NewParam, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4199. Builder.GetInsertPoint());
  4200. }
  4201. if (isa<Argument>(OldParam) && OldTy->isPointerTy()) {
  4202. // OldParam will be removed with Old function.
  4203. // Create alloca to replace it.
  4204. Value *AllocParam = Builder.CreateAlloca(OldTy->getPointerElementType());
  4205. OldParam->replaceAllUsesWith(AllocParam);
  4206. OldParam = AllocParam;
  4207. }
  4208. if (NewTy == HandleTy) {
  4209. CopyHandleToResourcePtr(NewParam, OldParam, *m_pHLModule, Builder);
  4210. // Save resource attribute.
  4211. Type *ResTy = OldTy->getPointerElementType();
  4212. MDNode *MD = HLModule::GetDxilResourceAttrib(ResTy, M);
  4213. m_pHLModule->MarkDxilResourceAttrib(Arg, MD);
  4214. } else if (vectorEltsMap.count(NewParam)) {
  4215. // Vector is flattened to scalars.
  4216. Type *VecTy = OldTy;
  4217. if (VecTy->isPointerTy())
  4218. VecTy = VecTy->getPointerElementType();
  4219. // Flattened vector.
  4220. SmallVector<Value *, 4> &elts = vectorEltsMap[NewParam];
  4221. unsigned vecSize = elts.size();
  4222. if (NewTy->isPointerTy()) {
  4223. if (bIn) {
  4224. // Copy NewParam to OldParam at entry.
  4225. CopyEltsPtrToVectorPtr(elts, OldParam, VecTy, vecSize, Builder);
  4226. }
  4227. // bOut must be true here.
  4228. // Store the OldParam to NewParam before every return.
  4229. for (auto &BB : F.getBasicBlockList()) {
  4230. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4231. IRBuilder<> RetBuilder(RI);
  4232. CopyVectorPtrToEltsPtr(OldParam, elts, vecSize, RetBuilder);
  4233. }
  4234. }
  4235. } else {
  4236. // Must be in parameter.
  4237. // Copy NewParam to OldParam at entry.
  4238. Value *Vec = UndefValue::get(VecTy);
  4239. for (unsigned i = 0; i < vecSize; i++) {
  4240. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  4241. }
  4242. if (OldTy->isPointerTy()) {
  4243. Builder.CreateStore(Vec, OldParam);
  4244. } else {
  4245. OldParam->replaceAllUsesWith(Vec);
  4246. }
  4247. }
  4248. } else if (!NewTy->isPointerTy()) {
  4249. // Ptr param is cast to non-ptr param.
  4250. // Must be in param.
  4251. // Store NewParam to OldParam at entry.
  4252. Builder.CreateStore(NewParam, OldParam);
  4253. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4254. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4255. Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
  4256. *m_pHLModule, Builder, bRowMajor);
  4257. OldParam->replaceAllUsesWith(Mat);
  4258. } else {
  4259. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4260. // NewTy is pointer type.
  4261. if (bIn) {
  4262. // Copy NewParam to OldParam at entry.
  4263. CastCopyNewPtrToOldPtr(NewParam, OldParam, *m_pHLModule, HandleTy,
  4264. Builder, bRowMajor);
  4265. }
  4266. if (bOut) {
  4267. // Store the OldParam to NewParam before every return.
  4268. for (auto &BB : F.getBasicBlockList()) {
  4269. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4270. IRBuilder<> RetBuilder(RI);
  4271. CastCopyOldPtrToNewPtr(OldParam, NewParam, *m_pHLModule, HandleTy,
  4272. RetBuilder, bRowMajor);
  4273. }
  4274. }
  4275. }
  4276. Type *NewEltTy = HLModule::GetArrayEltTy(NewTy);
  4277. Type *OldEltTy = HLModule::GetArrayEltTy(OldTy);
  4278. if (NewEltTy == HandlePtrTy) {
  4279. // Save resource attribute.
  4280. Type *ResTy = OldEltTy;
  4281. MDNode *MD = HLModule::GetDxilResourceAttrib(ResTy, M);
  4282. m_pHLModule->MarkDxilResourceAttrib(Arg, MD);
  4283. }
  4284. }
  4285. }
  4286. Value *SROA_Parameter_HLSL::castArgumentIfRequired(
  4287. Value *V, Type *Ty, bool bOut, bool hasShaderInputOutput,
  4288. DxilParamInputQual inputQual, DxilFieldAnnotation &annotation,
  4289. std::deque<Value *> &WorkList, IRBuilder<> &Builder) {
  4290. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4291. Module &M = *m_pHLModule->GetModule();
  4292. // Remove pointer for vector/scalar which is not out.
  4293. if (V->getType()->isPointerTy() && !Ty->isAggregateType() && !bOut) {
  4294. Value *Ptr = Builder.CreateAlloca(Ty);
  4295. V->replaceAllUsesWith(Ptr);
  4296. // Create load here to make correct type.
  4297. // The Ptr will be store with correct value in replaceCastParameter and
  4298. // replaceCastArgument.
  4299. V = Builder.CreateLoad(Ptr);
  4300. castParamMap[V] = std::make_pair(Ptr, inputQual);
  4301. }
  4302. // Lower resource type to handle ty.
  4303. if (HLModule::IsHLSLObjectType(Ty) &&
  4304. !HLModule::IsStreamOutputPtrType(V->getType())) {
  4305. Value *Res = V;
  4306. if (!bOut) {
  4307. Value *LdRes = Builder.CreateLoad(Res);
  4308. V = m_pHLModule->EmitHLOperationCall(Builder,
  4309. HLOpcodeGroup::HLCreateHandle,
  4310. /*opcode*/ 0, HandleTy, {LdRes}, M);
  4311. } else {
  4312. V = Builder.CreateAlloca(HandleTy);
  4313. }
  4314. castParamMap[V] = std::make_pair(Res, inputQual);
  4315. } else if (Ty->isArrayTy()) {
  4316. unsigned arraySize = 1;
  4317. Type *AT = Ty;
  4318. while (AT->isArrayTy()) {
  4319. arraySize *= AT->getArrayNumElements();
  4320. AT = AT->getArrayElementType();
  4321. }
  4322. if (HLModule::IsHLSLObjectType(AT)) {
  4323. Value *Res = V;
  4324. Type *Ty = ArrayType::get(HandleTy, arraySize);
  4325. V = Builder.CreateAlloca(Ty);
  4326. castParamMap[V] = std::make_pair(Res, inputQual);
  4327. }
  4328. }
  4329. if (!hasShaderInputOutput) {
  4330. if (Ty->isVectorTy()) {
  4331. Value *OldV = V;
  4332. Type *EltTy = Ty->getVectorElementType();
  4333. unsigned vecSize = Ty->getVectorNumElements();
  4334. // Split vector into scalars.
  4335. if (OldV->getType()->isPointerTy()) {
  4336. // Split into scalar ptr.
  4337. V = Builder.CreateAlloca(EltTy);
  4338. vectorEltsMap[V].emplace_back(V);
  4339. for (unsigned i = 1; i < vecSize; i++) {
  4340. Value *Elt = Builder.CreateAlloca(EltTy);
  4341. vectorEltsMap[V].emplace_back(Elt);
  4342. }
  4343. } else {
  4344. // Split into scalar.
  4345. V = Builder.CreateExtractElement(OldV, (uint64_t)0);
  4346. vectorEltsMap[V].emplace_back(V);
  4347. for (unsigned i = 1; i < vecSize; i++) {
  4348. Value *Elt = Builder.CreateExtractElement(OldV, i);
  4349. vectorEltsMap[V].emplace_back(Elt);
  4350. }
  4351. }
  4352. // Add to work list by reverse order.
  4353. for (unsigned i = vecSize - 1; i > 0; i--) {
  4354. Value *Elt = vectorEltsMap[V][i];
  4355. WorkList.push_front(Elt);
  4356. }
  4357. // For case OldV is from input vector ptr.
  4358. if (castParamMap.count(OldV)) {
  4359. OldV = castParamMap[OldV].first;
  4360. }
  4361. castParamMap[V] = std::make_pair(OldV, inputQual);
  4362. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  4363. unsigned col, row;
  4364. Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  4365. Value *Mat = V;
  4366. // Cast matrix to array.
  4367. Type *AT = ArrayType::get(EltTy, col * row);
  4368. V = Builder.CreateAlloca(AT);
  4369. castParamMap[V] = std::make_pair(Mat, inputQual);
  4370. DXASSERT(annotation.HasMatrixAnnotation(), "need matrix annotation here");
  4371. if (annotation.GetMatrixAnnotation().Orientation ==
  4372. hlsl::MatrixOrientation::RowMajor) {
  4373. castRowMajorParamMap.insert(V);
  4374. }
  4375. } else if (Ty->isArrayTy()) {
  4376. unsigned arraySize = 1;
  4377. Type *AT = Ty;
  4378. unsigned dim = 0;
  4379. while (AT->isArrayTy()) {
  4380. ++dim;
  4381. arraySize *= AT->getArrayNumElements();
  4382. AT = AT->getArrayElementType();
  4383. }
  4384. if (VectorType *VT = dyn_cast<VectorType>(AT)) {
  4385. Value *VecArray = V;
  4386. Type *AT = ArrayType::get(VT->getElementType(),
  4387. arraySize * VT->getNumElements());
  4388. V = Builder.CreateAlloca(AT);
  4389. castParamMap[V] = std::make_pair(VecArray, inputQual);
  4390. } else if (HLMatrixLower::IsMatrixType(AT)) {
  4391. unsigned col, row;
  4392. Type *EltTy = HLMatrixLower::GetMatrixInfo(AT, col, row);
  4393. Value *MatArray = V;
  4394. Type *AT = ArrayType::get(EltTy, arraySize * col * row);
  4395. V = Builder.CreateAlloca(AT);
  4396. castParamMap[V] = std::make_pair(MatArray, inputQual);
  4397. DXASSERT(annotation.HasMatrixAnnotation(),
  4398. "need matrix annotation here");
  4399. if (annotation.GetMatrixAnnotation().Orientation ==
  4400. hlsl::MatrixOrientation::RowMajor) {
  4401. castRowMajorParamMap.insert(V);
  4402. }
  4403. } else if (dim > 1) {
  4404. // Flatten multi-dim array to 1dim.
  4405. Value *MultiArray = V;
  4406. V = Builder.CreateAlloca(
  4407. ArrayType::get(VT->getElementType(), arraySize));
  4408. castParamMap[V] = std::make_pair(MultiArray, inputQual);
  4409. }
  4410. }
  4411. } else {
  4412. // Entry function matrix value parameter has major.
  4413. // Make sure its user use row major matrix value.
  4414. bool updateToColMajor = annotation.HasMatrixAnnotation() &&
  4415. annotation.GetMatrixAnnotation().Orientation ==
  4416. MatrixOrientation::ColumnMajor;
  4417. if (updateToColMajor) {
  4418. if (V->getType()->isPointerTy()) {
  4419. for (User *user : V->users()) {
  4420. CallInst *CI = dyn_cast<CallInst>(user);
  4421. if (!CI)
  4422. continue;
  4423. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4424. if (group != HLOpcodeGroup::HLMatLoadStore)
  4425. continue;
  4426. HLMatLoadStoreOpcode opcode =
  4427. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4428. Type *opcodeTy = Builder.getInt32Ty();
  4429. switch (opcode) {
  4430. case HLMatLoadStoreOpcode::RowMatLoad: {
  4431. // Update matrix function opcode to col major version.
  4432. Value *rowOpArg = ConstantInt::get(
  4433. opcodeTy,
  4434. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
  4435. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4436. // Cast it to row major.
  4437. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4438. Builder, HLOpcodeGroup::HLCast,
  4439. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
  4440. CI->replaceAllUsesWith(RowMat);
  4441. // Set arg to CI again.
  4442. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
  4443. } break;
  4444. case HLMatLoadStoreOpcode::RowMatStore:
  4445. // Update matrix function opcode to col major version.
  4446. Value *rowOpArg = ConstantInt::get(
  4447. opcodeTy,
  4448. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
  4449. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4450. Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  4451. // Cast it to col major.
  4452. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4453. Builder, HLOpcodeGroup::HLCast,
  4454. (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
  4455. CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
  4456. break;
  4457. }
  4458. }
  4459. } else {
  4460. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4461. Builder, HLOpcodeGroup::HLCast,
  4462. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
  4463. V->replaceAllUsesWith(RowMat);
  4464. // Set arg to V again.
  4465. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
  4466. }
  4467. }
  4468. }
  4469. return V;
  4470. }
  4471. void SROA_Parameter_HLSL::flattenArgument(
  4472. Function *F, Value *Arg, bool bForParam,
  4473. DxilParameterAnnotation &paramAnnotation,
  4474. std::vector<Value *> &FlatParamList,
  4475. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  4476. IRBuilder<> &Builder, DbgDeclareInst *DDI) {
  4477. std::deque<Value *> WorkList;
  4478. WorkList.push_back(Arg);
  4479. Function *Entry = m_pHLModule->GetEntryFunction();
  4480. bool hasShaderInputOutput = F == Entry;
  4481. if (m_pHLModule->HasHLFunctionProps(Entry)) {
  4482. HLFunctionProps &funcProps = m_pHLModule->GetHLFunctionProps(Entry);
  4483. if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
  4484. Function *patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc;
  4485. hasShaderInputOutput |= F == patchConstantFunc;
  4486. }
  4487. }
  4488. unsigned startArgIndex = FlatAnnotationList.size();
  4489. // Map from value to annotation.
  4490. std::unordered_map<Value *, DxilFieldAnnotation> annotationMap;
  4491. annotationMap[Arg] = paramAnnotation;
  4492. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  4493. const std::string &semantic = paramAnnotation.GetSemanticString();
  4494. bool bSemOverride = !semantic.empty();
  4495. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4496. bool bOut = inputQual == DxilParamInputQual::Out ||
  4497. inputQual == DxilParamInputQual::Inout ||
  4498. inputQual == DxilParamInputQual::OutStream0 ||
  4499. inputQual == DxilParamInputQual::OutStream1 ||
  4500. inputQual == DxilParamInputQual::OutStream2 ||
  4501. inputQual == DxilParamInputQual::OutStream3;
  4502. // Map from semantic string to type.
  4503. llvm::StringMap<Type *> semanticTypeMap;
  4504. // Original semantic type.
  4505. if (!semantic.empty()) {
  4506. // Unwrap top-level array if primitive
  4507. if (inputQual == DxilParamInputQual::InputPatch ||
  4508. inputQual == DxilParamInputQual::OutputPatch ||
  4509. inputQual == DxilParamInputQual::InputPrimitive) {
  4510. Type *Ty = Arg->getType();
  4511. if (Ty->isPointerTy())
  4512. Ty = Ty->getPointerElementType();
  4513. if (Ty->isArrayTy())
  4514. semanticTypeMap[semantic] = Ty->getArrayElementType();
  4515. } else {
  4516. semanticTypeMap[semantic] = Arg->getType();
  4517. }
  4518. }
  4519. std::vector<Instruction*> deadAllocas;
  4520. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  4521. unsigned debugOffset = 0;
  4522. const DataLayout &DL = F->getParent()->getDataLayout();
  4523. // Process the worklist
  4524. while (!WorkList.empty()) {
  4525. Value *V = WorkList.front();
  4526. WorkList.pop_front();
  4527. // Do not skip unused parameter.
  4528. DxilFieldAnnotation &annotation = annotationMap[V];
  4529. const bool bAllowReplace = !bOut;
  4530. SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, bAllowReplace);
  4531. std::vector<Value *> Elts;
  4532. // Not flat vector for entry function currently.
  4533. bool SROAed = SROA_Helper::DoScalarReplacement(
  4534. V, Elts, Builder, /*bFlatVector*/ false, annotation.IsPrecise(),
  4535. dxilTypeSys, DeadInsts);
  4536. if (SROAed) {
  4537. Type *Ty = V->getType()->getPointerElementType();
  4538. // Skip empty struct parameters.
  4539. if (SROA_Helper::IsEmptyStructType(Ty, dxilTypeSys)) {
  4540. SROA_Helper::MarkEmptyStructUsers(V, DeadInsts);
  4541. DeleteDeadInstructions();
  4542. continue;
  4543. }
  4544. // Push Elts into workList.
  4545. // Use rbegin to make sure the order not change.
  4546. for (auto iter = Elts.rbegin(); iter != Elts.rend(); iter++)
  4547. WorkList.push_front(*iter);
  4548. bool precise = annotation.IsPrecise();
  4549. const std::string &semantic = annotation.GetSemanticString();
  4550. hlsl::InterpolationMode interpMode = annotation.GetInterpolationMode();
  4551. for (unsigned i=0;i<Elts.size();i++) {
  4552. Value *Elt = Elts[i];
  4553. DxilFieldAnnotation EltAnnotation = GetEltAnnotation(Ty, i, annotation, dxilTypeSys);
  4554. const std::string &eltSem = EltAnnotation.GetSemanticString();
  4555. if (!semantic.empty()) {
  4556. if (!eltSem.empty()) {
  4557. // TODO: warning for override the semantic in EltAnnotation.
  4558. }
  4559. // Just save parent semantic here, allocate later.
  4560. EltAnnotation.SetSemanticString(semantic);
  4561. } else if (!eltSem.empty() &&
  4562. semanticTypeMap.count(eltSem) == 0) {
  4563. Type *EltTy = HLModule::GetArrayEltTy(Ty);
  4564. DXASSERT(EltTy->isStructTy(), "must be a struct type to has semantic.");
  4565. semanticTypeMap[eltSem] = EltTy->getStructElementType(i);
  4566. }
  4567. if (precise)
  4568. EltAnnotation.SetPrecise();
  4569. if (EltAnnotation.GetInterpolationMode().GetKind() == DXIL::InterpolationMode::Undefined)
  4570. EltAnnotation.SetInterpolationMode(interpMode);
  4571. annotationMap[Elt] = EltAnnotation;
  4572. }
  4573. annotationMap.erase(V);
  4574. ++NumReplaced;
  4575. if (Instruction *I = dyn_cast<Instruction>(V))
  4576. deadAllocas.emplace_back(I);
  4577. } else {
  4578. if (bSemOverride) {
  4579. if (!annotation.GetSemanticString().empty()) {
  4580. // TODO: warning for override the semantic in EltAnnotation.
  4581. }
  4582. // Just save parent semantic here, allocate later.
  4583. annotation.SetSemanticString(semantic);
  4584. }
  4585. Type *Ty = V->getType();
  4586. if (Ty->isPointerTy())
  4587. Ty = Ty->getPointerElementType();
  4588. // Flatten array of SV_Target.
  4589. StringRef semanticStr = annotation.GetSemanticString();
  4590. if (semanticStr.upper().find("SV_TARGET") == 0 &&
  4591. Ty->isArrayTy()) {
  4592. Type *Ty = cast<ArrayType>(V->getType()->getPointerElementType());
  4593. StringRef targetStr;
  4594. unsigned targetIndex;
  4595. Semantic::DecomposeNameAndIndex(semanticStr, &targetStr, &targetIndex);
  4596. // Replace target parameter with local target.
  4597. AllocaInst *localTarget = Builder.CreateAlloca(Ty);
  4598. V->replaceAllUsesWith(localTarget);
  4599. unsigned arraySize = 1;
  4600. std::vector<unsigned> arraySizeList;
  4601. while (Ty->isArrayTy()) {
  4602. unsigned size = Ty->getArrayNumElements();
  4603. arraySizeList.emplace_back(size);
  4604. arraySize *= size;
  4605. Ty = Ty->getArrayElementType();
  4606. }
  4607. unsigned arrayLevel = arraySizeList.size();
  4608. std::vector<unsigned> arrayIdxList(arrayLevel, 0);
  4609. // Create flattened target.
  4610. DxilFieldAnnotation EltAnnotation = annotation;
  4611. for (unsigned i=0;i<arraySize;i++) {
  4612. Value *Elt = Builder.CreateAlloca(Ty);
  4613. EltAnnotation.SetSemanticString(targetStr.str()+std::to_string(targetIndex+i));
  4614. // Add semantic type.
  4615. semanticTypeMap[EltAnnotation.GetSemanticString()] = Ty;
  4616. annotationMap[Elt] = EltAnnotation;
  4617. WorkList.push_front(Elt);
  4618. // Copy local target to flattened target.
  4619. std::vector<Value*> idxList(arrayLevel+1);
  4620. idxList[0] = Builder.getInt32(0);
  4621. for (unsigned idx=0;idx<arrayLevel; idx++) {
  4622. idxList[idx+1] = Builder.getInt32(arrayIdxList[idx]);
  4623. }
  4624. if (bForParam) {
  4625. // If Argument, copy before each return.
  4626. for (auto &BB : F->getBasicBlockList()) {
  4627. TerminatorInst *TI = BB.getTerminator();
  4628. if (isa<ReturnInst>(TI)) {
  4629. IRBuilder<> RetBuilder(TI);
  4630. Value *Ptr = RetBuilder.CreateGEP(localTarget, idxList);
  4631. Value *V = RetBuilder.CreateLoad(Ptr);
  4632. RetBuilder.CreateStore(V, Elt);
  4633. }
  4634. }
  4635. } else {
  4636. // Else, copy with Builder.
  4637. Value *Ptr = Builder.CreateGEP(localTarget, idxList);
  4638. Value *V = Builder.CreateLoad(Ptr);
  4639. Builder.CreateStore(V, Elt);
  4640. }
  4641. // Update arrayIdxList.
  4642. for (unsigned idx=arrayLevel;idx>0;idx--) {
  4643. arrayIdxList[idx-1]++;
  4644. if (arrayIdxList[idx-1] < arraySizeList[idx-1])
  4645. break;
  4646. arrayIdxList[idx-1] = 0;
  4647. }
  4648. }
  4649. // Don't override flattened SV_Target.
  4650. if (V == Arg) {
  4651. bSemOverride = false;
  4652. }
  4653. continue;
  4654. }
  4655. // Cast vector/matrix/resource parameter.
  4656. V = castArgumentIfRequired(V, Ty, bOut, hasShaderInputOutput, inputQual,
  4657. annotation, WorkList, Builder);
  4658. // Cannot SROA, save it to final parameter list.
  4659. FlatParamList.emplace_back(V);
  4660. // Create ParamAnnotation for V.
  4661. FlatAnnotationList.emplace_back(DxilParameterAnnotation());
  4662. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList.back();
  4663. flatParamAnnotation.SetParamInputQual(paramAnnotation.GetParamInputQual());
  4664. flatParamAnnotation.SetInterpolationMode(annotation.GetInterpolationMode());
  4665. flatParamAnnotation.SetSemanticString(annotation.GetSemanticString());
  4666. flatParamAnnotation.SetCompType(annotation.GetCompType().GetKind());
  4667. flatParamAnnotation.SetMatrixAnnotation(annotation.GetMatrixAnnotation());
  4668. flatParamAnnotation.SetPrecise(annotation.IsPrecise());
  4669. // Add debug info.
  4670. if (DDI && V != Arg) {
  4671. Value *TmpV = V;
  4672. // If V is casted, add debug into to original V.
  4673. if (castParamMap.count(V)) {
  4674. TmpV = castParamMap[V].first;
  4675. // One more level for ptr of input vector.
  4676. // It cast from ptr to non-ptr then cast to scalars.
  4677. if (castParamMap.count(TmpV)) {
  4678. TmpV = castParamMap[TmpV].first;
  4679. }
  4680. }
  4681. Type *Ty = TmpV->getType();
  4682. if (Ty->isPointerTy())
  4683. Ty = Ty->getPointerElementType();
  4684. unsigned size = DL.getTypeAllocSize(Ty);
  4685. DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
  4686. debugOffset += size;
  4687. DIB.insertDeclare(TmpV, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4688. Builder.GetInsertPoint());
  4689. }
  4690. // Flatten stream out.
  4691. if (HLModule::IsStreamOutputPtrType(V->getType())) {
  4692. // For stream output objects.
  4693. // Create a value as output value.
  4694. Type *outputType = V->getType()->getPointerElementType()->getStructElementType(0);
  4695. Value *outputVal = Builder.CreateAlloca(outputType);
  4696. // For each stream.Append(data)
  4697. // transform into
  4698. // d = load data
  4699. // store outputVal, d
  4700. // stream.Append(outputVal)
  4701. for (User *user : V->users()) {
  4702. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4703. unsigned opcode = GetHLOpcode(CI);
  4704. if (opcode == static_cast<unsigned>(IntrinsicOp::MOP_Append)) {
  4705. if (CI->getNumArgOperands() == (HLOperandIndex::kStreamAppendDataOpIndex + 1)) {
  4706. Value *data =
  4707. CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  4708. DXASSERT(data->getType()->isPointerTy(),
  4709. "Append value must be pointer.");
  4710. IRBuilder<> Builder(CI);
  4711. llvm::SmallVector<llvm::Value *, 16> idxList;
  4712. SplitCpy(data->getType(), outputVal, data, idxList, Builder,
  4713. dxilTypeSys, &flatParamAnnotation);
  4714. CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
  4715. }
  4716. else {
  4717. // Append has been flattened.
  4718. // Flatten store outputVal.
  4719. // Must be struct to be flatten.
  4720. IRBuilder<> Builder(CI);
  4721. llvm::SmallVector<llvm::Value *, 16> idxList;
  4722. llvm::SmallVector<llvm::Value *, 16> EltPtrList;
  4723. // split
  4724. SplitPtr(outputVal->getType(), outputVal, idxList, EltPtrList,
  4725. Builder);
  4726. unsigned eltCount = CI->getNumArgOperands()-2;
  4727. DXASSERT_LOCALVAR(eltCount, eltCount == EltPtrList.size(), "invalid element count");
  4728. for (unsigned i = HLOperandIndex::kStreamAppendDataOpIndex; i < CI->getNumArgOperands(); i++) {
  4729. Value *DataPtr = CI->getArgOperand(i);
  4730. Value *EltPtr =
  4731. EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
  4732. llvm::SmallVector<llvm::Value *, 16> idxList;
  4733. SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
  4734. Builder, dxilTypeSys, &flatParamAnnotation);
  4735. CI->setArgOperand(i, EltPtr);
  4736. }
  4737. }
  4738. }
  4739. }
  4740. }
  4741. // Then split output value to generate ParamQual.
  4742. WorkList.push_front(outputVal);
  4743. }
  4744. }
  4745. }
  4746. // Now erase any instructions that were made dead while rewriting the
  4747. // alloca.
  4748. DeleteDeadInstructions();
  4749. // Erase dead allocas after all uses deleted.
  4750. for (Instruction *I : deadAllocas)
  4751. I->eraseFromParent();
  4752. unsigned endArgIndex = FlatAnnotationList.size();
  4753. if (bForParam && startArgIndex < endArgIndex) {
  4754. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4755. if (inputQual == DxilParamInputQual::OutStream0 ||
  4756. inputQual == DxilParamInputQual::OutStream1 ||
  4757. inputQual == DxilParamInputQual::OutStream2 ||
  4758. inputQual == DxilParamInputQual::OutStream3)
  4759. startArgIndex++;
  4760. DxilParameterAnnotation &flatParamAnnotation =
  4761. FlatAnnotationList[startArgIndex];
  4762. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  4763. if (!semantic.empty())
  4764. allocateSemanticIndex(FlatAnnotationList, startArgIndex,
  4765. semanticTypeMap);
  4766. }
  4767. }
  4768. /// moveFunctionBlocks - Move body of F to flatF.
  4769. void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
  4770. bool updateRetType = F->getReturnType() != flatF->getReturnType();
  4771. // Splice the body of the old function right into the new function.
  4772. flatF->getBasicBlockList().splice(flatF->begin(), F->getBasicBlockList());
  4773. // Update Block uses.
  4774. if (updateRetType) {
  4775. for (BasicBlock &BB : flatF->getBasicBlockList()) {
  4776. if (updateRetType) {
  4777. // Replace ret with ret void.
  4778. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4779. // Create store for return.
  4780. IRBuilder<> Builder(RI);
  4781. Builder.CreateRetVoid();
  4782. RI->eraseFromParent();
  4783. }
  4784. }
  4785. }
  4786. }
  4787. }
  4788. static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
  4789. DxilFieldAnnotation *fieldAnnotation) {
  4790. for (auto U = V->user_begin(); U != V->user_end();) {
  4791. User *user = *(U++);
  4792. if (StoreInst *ST = dyn_cast<StoreInst>(user)) {
  4793. Value *ptr = ST->getPointerOperand();
  4794. Value *val = ST->getValueOperand();
  4795. IRBuilder<> Builder(ST);
  4796. SmallVector<Value *, 16> idxList;
  4797. SplitCpy(ptr->getType(), ptr, val, idxList, Builder, typeSys,
  4798. fieldAnnotation);
  4799. ST->eraseFromParent();
  4800. }
  4801. }
  4802. }
  4803. static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
  4804. if (bLoad && bStore)
  4805. return;
  4806. for (User *user : V->users()) {
  4807. if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
  4808. bLoad = true;
  4809. } else if (StoreInst *SI = dyn_cast<StoreInst>(user)) {
  4810. bStore = true;
  4811. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  4812. CheckArgUsage(GEP, bLoad, bStore);
  4813. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4814. if (CI->getType()->isPointerTy())
  4815. CheckArgUsage(CI, bLoad, bStore);
  4816. else {
  4817. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4818. if (group == HLOpcodeGroup::HLMatLoadStore) {
  4819. HLMatLoadStoreOpcode opcode =
  4820. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4821. switch (opcode) {
  4822. case HLMatLoadStoreOpcode::ColMatLoad:
  4823. case HLMatLoadStoreOpcode::RowMatLoad:
  4824. bLoad = true;
  4825. break;
  4826. case HLMatLoadStoreOpcode::ColMatStore:
  4827. case HLMatLoadStoreOpcode::RowMatStore:
  4828. bStore = true;
  4829. break;
  4830. }
  4831. }
  4832. }
  4833. }
  4834. }
  4835. }
  4836. // Support store to input and load from output.
  4837. static void LegalizeDxilInputOutputs(Function *F,
  4838. DxilFunctionAnnotation *EntryAnnotation,
  4839. DxilTypeSystem &typeSys) {
  4840. BasicBlock &EntryBlk = F->getEntryBlock();
  4841. Module *M = F->getParent();
  4842. // Map from output to the temp created for it.
  4843. std::unordered_map<Argument *, Value*> outputTempMap;
  4844. for (Argument &arg : F->args()) {
  4845. Type *Ty = arg.getType();
  4846. DxilParameterAnnotation &paramAnnotation = EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  4847. DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
  4848. bool isColMajor = false;
  4849. // Skip arg which is not a pointer.
  4850. if (!Ty->isPointerTy()) {
  4851. if (HLMatrixLower::IsMatrixType(Ty)) {
  4852. // Replace matrix arg with cast to vec. It will be lowered in
  4853. // DxilGenerationPass.
  4854. isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
  4855. MatrixOrientation::ColumnMajor;
  4856. IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
  4857. HLCastOpcode opcode = isColMajor ? HLCastOpcode::ColMatrixToVecCast
  4858. : HLCastOpcode::RowMatrixToVecCast;
  4859. Value *undefVal = UndefValue::get(Ty);
  4860. Value *Cast = HLModule::EmitHLOperationCall(
  4861. Builder, HLOpcodeGroup::HLCast, static_cast<unsigned>(opcode), Ty,
  4862. {undefVal}, *M);
  4863. arg.replaceAllUsesWith(Cast);
  4864. // Set arg as the operand.
  4865. CallInst *CI = cast<CallInst>(Cast);
  4866. CI->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, &arg);
  4867. }
  4868. continue;
  4869. }
  4870. Ty = Ty->getPointerElementType();
  4871. bool bLoad = false;
  4872. bool bStore = false;
  4873. CheckArgUsage(&arg, bLoad, bStore);
  4874. bool bNeedTemp = false;
  4875. bool bStoreInputToTemp = false;
  4876. bool bLoadOutputFromTemp = false;
  4877. if (qual == DxilParamInputQual::In && bStore) {
  4878. bNeedTemp = true;
  4879. bStoreInputToTemp = true;
  4880. } else if (qual == DxilParamInputQual::Out && bLoad) {
  4881. bNeedTemp = true;
  4882. bLoadOutputFromTemp = true;
  4883. } else if (bLoad && bStore) {
  4884. switch (qual) {
  4885. case DxilParamInputQual::InputPrimitive:
  4886. case DxilParamInputQual::InputPatch:
  4887. case DxilParamInputQual::OutputPatch: {
  4888. bNeedTemp = true;
  4889. bStoreInputToTemp = true;
  4890. } break;
  4891. case DxilParamInputQual::Inout:
  4892. break;
  4893. default:
  4894. DXASSERT(0, "invalid input qual here");
  4895. }
  4896. } else if (qual == DxilParamInputQual::Inout) {
  4897. // Only replace inout when (bLoad && bStore) == false.
  4898. bNeedTemp = true;
  4899. bLoadOutputFromTemp = true;
  4900. bStoreInputToTemp = true;
  4901. }
  4902. if (HLMatrixLower::IsMatrixType(Ty)) {
  4903. bNeedTemp = true;
  4904. if (qual == DxilParamInputQual::In)
  4905. bStoreInputToTemp = bLoad;
  4906. else if (qual == DxilParamInputQual::Out)
  4907. bLoadOutputFromTemp = bStore;
  4908. else if (qual == DxilParamInputQual::Inout) {
  4909. bStoreInputToTemp = true;
  4910. bLoadOutputFromTemp = true;
  4911. }
  4912. }
  4913. if (bNeedTemp) {
  4914. IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
  4915. AllocaInst *temp = Builder.CreateAlloca(Ty);
  4916. // Replace all uses with temp.
  4917. arg.replaceAllUsesWith(temp);
  4918. // Copy input to temp.
  4919. if (bStoreInputToTemp) {
  4920. llvm::SmallVector<llvm::Value *, 16> idxList;
  4921. // split copy.
  4922. SplitCpy(temp->getType(), temp, &arg, idxList, Builder, typeSys,
  4923. &paramAnnotation);
  4924. }
  4925. // Generate store output, temp later.
  4926. if (bLoadOutputFromTemp) {
  4927. outputTempMap[&arg] = temp;
  4928. }
  4929. }
  4930. }
  4931. for (BasicBlock &BB : F->getBasicBlockList()) {
  4932. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4933. IRBuilder<> Builder(RI);
  4934. // Copy temp to output.
  4935. for (auto It : outputTempMap) {
  4936. Argument *output = It.first;
  4937. Value *temp = It.second;
  4938. llvm::SmallVector<llvm::Value *, 16> idxList;
  4939. DxilParameterAnnotation &paramAnnotation =
  4940. EntryAnnotation->GetParameterAnnotation(output->getArgNo());
  4941. auto Iter = Builder.GetInsertPoint();
  4942. bool onlyRetBlk = false;
  4943. if (RI != BB.begin())
  4944. Iter--;
  4945. else
  4946. onlyRetBlk = true;
  4947. // split copy.
  4948. SplitCpy(output->getType(), output, temp, idxList, Builder, typeSys,
  4949. &paramAnnotation);
  4950. }
  4951. // Clone the return.
  4952. Builder.CreateRet(RI->getReturnValue());
  4953. RI->eraseFromParent();
  4954. }
  4955. }
  4956. }
  4957. void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
  4958. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  4959. // Skip void (void) function.
  4960. if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
  4961. return;
  4962. }
  4963. // Clear maps for cast.
  4964. castParamMap.clear();
  4965. vectorEltsMap.clear();
  4966. DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
  4967. DXASSERT(funcAnnotation, "must find annotation for function");
  4968. std::deque<Value *> WorkList;
  4969. std::vector<Value *> FlatParamList;
  4970. std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
  4971. const bool bForParamTrue = true;
  4972. // Add all argument to worklist.
  4973. for (Argument &Arg : F->args()) {
  4974. // merge GEP use for arg.
  4975. HLModule::MergeGepUse(&Arg);
  4976. // Insert point may be removed. So recreate builder every time.
  4977. IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
  4978. DxilParameterAnnotation &paramAnnotation =
  4979. funcAnnotation->GetParameterAnnotation(Arg.getArgNo());
  4980. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&Arg);
  4981. flattenArgument(F, &Arg, bForParamTrue, paramAnnotation, FlatParamList,
  4982. FlatParamAnnotationList, Builder, DDI);
  4983. }
  4984. Type *retType = F->getReturnType();
  4985. std::vector<Value *> FlatRetList;
  4986. std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
  4987. // Split and change to out parameter.
  4988. if (!retType->isVoidTy()) {
  4989. Instruction *InsertPt = F->getEntryBlock().getFirstInsertionPt();
  4990. IRBuilder<> Builder(InsertPt);
  4991. Value *retValAddr = Builder.CreateAlloca(retType);
  4992. DxilParameterAnnotation &retAnnotation =
  4993. funcAnnotation->GetRetTypeAnnotation();
  4994. Module &M = *m_pHLModule->GetModule();
  4995. Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
  4996. // Create DbgDecl for the ret value.
  4997. if (DISubprogram *funcDI = getDISubprogram(F)) {
  4998. DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
  4999. DITypeIdentifierMap EmptyMap;
  5000. DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
  5001. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  5002. DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
  5003. funcDI->getLine(), RetDIType);
  5004. DIExpression *Expr = nullptr;
  5005. // TODO: how to get col?
  5006. DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0, funcDI);
  5007. DIB.insertDeclare(retValAddr, RetVar, Expr, DL, Builder.GetInsertPoint());
  5008. }
  5009. for (BasicBlock &BB : F->getBasicBlockList()) {
  5010. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  5011. // Create store for return.
  5012. IRBuilder<> RetBuilder(RI);
  5013. if (!retAnnotation.HasMatrixAnnotation()) {
  5014. RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
  5015. } else {
  5016. bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
  5017. MatrixOrientation::RowMajor;
  5018. Value *RetVal = RI->getReturnValue();
  5019. if (!isRowMajor) {
  5020. // Matrix value is row major. ColMatStore require col major.
  5021. // Cast before store.
  5022. RetVal = HLModule::EmitHLOperationCall(
  5023. RetBuilder, HLOpcodeGroup::HLCast,
  5024. static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
  5025. RetVal->getType(), {RetVal}, M);
  5026. }
  5027. unsigned opcode = static_cast<unsigned>(
  5028. isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
  5029. : HLMatLoadStoreOpcode::ColMatStore);
  5030. HLModule::EmitHLOperationCall(RetBuilder,
  5031. HLOpcodeGroup::HLMatLoadStore, opcode,
  5032. voidTy, {retValAddr, RetVal}, M);
  5033. }
  5034. }
  5035. }
  5036. // Create a fake store to keep retValAddr so it can be flattened.
  5037. if (retValAddr->user_empty()) {
  5038. Builder.CreateStore(UndefValue::get(retType), retValAddr);
  5039. }
  5040. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(retValAddr);
  5041. flattenArgument(F, retValAddr, bForParamTrue,
  5042. funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
  5043. FlatRetAnnotationList, Builder, DDI);
  5044. }
  5045. // Always change return type as parameter.
  5046. // By doing this, no need to check return when generate storeOutput.
  5047. if (FlatRetList.size() ||
  5048. // For empty struct return type.
  5049. !retType->isVoidTy()) {
  5050. // Return value is flattened.
  5051. // Change return value into out parameter.
  5052. retType = Type::getVoidTy(retType->getContext());
  5053. // Merge return data info param data.
  5054. FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
  5055. FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
  5056. FlatRetAnnotationList.begin(),
  5057. FlatRetAnnotationList.end());
  5058. }
  5059. // TODO: take care AttributeSet for function and each argument.
  5060. std::vector<Type *> FinalTypeList;
  5061. for (Value * arg : FlatParamList) {
  5062. FinalTypeList.emplace_back(arg->getType());
  5063. }
  5064. unsigned extraParamSize = 0;
  5065. if (m_pHLModule->HasHLFunctionProps(F)) {
  5066. HLFunctionProps &funcProps = m_pHLModule->GetHLFunctionProps(F);
  5067. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5068. auto &VS = funcProps.ShaderProps.VS;
  5069. Type *outFloatTy = Type::getFloatPtrTy(F->getContext());
  5070. // Add out float parameter for each clip plane.
  5071. unsigned i=0;
  5072. for (; i < DXIL::kNumClipPlanes; i++) {
  5073. if (!VS.clipPlanes[i])
  5074. break;
  5075. FinalTypeList.emplace_back(outFloatTy);
  5076. }
  5077. extraParamSize = i;
  5078. }
  5079. }
  5080. FunctionType *flatFuncTy = FunctionType::get(retType, FinalTypeList, false);
  5081. // Return if nothing changed.
  5082. if (flatFuncTy == F->getFunctionType()) {
  5083. // Copy semantic allocation.
  5084. if (!FlatParamAnnotationList.empty()) {
  5085. if (!FlatParamAnnotationList[0].GetSemanticString().empty()) {
  5086. for (unsigned i = 0; i < FlatParamAnnotationList.size(); i++) {
  5087. DxilParameterAnnotation &paramAnnotation = funcAnnotation->GetParameterAnnotation(i);
  5088. DxilParameterAnnotation &flatParamAnnotation = FlatParamAnnotationList[i];
  5089. paramAnnotation.SetSemanticIndexVec(flatParamAnnotation.GetSemanticIndexVec());
  5090. paramAnnotation.SetSemanticString(flatParamAnnotation.GetSemanticString());
  5091. }
  5092. }
  5093. }
  5094. // Support store to input and load from output.
  5095. LegalizeDxilInputOutputs(F, funcAnnotation, typeSys);
  5096. return;
  5097. }
  5098. std::string flatName = F->getName().str() + ".flat";
  5099. DXASSERT(nullptr == F->getParent()->getFunction(flatName),
  5100. "else overwriting existing function");
  5101. Function *flatF =
  5102. cast<Function>(F->getParent()->getOrInsertFunction(flatName, flatFuncTy));
  5103. funcMap[F] = flatF;
  5104. // Update function debug info.
  5105. if (DISubprogram *funcDI = getDISubprogram(F))
  5106. funcDI->replaceFunction(flatF);
  5107. // Create FunctionAnnotation for flatF.
  5108. DxilFunctionAnnotation *flatFuncAnnotation = m_pHLModule->AddFunctionAnnotation(flatF);
  5109. // Don't need to set Ret Info, flatF always return void now.
  5110. // Param Info
  5111. for (unsigned ArgNo = 0; ArgNo < FlatParamAnnotationList.size(); ++ArgNo) {
  5112. DxilParameterAnnotation &paramAnnotation = flatFuncAnnotation->GetParameterAnnotation(ArgNo);
  5113. paramAnnotation = FlatParamAnnotationList[ArgNo];
  5114. }
  5115. DXASSERT(flatF->arg_size() == (extraParamSize + FlatParamAnnotationList.size()), "parameter count mismatch");
  5116. // ShaderProps.
  5117. if (m_pHLModule->HasHLFunctionProps(F)) {
  5118. HLFunctionProps &funcProps = m_pHLModule->GetHLFunctionProps(F);
  5119. std::unique_ptr<HLFunctionProps> flatFuncProps = std::make_unique<HLFunctionProps>();
  5120. flatFuncProps->shaderKind = funcProps.shaderKind;
  5121. flatFuncProps->ShaderProps = funcProps.ShaderProps;
  5122. m_pHLModule->AddHLFunctionProps(flatF, flatFuncProps);
  5123. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5124. auto &VS = funcProps.ShaderProps.VS;
  5125. unsigned clipArgIndex = FlatParamAnnotationList.size();
  5126. // Add out float SV_ClipDistance for each clip plane.
  5127. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  5128. if (!VS.clipPlanes[i])
  5129. break;
  5130. DxilParameterAnnotation &paramAnnotation =
  5131. flatFuncAnnotation->GetParameterAnnotation(clipArgIndex+i);
  5132. paramAnnotation.SetParamInputQual(DxilParamInputQual::Out);
  5133. Twine semName = Twine("SV_ClipDistance") + Twine(i);
  5134. paramAnnotation.SetSemanticString(semName.str());
  5135. paramAnnotation.SetCompType(DXIL::ComponentType::F32);
  5136. paramAnnotation.AppendSemanticIndex(i);
  5137. }
  5138. }
  5139. }
  5140. // Move function body into flatF.
  5141. moveFunctionBody(F, flatF);
  5142. // Replace old parameters with flatF Arguments.
  5143. auto argIter = flatF->arg_begin();
  5144. auto flatArgIter = FlatParamList.begin();
  5145. LLVMContext &Context = F->getContext();
  5146. // Parameter cast come from begining of entry block.
  5147. IRBuilder<> Builder(flatF->getEntryBlock().getFirstInsertionPt());
  5148. while (argIter != flatF->arg_end()) {
  5149. Argument *Arg = argIter++;
  5150. if (flatArgIter == FlatParamList.end()) {
  5151. DXASSERT(extraParamSize>0, "parameter count mismatch");
  5152. break;
  5153. }
  5154. Value *flatArg = *(flatArgIter++);
  5155. if (castParamMap.count(flatArg)) {
  5156. replaceCastParameter(flatArg, castParamMap[flatArg].first, *flatF, Arg,
  5157. castParamMap[flatArg].second, Builder);
  5158. }
  5159. flatArg->replaceAllUsesWith(Arg);
  5160. // Update arg debug info.
  5161. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(flatArg);
  5162. if (DDI) {
  5163. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(Arg));
  5164. DDI->setArgOperand(0, VMD);
  5165. }
  5166. HLModule::MergeGepUse(Arg);
  5167. // Flatten store of array parameter.
  5168. if (Arg->getType()->isPointerTy()) {
  5169. Type *Ty = Arg->getType()->getPointerElementType();
  5170. if (Ty->isArrayTy())
  5171. SplitArrayCopy(
  5172. Arg, typeSys,
  5173. &flatFuncAnnotation->GetParameterAnnotation(Arg->getArgNo()));
  5174. }
  5175. }
  5176. // Support store to input and load from output.
  5177. LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, typeSys);
  5178. }
  5179. void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI) {
  5180. DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
  5181. DXASSERT(funcAnnotation, "must find annotation for function");
  5182. // Clear maps for cast.
  5183. castParamMap.clear();
  5184. vectorEltsMap.clear();
  5185. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5186. std::vector<Value *> FlatParamList;
  5187. std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
  5188. IRBuilder<> AllocaBuilder(
  5189. CI->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  5190. IRBuilder<> CallBuilder(CI);
  5191. IRBuilder<> RetBuilder(CI->getNextNode());
  5192. Type *retType = F->getReturnType();
  5193. std::vector<Value *> FlatRetList;
  5194. std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
  5195. const bool bForParamFalse = false;
  5196. // Split and change to out parameter.
  5197. if (!retType->isVoidTy()) {
  5198. Value *retValAddr = AllocaBuilder.CreateAlloca(retType);
  5199. // Create DbgDecl for the ret value.
  5200. if (DISubprogram *funcDI = getDISubprogram(F)) {
  5201. DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
  5202. DITypeIdentifierMap EmptyMap;
  5203. DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
  5204. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  5205. DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
  5206. funcDI->getLine(), RetDIType);
  5207. DIExpression *Expr = nullptr;
  5208. // TODO: how to get col?
  5209. DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0, funcDI);
  5210. DIB.insertDeclare(retValAddr, RetVar, Expr, DL, CI);
  5211. }
  5212. DxilParameterAnnotation &retAnnotation = funcAnnotation->GetRetTypeAnnotation();
  5213. // Load ret value and replace CI.
  5214. Value *newRetVal = nullptr;
  5215. if (!retAnnotation.HasMatrixAnnotation()) {
  5216. newRetVal = RetBuilder.CreateLoad(retValAddr);
  5217. } else {
  5218. bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
  5219. MatrixOrientation::RowMajor;
  5220. unsigned opcode =
  5221. static_cast<unsigned>(isRowMajor ? HLMatLoadStoreOpcode::RowMatLoad
  5222. : HLMatLoadStoreOpcode::ColMatLoad);
  5223. newRetVal = HLModule::EmitHLOperationCall(RetBuilder, HLOpcodeGroup::HLMatLoadStore,
  5224. opcode, retType, {retValAddr},
  5225. *m_pHLModule->GetModule());
  5226. if (!isRowMajor) {
  5227. // ColMatLoad will return a col major.
  5228. // Matrix value should be row major.
  5229. // Cast it here.
  5230. newRetVal = HLModule::EmitHLOperationCall(
  5231. RetBuilder, HLOpcodeGroup::HLCast,
  5232. static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix), retType,
  5233. {newRetVal}, *m_pHLModule->GetModule());
  5234. }
  5235. }
  5236. CI->replaceAllUsesWith(newRetVal);
  5237. // Flat ret val
  5238. flattenArgument(flatF, retValAddr, bForParamFalse,
  5239. funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
  5240. FlatRetAnnotationList, AllocaBuilder,
  5241. /*DbgDeclareInst*/ nullptr);
  5242. }
  5243. std::vector<Value *> args;
  5244. for (auto &arg : CI->arg_operands()) {
  5245. args.emplace_back(arg.get());
  5246. }
  5247. // Remove CI from user of args.
  5248. CI->dropAllReferences();
  5249. // Add all argument to worklist.
  5250. for (unsigned i=0;i<args.size();i++) {
  5251. DxilParameterAnnotation &paramAnnotation =
  5252. funcAnnotation->GetParameterAnnotation(i);
  5253. Value *arg = args[i];
  5254. Type *Ty = arg->getType();
  5255. if (Ty->isPointerTy()) {
  5256. // For pointer, alloca another pointer, replace in CI.
  5257. Value *tempArg =
  5258. AllocaBuilder.CreateAlloca(arg->getType()->getPointerElementType());
  5259. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  5260. // TODO: support special InputQual like InputPatch.
  5261. if (inputQual == DxilParamInputQual::In ||
  5262. inputQual == DxilParamInputQual::Inout) {
  5263. // Copy in param.
  5264. llvm::SmallVector<llvm::Value *, 16> idxList;
  5265. // split copy to avoid load of struct.
  5266. SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, typeSys,
  5267. &paramAnnotation);
  5268. }
  5269. if (inputQual == DxilParamInputQual::Out ||
  5270. inputQual == DxilParamInputQual::Inout) {
  5271. // Copy out param.
  5272. llvm::SmallVector<llvm::Value *, 16> idxList;
  5273. // split copy to avoid load of struct.
  5274. SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, typeSys,
  5275. &paramAnnotation);
  5276. }
  5277. arg = tempArg;
  5278. flattenArgument(flatF, arg, bForParamFalse, paramAnnotation,
  5279. FlatParamList, FlatParamAnnotationList, AllocaBuilder,
  5280. /*DbgDeclareInst*/ nullptr);
  5281. } else {
  5282. // Cast vector into array.
  5283. if (Ty->isVectorTy()) {
  5284. unsigned vecSize = Ty->getVectorNumElements();
  5285. for (unsigned vi = 0; vi < vecSize; vi++) {
  5286. Value *Elt = CallBuilder.CreateExtractElement(arg, vi);
  5287. // Cannot SROA, save it to final parameter list.
  5288. FlatParamList.emplace_back(Elt);
  5289. // Create ParamAnnotation for V.
  5290. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5291. DxilParameterAnnotation &flatParamAnnotation =
  5292. FlatRetAnnotationList.back();
  5293. flatParamAnnotation = paramAnnotation;
  5294. }
  5295. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  5296. unsigned col, row;
  5297. Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  5298. Value *Mat = arg;
  5299. // Cast matrix to array.
  5300. Type *AT = ArrayType::get(EltTy, col * row);
  5301. arg = AllocaBuilder.CreateAlloca(AT);
  5302. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  5303. castParamMap[arg] = std::make_pair(Mat, inputQual);
  5304. DXASSERT(paramAnnotation.HasMatrixAnnotation(),
  5305. "need matrix annotation here");
  5306. if (paramAnnotation.GetMatrixAnnotation().Orientation ==
  5307. hlsl::MatrixOrientation::RowMajor) {
  5308. castRowMajorParamMap.insert(arg);
  5309. }
  5310. // Cannot SROA, save it to final parameter list.
  5311. FlatParamList.emplace_back(arg);
  5312. // Create ParamAnnotation for V.
  5313. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5314. DxilParameterAnnotation &flatParamAnnotation =
  5315. FlatRetAnnotationList.back();
  5316. flatParamAnnotation = paramAnnotation;
  5317. } else {
  5318. // Cannot SROA, save it to final parameter list.
  5319. FlatParamList.emplace_back(arg);
  5320. // Create ParamAnnotation for V.
  5321. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5322. DxilParameterAnnotation &flatParamAnnotation =
  5323. FlatRetAnnotationList.back();
  5324. flatParamAnnotation = paramAnnotation;
  5325. }
  5326. }
  5327. }
  5328. // Always change return type as parameter.
  5329. // By doing this, no need to check return when generate storeOutput.
  5330. if (FlatRetList.size() ||
  5331. // For empty struct return type.
  5332. !retType->isVoidTy()) {
  5333. // Merge return data info param data.
  5334. FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
  5335. FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
  5336. FlatRetAnnotationList.begin(),
  5337. FlatRetAnnotationList.end());
  5338. }
  5339. RetBuilder.SetInsertPoint(CI->getNextNode());
  5340. unsigned paramSize = FlatParamList.size();
  5341. for (unsigned i = 0; i < paramSize; i++) {
  5342. Value *&flatArg = FlatParamList[i];
  5343. if (castParamMap.count(flatArg)) {
  5344. replaceCastArgument(flatArg, castParamMap[flatArg].first,
  5345. castParamMap[flatArg].second, CallBuilder,
  5346. RetBuilder);
  5347. if (vectorEltsMap.count(flatArg) && !flatArg->getType()->isPointerTy()) {
  5348. // Vector elements need to be updated.
  5349. SmallVector<Value *, 4> &elts = vectorEltsMap[flatArg];
  5350. // Back one step.
  5351. --i;
  5352. for (Value *elt : elts) {
  5353. FlatParamList[++i] = elt;
  5354. }
  5355. }
  5356. }
  5357. }
  5358. CallInst *NewCI = CallBuilder.CreateCall(flatF, FlatParamList);
  5359. CallBuilder.SetInsertPoint(NewCI);
  5360. CI->eraseFromParent();
  5361. }
  5362. void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
  5363. // Update entry function.
  5364. if (F == m_pHLModule->GetEntryFunction()) {
  5365. m_pHLModule->SetEntryFunction(flatF);
  5366. }
  5367. // Update patch constant function.
  5368. if (m_pHLModule->HasHLFunctionProps(flatF)) {
  5369. HLFunctionProps &funcProps = m_pHLModule->GetHLFunctionProps(flatF);
  5370. if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
  5371. Function *oldPatchConstantFunc =
  5372. funcProps.ShaderProps.HS.patchConstantFunc;
  5373. if (funcMap.count(oldPatchConstantFunc))
  5374. funcProps.ShaderProps.HS.patchConstantFunc =
  5375. funcMap[oldPatchConstantFunc];
  5376. }
  5377. }
  5378. // TODO: flatten vector argument and lower resource argument when flatten
  5379. // functions.
  5380. for (auto it = F->user_begin(); it != F->user_end(); ) {
  5381. CallInst *CI = cast<CallInst>(*(it++));
  5382. createFlattenedFunctionCall(F, flatF, CI);
  5383. }
  5384. }
  5385. // Public interface to the SROA_Parameter_HLSL pass
  5386. ModulePass *llvm::createSROA_Parameter_HLSL() {
  5387. return new SROA_Parameter_HLSL();
  5388. }
  5389. //===----------------------------------------------------------------------===//
  5390. // Lower static global into Alloca.
  5391. //===----------------------------------------------------------------------===//
  5392. namespace {
  5393. class LowerStaticGlobalIntoAlloca : public ModulePass {
  5394. HLModule *m_pHLModule;
  5395. public:
  5396. static char ID; // Pass identification, replacement for typeid
  5397. explicit LowerStaticGlobalIntoAlloca() : ModulePass(ID) {}
  5398. const char *getPassName() const override { return "Lower static global into Alloca"; }
  5399. bool runOnModule(Module &M) override {
  5400. m_pHLModule = &M.GetOrCreateHLModule();
  5401. // Lower static global into allocas.
  5402. std::vector<GlobalVariable *> staticGVs;
  5403. for (GlobalVariable &GV : M.globals()) {
  5404. bool isStaticGlobal =
  5405. HLModule::IsStaticGlobal(&GV) &&
  5406. GV.getType()->getAddressSpace() == DXIL::kDefaultAddrSpace;
  5407. if (isStaticGlobal &&
  5408. !GV.getType()->getElementType()->isAggregateType()) {
  5409. staticGVs.emplace_back(&GV);
  5410. }
  5411. }
  5412. bool bUpdated = false;
  5413. const DataLayout &DL = M.getDataLayout();
  5414. for (GlobalVariable *GV : staticGVs) {
  5415. bUpdated |= lowerStaticGlobalIntoAlloca(GV, DL);
  5416. }
  5417. return bUpdated;
  5418. }
  5419. private:
  5420. bool lowerStaticGlobalIntoAlloca(GlobalVariable *GV, const DataLayout &DL);
  5421. };
  5422. }
  5423. bool LowerStaticGlobalIntoAlloca::lowerStaticGlobalIntoAlloca(GlobalVariable *GV, const DataLayout &DL) {
  5424. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5425. unsigned size = DL.getTypeAllocSize(GV->getType()->getElementType());
  5426. PointerStatus PS(size);
  5427. GV->removeDeadConstantUsers();
  5428. PS.analyzePointer(GV, PS, typeSys, /*bStructElt*/ false);
  5429. bool NotStored = (PS.StoredType == PointerStatus::NotStored) ||
  5430. (PS.StoredType == PointerStatus::InitializerStored);
  5431. // Make sure GV only used in one function.
  5432. // Skip GV which don't have store.
  5433. if (PS.HasMultipleAccessingFunctions || NotStored)
  5434. return false;
  5435. Function *F = const_cast<Function*>(PS.AccessingFunction);
  5436. IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
  5437. AllocaInst *AI = Builder.CreateAlloca(GV->getType()->getElementType());
  5438. // Store initializer is exist.
  5439. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  5440. Builder.CreateStore(GV->getInitializer(), GV);
  5441. }
  5442. ReplaceConstantWithInst(GV, AI, Builder);
  5443. GV->eraseFromParent();
  5444. return true;
  5445. }
  5446. char LowerStaticGlobalIntoAlloca::ID = 0;
  5447. INITIALIZE_PASS(LowerStaticGlobalIntoAlloca, "static-global-to-alloca",
  5448. "Lower static global into Alloca", false,
  5449. false)
  5450. // Public interface to the LowerStaticGlobalIntoAlloca pass
  5451. ModulePass *llvm::createLowerStaticGlobalIntoAlloca() {
  5452. return new LowerStaticGlobalIntoAlloca();
  5453. }
  5454. //===----------------------------------------------------------------------===//
  5455. // Lower one type to another type.
  5456. //===----------------------------------------------------------------------===//
  5457. namespace {
  5458. class LowerTypePass : public ModulePass {
  5459. public:
  5460. explicit LowerTypePass(char &ID)
  5461. : ModulePass(ID) {}
  5462. bool runOnModule(Module &M) override;
  5463. private:
  5464. bool runOnFunction(Function &F, bool HasDbgInfo);
  5465. AllocaInst *lowerAlloca(AllocaInst *A);
  5466. GlobalVariable *lowerInternalGlobal(GlobalVariable *GV);
  5467. protected:
  5468. virtual bool needToLower(Value *V) = 0;
  5469. virtual void lowerUseWithNewValue(Value *V, Value *NewV) = 0;
  5470. virtual Type *lowerType(Type *Ty) = 0;
  5471. virtual Constant *lowerInitVal(Constant *InitVal, Type *NewTy) = 0;
  5472. virtual StringRef getGlobalPrefix() = 0;
  5473. virtual void initialize(Module &M) {};
  5474. };
  5475. AllocaInst *LowerTypePass::lowerAlloca(AllocaInst *A) {
  5476. IRBuilder<> Builder(A);
  5477. Type *NewTy = lowerType(A->getAllocatedType());
  5478. return Builder.CreateAlloca(NewTy);
  5479. }
  5480. GlobalVariable *LowerTypePass::lowerInternalGlobal(GlobalVariable *GV) {
  5481. Type *NewTy = lowerType(GV->getType()->getPointerElementType());
  5482. // So set init val to undef.
  5483. Constant *InitVal = UndefValue::get(NewTy);
  5484. if (GV->hasInitializer()) {
  5485. Constant *OldInitVal = GV->getInitializer();
  5486. if (isa<ConstantAggregateZero>(OldInitVal))
  5487. InitVal = ConstantAggregateZero::get(NewTy);
  5488. else if (!isa<UndefValue>(OldInitVal)) {
  5489. InitVal = lowerInitVal(OldInitVal, NewTy);
  5490. }
  5491. }
  5492. bool isConst = GV->isConstant();
  5493. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  5494. unsigned AddressSpace = GV->getType()->getAddressSpace();
  5495. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  5496. Module *M = GV->getParent();
  5497. GlobalVariable *NewGV = new llvm::GlobalVariable(
  5498. *M, NewTy, /*IsConstant*/ isConst, linkage,
  5499. /*InitVal*/ InitVal, GV->getName() + getGlobalPrefix(),
  5500. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  5501. return NewGV;
  5502. }
  5503. bool LowerTypePass::runOnFunction(Function &F, bool HasDbgInfo) {
  5504. std::vector<AllocaInst *> workList;
  5505. // Scan the entry basic block, adding allocas to the worklist.
  5506. BasicBlock &BB = F.getEntryBlock();
  5507. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  5508. if (!isa<AllocaInst>(I))
  5509. continue;
  5510. AllocaInst *A = cast<AllocaInst>(I);
  5511. if (needToLower(A))
  5512. workList.emplace_back(A);
  5513. }
  5514. LLVMContext &Context = F.getContext();
  5515. for (AllocaInst *A : workList) {
  5516. AllocaInst *NewA = lowerAlloca(A);
  5517. if (HasDbgInfo) {
  5518. // Add debug info.
  5519. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(A);
  5520. if (DDI) {
  5521. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  5522. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  5523. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(NewA));
  5524. IRBuilder<> debugBuilder(DDI);
  5525. debugBuilder.CreateCall(DDI->getCalledFunction(),
  5526. {VMD, DDIVar, DDIExp});
  5527. }
  5528. }
  5529. // Replace users.
  5530. lowerUseWithNewValue(A, NewA);
  5531. // Remove alloca.
  5532. A->eraseFromParent();
  5533. }
  5534. return true;
  5535. }
  5536. bool LowerTypePass::runOnModule(Module &M) {
  5537. initialize(M);
  5538. // Load up debug information, to cross-reference values and the instructions
  5539. // used to load them.
  5540. bool HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  5541. llvm::DebugInfoFinder Finder;
  5542. if (HasDbgInfo) {
  5543. Finder.processModule(M);
  5544. }
  5545. std::vector<AllocaInst*> multiDimAllocas;
  5546. for (Function &F : M.functions()) {
  5547. if (F.isDeclaration())
  5548. continue;
  5549. runOnFunction(F, HasDbgInfo);
  5550. }
  5551. // Work on internal global.
  5552. std::vector<GlobalVariable *> vecGVs;
  5553. for (GlobalVariable &GV : M.globals()) {
  5554. if (HLModule::IsStaticGlobal(&GV) || HLModule::IsSharedMemoryGlobal(&GV)) {
  5555. if (needToLower(&GV) && !GV.user_empty())
  5556. vecGVs.emplace_back(&GV);
  5557. }
  5558. }
  5559. for (GlobalVariable *GV : vecGVs) {
  5560. GlobalVariable *NewGV = lowerInternalGlobal(GV);
  5561. // Add debug info.
  5562. if (HasDbgInfo) {
  5563. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, NewGV);
  5564. }
  5565. // Replace users.
  5566. lowerUseWithNewValue(GV, NewGV);
  5567. // Remove GV.
  5568. GV->removeDeadConstantUsers();
  5569. GV->eraseFromParent();
  5570. }
  5571. return true;
  5572. }
  5573. }
  5574. //===----------------------------------------------------------------------===//
  5575. // DynamicIndexingVector to Array.
  5576. //===----------------------------------------------------------------------===//
  5577. namespace {
  5578. class DynamicIndexingVectorToArray : public LowerTypePass {
  5579. bool ReplaceAllVectors;
  5580. public:
  5581. explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
  5582. : LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
  5583. static char ID; // Pass identification, replacement for typeid
  5584. void applyOptions(PassOptions O) override;
  5585. void dumpConfig(raw_ostream &OS) override;
  5586. protected:
  5587. bool needToLower(Value *V) override;
  5588. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  5589. Type *lowerType(Type *Ty) override;
  5590. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  5591. StringRef getGlobalPrefix() override { return ".v"; }
  5592. private:
  5593. bool HasVectorDynamicIndexing(Value *V);
  5594. void ReplaceVecGEP(Value *GEP, ArrayRef<Value *> idxList, Value *A,
  5595. IRBuilder<> &Builder);
  5596. void ReplaceVecArrayGEP(Value *GEP, ArrayRef<Value *> idxList, Value *A,
  5597. IRBuilder<> &Builder);
  5598. void ReplaceVectorWithArray(Value *Vec, Value *Array);
  5599. void ReplaceVectorArrayWithArray(Value *VecArray, Value *Array);
  5600. void ReplaceStaticIndexingOnVector(Value *V);
  5601. };
  5602. void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
  5603. GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
  5604. ReplaceAllVectors);
  5605. }
  5606. void DynamicIndexingVectorToArray::dumpConfig(raw_ostream &OS) {
  5607. ModulePass::dumpConfig(OS);
  5608. OS << ",ReplaceAllVectors=" << ReplaceAllVectors;
  5609. }
  5610. void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
  5611. for (auto U = V->user_begin(), E = V->user_end(); U != E;) {
  5612. Value *User = *(U++);
  5613. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5614. // Only work on element access for vector.
  5615. if (GEP->getNumOperands() == 3) {
  5616. auto Idx = GEP->idx_begin();
  5617. // Skip the pointer idx.
  5618. Idx++;
  5619. ConstantInt *constIdx = cast<ConstantInt>(Idx);
  5620. for (auto GEPU = GEP->user_begin(), GEPE = GEP->user_end();
  5621. GEPU != GEPE;) {
  5622. Instruction *GEPUser = cast<Instruction>(*(GEPU++));
  5623. IRBuilder<> Builder(GEPUser);
  5624. if (LoadInst *ldInst = dyn_cast<LoadInst>(GEPUser)) {
  5625. // Change
  5626. // ld a->x
  5627. // into
  5628. // b = ld a
  5629. // b.x
  5630. Value *ldVal = Builder.CreateLoad(V);
  5631. Value *Elt = Builder.CreateExtractElement(ldVal, constIdx);
  5632. ldInst->replaceAllUsesWith(Elt);
  5633. ldInst->eraseFromParent();
  5634. } else {
  5635. // Change
  5636. // st val, a->x
  5637. // into
  5638. // tmp = ld a
  5639. // tmp.x = val
  5640. // st tmp, a
  5641. // Must be store inst here.
  5642. StoreInst *stInst = cast<StoreInst>(GEPUser);
  5643. Value *val = stInst->getValueOperand();
  5644. Value *ldVal = Builder.CreateLoad(V);
  5645. ldVal = Builder.CreateInsertElement(ldVal, val, constIdx);
  5646. Builder.CreateStore(ldVal, V);
  5647. stInst->eraseFromParent();
  5648. }
  5649. }
  5650. GEP->eraseFromParent();
  5651. } else if (GEP->getNumIndices() == 1) {
  5652. Value *Idx = *GEP->idx_begin();
  5653. if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
  5654. if (C->getLimitedValue() == 0) {
  5655. GEP->replaceAllUsesWith(V);
  5656. GEP->eraseFromParent();
  5657. }
  5658. }
  5659. }
  5660. }
  5661. }
  5662. }
  5663. bool DynamicIndexingVectorToArray::needToLower(Value *V) {
  5664. Type *Ty = V->getType()->getPointerElementType();
  5665. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  5666. if (isa<GlobalVariable>(V) || ReplaceAllVectors) {
  5667. return true;
  5668. }
  5669. // Don't lower local vector which only static indexing.
  5670. if (HasVectorDynamicIndexing(V)) {
  5671. return true;
  5672. } else {
  5673. // Change vector indexing with ld st.
  5674. ReplaceStaticIndexingOnVector(V);
  5675. return false;
  5676. }
  5677. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  5678. // Array must be replaced even without dynamic indexing to remove vector
  5679. // type in dxil.
  5680. // TODO: optimize static array index in later pass.
  5681. Type *EltTy = HLModule::GetArrayEltTy(AT);
  5682. return isa<VectorType>(EltTy);
  5683. }
  5684. return false;
  5685. }
  5686. void DynamicIndexingVectorToArray::ReplaceVecGEP(Value *GEP, ArrayRef<Value *> idxList,
  5687. Value *A, IRBuilder<> &Builder) {
  5688. Value *newGEP = Builder.CreateGEP(A, idxList);
  5689. if (GEP->getType()->getPointerElementType()->isVectorTy()) {
  5690. ReplaceVectorWithArray(GEP, newGEP);
  5691. } else {
  5692. GEP->replaceAllUsesWith(newGEP);
  5693. }
  5694. }
  5695. void DynamicIndexingVectorToArray::ReplaceVectorWithArray(Value *Vec, Value *A) {
  5696. unsigned size = Vec->getType()->getPointerElementType()->getVectorNumElements();
  5697. for (auto U = Vec->user_begin(); U != Vec->user_end();) {
  5698. User *User = (*U++);
  5699. // GlobalVariable user.
  5700. if (isa<ConstantExpr>(User)) {
  5701. if (User->user_empty())
  5702. continue;
  5703. if (GEPOperator *GEP = dyn_cast<GEPOperator>(User)) {
  5704. IRBuilder<> Builder(Vec->getContext());
  5705. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5706. ReplaceVecGEP(GEP, idxList, A, Builder);
  5707. continue;
  5708. }
  5709. }
  5710. // Instrution user.
  5711. Instruction *UserInst = cast<Instruction>(User);
  5712. IRBuilder<> Builder(UserInst);
  5713. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5714. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5715. ReplaceVecGEP(cast<GEPOperator>(GEP), idxList, A, Builder);
  5716. GEP->eraseFromParent();
  5717. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User)) {
  5718. // If ld whole struct, need to split the load.
  5719. Value *newLd = UndefValue::get(ldInst->getType());
  5720. Value *zero = Builder.getInt32(0);
  5721. for (unsigned i = 0; i < size; i++) {
  5722. Value *idx = Builder.getInt32(i);
  5723. Value *GEP = Builder.CreateInBoundsGEP(A, {zero, idx});
  5724. Value *Elt = Builder.CreateLoad(GEP);
  5725. newLd = Builder.CreateInsertElement(newLd, Elt, i);
  5726. }
  5727. ldInst->replaceAllUsesWith(newLd);
  5728. ldInst->eraseFromParent();
  5729. } else if (StoreInst *stInst = dyn_cast<StoreInst>(User)) {
  5730. Value *val = stInst->getValueOperand();
  5731. Value *zero = Builder.getInt32(0);
  5732. for (unsigned i = 0; i < size; i++) {
  5733. Value *Elt = Builder.CreateExtractElement(val, i);
  5734. Value *idx = Builder.getInt32(i);
  5735. Value *GEP = Builder.CreateInBoundsGEP(A, {zero, idx});
  5736. Builder.CreateStore(Elt, GEP);
  5737. }
  5738. stInst->eraseFromParent();
  5739. } else {
  5740. // Vector parameter should be lowered.
  5741. // No function call should use vector.
  5742. DXASSERT(0, "not implement yet");
  5743. }
  5744. }
  5745. }
  5746. void DynamicIndexingVectorToArray::ReplaceVecArrayGEP(Value *GEP,
  5747. ArrayRef<Value *> idxList, Value *A,
  5748. IRBuilder<> &Builder) {
  5749. Value *newGEP = Builder.CreateGEP(A, idxList);
  5750. Type *Ty = GEP->getType()->getPointerElementType();
  5751. if (Ty->isVectorTy()) {
  5752. ReplaceVectorWithArray(GEP, newGEP);
  5753. } else if (Ty->isArrayTy()) {
  5754. ReplaceVectorArrayWithArray(GEP, newGEP);
  5755. } else {
  5756. DXASSERT(Ty->isSingleValueType(), "must be vector subscript here");
  5757. GEP->replaceAllUsesWith(newGEP);
  5758. }
  5759. }
  5760. void DynamicIndexingVectorToArray::ReplaceVectorArrayWithArray(Value *VA, Value *A) {
  5761. for (auto U = VA->user_begin(); U != VA->user_end();) {
  5762. User *User = *(U++);
  5763. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5764. IRBuilder<> Builder(GEP);
  5765. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5766. ReplaceVecArrayGEP(GEP, idxList, A, Builder);
  5767. GEP->eraseFromParent();
  5768. } else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(User)) {
  5769. IRBuilder<> Builder(GEPOp->getContext());
  5770. SmallVector<Value *, 4> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  5771. ReplaceVecArrayGEP(GEPOp, idxList, A, Builder);
  5772. } else {
  5773. DXASSERT(0, "Array pointer should only used by GEP");
  5774. }
  5775. }
  5776. }
  5777. void DynamicIndexingVectorToArray::lowerUseWithNewValue(Value *V, Value *NewV) {
  5778. Type *Ty = V->getType()->getPointerElementType();
  5779. // Replace V with NewV.
  5780. if (Ty->isVectorTy()) {
  5781. ReplaceVectorWithArray(V, NewV);
  5782. } else {
  5783. ReplaceVectorArrayWithArray(V, NewV);
  5784. }
  5785. }
  5786. Type *DynamicIndexingVectorToArray::lowerType(Type *Ty) {
  5787. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  5788. return ArrayType::get(VT->getElementType(), VT->getNumElements());
  5789. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  5790. SmallVector<ArrayType *, 4> nestArrayTys;
  5791. nestArrayTys.emplace_back(AT);
  5792. Type *EltTy = AT->getElementType();
  5793. // support multi level of array
  5794. while (EltTy->isArrayTy()) {
  5795. ArrayType *ElAT = cast<ArrayType>(EltTy);
  5796. nestArrayTys.emplace_back(ElAT);
  5797. EltTy = ElAT->getElementType();
  5798. }
  5799. if (EltTy->isVectorTy()) {
  5800. Type *vecAT = ArrayType::get(EltTy->getVectorElementType(),
  5801. EltTy->getVectorNumElements());
  5802. return CreateNestArrayTy(vecAT, nestArrayTys);
  5803. }
  5804. return nullptr;
  5805. }
  5806. return nullptr;
  5807. }
  5808. Constant *DynamicIndexingVectorToArray::lowerInitVal(Constant *InitVal, Type *NewTy) {
  5809. Type *VecTy = InitVal->getType();
  5810. ArrayType *ArrayTy = cast<ArrayType>(NewTy);
  5811. if (VecTy->isVectorTy()) {
  5812. SmallVector<Constant *, 4> Elts;
  5813. for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
  5814. Elts.emplace_back(InitVal->getAggregateElement(i));
  5815. }
  5816. return ConstantArray::get(ArrayTy, Elts);
  5817. } else {
  5818. ArrayType *AT = cast<ArrayType>(VecTy);
  5819. ArrayType *EltArrayTy = cast<ArrayType>(ArrayTy->getElementType());
  5820. SmallVector<Constant *, 4> Elts;
  5821. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  5822. Constant *Elt = lowerInitVal(InitVal->getAggregateElement(i), EltArrayTy);
  5823. Elts.emplace_back(Elt);
  5824. }
  5825. return ConstantArray::get(ArrayTy, Elts);
  5826. }
  5827. }
  5828. bool DynamicIndexingVectorToArray::HasVectorDynamicIndexing(Value *V) {
  5829. for (auto User : V->users()) {
  5830. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5831. for (auto Idx = GEP->idx_begin(); Idx != GEP->idx_end(); ++Idx) {
  5832. if (!isa<ConstantInt>(Idx))
  5833. return true;
  5834. }
  5835. }
  5836. }
  5837. return false;
  5838. }
  5839. }
  5840. char DynamicIndexingVectorToArray::ID = 0;
  5841. INITIALIZE_PASS(DynamicIndexingVectorToArray, "dynamic-vector-to-array",
  5842. "Replace dynamic indexing vector with array", false,
  5843. false)
  5844. // Public interface to the DynamicIndexingVectorToArray pass
  5845. ModulePass *llvm::createDynamicIndexingVectorToArrayPass(bool ReplaceAllVector) {
  5846. return new DynamicIndexingVectorToArray(ReplaceAllVector);
  5847. }
  5848. //===----------------------------------------------------------------------===//
  5849. // Flatten multi dim array into 1 dim.
  5850. //===----------------------------------------------------------------------===//
  5851. namespace {
  5852. class MultiDimArrayToOneDimArray : public LowerTypePass {
  5853. public:
  5854. explicit MultiDimArrayToOneDimArray() : LowerTypePass(ID) {}
  5855. static char ID; // Pass identification, replacement for typeid
  5856. protected:
  5857. bool needToLower(Value *V) override;
  5858. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  5859. Type *lowerType(Type *Ty) override;
  5860. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  5861. StringRef getGlobalPrefix() override { return ".1dim"; }
  5862. };
  5863. bool MultiDimArrayToOneDimArray::needToLower(Value *V) {
  5864. Type *Ty = V->getType()->getPointerElementType();
  5865. ArrayType *AT = dyn_cast<ArrayType>(Ty);
  5866. if (!AT)
  5867. return false;
  5868. if (!isa<ArrayType>(AT->getElementType())) {
  5869. return false;
  5870. } else {
  5871. // Merge all GEP.
  5872. HLModule::MergeGepUse(V);
  5873. return true;
  5874. }
  5875. }
  5876. void ReplaceMultiDimGEP(User *GEP, Value *OneDim, IRBuilder<> &Builder) {
  5877. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  5878. Value *PtrOffset = GEPIt.getOperand();
  5879. ++GEPIt;
  5880. Value *ArrayIdx = GEPIt.getOperand();
  5881. ++GEPIt;
  5882. Value *VecIdx = nullptr;
  5883. for (; GEPIt != E; ++GEPIt) {
  5884. if (GEPIt->isArrayTy()) {
  5885. unsigned arraySize = GEPIt->getArrayNumElements();
  5886. Value *V = GEPIt.getOperand();
  5887. ArrayIdx = Builder.CreateMul(ArrayIdx, Builder.getInt32(arraySize));
  5888. ArrayIdx = Builder.CreateAdd(V, ArrayIdx);
  5889. } else {
  5890. DXASSERT_NOMSG(isa<VectorType>(*GEPIt));
  5891. VecIdx = GEPIt.getOperand();
  5892. }
  5893. }
  5894. Value *NewGEP = nullptr;
  5895. if (!VecIdx)
  5896. NewGEP = Builder.CreateGEP(OneDim, {PtrOffset, ArrayIdx});
  5897. else
  5898. NewGEP = Builder.CreateGEP(OneDim, {PtrOffset, ArrayIdx, VecIdx});
  5899. GEP->replaceAllUsesWith(NewGEP);
  5900. }
  5901. void MultiDimArrayToOneDimArray::lowerUseWithNewValue(Value *MultiDim, Value *OneDim) {
  5902. LLVMContext &Context = MultiDim->getContext();
  5903. // All users should be element type.
  5904. // Replace users of AI.
  5905. for (auto it = MultiDim->user_begin(); it != MultiDim->user_end();) {
  5906. User *U = *(it++);
  5907. if (U->user_empty())
  5908. continue;
  5909. // Must be GEP.
  5910. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U);
  5911. if (!GEP) {
  5912. DXASSERT_NOMSG(isa<GEPOperator>(U));
  5913. // NewGEP must be GEPOperator too.
  5914. // No instruction will be build.
  5915. IRBuilder<> Builder(Context);
  5916. ReplaceMultiDimGEP(U, OneDim, Builder);
  5917. } else {
  5918. IRBuilder<> Builder(GEP);
  5919. ReplaceMultiDimGEP(U, OneDim, Builder);
  5920. }
  5921. if (GEP)
  5922. GEP->eraseFromParent();
  5923. }
  5924. }
  5925. Type *MultiDimArrayToOneDimArray::lowerType(Type *Ty) {
  5926. ArrayType *AT = cast<ArrayType>(Ty);
  5927. unsigned arraySize = AT->getNumElements();
  5928. Type *EltTy = AT->getElementType();
  5929. // support multi level of array
  5930. while (EltTy->isArrayTy()) {
  5931. ArrayType *ElAT = cast<ArrayType>(EltTy);
  5932. arraySize *= ElAT->getNumElements();
  5933. EltTy = ElAT->getElementType();
  5934. }
  5935. return ArrayType::get(EltTy, arraySize);
  5936. }
  5937. void FlattenMultiDimConstArray(Constant *V, std::vector<Constant *> &Elts) {
  5938. if (!V->getType()->isArrayTy()) {
  5939. Elts.emplace_back(V);
  5940. } else {
  5941. ArrayType *AT = cast<ArrayType>(V->getType());
  5942. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  5943. FlattenMultiDimConstArray(V->getAggregateElement(i), Elts);
  5944. }
  5945. }
  5946. }
  5947. Constant *MultiDimArrayToOneDimArray::lowerInitVal(Constant *InitVal, Type *NewTy) {
  5948. if (InitVal) {
  5949. // MultiDim array init should be done by store.
  5950. if (isa<ConstantAggregateZero>(InitVal))
  5951. InitVal = ConstantAggregateZero::get(NewTy);
  5952. else if (isa<UndefValue>(InitVal))
  5953. InitVal = UndefValue::get(NewTy);
  5954. else {
  5955. std::vector<Constant *> Elts;
  5956. FlattenMultiDimConstArray(InitVal, Elts);
  5957. InitVal = ConstantArray::get(cast<ArrayType>(NewTy), Elts);
  5958. }
  5959. } else {
  5960. InitVal = UndefValue::get(NewTy);
  5961. }
  5962. return InitVal;
  5963. }
  5964. }
  5965. char MultiDimArrayToOneDimArray::ID = 0;
  5966. INITIALIZE_PASS(MultiDimArrayToOneDimArray, "multi-dim-one-dim",
  5967. "Flatten multi-dim array into one-dim array", false,
  5968. false)
  5969. // Public interface to the SROA_Parameter_HLSL pass
  5970. ModulePass *llvm::createMultiDimArrayToOneDimArrayPass() {
  5971. return new MultiDimArrayToOneDimArray();
  5972. }
  5973. //===----------------------------------------------------------------------===//
  5974. // Lower resource into handle.
  5975. //===----------------------------------------------------------------------===//
  5976. namespace {
  5977. class ResourceToHandle : public LowerTypePass {
  5978. public:
  5979. explicit ResourceToHandle() : LowerTypePass(ID) {}
  5980. static char ID; // Pass identification, replacement for typeid
  5981. protected:
  5982. bool needToLower(Value *V) override;
  5983. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  5984. Type *lowerType(Type *Ty) override;
  5985. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  5986. StringRef getGlobalPrefix() override { return ".res"; }
  5987. void initialize(Module &M) override;
  5988. private:
  5989. void ReplaceResourceWithHandle(Value *ResPtr, Value *HandlePtr);
  5990. void ReplaceResourceGEPWithHandleGEP(Value *GEP, ArrayRef<Value *> idxList,
  5991. Value *A, IRBuilder<> &Builder);
  5992. void ReplaceResourceArrayWithHandleArray(Value *VA, Value *A);
  5993. Type *m_HandleTy;
  5994. HLModule *m_pHLM;
  5995. };
  5996. void ResourceToHandle::initialize(Module &M) {
  5997. DXASSERT(M.HasHLModule(), "require HLModule");
  5998. m_pHLM = &M.GetHLModule();
  5999. m_HandleTy = m_pHLM->GetOP()->GetHandleType();
  6000. }
  6001. bool ResourceToHandle::needToLower(Value *V) {
  6002. Type *Ty = V->getType()->getPointerElementType();
  6003. Ty = HLModule::GetArrayEltTy(Ty);
  6004. return (HLModule::IsHLSLObjectType(Ty) && !HLModule::IsStreamOutputType(Ty));
  6005. }
  6006. Type *ResourceToHandle::lowerType(Type *Ty) {
  6007. if ((HLModule::IsHLSLObjectType(Ty) && !HLModule::IsStreamOutputType(Ty))) {
  6008. return m_HandleTy;
  6009. }
  6010. ArrayType *AT = cast<ArrayType>(Ty);
  6011. SmallVector<ArrayType *, 4> nestArrayTys;
  6012. nestArrayTys.emplace_back(AT);
  6013. Type *EltTy = AT->getElementType();
  6014. // support multi level of array
  6015. while (EltTy->isArrayTy()) {
  6016. ArrayType *ElAT = cast<ArrayType>(EltTy);
  6017. nestArrayTys.emplace_back(ElAT);
  6018. EltTy = ElAT->getElementType();
  6019. }
  6020. return CreateNestArrayTy(m_HandleTy, nestArrayTys);
  6021. }
  6022. Constant *ResourceToHandle::lowerInitVal(Constant *InitVal, Type *NewTy) {
  6023. DXASSERT(isa<UndefValue>(InitVal), "resource cannot have real init val");
  6024. return UndefValue::get(NewTy);
  6025. }
  6026. void ResourceToHandle::ReplaceResourceWithHandle(Value *ResPtr,
  6027. Value *HandlePtr) {
  6028. for (auto it = ResPtr->user_begin(); it != ResPtr->user_end();) {
  6029. User *U = *(it++);
  6030. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  6031. IRBuilder<> Builder(LI);
  6032. Value *Handle = Builder.CreateLoad(HandlePtr);
  6033. Type *ResTy = LI->getType();
  6034. // Used by createHandle or Store.
  6035. for (auto ldIt = LI->user_begin(); ldIt != LI->user_end();) {
  6036. User *ldU = *(ldIt++);
  6037. if (StoreInst *SI = dyn_cast<StoreInst>(ldU)) {
  6038. Value *TmpRes = HLModule::EmitHLOperationCall(
  6039. Builder, HLOpcodeGroup::HLCast,
  6040. (unsigned)HLCastOpcode::HandleToResCast, ResTy, {Handle},
  6041. *m_pHLM->GetModule());
  6042. SI->replaceUsesOfWith(LI, TmpRes);
  6043. } else {
  6044. CallInst *CI = cast<CallInst>(ldU);
  6045. HLOpcodeGroup group =
  6046. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  6047. DXASSERT(group == HLOpcodeGroup::HLCreateHandle,
  6048. "must be createHandle");
  6049. CI->replaceAllUsesWith(Handle);
  6050. CI->eraseFromParent();
  6051. }
  6052. }
  6053. LI->eraseFromParent();
  6054. } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  6055. Value *Res = SI->getValueOperand();
  6056. IRBuilder<> Builder(SI);
  6057. // CreateHandle from Res.
  6058. Value *Handle = HLModule::EmitHLOperationCall(
  6059. Builder, HLOpcodeGroup::HLCreateHandle,
  6060. /*opcode*/ 0, m_HandleTy, {Res}, *m_pHLM->GetModule());
  6061. // Store Handle to HandlePtr.
  6062. Builder.CreateStore(Handle, HandlePtr);
  6063. // Remove resource Store.
  6064. SI->eraseFromParent();
  6065. } else {
  6066. DXASSERT(0, "invalid operation on resource");
  6067. }
  6068. }
  6069. }
  6070. void ResourceToHandle::ReplaceResourceGEPWithHandleGEP(
  6071. Value *GEP, ArrayRef<Value *> idxList, Value *A, IRBuilder<> &Builder) {
  6072. Value *newGEP = Builder.CreateGEP(A, idxList);
  6073. Type *Ty = GEP->getType()->getPointerElementType();
  6074. if (Ty->isArrayTy()) {
  6075. ReplaceResourceArrayWithHandleArray(GEP, newGEP);
  6076. } else {
  6077. DXASSERT(HLModule::IsHLSLObjectType(Ty), "must be resource type here");
  6078. ReplaceResourceWithHandle(GEP, newGEP);
  6079. }
  6080. }
  6081. void ResourceToHandle::ReplaceResourceArrayWithHandleArray(Value *VA,
  6082. Value *A) {
  6083. for (auto U = VA->user_begin(); U != VA->user_end();) {
  6084. User *User = *(U++);
  6085. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  6086. IRBuilder<> Builder(GEP);
  6087. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  6088. ReplaceResourceGEPWithHandleGEP(GEP, idxList, A, Builder);
  6089. GEP->eraseFromParent();
  6090. } else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(User)) {
  6091. IRBuilder<> Builder(GEPOp->getContext());
  6092. SmallVector<Value *, 4> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  6093. ReplaceResourceGEPWithHandleGEP(GEPOp, idxList, A, Builder);
  6094. } else {
  6095. DXASSERT(0, "Array pointer should only used by GEP");
  6096. }
  6097. }
  6098. }
  6099. void ResourceToHandle::lowerUseWithNewValue(Value *V, Value *NewV) {
  6100. Type *Ty = V->getType()->getPointerElementType();
  6101. // Replace V with NewV.
  6102. if (Ty->isArrayTy()) {
  6103. ReplaceResourceArrayWithHandleArray(V, NewV);
  6104. } else {
  6105. ReplaceResourceWithHandle(V, NewV);
  6106. }
  6107. }
  6108. }
  6109. char ResourceToHandle::ID = 0;
  6110. INITIALIZE_PASS(ResourceToHandle, "resource-handle",
  6111. "Lower resource into handle", false,
  6112. false)
  6113. // Public interface to the ResourceToHandle pass
  6114. ModulePass *llvm::createResourceToHandlePass() {
  6115. return new ResourceToHandle();
  6116. }