ScalarReplAggregatesHLSL.cpp 260 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060
  1. //===- ScalarReplAggregatesHLSL.cpp - Scalar Replacement of Aggregates ----===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //===----------------------------------------------------------------------===//
  10. //
  11. // Based on ScalarReplAggregates.cpp. The difference is HLSL version will keep
  12. // array so it can break up all structure.
  13. //
  14. //===----------------------------------------------------------------------===//
  15. #include "llvm/ADT/SetVector.h"
  16. #include "llvm/ADT/SmallVector.h"
  17. #include "llvm/ADT/Statistic.h"
  18. #include "llvm/Analysis/AssumptionCache.h"
  19. #include "llvm/Analysis/Loads.h"
  20. #include "llvm/Analysis/ValueTracking.h"
  21. #include "llvm/IR/CallSite.h"
  22. #include "llvm/IR/Constants.h"
  23. #include "llvm/IR/DIBuilder.h"
  24. #include "llvm/IR/DataLayout.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/DerivedTypes.h"
  27. #include "llvm/IR/Dominators.h"
  28. #include "llvm/IR/Function.h"
  29. #include "llvm/IR/GetElementPtrTypeIterator.h"
  30. #include "llvm/IR/GlobalVariable.h"
  31. #include "llvm/IR/IRBuilder.h"
  32. #include "llvm/IR/Instructions.h"
  33. #include "llvm/IR/IntrinsicInst.h"
  34. #include "llvm/IR/LLVMContext.h"
  35. #include "llvm/IR/Module.h"
  36. #include "llvm/IR/Operator.h"
  37. #include "llvm/Pass.h"
  38. #include "llvm/Support/Debug.h"
  39. #include "llvm/Support/ErrorHandling.h"
  40. #include "llvm/Support/MathExtras.h"
  41. #include "llvm/Support/raw_ostream.h"
  42. #include "llvm/Transforms/Scalar.h"
  43. #include "llvm/Transforms/Utils/Local.h"
  44. #include "llvm/Transforms/Utils/PromoteMemToReg.h"
  45. #include "llvm/Transforms/Utils/SSAUpdater.h"
  46. #include "llvm/Transforms/Utils/Local.h"
  47. #include "dxc/HLSL/HLOperations.h"
  48. #include "dxc/HLSL/DxilConstants.h"
  49. #include "dxc/HLSL/HLModule.h"
  50. #include "dxc/HLSL/DxilUtil.h"
  51. #include "dxc/HLSL/DxilModule.h"
  52. #include "dxc/HlslIntrinsicOp.h"
  53. #include "dxc/HLSL/DxilTypeSystem.h"
  54. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  55. #include "dxc/HLSL/DxilOperations.h"
  56. #include <deque>
  57. #include <unordered_map>
  58. #include <unordered_set>
  59. using namespace llvm;
  60. using namespace hlsl;
  61. #define DEBUG_TYPE "scalarreplhlsl"
  62. STATISTIC(NumReplaced, "Number of allocas broken up");
  63. STATISTIC(NumPromoted, "Number of allocas promoted");
  64. STATISTIC(NumAdjusted, "Number of scalar allocas adjusted to allow promotion");
  65. STATISTIC(NumConverted, "Number of aggregates converted to scalar");
  66. namespace {
  67. class SROA_Helper {
  68. public:
  69. // Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  70. // Then do SROA on V.
  71. static bool DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  72. IRBuilder<> &Builder, bool bFlatVector,
  73. bool hasPrecise, DxilTypeSystem &typeSys,
  74. SmallVector<Value *, 32> &DeadInsts);
  75. static bool DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
  76. IRBuilder<> &Builder, bool bFlatVector,
  77. bool hasPrecise, DxilTypeSystem &typeSys,
  78. SmallVector<Value *, 32> &DeadInsts);
  79. // Lower memcpy related to V.
  80. static bool LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  81. DxilTypeSystem &typeSys, const DataLayout &DL,
  82. bool bAllowReplace);
  83. static void MarkEmptyStructUsers(Value *V,
  84. SmallVector<Value *, 32> &DeadInsts);
  85. static bool IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys);
  86. private:
  87. SROA_Helper(Value *V, ArrayRef<Value *> Elts,
  88. SmallVector<Value *, 32> &DeadInsts)
  89. : OldVal(V), NewElts(Elts), DeadInsts(DeadInsts) {}
  90. void RewriteForScalarRepl(Value *V, IRBuilder<> &Builder);
  91. private:
  92. // Must be a pointer type val.
  93. Value * OldVal;
  94. // Flattened elements for OldVal.
  95. ArrayRef<Value*> NewElts;
  96. SmallVector<Value *, 32> &DeadInsts;
  97. void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
  98. void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
  99. void RewriteForLoad(LoadInst *loadInst);
  100. void RewriteForStore(StoreInst *storeInst);
  101. void RewriteMemIntrin(MemIntrinsic *MI, Instruction *Inst);
  102. void RewriteCall(CallInst *CI);
  103. void RewriteBitCast(BitCastInst *BCI);
  104. };
  105. struct SROA_HLSL : public FunctionPass {
  106. SROA_HLSL(bool Promote, int T, bool hasDT, char &ID, int ST, int AT, int SLT)
  107. : FunctionPass(ID), HasDomTree(hasDT), RunPromotion(Promote) {
  108. if (AT == -1)
  109. ArrayElementThreshold = 8;
  110. else
  111. ArrayElementThreshold = AT;
  112. if (SLT == -1)
  113. // Do not limit the scalar integer load size if no threshold is given.
  114. ScalarLoadThreshold = -1;
  115. else
  116. ScalarLoadThreshold = SLT;
  117. }
  118. bool runOnFunction(Function &F) override;
  119. bool performScalarRepl(Function &F, DxilTypeSystem &typeSys);
  120. bool performPromotion(Function &F);
  121. bool markPrecise(Function &F);
  122. private:
  123. bool HasDomTree;
  124. bool RunPromotion;
  125. /// DeadInsts - Keep track of instructions we have made dead, so that
  126. /// we can remove them after we are done working.
  127. SmallVector<Value *, 32> DeadInsts;
  128. /// AllocaInfo - When analyzing uses of an alloca instruction, this captures
  129. /// information about the uses. All these fields are initialized to false
  130. /// and set to true when something is learned.
  131. struct AllocaInfo {
  132. /// The alloca to promote.
  133. AllocaInst *AI;
  134. /// CheckedPHIs - This is a set of verified PHI nodes, to prevent infinite
  135. /// looping and avoid redundant work.
  136. SmallPtrSet<PHINode *, 8> CheckedPHIs;
  137. /// isUnsafe - This is set to true if the alloca cannot be SROA'd.
  138. bool isUnsafe : 1;
  139. /// isMemCpySrc - This is true if this aggregate is memcpy'd from.
  140. bool isMemCpySrc : 1;
  141. /// isMemCpyDst - This is true if this aggregate is memcpy'd into.
  142. bool isMemCpyDst : 1;
  143. /// hasSubelementAccess - This is true if a subelement of the alloca is
  144. /// ever accessed, or false if the alloca is only accessed with mem
  145. /// intrinsics or load/store that only access the entire alloca at once.
  146. bool hasSubelementAccess : 1;
  147. /// hasALoadOrStore - This is true if there are any loads or stores to it.
  148. /// The alloca may just be accessed with memcpy, for example, which would
  149. /// not set this.
  150. bool hasALoadOrStore : 1;
  151. /// hasArrayIndexing - This is true if there are any dynamic array
  152. /// indexing to it.
  153. bool hasArrayIndexing : 1;
  154. /// hasVectorIndexing - This is true if there are any dynamic vector
  155. /// indexing to it.
  156. bool hasVectorIndexing : 1;
  157. explicit AllocaInfo(AllocaInst *ai)
  158. : AI(ai), isUnsafe(false), isMemCpySrc(false), isMemCpyDst(false),
  159. hasSubelementAccess(false), hasALoadOrStore(false),
  160. hasArrayIndexing(false), hasVectorIndexing(false) {}
  161. };
  162. /// ArrayElementThreshold - The maximum number of elements an array can
  163. /// have to be considered for SROA.
  164. unsigned ArrayElementThreshold;
  165. /// ScalarLoadThreshold - The maximum size in bits of scalars to load when
  166. /// converting to scalar
  167. unsigned ScalarLoadThreshold;
  168. void MarkUnsafe(AllocaInfo &I, Instruction *User) {
  169. I.isUnsafe = true;
  170. DEBUG(dbgs() << " Transformation preventing inst: " << *User << '\n');
  171. }
  172. bool isSafeAllocaToScalarRepl(AllocaInst *AI);
  173. void isSafeForScalarRepl(Instruction *I, uint64_t Offset, AllocaInfo &Info);
  174. void isSafePHISelectUseForScalarRepl(Instruction *User, uint64_t Offset,
  175. AllocaInfo &Info);
  176. void isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset, AllocaInfo &Info);
  177. void isSafeMemAccess(uint64_t Offset, uint64_t MemSize, Type *MemOpType,
  178. bool isStore, AllocaInfo &Info, Instruction *TheAccess,
  179. bool AllowWholeAccess);
  180. bool TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
  181. const DataLayout &DL);
  182. void DeleteDeadInstructions();
  183. bool ShouldAttemptScalarRepl(AllocaInst *AI);
  184. };
  185. // SROA_DT - SROA that uses DominatorTree.
  186. struct SROA_DT_HLSL : public SROA_HLSL {
  187. static char ID;
  188. public:
  189. SROA_DT_HLSL(bool Promote = false, int T = -1, int ST = -1, int AT = -1, int SLT = -1)
  190. : SROA_HLSL(Promote, T, true, ID, ST, AT, SLT) {
  191. initializeSROA_DTPass(*PassRegistry::getPassRegistry());
  192. }
  193. // getAnalysisUsage - This pass does not require any passes, but we know it
  194. // will not alter the CFG, so say so.
  195. void getAnalysisUsage(AnalysisUsage &AU) const override {
  196. AU.addRequired<AssumptionCacheTracker>();
  197. AU.addRequired<DominatorTreeWrapperPass>();
  198. AU.setPreservesCFG();
  199. }
  200. };
  201. // SROA_SSAUp - SROA that uses SSAUpdater.
  202. struct SROA_SSAUp_HLSL : public SROA_HLSL {
  203. static char ID;
  204. public:
  205. SROA_SSAUp_HLSL(bool Promote = false, int T = -1, int ST = -1, int AT = -1, int SLT = -1)
  206. : SROA_HLSL(Promote, T, false, ID, ST, AT, SLT) {
  207. initializeSROA_SSAUpPass(*PassRegistry::getPassRegistry());
  208. }
  209. // getAnalysisUsage - This pass does not require any passes, but we know it
  210. // will not alter the CFG, so say so.
  211. void getAnalysisUsage(AnalysisUsage &AU) const override {
  212. AU.addRequired<AssumptionCacheTracker>();
  213. AU.setPreservesCFG();
  214. }
  215. };
  216. // Simple struct to split memcpy into ld/st
  217. struct MemcpySplitter {
  218. llvm::LLVMContext &m_context;
  219. DxilTypeSystem &m_typeSys;
  220. public:
  221. MemcpySplitter(llvm::LLVMContext &context, DxilTypeSystem &typeSys)
  222. : m_context(context), m_typeSys(typeSys) {}
  223. void Split(llvm::Function &F);
  224. static void PatchMemCpyWithZeroIdxGEP(Module &M);
  225. static void PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI, const DataLayout &DL);
  226. static void SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  227. DxilFieldAnnotation *fieldAnnotation,
  228. DxilTypeSystem &typeSys);
  229. };
  230. }
  231. char SROA_DT_HLSL::ID = 0;
  232. char SROA_SSAUp_HLSL::ID = 0;
  233. INITIALIZE_PASS_BEGIN(SROA_DT_HLSL, "scalarreplhlsl",
  234. "Scalar Replacement of Aggregates HLSL (DT)", false,
  235. false)
  236. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  237. INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
  238. INITIALIZE_PASS_END(SROA_DT_HLSL, "scalarreplhlsl",
  239. "Scalar Replacement of Aggregates HLSL (DT)", false, false)
  240. INITIALIZE_PASS_BEGIN(SROA_SSAUp_HLSL, "scalarreplhlsl-ssa",
  241. "Scalar Replacement of Aggregates HLSL (SSAUp)", false,
  242. false)
  243. INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
  244. INITIALIZE_PASS_END(SROA_SSAUp_HLSL, "scalarreplhlsl-ssa",
  245. "Scalar Replacement of Aggregates HLSL (SSAUp)", false,
  246. false)
  247. // Public interface to the ScalarReplAggregates pass
  248. FunctionPass *llvm::createScalarReplAggregatesHLSLPass(bool UseDomTree, bool Promote) {
  249. if (UseDomTree)
  250. return new SROA_DT_HLSL(Promote);
  251. return new SROA_SSAUp_HLSL(Promote);
  252. }
  253. //===----------------------------------------------------------------------===//
  254. // Convert To Scalar Optimization.
  255. //===----------------------------------------------------------------------===//
  256. namespace {
  257. /// ConvertToScalarInfo - This class implements the "Convert To Scalar"
  258. /// optimization, which scans the uses of an alloca and determines if it can
  259. /// rewrite it in terms of a single new alloca that can be mem2reg'd.
  260. class ConvertToScalarInfo {
  261. /// AllocaSize - The size of the alloca being considered in bytes.
  262. unsigned AllocaSize;
  263. const DataLayout &DL;
  264. unsigned ScalarLoadThreshold;
  265. /// IsNotTrivial - This is set to true if there is some access to the object
  266. /// which means that mem2reg can't promote it.
  267. bool IsNotTrivial;
  268. /// ScalarKind - Tracks the kind of alloca being considered for promotion,
  269. /// computed based on the uses of the alloca rather than the LLVM type system.
  270. enum {
  271. Unknown,
  272. // Accesses via GEPs that are consistent with element access of a vector
  273. // type. This will not be converted into a vector unless there is a later
  274. // access using an actual vector type.
  275. ImplicitVector,
  276. // Accesses via vector operations and GEPs that are consistent with the
  277. // layout of a vector type.
  278. Vector,
  279. // An integer bag-of-bits with bitwise operations for insertion and
  280. // extraction. Any combination of types can be converted into this kind
  281. // of scalar.
  282. Integer
  283. } ScalarKind;
  284. /// VectorTy - This tracks the type that we should promote the vector to if
  285. /// it is possible to turn it into a vector. This starts out null, and if it
  286. /// isn't possible to turn into a vector type, it gets set to VoidTy.
  287. VectorType *VectorTy;
  288. /// HadNonMemTransferAccess - True if there is at least one access to the
  289. /// alloca that is not a MemTransferInst. We don't want to turn structs into
  290. /// large integers unless there is some potential for optimization.
  291. bool HadNonMemTransferAccess;
  292. /// HadDynamicAccess - True if some element of this alloca was dynamic.
  293. /// We don't yet have support for turning a dynamic access into a large
  294. /// integer.
  295. bool HadDynamicAccess;
  296. public:
  297. explicit ConvertToScalarInfo(unsigned Size, const DataLayout &DL,
  298. unsigned SLT)
  299. : AllocaSize(Size), DL(DL), ScalarLoadThreshold(SLT), IsNotTrivial(false),
  300. ScalarKind(Unknown), VectorTy(nullptr), HadNonMemTransferAccess(false),
  301. HadDynamicAccess(false) {}
  302. AllocaInst *TryConvert(AllocaInst *AI);
  303. private:
  304. bool CanConvertToScalar(Value *V, uint64_t Offset, Value *NonConstantIdx);
  305. void MergeInTypeForLoadOrStore(Type *In, uint64_t Offset);
  306. bool MergeInVectorType(VectorType *VInTy, uint64_t Offset);
  307. void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset,
  308. Value *NonConstantIdx);
  309. Value *ConvertScalar_ExtractValue(Value *NV, Type *ToType, uint64_t Offset,
  310. Value *NonConstantIdx,
  311. IRBuilder<> &Builder);
  312. Value *ConvertScalar_InsertValue(Value *StoredVal, Value *ExistingVal,
  313. uint64_t Offset, Value *NonConstantIdx,
  314. IRBuilder<> &Builder);
  315. };
  316. } // end anonymous namespace.
  317. /// TryConvert - Analyze the specified alloca, and if it is safe to do so,
  318. /// rewrite it to be a new alloca which is mem2reg'able. This returns the new
  319. /// alloca if possible or null if not.
  320. AllocaInst *ConvertToScalarInfo::TryConvert(AllocaInst *AI) {
  321. // If we can't convert this scalar, or if mem2reg can trivially do it, bail
  322. // out.
  323. if (!CanConvertToScalar(AI, 0, nullptr) || !IsNotTrivial)
  324. return nullptr;
  325. // If an alloca has only memset / memcpy uses, it may still have an Unknown
  326. // ScalarKind. Treat it as an Integer below.
  327. if (ScalarKind == Unknown)
  328. ScalarKind = Integer;
  329. if (ScalarKind == Vector && VectorTy->getBitWidth() != AllocaSize * 8)
  330. ScalarKind = Integer;
  331. // If we were able to find a vector type that can handle this with
  332. // insert/extract elements, and if there was at least one use that had
  333. // a vector type, promote this to a vector. We don't want to promote
  334. // random stuff that doesn't use vectors (e.g. <9 x double>) because then
  335. // we just get a lot of insert/extracts. If at least one vector is
  336. // involved, then we probably really do have a union of vector/array.
  337. Type *NewTy;
  338. if (ScalarKind == Vector) {
  339. assert(VectorTy && "Missing type for vector scalar.");
  340. DEBUG(dbgs() << "CONVERT TO VECTOR: " << *AI << "\n TYPE = " << *VectorTy
  341. << '\n');
  342. NewTy = VectorTy; // Use the vector type.
  343. } else {
  344. unsigned BitWidth = AllocaSize * 8;
  345. // Do not convert to scalar integer if the alloca size exceeds the
  346. // scalar load threshold.
  347. if (BitWidth > ScalarLoadThreshold)
  348. return nullptr;
  349. if ((ScalarKind == ImplicitVector || ScalarKind == Integer) &&
  350. !HadNonMemTransferAccess && !DL.fitsInLegalInteger(BitWidth))
  351. return nullptr;
  352. // Dynamic accesses on integers aren't yet supported. They need us to shift
  353. // by a dynamic amount which could be difficult to work out as we might not
  354. // know whether to use a left or right shift.
  355. if (ScalarKind == Integer && HadDynamicAccess)
  356. return nullptr;
  357. DEBUG(dbgs() << "CONVERT TO SCALAR INTEGER: " << *AI << "\n");
  358. // Create and insert the integer alloca.
  359. NewTy = IntegerType::get(AI->getContext(), BitWidth);
  360. }
  361. AllocaInst *NewAI =
  362. new AllocaInst(NewTy, nullptr, "", AI->getParent()->begin());
  363. ConvertUsesToScalar(AI, NewAI, 0, nullptr);
  364. return NewAI;
  365. }
  366. /// MergeInTypeForLoadOrStore - Add the 'In' type to the accumulated vector type
  367. /// (VectorTy) so far at the offset specified by Offset (which is specified in
  368. /// bytes).
  369. ///
  370. /// There are two cases we handle here:
  371. /// 1) A union of vector types of the same size and potentially its elements.
  372. /// Here we turn element accesses into insert/extract element operations.
  373. /// This promotes a <4 x float> with a store of float to the third element
  374. /// into a <4 x float> that uses insert element.
  375. /// 2) A fully general blob of memory, which we turn into some (potentially
  376. /// large) integer type with extract and insert operations where the loads
  377. /// and stores would mutate the memory. We mark this by setting VectorTy
  378. /// to VoidTy.
  379. void ConvertToScalarInfo::MergeInTypeForLoadOrStore(Type *In, uint64_t Offset) {
  380. // If we already decided to turn this into a blob of integer memory, there is
  381. // nothing to be done.
  382. if (ScalarKind == Integer)
  383. return;
  384. // If this could be contributing to a vector, analyze it.
  385. // If the In type is a vector that is the same size as the alloca, see if it
  386. // matches the existing VecTy.
  387. if (VectorType *VInTy = dyn_cast<VectorType>(In)) {
  388. if (MergeInVectorType(VInTy, Offset))
  389. return;
  390. } else if (In->isFloatTy() || In->isDoubleTy() ||
  391. (In->isIntegerTy() && In->getPrimitiveSizeInBits() >= 8 &&
  392. isPowerOf2_32(In->getPrimitiveSizeInBits()))) {
  393. // Full width accesses can be ignored, because they can always be turned
  394. // into bitcasts.
  395. unsigned EltSize = In->getPrimitiveSizeInBits() / 8;
  396. if (EltSize == AllocaSize)
  397. return;
  398. // If we're accessing something that could be an element of a vector, see
  399. // if the implied vector agrees with what we already have and if Offset is
  400. // compatible with it.
  401. if (Offset % EltSize == 0 && AllocaSize % EltSize == 0 &&
  402. (!VectorTy ||
  403. EltSize == VectorTy->getElementType()->getPrimitiveSizeInBits() / 8)) {
  404. if (!VectorTy) {
  405. ScalarKind = ImplicitVector;
  406. VectorTy = VectorType::get(In, AllocaSize / EltSize);
  407. }
  408. return;
  409. }
  410. }
  411. // Otherwise, we have a case that we can't handle with an optimized vector
  412. // form. We can still turn this into a large integer.
  413. ScalarKind = Integer;
  414. }
  415. /// MergeInVectorType - Handles the vector case of MergeInTypeForLoadOrStore,
  416. /// returning true if the type was successfully merged and false otherwise.
  417. bool ConvertToScalarInfo::MergeInVectorType(VectorType *VInTy,
  418. uint64_t Offset) {
  419. if (VInTy->getBitWidth() / 8 == AllocaSize && Offset == 0) {
  420. // If we're storing/loading a vector of the right size, allow it as a
  421. // vector. If this the first vector we see, remember the type so that
  422. // we know the element size. If this is a subsequent access, ignore it
  423. // even if it is a differing type but the same size. Worst case we can
  424. // bitcast the resultant vectors.
  425. if (!VectorTy)
  426. VectorTy = VInTy;
  427. ScalarKind = Vector;
  428. return true;
  429. }
  430. return false;
  431. }
  432. /// CanConvertToScalar - V is a pointer. If we can convert the pointee and all
  433. /// its accesses to a single vector type, return true and set VecTy to
  434. /// the new type. If we could convert the alloca into a single promotable
  435. /// integer, return true but set VecTy to VoidTy. Further, if the use is not a
  436. /// completely trivial use that mem2reg could promote, set IsNotTrivial. Offset
  437. /// is the current offset from the base of the alloca being analyzed.
  438. ///
  439. /// If we see at least one access to the value that is as a vector type, set the
  440. /// SawVec flag.
  441. bool ConvertToScalarInfo::CanConvertToScalar(Value *V, uint64_t Offset,
  442. Value *NonConstantIdx) {
  443. for (User *U : V->users()) {
  444. Instruction *UI = cast<Instruction>(U);
  445. if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  446. // Don't break volatile loads.
  447. if (!LI->isSimple())
  448. return false;
  449. HadNonMemTransferAccess = true;
  450. MergeInTypeForLoadOrStore(LI->getType(), Offset);
  451. continue;
  452. }
  453. if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
  454. // Storing the pointer, not into the value?
  455. if (SI->getOperand(0) == V || !SI->isSimple())
  456. return false;
  457. HadNonMemTransferAccess = true;
  458. MergeInTypeForLoadOrStore(SI->getOperand(0)->getType(), Offset);
  459. continue;
  460. }
  461. if (BitCastInst *BCI = dyn_cast<BitCastInst>(UI)) {
  462. if (!onlyUsedByLifetimeMarkers(BCI))
  463. IsNotTrivial = true; // Can't be mem2reg'd.
  464. if (!CanConvertToScalar(BCI, Offset, NonConstantIdx))
  465. return false;
  466. continue;
  467. }
  468. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(UI)) {
  469. // If this is a GEP with a variable indices, we can't handle it.
  470. PointerType *PtrTy = dyn_cast<PointerType>(GEP->getPointerOperandType());
  471. if (!PtrTy)
  472. return false;
  473. // Compute the offset that this GEP adds to the pointer.
  474. SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
  475. Value *GEPNonConstantIdx = nullptr;
  476. if (!GEP->hasAllConstantIndices()) {
  477. if (!isa<VectorType>(PtrTy->getElementType()))
  478. return false;
  479. if (NonConstantIdx)
  480. return false;
  481. GEPNonConstantIdx = Indices.pop_back_val();
  482. if (!GEPNonConstantIdx->getType()->isIntegerTy(32))
  483. return false;
  484. HadDynamicAccess = true;
  485. } else
  486. GEPNonConstantIdx = NonConstantIdx;
  487. uint64_t GEPOffset = DL.getIndexedOffset(PtrTy, Indices);
  488. // See if all uses can be converted.
  489. if (!CanConvertToScalar(GEP, Offset + GEPOffset, GEPNonConstantIdx))
  490. return false;
  491. IsNotTrivial = true; // Can't be mem2reg'd.
  492. HadNonMemTransferAccess = true;
  493. continue;
  494. }
  495. // If this is a constant sized memset of a constant value (e.g. 0) we can
  496. // handle it.
  497. if (MemSetInst *MSI = dyn_cast<MemSetInst>(UI)) {
  498. // Store to dynamic index.
  499. if (NonConstantIdx)
  500. return false;
  501. // Store of constant value.
  502. if (!isa<ConstantInt>(MSI->getValue()))
  503. return false;
  504. // Store of constant size.
  505. ConstantInt *Len = dyn_cast<ConstantInt>(MSI->getLength());
  506. if (!Len)
  507. return false;
  508. // If the size differs from the alloca, we can only convert the alloca to
  509. // an integer bag-of-bits.
  510. // FIXME: This should handle all of the cases that are currently accepted
  511. // as vector element insertions.
  512. if (Len->getZExtValue() != AllocaSize || Offset != 0)
  513. ScalarKind = Integer;
  514. IsNotTrivial = true; // Can't be mem2reg'd.
  515. HadNonMemTransferAccess = true;
  516. continue;
  517. }
  518. // If this is a memcpy or memmove into or out of the whole allocation, we
  519. // can handle it like a load or store of the scalar type.
  520. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(UI)) {
  521. // Store to dynamic index.
  522. if (NonConstantIdx)
  523. return false;
  524. ConstantInt *Len = dyn_cast<ConstantInt>(MTI->getLength());
  525. if (!Len || Len->getZExtValue() != AllocaSize || Offset != 0)
  526. return false;
  527. IsNotTrivial = true; // Can't be mem2reg'd.
  528. continue;
  529. }
  530. // If this is a lifetime intrinsic, we can handle it.
  531. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(UI)) {
  532. if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
  533. II->getIntrinsicID() == Intrinsic::lifetime_end) {
  534. continue;
  535. }
  536. }
  537. // Otherwise, we cannot handle this!
  538. return false;
  539. }
  540. return true;
  541. }
  542. /// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca
  543. /// directly. This happens when we are converting an "integer union" to a
  544. /// single integer scalar, or when we are converting a "vector union" to a
  545. /// vector with insert/extractelement instructions.
  546. ///
  547. /// Offset is an offset from the original alloca, in bits that need to be
  548. /// shifted to the right. By the end of this, there should be no uses of Ptr.
  549. void ConvertToScalarInfo::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI,
  550. uint64_t Offset,
  551. Value *NonConstantIdx) {
  552. while (!Ptr->use_empty()) {
  553. Instruction *User = cast<Instruction>(Ptr->user_back());
  554. if (BitCastInst *CI = dyn_cast<BitCastInst>(User)) {
  555. ConvertUsesToScalar(CI, NewAI, Offset, NonConstantIdx);
  556. CI->eraseFromParent();
  557. continue;
  558. }
  559. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  560. // Compute the offset that this GEP adds to the pointer.
  561. SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
  562. Value *GEPNonConstantIdx = nullptr;
  563. if (!GEP->hasAllConstantIndices()) {
  564. assert(!NonConstantIdx &&
  565. "Dynamic GEP reading from dynamic GEP unsupported");
  566. GEPNonConstantIdx = Indices.pop_back_val();
  567. } else
  568. GEPNonConstantIdx = NonConstantIdx;
  569. uint64_t GEPOffset =
  570. DL.getIndexedOffset(GEP->getPointerOperandType(), Indices);
  571. ConvertUsesToScalar(GEP, NewAI, Offset + GEPOffset * 8,
  572. GEPNonConstantIdx);
  573. GEP->eraseFromParent();
  574. continue;
  575. }
  576. IRBuilder<> Builder(User);
  577. if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
  578. // The load is a bit extract from NewAI shifted right by Offset bits.
  579. Value *LoadedVal = Builder.CreateLoad(NewAI);
  580. Value *NewLoadVal = ConvertScalar_ExtractValue(
  581. LoadedVal, LI->getType(), Offset, NonConstantIdx, Builder);
  582. LI->replaceAllUsesWith(NewLoadVal);
  583. LI->eraseFromParent();
  584. continue;
  585. }
  586. if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
  587. assert(SI->getOperand(0) != Ptr && "Consistency error!");
  588. Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName() + ".in");
  589. Value *New = ConvertScalar_InsertValue(SI->getOperand(0), Old, Offset,
  590. NonConstantIdx, Builder);
  591. Builder.CreateStore(New, NewAI);
  592. SI->eraseFromParent();
  593. // If the load we just inserted is now dead, then the inserted store
  594. // overwrote the entire thing.
  595. if (Old->use_empty())
  596. Old->eraseFromParent();
  597. continue;
  598. }
  599. // If this is a constant sized memset of a constant value (e.g. 0) we can
  600. // transform it into a store of the expanded constant value.
  601. if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) {
  602. assert(MSI->getRawDest() == Ptr && "Consistency error!");
  603. assert(!NonConstantIdx && "Cannot replace dynamic memset with insert");
  604. int64_t SNumBytes = cast<ConstantInt>(MSI->getLength())->getSExtValue();
  605. if (SNumBytes > 0 && (SNumBytes >> 32) == 0) {
  606. unsigned NumBytes = static_cast<unsigned>(SNumBytes);
  607. unsigned Val = cast<ConstantInt>(MSI->getValue())->getZExtValue();
  608. // Compute the value replicated the right number of times.
  609. APInt APVal(NumBytes * 8, Val);
  610. // Splat the value if non-zero.
  611. if (Val)
  612. for (unsigned i = 1; i != NumBytes; ++i)
  613. APVal |= APVal << 8;
  614. Instruction *Old = Builder.CreateLoad(NewAI, NewAI->getName() + ".in");
  615. Value *New = ConvertScalar_InsertValue(
  616. ConstantInt::get(User->getContext(), APVal), Old, Offset, nullptr,
  617. Builder);
  618. Builder.CreateStore(New, NewAI);
  619. // If the load we just inserted is now dead, then the memset overwrote
  620. // the entire thing.
  621. if (Old->use_empty())
  622. Old->eraseFromParent();
  623. }
  624. MSI->eraseFromParent();
  625. continue;
  626. }
  627. // If this is a memcpy or memmove into or out of the whole allocation, we
  628. // can handle it like a load or store of the scalar type.
  629. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(User)) {
  630. assert(Offset == 0 && "must be store to start of alloca");
  631. assert(!NonConstantIdx && "Cannot replace dynamic transfer with insert");
  632. // If the source and destination are both to the same alloca, then this is
  633. // a noop copy-to-self, just delete it. Otherwise, emit a load and store
  634. // as appropriate.
  635. AllocaInst *OrigAI = cast<AllocaInst>(GetUnderlyingObject(Ptr, DL, 0));
  636. if (GetUnderlyingObject(MTI->getSource(), DL, 0) != OrigAI) {
  637. // Dest must be OrigAI, change this to be a load from the original
  638. // pointer (bitcasted), then a store to our new alloca.
  639. assert(MTI->getRawDest() == Ptr && "Neither use is of pointer?");
  640. Value *SrcPtr = MTI->getSource();
  641. PointerType *SPTy = cast<PointerType>(SrcPtr->getType());
  642. PointerType *AIPTy = cast<PointerType>(NewAI->getType());
  643. if (SPTy->getAddressSpace() != AIPTy->getAddressSpace()) {
  644. AIPTy = PointerType::get(AIPTy->getElementType(),
  645. SPTy->getAddressSpace());
  646. }
  647. SrcPtr = Builder.CreateBitCast(SrcPtr, AIPTy);
  648. LoadInst *SrcVal = Builder.CreateLoad(SrcPtr, "srcval");
  649. SrcVal->setAlignment(MTI->getAlignment());
  650. Builder.CreateStore(SrcVal, NewAI);
  651. } else if (GetUnderlyingObject(MTI->getDest(), DL, 0) != OrigAI) {
  652. // Src must be OrigAI, change this to be a load from NewAI then a store
  653. // through the original dest pointer (bitcasted).
  654. assert(MTI->getRawSource() == Ptr && "Neither use is of pointer?");
  655. LoadInst *SrcVal = Builder.CreateLoad(NewAI, "srcval");
  656. PointerType *DPTy = cast<PointerType>(MTI->getDest()->getType());
  657. PointerType *AIPTy = cast<PointerType>(NewAI->getType());
  658. if (DPTy->getAddressSpace() != AIPTy->getAddressSpace()) {
  659. AIPTy = PointerType::get(AIPTy->getElementType(),
  660. DPTy->getAddressSpace());
  661. }
  662. Value *DstPtr = Builder.CreateBitCast(MTI->getDest(), AIPTy);
  663. StoreInst *NewStore = Builder.CreateStore(SrcVal, DstPtr);
  664. NewStore->setAlignment(MTI->getAlignment());
  665. } else {
  666. // Noop transfer. Src == Dst
  667. }
  668. MTI->eraseFromParent();
  669. continue;
  670. }
  671. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) {
  672. if (II->getIntrinsicID() == Intrinsic::lifetime_start ||
  673. II->getIntrinsicID() == Intrinsic::lifetime_end) {
  674. // There's no need to preserve these, as the resulting alloca will be
  675. // converted to a register anyways.
  676. II->eraseFromParent();
  677. continue;
  678. }
  679. }
  680. llvm_unreachable("Unsupported operation!");
  681. }
  682. }
  683. /// ConvertScalar_ExtractValue - Extract a value of type ToType from an integer
  684. /// or vector value FromVal, extracting the bits from the offset specified by
  685. /// Offset. This returns the value, which is of type ToType.
  686. ///
  687. /// This happens when we are converting an "integer union" to a single
  688. /// integer scalar, or when we are converting a "vector union" to a vector with
  689. /// insert/extractelement instructions.
  690. ///
  691. /// Offset is an offset from the original alloca, in bits that need to be
  692. /// shifted to the right.
  693. Value *ConvertToScalarInfo::ConvertScalar_ExtractValue(Value *FromVal,
  694. Type *ToType,
  695. uint64_t Offset,
  696. Value *NonConstantIdx,
  697. IRBuilder<> &Builder) {
  698. // If the load is of the whole new alloca, no conversion is needed.
  699. Type *FromType = FromVal->getType();
  700. if (FromType == ToType && Offset == 0)
  701. return FromVal;
  702. // If the result alloca is a vector type, this is either an element
  703. // access or a bitcast to another vector type of the same size.
  704. if (VectorType *VTy = dyn_cast<VectorType>(FromType)) {
  705. unsigned FromTypeSize = DL.getTypeAllocSize(FromType);
  706. unsigned ToTypeSize = DL.getTypeAllocSize(ToType);
  707. if (FromTypeSize == ToTypeSize)
  708. return Builder.CreateBitCast(FromVal, ToType);
  709. // Otherwise it must be an element access.
  710. unsigned Elt = 0;
  711. if (Offset) {
  712. unsigned EltSize = DL.getTypeAllocSizeInBits(VTy->getElementType());
  713. Elt = Offset / EltSize;
  714. assert(EltSize * Elt == Offset && "Invalid modulus in validity checking");
  715. }
  716. // Return the element extracted out of it.
  717. Value *Idx;
  718. if (NonConstantIdx) {
  719. if (Elt)
  720. Idx = Builder.CreateAdd(NonConstantIdx, Builder.getInt32(Elt),
  721. "dyn.offset");
  722. else
  723. Idx = NonConstantIdx;
  724. } else
  725. Idx = Builder.getInt32(Elt);
  726. Value *V = Builder.CreateExtractElement(FromVal, Idx);
  727. if (V->getType() != ToType)
  728. V = Builder.CreateBitCast(V, ToType);
  729. return V;
  730. }
  731. // If ToType is a first class aggregate, extract out each of the pieces and
  732. // use insertvalue's to form the FCA.
  733. if (StructType *ST = dyn_cast<StructType>(ToType)) {
  734. assert(!NonConstantIdx &&
  735. "Dynamic indexing into struct types not supported");
  736. const StructLayout &Layout = *DL.getStructLayout(ST);
  737. Value *Res = UndefValue::get(ST);
  738. for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
  739. Value *Elt = ConvertScalar_ExtractValue(
  740. FromVal, ST->getElementType(i),
  741. Offset + Layout.getElementOffsetInBits(i), nullptr, Builder);
  742. Res = Builder.CreateInsertValue(Res, Elt, i);
  743. }
  744. return Res;
  745. }
  746. if (ArrayType *AT = dyn_cast<ArrayType>(ToType)) {
  747. assert(!NonConstantIdx &&
  748. "Dynamic indexing into array types not supported");
  749. uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType());
  750. Value *Res = UndefValue::get(AT);
  751. for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) {
  752. Value *Elt =
  753. ConvertScalar_ExtractValue(FromVal, AT->getElementType(),
  754. Offset + i * EltSize, nullptr, Builder);
  755. Res = Builder.CreateInsertValue(Res, Elt, i);
  756. }
  757. return Res;
  758. }
  759. // Otherwise, this must be a union that was converted to an integer value.
  760. IntegerType *NTy = cast<IntegerType>(FromVal->getType());
  761. // If this is a big-endian system and the load is narrower than the
  762. // full alloca type, we need to do a shift to get the right bits.
  763. int ShAmt = 0;
  764. if (DL.isBigEndian()) {
  765. // On big-endian machines, the lowest bit is stored at the bit offset
  766. // from the pointer given by getTypeStoreSizeInBits. This matters for
  767. // integers with a bitwidth that is not a multiple of 8.
  768. ShAmt = DL.getTypeStoreSizeInBits(NTy) - DL.getTypeStoreSizeInBits(ToType) -
  769. Offset;
  770. } else {
  771. ShAmt = Offset;
  772. }
  773. // Note: we support negative bitwidths (with shl) which are not defined.
  774. // We do this to support (f.e.) loads off the end of a structure where
  775. // only some bits are used.
  776. if (ShAmt > 0 && (unsigned)ShAmt < NTy->getBitWidth())
  777. FromVal = Builder.CreateLShr(FromVal,
  778. ConstantInt::get(FromVal->getType(), ShAmt));
  779. else if (ShAmt < 0 && (unsigned)-ShAmt < NTy->getBitWidth())
  780. FromVal = Builder.CreateShl(FromVal,
  781. ConstantInt::get(FromVal->getType(), -ShAmt));
  782. // Finally, unconditionally truncate the integer to the right width.
  783. unsigned LIBitWidth = DL.getTypeSizeInBits(ToType);
  784. if (LIBitWidth < NTy->getBitWidth())
  785. FromVal = Builder.CreateTrunc(
  786. FromVal, IntegerType::get(FromVal->getContext(), LIBitWidth));
  787. else if (LIBitWidth > NTy->getBitWidth())
  788. FromVal = Builder.CreateZExt(
  789. FromVal, IntegerType::get(FromVal->getContext(), LIBitWidth));
  790. // If the result is an integer, this is a trunc or bitcast.
  791. if (ToType->isIntegerTy()) {
  792. // Should be done.
  793. } else if (ToType->isFloatingPointTy() || ToType->isVectorTy()) {
  794. // Just do a bitcast, we know the sizes match up.
  795. FromVal = Builder.CreateBitCast(FromVal, ToType);
  796. } else {
  797. // Otherwise must be a pointer.
  798. FromVal = Builder.CreateIntToPtr(FromVal, ToType);
  799. }
  800. assert(FromVal->getType() == ToType && "Didn't convert right?");
  801. return FromVal;
  802. }
  803. /// ConvertScalar_InsertValue - Insert the value "SV" into the existing integer
  804. /// or vector value "Old" at the offset specified by Offset.
  805. ///
  806. /// This happens when we are converting an "integer union" to a
  807. /// single integer scalar, or when we are converting a "vector union" to a
  808. /// vector with insert/extractelement instructions.
  809. ///
  810. /// Offset is an offset from the original alloca, in bits that need to be
  811. /// shifted to the right.
  812. ///
  813. /// NonConstantIdx is an index value if there was a GEP with a non-constant
  814. /// index value. If this is 0 then all GEPs used to find this insert address
  815. /// are constant.
  816. Value *ConvertToScalarInfo::ConvertScalar_InsertValue(Value *SV, Value *Old,
  817. uint64_t Offset,
  818. Value *NonConstantIdx,
  819. IRBuilder<> &Builder) {
  820. // Convert the stored type to the actual type, shift it left to insert
  821. // then 'or' into place.
  822. Type *AllocaType = Old->getType();
  823. LLVMContext &Context = Old->getContext();
  824. if (VectorType *VTy = dyn_cast<VectorType>(AllocaType)) {
  825. uint64_t VecSize = DL.getTypeAllocSizeInBits(VTy);
  826. uint64_t ValSize = DL.getTypeAllocSizeInBits(SV->getType());
  827. // Changing the whole vector with memset or with an access of a different
  828. // vector type?
  829. if (ValSize == VecSize)
  830. return Builder.CreateBitCast(SV, AllocaType);
  831. // Must be an element insertion.
  832. Type *EltTy = VTy->getElementType();
  833. if (SV->getType() != EltTy)
  834. SV = Builder.CreateBitCast(SV, EltTy);
  835. uint64_t EltSize = DL.getTypeAllocSizeInBits(EltTy);
  836. unsigned Elt = Offset / EltSize;
  837. Value *Idx;
  838. if (NonConstantIdx) {
  839. if (Elt)
  840. Idx = Builder.CreateAdd(NonConstantIdx, Builder.getInt32(Elt),
  841. "dyn.offset");
  842. else
  843. Idx = NonConstantIdx;
  844. } else
  845. Idx = Builder.getInt32(Elt);
  846. return Builder.CreateInsertElement(Old, SV, Idx);
  847. }
  848. // If SV is a first-class aggregate value, insert each value recursively.
  849. if (StructType *ST = dyn_cast<StructType>(SV->getType())) {
  850. assert(!NonConstantIdx &&
  851. "Dynamic indexing into struct types not supported");
  852. const StructLayout &Layout = *DL.getStructLayout(ST);
  853. for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) {
  854. Value *Elt = Builder.CreateExtractValue(SV, i);
  855. Old = ConvertScalar_InsertValue(Elt, Old,
  856. Offset + Layout.getElementOffsetInBits(i),
  857. nullptr, Builder);
  858. }
  859. return Old;
  860. }
  861. if (ArrayType *AT = dyn_cast<ArrayType>(SV->getType())) {
  862. assert(!NonConstantIdx &&
  863. "Dynamic indexing into array types not supported");
  864. uint64_t EltSize = DL.getTypeAllocSizeInBits(AT->getElementType());
  865. for (unsigned i = 0, e = AT->getNumElements(); i != e; ++i) {
  866. Value *Elt = Builder.CreateExtractValue(SV, i);
  867. Old = ConvertScalar_InsertValue(Elt, Old, Offset + i * EltSize, nullptr,
  868. Builder);
  869. }
  870. return Old;
  871. }
  872. // If SV is a float, convert it to the appropriate integer type.
  873. // If it is a pointer, do the same.
  874. unsigned SrcWidth = DL.getTypeSizeInBits(SV->getType());
  875. unsigned DestWidth = DL.getTypeSizeInBits(AllocaType);
  876. unsigned SrcStoreWidth = DL.getTypeStoreSizeInBits(SV->getType());
  877. unsigned DestStoreWidth = DL.getTypeStoreSizeInBits(AllocaType);
  878. if (SV->getType()->isFloatingPointTy() || SV->getType()->isVectorTy())
  879. SV =
  880. Builder.CreateBitCast(SV, IntegerType::get(SV->getContext(), SrcWidth));
  881. else if (SV->getType()->isPointerTy())
  882. SV = Builder.CreatePtrToInt(SV, DL.getIntPtrType(SV->getType()));
  883. // Zero extend or truncate the value if needed.
  884. if (SV->getType() != AllocaType) {
  885. if (SV->getType()->getPrimitiveSizeInBits() <
  886. AllocaType->getPrimitiveSizeInBits())
  887. SV = Builder.CreateZExt(SV, AllocaType);
  888. else {
  889. // Truncation may be needed if storing more than the alloca can hold
  890. // (undefined behavior).
  891. SV = Builder.CreateTrunc(SV, AllocaType);
  892. SrcWidth = DestWidth;
  893. SrcStoreWidth = DestStoreWidth;
  894. }
  895. }
  896. // If this is a big-endian system and the store is narrower than the
  897. // full alloca type, we need to do a shift to get the right bits.
  898. int ShAmt = 0;
  899. if (DL.isBigEndian()) {
  900. // On big-endian machines, the lowest bit is stored at the bit offset
  901. // from the pointer given by getTypeStoreSizeInBits. This matters for
  902. // integers with a bitwidth that is not a multiple of 8.
  903. ShAmt = DestStoreWidth - SrcStoreWidth - Offset;
  904. } else {
  905. ShAmt = Offset;
  906. }
  907. // Note: we support negative bitwidths (with shr) which are not defined.
  908. // We do this to support (f.e.) stores off the end of a structure where
  909. // only some bits in the structure are set.
  910. APInt Mask(APInt::getLowBitsSet(DestWidth, SrcWidth));
  911. if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) {
  912. SV = Builder.CreateShl(SV, ConstantInt::get(SV->getType(), ShAmt));
  913. Mask <<= ShAmt;
  914. } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) {
  915. SV = Builder.CreateLShr(SV, ConstantInt::get(SV->getType(), -ShAmt));
  916. Mask = Mask.lshr(-ShAmt);
  917. }
  918. // Mask out the bits we are about to insert from the old value, and or
  919. // in the new bits.
  920. if (SrcWidth != DestWidth) {
  921. assert(DestWidth > SrcWidth);
  922. Old = Builder.CreateAnd(Old, ConstantInt::get(Context, ~Mask), "mask");
  923. SV = Builder.CreateOr(Old, SV, "ins");
  924. }
  925. return SV;
  926. }
  927. //===----------------------------------------------------------------------===//
  928. // SRoA Driver
  929. //===----------------------------------------------------------------------===//
  930. bool SROA_HLSL::runOnFunction(Function &F) {
  931. Module *M = F.getParent();
  932. HLModule &HLM = M->GetOrCreateHLModule();
  933. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  934. bool Changed = performScalarRepl(F, typeSys);
  935. // change rest memcpy into ld/st.
  936. MemcpySplitter splitter(F.getContext(), typeSys);
  937. splitter.Split(F);
  938. Changed |= markPrecise(F);
  939. return Changed;
  940. }
  941. namespace {
  942. class AllocaPromoter : public LoadAndStorePromoter {
  943. AllocaInst *AI;
  944. DIBuilder *DIB;
  945. SmallVector<DbgDeclareInst *, 4> DDIs;
  946. SmallVector<DbgValueInst *, 4> DVIs;
  947. public:
  948. AllocaPromoter(ArrayRef<Instruction *> Insts, SSAUpdater &S, DIBuilder *DB)
  949. : LoadAndStorePromoter(Insts, S), AI(nullptr), DIB(DB) {}
  950. void run(AllocaInst *AI, const SmallVectorImpl<Instruction *> &Insts) {
  951. // Remember which alloca we're promoting (for isInstInList).
  952. this->AI = AI;
  953. if (auto *L = LocalAsMetadata::getIfExists(AI)) {
  954. if (auto *DINode = MetadataAsValue::getIfExists(AI->getContext(), L)) {
  955. for (User *U : DINode->users())
  956. if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U))
  957. DDIs.push_back(DDI);
  958. else if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U))
  959. DVIs.push_back(DVI);
  960. }
  961. }
  962. LoadAndStorePromoter::run(Insts);
  963. AI->eraseFromParent();
  964. for (SmallVectorImpl<DbgDeclareInst *>::iterator I = DDIs.begin(),
  965. E = DDIs.end();
  966. I != E; ++I) {
  967. DbgDeclareInst *DDI = *I;
  968. DDI->eraseFromParent();
  969. }
  970. for (SmallVectorImpl<DbgValueInst *>::iterator I = DVIs.begin(),
  971. E = DVIs.end();
  972. I != E; ++I) {
  973. DbgValueInst *DVI = *I;
  974. DVI->eraseFromParent();
  975. }
  976. }
  977. bool
  978. isInstInList(Instruction *I,
  979. const SmallVectorImpl<Instruction *> &Insts) const override {
  980. if (LoadInst *LI = dyn_cast<LoadInst>(I))
  981. return LI->getOperand(0) == AI;
  982. return cast<StoreInst>(I)->getPointerOperand() == AI;
  983. }
  984. void updateDebugInfo(Instruction *Inst) const override {
  985. for (SmallVectorImpl<DbgDeclareInst *>::const_iterator I = DDIs.begin(),
  986. E = DDIs.end();
  987. I != E; ++I) {
  988. DbgDeclareInst *DDI = *I;
  989. if (StoreInst *SI = dyn_cast<StoreInst>(Inst))
  990. ConvertDebugDeclareToDebugValue(DDI, SI, *DIB);
  991. else if (LoadInst *LI = dyn_cast<LoadInst>(Inst))
  992. ConvertDebugDeclareToDebugValue(DDI, LI, *DIB);
  993. }
  994. for (SmallVectorImpl<DbgValueInst *>::const_iterator I = DVIs.begin(),
  995. E = DVIs.end();
  996. I != E; ++I) {
  997. DbgValueInst *DVI = *I;
  998. Value *Arg = nullptr;
  999. if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
  1000. // If an argument is zero extended then use argument directly. The ZExt
  1001. // may be zapped by an optimization pass in future.
  1002. if (ZExtInst *ZExt = dyn_cast<ZExtInst>(SI->getOperand(0)))
  1003. Arg = dyn_cast<Argument>(ZExt->getOperand(0));
  1004. if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0)))
  1005. Arg = dyn_cast<Argument>(SExt->getOperand(0));
  1006. if (!Arg)
  1007. Arg = SI->getOperand(0);
  1008. } else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
  1009. Arg = LI->getOperand(0);
  1010. } else {
  1011. continue;
  1012. }
  1013. DIB->insertDbgValueIntrinsic(Arg, 0, DVI->getVariable(),
  1014. DVI->getExpression(), DVI->getDebugLoc(),
  1015. Inst);
  1016. }
  1017. }
  1018. };
  1019. } // end anon namespace
  1020. /// isSafeSelectToSpeculate - Select instructions that use an alloca and are
  1021. /// subsequently loaded can be rewritten to load both input pointers and then
  1022. /// select between the result, allowing the load of the alloca to be promoted.
  1023. /// From this:
  1024. /// %P2 = select i1 %cond, i32* %Alloca, i32* %Other
  1025. /// %V = load i32* %P2
  1026. /// to:
  1027. /// %V1 = load i32* %Alloca -> will be mem2reg'd
  1028. /// %V2 = load i32* %Other
  1029. /// %V = select i1 %cond, i32 %V1, i32 %V2
  1030. ///
  1031. /// We can do this to a select if its only uses are loads and if the operand to
  1032. /// the select can be loaded unconditionally.
  1033. static bool isSafeSelectToSpeculate(SelectInst *SI) {
  1034. const DataLayout &DL = SI->getModule()->getDataLayout();
  1035. bool TDerefable = isDereferenceablePointer(SI->getTrueValue(), DL);
  1036. bool FDerefable = isDereferenceablePointer(SI->getFalseValue(), DL);
  1037. for (User *U : SI->users()) {
  1038. LoadInst *LI = dyn_cast<LoadInst>(U);
  1039. if (!LI || !LI->isSimple())
  1040. return false;
  1041. // Both operands to the select need to be dereferencable, either absolutely
  1042. // (e.g. allocas) or at this point because we can see other accesses to it.
  1043. if (!TDerefable &&
  1044. !isSafeToLoadUnconditionally(SI->getTrueValue(), LI,
  1045. LI->getAlignment()))
  1046. return false;
  1047. if (!FDerefable &&
  1048. !isSafeToLoadUnconditionally(SI->getFalseValue(), LI,
  1049. LI->getAlignment()))
  1050. return false;
  1051. }
  1052. return true;
  1053. }
  1054. /// isSafePHIToSpeculate - PHI instructions that use an alloca and are
  1055. /// subsequently loaded can be rewritten to load both input pointers in the pred
  1056. /// blocks and then PHI the results, allowing the load of the alloca to be
  1057. /// promoted.
  1058. /// From this:
  1059. /// %P2 = phi [i32* %Alloca, i32* %Other]
  1060. /// %V = load i32* %P2
  1061. /// to:
  1062. /// %V1 = load i32* %Alloca -> will be mem2reg'd
  1063. /// ...
  1064. /// %V2 = load i32* %Other
  1065. /// ...
  1066. /// %V = phi [i32 %V1, i32 %V2]
  1067. ///
  1068. /// We can do this to a select if its only uses are loads and if the operand to
  1069. /// the select can be loaded unconditionally.
  1070. static bool isSafePHIToSpeculate(PHINode *PN) {
  1071. // For now, we can only do this promotion if the load is in the same block as
  1072. // the PHI, and if there are no stores between the phi and load.
  1073. // TODO: Allow recursive phi users.
  1074. // TODO: Allow stores.
  1075. BasicBlock *BB = PN->getParent();
  1076. unsigned MaxAlign = 0;
  1077. for (User *U : PN->users()) {
  1078. LoadInst *LI = dyn_cast<LoadInst>(U);
  1079. if (!LI || !LI->isSimple())
  1080. return false;
  1081. // For now we only allow loads in the same block as the PHI. This is a
  1082. // common case that happens when instcombine merges two loads through a PHI.
  1083. if (LI->getParent() != BB)
  1084. return false;
  1085. // Ensure that there are no instructions between the PHI and the load that
  1086. // could store.
  1087. for (BasicBlock::iterator BBI = PN; &*BBI != LI; ++BBI)
  1088. if (BBI->mayWriteToMemory())
  1089. return false;
  1090. MaxAlign = std::max(MaxAlign, LI->getAlignment());
  1091. }
  1092. const DataLayout &DL = PN->getModule()->getDataLayout();
  1093. // Okay, we know that we have one or more loads in the same block as the PHI.
  1094. // We can transform this if it is safe to push the loads into the predecessor
  1095. // blocks. The only thing to watch out for is that we can't put a possibly
  1096. // trapping load in the predecessor if it is a critical edge.
  1097. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  1098. BasicBlock *Pred = PN->getIncomingBlock(i);
  1099. Value *InVal = PN->getIncomingValue(i);
  1100. // If the terminator of the predecessor has side-effects (an invoke),
  1101. // there is no safe place to put a load in the predecessor.
  1102. if (Pred->getTerminator()->mayHaveSideEffects())
  1103. return false;
  1104. // If the value is produced by the terminator of the predecessor
  1105. // (an invoke), there is no valid place to put a load in the predecessor.
  1106. if (Pred->getTerminator() == InVal)
  1107. return false;
  1108. // If the predecessor has a single successor, then the edge isn't critical.
  1109. if (Pred->getTerminator()->getNumSuccessors() == 1)
  1110. continue;
  1111. // If this pointer is always safe to load, or if we can prove that there is
  1112. // already a load in the block, then we can move the load to the pred block.
  1113. if (isDereferenceablePointer(InVal, DL) ||
  1114. isSafeToLoadUnconditionally(InVal, Pred->getTerminator(), MaxAlign))
  1115. continue;
  1116. return false;
  1117. }
  1118. return true;
  1119. }
  1120. /// tryToMakeAllocaBePromotable - This returns true if the alloca only has
  1121. /// direct (non-volatile) loads and stores to it. If the alloca is close but
  1122. /// not quite there, this will transform the code to allow promotion. As such,
  1123. /// it is a non-pure predicate.
  1124. static bool tryToMakeAllocaBePromotable(AllocaInst *AI, const DataLayout &DL) {
  1125. SetVector<Instruction *, SmallVector<Instruction *, 4>,
  1126. SmallPtrSet<Instruction *, 4>>
  1127. InstsToRewrite;
  1128. for (User *U : AI->users()) {
  1129. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  1130. if (!LI->isSimple())
  1131. return false;
  1132. continue;
  1133. }
  1134. if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  1135. if (SI->getOperand(0) == AI || !SI->isSimple())
  1136. return false; // Don't allow a store OF the AI, only INTO the AI.
  1137. continue;
  1138. }
  1139. if (SelectInst *SI = dyn_cast<SelectInst>(U)) {
  1140. // If the condition being selected on is a constant, fold the select, yes
  1141. // this does (rarely) happen early on.
  1142. if (ConstantInt *CI = dyn_cast<ConstantInt>(SI->getCondition())) {
  1143. Value *Result = SI->getOperand(1 + CI->isZero());
  1144. SI->replaceAllUsesWith(Result);
  1145. SI->eraseFromParent();
  1146. // This is very rare and we just scrambled the use list of AI, start
  1147. // over completely.
  1148. return tryToMakeAllocaBePromotable(AI, DL);
  1149. }
  1150. // If it is safe to turn "load (select c, AI, ptr)" into a select of two
  1151. // loads, then we can transform this by rewriting the select.
  1152. if (!isSafeSelectToSpeculate(SI))
  1153. return false;
  1154. InstsToRewrite.insert(SI);
  1155. continue;
  1156. }
  1157. if (PHINode *PN = dyn_cast<PHINode>(U)) {
  1158. if (PN->use_empty()) { // Dead PHIs can be stripped.
  1159. InstsToRewrite.insert(PN);
  1160. continue;
  1161. }
  1162. // If it is safe to turn "load (phi [AI, ptr, ...])" into a PHI of loads
  1163. // in the pred blocks, then we can transform this by rewriting the PHI.
  1164. if (!isSafePHIToSpeculate(PN))
  1165. return false;
  1166. InstsToRewrite.insert(PN);
  1167. continue;
  1168. }
  1169. if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
  1170. if (onlyUsedByLifetimeMarkers(BCI)) {
  1171. InstsToRewrite.insert(BCI);
  1172. continue;
  1173. }
  1174. }
  1175. return false;
  1176. }
  1177. // If there are no instructions to rewrite, then all uses are load/stores and
  1178. // we're done!
  1179. if (InstsToRewrite.empty())
  1180. return true;
  1181. // If we have instructions that need to be rewritten for this to be promotable
  1182. // take care of it now.
  1183. for (unsigned i = 0, e = InstsToRewrite.size(); i != e; ++i) {
  1184. if (BitCastInst *BCI = dyn_cast<BitCastInst>(InstsToRewrite[i])) {
  1185. // This could only be a bitcast used by nothing but lifetime intrinsics.
  1186. for (BitCastInst::user_iterator I = BCI->user_begin(),
  1187. E = BCI->user_end();
  1188. I != E;)
  1189. cast<Instruction>(*I++)->eraseFromParent();
  1190. BCI->eraseFromParent();
  1191. continue;
  1192. }
  1193. if (SelectInst *SI = dyn_cast<SelectInst>(InstsToRewrite[i])) {
  1194. // Selects in InstsToRewrite only have load uses. Rewrite each as two
  1195. // loads with a new select.
  1196. while (!SI->use_empty()) {
  1197. LoadInst *LI = cast<LoadInst>(SI->user_back());
  1198. IRBuilder<> Builder(LI);
  1199. LoadInst *TrueLoad =
  1200. Builder.CreateLoad(SI->getTrueValue(), LI->getName() + ".t");
  1201. LoadInst *FalseLoad =
  1202. Builder.CreateLoad(SI->getFalseValue(), LI->getName() + ".f");
  1203. // Transfer alignment and AA info if present.
  1204. TrueLoad->setAlignment(LI->getAlignment());
  1205. FalseLoad->setAlignment(LI->getAlignment());
  1206. AAMDNodes Tags;
  1207. LI->getAAMetadata(Tags);
  1208. if (Tags) {
  1209. TrueLoad->setAAMetadata(Tags);
  1210. FalseLoad->setAAMetadata(Tags);
  1211. }
  1212. Value *V =
  1213. Builder.CreateSelect(SI->getCondition(), TrueLoad, FalseLoad);
  1214. V->takeName(LI);
  1215. LI->replaceAllUsesWith(V);
  1216. LI->eraseFromParent();
  1217. }
  1218. // Now that all the loads are gone, the select is gone too.
  1219. SI->eraseFromParent();
  1220. continue;
  1221. }
  1222. // Otherwise, we have a PHI node which allows us to push the loads into the
  1223. // predecessors.
  1224. PHINode *PN = cast<PHINode>(InstsToRewrite[i]);
  1225. if (PN->use_empty()) {
  1226. PN->eraseFromParent();
  1227. continue;
  1228. }
  1229. Type *LoadTy = cast<PointerType>(PN->getType())->getElementType();
  1230. PHINode *NewPN = PHINode::Create(LoadTy, PN->getNumIncomingValues(),
  1231. PN->getName() + ".ld", PN);
  1232. // Get the AA tags and alignment to use from one of the loads. It doesn't
  1233. // matter which one we get and if any differ, it doesn't matter.
  1234. LoadInst *SomeLoad = cast<LoadInst>(PN->user_back());
  1235. AAMDNodes AATags;
  1236. SomeLoad->getAAMetadata(AATags);
  1237. unsigned Align = SomeLoad->getAlignment();
  1238. // Rewrite all loads of the PN to use the new PHI.
  1239. while (!PN->use_empty()) {
  1240. LoadInst *LI = cast<LoadInst>(PN->user_back());
  1241. LI->replaceAllUsesWith(NewPN);
  1242. LI->eraseFromParent();
  1243. }
  1244. // Inject loads into all of the pred blocks. Keep track of which blocks we
  1245. // insert them into in case we have multiple edges from the same block.
  1246. DenseMap<BasicBlock *, LoadInst *> InsertedLoads;
  1247. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
  1248. BasicBlock *Pred = PN->getIncomingBlock(i);
  1249. LoadInst *&Load = InsertedLoads[Pred];
  1250. if (!Load) {
  1251. Load = new LoadInst(PN->getIncomingValue(i),
  1252. PN->getName() + "." + Pred->getName(),
  1253. Pred->getTerminator());
  1254. Load->setAlignment(Align);
  1255. if (AATags)
  1256. Load->setAAMetadata(AATags);
  1257. }
  1258. NewPN->addIncoming(Load, Pred);
  1259. }
  1260. PN->eraseFromParent();
  1261. }
  1262. ++NumAdjusted;
  1263. return true;
  1264. }
  1265. bool SROA_HLSL::performPromotion(Function &F) {
  1266. std::vector<AllocaInst *> Allocas;
  1267. const DataLayout &DL = F.getParent()->getDataLayout();
  1268. DominatorTree *DT = nullptr;
  1269. if (HasDomTree)
  1270. DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  1271. AssumptionCache &AC =
  1272. getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
  1273. BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function
  1274. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  1275. bool Changed = false;
  1276. SmallVector<Instruction *, 64> Insts;
  1277. while (1) {
  1278. Allocas.clear();
  1279. // Find allocas that are safe to promote, by looking at all instructions in
  1280. // the entry node
  1281. for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
  1282. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
  1283. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(AI);
  1284. // Skip alloca has debug info when not promote.
  1285. if (DDI && !RunPromotion) {
  1286. continue;
  1287. }
  1288. if (tryToMakeAllocaBePromotable(AI, DL))
  1289. Allocas.push_back(AI);
  1290. }
  1291. if (Allocas.empty())
  1292. break;
  1293. if (HasDomTree)
  1294. PromoteMemToReg(Allocas, *DT, nullptr, &AC);
  1295. else {
  1296. SSAUpdater SSA;
  1297. for (unsigned i = 0, e = Allocas.size(); i != e; ++i) {
  1298. AllocaInst *AI = Allocas[i];
  1299. // Build list of instructions to promote.
  1300. for (User *U : AI->users())
  1301. Insts.push_back(cast<Instruction>(U));
  1302. AllocaPromoter(Insts, SSA, &DIB).run(AI, Insts);
  1303. Insts.clear();
  1304. }
  1305. }
  1306. NumPromoted += Allocas.size();
  1307. Changed = true;
  1308. }
  1309. return Changed;
  1310. }
  1311. /// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
  1312. /// SROA. It must be a struct or array type with a small number of elements.
  1313. bool SROA_HLSL::ShouldAttemptScalarRepl(AllocaInst *AI) {
  1314. Type *T = AI->getAllocatedType();
  1315. // promote every struct.
  1316. if (StructType *ST = dyn_cast<StructType>(T))
  1317. return true;
  1318. // promote every array.
  1319. if (ArrayType *AT = dyn_cast<ArrayType>(T))
  1320. return true;
  1321. return false;
  1322. }
  1323. // performScalarRepl - This algorithm is a simple worklist driven algorithm,
  1324. // which runs on all of the alloca instructions in the entry block, removing
  1325. // them if they are only used by getelementptr instructions.
  1326. //
  1327. bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
  1328. std::vector<AllocaInst *> AllocaList;
  1329. const DataLayout &DL = F.getParent()->getDataLayout();
  1330. // Scan the entry basic block, adding allocas to the worklist.
  1331. BasicBlock &BB = F.getEntryBlock();
  1332. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1333. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1334. if (A->hasNUsesOrMore(1))
  1335. AllocaList.emplace_back(A);
  1336. }
  1337. // merge GEP use for the allocs
  1338. for (auto A : AllocaList)
  1339. HLModule::MergeGepUse(A);
  1340. // Make sure big alloca split first.
  1341. // This will simplify memcpy check between part of big alloca and small
  1342. // alloca. Big alloca will be split to smaller piece first, when process the
  1343. // alloca, it will be alloca flattened from big alloca instead of a GEP of big
  1344. // alloca.
  1345. auto size_cmp = [&DL](const AllocaInst *a0, const AllocaInst *a1) -> bool {
  1346. return DL.getTypeAllocSize(a0->getAllocatedType()) >
  1347. DL.getTypeAllocSize(a1->getAllocatedType());
  1348. };
  1349. std::sort(AllocaList.begin(), AllocaList.end(), size_cmp);
  1350. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  1351. // Process the worklist
  1352. bool Changed = false;
  1353. for (AllocaInst *Alloc : AllocaList) {
  1354. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(Alloc);
  1355. unsigned debugOffset = 0;
  1356. std::deque<AllocaInst *> WorkList;
  1357. WorkList.emplace_back(Alloc);
  1358. while (!WorkList.empty()) {
  1359. AllocaInst *AI = WorkList.front();
  1360. WorkList.pop_front();
  1361. // Handle dead allocas trivially. These can be formed by SROA'ing arrays
  1362. // with unused elements.
  1363. if (AI->use_empty()) {
  1364. AI->eraseFromParent();
  1365. Changed = true;
  1366. continue;
  1367. }
  1368. const bool bAllowReplace = true;
  1369. if (SROA_Helper::LowerMemcpy(AI, /*annotation*/ nullptr, typeSys, DL,
  1370. bAllowReplace)) {
  1371. Changed = true;
  1372. continue;
  1373. }
  1374. // If this alloca is impossible for us to promote, reject it early.
  1375. if (AI->isArrayAllocation() || !AI->getAllocatedType()->isSized())
  1376. continue;
  1377. // Check to see if we can perform the core SROA transformation. We cannot
  1378. // transform the allocation instruction if it is an array allocation
  1379. // (allocations OF arrays are ok though), and an allocation of a scalar
  1380. // value cannot be decomposed at all.
  1381. uint64_t AllocaSize = DL.getTypeAllocSize(AI->getAllocatedType());
  1382. // Do not promote [0 x %struct].
  1383. if (AllocaSize == 0)
  1384. continue;
  1385. Type *Ty = AI->getAllocatedType();
  1386. // Skip empty struct type.
  1387. if (SROA_Helper::IsEmptyStructType(Ty, typeSys)) {
  1388. SROA_Helper::MarkEmptyStructUsers(AI, DeadInsts);
  1389. DeleteDeadInstructions();
  1390. continue;
  1391. }
  1392. // If the alloca looks like a good candidate for scalar replacement, and
  1393. // if
  1394. // all its users can be transformed, then split up the aggregate into its
  1395. // separate elements.
  1396. if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
  1397. std::vector<Value *> Elts;
  1398. IRBuilder<> Builder(AI);
  1399. bool hasPrecise = HLModule::HasPreciseAttributeWithMetadata(AI);
  1400. bool SROAed = SROA_Helper::DoScalarReplacement(
  1401. AI, Elts, Builder, /*bFlatVector*/ true, hasPrecise, typeSys,
  1402. DeadInsts);
  1403. if (SROAed) {
  1404. Type *Ty = AI->getAllocatedType();
  1405. // Skip empty struct parameters.
  1406. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  1407. if (!HLMatrixLower::IsMatrixType(Ty)) {
  1408. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  1409. if (SA && SA->IsEmptyStruct()) {
  1410. for (User *U : AI->users()) {
  1411. if (StoreInst *SI = dyn_cast<StoreInst>(U))
  1412. DeadInsts.emplace_back(SI);
  1413. }
  1414. DeleteDeadInstructions();
  1415. AI->replaceAllUsesWith(UndefValue::get(AI->getType()));
  1416. AI->eraseFromParent();
  1417. continue;
  1418. }
  1419. }
  1420. }
  1421. // Push Elts into workList.
  1422. for (auto iter = Elts.begin(); iter != Elts.end(); iter++)
  1423. WorkList.emplace_back(cast<AllocaInst>(*iter));
  1424. // Now erase any instructions that were made dead while rewriting the
  1425. // alloca.
  1426. DeleteDeadInstructions();
  1427. ++NumReplaced;
  1428. AI->eraseFromParent();
  1429. Changed = true;
  1430. continue;
  1431. }
  1432. }
  1433. // Add debug info.
  1434. if (DDI != nullptr && AI != Alloc) {
  1435. Type *Ty = AI->getAllocatedType();
  1436. unsigned size = DL.getTypeAllocSize(Ty);
  1437. DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
  1438. debugOffset += size;
  1439. DIB.insertDeclare(AI, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  1440. DDI);
  1441. }
  1442. }
  1443. }
  1444. return Changed;
  1445. }
  1446. // markPrecise - To save the precise attribute on alloca inst which might be removed by promote,
  1447. // mark precise attribute with function call on alloca inst stores.
  1448. bool SROA_HLSL::markPrecise(Function &F) {
  1449. bool Changed = false;
  1450. BasicBlock &BB = F.getEntryBlock();
  1451. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I)
  1452. if (AllocaInst *A = dyn_cast<AllocaInst>(I)) {
  1453. // TODO: Only do this on basic types.
  1454. if (HLModule::HasPreciseAttributeWithMetadata(A)) {
  1455. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(A,
  1456. *(F.getParent()));
  1457. Changed = true;
  1458. }
  1459. }
  1460. return Changed;
  1461. }
  1462. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  1463. /// recursively including all their operands that become trivially dead.
  1464. void SROA_HLSL::DeleteDeadInstructions() {
  1465. while (!DeadInsts.empty()) {
  1466. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  1467. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  1468. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  1469. // Zero out the operand and see if it becomes trivially dead.
  1470. // (But, don't add allocas to the dead instruction list -- they are
  1471. // already on the worklist and will be deleted separately.)
  1472. *OI = nullptr;
  1473. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  1474. DeadInsts.push_back(U);
  1475. }
  1476. I->eraseFromParent();
  1477. }
  1478. }
  1479. /// isSafeForScalarRepl - Check if instruction I is a safe use with regard to
  1480. /// performing scalar replacement of alloca AI. The results are flagged in
  1481. /// the Info parameter. Offset indicates the position within AI that is
  1482. /// referenced by this instruction.
  1483. void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset,
  1484. AllocaInfo &Info) {
  1485. if (I->getType()->isPointerTy()) {
  1486. // Don't check object pointers.
  1487. if (HLModule::IsHLSLObjectType(I->getType()->getPointerElementType()))
  1488. return;
  1489. }
  1490. const DataLayout &DL = I->getModule()->getDataLayout();
  1491. for (Use &U : I->uses()) {
  1492. Instruction *User = cast<Instruction>(U.getUser());
  1493. if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) {
  1494. isSafeForScalarRepl(BC, Offset, Info);
  1495. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(User)) {
  1496. uint64_t GEPOffset = Offset;
  1497. isSafeGEP(GEPI, GEPOffset, Info);
  1498. if (!Info.isUnsafe)
  1499. isSafeForScalarRepl(GEPI, GEPOffset, Info);
  1500. } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
  1501. ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength());
  1502. if (!Length || Length->isNegative())
  1503. return MarkUnsafe(Info, User);
  1504. isSafeMemAccess(Offset, Length->getZExtValue(), nullptr,
  1505. U.getOperandNo() == 0, Info, MI,
  1506. true /*AllowWholeAccess*/);
  1507. } else if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
  1508. if (!LI->isSimple())
  1509. return MarkUnsafe(Info, User);
  1510. Type *LIType = LI->getType();
  1511. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1512. LI, true /*AllowWholeAccess*/);
  1513. Info.hasALoadOrStore = true;
  1514. } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
  1515. // Store is ok if storing INTO the pointer, not storing the pointer
  1516. if (!SI->isSimple() || SI->getOperand(0) == I)
  1517. return MarkUnsafe(Info, User);
  1518. Type *SIType = SI->getOperand(0)->getType();
  1519. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1520. SI, true /*AllowWholeAccess*/);
  1521. Info.hasALoadOrStore = true;
  1522. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(User)) {
  1523. if (II->getIntrinsicID() != Intrinsic::lifetime_start &&
  1524. II->getIntrinsicID() != Intrinsic::lifetime_end)
  1525. return MarkUnsafe(Info, User);
  1526. } else if (isa<PHINode>(User) || isa<SelectInst>(User)) {
  1527. isSafePHISelectUseForScalarRepl(User, Offset, Info);
  1528. } else if (CallInst *CI = dyn_cast<CallInst>(User)) {
  1529. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1530. // HL functions are safe for scalar repl.
  1531. if (group == HLOpcodeGroup::NotHL)
  1532. return MarkUnsafe(Info, User);
  1533. } else {
  1534. return MarkUnsafe(Info, User);
  1535. }
  1536. if (Info.isUnsafe)
  1537. return;
  1538. }
  1539. }
  1540. /// isSafePHIUseForScalarRepl - If we see a PHI node or select using a pointer
  1541. /// derived from the alloca, we can often still split the alloca into elements.
  1542. /// This is useful if we have a large alloca where one element is phi'd
  1543. /// together somewhere: we can SRoA and promote all the other elements even if
  1544. /// we end up not being able to promote this one.
  1545. ///
  1546. /// All we require is that the uses of the PHI do not index into other parts of
  1547. /// the alloca. The most important use case for this is single load and stores
  1548. /// that are PHI'd together, which can happen due to code sinking.
  1549. void SROA_HLSL::isSafePHISelectUseForScalarRepl(Instruction *I, uint64_t Offset,
  1550. AllocaInfo &Info) {
  1551. // If we've already checked this PHI, don't do it again.
  1552. if (PHINode *PN = dyn_cast<PHINode>(I))
  1553. if (!Info.CheckedPHIs.insert(PN).second)
  1554. return;
  1555. const DataLayout &DL = I->getModule()->getDataLayout();
  1556. for (User *U : I->users()) {
  1557. Instruction *UI = cast<Instruction>(U);
  1558. if (BitCastInst *BC = dyn_cast<BitCastInst>(UI)) {
  1559. isSafePHISelectUseForScalarRepl(BC, Offset, Info);
  1560. } else if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(UI)) {
  1561. // Only allow "bitcast" GEPs for simplicity. We could generalize this,
  1562. // but would have to prove that we're staying inside of an element being
  1563. // promoted.
  1564. if (!GEPI->hasAllZeroIndices())
  1565. return MarkUnsafe(Info, UI);
  1566. isSafePHISelectUseForScalarRepl(GEPI, Offset, Info);
  1567. } else if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
  1568. if (!LI->isSimple())
  1569. return MarkUnsafe(Info, UI);
  1570. Type *LIType = LI->getType();
  1571. isSafeMemAccess(Offset, DL.getTypeAllocSize(LIType), LIType, false, Info,
  1572. LI, false /*AllowWholeAccess*/);
  1573. Info.hasALoadOrStore = true;
  1574. } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
  1575. // Store is ok if storing INTO the pointer, not storing the pointer
  1576. if (!SI->isSimple() || SI->getOperand(0) == I)
  1577. return MarkUnsafe(Info, UI);
  1578. Type *SIType = SI->getOperand(0)->getType();
  1579. isSafeMemAccess(Offset, DL.getTypeAllocSize(SIType), SIType, true, Info,
  1580. SI, false /*AllowWholeAccess*/);
  1581. Info.hasALoadOrStore = true;
  1582. } else if (isa<PHINode>(UI) || isa<SelectInst>(UI)) {
  1583. isSafePHISelectUseForScalarRepl(UI, Offset, Info);
  1584. } else {
  1585. return MarkUnsafe(Info, UI);
  1586. }
  1587. if (Info.isUnsafe)
  1588. return;
  1589. }
  1590. }
  1591. /// isSafeGEP - Check if a GEP instruction can be handled for scalar
  1592. /// replacement. It is safe when all the indices are constant, in-bounds
  1593. /// references, and when the resulting offset corresponds to an element within
  1594. /// the alloca type. The results are flagged in the Info parameter. Upon
  1595. /// return, Offset is adjusted as specified by the GEP indices.
  1596. void SROA_HLSL::isSafeGEP(GetElementPtrInst *GEPI, uint64_t &Offset,
  1597. AllocaInfo &Info) {
  1598. gep_type_iterator GEPIt = gep_type_begin(GEPI), E = gep_type_end(GEPI);
  1599. if (GEPIt == E)
  1600. return;
  1601. bool NonConstant = false;
  1602. unsigned NonConstantIdxSize = 0;
  1603. // Compute the offset due to this GEP and check if the alloca has a
  1604. // component element at that offset.
  1605. SmallVector<Value *, 8> Indices(GEPI->op_begin() + 1, GEPI->op_end());
  1606. auto indicesIt = Indices.begin();
  1607. // Walk through the GEP type indices, checking the types that this indexes
  1608. // into.
  1609. uint32_t arraySize = 0;
  1610. bool isArrayIndexing = false;
  1611. for (;GEPIt != E; ++GEPIt) {
  1612. Type *Ty = *GEPIt;
  1613. if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
  1614. // Don't go inside struct when mark hasArrayIndexing and hasVectorIndexing.
  1615. // The following level won't affect scalar repl on the struct.
  1616. break;
  1617. }
  1618. if (GEPIt->isArrayTy()) {
  1619. arraySize = GEPIt->getArrayNumElements();
  1620. isArrayIndexing = true;
  1621. }
  1622. if (GEPIt->isVectorTy()) {
  1623. arraySize = GEPIt->getVectorNumElements();
  1624. isArrayIndexing = false;
  1625. }
  1626. // Allow dynamic indexing
  1627. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1628. if (!IdxVal) {
  1629. // for dynamic index, use array size - 1 to check the offset
  1630. *indicesIt = Constant::getIntegerValue(
  1631. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1632. if (isArrayIndexing)
  1633. Info.hasArrayIndexing = true;
  1634. else
  1635. Info.hasVectorIndexing = true;
  1636. NonConstant = true;
  1637. }
  1638. indicesIt++;
  1639. }
  1640. // Continue iterate only for the NonConstant.
  1641. for (;GEPIt != E; ++GEPIt) {
  1642. Type *Ty = *GEPIt;
  1643. if (Ty->isArrayTy()) {
  1644. arraySize = GEPIt->getArrayNumElements();
  1645. }
  1646. if (Ty->isVectorTy()) {
  1647. arraySize = GEPIt->getVectorNumElements();
  1648. }
  1649. // Allow dynamic indexing
  1650. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  1651. if (!IdxVal) {
  1652. // for dynamic index, use array size - 1 to check the offset
  1653. *indicesIt = Constant::getIntegerValue(
  1654. Type::getInt32Ty(GEPI->getContext()), APInt(32, arraySize - 1));
  1655. NonConstant = true;
  1656. }
  1657. indicesIt++;
  1658. }
  1659. // If this GEP is non-constant then the last operand must have been a
  1660. // dynamic index into a vector. Pop this now as it has no impact on the
  1661. // constant part of the offset.
  1662. if (NonConstant)
  1663. Indices.pop_back();
  1664. const DataLayout &DL = GEPI->getModule()->getDataLayout();
  1665. Offset += DL.getIndexedOffset(GEPI->getPointerOperandType(), Indices);
  1666. if (!TypeHasComponent(Info.AI->getAllocatedType(), Offset, NonConstantIdxSize,
  1667. DL))
  1668. MarkUnsafe(Info, GEPI);
  1669. }
  1670. /// isHomogeneousAggregate - Check if type T is a struct or array containing
  1671. /// elements of the same type (which is always true for arrays). If so,
  1672. /// return true with NumElts and EltTy set to the number of elements and the
  1673. /// element type, respectively.
  1674. static bool isHomogeneousAggregate(Type *T, unsigned &NumElts, Type *&EltTy) {
  1675. if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  1676. NumElts = AT->getNumElements();
  1677. EltTy = (NumElts == 0 ? nullptr : AT->getElementType());
  1678. return true;
  1679. }
  1680. if (StructType *ST = dyn_cast<StructType>(T)) {
  1681. NumElts = ST->getNumContainedTypes();
  1682. EltTy = (NumElts == 0 ? nullptr : ST->getContainedType(0));
  1683. for (unsigned n = 1; n < NumElts; ++n) {
  1684. if (ST->getContainedType(n) != EltTy)
  1685. return false;
  1686. }
  1687. return true;
  1688. }
  1689. return false;
  1690. }
  1691. /// isCompatibleAggregate - Check if T1 and T2 are either the same type or are
  1692. /// "homogeneous" aggregates with the same element type and number of elements.
  1693. static bool isCompatibleAggregate(Type *T1, Type *T2) {
  1694. if (T1 == T2)
  1695. return true;
  1696. unsigned NumElts1, NumElts2;
  1697. Type *EltTy1, *EltTy2;
  1698. if (isHomogeneousAggregate(T1, NumElts1, EltTy1) &&
  1699. isHomogeneousAggregate(T2, NumElts2, EltTy2) && NumElts1 == NumElts2 &&
  1700. EltTy1 == EltTy2)
  1701. return true;
  1702. return false;
  1703. }
  1704. /// isSafeMemAccess - Check if a load/store/memcpy operates on the entire AI
  1705. /// alloca or has an offset and size that corresponds to a component element
  1706. /// within it. The offset checked here may have been formed from a GEP with a
  1707. /// pointer bitcasted to a different type.
  1708. ///
  1709. /// If AllowWholeAccess is true, then this allows uses of the entire alloca as a
  1710. /// unit. If false, it only allows accesses known to be in a single element.
  1711. void SROA_HLSL::isSafeMemAccess(uint64_t Offset, uint64_t MemSize,
  1712. Type *MemOpType, bool isStore, AllocaInfo &Info,
  1713. Instruction *TheAccess, bool AllowWholeAccess) {
  1714. // What hlsl cares is Info.hasVectorIndexing.
  1715. // Do nothing here.
  1716. }
  1717. /// TypeHasComponent - Return true if T has a component type with the
  1718. /// specified offset and size. If Size is zero, do not check the size.
  1719. bool SROA_HLSL::TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
  1720. const DataLayout &DL) {
  1721. Type *EltTy;
  1722. uint64_t EltSize;
  1723. if (StructType *ST = dyn_cast<StructType>(T)) {
  1724. const StructLayout *Layout = DL.getStructLayout(ST);
  1725. unsigned EltIdx = Layout->getElementContainingOffset(Offset);
  1726. EltTy = ST->getContainedType(EltIdx);
  1727. EltSize = DL.getTypeAllocSize(EltTy);
  1728. Offset -= Layout->getElementOffset(EltIdx);
  1729. } else if (ArrayType *AT = dyn_cast<ArrayType>(T)) {
  1730. EltTy = AT->getElementType();
  1731. EltSize = DL.getTypeAllocSize(EltTy);
  1732. if (Offset >= AT->getNumElements() * EltSize)
  1733. return false;
  1734. Offset %= EltSize;
  1735. } else if (VectorType *VT = dyn_cast<VectorType>(T)) {
  1736. EltTy = VT->getElementType();
  1737. EltSize = DL.getTypeAllocSize(EltTy);
  1738. if (Offset >= VT->getNumElements() * EltSize)
  1739. return false;
  1740. Offset %= EltSize;
  1741. } else {
  1742. return false;
  1743. }
  1744. if (Offset == 0 && (Size == 0 || EltSize == Size))
  1745. return true;
  1746. // Check if the component spans multiple elements.
  1747. if (Offset + Size > EltSize)
  1748. return false;
  1749. return TypeHasComponent(EltTy, Offset, Size, DL);
  1750. }
  1751. /// LoadVectorArray - Load vector array like [2 x <4 x float>] from
  1752. /// arrays like 4 [2 x float] or struct array like
  1753. /// [2 x { <4 x float>, < 4 x uint> }]
  1754. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  1755. static Value *LoadVectorOrStructArray(ArrayType *AT, ArrayRef<Value *> NewElts,
  1756. SmallVector<Value *, 8> &idxList,
  1757. IRBuilder<> &Builder) {
  1758. Type *EltTy = AT->getElementType();
  1759. Value *retVal = llvm::UndefValue::get(AT);
  1760. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  1761. uint32_t arraySize = AT->getNumElements();
  1762. for (uint32_t i = 0; i < arraySize; i++) {
  1763. Constant *idx = ConstantInt::get(i32Ty, i);
  1764. idxList.emplace_back(idx);
  1765. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  1766. Value *EltVal = LoadVectorOrStructArray(EltAT, NewElts, idxList, Builder);
  1767. retVal = Builder.CreateInsertValue(retVal, EltVal, i);
  1768. } else {
  1769. assert(EltTy->isVectorTy() ||
  1770. EltTy->isStructTy() && "must be a vector or struct type");
  1771. bool isVectorTy = EltTy->isVectorTy();
  1772. Value *retVec = llvm::UndefValue::get(EltTy);
  1773. if (isVectorTy) {
  1774. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  1775. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1776. Value *elt = Builder.CreateLoad(GEP);
  1777. retVec = Builder.CreateInsertElement(retVec, elt, c);
  1778. }
  1779. } else {
  1780. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  1781. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1782. Value *elt = Builder.CreateLoad(GEP);
  1783. retVec = Builder.CreateInsertValue(retVec, elt, c);
  1784. }
  1785. }
  1786. retVal = Builder.CreateInsertValue(retVal, retVec, i);
  1787. }
  1788. idxList.pop_back();
  1789. }
  1790. return retVal;
  1791. }
  1792. /// LoadVectorArray - Store vector array like [2 x <4 x float>] to
  1793. /// arrays like 4 [2 x float] or struct array like
  1794. /// [2 x { <4 x float>, < 4 x uint> }]
  1795. /// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
  1796. static void StoreVectorOrStructArray(ArrayType *AT, Value *val,
  1797. ArrayRef<Value *> NewElts,
  1798. SmallVector<Value *, 8> &idxList,
  1799. IRBuilder<> &Builder) {
  1800. Type *EltTy = AT->getElementType();
  1801. Type *i32Ty = Type::getInt32Ty(EltTy->getContext());
  1802. uint32_t arraySize = AT->getNumElements();
  1803. for (uint32_t i = 0; i < arraySize; i++) {
  1804. Value *elt = Builder.CreateExtractValue(val, i);
  1805. Constant *idx = ConstantInt::get(i32Ty, i);
  1806. idxList.emplace_back(idx);
  1807. if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
  1808. StoreVectorOrStructArray(EltAT, elt, NewElts, idxList, Builder);
  1809. } else {
  1810. assert(EltTy->isVectorTy() ||
  1811. EltTy->isStructTy() && "must be a vector or struct type");
  1812. bool isVectorTy = EltTy->isVectorTy();
  1813. if (isVectorTy) {
  1814. for (uint32_t c = 0; c < EltTy->getVectorNumElements(); c++) {
  1815. Value *component = Builder.CreateExtractElement(elt, c);
  1816. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1817. Builder.CreateStore(component, GEP);
  1818. }
  1819. } else {
  1820. for (uint32_t c = 0; c < EltTy->getStructNumElements(); c++) {
  1821. Value *field = Builder.CreateExtractValue(elt, c);
  1822. Value *GEP = Builder.CreateInBoundsGEP(NewElts[c], idxList);
  1823. Builder.CreateStore(field, GEP);
  1824. }
  1825. }
  1826. }
  1827. idxList.pop_back();
  1828. }
  1829. }
  1830. /// HasPadding - Return true if the specified type has any structure or
  1831. /// alignment padding in between the elements that would be split apart
  1832. /// by SROA; return false otherwise.
  1833. static bool HasPadding(Type *Ty, const DataLayout &DL) {
  1834. if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
  1835. Ty = ATy->getElementType();
  1836. return DL.getTypeSizeInBits(Ty) != DL.getTypeAllocSizeInBits(Ty);
  1837. }
  1838. // SROA currently handles only Arrays and Structs.
  1839. StructType *STy = cast<StructType>(Ty);
  1840. const StructLayout *SL = DL.getStructLayout(STy);
  1841. unsigned PrevFieldBitOffset = 0;
  1842. for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
  1843. unsigned FieldBitOffset = SL->getElementOffsetInBits(i);
  1844. // Check to see if there is any padding between this element and the
  1845. // previous one.
  1846. if (i) {
  1847. unsigned PrevFieldEnd =
  1848. PrevFieldBitOffset + DL.getTypeSizeInBits(STy->getElementType(i - 1));
  1849. if (PrevFieldEnd < FieldBitOffset)
  1850. return true;
  1851. }
  1852. PrevFieldBitOffset = FieldBitOffset;
  1853. }
  1854. // Check for tail padding.
  1855. if (unsigned EltCount = STy->getNumElements()) {
  1856. unsigned PrevFieldEnd =
  1857. PrevFieldBitOffset +
  1858. DL.getTypeSizeInBits(STy->getElementType(EltCount - 1));
  1859. if (PrevFieldEnd < SL->getSizeInBits())
  1860. return true;
  1861. }
  1862. return false;
  1863. }
  1864. /// isSafeStructAllocaToScalarRepl - Check to see if the specified allocation of
  1865. /// an aggregate can be broken down into elements. Return 0 if not, 3 if safe,
  1866. /// or 1 if safe after canonicalization has been performed.
  1867. bool SROA_HLSL::isSafeAllocaToScalarRepl(AllocaInst *AI) {
  1868. // Loop over the use list of the alloca. We can only transform it if all of
  1869. // the users are safe to transform.
  1870. AllocaInfo Info(AI);
  1871. isSafeForScalarRepl(AI, 0, Info);
  1872. if (Info.isUnsafe) {
  1873. DEBUG(dbgs() << "Cannot transform: " << *AI << '\n');
  1874. return false;
  1875. }
  1876. // vector indexing need translate vector into array
  1877. if (Info.hasVectorIndexing)
  1878. return false;
  1879. const DataLayout &DL = AI->getModule()->getDataLayout();
  1880. // Okay, we know all the users are promotable. If the aggregate is a memcpy
  1881. // source and destination, we have to be careful. In particular, the memcpy
  1882. // could be moving around elements that live in structure padding of the LLVM
  1883. // types, but may actually be used. In these cases, we refuse to promote the
  1884. // struct.
  1885. if (Info.isMemCpySrc && Info.isMemCpyDst &&
  1886. HasPadding(AI->getAllocatedType(), DL))
  1887. return false;
  1888. return true;
  1889. }
  1890. // Copy data from srcPtr to destPtr.
  1891. static void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
  1892. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1893. IRBuilder<> &Builder) {
  1894. if (idxList.size() > 1) {
  1895. DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
  1896. SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
  1897. }
  1898. llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
  1899. Builder.CreateStore(ld, DestPtr);
  1900. }
  1901. // Copy srcVal to destPtr.
  1902. static void SimpleValCopy(Value *DestPtr, Value *SrcVal,
  1903. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1904. IRBuilder<> &Builder) {
  1905. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  1906. Value *Val = SrcVal;
  1907. // Skip beginning pointer type.
  1908. for (unsigned i = 1; i < idxList.size(); i++) {
  1909. ConstantInt *idx = cast<ConstantInt>(idxList[i]);
  1910. Type *Ty = Val->getType();
  1911. if (Ty->isAggregateType()) {
  1912. Val = Builder.CreateExtractValue(Val, idx->getLimitedValue());
  1913. }
  1914. }
  1915. Builder.CreateStore(Val, DestGEP);
  1916. }
  1917. static void SimpleCopy(Value *Dest, Value *Src,
  1918. llvm::SmallVector<llvm::Value *, 16> &idxList,
  1919. IRBuilder<> &Builder) {
  1920. if (Src->getType()->isPointerTy())
  1921. SimplePtrCopy(Dest, Src, idxList, Builder);
  1922. else
  1923. SimpleValCopy(Dest, Src, idxList, Builder);
  1924. }
  1925. // Split copy into ld/st.
  1926. static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
  1927. SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
  1928. DxilTypeSystem &typeSys,
  1929. DxilFieldAnnotation *fieldAnnotation) {
  1930. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  1931. Constant *idx = Constant::getIntegerValue(
  1932. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  1933. idxList.emplace_back(idx);
  1934. SplitCpy(PT->getElementType(), Dest, Src, idxList, Builder, typeSys,
  1935. fieldAnnotation);
  1936. idxList.pop_back();
  1937. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  1938. // If no fieldAnnotation, use row major as default.
  1939. // Only load then store immediately should be fine.
  1940. bool bRowMajor = true;
  1941. if (fieldAnnotation) {
  1942. DXASSERT(fieldAnnotation->HasMatrixAnnotation(),
  1943. "must has matrix annotation");
  1944. bRowMajor = fieldAnnotation->GetMatrixAnnotation().Orientation ==
  1945. MatrixOrientation::RowMajor;
  1946. }
  1947. Module *M = Builder.GetInsertPoint()->getModule();
  1948. Value *DestGEP = Builder.CreateInBoundsGEP(Dest, idxList);
  1949. Value *SrcGEP = Builder.CreateInBoundsGEP(Src, idxList);
  1950. if (bRowMajor) {
  1951. Value *Load = HLModule::EmitHLOperationCall(
  1952. Builder, HLOpcodeGroup::HLMatLoadStore,
  1953. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty, {SrcGEP},
  1954. *M);
  1955. // Generate Matrix Store.
  1956. HLModule::EmitHLOperationCall(
  1957. Builder, HLOpcodeGroup::HLMatLoadStore,
  1958. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
  1959. {DestGEP, Load}, *M);
  1960. } else {
  1961. Value *Load = HLModule::EmitHLOperationCall(
  1962. Builder, HLOpcodeGroup::HLMatLoadStore,
  1963. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty, {SrcGEP},
  1964. *M);
  1965. // Generate Matrix Store.
  1966. HLModule::EmitHLOperationCall(
  1967. Builder, HLOpcodeGroup::HLMatLoadStore,
  1968. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
  1969. {DestGEP, Load}, *M);
  1970. }
  1971. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  1972. if (HLModule::IsHLSLObjectType(ST)) {
  1973. // Avoid split HLSL object.
  1974. SimpleCopy(Dest, Src, idxList, Builder);
  1975. return;
  1976. }
  1977. DxilStructAnnotation *STA = typeSys.GetStructAnnotation(ST);
  1978. DXASSERT(STA, "require annotation here");
  1979. if (STA->IsEmptyStruct())
  1980. return;
  1981. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  1982. llvm::Type *ET = ST->getElementType(i);
  1983. Constant *idx = llvm::Constant::getIntegerValue(
  1984. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  1985. idxList.emplace_back(idx);
  1986. DxilFieldAnnotation &EltAnnotation = STA->GetFieldAnnotation(i);
  1987. SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, &EltAnnotation);
  1988. idxList.pop_back();
  1989. }
  1990. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  1991. Type *ET = AT->getElementType();
  1992. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  1993. Constant *idx = Constant::getIntegerValue(
  1994. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  1995. idxList.emplace_back(idx);
  1996. SplitCpy(ET, Dest, Src, idxList, Builder, typeSys, fieldAnnotation);
  1997. idxList.pop_back();
  1998. }
  1999. } else {
  2000. SimpleCopy(Dest, Src, idxList, Builder);
  2001. }
  2002. }
  2003. static void SplitPtr(Type *Ty, Value *Ptr, SmallVector<Value *, 16> &idxList,
  2004. SmallVector<Value *, 16> &EltPtrList,
  2005. IRBuilder<> &Builder) {
  2006. if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
  2007. Constant *idx = Constant::getIntegerValue(
  2008. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  2009. idxList.emplace_back(idx);
  2010. SplitPtr(PT->getElementType(), Ptr, idxList, EltPtrList, Builder);
  2011. idxList.pop_back();
  2012. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  2013. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2014. EltPtrList.emplace_back(GEP);
  2015. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2016. if (HLModule::IsHLSLObjectType(ST)) {
  2017. // Avoid split HLSL object.
  2018. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2019. EltPtrList.emplace_back(GEP);
  2020. return;
  2021. }
  2022. for (uint32_t i = 0; i < ST->getNumElements(); i++) {
  2023. llvm::Type *ET = ST->getElementType(i);
  2024. Constant *idx = llvm::Constant::getIntegerValue(
  2025. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  2026. idxList.emplace_back(idx);
  2027. SplitPtr(ET, Ptr, idxList, EltPtrList, Builder);
  2028. idxList.pop_back();
  2029. }
  2030. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  2031. if (AT->getNumContainedTypes() == 0) {
  2032. // Skip case like [0 x %struct].
  2033. return;
  2034. }
  2035. Type *ElTy = AT->getElementType();
  2036. SmallVector<ArrayType *, 4> nestArrayTys;
  2037. nestArrayTys.emplace_back(AT);
  2038. // support multi level of array
  2039. while (ElTy->isArrayTy()) {
  2040. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2041. nestArrayTys.emplace_back(ElAT);
  2042. ElTy = ElAT->getElementType();
  2043. }
  2044. if (!ElTy->isStructTy() ||
  2045. HLMatrixLower::IsMatrixType(ElTy)) {
  2046. // Not split array of basic type.
  2047. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2048. EltPtrList.emplace_back(GEP);
  2049. }
  2050. else {
  2051. DXASSERT(0, "Not support array of struct when split pointers.");
  2052. }
  2053. } else {
  2054. Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  2055. EltPtrList.emplace_back(GEP);
  2056. }
  2057. }
  2058. // Support case when bitcast (gep ptr, 0,0) is transformed into bitcast ptr.
  2059. static unsigned MatchSizeByCheckElementType(Type *Ty, const DataLayout &DL, unsigned size, unsigned level) {
  2060. unsigned ptrSize = DL.getTypeAllocSize(Ty);
  2061. // Size match, return current level.
  2062. if (ptrSize == size) {
  2063. // Not go deeper for matrix.
  2064. if (HLMatrixLower::IsMatrixType(Ty))
  2065. return level;
  2066. // For struct, go deeper if size not change.
  2067. // This will leave memcpy to deeper level when flatten.
  2068. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2069. if (ST->getNumElements() == 1) {
  2070. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size, level+1);
  2071. }
  2072. }
  2073. // Don't do this for array.
  2074. // Array will be flattened as struct of array.
  2075. return level;
  2076. }
  2077. // Add ZeroIdx cannot make ptrSize bigger.
  2078. if (ptrSize < size)
  2079. return 0;
  2080. // ptrSize > size.
  2081. // Try to use element type to make size match.
  2082. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2083. return MatchSizeByCheckElementType(ST->getElementType(0), DL, size, level+1);
  2084. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  2085. return MatchSizeByCheckElementType(AT->getElementType(), DL, size, level+1);
  2086. } else {
  2087. return 0;
  2088. }
  2089. }
  2090. static void PatchZeroIdxGEP(Value *Ptr, Value *RawPtr, MemCpyInst *MI,
  2091. unsigned level, IRBuilder<> &Builder) {
  2092. Value *zeroIdx = Builder.getInt32(0);
  2093. SmallVector<Value *, 2> IdxList(level + 1, zeroIdx);
  2094. Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  2095. // Use BitCastInst::Create to prevent idxList from being optimized.
  2096. CastInst *Cast =
  2097. BitCastInst::Create(Instruction::BitCast, GEP, RawPtr->getType());
  2098. Builder.Insert(Cast);
  2099. MI->replaceUsesOfWith(RawPtr, Cast);
  2100. // Remove RawPtr if possible.
  2101. if (RawPtr->user_empty()) {
  2102. if (Instruction *I = dyn_cast<Instruction>(RawPtr)) {
  2103. I->eraseFromParent();
  2104. }
  2105. }
  2106. }
  2107. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(MemCpyInst *MI,
  2108. const DataLayout &DL) {
  2109. Value *Dest = MI->getRawDest();
  2110. Value *Src = MI->getRawSource();
  2111. // Only remove one level bitcast generated from inline.
  2112. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  2113. Dest = BC->getOperand(0);
  2114. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  2115. Src = BC->getOperand(0);
  2116. IRBuilder<> Builder(MI);
  2117. ConstantInt *zero = Builder.getInt32(0);
  2118. Type *DestTy = Dest->getType()->getPointerElementType();
  2119. Type *SrcTy = Src->getType()->getPointerElementType();
  2120. // Support case when bitcast (gep ptr, 0,0) is transformed into
  2121. // bitcast ptr.
  2122. // Also replace (gep ptr, 0) with ptr.
  2123. ConstantInt *Length = cast<ConstantInt>(MI->getLength());
  2124. unsigned size = Length->getLimitedValue();
  2125. if (unsigned level = MatchSizeByCheckElementType(DestTy, DL, size, 0)) {
  2126. PatchZeroIdxGEP(Dest, MI->getRawDest(), MI, level, Builder);
  2127. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Dest)) {
  2128. if (GEP->getNumIndices() == 1) {
  2129. Value *idx = *GEP->idx_begin();
  2130. if (idx == zero) {
  2131. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  2132. }
  2133. }
  2134. }
  2135. if (unsigned level = MatchSizeByCheckElementType(SrcTy, DL, size, 0)) {
  2136. PatchZeroIdxGEP(Src, MI->getRawSource(), MI, level, Builder);
  2137. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  2138. if (GEP->getNumIndices() == 1) {
  2139. Value *idx = *GEP->idx_begin();
  2140. if (idx == zero) {
  2141. GEP->replaceAllUsesWith(GEP->getPointerOperand());
  2142. }
  2143. }
  2144. }
  2145. }
  2146. void MemcpySplitter::PatchMemCpyWithZeroIdxGEP(Module &M) {
  2147. const DataLayout &DL = M.getDataLayout();
  2148. for (Function &F : M.functions()) {
  2149. for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
  2150. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
  2151. // Avoid invalidating the iterator.
  2152. Instruction *I = BI++;
  2153. if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
  2154. PatchMemCpyWithZeroIdxGEP(MI, DL);
  2155. }
  2156. }
  2157. }
  2158. }
  2159. }
  2160. static void DeleteMemcpy(MemCpyInst *MI) {
  2161. Value *Op0 = MI->getOperand(0);
  2162. Value *Op1 = MI->getOperand(1);
  2163. // delete memcpy
  2164. MI->eraseFromParent();
  2165. if (Instruction *op0 = dyn_cast<Instruction>(Op0)) {
  2166. if (op0->user_empty())
  2167. op0->eraseFromParent();
  2168. }
  2169. if (Instruction *op1 = dyn_cast<Instruction>(Op1)) {
  2170. if (op1->user_empty())
  2171. op1->eraseFromParent();
  2172. }
  2173. }
  2174. void MemcpySplitter::SplitMemCpy(MemCpyInst *MI, const DataLayout &DL,
  2175. DxilFieldAnnotation *fieldAnnotation,
  2176. DxilTypeSystem &typeSys) {
  2177. Value *Dest = MI->getRawDest();
  2178. Value *Src = MI->getRawSource();
  2179. // Only remove one level bitcast generated from inline.
  2180. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  2181. Dest = BC->getOperand(0);
  2182. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  2183. Src = BC->getOperand(0);
  2184. if (Dest == Src) {
  2185. // delete self copy.
  2186. DeleteMemcpy(MI);
  2187. return;
  2188. }
  2189. IRBuilder<> Builder(MI);
  2190. Type *DestTy = Dest->getType()->getPointerElementType();
  2191. Type *SrcTy = Src->getType()->getPointerElementType();
  2192. // Allow copy between different address space.
  2193. if (DestTy != SrcTy) {
  2194. return;
  2195. }
  2196. llvm::SmallVector<llvm::Value *, 16> idxList;
  2197. // split
  2198. // Matrix is treated as scalar type, will not use memcpy.
  2199. // So use nullptr for fieldAnnotation should be safe here.
  2200. SplitCpy(Dest->getType(), Dest, Src, idxList, Builder, typeSys,
  2201. fieldAnnotation);
  2202. // delete memcpy
  2203. DeleteMemcpy(MI);
  2204. }
  2205. void MemcpySplitter::Split(llvm::Function &F) {
  2206. const DataLayout &DL = F.getParent()->getDataLayout();
  2207. // Walk all instruction in the function.
  2208. for (Function::iterator BB = F.begin(), BBE = F.end(); BB != BBE; ++BB) {
  2209. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
  2210. // Avoid invalidating the iterator.
  2211. Instruction *I = BI++;
  2212. if (MemCpyInst *MI = dyn_cast<MemCpyInst>(I)) {
  2213. // Matrix is treated as scalar type, will not use memcpy.
  2214. // So use nullptr for fieldAnnotation should be safe here.
  2215. SplitMemCpy(MI, DL, /*fieldAnnotation*/ nullptr, m_typeSys);
  2216. }
  2217. }
  2218. }
  2219. }
  2220. //===----------------------------------------------------------------------===//
  2221. // SRoA Helper
  2222. //===----------------------------------------------------------------------===//
  2223. /// RewriteGEP - Rewrite the GEP to be relative to new element when can find a
  2224. /// new element which is struct field. If cannot find, create new element GEPs
  2225. /// and try to rewrite GEP with new GEPS.
  2226. void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
  2227. assert(OldVal == GEP->getPointerOperand() && "");
  2228. Value *NewPointer = nullptr;
  2229. SmallVector<Value *, 8> NewArgs;
  2230. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  2231. for (; GEPIt != E; ++GEPIt) {
  2232. if (GEPIt->isStructTy()) {
  2233. // must be const
  2234. ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
  2235. assert(IdxVal->getLimitedValue() < NewElts.size() && "");
  2236. NewPointer = NewElts[IdxVal->getLimitedValue()];
  2237. // The idx is used for NewPointer, not part of newGEP idx,
  2238. GEPIt++;
  2239. break;
  2240. } else if (GEPIt->isArrayTy()) {
  2241. // Add array idx.
  2242. NewArgs.push_back(GEPIt.getOperand());
  2243. } else if (GEPIt->isPointerTy()) {
  2244. // Add pointer idx.
  2245. NewArgs.push_back(GEPIt.getOperand());
  2246. } else if (GEPIt->isVectorTy()) {
  2247. // Add vector idx.
  2248. NewArgs.push_back(GEPIt.getOperand());
  2249. } else {
  2250. llvm_unreachable("should break from structTy");
  2251. }
  2252. }
  2253. if (NewPointer) {
  2254. // Struct split.
  2255. // Add rest of idx.
  2256. for (; GEPIt != E; ++GEPIt) {
  2257. NewArgs.push_back(GEPIt.getOperand());
  2258. }
  2259. // If only 1 level struct, just use the new pointer.
  2260. Value *NewGEP = NewPointer;
  2261. if (NewArgs.size() > 1) {
  2262. NewGEP = Builder.CreateInBoundsGEP(NewPointer, NewArgs);
  2263. NewGEP->takeName(GEP);
  2264. }
  2265. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  2266. GEP->replaceAllUsesWith(NewGEP);
  2267. if (isa<Instruction>(GEP))
  2268. DeadInsts.push_back(GEP);
  2269. } else {
  2270. // End at array of basic type.
  2271. Type *Ty = GEP->getType()->getPointerElementType();
  2272. if (Ty->isVectorTy() ||
  2273. (Ty->isStructTy() && !HLModule::IsHLSLObjectType(Ty)) ||
  2274. Ty->isArrayTy()) {
  2275. SmallVector<Value *, 8> NewArgs;
  2276. NewArgs.append(GEP->idx_begin(), GEP->idx_end());
  2277. SmallVector<Value *, 8> NewGEPs;
  2278. // create new geps
  2279. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2280. Value *NewGEP = Builder.CreateGEP(nullptr, NewElts[i], NewArgs);
  2281. NewGEPs.emplace_back(NewGEP);
  2282. }
  2283. SROA_Helper helper(GEP, NewGEPs, DeadInsts);
  2284. helper.RewriteForScalarRepl(GEP, Builder);
  2285. for (Value *NewGEP : NewGEPs) {
  2286. if (NewGEP->user_empty() && isa<Instruction>(NewGEP)) {
  2287. // Delete unused newGEP.
  2288. cast<Instruction>(NewGEP)->eraseFromParent();
  2289. }
  2290. }
  2291. if (GEP->user_empty() && isa<Instruction>(GEP))
  2292. DeadInsts.push_back(GEP);
  2293. } else {
  2294. Value *vecIdx = NewArgs.back();
  2295. if (ConstantInt *immVecIdx = dyn_cast<ConstantInt>(vecIdx)) {
  2296. // Replace vecArray[arrayIdx][immVecIdx]
  2297. // with scalarArray_immVecIdx[arrayIdx]
  2298. // Pop the vecIdx.
  2299. NewArgs.pop_back();
  2300. Value *NewGEP = NewElts[immVecIdx->getLimitedValue()];
  2301. if (NewArgs.size() > 1) {
  2302. NewGEP = Builder.CreateInBoundsGEP(NewGEP, NewArgs);
  2303. NewGEP->takeName(GEP);
  2304. }
  2305. assert(NewGEP->getType() == GEP->getType() && "type mismatch");
  2306. GEP->replaceAllUsesWith(NewGEP);
  2307. if (isa<Instruction>(GEP))
  2308. DeadInsts.push_back(GEP);
  2309. } else {
  2310. // dynamic vector indexing.
  2311. assert(0 && "should not reach here");
  2312. }
  2313. }
  2314. }
  2315. }
  2316. /// isVectorOrStructArray - Check if T is array of vector or struct.
  2317. static bool isVectorOrStructArray(Type *T) {
  2318. if (!T->isArrayTy())
  2319. return false;
  2320. T = dxilutil::GetArrayEltTy(T);
  2321. return T->isStructTy() || T->isVectorTy();
  2322. }
  2323. static void SimplifyStructValUsage(Value *StructVal, std::vector<Value *> Elts,
  2324. SmallVectorImpl<Value *> &DeadInsts) {
  2325. for (User *user : StructVal->users()) {
  2326. if (ExtractValueInst *Extract = dyn_cast<ExtractValueInst>(user)) {
  2327. DXASSERT(Extract->getNumIndices() == 1, "only support 1 index case");
  2328. unsigned index = Extract->getIndices()[0];
  2329. Value *Elt = Elts[index];
  2330. Extract->replaceAllUsesWith(Elt);
  2331. DeadInsts.emplace_back(Extract);
  2332. } else if (InsertValueInst *Insert = dyn_cast<InsertValueInst>(user)) {
  2333. DXASSERT(Insert->getNumIndices() == 1, "only support 1 index case");
  2334. unsigned index = Insert->getIndices()[0];
  2335. if (Insert->getAggregateOperand() == StructVal) {
  2336. // Update field.
  2337. std::vector<Value *> NewElts = Elts;
  2338. NewElts[index] = Insert->getInsertedValueOperand();
  2339. SimplifyStructValUsage(Insert, NewElts, DeadInsts);
  2340. } else {
  2341. // Insert to another bigger struct.
  2342. IRBuilder<> Builder(Insert);
  2343. Value *TmpStructVal = UndefValue::get(StructVal->getType());
  2344. for (unsigned i = 0; i < Elts.size(); i++) {
  2345. TmpStructVal =
  2346. Builder.CreateInsertValue(TmpStructVal, Elts[i], {i});
  2347. }
  2348. Insert->replaceUsesOfWith(StructVal, TmpStructVal);
  2349. }
  2350. }
  2351. }
  2352. }
  2353. /// RewriteForLoad - Replace OldVal with flattened NewElts in LoadInst.
  2354. void SROA_Helper::RewriteForLoad(LoadInst *LI) {
  2355. Type *LIType = LI->getType();
  2356. Type *ValTy = OldVal->getType()->getPointerElementType();
  2357. IRBuilder<> Builder(LI);
  2358. if (LIType->isVectorTy()) {
  2359. // Replace:
  2360. // %res = load { 2 x i32 }* %alloc
  2361. // with:
  2362. // %load.0 = load i32* %alloc.0
  2363. // %insert.0 insertvalue { 2 x i32 } zeroinitializer, i32 %load.0, 0
  2364. // %load.1 = load i32* %alloc.1
  2365. // %insert = insertvalue { 2 x i32 } %insert.0, i32 %load.1, 1
  2366. Value *Insert = UndefValue::get(LIType);
  2367. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2368. Value *Load = Builder.CreateLoad(NewElts[i], "load");
  2369. Insert = Builder.CreateInsertElement(Insert, Load, i, "insert");
  2370. }
  2371. LI->replaceAllUsesWith(Insert);
  2372. DeadInsts.push_back(LI);
  2373. } else if (isCompatibleAggregate(LIType, ValTy)) {
  2374. if (isVectorOrStructArray(LIType)) {
  2375. // Replace:
  2376. // %res = load [2 x <2 x float>] * %alloc
  2377. // with:
  2378. // %load.0 = load [4 x float]* %alloc.0
  2379. // %insert.0 insertvalue [4 x float] zeroinitializer,i32 %load.0,0
  2380. // %load.1 = load [4 x float]* %alloc.1
  2381. // %insert = insertvalue [4 x float] %insert.0, i32 %load.1, 1
  2382. // ...
  2383. Type *i32Ty = Type::getInt32Ty(LIType->getContext());
  2384. Value *zero = ConstantInt::get(i32Ty, 0);
  2385. SmallVector<Value *, 8> idxList;
  2386. idxList.emplace_back(zero);
  2387. Value *newLd =
  2388. LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
  2389. LI->replaceAllUsesWith(newLd);
  2390. DeadInsts.push_back(LI);
  2391. } else {
  2392. // Replace:
  2393. // %res = load { i32, i32 }* %alloc
  2394. // with:
  2395. // %load.0 = load i32* %alloc.0
  2396. // %insert.0 insertvalue { i32, i32 } zeroinitializer, i32 %load.0,
  2397. // 0
  2398. // %load.1 = load i32* %alloc.1
  2399. // %insert = insertvalue { i32, i32 } %insert.0, i32 %load.1, 1
  2400. // (Also works for arrays instead of structs)
  2401. Module *M = LI->getModule();
  2402. Value *Insert = UndefValue::get(LIType);
  2403. std::vector<Value *> LdElts(NewElts.size());
  2404. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2405. Value *Ptr = NewElts[i];
  2406. Type *Ty = Ptr->getType()->getPointerElementType();
  2407. Value *Load = nullptr;
  2408. if (!HLMatrixLower::IsMatrixType(Ty))
  2409. Load = Builder.CreateLoad(Ptr, "load");
  2410. else {
  2411. // Generate Matrix Load.
  2412. Load = HLModule::EmitHLOperationCall(
  2413. Builder, HLOpcodeGroup::HLMatLoadStore,
  2414. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty,
  2415. {Ptr}, *M);
  2416. }
  2417. LdElts[i] = Load;
  2418. Insert = Builder.CreateInsertValue(Insert, Load, i, "insert");
  2419. }
  2420. LI->replaceAllUsesWith(Insert);
  2421. if (LIType->isStructTy()) {
  2422. SimplifyStructValUsage(Insert, LdElts, DeadInsts);
  2423. }
  2424. DeadInsts.push_back(LI);
  2425. }
  2426. } else {
  2427. llvm_unreachable("other type don't need rewrite");
  2428. }
  2429. }
  2430. /// RewriteForStore - Replace OldVal with flattened NewElts in StoreInst.
  2431. void SROA_Helper::RewriteForStore(StoreInst *SI) {
  2432. Value *Val = SI->getOperand(0);
  2433. Type *SIType = Val->getType();
  2434. IRBuilder<> Builder(SI);
  2435. Type *ValTy = OldVal->getType()->getPointerElementType();
  2436. if (SIType->isVectorTy()) {
  2437. // Replace:
  2438. // store <2 x float> %val, <2 x float>* %alloc
  2439. // with:
  2440. // %val.0 = extractelement { 2 x float } %val, 0
  2441. // store i32 %val.0, i32* %alloc.0
  2442. // %val.1 = extractelement { 2 x float } %val, 1
  2443. // store i32 %val.1, i32* %alloc.1
  2444. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2445. Value *Extract = Builder.CreateExtractElement(Val, i, Val->getName());
  2446. Builder.CreateStore(Extract, NewElts[i]);
  2447. }
  2448. DeadInsts.push_back(SI);
  2449. } else if (isCompatibleAggregate(SIType, ValTy)) {
  2450. if (isVectorOrStructArray(SIType)) {
  2451. // Replace:
  2452. // store [2 x <2 x i32>] %val, [2 x <2 x i32>]* %alloc, align 16
  2453. // with:
  2454. // %val.0 = extractvalue [2 x <2 x i32>] %val, 0
  2455. // %all0c.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2456. // i32 0, i32 0
  2457. // %val.0.0 = extractelement <2 x i32> %243, i64 0
  2458. // store i32 %val.0.0, i32* %all0c.0.0
  2459. // %alloc.1.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2460. // i32 0, i32 0
  2461. // %val.0.1 = extractelement <2 x i32> %243, i64 1
  2462. // store i32 %val.0.1, i32* %alloc.1.0
  2463. // %val.1 = extractvalue [2 x <2 x i32>] %val, 1
  2464. // %alloc.0.0 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.0,
  2465. // i32 0, i32 1
  2466. // %val.1.0 = extractelement <2 x i32> %248, i64 0
  2467. // store i32 %val.1.0, i32* %alloc.0.0
  2468. // %all0c.1.1 = getelementptr inbounds [2 x i32], [2 x i32]* %alloc.1,
  2469. // i32 0, i32 1
  2470. // %val.1.1 = extractelement <2 x i32> %248, i64 1
  2471. // store i32 %val.1.1, i32* %all0c.1.1
  2472. ArrayType *AT = cast<ArrayType>(SIType);
  2473. Type *i32Ty = Type::getInt32Ty(SIType->getContext());
  2474. Value *zero = ConstantInt::get(i32Ty, 0);
  2475. SmallVector<Value *, 8> idxList;
  2476. idxList.emplace_back(zero);
  2477. StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
  2478. DeadInsts.push_back(SI);
  2479. } else {
  2480. // Replace:
  2481. // store { i32, i32 } %val, { i32, i32 }* %alloc
  2482. // with:
  2483. // %val.0 = extractvalue { i32, i32 } %val, 0
  2484. // store i32 %val.0, i32* %alloc.0
  2485. // %val.1 = extractvalue { i32, i32 } %val, 1
  2486. // store i32 %val.1, i32* %alloc.1
  2487. // (Also works for arrays instead of structs)
  2488. Module *M = SI->getModule();
  2489. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2490. Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
  2491. if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
  2492. Builder.CreateStore(Extract, NewElts[i]);
  2493. } else {
  2494. // Generate Matrix Store.
  2495. HLModule::EmitHLOperationCall(
  2496. Builder, HLOpcodeGroup::HLMatLoadStore,
  2497. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore),
  2498. Extract->getType(), {NewElts[i], Extract}, *M);
  2499. }
  2500. }
  2501. DeadInsts.push_back(SI);
  2502. }
  2503. } else {
  2504. llvm_unreachable("other type don't need rewrite");
  2505. }
  2506. }
  2507. /// RewriteMemIntrin - MI is a memcpy/memset/memmove from or to AI.
  2508. /// Rewrite it to copy or set the elements of the scalarized memory.
  2509. void SROA_Helper::RewriteMemIntrin(MemIntrinsic *MI, Instruction *Inst) {
  2510. // If this is a memcpy/memmove, construct the other pointer as the
  2511. // appropriate type. The "Other" pointer is the pointer that goes to memory
  2512. // that doesn't have anything to do with the alloca that we are promoting. For
  2513. // memset, this Value* stays null.
  2514. Value *OtherPtr = nullptr;
  2515. unsigned MemAlignment = MI->getAlignment();
  2516. if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { // memmove/memcopy
  2517. if (Inst == MTI->getRawDest())
  2518. OtherPtr = MTI->getRawSource();
  2519. else {
  2520. assert(Inst == MTI->getRawSource());
  2521. OtherPtr = MTI->getRawDest();
  2522. }
  2523. }
  2524. // If there is an other pointer, we want to convert it to the same pointer
  2525. // type as AI has, so we can GEP through it safely.
  2526. if (OtherPtr) {
  2527. unsigned AddrSpace =
  2528. cast<PointerType>(OtherPtr->getType())->getAddressSpace();
  2529. // Remove bitcasts and all-zero GEPs from OtherPtr. This is an
  2530. // optimization, but it's also required to detect the corner case where
  2531. // both pointer operands are referencing the same memory, and where
  2532. // OtherPtr may be a bitcast or GEP that currently being rewritten. (This
  2533. // function is only called for mem intrinsics that access the whole
  2534. // aggregate, so non-zero GEPs are not an issue here.)
  2535. OtherPtr = OtherPtr->stripPointerCasts();
  2536. // Copying the alloca to itself is a no-op: just delete it.
  2537. if (OtherPtr == OldVal || OtherPtr == NewElts[0]) {
  2538. // This code will run twice for a no-op memcpy -- once for each operand.
  2539. // Put only one reference to MI on the DeadInsts list.
  2540. for (SmallVectorImpl<Value *>::const_iterator I = DeadInsts.begin(),
  2541. E = DeadInsts.end();
  2542. I != E; ++I)
  2543. if (*I == MI)
  2544. return;
  2545. DeadInsts.push_back(MI);
  2546. return;
  2547. }
  2548. // If the pointer is not the right type, insert a bitcast to the right
  2549. // type.
  2550. Type *NewTy =
  2551. PointerType::get(OldVal->getType()->getPointerElementType(), AddrSpace);
  2552. if (OtherPtr->getType() != NewTy)
  2553. OtherPtr = new BitCastInst(OtherPtr, NewTy, OtherPtr->getName(), MI);
  2554. }
  2555. // Process each element of the aggregate.
  2556. bool SROADest = MI->getRawDest() == Inst;
  2557. Constant *Zero = Constant::getNullValue(Type::getInt32Ty(MI->getContext()));
  2558. const DataLayout &DL = MI->getModule()->getDataLayout();
  2559. for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
  2560. // If this is a memcpy/memmove, emit a GEP of the other element address.
  2561. Value *OtherElt = nullptr;
  2562. unsigned OtherEltAlign = MemAlignment;
  2563. if (OtherPtr) {
  2564. Value *Idx[2] = {Zero,
  2565. ConstantInt::get(Type::getInt32Ty(MI->getContext()), i)};
  2566. OtherElt = GetElementPtrInst::CreateInBounds(
  2567. OtherPtr, Idx, OtherPtr->getName() + "." + Twine(i), MI);
  2568. uint64_t EltOffset;
  2569. PointerType *OtherPtrTy = cast<PointerType>(OtherPtr->getType());
  2570. Type *OtherTy = OtherPtrTy->getElementType();
  2571. if (StructType *ST = dyn_cast<StructType>(OtherTy)) {
  2572. EltOffset = DL.getStructLayout(ST)->getElementOffset(i);
  2573. } else {
  2574. Type *EltTy = cast<SequentialType>(OtherTy)->getElementType();
  2575. EltOffset = DL.getTypeAllocSize(EltTy) * i;
  2576. }
  2577. // The alignment of the other pointer is the guaranteed alignment of the
  2578. // element, which is affected by both the known alignment of the whole
  2579. // mem intrinsic and the alignment of the element. If the alignment of
  2580. // the memcpy (f.e.) is 32 but the element is at a 4-byte offset, then the
  2581. // known alignment is just 4 bytes.
  2582. OtherEltAlign = (unsigned)MinAlign(OtherEltAlign, EltOffset);
  2583. }
  2584. Value *EltPtr = NewElts[i];
  2585. Type *EltTy = cast<PointerType>(EltPtr->getType())->getElementType();
  2586. // If we got down to a scalar, insert a load or store as appropriate.
  2587. if (EltTy->isSingleValueType()) {
  2588. if (isa<MemTransferInst>(MI)) {
  2589. if (SROADest) {
  2590. // From Other to Alloca.
  2591. Value *Elt = new LoadInst(OtherElt, "tmp", false, OtherEltAlign, MI);
  2592. new StoreInst(Elt, EltPtr, MI);
  2593. } else {
  2594. // From Alloca to Other.
  2595. Value *Elt = new LoadInst(EltPtr, "tmp", MI);
  2596. new StoreInst(Elt, OtherElt, false, OtherEltAlign, MI);
  2597. }
  2598. continue;
  2599. }
  2600. assert(isa<MemSetInst>(MI));
  2601. // If the stored element is zero (common case), just store a null
  2602. // constant.
  2603. Constant *StoreVal;
  2604. if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getArgOperand(1))) {
  2605. if (CI->isZero()) {
  2606. StoreVal = Constant::getNullValue(EltTy); // 0.0, null, 0, <0,0>
  2607. } else {
  2608. // If EltTy is a vector type, get the element type.
  2609. Type *ValTy = EltTy->getScalarType();
  2610. // Construct an integer with the right value.
  2611. unsigned EltSize = DL.getTypeSizeInBits(ValTy);
  2612. APInt OneVal(EltSize, CI->getZExtValue());
  2613. APInt TotalVal(OneVal);
  2614. // Set each byte.
  2615. for (unsigned i = 0; 8 * i < EltSize; ++i) {
  2616. TotalVal = TotalVal.shl(8);
  2617. TotalVal |= OneVal;
  2618. }
  2619. // Convert the integer value to the appropriate type.
  2620. StoreVal = ConstantInt::get(CI->getContext(), TotalVal);
  2621. if (ValTy->isPointerTy())
  2622. StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
  2623. else if (ValTy->isFloatingPointTy())
  2624. StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
  2625. assert(StoreVal->getType() == ValTy && "Type mismatch!");
  2626. // If the requested value was a vector constant, create it.
  2627. if (EltTy->isVectorTy()) {
  2628. unsigned NumElts = cast<VectorType>(EltTy)->getNumElements();
  2629. StoreVal = ConstantVector::getSplat(NumElts, StoreVal);
  2630. }
  2631. }
  2632. new StoreInst(StoreVal, EltPtr, MI);
  2633. continue;
  2634. }
  2635. // Otherwise, if we're storing a byte variable, use a memset call for
  2636. // this element.
  2637. }
  2638. unsigned EltSize = DL.getTypeAllocSize(EltTy);
  2639. if (!EltSize)
  2640. continue;
  2641. IRBuilder<> Builder(MI);
  2642. // Finally, insert the meminst for this element.
  2643. if (isa<MemSetInst>(MI)) {
  2644. Builder.CreateMemSet(EltPtr, MI->getArgOperand(1), EltSize,
  2645. MI->isVolatile());
  2646. } else {
  2647. assert(isa<MemTransferInst>(MI));
  2648. Value *Dst = SROADest ? EltPtr : OtherElt; // Dest ptr
  2649. Value *Src = SROADest ? OtherElt : EltPtr; // Src ptr
  2650. if (isa<MemCpyInst>(MI))
  2651. Builder.CreateMemCpy(Dst, Src, EltSize, OtherEltAlign,
  2652. MI->isVolatile());
  2653. else
  2654. Builder.CreateMemMove(Dst, Src, EltSize, OtherEltAlign,
  2655. MI->isVolatile());
  2656. }
  2657. }
  2658. DeadInsts.push_back(MI);
  2659. }
  2660. void SROA_Helper::RewriteBitCast(BitCastInst *BCI) {
  2661. Type *DstTy = BCI->getType();
  2662. Value *Val = BCI->getOperand(0);
  2663. Type *SrcTy = Val->getType();
  2664. if (!DstTy->isPointerTy()) {
  2665. assert(0 && "Type mismatch.");
  2666. return;
  2667. }
  2668. if (!SrcTy->isPointerTy()) {
  2669. assert(0 && "Type mismatch.");
  2670. return;
  2671. }
  2672. DstTy = DstTy->getPointerElementType();
  2673. SrcTy = SrcTy->getPointerElementType();
  2674. if (!DstTy->isStructTy()) {
  2675. assert(0 && "Type mismatch.");
  2676. return;
  2677. }
  2678. if (!SrcTy->isStructTy()) {
  2679. assert(0 && "Type mismatch.");
  2680. return;
  2681. }
  2682. // Only support bitcast to parent struct type.
  2683. StructType *DstST = cast<StructType>(DstTy);
  2684. StructType *SrcST = cast<StructType>(SrcTy);
  2685. bool bTypeMatch = false;
  2686. unsigned level = 0;
  2687. while (SrcST) {
  2688. level++;
  2689. Type *EltTy = SrcST->getElementType(0);
  2690. if (EltTy == DstST) {
  2691. bTypeMatch = true;
  2692. break;
  2693. }
  2694. SrcST = dyn_cast<StructType>(EltTy);
  2695. }
  2696. if (!bTypeMatch) {
  2697. assert(0 && "Type mismatch.");
  2698. return;
  2699. }
  2700. std::vector<Value*> idxList(level+1);
  2701. ConstantInt *zeroIdx = ConstantInt::get(Type::getInt32Ty(Val->getContext()), 0);
  2702. for (unsigned i=0;i<(level+1);i++)
  2703. idxList[i] = zeroIdx;
  2704. IRBuilder<> Builder(BCI);
  2705. Instruction *GEP = cast<Instruction>(Builder.CreateInBoundsGEP(Val, idxList));
  2706. BCI->replaceAllUsesWith(GEP);
  2707. BCI->eraseFromParent();
  2708. IRBuilder<> GEPBuilder(GEP);
  2709. RewriteForGEP(cast<GEPOperator>(GEP), GEPBuilder);
  2710. }
  2711. /// RewriteCall - Replace OldVal with flattened NewElts in CallInst.
  2712. void SROA_Helper::RewriteCall(CallInst *CI) {
  2713. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  2714. Function *F = CI->getCalledFunction();
  2715. if (group != HLOpcodeGroup::NotHL) {
  2716. unsigned opcode = GetHLOpcode(CI);
  2717. if (group == HLOpcodeGroup::HLIntrinsic) {
  2718. IntrinsicOp IOP = static_cast<IntrinsicOp>(opcode);
  2719. switch (IOP) {
  2720. case IntrinsicOp::MOP_Append: {
  2721. // Buffer Append already expand in code gen.
  2722. // Must be OutputStream Append here.
  2723. SmallVector<Value *, 4> flatArgs;
  2724. for (Value *arg : CI->arg_operands()) {
  2725. if (arg == OldVal) {
  2726. // Flatten to arg.
  2727. // Every Elt has a pointer type.
  2728. // For Append, it's not a problem.
  2729. for (Value *Elt : NewElts)
  2730. flatArgs.emplace_back(Elt);
  2731. } else
  2732. flatArgs.emplace_back(arg);
  2733. }
  2734. SmallVector<Type *, 4> flatParamTys;
  2735. for (Value *arg : flatArgs)
  2736. flatParamTys.emplace_back(arg->getType());
  2737. // Don't need flat return type for Append.
  2738. FunctionType *flatFuncTy =
  2739. FunctionType::get(CI->getType(), flatParamTys, false);
  2740. Function *flatF =
  2741. GetOrCreateHLFunction(*F->getParent(), flatFuncTy, group, opcode);
  2742. IRBuilder<> Builder(CI);
  2743. // Append return void, don't need to replace CI with flatCI.
  2744. Builder.CreateCall(flatF, flatArgs);
  2745. DeadInsts.push_back(CI);
  2746. } break;
  2747. default:
  2748. DXASSERT(0, "cannot flatten hlsl intrinsic.");
  2749. }
  2750. }
  2751. // TODO: check other high level dx operations if need to.
  2752. } else {
  2753. DXASSERT(0, "should done at inline");
  2754. }
  2755. }
  2756. /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
  2757. void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
  2758. if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
  2759. if (OldVal == GEP->getPointerOperand()) {
  2760. // Flatten GEP.
  2761. RewriteForGEP(GEP, Builder);
  2762. return;
  2763. }
  2764. }
  2765. // Skip unused CE.
  2766. if (CE->use_empty())
  2767. return;
  2768. Instruction *constInst = CE->getAsInstruction();
  2769. Builder.Insert(constInst);
  2770. // Replace CE with constInst.
  2771. for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
  2772. Use &TheUse = *UI++;
  2773. if (isa<Instruction>(TheUse.getUser()))
  2774. TheUse.set(constInst);
  2775. else {
  2776. RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
  2777. }
  2778. }
  2779. }
  2780. /// RewriteForScalarRepl - OldVal is being split into NewElts, so rewrite
  2781. /// users of V, which references it, to use the separate elements.
  2782. void SROA_Helper::RewriteForScalarRepl(Value *V, IRBuilder<> &Builder) {
  2783. for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E;) {
  2784. Use &TheUse = *UI++;
  2785. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(TheUse.getUser())) {
  2786. RewriteForConstExpr(CE, Builder);
  2787. continue;
  2788. }
  2789. Instruction *User = cast<Instruction>(TheUse.getUser());
  2790. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  2791. IRBuilder<> Builder(GEP);
  2792. RewriteForGEP(cast<GEPOperator>(GEP), Builder);
  2793. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User))
  2794. RewriteForLoad(ldInst);
  2795. else if (StoreInst *stInst = dyn_cast<StoreInst>(User))
  2796. RewriteForStore(stInst);
  2797. else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User))
  2798. RewriteMemIntrin(MI, cast<Instruction>(V));
  2799. else if (CallInst *CI = dyn_cast<CallInst>(User))
  2800. RewriteCall(CI);
  2801. else if (BitCastInst *BCI = dyn_cast<BitCastInst>(User))
  2802. RewriteBitCast(BCI);
  2803. else {
  2804. assert(0 && "not support.");
  2805. }
  2806. }
  2807. }
  2808. static ArrayType *CreateNestArrayTy(Type *FinalEltTy,
  2809. ArrayRef<ArrayType *> nestArrayTys) {
  2810. Type *newAT = FinalEltTy;
  2811. for (auto ArrayTy = nestArrayTys.rbegin(), E=nestArrayTys.rend(); ArrayTy != E;
  2812. ++ArrayTy)
  2813. newAT = ArrayType::get(newAT, (*ArrayTy)->getNumElements());
  2814. return cast<ArrayType>(newAT);
  2815. }
  2816. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2817. /// Then do SROA on V.
  2818. bool SROA_Helper::DoScalarReplacement(Value *V, std::vector<Value *> &Elts,
  2819. IRBuilder<> &Builder, bool bFlatVector,
  2820. bool hasPrecise, DxilTypeSystem &typeSys,
  2821. SmallVector<Value *, 32> &DeadInsts) {
  2822. DEBUG(dbgs() << "Found inst to SROA: " << *V << '\n');
  2823. Type *Ty = V->getType();
  2824. // Skip none pointer types.
  2825. if (!Ty->isPointerTy())
  2826. return false;
  2827. Ty = Ty->getPointerElementType();
  2828. // Skip none aggregate types.
  2829. if (!Ty->isAggregateType())
  2830. return false;
  2831. // Skip matrix types.
  2832. if (HLMatrixLower::IsMatrixType(Ty))
  2833. return false;
  2834. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2835. // Skip HLSL object types.
  2836. if (HLModule::IsHLSLObjectType(ST)) {
  2837. return false;
  2838. }
  2839. unsigned numTypes = ST->getNumContainedTypes();
  2840. Elts.reserve(numTypes);
  2841. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  2842. // Skip empty struct.
  2843. if (SA && SA->IsEmptyStruct())
  2844. return true;
  2845. for (int i = 0, e = numTypes; i != e; ++i) {
  2846. AllocaInst *NA = Builder.CreateAlloca(ST->getContainedType(i), nullptr, V->getName() + "." + Twine(i));
  2847. bool markPrecise = hasPrecise;
  2848. if (SA) {
  2849. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2850. markPrecise |= FA.IsPrecise();
  2851. }
  2852. if (markPrecise)
  2853. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2854. Elts.push_back(NA);
  2855. }
  2856. } else {
  2857. ArrayType *AT = cast<ArrayType>(Ty);
  2858. if (AT->getNumContainedTypes() == 0) {
  2859. // Skip case like [0 x %struct].
  2860. return false;
  2861. }
  2862. Type *ElTy = AT->getElementType();
  2863. SmallVector<ArrayType *, 4> nestArrayTys;
  2864. nestArrayTys.emplace_back(AT);
  2865. // support multi level of array
  2866. while (ElTy->isArrayTy()) {
  2867. ArrayType *ElAT = cast<ArrayType>(ElTy);
  2868. nestArrayTys.emplace_back(ElAT);
  2869. ElTy = ElAT->getElementType();
  2870. }
  2871. if (ElTy->isStructTy() &&
  2872. // Skip Matrix type.
  2873. !HLMatrixLower::IsMatrixType(ElTy)) {
  2874. if (!HLModule::IsHLSLObjectType(ElTy)) {
  2875. // for array of struct
  2876. // split into arrays of struct elements
  2877. StructType *ElST = cast<StructType>(ElTy);
  2878. unsigned numTypes = ElST->getNumContainedTypes();
  2879. Elts.reserve(numTypes);
  2880. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  2881. // Skip empty struct.
  2882. if (SA && SA->IsEmptyStruct())
  2883. return true;
  2884. for (int i = 0, e = numTypes; i != e; ++i) {
  2885. AllocaInst *NA = Builder.CreateAlloca(
  2886. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys),
  2887. nullptr, V->getName() + "." + Twine(i));
  2888. bool markPrecise = hasPrecise;
  2889. if (SA) {
  2890. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  2891. markPrecise |= FA.IsPrecise();
  2892. }
  2893. if (markPrecise)
  2894. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2895. Elts.push_back(NA);
  2896. }
  2897. } else {
  2898. // For local resource array which not dynamic indexing,
  2899. // split it.
  2900. if (dxilutil::HasDynamicIndexing(V) ||
  2901. // Only support 1 dim split.
  2902. nestArrayTys.size() > 1)
  2903. return false;
  2904. for (int i = 0, e = AT->getNumElements(); i != e; ++i) {
  2905. AllocaInst *NA = Builder.CreateAlloca(ElTy, nullptr,
  2906. V->getName() + "." + Twine(i));
  2907. Elts.push_back(NA);
  2908. }
  2909. }
  2910. } else if (ElTy->isVectorTy()) {
  2911. // Skip vector if required.
  2912. if (!bFlatVector)
  2913. return false;
  2914. // for array of vector
  2915. // split into arrays of scalar
  2916. VectorType *ElVT = cast<VectorType>(ElTy);
  2917. Elts.reserve(ElVT->getNumElements());
  2918. ArrayType *scalarArrayTy = CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  2919. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  2920. AllocaInst *NA = Builder.CreateAlloca(scalarArrayTy, nullptr,
  2921. V->getName() + "." + Twine(i));
  2922. if (hasPrecise)
  2923. HLModule::MarkPreciseAttributeWithMetadata(NA);
  2924. Elts.push_back(NA);
  2925. }
  2926. } else
  2927. // Skip array of basic types.
  2928. return false;
  2929. }
  2930. // Now that we have created the new alloca instructions, rewrite all the
  2931. // uses of the old alloca.
  2932. SROA_Helper helper(V, Elts, DeadInsts);
  2933. helper.RewriteForScalarRepl(V, Builder);
  2934. return true;
  2935. }
  2936. static Constant *GetEltInit(Type *Ty, Constant *Init, unsigned idx,
  2937. Type *EltTy) {
  2938. if (isa<UndefValue>(Init))
  2939. return UndefValue::get(EltTy);
  2940. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  2941. return Init->getAggregateElement(idx);
  2942. } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  2943. return Init->getAggregateElement(idx);
  2944. } else {
  2945. ArrayType *AT = cast<ArrayType>(Ty);
  2946. ArrayType *EltArrayTy = cast<ArrayType>(EltTy);
  2947. std::vector<Constant *> Elts;
  2948. if (!AT->getElementType()->isArrayTy()) {
  2949. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2950. // Get Array[i]
  2951. Constant *InitArrayElt = Init->getAggregateElement(i);
  2952. // Get Array[i].idx
  2953. InitArrayElt = InitArrayElt->getAggregateElement(idx);
  2954. Elts.emplace_back(InitArrayElt);
  2955. }
  2956. return ConstantArray::get(EltArrayTy, Elts);
  2957. } else {
  2958. Type *EltTy = AT->getElementType();
  2959. ArrayType *NestEltArrayTy = cast<ArrayType>(EltArrayTy->getElementType());
  2960. // Nested array.
  2961. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  2962. // Get Array[i]
  2963. Constant *InitArrayElt = Init->getAggregateElement(i);
  2964. // Get Array[i].idx
  2965. InitArrayElt = GetEltInit(EltTy, InitArrayElt, idx, NestEltArrayTy);
  2966. Elts.emplace_back(InitArrayElt);
  2967. }
  2968. return ConstantArray::get(EltArrayTy, Elts);
  2969. }
  2970. }
  2971. }
  2972. /// DoScalarReplacement - Split V into AllocaInsts with Builder and save the new AllocaInsts into Elts.
  2973. /// Then do SROA on V.
  2974. bool SROA_Helper::DoScalarReplacement(GlobalVariable *GV, std::vector<Value *> &Elts,
  2975. IRBuilder<> &Builder, bool bFlatVector,
  2976. bool hasPrecise, DxilTypeSystem &typeSys,
  2977. SmallVector<Value *, 32> &DeadInsts) {
  2978. DEBUG(dbgs() << "Found inst to SROA: " << *GV << '\n');
  2979. Type *Ty = GV->getType();
  2980. // Skip none pointer types.
  2981. if (!Ty->isPointerTy())
  2982. return false;
  2983. Ty = Ty->getPointerElementType();
  2984. // Skip none aggregate types.
  2985. if (!Ty->isAggregateType() && !bFlatVector)
  2986. return false;
  2987. // Skip basic types.
  2988. if (Ty->isSingleValueType() && !Ty->isVectorTy())
  2989. return false;
  2990. // Skip matrix types.
  2991. if (HLMatrixLower::IsMatrixType(Ty))
  2992. return false;
  2993. Module *M = GV->getParent();
  2994. Constant *Init = GV->getInitializer();
  2995. if (!Init)
  2996. Init = UndefValue::get(Ty);
  2997. bool isConst = GV->isConstant();
  2998. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  2999. unsigned AddressSpace = GV->getType()->getAddressSpace();
  3000. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  3001. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3002. // Skip HLSL object types.
  3003. if (HLModule::IsHLSLObjectType(ST))
  3004. return false;
  3005. unsigned numTypes = ST->getNumContainedTypes();
  3006. Elts.reserve(numTypes);
  3007. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  3008. for (int i = 0, e = numTypes; i != e; ++i) {
  3009. Constant *EltInit = GetEltInit(Ty, Init, i, ST->getElementType(i));
  3010. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3011. *M, ST->getContainedType(i), /*IsConstant*/ isConst, linkage,
  3012. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3013. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3014. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  3015. // TODO: set precise.
  3016. // if (hasPrecise || FA.IsPrecise())
  3017. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3018. Elts.push_back(EltGV);
  3019. }
  3020. } else if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  3021. // TODO: support dynamic indexing on vector by change it to array.
  3022. unsigned numElts = VT->getNumElements();
  3023. Elts.reserve(numElts);
  3024. Type *EltTy = VT->getElementType();
  3025. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  3026. for (int i = 0, e = numElts; i != e; ++i) {
  3027. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  3028. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3029. *M, EltTy, /*IsConstant*/ isConst, linkage,
  3030. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3031. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3032. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  3033. // TODO: set precise.
  3034. // if (hasPrecise || FA.IsPrecise())
  3035. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3036. Elts.push_back(EltGV);
  3037. }
  3038. } else {
  3039. ArrayType *AT = cast<ArrayType>(Ty);
  3040. if (AT->getNumContainedTypes() == 0) {
  3041. // Skip case like [0 x %struct].
  3042. return false;
  3043. }
  3044. Type *ElTy = AT->getElementType();
  3045. SmallVector<ArrayType *, 4> nestArrayTys;
  3046. nestArrayTys.emplace_back(AT);
  3047. // support multi level of array
  3048. while (ElTy->isArrayTy()) {
  3049. ArrayType *ElAT = cast<ArrayType>(ElTy);
  3050. nestArrayTys.emplace_back(ElAT);
  3051. ElTy = ElAT->getElementType();
  3052. }
  3053. if (ElTy->isStructTy() &&
  3054. // Skip Matrix type.
  3055. !HLMatrixLower::IsMatrixType(ElTy)) {
  3056. // for array of struct
  3057. // split into arrays of struct elements
  3058. StructType *ElST = cast<StructType>(ElTy);
  3059. unsigned numTypes = ElST->getNumContainedTypes();
  3060. Elts.reserve(numTypes);
  3061. //DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ElST);
  3062. for (int i = 0, e = numTypes; i != e; ++i) {
  3063. Type *EltTy =
  3064. CreateNestArrayTy(ElST->getContainedType(i), nestArrayTys);
  3065. Constant *EltInit = GetEltInit(Ty, Init, i, EltTy);
  3066. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3067. *M, EltTy, /*IsConstant*/ isConst, linkage,
  3068. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3069. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3070. //DxilFieldAnnotation &FA = SA->GetFieldAnnotation(i);
  3071. // TODO: set precise.
  3072. // if (hasPrecise || FA.IsPrecise())
  3073. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3074. Elts.push_back(EltGV);
  3075. }
  3076. } else if (ElTy->isVectorTy()) {
  3077. // Skip vector if required.
  3078. if (!bFlatVector)
  3079. return false;
  3080. // for array of vector
  3081. // split into arrays of scalar
  3082. VectorType *ElVT = cast<VectorType>(ElTy);
  3083. Elts.reserve(ElVT->getNumElements());
  3084. ArrayType *scalarArrayTy =
  3085. CreateNestArrayTy(ElVT->getElementType(), nestArrayTys);
  3086. for (int i = 0, e = ElVT->getNumElements(); i != e; ++i) {
  3087. Constant *EltInit = GetEltInit(Ty, Init, i, scalarArrayTy);
  3088. GlobalVariable *EltGV = new llvm::GlobalVariable(
  3089. *M, scalarArrayTy, /*IsConstant*/ isConst, linkage,
  3090. /*InitVal*/ EltInit, GV->getName() + "." + Twine(i),
  3091. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  3092. // TODO: set precise.
  3093. // if (hasPrecise)
  3094. // HLModule::MarkPreciseAttributeWithMetadata(NA);
  3095. Elts.push_back(EltGV);
  3096. }
  3097. } else
  3098. // Skip array of basic types.
  3099. return false;
  3100. }
  3101. // Now that we have created the new alloca instructions, rewrite all the
  3102. // uses of the old alloca.
  3103. SROA_Helper helper(GV, Elts, DeadInsts);
  3104. helper.RewriteForScalarRepl(GV, Builder);
  3105. return true;
  3106. }
  3107. struct PointerStatus {
  3108. /// Keep track of what stores to the pointer look like.
  3109. enum StoredType {
  3110. /// There is no store to this pointer. It can thus be marked constant.
  3111. NotStored,
  3112. /// This ptr is a global, and is stored to, but the only thing stored is the
  3113. /// constant it
  3114. /// was initialized with. This is only tracked for scalar globals.
  3115. InitializerStored,
  3116. /// This ptr is stored to, but only its initializer and one other value
  3117. /// is ever stored to it. If this global isStoredOnce, we track the value
  3118. /// stored to it in StoredOnceValue below. This is only tracked for scalar
  3119. /// globals.
  3120. StoredOnce,
  3121. /// This ptr is only assigned by a memcpy.
  3122. MemcopyDestOnce,
  3123. /// This ptr is stored to by multiple values or something else that we
  3124. /// cannot track.
  3125. Stored
  3126. } StoredType;
  3127. /// Keep track of what loaded from the pointer look like.
  3128. enum LoadedType {
  3129. /// There is no load to this pointer. It can thus be marked constant.
  3130. NotLoaded,
  3131. /// This ptr is only used by a memcpy.
  3132. MemcopySrcOnce,
  3133. /// This ptr is loaded to by multiple instructions or something else that we
  3134. /// cannot track.
  3135. Loaded
  3136. } LoadedType;
  3137. /// If only one value (besides the initializer constant) is ever stored to
  3138. /// this global, keep track of what value it is.
  3139. Value *StoredOnceValue;
  3140. /// Memcpy which this ptr is used.
  3141. std::unordered_set<MemCpyInst *> memcpySet;
  3142. /// Memcpy which use this ptr as dest.
  3143. MemCpyInst *StoringMemcpy;
  3144. /// Memcpy which use this ptr as src.
  3145. MemCpyInst *LoadingMemcpy;
  3146. /// These start out null/false. When the first accessing function is noticed,
  3147. /// it is recorded. When a second different accessing function is noticed,
  3148. /// HasMultipleAccessingFunctions is set to true.
  3149. const Function *AccessingFunction;
  3150. bool HasMultipleAccessingFunctions;
  3151. /// Size of the ptr.
  3152. unsigned Size;
  3153. /// Look at all uses of the global and fill in the GlobalStatus structure. If
  3154. /// the global has its address taken, return true to indicate we can't do
  3155. /// anything with it.
  3156. static void analyzePointer(const Value *V, PointerStatus &PS,
  3157. DxilTypeSystem &typeSys, bool bStructElt);
  3158. PointerStatus(unsigned size)
  3159. : StoredType(NotStored), LoadedType(NotLoaded), StoredOnceValue(nullptr),
  3160. StoringMemcpy(nullptr), LoadingMemcpy(nullptr),
  3161. AccessingFunction(nullptr), HasMultipleAccessingFunctions(false),
  3162. Size(size) {}
  3163. void MarkAsStored() {
  3164. StoredType = PointerStatus::StoredType::Stored;
  3165. StoredOnceValue = nullptr;
  3166. }
  3167. void MarkAsLoaded() { LoadedType = PointerStatus::LoadedType::Loaded; }
  3168. };
  3169. void PointerStatus::analyzePointer(const Value *V, PointerStatus &PS,
  3170. DxilTypeSystem &typeSys, bool bStructElt) {
  3171. if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
  3172. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  3173. PS.StoredType = PointerStatus::StoredType::InitializerStored;
  3174. }
  3175. }
  3176. for (const User *U : V->users()) {
  3177. if (const Instruction *I = dyn_cast<Instruction>(U)) {
  3178. const Function *F = I->getParent()->getParent();
  3179. if (!PS.AccessingFunction) {
  3180. PS.AccessingFunction = F;
  3181. } else {
  3182. if (F != PS.AccessingFunction)
  3183. PS.HasMultipleAccessingFunctions = true;
  3184. }
  3185. }
  3186. if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(U)) {
  3187. analyzePointer(BC, PS, typeSys, bStructElt);
  3188. } else if (const MemCpyInst *MC = dyn_cast<MemCpyInst>(U)) {
  3189. // Do not collect memcpy on struct GEP use.
  3190. // These memcpy will be flattened in next level.
  3191. if (!bStructElt) {
  3192. MemCpyInst *MI = const_cast<MemCpyInst *>(MC);
  3193. PS.memcpySet.insert(MI);
  3194. bool bFullCopy = false;
  3195. if (ConstantInt *Length = dyn_cast<ConstantInt>(MC->getLength())) {
  3196. bFullCopy = PS.Size == Length->getLimitedValue()
  3197. || PS.Size == 0 || Length->getLimitedValue() == 0; // handle unbounded arrays
  3198. }
  3199. if (MC->getRawDest() == V) {
  3200. if (bFullCopy &&
  3201. PS.StoredType == PointerStatus::StoredType::NotStored) {
  3202. PS.StoredType = PointerStatus::StoredType::MemcopyDestOnce;
  3203. PS.StoringMemcpy = MI;
  3204. } else {
  3205. PS.MarkAsStored();
  3206. PS.StoringMemcpy = nullptr;
  3207. }
  3208. } else if (MC->getRawSource() == V) {
  3209. if (bFullCopy &&
  3210. PS.LoadedType == PointerStatus::LoadedType::NotLoaded) {
  3211. PS.LoadedType = PointerStatus::LoadedType::MemcopySrcOnce;
  3212. PS.LoadingMemcpy = MI;
  3213. } else {
  3214. PS.MarkAsLoaded();
  3215. PS.LoadingMemcpy = nullptr;
  3216. }
  3217. }
  3218. } else {
  3219. if (MC->getRawDest() == V) {
  3220. PS.MarkAsStored();
  3221. } else {
  3222. DXASSERT(MC->getRawSource() == V, "must be source here");
  3223. PS.MarkAsLoaded();
  3224. }
  3225. }
  3226. } else if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3227. gep_type_iterator GEPIt = gep_type_begin(GEP);
  3228. gep_type_iterator GEPEnd = gep_type_end(GEP);
  3229. // Skip pointer idx.
  3230. GEPIt++;
  3231. // Struct elt will be flattened in next level.
  3232. bool bStructElt = (GEPIt != GEPEnd) && GEPIt->isStructTy();
  3233. analyzePointer(GEP, PS, typeSys, bStructElt);
  3234. } else if (const StoreInst *SI = dyn_cast<StoreInst>(U)) {
  3235. Value *V = SI->getOperand(0);
  3236. if (PS.StoredType == PointerStatus::StoredType::NotStored) {
  3237. PS.StoredType = PointerStatus::StoredType::StoredOnce;
  3238. PS.StoredOnceValue = V;
  3239. } else {
  3240. PS.MarkAsStored();
  3241. }
  3242. } else if (const LoadInst *LI = dyn_cast<LoadInst>(U)) {
  3243. PS.MarkAsLoaded();
  3244. } else if (const CallInst *CI = dyn_cast<CallInst>(U)) {
  3245. Function *F = CI->getCalledFunction();
  3246. DxilFunctionAnnotation *annotation = typeSys.GetFunctionAnnotation(F);
  3247. if (!annotation) {
  3248. HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
  3249. switch (group) {
  3250. case HLOpcodeGroup::HLMatLoadStore: {
  3251. HLMatLoadStoreOpcode opcode =
  3252. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  3253. switch (opcode) {
  3254. case HLMatLoadStoreOpcode::ColMatLoad:
  3255. case HLMatLoadStoreOpcode::RowMatLoad:
  3256. PS.MarkAsLoaded();
  3257. break;
  3258. case HLMatLoadStoreOpcode::ColMatStore:
  3259. case HLMatLoadStoreOpcode::RowMatStore:
  3260. PS.MarkAsStored();
  3261. break;
  3262. default:
  3263. DXASSERT(0, "invalid opcode");
  3264. PS.MarkAsStored();
  3265. PS.MarkAsLoaded();
  3266. }
  3267. } break;
  3268. case HLOpcodeGroup::HLSubscript: {
  3269. HLSubscriptOpcode opcode =
  3270. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
  3271. switch (opcode) {
  3272. case HLSubscriptOpcode::VectorSubscript:
  3273. case HLSubscriptOpcode::ColMatElement:
  3274. case HLSubscriptOpcode::ColMatSubscript:
  3275. case HLSubscriptOpcode::RowMatElement:
  3276. case HLSubscriptOpcode::RowMatSubscript:
  3277. analyzePointer(CI, PS, typeSys, bStructElt);
  3278. break;
  3279. default:
  3280. // Rest are resource ptr like buf[i].
  3281. // Only read of resource handle.
  3282. PS.MarkAsLoaded();
  3283. break;
  3284. }
  3285. } break;
  3286. default: {
  3287. // If not sure its out param or not. Take as out param.
  3288. PS.MarkAsStored();
  3289. PS.MarkAsLoaded();
  3290. }
  3291. }
  3292. continue;
  3293. }
  3294. unsigned argSize = F->arg_size();
  3295. for (unsigned i = 0; i < argSize; i++) {
  3296. Value *arg = CI->getArgOperand(i);
  3297. if (V == arg) {
  3298. // Do not replace struct arg.
  3299. // Mark stored and loaded to disable replace.
  3300. PS.MarkAsStored();
  3301. PS.MarkAsLoaded();
  3302. }
  3303. }
  3304. }
  3305. }
  3306. }
  3307. static void ReplaceConstantWithInst(Constant *C, Value *V, IRBuilder<> &Builder) {
  3308. for (auto it = C->user_begin(); it != C->user_end(); ) {
  3309. User *U = *(it++);
  3310. if (Instruction *I = dyn_cast<Instruction>(U)) {
  3311. I->replaceUsesOfWith(C, V);
  3312. } else {
  3313. ConstantExpr *CE = cast<ConstantExpr>(U);
  3314. Instruction *Inst = CE->getAsInstruction();
  3315. Builder.Insert(Inst);
  3316. Inst->replaceUsesOfWith(C, V);
  3317. ReplaceConstantWithInst(CE, Inst, Builder);
  3318. }
  3319. }
  3320. }
  3321. static void ReplaceUnboundedArrayUses(Value *V, Value *Src, IRBuilder<> &Builder) {
  3322. for (auto it = V->user_begin(); it != V->user_end(); ) {
  3323. User *U = *(it++);
  3324. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  3325. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  3326. Value *NewGEP = Builder.CreateGEP(Src, idxList);
  3327. GEP->replaceAllUsesWith(NewGEP);
  3328. } else if (BitCastInst *BC = dyn_cast<BitCastInst>(U)) {
  3329. BC->setOperand(0, Src);
  3330. } else {
  3331. DXASSERT(false, "otherwise unbounded array used in unexpected instruction");
  3332. }
  3333. }
  3334. }
  3335. static void ReplaceMemcpy(Value *V, Value *Src, MemCpyInst *MC) {
  3336. if (Constant *C = dyn_cast<Constant>(V)) {
  3337. if (isa<Constant>(Src)) {
  3338. V->replaceAllUsesWith(Src);
  3339. } else {
  3340. // Replace Constant with a non-Constant.
  3341. IRBuilder<> Builder(MC);
  3342. ReplaceConstantWithInst(C, Src, Builder);
  3343. }
  3344. } else {
  3345. Type* TyV = V->getType()->getPointerElementType();
  3346. Type* TySrc = Src->getType()->getPointerElementType();
  3347. if (TyV == TySrc) {
  3348. if (V != Src)
  3349. V->replaceAllUsesWith(Src);
  3350. } else {
  3351. DXASSERT((TyV->isArrayTy() && TySrc->isArrayTy()) &&
  3352. (TyV->getArrayNumElements() == 0 ||
  3353. TySrc->getArrayNumElements() == 0),
  3354. "otherwise mismatched types in memcpy are not unbounded array");
  3355. IRBuilder<> Builder(MC);
  3356. ReplaceUnboundedArrayUses(V, Src, Builder);
  3357. }
  3358. }
  3359. Value *RawDest = MC->getOperand(0);
  3360. Value *RawSrc = MC->getOperand(1);
  3361. MC->eraseFromParent();
  3362. if (Instruction *I = dyn_cast<Instruction>(RawDest)) {
  3363. if (I->user_empty())
  3364. I->eraseFromParent();
  3365. }
  3366. if (Instruction *I = dyn_cast<Instruction>(RawSrc)) {
  3367. if (I->user_empty())
  3368. I->eraseFromParent();
  3369. }
  3370. }
  3371. bool SROA_Helper::LowerMemcpy(Value *V, DxilFieldAnnotation *annotation,
  3372. DxilTypeSystem &typeSys, const DataLayout &DL,
  3373. bool bAllowReplace) {
  3374. Type *Ty = V->getType();
  3375. if (!Ty->isPointerTy()) {
  3376. return false;
  3377. }
  3378. // Get access status and collect memcpy uses.
  3379. // if MemcpyOnce, replace with dest with src if dest is not out param.
  3380. // else flat memcpy.
  3381. unsigned size = DL.getTypeAllocSize(Ty->getPointerElementType());
  3382. PointerStatus PS(size);
  3383. const bool bStructElt = false;
  3384. PointerStatus::analyzePointer(V, PS, typeSys, bStructElt);
  3385. if (bAllowReplace && !PS.HasMultipleAccessingFunctions) {
  3386. if (PS.StoredType == PointerStatus::StoredType::MemcopyDestOnce &&
  3387. // Skip argument for input argument has input value, it is not dest once anymore.
  3388. !isa<Argument>(V)) {
  3389. // Replace with src of memcpy.
  3390. MemCpyInst *MC = PS.StoringMemcpy;
  3391. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3392. Value *Src = MC->getOperand(1);
  3393. // Only remove one level bitcast generated from inline.
  3394. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Src))
  3395. Src = BC->getOperand(0);
  3396. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
  3397. // For GEP, the ptr could have other GEP read/write.
  3398. // Only scan one GEP is not enough.
  3399. Value *Ptr = GEP->getPointerOperand();
  3400. if (CallInst *PtrCI = dyn_cast<CallInst>(Ptr)) {
  3401. hlsl::HLOpcodeGroup group =
  3402. hlsl::GetHLOpcodeGroup(PtrCI->getCalledFunction());
  3403. if (group == HLOpcodeGroup::HLSubscript) {
  3404. HLSubscriptOpcode opcode =
  3405. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(PtrCI));
  3406. if (opcode == HLSubscriptOpcode::CBufferSubscript) {
  3407. // Ptr from CBuffer is safe.
  3408. ReplaceMemcpy(V, Src, MC);
  3409. return true;
  3410. }
  3411. }
  3412. }
  3413. } else if (!isa<CallInst>(Src)) {
  3414. // Resource ptr should not be replaced.
  3415. // Need to make sure src not updated after current memcpy.
  3416. // Check Src only have 1 store now.
  3417. PointerStatus SrcPS(size);
  3418. PointerStatus::analyzePointer(Src, SrcPS, typeSys, bStructElt);
  3419. if (SrcPS.StoredType != PointerStatus::StoredType::Stored) {
  3420. ReplaceMemcpy(V, Src, MC);
  3421. return true;
  3422. }
  3423. }
  3424. }
  3425. } else if (PS.LoadedType == PointerStatus::LoadedType::MemcopySrcOnce) {
  3426. // Replace dst of memcpy.
  3427. MemCpyInst *MC = PS.LoadingMemcpy;
  3428. if (MC->getSourceAddressSpace() == MC->getDestAddressSpace()) {
  3429. Value *Dest = MC->getOperand(0);
  3430. // Only remove one level bitcast generated from inline.
  3431. if (BitCastOperator *BC = dyn_cast<BitCastOperator>(Dest))
  3432. Dest = BC->getOperand(0);
  3433. // For GEP, the ptr could have other GEP read/write.
  3434. // Only scan one GEP is not enough.
  3435. // And resource ptr should not be replaced.
  3436. if (!isa<GEPOperator>(Dest) && !isa<CallInst>(Dest) &&
  3437. !isa<BitCastOperator>(Dest)) {
  3438. // Need to make sure Dest not updated after current memcpy.
  3439. // Check Dest only have 1 store now.
  3440. PointerStatus DestPS(size);
  3441. PointerStatus::analyzePointer(Dest, DestPS, typeSys, bStructElt);
  3442. if (DestPS.StoredType != PointerStatus::StoredType::Stored) {
  3443. ReplaceMemcpy(Dest, V, MC);
  3444. // V still need to be flatten.
  3445. // Lower memcpy come from Dest.
  3446. return LowerMemcpy(V, annotation, typeSys, DL, bAllowReplace);
  3447. }
  3448. }
  3449. }
  3450. }
  3451. }
  3452. for (MemCpyInst *MC : PS.memcpySet) {
  3453. MemcpySplitter::SplitMemCpy(MC, DL, annotation, typeSys);
  3454. }
  3455. return false;
  3456. }
  3457. /// MarkEmptyStructUsers - Add instruction related to Empty struct to DeadInsts.
  3458. void SROA_Helper::MarkEmptyStructUsers(Value *V, SmallVector<Value *, 32> &DeadInsts) {
  3459. for (User *U : V->users()) {
  3460. MarkEmptyStructUsers(U, DeadInsts);
  3461. }
  3462. if (Instruction *I = dyn_cast<Instruction>(V)) {
  3463. // Only need to add no use inst here.
  3464. // DeleteDeadInst will delete everything.
  3465. if (I->user_empty())
  3466. DeadInsts.emplace_back(I);
  3467. }
  3468. }
  3469. bool SROA_Helper::IsEmptyStructType(Type *Ty, DxilTypeSystem &typeSys) {
  3470. if (isa<ArrayType>(Ty))
  3471. Ty = Ty->getArrayElementType();
  3472. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3473. if (!HLMatrixLower::IsMatrixType(Ty)) {
  3474. DxilStructAnnotation *SA = typeSys.GetStructAnnotation(ST);
  3475. if (SA && SA->IsEmptyStruct())
  3476. return true;
  3477. }
  3478. }
  3479. return false;
  3480. }
  3481. //===----------------------------------------------------------------------===//
  3482. // SROA on function parameters.
  3483. //===----------------------------------------------------------------------===//
  3484. namespace {
  3485. class SROA_Parameter_HLSL : public ModulePass {
  3486. HLModule *m_pHLModule;
  3487. public:
  3488. static char ID; // Pass identification, replacement for typeid
  3489. explicit SROA_Parameter_HLSL() : ModulePass(ID) {}
  3490. const char *getPassName() const override { return "SROA Parameter HLSL"; }
  3491. bool runOnModule(Module &M) override {
  3492. // Patch memcpy to cover case bitcast (gep ptr, 0,0) is transformed into
  3493. // bitcast ptr.
  3494. MemcpySplitter::PatchMemCpyWithZeroIdxGEP(M);
  3495. m_pHLModule = &M.GetOrCreateHLModule();
  3496. // Load up debug information, to cross-reference values and the instructions
  3497. // used to load them.
  3498. m_HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  3499. std::deque<Function *> WorkList;
  3500. for (Function &F : M.functions()) {
  3501. HLOpcodeGroup group = GetHLOpcodeGroup(&F);
  3502. // Skip HL operations.
  3503. if (group != HLOpcodeGroup::NotHL || group == HLOpcodeGroup::HLExtIntrinsic) {
  3504. continue;
  3505. }
  3506. if (F.isDeclaration()) {
  3507. // Skip llvm intrinsic.
  3508. if (F.isIntrinsic())
  3509. continue;
  3510. // Skip unused external function.
  3511. if (F.user_empty())
  3512. continue;
  3513. }
  3514. // Skip void(void) functions.
  3515. if (F.getReturnType()->isVoidTy() && F.arg_size() == 0)
  3516. continue;
  3517. WorkList.emplace_back(&F);
  3518. }
  3519. // Preprocess aggregate function param used as function call arg.
  3520. for (Function *F : WorkList) {
  3521. preprocessArgUsedInCall(F);
  3522. }
  3523. // Process the worklist
  3524. while (!WorkList.empty()) {
  3525. Function *F = WorkList.front();
  3526. WorkList.pop_front();
  3527. createFlattenedFunction(F);
  3528. }
  3529. // Replace functions with flattened version when we flat all the functions.
  3530. for (auto Iter : funcMap)
  3531. replaceCall(Iter.first, Iter.second);
  3532. // Remove flattened functions.
  3533. for (auto Iter : funcMap) {
  3534. Function *F = Iter.first;
  3535. Function *flatF = Iter.second;
  3536. flatF->takeName(F);
  3537. F->eraseFromParent();
  3538. }
  3539. // Flatten internal global.
  3540. std::vector<GlobalVariable *> staticGVs;
  3541. for (GlobalVariable &GV : M.globals()) {
  3542. if (dxilutil::IsStaticGlobal(&GV) ||
  3543. dxilutil::IsSharedMemoryGlobal(&GV)) {
  3544. staticGVs.emplace_back(&GV);
  3545. } else {
  3546. // merge GEP use for global.
  3547. HLModule::MergeGepUse(&GV);
  3548. }
  3549. }
  3550. for (GlobalVariable *GV : staticGVs)
  3551. flattenGlobal(GV);
  3552. // Remove unused internal global.
  3553. staticGVs.clear();
  3554. for (GlobalVariable &GV : M.globals()) {
  3555. if (dxilutil::IsStaticGlobal(&GV) ||
  3556. dxilutil::IsSharedMemoryGlobal(&GV)) {
  3557. staticGVs.emplace_back(&GV);
  3558. }
  3559. }
  3560. for (GlobalVariable *GV : staticGVs) {
  3561. bool onlyStoreUse = true;
  3562. for (User *user : GV->users()) {
  3563. if (isa<StoreInst>(user))
  3564. continue;
  3565. if (isa<ConstantExpr>(user) && user->user_empty())
  3566. continue;
  3567. // Check matrix store.
  3568. if (HLMatrixLower::IsMatrixType(
  3569. GV->getType()->getPointerElementType())) {
  3570. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  3571. if (GetHLOpcodeGroupByName(CI->getCalledFunction()) ==
  3572. HLOpcodeGroup::HLMatLoadStore) {
  3573. HLMatLoadStoreOpcode opcode =
  3574. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  3575. if (opcode == HLMatLoadStoreOpcode::ColMatStore ||
  3576. opcode == HLMatLoadStoreOpcode::RowMatStore)
  3577. continue;
  3578. }
  3579. }
  3580. }
  3581. onlyStoreUse = false;
  3582. break;
  3583. }
  3584. if (onlyStoreUse) {
  3585. for (auto UserIt = GV->user_begin(); UserIt != GV->user_end();) {
  3586. Value *User = *(UserIt++);
  3587. if (Instruction *I = dyn_cast<Instruction>(User)) {
  3588. I->eraseFromParent();
  3589. }
  3590. else {
  3591. ConstantExpr *CE = cast<ConstantExpr>(User);
  3592. CE->dropAllReferences();
  3593. }
  3594. }
  3595. GV->eraseFromParent();
  3596. }
  3597. }
  3598. return true;
  3599. }
  3600. private:
  3601. void DeleteDeadInstructions();
  3602. void preprocessArgUsedInCall(Function *F);
  3603. void moveFunctionBody(Function *F, Function *flatF);
  3604. void replaceCall(Function *F, Function *flatF);
  3605. void createFlattenedFunction(Function *F);
  3606. void createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI);
  3607. void
  3608. flattenArgument(Function *F, Value *Arg, bool bForParam,
  3609. DxilParameterAnnotation &paramAnnotation,
  3610. std::vector<Value *> &FlatParamList,
  3611. std::vector<DxilParameterAnnotation> &FlatRetAnnotationList,
  3612. IRBuilder<> &Builder, DbgDeclareInst *DDI);
  3613. Value *castArgumentIfRequired(Value *V, Type *Ty, bool bOut,
  3614. bool hasShaderInputOutput,
  3615. DxilParamInputQual inputQual,
  3616. DxilFieldAnnotation &annotation,
  3617. std::deque<Value *> &WorkList,
  3618. IRBuilder<> &Builder);
  3619. // Replace argument which changed type when flatten.
  3620. void replaceCastArgument(Value *&NewArg, Value *OldArg,
  3621. DxilParamInputQual inputQual,
  3622. IRBuilder<> &CallBuilder, IRBuilder<> &RetBuilder);
  3623. // Replace use of parameter which changed type when flatten.
  3624. // Also add information to Arg if required.
  3625. void replaceCastParameter(Value *NewParam, Value *OldParam, Function &F,
  3626. Argument *Arg, const DxilParamInputQual inputQual,
  3627. IRBuilder<> &Builder);
  3628. void allocateSemanticIndex(
  3629. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3630. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap);
  3631. bool hasDynamicVectorIndexing(Value *V);
  3632. void flattenGlobal(GlobalVariable *GV);
  3633. /// DeadInsts - Keep track of instructions we have made dead, so that
  3634. /// we can remove them after we are done working.
  3635. SmallVector<Value *, 32> DeadInsts;
  3636. // Map from orginal function to the flatten version.
  3637. std::unordered_map<Function *, Function *> funcMap;
  3638. // Map from original arg/param to flatten cast version.
  3639. std::unordered_map<Value *, std::pair<Value*, DxilParamInputQual>> castParamMap;
  3640. // Map form first element of a vector the list of all elements of the vector.
  3641. std::unordered_map<Value *, SmallVector<Value*, 4> > vectorEltsMap;
  3642. // Set for row major matrix parameter.
  3643. std::unordered_set<Value *> castRowMajorParamMap;
  3644. bool m_HasDbgInfo;
  3645. };
  3646. }
  3647. char SROA_Parameter_HLSL::ID = 0;
  3648. INITIALIZE_PASS(SROA_Parameter_HLSL, "scalarrepl-param-hlsl",
  3649. "Scalar Replacement of Aggregates HLSL (parameters)", false,
  3650. false)
  3651. /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
  3652. /// recursively including all their operands that become trivially dead.
  3653. void SROA_Parameter_HLSL::DeleteDeadInstructions() {
  3654. while (!DeadInsts.empty()) {
  3655. Instruction *I = cast<Instruction>(DeadInsts.pop_back_val());
  3656. for (User::op_iterator OI = I->op_begin(), E = I->op_end(); OI != E; ++OI)
  3657. if (Instruction *U = dyn_cast<Instruction>(*OI)) {
  3658. // Zero out the operand and see if it becomes trivially dead.
  3659. // (But, don't add allocas to the dead instruction list -- they are
  3660. // already on the worklist and will be deleted separately.)
  3661. *OI = nullptr;
  3662. if (isInstructionTriviallyDead(U) && !isa<AllocaInst>(U))
  3663. DeadInsts.push_back(U);
  3664. }
  3665. I->eraseFromParent();
  3666. }
  3667. }
  3668. bool SROA_Parameter_HLSL::hasDynamicVectorIndexing(Value *V) {
  3669. for (User *U : V->users()) {
  3670. if (!U->getType()->isPointerTy())
  3671. continue;
  3672. if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3673. gep_type_iterator GEPIt = gep_type_begin(U), E = gep_type_end(U);
  3674. for (; GEPIt != E; ++GEPIt) {
  3675. if (isa<VectorType>(*GEPIt)) {
  3676. Value *VecIdx = GEPIt.getOperand();
  3677. if (!isa<ConstantInt>(VecIdx))
  3678. return true;
  3679. }
  3680. }
  3681. }
  3682. }
  3683. return false;
  3684. }
  3685. void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
  3686. Type *Ty = GV->getType()->getPointerElementType();
  3687. // Skip basic types.
  3688. if (!Ty->isAggregateType() && !Ty->isVectorTy())
  3689. return;
  3690. std::deque<Value *> WorkList;
  3691. WorkList.push_back(GV);
  3692. // merge GEP use for global.
  3693. HLModule::MergeGepUse(GV);
  3694. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  3695. // Only used to create ConstantExpr.
  3696. IRBuilder<> Builder(m_pHLModule->GetCtx());
  3697. std::vector<Instruction*> deadAllocas;
  3698. const DataLayout &DL = GV->getParent()->getDataLayout();
  3699. unsigned debugOffset = 0;
  3700. std::unordered_map<Value*, StringRef> EltNameMap;
  3701. // Process the worklist
  3702. while (!WorkList.empty()) {
  3703. GlobalVariable *EltGV = cast<GlobalVariable>(WorkList.front());
  3704. WorkList.pop_front();
  3705. const bool bAllowReplace = true;
  3706. if (SROA_Helper::LowerMemcpy(EltGV, /*annoation*/ nullptr, dxilTypeSys, DL,
  3707. bAllowReplace)) {
  3708. continue;
  3709. }
  3710. // Flat Global vector if no dynamic vector indexing.
  3711. bool bFlatVector = !hasDynamicVectorIndexing(EltGV);
  3712. std::vector<Value *> Elts;
  3713. bool SROAed = SROA_Helper::DoScalarReplacement(
  3714. EltGV, Elts, Builder, bFlatVector,
  3715. // TODO: set precise.
  3716. /*hasPrecise*/ false,
  3717. dxilTypeSys, DeadInsts);
  3718. if (SROAed) {
  3719. // Push Elts into workList.
  3720. // Use rbegin to make sure the order not change.
  3721. for (auto iter = Elts.rbegin(); iter != Elts.rend(); iter++) {
  3722. WorkList.push_front(*iter);
  3723. if (m_HasDbgInfo) {
  3724. StringRef EltName = (*iter)->getName().ltrim(GV->getName());
  3725. EltNameMap[*iter] = EltName;
  3726. }
  3727. }
  3728. EltGV->removeDeadConstantUsers();
  3729. // Now erase any instructions that were made dead while rewriting the
  3730. // alloca.
  3731. DeleteDeadInstructions();
  3732. ++NumReplaced;
  3733. } else {
  3734. // Add debug info for flattened globals.
  3735. if (m_HasDbgInfo && GV != EltGV) {
  3736. DebugInfoFinder &Finder = m_pHLModule->GetOrCreateDebugInfoFinder();
  3737. Type *Ty = EltGV->getType()->getElementType();
  3738. unsigned size = DL.getTypeAllocSizeInBits(Ty);
  3739. unsigned align = DL.getPrefTypeAlignment(Ty);
  3740. HLModule::CreateElementGlobalVariableDebugInfo(
  3741. GV, Finder, EltGV, size, align, debugOffset,
  3742. EltNameMap[EltGV]);
  3743. debugOffset += size;
  3744. }
  3745. }
  3746. }
  3747. DeleteDeadInstructions();
  3748. if (GV->user_empty()) {
  3749. GV->removeDeadConstantUsers();
  3750. GV->eraseFromParent();
  3751. }
  3752. }
  3753. static DxilFieldAnnotation &GetEltAnnotation(Type *Ty, unsigned idx, DxilFieldAnnotation &annotation, DxilTypeSystem &dxilTypeSys) {
  3754. while (Ty->isArrayTy())
  3755. Ty = Ty->getArrayElementType();
  3756. if (StructType *ST = dyn_cast<StructType>(Ty)) {
  3757. if (HLMatrixLower::IsMatrixType(Ty))
  3758. return annotation;
  3759. DxilStructAnnotation *SA = dxilTypeSys.GetStructAnnotation(ST);
  3760. if (SA) {
  3761. DxilFieldAnnotation &FA = SA->GetFieldAnnotation(idx);
  3762. return FA;
  3763. }
  3764. }
  3765. return annotation;
  3766. }
  3767. // Note: Semantic index allocation.
  3768. // Semantic index is allocated base on linear layout.
  3769. // For following code
  3770. /*
  3771. struct S {
  3772. float4 m;
  3773. float4 m2;
  3774. };
  3775. S s[2] : semantic;
  3776. struct S2 {
  3777. float4 m[2];
  3778. float4 m2[2];
  3779. };
  3780. S2 s2 : semantic;
  3781. */
  3782. // The semantic index is like this:
  3783. // s[0].m : semantic0
  3784. // s[0].m2 : semantic1
  3785. // s[1].m : semantic2
  3786. // s[1].m2 : semantic3
  3787. // s2.m[0] : semantic0
  3788. // s2.m[1] : semantic1
  3789. // s2.m2[0] : semantic2
  3790. // s2.m2[1] : semantic3
  3791. // But when flatten argument, the result is like this:
  3792. // float4 s_m[2], float4 s_m2[2].
  3793. // float4 s2_m[2], float4 s2_m2[2].
  3794. // To do the allocation, need to map from each element to its flattened argument.
  3795. // Say arg index of float4 s_m[2] is 0, float4 s_m2[2] is 1.
  3796. // Need to get 0 from s[0].m and s[1].m, get 1 from s[0].m2 and s[1].m2.
  3797. // Allocate the argments with same semantic string from type where the
  3798. // semantic starts( S2 for s2.m[2] and s2.m2[2]).
  3799. // Iterate each elements of the type, save the semantic index and update it.
  3800. // The map from element to the arg ( s[0].m2 -> s.m2[2]) is done by argIdx.
  3801. // ArgIdx only inc by 1 when finish a struct field.
  3802. static unsigned AllocateSemanticIndex(
  3803. Type *Ty, unsigned &semIndex, unsigned argIdx, unsigned endArgIdx,
  3804. std::vector<DxilParameterAnnotation> &FlatAnnotationList) {
  3805. if (Ty->isPointerTy()) {
  3806. return AllocateSemanticIndex(Ty->getPointerElementType(), semIndex, argIdx,
  3807. endArgIdx, FlatAnnotationList);
  3808. } else if (Ty->isArrayTy()) {
  3809. unsigned arraySize = Ty->getArrayNumElements();
  3810. unsigned updatedArgIdx = argIdx;
  3811. Type *EltTy = Ty->getArrayElementType();
  3812. for (unsigned i = 0; i < arraySize; i++) {
  3813. updatedArgIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3814. FlatAnnotationList);
  3815. }
  3816. return updatedArgIdx;
  3817. } else if (Ty->isStructTy() && !HLMatrixLower::IsMatrixType(Ty)) {
  3818. unsigned fieldsCount = Ty->getStructNumElements();
  3819. for (unsigned i = 0; i < fieldsCount; i++) {
  3820. Type *EltTy = Ty->getStructElementType(i);
  3821. argIdx = AllocateSemanticIndex(EltTy, semIndex, argIdx, endArgIdx,
  3822. FlatAnnotationList);
  3823. if (!(EltTy->isStructTy() && !HLMatrixLower::IsMatrixType(EltTy))) {
  3824. // Update argIdx only when it is a leaf node.
  3825. argIdx++;
  3826. }
  3827. }
  3828. return argIdx;
  3829. } else {
  3830. DXASSERT(argIdx < endArgIdx, "arg index out of bound");
  3831. DxilParameterAnnotation &paramAnnotation = FlatAnnotationList[argIdx];
  3832. // Get element size.
  3833. unsigned rows = 1;
  3834. if (paramAnnotation.HasMatrixAnnotation()) {
  3835. const DxilMatrixAnnotation &matrix =
  3836. paramAnnotation.GetMatrixAnnotation();
  3837. if (matrix.Orientation == MatrixOrientation::RowMajor) {
  3838. rows = matrix.Rows;
  3839. } else {
  3840. DXASSERT(matrix.Orientation == MatrixOrientation::ColumnMajor, "");
  3841. rows = matrix.Cols;
  3842. }
  3843. }
  3844. // Save semIndex.
  3845. for (unsigned i = 0; i < rows; i++)
  3846. paramAnnotation.AppendSemanticIndex(semIndex + i);
  3847. // Update semIndex.
  3848. semIndex += rows;
  3849. return argIdx;
  3850. }
  3851. }
  3852. void SROA_Parameter_HLSL::allocateSemanticIndex(
  3853. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  3854. unsigned startArgIndex, llvm::StringMap<Type *> &semanticTypeMap) {
  3855. unsigned endArgIndex = FlatAnnotationList.size();
  3856. // Allocate semantic index.
  3857. for (unsigned i = startArgIndex; i < endArgIndex; ++i) {
  3858. // Group by semantic names.
  3859. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList[i];
  3860. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  3861. // If semantic is undefined, an error will be emitted elsewhere. For now,
  3862. // we should avoid asserting.
  3863. if (semantic.empty())
  3864. continue;
  3865. unsigned semGroupEnd = i + 1;
  3866. while (semGroupEnd < endArgIndex &&
  3867. FlatAnnotationList[semGroupEnd].GetSemanticString() == semantic) {
  3868. ++semGroupEnd;
  3869. }
  3870. StringRef baseSemName; // The 'FOO' in 'FOO1'.
  3871. uint32_t semIndex; // The '1' in 'FOO1'
  3872. // Split semName and index.
  3873. Semantic::DecomposeNameAndIndex(semantic, &baseSemName, &semIndex);
  3874. DXASSERT(semanticTypeMap.count(semantic) > 0, "Must has semantic type");
  3875. Type *semanticTy = semanticTypeMap[semantic];
  3876. AllocateSemanticIndex(semanticTy, semIndex, /*argIdx*/ i,
  3877. /*endArgIdx*/ semGroupEnd, FlatAnnotationList);
  3878. // Update i.
  3879. i = semGroupEnd - 1;
  3880. }
  3881. }
  3882. //
  3883. // Cast parameters.
  3884. //
  3885. static void CopyHandleToResourcePtr(Value *Handle, Value *ResPtr, HLModule &HLM,
  3886. IRBuilder<> &Builder) {
  3887. // Cast it to resource.
  3888. Type *ResTy = ResPtr->getType()->getPointerElementType();
  3889. Value *Res = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3890. (unsigned)HLCastOpcode::HandleToResCast,
  3891. ResTy, {Handle}, *HLM.GetModule());
  3892. // Store casted resource to OldArg.
  3893. Builder.CreateStore(Res, ResPtr);
  3894. }
  3895. static void CopyHandlePtrToResourcePtr(Value *HandlePtr, Value *ResPtr,
  3896. HLModule &HLM, IRBuilder<> &Builder) {
  3897. // Load the handle.
  3898. Value *Handle = Builder.CreateLoad(HandlePtr);
  3899. CopyHandleToResourcePtr(Handle, ResPtr, HLM, Builder);
  3900. }
  3901. static Value *CastResourcePtrToHandle(Value *Res, Type *HandleTy, HLModule &HLM,
  3902. IRBuilder<> &Builder) {
  3903. // Load OldArg.
  3904. Value *LdRes = Builder.CreateLoad(Res);
  3905. Value *Handle = HLM.EmitHLOperationCall(
  3906. Builder, HLOpcodeGroup::HLCreateHandle,
  3907. /*opcode*/ 0, HandleTy, {LdRes}, *HLM.GetModule());
  3908. return Handle;
  3909. }
  3910. static void CopyResourcePtrToHandlePtr(Value *Res, Value *HandlePtr,
  3911. HLModule &HLM, IRBuilder<> &Builder) {
  3912. Type *HandleTy = HandlePtr->getType()->getPointerElementType();
  3913. Value *Handle = CastResourcePtrToHandle(Res, HandleTy, HLM, Builder);
  3914. Builder.CreateStore(Handle, HandlePtr);
  3915. }
  3916. static void CopyVectorPtrToEltsPtr(Value *VecPtr, ArrayRef<Value *> elts,
  3917. unsigned vecSize, IRBuilder<> &Builder) {
  3918. Value *Vec = Builder.CreateLoad(VecPtr);
  3919. for (unsigned i = 0; i < vecSize; i++) {
  3920. Value *Elt = Builder.CreateExtractElement(Vec, i);
  3921. Builder.CreateStore(Elt, elts[i]);
  3922. }
  3923. }
  3924. static void CopyEltsPtrToVectorPtr(ArrayRef<Value *> elts, Value *VecPtr,
  3925. Type *VecTy, unsigned vecSize,
  3926. IRBuilder<> &Builder) {
  3927. Value *Vec = UndefValue::get(VecTy);
  3928. for (unsigned i = 0; i < vecSize; i++) {
  3929. Value *Elt = Builder.CreateLoad(elts[i]);
  3930. Vec = Builder.CreateInsertElement(Vec, Elt, i);
  3931. }
  3932. Builder.CreateStore(Vec, VecPtr);
  3933. }
  3934. static void CopyMatToArrayPtr(Value *Mat, Value *ArrayPtr,
  3935. unsigned arrayBaseIdx, HLModule &HLM,
  3936. IRBuilder<> &Builder, bool bRowMajor) {
  3937. Type *Ty = Mat->getType();
  3938. // Mat val is row major.
  3939. unsigned col, row;
  3940. HLMatrixLower::GetMatrixInfo(Mat->getType(), col, row);
  3941. Type *VecTy = HLMatrixLower::LowerMatrixType(Ty);
  3942. Value *Vec =
  3943. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3944. (unsigned)HLCastOpcode::RowMatrixToVecCast, VecTy,
  3945. {Mat}, *HLM.GetModule());
  3946. Value *zero = Builder.getInt32(0);
  3947. for (unsigned r = 0; r < row; r++) {
  3948. for (unsigned c = 0; c < col; c++) {
  3949. unsigned rowMatIdx = HLMatrixLower::GetColMajorIdx(r, c, row);
  3950. Value *Elt = Builder.CreateExtractElement(Vec, rowMatIdx);
  3951. unsigned matIdx =
  3952. bRowMajor ? rowMatIdx : HLMatrixLower::GetColMajorIdx(r, c, row);
  3953. Value *Ptr = Builder.CreateInBoundsGEP(
  3954. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3955. Builder.CreateStore(Elt, Ptr);
  3956. }
  3957. }
  3958. }
  3959. static void CopyMatPtrToArrayPtr(Value *MatPtr, Value *ArrayPtr,
  3960. unsigned arrayBaseIdx, HLModule &HLM,
  3961. IRBuilder<> &Builder, bool bRowMajor) {
  3962. Type *Ty = MatPtr->getType()->getPointerElementType();
  3963. Value *Mat = nullptr;
  3964. if (bRowMajor) {
  3965. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3966. (unsigned)HLMatLoadStoreOpcode::RowMatLoad,
  3967. Ty, {MatPtr}, *HLM.GetModule());
  3968. } else {
  3969. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  3970. (unsigned)HLMatLoadStoreOpcode::ColMatLoad,
  3971. Ty, {MatPtr}, *HLM.GetModule());
  3972. // Matrix value should be row major.
  3973. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  3974. (unsigned)HLCastOpcode::ColMatrixToRowMatrix,
  3975. Ty, {Mat}, *HLM.GetModule());
  3976. }
  3977. CopyMatToArrayPtr(Mat, ArrayPtr, arrayBaseIdx, HLM, Builder, bRowMajor);
  3978. }
  3979. static Value *LoadArrayPtrToMat(Value *ArrayPtr, unsigned arrayBaseIdx,
  3980. Type *Ty, HLModule &HLM, IRBuilder<> &Builder,
  3981. bool bRowMajor) {
  3982. unsigned col, row;
  3983. HLMatrixLower::GetMatrixInfo(Ty, col, row);
  3984. // HLInit operands are in row major.
  3985. SmallVector<Value *, 16> Elts;
  3986. Value *zero = Builder.getInt32(0);
  3987. for (unsigned r = 0; r < row; r++) {
  3988. for (unsigned c = 0; c < col; c++) {
  3989. unsigned matIdx = bRowMajor ? HLMatrixLower::GetRowMajorIdx(r, c, col)
  3990. : HLMatrixLower::GetColMajorIdx(r, c, row);
  3991. Value *Ptr = Builder.CreateInBoundsGEP(
  3992. ArrayPtr, {zero, Builder.getInt32(arrayBaseIdx + matIdx)});
  3993. Value *Elt = Builder.CreateLoad(Ptr);
  3994. Elts.emplace_back(Elt);
  3995. }
  3996. }
  3997. return HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLInit,
  3998. /*opcode*/ 0, Ty, {Elts}, *HLM.GetModule());
  3999. }
  4000. static void CopyArrayPtrToMatPtr(Value *ArrayPtr, unsigned arrayBaseIdx,
  4001. Value *MatPtr, HLModule &HLM,
  4002. IRBuilder<> &Builder, bool bRowMajor) {
  4003. Type *Ty = MatPtr->getType()->getPointerElementType();
  4004. Value *Mat =
  4005. LoadArrayPtrToMat(ArrayPtr, arrayBaseIdx, Ty, HLM, Builder, bRowMajor);
  4006. if (bRowMajor) {
  4007. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  4008. (unsigned)HLMatLoadStoreOpcode::RowMatStore, Ty,
  4009. {MatPtr, Mat}, *HLM.GetModule());
  4010. } else {
  4011. // Mat is row major.
  4012. // Cast it to col major before store.
  4013. Mat = HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLCast,
  4014. (unsigned)HLCastOpcode::RowMatrixToColMatrix,
  4015. Ty, {Mat}, *HLM.GetModule());
  4016. HLM.EmitHLOperationCall(Builder, HLOpcodeGroup::HLMatLoadStore,
  4017. (unsigned)HLMatLoadStoreOpcode::ColMatStore, Ty,
  4018. {MatPtr, Mat}, *HLM.GetModule());
  4019. }
  4020. }
  4021. using CopyFunctionTy = void(Value *FromPtr, Value *ToPtr, HLModule &HLM,
  4022. Type *HandleTy, IRBuilder<> &Builder,
  4023. bool bRowMajor);
  4024. static void
  4025. CastCopyArrayMultiDimTo1Dim(Value *FromArray, Value *ToArray, Type *CurFromTy,
  4026. std::vector<Value *> &idxList, unsigned calcIdx,
  4027. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  4028. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  4029. if (CurFromTy->isVectorTy()) {
  4030. // Copy vector to array.
  4031. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4032. Value *V = Builder.CreateLoad(FromPtr);
  4033. unsigned vecSize = CurFromTy->getVectorNumElements();
  4034. Value *zeroIdx = Builder.getInt32(0);
  4035. for (unsigned i = 0; i < vecSize; i++) {
  4036. Value *ToPtr = Builder.CreateInBoundsGEP(
  4037. ToArray, {zeroIdx, Builder.getInt32(calcIdx++)});
  4038. Value *Elt = Builder.CreateExtractElement(V, i);
  4039. Builder.CreateStore(Elt, ToPtr);
  4040. }
  4041. } else if (HLMatrixLower::IsMatrixType(CurFromTy)) {
  4042. // Copy matrix to array.
  4043. unsigned col, row;
  4044. HLMatrixLower::GetMatrixInfo(CurFromTy, col, row);
  4045. // Calculate the offset.
  4046. unsigned offset = calcIdx * col * row;
  4047. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4048. CopyMatPtrToArrayPtr(FromPtr, ToArray, offset, HLM, Builder, bRowMajor);
  4049. } else if (!CurFromTy->isArrayTy()) {
  4050. Value *FromPtr = Builder.CreateInBoundsGEP(FromArray, idxList);
  4051. Value *ToPtr = Builder.CreateInBoundsGEP(
  4052. ToArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  4053. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  4054. } else {
  4055. unsigned size = CurFromTy->getArrayNumElements();
  4056. Type *FromEltTy = CurFromTy->getArrayElementType();
  4057. for (unsigned i = 0; i < size; i++) {
  4058. idxList.push_back(Builder.getInt32(i));
  4059. unsigned idx = calcIdx * size + i;
  4060. CastCopyArrayMultiDimTo1Dim(FromArray, ToArray, FromEltTy, idxList, idx,
  4061. HandleTy, HLM, Builder, CastCopyFn,
  4062. bRowMajor);
  4063. idxList.pop_back();
  4064. }
  4065. }
  4066. }
  4067. static void
  4068. CastCopyArray1DimToMultiDim(Value *FromArray, Value *ToArray, Type *CurToTy,
  4069. std::vector<Value *> &idxList, unsigned calcIdx,
  4070. Type *HandleTy, HLModule &HLM, IRBuilder<> &Builder,
  4071. CopyFunctionTy CastCopyFn, bool bRowMajor) {
  4072. if (CurToTy->isVectorTy()) {
  4073. // Copy array to vector.
  4074. Value *V = UndefValue::get(CurToTy);
  4075. unsigned vecSize = CurToTy->getVectorNumElements();
  4076. // Calculate the offset.
  4077. unsigned offset = calcIdx * vecSize;
  4078. Value *zeroIdx = Builder.getInt32(0);
  4079. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4080. for (unsigned i = 0; i < vecSize; i++) {
  4081. Value *FromPtr = Builder.CreateInBoundsGEP(
  4082. FromArray, {zeroIdx, Builder.getInt32(offset++)});
  4083. Value *Elt = Builder.CreateLoad(FromPtr);
  4084. V = Builder.CreateInsertElement(V, Elt, i);
  4085. }
  4086. Builder.CreateStore(V, ToPtr);
  4087. } else if (HLMatrixLower::IsMatrixType(CurToTy)) {
  4088. // Copy array to matrix.
  4089. unsigned col, row;
  4090. HLMatrixLower::GetMatrixInfo(CurToTy, col, row);
  4091. // Calculate the offset.
  4092. unsigned offset = calcIdx * col * row;
  4093. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4094. CopyArrayPtrToMatPtr(FromArray, offset, ToPtr, HLM, Builder, bRowMajor);
  4095. } else if (!CurToTy->isArrayTy()) {
  4096. Value *FromPtr = Builder.CreateInBoundsGEP(
  4097. FromArray, {Builder.getInt32(0), Builder.getInt32(calcIdx)});
  4098. Value *ToPtr = Builder.CreateInBoundsGEP(ToArray, idxList);
  4099. CastCopyFn(FromPtr, ToPtr, HLM, HandleTy, Builder, bRowMajor);
  4100. } else {
  4101. unsigned size = CurToTy->getArrayNumElements();
  4102. Type *ToEltTy = CurToTy->getArrayElementType();
  4103. for (unsigned i = 0; i < size; i++) {
  4104. idxList.push_back(Builder.getInt32(i));
  4105. unsigned idx = calcIdx * size + i;
  4106. CastCopyArray1DimToMultiDim(FromArray, ToArray, ToEltTy, idxList, idx,
  4107. HandleTy, HLM, Builder, CastCopyFn,
  4108. bRowMajor);
  4109. idxList.pop_back();
  4110. }
  4111. }
  4112. }
  4113. static void CastCopyOldPtrToNewPtr(Value *OldPtr, Value *NewPtr, HLModule &HLM,
  4114. Type *HandleTy, IRBuilder<> &Builder,
  4115. bool bRowMajor) {
  4116. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4117. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4118. if (NewTy == HandleTy) {
  4119. CopyResourcePtrToHandlePtr(OldPtr, NewPtr, HLM, Builder);
  4120. } else if (OldTy->isVectorTy()) {
  4121. // Copy vector to array.
  4122. Value *V = Builder.CreateLoad(OldPtr);
  4123. unsigned vecSize = OldTy->getVectorNumElements();
  4124. Value *zeroIdx = Builder.getInt32(0);
  4125. for (unsigned i = 0; i < vecSize; i++) {
  4126. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4127. Value *Elt = Builder.CreateExtractElement(V, i);
  4128. Builder.CreateStore(Elt, EltPtr);
  4129. }
  4130. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4131. CopyMatPtrToArrayPtr(OldPtr, NewPtr, /*arrayBaseIdx*/ 0, HLM, Builder,
  4132. bRowMajor);
  4133. } else if (OldTy->isArrayTy()) {
  4134. std::vector<Value *> idxList;
  4135. idxList.emplace_back(Builder.getInt32(0));
  4136. CastCopyArrayMultiDimTo1Dim(OldPtr, NewPtr, OldTy, idxList, /*calcIdx*/ 0,
  4137. HandleTy, HLM, Builder, CastCopyOldPtrToNewPtr,
  4138. bRowMajor);
  4139. }
  4140. }
  4141. static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
  4142. Type *HandleTy, IRBuilder<> &Builder,
  4143. bool bRowMajor) {
  4144. Type *NewTy = NewPtr->getType()->getPointerElementType();
  4145. Type *OldTy = OldPtr->getType()->getPointerElementType();
  4146. if (NewTy == HandleTy) {
  4147. CopyHandlePtrToResourcePtr(NewPtr, OldPtr, HLM, Builder);
  4148. } else if (OldTy->isVectorTy()) {
  4149. // Copy array to vector.
  4150. Value *V = UndefValue::get(OldTy);
  4151. unsigned vecSize = OldTy->getVectorNumElements();
  4152. Value *zeroIdx = Builder.getInt32(0);
  4153. for (unsigned i = 0; i < vecSize; i++) {
  4154. Value *EltPtr = Builder.CreateGEP(NewPtr, {zeroIdx, Builder.getInt32(i)});
  4155. Value *Elt = Builder.CreateLoad(EltPtr);
  4156. V = Builder.CreateInsertElement(V, Elt, i);
  4157. }
  4158. Builder.CreateStore(V, OldPtr);
  4159. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4160. CopyArrayPtrToMatPtr(NewPtr, /*arrayBaseIdx*/ 0, OldPtr, HLM, Builder,
  4161. bRowMajor);
  4162. } else if (OldTy->isArrayTy()) {
  4163. std::vector<Value *> idxList;
  4164. idxList.emplace_back(Builder.getInt32(0));
  4165. CastCopyArray1DimToMultiDim(NewPtr, OldPtr, OldTy, idxList, /*calcIdx*/ 0,
  4166. HandleTy, HLM, Builder, CastCopyNewPtrToOldPtr,
  4167. bRowMajor);
  4168. }
  4169. }
  4170. void SROA_Parameter_HLSL::replaceCastArgument(Value *&NewArg, Value *OldArg,
  4171. DxilParamInputQual inputQual,
  4172. IRBuilder<> &CallBuilder,
  4173. IRBuilder<> &RetBuilder) {
  4174. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4175. Type *NewTy = NewArg->getType();
  4176. Type *OldTy = OldArg->getType();
  4177. bool bIn = inputQual == DxilParamInputQual::Inout ||
  4178. inputQual == DxilParamInputQual::In;
  4179. bool bOut = inputQual == DxilParamInputQual::Inout ||
  4180. inputQual == DxilParamInputQual::Out;
  4181. if (NewArg->getType() == HandleTy) {
  4182. Value *Handle =
  4183. CastResourcePtrToHandle(OldArg, HandleTy, *m_pHLModule, CallBuilder);
  4184. // Use Handle as NewArg.
  4185. NewArg = Handle;
  4186. } else if (vectorEltsMap.count(NewArg)) {
  4187. Type *VecTy = OldTy;
  4188. if (VecTy->isPointerTy())
  4189. VecTy = VecTy->getPointerElementType();
  4190. // Flattened vector.
  4191. SmallVector<Value *, 4> &elts = vectorEltsMap[NewArg];
  4192. unsigned vecSize = elts.size();
  4193. if (NewTy->isPointerTy()) {
  4194. if (bIn) {
  4195. // Copy OldArg to NewArg before Call.
  4196. CopyVectorPtrToEltsPtr(OldArg, elts, vecSize, CallBuilder);
  4197. }
  4198. // bOut must be true here.
  4199. // Store NewArg to OldArg after Call.
  4200. CopyEltsPtrToVectorPtr(elts, OldArg, VecTy, vecSize, RetBuilder);
  4201. } else {
  4202. // Must be in parameter.
  4203. // Copy OldArg to NewArg before Call.
  4204. Value *Vec = OldArg;
  4205. if (OldTy->isPointerTy()) {
  4206. Vec = CallBuilder.CreateLoad(OldArg);
  4207. }
  4208. for (unsigned i = 0; i < vecSize; i++) {
  4209. Value *Elt = CallBuilder.CreateExtractElement(Vec, i);
  4210. // Save elt to update arg in createFlattenedFunctionCall.
  4211. elts[i] = Elt;
  4212. }
  4213. }
  4214. // Don't need elts anymore.
  4215. vectorEltsMap.erase(NewArg);
  4216. } else if (!NewTy->isPointerTy()) {
  4217. // Ptr param is cast to non-ptr param.
  4218. // Must be in param.
  4219. // Load OldArg as NewArg before call.
  4220. NewArg = CallBuilder.CreateLoad(OldArg);
  4221. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4222. bool bRowMajor = castRowMajorParamMap.count(NewArg);
  4223. CopyMatToArrayPtr(OldArg, NewArg, /*arrayBaseIdx*/ 0, *m_pHLModule,
  4224. CallBuilder, bRowMajor);
  4225. } else {
  4226. bool bRowMajor = castRowMajorParamMap.count(NewArg);
  4227. // NewTy is pointer type.
  4228. // Copy OldArg to NewArg before Call.
  4229. if (bIn) {
  4230. CastCopyOldPtrToNewPtr(OldArg, NewArg, *m_pHLModule, HandleTy,
  4231. CallBuilder, bRowMajor);
  4232. }
  4233. if (bOut) {
  4234. // Store NewArg to OldArg after Call.
  4235. CastCopyNewPtrToOldPtr(NewArg, OldArg, *m_pHLModule, HandleTy, RetBuilder,
  4236. bRowMajor);
  4237. }
  4238. }
  4239. }
  4240. void SROA_Parameter_HLSL::replaceCastParameter(
  4241. Value *NewParam, Value *OldParam, Function &F, Argument *Arg,
  4242. const DxilParamInputQual inputQual, IRBuilder<> &Builder) {
  4243. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4244. Type *HandlePtrTy = PointerType::get(HandleTy, 0);
  4245. Module &M = *m_pHLModule->GetModule();
  4246. Type *NewTy = NewParam->getType();
  4247. Type *OldTy = OldParam->getType();
  4248. bool bIn = inputQual == DxilParamInputQual::Inout ||
  4249. inputQual == DxilParamInputQual::In;
  4250. bool bOut = inputQual == DxilParamInputQual::Inout ||
  4251. inputQual == DxilParamInputQual::Out;
  4252. // Make sure InsertPoint after OldParam inst.
  4253. if (Instruction *I = dyn_cast<Instruction>(OldParam)) {
  4254. Builder.SetInsertPoint(I->getNextNode());
  4255. }
  4256. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(OldParam)) {
  4257. // Add debug info to new param.
  4258. DIBuilder DIB(*F.getParent(), /*AllowUnresolved*/ false);
  4259. DIExpression *DDIExp = DDI->getExpression();
  4260. DIB.insertDeclare(NewParam, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4261. Builder.GetInsertPoint());
  4262. }
  4263. if (isa<Argument>(OldParam) && OldTy->isPointerTy()) {
  4264. // OldParam will be removed with Old function.
  4265. // Create alloca to replace it.
  4266. Value *AllocParam = Builder.CreateAlloca(OldTy->getPointerElementType());
  4267. OldParam->replaceAllUsesWith(AllocParam);
  4268. OldParam = AllocParam;
  4269. }
  4270. if (NewTy == HandleTy) {
  4271. CopyHandleToResourcePtr(NewParam, OldParam, *m_pHLModule, Builder);
  4272. // Save resource attribute.
  4273. Type *ResTy = OldTy->getPointerElementType();
  4274. MDNode *MD = HLModule::GetDxilResourceAttrib(ResTy, M);
  4275. m_pHLModule->MarkDxilResourceAttrib(Arg, MD);
  4276. } else if (vectorEltsMap.count(NewParam)) {
  4277. // Vector is flattened to scalars.
  4278. Type *VecTy = OldTy;
  4279. if (VecTy->isPointerTy())
  4280. VecTy = VecTy->getPointerElementType();
  4281. // Flattened vector.
  4282. SmallVector<Value *, 4> &elts = vectorEltsMap[NewParam];
  4283. unsigned vecSize = elts.size();
  4284. if (NewTy->isPointerTy()) {
  4285. if (bIn) {
  4286. // Copy NewParam to OldParam at entry.
  4287. CopyEltsPtrToVectorPtr(elts, OldParam, VecTy, vecSize, Builder);
  4288. }
  4289. // bOut must be true here.
  4290. // Store the OldParam to NewParam before every return.
  4291. for (auto &BB : F.getBasicBlockList()) {
  4292. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4293. IRBuilder<> RetBuilder(RI);
  4294. CopyVectorPtrToEltsPtr(OldParam, elts, vecSize, RetBuilder);
  4295. }
  4296. }
  4297. } else {
  4298. // Must be in parameter.
  4299. // Copy NewParam to OldParam at entry.
  4300. Value *Vec = UndefValue::get(VecTy);
  4301. for (unsigned i = 0; i < vecSize; i++) {
  4302. Vec = Builder.CreateInsertElement(Vec, elts[i], i);
  4303. }
  4304. if (OldTy->isPointerTy()) {
  4305. Builder.CreateStore(Vec, OldParam);
  4306. } else {
  4307. OldParam->replaceAllUsesWith(Vec);
  4308. }
  4309. }
  4310. // Don't need elts anymore.
  4311. vectorEltsMap.erase(NewParam);
  4312. } else if (!NewTy->isPointerTy()) {
  4313. // Ptr param is cast to non-ptr param.
  4314. // Must be in param.
  4315. // Store NewParam to OldParam at entry.
  4316. Builder.CreateStore(NewParam, OldParam);
  4317. } else if (HLMatrixLower::IsMatrixType(OldTy)) {
  4318. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4319. Value *Mat = LoadArrayPtrToMat(NewParam, /*arrayBaseIdx*/ 0, OldTy,
  4320. *m_pHLModule, Builder, bRowMajor);
  4321. OldParam->replaceAllUsesWith(Mat);
  4322. } else {
  4323. bool bRowMajor = castRowMajorParamMap.count(NewParam);
  4324. // NewTy is pointer type.
  4325. if (bIn) {
  4326. // Copy NewParam to OldParam at entry.
  4327. CastCopyNewPtrToOldPtr(NewParam, OldParam, *m_pHLModule, HandleTy,
  4328. Builder, bRowMajor);
  4329. }
  4330. if (bOut) {
  4331. // Store the OldParam to NewParam before every return.
  4332. for (auto &BB : F.getBasicBlockList()) {
  4333. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4334. IRBuilder<> RetBuilder(RI);
  4335. CastCopyOldPtrToNewPtr(OldParam, NewParam, *m_pHLModule, HandleTy,
  4336. RetBuilder, bRowMajor);
  4337. }
  4338. }
  4339. }
  4340. Type *NewEltTy = dxilutil::GetArrayEltTy(NewTy);
  4341. Type *OldEltTy = dxilutil::GetArrayEltTy(OldTy);
  4342. if (NewEltTy == HandlePtrTy) {
  4343. // Save resource attribute.
  4344. Type *ResTy = OldEltTy;
  4345. MDNode *MD = HLModule::GetDxilResourceAttrib(ResTy, M);
  4346. m_pHLModule->MarkDxilResourceAttrib(Arg, MD);
  4347. }
  4348. }
  4349. }
  4350. Value *SROA_Parameter_HLSL::castArgumentIfRequired(
  4351. Value *V, Type *Ty, bool bOut, bool hasShaderInputOutput,
  4352. DxilParamInputQual inputQual, DxilFieldAnnotation &annotation,
  4353. std::deque<Value *> &WorkList, IRBuilder<> &Builder) {
  4354. Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
  4355. Module &M = *m_pHLModule->GetModule();
  4356. // Remove pointer for vector/scalar which is not out.
  4357. if (V->getType()->isPointerTy() && !Ty->isAggregateType() && !bOut) {
  4358. Value *Ptr = Builder.CreateAlloca(Ty);
  4359. V->replaceAllUsesWith(Ptr);
  4360. // Create load here to make correct type.
  4361. // The Ptr will be store with correct value in replaceCastParameter and
  4362. // replaceCastArgument.
  4363. if (Ptr->hasOneUse()) {
  4364. // Load after existing user for call arg replace.
  4365. // If not, call arg will load undef.
  4366. // This will not hurt parameter, new load is only after first load.
  4367. // It still before all the load users.
  4368. Instruction *User = cast<Instruction>(*(Ptr->user_begin()));
  4369. IRBuilder<> CallBuilder(User->getNextNode());
  4370. V = CallBuilder.CreateLoad(Ptr);
  4371. } else {
  4372. V = Builder.CreateLoad(Ptr);
  4373. }
  4374. castParamMap[V] = std::make_pair(Ptr, inputQual);
  4375. }
  4376. // Lower resource type to handle ty.
  4377. if (HLModule::IsHLSLObjectType(Ty) &&
  4378. !HLModule::IsStreamOutputPtrType(V->getType())) {
  4379. Value *Res = V;
  4380. if (!bOut) {
  4381. Value *LdRes = Builder.CreateLoad(Res);
  4382. V = m_pHLModule->EmitHLOperationCall(Builder,
  4383. HLOpcodeGroup::HLCreateHandle,
  4384. /*opcode*/ 0, HandleTy, {LdRes}, M);
  4385. } else {
  4386. V = Builder.CreateAlloca(HandleTy);
  4387. }
  4388. castParamMap[V] = std::make_pair(Res, inputQual);
  4389. } else if (Ty->isArrayTy()) {
  4390. unsigned arraySize = 1;
  4391. Type *AT = Ty;
  4392. while (AT->isArrayTy()) {
  4393. arraySize *= AT->getArrayNumElements();
  4394. AT = AT->getArrayElementType();
  4395. }
  4396. if (HLModule::IsHLSLObjectType(AT)) {
  4397. Value *Res = V;
  4398. Type *Ty = ArrayType::get(HandleTy, arraySize);
  4399. V = Builder.CreateAlloca(Ty);
  4400. castParamMap[V] = std::make_pair(Res, inputQual);
  4401. }
  4402. }
  4403. if (!hasShaderInputOutput) {
  4404. if (Ty->isVectorTy()) {
  4405. Value *OldV = V;
  4406. Type *EltTy = Ty->getVectorElementType();
  4407. unsigned vecSize = Ty->getVectorNumElements();
  4408. // Split vector into scalars.
  4409. if (OldV->getType()->isPointerTy()) {
  4410. // Split into scalar ptr.
  4411. V = Builder.CreateAlloca(EltTy);
  4412. vectorEltsMap[V].emplace_back(V);
  4413. for (unsigned i = 1; i < vecSize; i++) {
  4414. Value *Elt = Builder.CreateAlloca(EltTy);
  4415. vectorEltsMap[V].emplace_back(Elt);
  4416. }
  4417. } else {
  4418. IRBuilder<> TmpBuilder(Builder.GetInsertPoint());
  4419. // Make sure extract element after OldV.
  4420. if (Instruction *OldI = dyn_cast<Instruction>(OldV)) {
  4421. TmpBuilder.SetInsertPoint(OldI->getNextNode());
  4422. }
  4423. // Split into scalar.
  4424. V = TmpBuilder.CreateExtractElement(OldV, (uint64_t)0);
  4425. vectorEltsMap[V].emplace_back(V);
  4426. for (unsigned i = 1; i < vecSize; i++) {
  4427. Value *Elt = TmpBuilder.CreateExtractElement(OldV, i);
  4428. vectorEltsMap[V].emplace_back(Elt);
  4429. }
  4430. }
  4431. // Add to work list by reverse order.
  4432. for (unsigned i = vecSize - 1; i > 0; i--) {
  4433. Value *Elt = vectorEltsMap[V][i];
  4434. WorkList.push_front(Elt);
  4435. }
  4436. // For case OldV is from input vector ptr.
  4437. if (castParamMap.count(OldV)) {
  4438. OldV = castParamMap[OldV].first;
  4439. }
  4440. castParamMap[V] = std::make_pair(OldV, inputQual);
  4441. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  4442. unsigned col, row;
  4443. Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  4444. Value *Mat = V;
  4445. // Cast matrix to array.
  4446. Type *AT = ArrayType::get(EltTy, col * row);
  4447. V = Builder.CreateAlloca(AT);
  4448. castParamMap[V] = std::make_pair(Mat, inputQual);
  4449. DXASSERT(annotation.HasMatrixAnnotation(), "need matrix annotation here");
  4450. if (annotation.GetMatrixAnnotation().Orientation ==
  4451. hlsl::MatrixOrientation::RowMajor) {
  4452. castRowMajorParamMap.insert(V);
  4453. }
  4454. } else if (Ty->isArrayTy()) {
  4455. unsigned arraySize = 1;
  4456. Type *AT = Ty;
  4457. unsigned dim = 0;
  4458. while (AT->isArrayTy()) {
  4459. ++dim;
  4460. arraySize *= AT->getArrayNumElements();
  4461. AT = AT->getArrayElementType();
  4462. }
  4463. if (VectorType *VT = dyn_cast<VectorType>(AT)) {
  4464. Value *VecArray = V;
  4465. Type *AT = ArrayType::get(VT->getElementType(),
  4466. arraySize * VT->getNumElements());
  4467. V = Builder.CreateAlloca(AT);
  4468. castParamMap[V] = std::make_pair(VecArray, inputQual);
  4469. } else if (HLMatrixLower::IsMatrixType(AT)) {
  4470. unsigned col, row;
  4471. Type *EltTy = HLMatrixLower::GetMatrixInfo(AT, col, row);
  4472. Value *MatArray = V;
  4473. Type *AT = ArrayType::get(EltTy, arraySize * col * row);
  4474. V = Builder.CreateAlloca(AT);
  4475. castParamMap[V] = std::make_pair(MatArray, inputQual);
  4476. DXASSERT(annotation.HasMatrixAnnotation(),
  4477. "need matrix annotation here");
  4478. if (annotation.GetMatrixAnnotation().Orientation ==
  4479. hlsl::MatrixOrientation::RowMajor) {
  4480. castRowMajorParamMap.insert(V);
  4481. }
  4482. } else if (dim > 1) {
  4483. // Flatten multi-dim array to 1dim.
  4484. Value *MultiArray = V;
  4485. V = Builder.CreateAlloca(
  4486. ArrayType::get(VT->getElementType(), arraySize));
  4487. castParamMap[V] = std::make_pair(MultiArray, inputQual);
  4488. }
  4489. }
  4490. } else {
  4491. // Entry function matrix value parameter has major.
  4492. // Make sure its user use row major matrix value.
  4493. bool updateToColMajor = annotation.HasMatrixAnnotation() &&
  4494. annotation.GetMatrixAnnotation().Orientation ==
  4495. MatrixOrientation::ColumnMajor;
  4496. if (updateToColMajor) {
  4497. if (V->getType()->isPointerTy()) {
  4498. for (User *user : V->users()) {
  4499. CallInst *CI = dyn_cast<CallInst>(user);
  4500. if (!CI)
  4501. continue;
  4502. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4503. if (group != HLOpcodeGroup::HLMatLoadStore)
  4504. continue;
  4505. HLMatLoadStoreOpcode opcode =
  4506. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4507. Type *opcodeTy = Builder.getInt32Ty();
  4508. switch (opcode) {
  4509. case HLMatLoadStoreOpcode::RowMatLoad: {
  4510. // Update matrix function opcode to col major version.
  4511. Value *rowOpArg = ConstantInt::get(
  4512. opcodeTy,
  4513. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
  4514. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4515. // Cast it to row major.
  4516. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4517. Builder, HLOpcodeGroup::HLCast,
  4518. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
  4519. CI->replaceAllUsesWith(RowMat);
  4520. // Set arg to CI again.
  4521. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
  4522. } break;
  4523. case HLMatLoadStoreOpcode::RowMatStore:
  4524. // Update matrix function opcode to col major version.
  4525. Value *rowOpArg = ConstantInt::get(
  4526. opcodeTy,
  4527. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
  4528. CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
  4529. Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  4530. // Cast it to col major.
  4531. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4532. Builder, HLOpcodeGroup::HLCast,
  4533. (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
  4534. CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
  4535. break;
  4536. }
  4537. }
  4538. } else {
  4539. CallInst *RowMat = HLModule::EmitHLOperationCall(
  4540. Builder, HLOpcodeGroup::HLCast,
  4541. (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
  4542. V->replaceAllUsesWith(RowMat);
  4543. // Set arg to V again.
  4544. RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
  4545. }
  4546. }
  4547. }
  4548. return V;
  4549. }
  4550. void SROA_Parameter_HLSL::flattenArgument(
  4551. Function *F, Value *Arg, bool bForParam,
  4552. DxilParameterAnnotation &paramAnnotation,
  4553. std::vector<Value *> &FlatParamList,
  4554. std::vector<DxilParameterAnnotation> &FlatAnnotationList,
  4555. IRBuilder<> &Builder, DbgDeclareInst *DDI) {
  4556. std::deque<Value *> WorkList;
  4557. WorkList.push_back(Arg);
  4558. Function *Entry = m_pHLModule->GetEntryFunction();
  4559. bool hasShaderInputOutput = F == Entry;
  4560. if (m_pHLModule->HasDxilFunctionProps(F)) {
  4561. hasShaderInputOutput = true;
  4562. }
  4563. if (m_pHLModule->HasDxilFunctionProps(Entry)) {
  4564. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(Entry);
  4565. if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
  4566. Function *patchConstantFunc = funcProps.ShaderProps.HS.patchConstantFunc;
  4567. hasShaderInputOutput |= F == patchConstantFunc;
  4568. }
  4569. }
  4570. unsigned startArgIndex = FlatAnnotationList.size();
  4571. // Map from value to annotation.
  4572. std::unordered_map<Value *, DxilFieldAnnotation> annotationMap;
  4573. annotationMap[Arg] = paramAnnotation;
  4574. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  4575. const std::string &semantic = paramAnnotation.GetSemanticString();
  4576. bool bSemOverride = !semantic.empty();
  4577. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4578. bool bOut = inputQual == DxilParamInputQual::Out ||
  4579. inputQual == DxilParamInputQual::Inout ||
  4580. inputQual == DxilParamInputQual::OutStream0 ||
  4581. inputQual == DxilParamInputQual::OutStream1 ||
  4582. inputQual == DxilParamInputQual::OutStream2 ||
  4583. inputQual == DxilParamInputQual::OutStream3;
  4584. // Map from semantic string to type.
  4585. llvm::StringMap<Type *> semanticTypeMap;
  4586. // Original semantic type.
  4587. if (!semantic.empty()) {
  4588. // Unwrap top-level array if primitive
  4589. if (inputQual == DxilParamInputQual::InputPatch ||
  4590. inputQual == DxilParamInputQual::OutputPatch ||
  4591. inputQual == DxilParamInputQual::InputPrimitive) {
  4592. Type *Ty = Arg->getType();
  4593. if (Ty->isPointerTy())
  4594. Ty = Ty->getPointerElementType();
  4595. if (Ty->isArrayTy())
  4596. semanticTypeMap[semantic] = Ty->getArrayElementType();
  4597. } else {
  4598. semanticTypeMap[semantic] = Arg->getType();
  4599. }
  4600. }
  4601. std::vector<Instruction*> deadAllocas;
  4602. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  4603. unsigned debugOffset = 0;
  4604. const DataLayout &DL = F->getParent()->getDataLayout();
  4605. // Process the worklist
  4606. while (!WorkList.empty()) {
  4607. Value *V = WorkList.front();
  4608. WorkList.pop_front();
  4609. // Do not skip unused parameter.
  4610. DxilFieldAnnotation &annotation = annotationMap[V];
  4611. const bool bAllowReplace = !bOut;
  4612. SROA_Helper::LowerMemcpy(V, &annotation, dxilTypeSys, DL, bAllowReplace);
  4613. std::vector<Value *> Elts;
  4614. // Not flat vector for entry function currently.
  4615. bool SROAed = SROA_Helper::DoScalarReplacement(
  4616. V, Elts, Builder, /*bFlatVector*/ false, annotation.IsPrecise(),
  4617. dxilTypeSys, DeadInsts);
  4618. if (SROAed) {
  4619. Type *Ty = V->getType()->getPointerElementType();
  4620. // Skip empty struct parameters.
  4621. if (SROA_Helper::IsEmptyStructType(Ty, dxilTypeSys)) {
  4622. SROA_Helper::MarkEmptyStructUsers(V, DeadInsts);
  4623. DeleteDeadInstructions();
  4624. continue;
  4625. }
  4626. // Push Elts into workList.
  4627. // Use rbegin to make sure the order not change.
  4628. for (auto iter = Elts.rbegin(); iter != Elts.rend(); iter++)
  4629. WorkList.push_front(*iter);
  4630. bool precise = annotation.IsPrecise();
  4631. const std::string &semantic = annotation.GetSemanticString();
  4632. hlsl::InterpolationMode interpMode = annotation.GetInterpolationMode();
  4633. for (unsigned i=0;i<Elts.size();i++) {
  4634. Value *Elt = Elts[i];
  4635. DxilFieldAnnotation EltAnnotation = GetEltAnnotation(Ty, i, annotation, dxilTypeSys);
  4636. const std::string &eltSem = EltAnnotation.GetSemanticString();
  4637. if (!semantic.empty()) {
  4638. if (!eltSem.empty()) {
  4639. // TODO: warning for override the semantic in EltAnnotation.
  4640. }
  4641. // Just save parent semantic here, allocate later.
  4642. EltAnnotation.SetSemanticString(semantic);
  4643. } else if (!eltSem.empty() &&
  4644. semanticTypeMap.count(eltSem) == 0) {
  4645. Type *EltTy = dxilutil::GetArrayEltTy(Ty);
  4646. DXASSERT(EltTy->isStructTy(), "must be a struct type to has semantic.");
  4647. semanticTypeMap[eltSem] = EltTy->getStructElementType(i);
  4648. }
  4649. if (precise)
  4650. EltAnnotation.SetPrecise();
  4651. if (EltAnnotation.GetInterpolationMode().GetKind() == DXIL::InterpolationMode::Undefined)
  4652. EltAnnotation.SetInterpolationMode(interpMode);
  4653. annotationMap[Elt] = EltAnnotation;
  4654. }
  4655. annotationMap.erase(V);
  4656. ++NumReplaced;
  4657. if (Instruction *I = dyn_cast<Instruction>(V))
  4658. deadAllocas.emplace_back(I);
  4659. } else {
  4660. if (bSemOverride) {
  4661. if (!annotation.GetSemanticString().empty()) {
  4662. // TODO: warning for override the semantic in EltAnnotation.
  4663. }
  4664. // Just save parent semantic here, allocate later.
  4665. annotation.SetSemanticString(semantic);
  4666. }
  4667. Type *Ty = V->getType();
  4668. if (Ty->isPointerTy())
  4669. Ty = Ty->getPointerElementType();
  4670. // Flatten array of SV_Target.
  4671. StringRef semanticStr = annotation.GetSemanticString();
  4672. if (semanticStr.upper().find("SV_TARGET") == 0 &&
  4673. Ty->isArrayTy()) {
  4674. Type *Ty = cast<ArrayType>(V->getType()->getPointerElementType());
  4675. StringRef targetStr;
  4676. unsigned targetIndex;
  4677. Semantic::DecomposeNameAndIndex(semanticStr, &targetStr, &targetIndex);
  4678. // Replace target parameter with local target.
  4679. AllocaInst *localTarget = Builder.CreateAlloca(Ty);
  4680. V->replaceAllUsesWith(localTarget);
  4681. unsigned arraySize = 1;
  4682. std::vector<unsigned> arraySizeList;
  4683. while (Ty->isArrayTy()) {
  4684. unsigned size = Ty->getArrayNumElements();
  4685. arraySizeList.emplace_back(size);
  4686. arraySize *= size;
  4687. Ty = Ty->getArrayElementType();
  4688. }
  4689. unsigned arrayLevel = arraySizeList.size();
  4690. std::vector<unsigned> arrayIdxList(arrayLevel, 0);
  4691. // Create flattened target.
  4692. DxilFieldAnnotation EltAnnotation = annotation;
  4693. for (unsigned i=0;i<arraySize;i++) {
  4694. Value *Elt = Builder.CreateAlloca(Ty);
  4695. EltAnnotation.SetSemanticString(targetStr.str()+std::to_string(targetIndex+i));
  4696. // Add semantic type.
  4697. semanticTypeMap[EltAnnotation.GetSemanticString()] = Ty;
  4698. annotationMap[Elt] = EltAnnotation;
  4699. WorkList.push_front(Elt);
  4700. // Copy local target to flattened target.
  4701. std::vector<Value*> idxList(arrayLevel+1);
  4702. idxList[0] = Builder.getInt32(0);
  4703. for (unsigned idx=0;idx<arrayLevel; idx++) {
  4704. idxList[idx+1] = Builder.getInt32(arrayIdxList[idx]);
  4705. }
  4706. if (bForParam) {
  4707. // If Argument, copy before each return.
  4708. for (auto &BB : F->getBasicBlockList()) {
  4709. TerminatorInst *TI = BB.getTerminator();
  4710. if (isa<ReturnInst>(TI)) {
  4711. IRBuilder<> RetBuilder(TI);
  4712. Value *Ptr = RetBuilder.CreateGEP(localTarget, idxList);
  4713. Value *V = RetBuilder.CreateLoad(Ptr);
  4714. RetBuilder.CreateStore(V, Elt);
  4715. }
  4716. }
  4717. } else {
  4718. // Else, copy with Builder.
  4719. Value *Ptr = Builder.CreateGEP(localTarget, idxList);
  4720. Value *V = Builder.CreateLoad(Ptr);
  4721. Builder.CreateStore(V, Elt);
  4722. }
  4723. // Update arrayIdxList.
  4724. for (unsigned idx=arrayLevel;idx>0;idx--) {
  4725. arrayIdxList[idx-1]++;
  4726. if (arrayIdxList[idx-1] < arraySizeList[idx-1])
  4727. break;
  4728. arrayIdxList[idx-1] = 0;
  4729. }
  4730. }
  4731. // Don't override flattened SV_Target.
  4732. if (V == Arg) {
  4733. bSemOverride = false;
  4734. }
  4735. continue;
  4736. }
  4737. // Cast vector/matrix/resource parameter.
  4738. V = castArgumentIfRequired(V, Ty, bOut, hasShaderInputOutput, inputQual,
  4739. annotation, WorkList, Builder);
  4740. // Cannot SROA, save it to final parameter list.
  4741. FlatParamList.emplace_back(V);
  4742. // Create ParamAnnotation for V.
  4743. FlatAnnotationList.emplace_back(DxilParameterAnnotation());
  4744. DxilParameterAnnotation &flatParamAnnotation = FlatAnnotationList.back();
  4745. flatParamAnnotation.SetParamInputQual(paramAnnotation.GetParamInputQual());
  4746. flatParamAnnotation.SetInterpolationMode(annotation.GetInterpolationMode());
  4747. flatParamAnnotation.SetSemanticString(annotation.GetSemanticString());
  4748. flatParamAnnotation.SetCompType(annotation.GetCompType().GetKind());
  4749. flatParamAnnotation.SetMatrixAnnotation(annotation.GetMatrixAnnotation());
  4750. flatParamAnnotation.SetPrecise(annotation.IsPrecise());
  4751. flatParamAnnotation.SetResourceAttribute(annotation.GetResourceAttribute());
  4752. // Add debug info.
  4753. if (DDI && V != Arg) {
  4754. Value *TmpV = V;
  4755. // If V is casted, add debug into to original V.
  4756. if (castParamMap.count(V)) {
  4757. TmpV = castParamMap[V].first;
  4758. // One more level for ptr of input vector.
  4759. // It cast from ptr to non-ptr then cast to scalars.
  4760. if (castParamMap.count(TmpV)) {
  4761. TmpV = castParamMap[TmpV].first;
  4762. }
  4763. }
  4764. Type *Ty = TmpV->getType();
  4765. if (Ty->isPointerTy())
  4766. Ty = Ty->getPointerElementType();
  4767. unsigned size = DL.getTypeAllocSize(Ty);
  4768. DIExpression *DDIExp = DIB.createBitPieceExpression(debugOffset, size);
  4769. debugOffset += size;
  4770. DIB.insertDeclare(TmpV, DDI->getVariable(), DDIExp, DDI->getDebugLoc(),
  4771. Builder.GetInsertPoint());
  4772. }
  4773. // Flatten stream out.
  4774. if (HLModule::IsStreamOutputPtrType(V->getType())) {
  4775. // For stream output objects.
  4776. // Create a value as output value.
  4777. Type *outputType = V->getType()->getPointerElementType()->getStructElementType(0);
  4778. Value *outputVal = Builder.CreateAlloca(outputType);
  4779. // For each stream.Append(data)
  4780. // transform into
  4781. // d = load data
  4782. // store outputVal, d
  4783. // stream.Append(outputVal)
  4784. for (User *user : V->users()) {
  4785. if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4786. unsigned opcode = GetHLOpcode(CI);
  4787. if (opcode == static_cast<unsigned>(IntrinsicOp::MOP_Append)) {
  4788. if (CI->getNumArgOperands() == (HLOperandIndex::kStreamAppendDataOpIndex + 1)) {
  4789. Value *data =
  4790. CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  4791. DXASSERT(data->getType()->isPointerTy(),
  4792. "Append value must be pointer.");
  4793. IRBuilder<> Builder(CI);
  4794. llvm::SmallVector<llvm::Value *, 16> idxList;
  4795. SplitCpy(data->getType(), outputVal, data, idxList, Builder,
  4796. dxilTypeSys, &flatParamAnnotation);
  4797. CI->setArgOperand(HLOperandIndex::kStreamAppendDataOpIndex, outputVal);
  4798. }
  4799. else {
  4800. // Append has been flattened.
  4801. // Flatten store outputVal.
  4802. // Must be struct to be flatten.
  4803. IRBuilder<> Builder(CI);
  4804. llvm::SmallVector<llvm::Value *, 16> idxList;
  4805. llvm::SmallVector<llvm::Value *, 16> EltPtrList;
  4806. // split
  4807. SplitPtr(outputVal->getType(), outputVal, idxList, EltPtrList,
  4808. Builder);
  4809. unsigned eltCount = CI->getNumArgOperands()-2;
  4810. DXASSERT_LOCALVAR(eltCount, eltCount == EltPtrList.size(), "invalid element count");
  4811. for (unsigned i = HLOperandIndex::kStreamAppendDataOpIndex; i < CI->getNumArgOperands(); i++) {
  4812. Value *DataPtr = CI->getArgOperand(i);
  4813. Value *EltPtr =
  4814. EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
  4815. llvm::SmallVector<llvm::Value *, 16> idxList;
  4816. SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
  4817. Builder, dxilTypeSys, &flatParamAnnotation);
  4818. CI->setArgOperand(i, EltPtr);
  4819. }
  4820. }
  4821. }
  4822. }
  4823. }
  4824. // Then split output value to generate ParamQual.
  4825. WorkList.push_front(outputVal);
  4826. }
  4827. }
  4828. }
  4829. // Now erase any instructions that were made dead while rewriting the
  4830. // alloca.
  4831. DeleteDeadInstructions();
  4832. // Erase dead allocas after all uses deleted.
  4833. for (Instruction *I : deadAllocas)
  4834. I->eraseFromParent();
  4835. unsigned endArgIndex = FlatAnnotationList.size();
  4836. if (bForParam && startArgIndex < endArgIndex) {
  4837. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  4838. if (inputQual == DxilParamInputQual::OutStream0 ||
  4839. inputQual == DxilParamInputQual::OutStream1 ||
  4840. inputQual == DxilParamInputQual::OutStream2 ||
  4841. inputQual == DxilParamInputQual::OutStream3)
  4842. startArgIndex++;
  4843. DxilParameterAnnotation &flatParamAnnotation =
  4844. FlatAnnotationList[startArgIndex];
  4845. const std::string &semantic = flatParamAnnotation.GetSemanticString();
  4846. if (!semantic.empty())
  4847. allocateSemanticIndex(FlatAnnotationList, startArgIndex,
  4848. semanticTypeMap);
  4849. }
  4850. }
  4851. static bool IsUsedAsCallArg(Value *V) {
  4852. for (User *U : V->users()) {
  4853. if (CallInst *CI = dyn_cast<CallInst>(U)) {
  4854. Function *CalledF = CI->getCalledFunction();
  4855. HLOpcodeGroup group = GetHLOpcodeGroup(CalledF);
  4856. // Skip HL operations.
  4857. if (group != HLOpcodeGroup::NotHL ||
  4858. group == HLOpcodeGroup::HLExtIntrinsic) {
  4859. continue;
  4860. }
  4861. // Skip llvm intrinsic.
  4862. if (CalledF->isIntrinsic())
  4863. continue;
  4864. return true;
  4865. }
  4866. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
  4867. if (IsUsedAsCallArg(GEP))
  4868. return true;
  4869. }
  4870. }
  4871. return false;
  4872. }
  4873. // For function parameter which used in function call and need to be flattened.
  4874. // Replace with tmp alloca.
  4875. void SROA_Parameter_HLSL::preprocessArgUsedInCall(Function *F) {
  4876. if (F->isDeclaration())
  4877. return;
  4878. const DataLayout &DL = m_pHLModule->GetModule()->getDataLayout();
  4879. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  4880. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(F);
  4881. DXASSERT(pFuncAnnot, "else invalid function");
  4882. IRBuilder<> AllocaBuilder(F->getEntryBlock().getFirstInsertionPt());
  4883. SmallVector<ReturnInst*, 2> retList;
  4884. for (BasicBlock &bb : F->getBasicBlockList()) {
  4885. if (ReturnInst *RI = dyn_cast<ReturnInst>(bb.getTerminator())) {
  4886. retList.emplace_back(RI);
  4887. }
  4888. }
  4889. for (Argument &arg : F->args()) {
  4890. Type *Ty = arg.getType();
  4891. // Only check pointer types.
  4892. if (!Ty->isPointerTy())
  4893. continue;
  4894. Ty = Ty->getPointerElementType();
  4895. // Skip scalar types.
  4896. if (!Ty->isAggregateType() &&
  4897. Ty->getScalarType() == Ty)
  4898. continue;
  4899. bool bUsedInCall = IsUsedAsCallArg(&arg);
  4900. if (bUsedInCall) {
  4901. // Create tmp.
  4902. Value *TmpArg = AllocaBuilder.CreateAlloca(Ty);
  4903. // Replace arg with tmp.
  4904. arg.replaceAllUsesWith(TmpArg);
  4905. DxilParameterAnnotation &paramAnnot = pFuncAnnot->GetParameterAnnotation(arg.getArgNo());
  4906. DxilParamInputQual inputQual = paramAnnot.GetParamInputQual();
  4907. unsigned size = DL.getTypeAllocSize(Ty);
  4908. // Copy between arg and tmp.
  4909. if (inputQual == DxilParamInputQual::In ||
  4910. inputQual == DxilParamInputQual::Inout) {
  4911. // copy arg to tmp.
  4912. CallInst *argToTmp = AllocaBuilder.CreateMemCpy(TmpArg, &arg, size, 0);
  4913. // Split the memcpy.
  4914. MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(argToTmp), DL, nullptr,
  4915. typeSys);
  4916. }
  4917. if (inputQual == DxilParamInputQual::Out ||
  4918. inputQual == DxilParamInputQual::Inout) {
  4919. for (ReturnInst *RI : retList) {
  4920. IRBuilder<> RetBuilder(RI);
  4921. // copy tmp to arg.
  4922. CallInst *tmpToArg =
  4923. RetBuilder.CreateMemCpy(&arg, TmpArg, size, 0);
  4924. // Split the memcpy.
  4925. MemcpySplitter::SplitMemCpy(cast<MemCpyInst>(tmpToArg), DL, nullptr,
  4926. typeSys);
  4927. }
  4928. }
  4929. // TODO: support other DxilParamInputQual.
  4930. }
  4931. }
  4932. }
  4933. /// moveFunctionBlocks - Move body of F to flatF.
  4934. void SROA_Parameter_HLSL::moveFunctionBody(Function *F, Function *flatF) {
  4935. bool updateRetType = F->getReturnType() != flatF->getReturnType();
  4936. // Splice the body of the old function right into the new function.
  4937. flatF->getBasicBlockList().splice(flatF->begin(), F->getBasicBlockList());
  4938. // Update Block uses.
  4939. if (updateRetType) {
  4940. for (BasicBlock &BB : flatF->getBasicBlockList()) {
  4941. if (updateRetType) {
  4942. // Replace ret with ret void.
  4943. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  4944. // Create store for return.
  4945. IRBuilder<> Builder(RI);
  4946. Builder.CreateRetVoid();
  4947. RI->eraseFromParent();
  4948. }
  4949. }
  4950. }
  4951. }
  4952. }
  4953. static void SplitArrayCopy(Value *V, DxilTypeSystem &typeSys,
  4954. DxilFieldAnnotation *fieldAnnotation) {
  4955. for (auto U = V->user_begin(); U != V->user_end();) {
  4956. User *user = *(U++);
  4957. if (StoreInst *ST = dyn_cast<StoreInst>(user)) {
  4958. Value *ptr = ST->getPointerOperand();
  4959. Value *val = ST->getValueOperand();
  4960. IRBuilder<> Builder(ST);
  4961. SmallVector<Value *, 16> idxList;
  4962. SplitCpy(ptr->getType(), ptr, val, idxList, Builder, typeSys,
  4963. fieldAnnotation);
  4964. ST->eraseFromParent();
  4965. }
  4966. }
  4967. }
  4968. static void CheckArgUsage(Value *V, bool &bLoad, bool &bStore) {
  4969. if (bLoad && bStore)
  4970. return;
  4971. for (User *user : V->users()) {
  4972. if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
  4973. bLoad = true;
  4974. } else if (StoreInst *SI = dyn_cast<StoreInst>(user)) {
  4975. bStore = true;
  4976. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  4977. CheckArgUsage(GEP, bLoad, bStore);
  4978. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  4979. if (CI->getType()->isPointerTy())
  4980. CheckArgUsage(CI, bLoad, bStore);
  4981. else {
  4982. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  4983. if (group == HLOpcodeGroup::HLMatLoadStore) {
  4984. HLMatLoadStoreOpcode opcode =
  4985. static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
  4986. switch (opcode) {
  4987. case HLMatLoadStoreOpcode::ColMatLoad:
  4988. case HLMatLoadStoreOpcode::RowMatLoad:
  4989. bLoad = true;
  4990. break;
  4991. case HLMatLoadStoreOpcode::ColMatStore:
  4992. case HLMatLoadStoreOpcode::RowMatStore:
  4993. bStore = true;
  4994. break;
  4995. }
  4996. }
  4997. }
  4998. }
  4999. }
  5000. }
  5001. // Support store to input and load from output.
  5002. static void LegalizeDxilInputOutputs(Function *F,
  5003. DxilFunctionAnnotation *EntryAnnotation,
  5004. DxilTypeSystem &typeSys) {
  5005. BasicBlock &EntryBlk = F->getEntryBlock();
  5006. Module *M = F->getParent();
  5007. // Map from output to the temp created for it.
  5008. std::unordered_map<Argument *, Value*> outputTempMap;
  5009. for (Argument &arg : F->args()) {
  5010. Type *Ty = arg.getType();
  5011. DxilParameterAnnotation &paramAnnotation = EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  5012. DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
  5013. bool isColMajor = false;
  5014. // Skip arg which is not a pointer.
  5015. if (!Ty->isPointerTy()) {
  5016. if (HLMatrixLower::IsMatrixType(Ty)) {
  5017. // Replace matrix arg with cast to vec. It will be lowered in
  5018. // DxilGenerationPass.
  5019. isColMajor = paramAnnotation.GetMatrixAnnotation().Orientation ==
  5020. MatrixOrientation::ColumnMajor;
  5021. IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
  5022. HLCastOpcode opcode = isColMajor ? HLCastOpcode::ColMatrixToVecCast
  5023. : HLCastOpcode::RowMatrixToVecCast;
  5024. Value *undefVal = UndefValue::get(Ty);
  5025. Value *Cast = HLModule::EmitHLOperationCall(
  5026. Builder, HLOpcodeGroup::HLCast, static_cast<unsigned>(opcode), Ty,
  5027. {undefVal}, *M);
  5028. arg.replaceAllUsesWith(Cast);
  5029. // Set arg as the operand.
  5030. CallInst *CI = cast<CallInst>(Cast);
  5031. CI->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, &arg);
  5032. }
  5033. continue;
  5034. }
  5035. Ty = Ty->getPointerElementType();
  5036. bool bLoad = false;
  5037. bool bStore = false;
  5038. CheckArgUsage(&arg, bLoad, bStore);
  5039. bool bNeedTemp = false;
  5040. bool bStoreInputToTemp = false;
  5041. bool bLoadOutputFromTemp = false;
  5042. if (qual == DxilParamInputQual::In && bStore) {
  5043. bNeedTemp = true;
  5044. bStoreInputToTemp = true;
  5045. } else if (qual == DxilParamInputQual::Out && bLoad) {
  5046. bNeedTemp = true;
  5047. bLoadOutputFromTemp = true;
  5048. } else if (bLoad && bStore) {
  5049. switch (qual) {
  5050. case DxilParamInputQual::InputPrimitive:
  5051. case DxilParamInputQual::InputPatch:
  5052. case DxilParamInputQual::OutputPatch: {
  5053. bNeedTemp = true;
  5054. bStoreInputToTemp = true;
  5055. } break;
  5056. case DxilParamInputQual::Inout:
  5057. break;
  5058. default:
  5059. DXASSERT(0, "invalid input qual here");
  5060. }
  5061. } else if (qual == DxilParamInputQual::Inout) {
  5062. // Only replace inout when (bLoad && bStore) == false.
  5063. bNeedTemp = true;
  5064. bLoadOutputFromTemp = true;
  5065. bStoreInputToTemp = true;
  5066. }
  5067. if (HLMatrixLower::IsMatrixType(Ty)) {
  5068. bNeedTemp = true;
  5069. if (qual == DxilParamInputQual::In)
  5070. bStoreInputToTemp = bLoad;
  5071. else if (qual == DxilParamInputQual::Out)
  5072. bLoadOutputFromTemp = bStore;
  5073. else if (qual == DxilParamInputQual::Inout) {
  5074. bStoreInputToTemp = true;
  5075. bLoadOutputFromTemp = true;
  5076. }
  5077. }
  5078. if (bNeedTemp) {
  5079. IRBuilder<> Builder(EntryBlk.getFirstInsertionPt());
  5080. AllocaInst *temp = Builder.CreateAlloca(Ty);
  5081. // Replace all uses with temp.
  5082. arg.replaceAllUsesWith(temp);
  5083. // Copy input to temp.
  5084. if (bStoreInputToTemp) {
  5085. llvm::SmallVector<llvm::Value *, 16> idxList;
  5086. // split copy.
  5087. SplitCpy(temp->getType(), temp, &arg, idxList, Builder, typeSys,
  5088. &paramAnnotation);
  5089. }
  5090. // Generate store output, temp later.
  5091. if (bLoadOutputFromTemp) {
  5092. outputTempMap[&arg] = temp;
  5093. }
  5094. }
  5095. }
  5096. for (BasicBlock &BB : F->getBasicBlockList()) {
  5097. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  5098. IRBuilder<> Builder(RI);
  5099. // Copy temp to output.
  5100. for (auto It : outputTempMap) {
  5101. Argument *output = It.first;
  5102. Value *temp = It.second;
  5103. llvm::SmallVector<llvm::Value *, 16> idxList;
  5104. DxilParameterAnnotation &paramAnnotation =
  5105. EntryAnnotation->GetParameterAnnotation(output->getArgNo());
  5106. auto Iter = Builder.GetInsertPoint();
  5107. bool onlyRetBlk = false;
  5108. if (RI != BB.begin())
  5109. Iter--;
  5110. else
  5111. onlyRetBlk = true;
  5112. // split copy.
  5113. SplitCpy(output->getType(), output, temp, idxList, Builder, typeSys,
  5114. &paramAnnotation);
  5115. }
  5116. // Clone the return.
  5117. Builder.CreateRet(RI->getReturnValue());
  5118. RI->eraseFromParent();
  5119. }
  5120. }
  5121. }
  5122. void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
  5123. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5124. // Skip void (void) function.
  5125. if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
  5126. return;
  5127. }
  5128. // Clear maps for cast.
  5129. castParamMap.clear();
  5130. vectorEltsMap.clear();
  5131. DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
  5132. DXASSERT(funcAnnotation, "must find annotation for function");
  5133. std::deque<Value *> WorkList;
  5134. LLVMContext &Ctx = m_pHLModule->GetCtx();
  5135. std::unique_ptr<BasicBlock> TmpBlockForFuncDecl;
  5136. if (F->isDeclaration()) {
  5137. TmpBlockForFuncDecl.reset(BasicBlock::Create(Ctx));
  5138. // Create return as terminator.
  5139. IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get());
  5140. RetBuilder.CreateRetVoid();
  5141. }
  5142. std::vector<Value *> FlatParamList;
  5143. std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
  5144. std::vector<int> FlatParamOriArgNoList;
  5145. const bool bForParamTrue = true;
  5146. // Add all argument to worklist.
  5147. for (Argument &Arg : F->args()) {
  5148. // merge GEP use for arg.
  5149. HLModule::MergeGepUse(&Arg);
  5150. // Insert point may be removed. So recreate builder every time.
  5151. IRBuilder<> Builder(Ctx);
  5152. if (!F->isDeclaration()) {
  5153. Builder.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
  5154. } else {
  5155. Builder.SetInsertPoint(TmpBlockForFuncDecl->getFirstInsertionPt());
  5156. }
  5157. unsigned prevFlatParamCount = FlatParamList.size();
  5158. DxilParameterAnnotation &paramAnnotation =
  5159. funcAnnotation->GetParameterAnnotation(Arg.getArgNo());
  5160. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&Arg);
  5161. flattenArgument(F, &Arg, bForParamTrue, paramAnnotation, FlatParamList,
  5162. FlatParamAnnotationList, Builder, DDI);
  5163. unsigned newFlatParamCount = FlatParamList.size() - prevFlatParamCount;
  5164. for (unsigned i = 0; i < newFlatParamCount; i++) {
  5165. FlatParamOriArgNoList.emplace_back(Arg.getArgNo());
  5166. }
  5167. }
  5168. Type *retType = F->getReturnType();
  5169. std::vector<Value *> FlatRetList;
  5170. std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
  5171. // Split and change to out parameter.
  5172. if (!retType->isVoidTy()) {
  5173. IRBuilder<> Builder(Ctx);
  5174. if (!F->isDeclaration()) {
  5175. Builder.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
  5176. } else {
  5177. Builder.SetInsertPoint(TmpBlockForFuncDecl->getFirstInsertionPt());
  5178. }
  5179. Value *retValAddr = Builder.CreateAlloca(retType);
  5180. DxilParameterAnnotation &retAnnotation =
  5181. funcAnnotation->GetRetTypeAnnotation();
  5182. Module &M = *m_pHLModule->GetModule();
  5183. Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
  5184. // Create DbgDecl for the ret value.
  5185. if (DISubprogram *funcDI = getDISubprogram(F)) {
  5186. DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
  5187. DITypeIdentifierMap EmptyMap;
  5188. DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
  5189. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  5190. DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
  5191. funcDI->getLine(), RetDIType);
  5192. DIExpression *Expr = nullptr;
  5193. // TODO: how to get col?
  5194. DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0, funcDI);
  5195. DIB.insertDeclare(retValAddr, RetVar, Expr, DL, Builder.GetInsertPoint());
  5196. }
  5197. for (BasicBlock &BB : F->getBasicBlockList()) {
  5198. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  5199. // Create store for return.
  5200. IRBuilder<> RetBuilder(RI);
  5201. if (!retAnnotation.HasMatrixAnnotation()) {
  5202. RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
  5203. } else {
  5204. bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
  5205. MatrixOrientation::RowMajor;
  5206. Value *RetVal = RI->getReturnValue();
  5207. if (!isRowMajor) {
  5208. // Matrix value is row major. ColMatStore require col major.
  5209. // Cast before store.
  5210. RetVal = HLModule::EmitHLOperationCall(
  5211. RetBuilder, HLOpcodeGroup::HLCast,
  5212. static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
  5213. RetVal->getType(), {RetVal}, M);
  5214. }
  5215. unsigned opcode = static_cast<unsigned>(
  5216. isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
  5217. : HLMatLoadStoreOpcode::ColMatStore);
  5218. HLModule::EmitHLOperationCall(RetBuilder,
  5219. HLOpcodeGroup::HLMatLoadStore, opcode,
  5220. voidTy, {retValAddr, RetVal}, M);
  5221. }
  5222. }
  5223. }
  5224. // Create a fake store to keep retValAddr so it can be flattened.
  5225. if (retValAddr->user_empty()) {
  5226. Builder.CreateStore(UndefValue::get(retType), retValAddr);
  5227. }
  5228. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(retValAddr);
  5229. flattenArgument(F, retValAddr, bForParamTrue,
  5230. funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
  5231. FlatRetAnnotationList, Builder, DDI);
  5232. const int kRetArgNo = -1;
  5233. for (unsigned i = 0; i < FlatRetList.size(); i++) {
  5234. FlatParamOriArgNoList.emplace_back(kRetArgNo);
  5235. }
  5236. }
  5237. // Always change return type as parameter.
  5238. // By doing this, no need to check return when generate storeOutput.
  5239. if (FlatRetList.size() ||
  5240. // For empty struct return type.
  5241. !retType->isVoidTy()) {
  5242. // Return value is flattened.
  5243. // Change return value into out parameter.
  5244. retType = Type::getVoidTy(retType->getContext());
  5245. // Merge return data info param data.
  5246. FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
  5247. FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
  5248. FlatRetAnnotationList.begin(),
  5249. FlatRetAnnotationList.end());
  5250. }
  5251. std::vector<Type *> FinalTypeList;
  5252. for (Value * arg : FlatParamList) {
  5253. FinalTypeList.emplace_back(arg->getType());
  5254. }
  5255. unsigned extraParamSize = 0;
  5256. if (m_pHLModule->HasDxilFunctionProps(F)) {
  5257. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
  5258. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5259. auto &VS = funcProps.ShaderProps.VS;
  5260. Type *outFloatTy = Type::getFloatPtrTy(F->getContext());
  5261. // Add out float parameter for each clip plane.
  5262. unsigned i=0;
  5263. for (; i < DXIL::kNumClipPlanes; i++) {
  5264. if (!VS.clipPlanes[i])
  5265. break;
  5266. FinalTypeList.emplace_back(outFloatTy);
  5267. }
  5268. extraParamSize = i;
  5269. }
  5270. }
  5271. FunctionType *flatFuncTy = FunctionType::get(retType, FinalTypeList, false);
  5272. // Return if nothing changed.
  5273. if (flatFuncTy == F->getFunctionType()) {
  5274. // Copy semantic allocation.
  5275. if (!FlatParamAnnotationList.empty()) {
  5276. if (!FlatParamAnnotationList[0].GetSemanticString().empty()) {
  5277. for (unsigned i = 0; i < FlatParamAnnotationList.size(); i++) {
  5278. DxilParameterAnnotation &paramAnnotation = funcAnnotation->GetParameterAnnotation(i);
  5279. DxilParameterAnnotation &flatParamAnnotation = FlatParamAnnotationList[i];
  5280. paramAnnotation.SetSemanticIndexVec(flatParamAnnotation.GetSemanticIndexVec());
  5281. paramAnnotation.SetSemanticString(flatParamAnnotation.GetSemanticString());
  5282. }
  5283. }
  5284. }
  5285. if (!F->isDeclaration()) {
  5286. // Support store to input and load from output.
  5287. LegalizeDxilInputOutputs(F, funcAnnotation, typeSys);
  5288. }
  5289. return;
  5290. }
  5291. std::string flatName = F->getName().str() + ".flat";
  5292. DXASSERT(nullptr == F->getParent()->getFunction(flatName),
  5293. "else overwriting existing function");
  5294. Function *flatF =
  5295. cast<Function>(F->getParent()->getOrInsertFunction(flatName, flatFuncTy));
  5296. funcMap[F] = flatF;
  5297. // Update function debug info.
  5298. if (DISubprogram *funcDI = getDISubprogram(F))
  5299. funcDI->replaceFunction(flatF);
  5300. // Create FunctionAnnotation for flatF.
  5301. DxilFunctionAnnotation *flatFuncAnnotation = m_pHLModule->AddFunctionAnnotation(flatF);
  5302. // Don't need to set Ret Info, flatF always return void now.
  5303. // Param Info
  5304. for (unsigned ArgNo = 0; ArgNo < FlatParamAnnotationList.size(); ++ArgNo) {
  5305. DxilParameterAnnotation &paramAnnotation = flatFuncAnnotation->GetParameterAnnotation(ArgNo);
  5306. paramAnnotation = FlatParamAnnotationList[ArgNo];
  5307. }
  5308. // Function Attr and Parameter Attr.
  5309. // Remove sret first.
  5310. if (F->hasStructRetAttr())
  5311. F->removeFnAttr(Attribute::StructRet);
  5312. for (Argument &arg : F->args()) {
  5313. if (arg.hasStructRetAttr()) {
  5314. Attribute::AttrKind SRet [] = {Attribute::StructRet};
  5315. AttributeSet SRetAS = AttributeSet::get(Ctx, arg.getArgNo() + 1, SRet);
  5316. arg.removeAttr(SRetAS);
  5317. }
  5318. }
  5319. AttributeSet AS = F->getAttributes();
  5320. AttrBuilder FnAttrs(AS.getFnAttributes(), AttributeSet::FunctionIndex);
  5321. AttributeSet flatAS;
  5322. flatAS = flatAS.addAttributes(
  5323. Ctx, AttributeSet::FunctionIndex,
  5324. AttributeSet::get(Ctx, AttributeSet::FunctionIndex, FnAttrs));
  5325. if (!F->isDeclaration()) {
  5326. // Only set Param attribute for function has a body.
  5327. for (unsigned ArgNo = 0; ArgNo < FlatParamAnnotationList.size(); ++ArgNo) {
  5328. unsigned oriArgNo = FlatParamOriArgNoList[ArgNo] + 1;
  5329. AttrBuilder paramAttr(AS, oriArgNo);
  5330. if (oriArgNo == AttributeSet::ReturnIndex)
  5331. paramAttr.addAttribute(Attribute::AttrKind::NoAlias);
  5332. flatAS = flatAS.addAttributes(
  5333. Ctx, ArgNo + 1, AttributeSet::get(Ctx, ArgNo + 1, paramAttr));
  5334. }
  5335. }
  5336. flatF->setAttributes(flatAS);
  5337. DXASSERT(flatF->arg_size() == (extraParamSize + FlatParamAnnotationList.size()), "parameter count mismatch");
  5338. // ShaderProps.
  5339. if (m_pHLModule->HasDxilFunctionProps(F)) {
  5340. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
  5341. std::unique_ptr<DxilFunctionProps> flatFuncProps = std::make_unique<DxilFunctionProps>();
  5342. flatFuncProps->shaderKind = funcProps.shaderKind;
  5343. flatFuncProps->ShaderProps = funcProps.ShaderProps;
  5344. m_pHLModule->AddDxilFunctionProps(flatF, flatFuncProps);
  5345. if (funcProps.shaderKind == ShaderModel::Kind::Vertex) {
  5346. auto &VS = funcProps.ShaderProps.VS;
  5347. unsigned clipArgIndex = FlatParamAnnotationList.size();
  5348. // Add out float SV_ClipDistance for each clip plane.
  5349. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  5350. if (!VS.clipPlanes[i])
  5351. break;
  5352. DxilParameterAnnotation &paramAnnotation =
  5353. flatFuncAnnotation->GetParameterAnnotation(clipArgIndex+i);
  5354. paramAnnotation.SetParamInputQual(DxilParamInputQual::Out);
  5355. Twine semName = Twine("SV_ClipDistance") + Twine(i);
  5356. paramAnnotation.SetSemanticString(semName.str());
  5357. paramAnnotation.SetCompType(DXIL::ComponentType::F32);
  5358. paramAnnotation.AppendSemanticIndex(i);
  5359. }
  5360. }
  5361. }
  5362. if (!F->isDeclaration()) {
  5363. // Move function body into flatF.
  5364. moveFunctionBody(F, flatF);
  5365. // Replace old parameters with flatF Arguments.
  5366. auto argIter = flatF->arg_begin();
  5367. auto flatArgIter = FlatParamList.begin();
  5368. LLVMContext &Context = F->getContext();
  5369. // Parameter cast come from begining of entry block.
  5370. IRBuilder<> Builder(flatF->getEntryBlock().getFirstInsertionPt());
  5371. while (argIter != flatF->arg_end()) {
  5372. Argument *Arg = argIter++;
  5373. if (flatArgIter == FlatParamList.end()) {
  5374. DXASSERT(extraParamSize > 0, "parameter count mismatch");
  5375. break;
  5376. }
  5377. Value *flatArg = *(flatArgIter++);
  5378. if (castParamMap.count(flatArg)) {
  5379. replaceCastParameter(flatArg, castParamMap[flatArg].first, *flatF, Arg,
  5380. castParamMap[flatArg].second, Builder);
  5381. }
  5382. flatArg->replaceAllUsesWith(Arg);
  5383. // Update arg debug info.
  5384. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(flatArg);
  5385. if (DDI) {
  5386. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(Arg));
  5387. DDI->setArgOperand(0, VMD);
  5388. }
  5389. HLModule::MergeGepUse(Arg);
  5390. // Flatten store of array parameter.
  5391. if (Arg->getType()->isPointerTy()) {
  5392. Type *Ty = Arg->getType()->getPointerElementType();
  5393. if (Ty->isArrayTy())
  5394. SplitArrayCopy(
  5395. Arg, typeSys,
  5396. &flatFuncAnnotation->GetParameterAnnotation(Arg->getArgNo()));
  5397. }
  5398. }
  5399. // Support store to input and load from output.
  5400. LegalizeDxilInputOutputs(flatF, flatFuncAnnotation, typeSys);
  5401. }
  5402. }
  5403. void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI) {
  5404. DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
  5405. DXASSERT(funcAnnotation, "must find annotation for function");
  5406. // Clear maps for cast.
  5407. castParamMap.clear();
  5408. vectorEltsMap.clear();
  5409. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5410. std::vector<Value *> FlatParamList;
  5411. std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
  5412. IRBuilder<> AllocaBuilder(
  5413. CI->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
  5414. IRBuilder<> CallBuilder(CI);
  5415. IRBuilder<> RetBuilder(CI->getNextNode());
  5416. Type *retType = F->getReturnType();
  5417. std::vector<Value *> FlatRetList;
  5418. std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
  5419. const bool bForParamFalse = false;
  5420. // Split and change to out parameter.
  5421. if (!retType->isVoidTy()) {
  5422. Value *retValAddr = AllocaBuilder.CreateAlloca(retType);
  5423. // Create DbgDecl for the ret value.
  5424. if (DISubprogram *funcDI = getDISubprogram(F)) {
  5425. DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
  5426. DITypeIdentifierMap EmptyMap;
  5427. DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
  5428. DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
  5429. DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
  5430. funcDI->getLine(), RetDIType);
  5431. DIExpression *Expr = nullptr;
  5432. // TODO: how to get col?
  5433. DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0, funcDI);
  5434. DIB.insertDeclare(retValAddr, RetVar, Expr, DL, CI);
  5435. }
  5436. DxilParameterAnnotation &retAnnotation = funcAnnotation->GetRetTypeAnnotation();
  5437. // Load ret value and replace CI.
  5438. Value *newRetVal = nullptr;
  5439. if (!retAnnotation.HasMatrixAnnotation()) {
  5440. newRetVal = RetBuilder.CreateLoad(retValAddr);
  5441. } else {
  5442. bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
  5443. MatrixOrientation::RowMajor;
  5444. unsigned opcode =
  5445. static_cast<unsigned>(isRowMajor ? HLMatLoadStoreOpcode::RowMatLoad
  5446. : HLMatLoadStoreOpcode::ColMatLoad);
  5447. newRetVal = HLModule::EmitHLOperationCall(RetBuilder, HLOpcodeGroup::HLMatLoadStore,
  5448. opcode, retType, {retValAddr},
  5449. *m_pHLModule->GetModule());
  5450. if (!isRowMajor) {
  5451. // ColMatLoad will return a col major.
  5452. // Matrix value should be row major.
  5453. // Cast it here.
  5454. newRetVal = HLModule::EmitHLOperationCall(
  5455. RetBuilder, HLOpcodeGroup::HLCast,
  5456. static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix), retType,
  5457. {newRetVal}, *m_pHLModule->GetModule());
  5458. }
  5459. }
  5460. CI->replaceAllUsesWith(newRetVal);
  5461. // Flat ret val
  5462. flattenArgument(flatF, retValAddr, bForParamFalse,
  5463. funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
  5464. FlatRetAnnotationList, AllocaBuilder,
  5465. /*DbgDeclareInst*/ nullptr);
  5466. }
  5467. std::vector<Value *> args;
  5468. for (auto &arg : CI->arg_operands()) {
  5469. args.emplace_back(arg.get());
  5470. }
  5471. // Remove CI from user of args.
  5472. CI->dropAllReferences();
  5473. // Add all argument to worklist.
  5474. for (unsigned i=0;i<args.size();i++) {
  5475. DxilParameterAnnotation &paramAnnotation =
  5476. funcAnnotation->GetParameterAnnotation(i);
  5477. Value *arg = args[i];
  5478. Type *Ty = arg->getType();
  5479. if (Ty->isPointerTy()) {
  5480. // For pointer, alloca another pointer, replace in CI.
  5481. Value *tempArg =
  5482. AllocaBuilder.CreateAlloca(arg->getType()->getPointerElementType());
  5483. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  5484. // TODO: support special InputQual like InputPatch.
  5485. if (inputQual == DxilParamInputQual::In ||
  5486. inputQual == DxilParamInputQual::Inout) {
  5487. // Copy in param.
  5488. llvm::SmallVector<llvm::Value *, 16> idxList;
  5489. // split copy to avoid load of struct.
  5490. SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, typeSys,
  5491. &paramAnnotation);
  5492. }
  5493. if (inputQual == DxilParamInputQual::Out ||
  5494. inputQual == DxilParamInputQual::Inout) {
  5495. // Copy out param.
  5496. llvm::SmallVector<llvm::Value *, 16> idxList;
  5497. // split copy to avoid load of struct.
  5498. SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, typeSys,
  5499. &paramAnnotation);
  5500. }
  5501. arg = tempArg;
  5502. flattenArgument(flatF, arg, bForParamFalse, paramAnnotation,
  5503. FlatParamList, FlatParamAnnotationList, AllocaBuilder,
  5504. /*DbgDeclareInst*/ nullptr);
  5505. } else {
  5506. // Cast vector into array.
  5507. if (Ty->isVectorTy()) {
  5508. unsigned vecSize = Ty->getVectorNumElements();
  5509. for (unsigned vi = 0; vi < vecSize; vi++) {
  5510. Value *Elt = CallBuilder.CreateExtractElement(arg, vi);
  5511. // Cannot SROA, save it to final parameter list.
  5512. FlatParamList.emplace_back(Elt);
  5513. // Create ParamAnnotation for V.
  5514. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5515. DxilParameterAnnotation &flatParamAnnotation =
  5516. FlatRetAnnotationList.back();
  5517. flatParamAnnotation = paramAnnotation;
  5518. }
  5519. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  5520. unsigned col, row;
  5521. Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  5522. Value *Mat = arg;
  5523. // Cast matrix to array.
  5524. Type *AT = ArrayType::get(EltTy, col * row);
  5525. arg = AllocaBuilder.CreateAlloca(AT);
  5526. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  5527. castParamMap[arg] = std::make_pair(Mat, inputQual);
  5528. DXASSERT(paramAnnotation.HasMatrixAnnotation(),
  5529. "need matrix annotation here");
  5530. if (paramAnnotation.GetMatrixAnnotation().Orientation ==
  5531. hlsl::MatrixOrientation::RowMajor) {
  5532. castRowMajorParamMap.insert(arg);
  5533. }
  5534. // Cannot SROA, save it to final parameter list.
  5535. FlatParamList.emplace_back(arg);
  5536. // Create ParamAnnotation for V.
  5537. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5538. DxilParameterAnnotation &flatParamAnnotation =
  5539. FlatRetAnnotationList.back();
  5540. flatParamAnnotation = paramAnnotation;
  5541. } else {
  5542. // Cannot SROA, save it to final parameter list.
  5543. FlatParamList.emplace_back(arg);
  5544. // Create ParamAnnotation for V.
  5545. FlatRetAnnotationList.emplace_back(DxilParameterAnnotation());
  5546. DxilParameterAnnotation &flatParamAnnotation =
  5547. FlatRetAnnotationList.back();
  5548. flatParamAnnotation = paramAnnotation;
  5549. }
  5550. }
  5551. }
  5552. // Always change return type as parameter.
  5553. // By doing this, no need to check return when generate storeOutput.
  5554. if (FlatRetList.size() ||
  5555. // For empty struct return type.
  5556. !retType->isVoidTy()) {
  5557. // Merge return data info param data.
  5558. FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
  5559. FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
  5560. FlatRetAnnotationList.begin(),
  5561. FlatRetAnnotationList.end());
  5562. }
  5563. RetBuilder.SetInsertPoint(CI->getNextNode());
  5564. unsigned paramSize = FlatParamList.size();
  5565. for (unsigned i = 0; i < paramSize; i++) {
  5566. Value *&flatArg = FlatParamList[i];
  5567. if (castParamMap.count(flatArg)) {
  5568. replaceCastArgument(flatArg, castParamMap[flatArg].first,
  5569. castParamMap[flatArg].second, CallBuilder,
  5570. RetBuilder);
  5571. if (vectorEltsMap.count(flatArg) && !flatArg->getType()->isPointerTy()) {
  5572. // Vector elements need to be updated.
  5573. SmallVector<Value *, 4> &elts = vectorEltsMap[flatArg];
  5574. // Back one step.
  5575. --i;
  5576. for (Value *elt : elts) {
  5577. FlatParamList[++i] = elt;
  5578. }
  5579. // Don't need elts anymore.
  5580. vectorEltsMap.erase(flatArg);
  5581. }
  5582. }
  5583. }
  5584. CallInst *NewCI = CallBuilder.CreateCall(flatF, FlatParamList);
  5585. CallBuilder.SetInsertPoint(NewCI);
  5586. CI->eraseFromParent();
  5587. }
  5588. void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
  5589. // Update entry function.
  5590. if (F == m_pHLModule->GetEntryFunction()) {
  5591. m_pHLModule->SetEntryFunction(flatF);
  5592. }
  5593. // Update patch constant function.
  5594. if (m_pHLModule->HasDxilFunctionProps(flatF)) {
  5595. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(flatF);
  5596. if (funcProps.shaderKind == DXIL::ShaderKind::Hull) {
  5597. Function *oldPatchConstantFunc =
  5598. funcProps.ShaderProps.HS.patchConstantFunc;
  5599. if (funcMap.count(oldPatchConstantFunc))
  5600. funcProps.ShaderProps.HS.patchConstantFunc =
  5601. funcMap[oldPatchConstantFunc];
  5602. }
  5603. }
  5604. // TODO: flatten vector argument and lower resource argument when flatten
  5605. // functions.
  5606. for (auto it = F->user_begin(); it != F->user_end(); ) {
  5607. CallInst *CI = cast<CallInst>(*(it++));
  5608. createFlattenedFunctionCall(F, flatF, CI);
  5609. }
  5610. }
  5611. // Public interface to the SROA_Parameter_HLSL pass
  5612. ModulePass *llvm::createSROA_Parameter_HLSL() {
  5613. return new SROA_Parameter_HLSL();
  5614. }
  5615. //===----------------------------------------------------------------------===//
  5616. // Lower static global into Alloca.
  5617. //===----------------------------------------------------------------------===//
  5618. namespace {
  5619. class LowerStaticGlobalIntoAlloca : public ModulePass {
  5620. HLModule *m_pHLModule;
  5621. public:
  5622. static char ID; // Pass identification, replacement for typeid
  5623. explicit LowerStaticGlobalIntoAlloca() : ModulePass(ID) {}
  5624. const char *getPassName() const override { return "Lower static global into Alloca"; }
  5625. bool runOnModule(Module &M) override {
  5626. m_pHLModule = &M.GetOrCreateHLModule();
  5627. // Lower static global into allocas.
  5628. std::vector<GlobalVariable *> staticGVs;
  5629. for (GlobalVariable &GV : M.globals()) {
  5630. bool isStaticGlobal =
  5631. dxilutil::IsStaticGlobal(&GV) &&
  5632. GV.getType()->getAddressSpace() == DXIL::kDefaultAddrSpace;
  5633. if (isStaticGlobal &&
  5634. !GV.getType()->getElementType()->isAggregateType()) {
  5635. staticGVs.emplace_back(&GV);
  5636. }
  5637. }
  5638. bool bUpdated = false;
  5639. const DataLayout &DL = M.getDataLayout();
  5640. for (GlobalVariable *GV : staticGVs) {
  5641. bUpdated |= lowerStaticGlobalIntoAlloca(GV, DL);
  5642. }
  5643. return bUpdated;
  5644. }
  5645. private:
  5646. bool lowerStaticGlobalIntoAlloca(GlobalVariable *GV, const DataLayout &DL);
  5647. };
  5648. }
  5649. bool LowerStaticGlobalIntoAlloca::lowerStaticGlobalIntoAlloca(GlobalVariable *GV, const DataLayout &DL) {
  5650. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  5651. unsigned size = DL.getTypeAllocSize(GV->getType()->getElementType());
  5652. PointerStatus PS(size);
  5653. GV->removeDeadConstantUsers();
  5654. PS.analyzePointer(GV, PS, typeSys, /*bStructElt*/ false);
  5655. bool NotStored = (PS.StoredType == PointerStatus::NotStored) ||
  5656. (PS.StoredType == PointerStatus::InitializerStored);
  5657. // Make sure GV only used in one function.
  5658. // Skip GV which don't have store.
  5659. if (PS.HasMultipleAccessingFunctions || NotStored)
  5660. return false;
  5661. Function *F = const_cast<Function*>(PS.AccessingFunction);
  5662. IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
  5663. AllocaInst *AI = Builder.CreateAlloca(GV->getType()->getElementType());
  5664. // Store initializer is exist.
  5665. if (GV->hasInitializer() && !isa<UndefValue>(GV->getInitializer())) {
  5666. Builder.CreateStore(GV->getInitializer(), GV);
  5667. }
  5668. ReplaceConstantWithInst(GV, AI, Builder);
  5669. GV->eraseFromParent();
  5670. return true;
  5671. }
  5672. char LowerStaticGlobalIntoAlloca::ID = 0;
  5673. INITIALIZE_PASS(LowerStaticGlobalIntoAlloca, "static-global-to-alloca",
  5674. "Lower static global into Alloca", false,
  5675. false)
  5676. // Public interface to the LowerStaticGlobalIntoAlloca pass
  5677. ModulePass *llvm::createLowerStaticGlobalIntoAlloca() {
  5678. return new LowerStaticGlobalIntoAlloca();
  5679. }
  5680. //===----------------------------------------------------------------------===//
  5681. // Lower one type to another type.
  5682. //===----------------------------------------------------------------------===//
  5683. namespace {
  5684. class LowerTypePass : public ModulePass {
  5685. public:
  5686. explicit LowerTypePass(char &ID)
  5687. : ModulePass(ID) {}
  5688. bool runOnModule(Module &M) override;
  5689. private:
  5690. bool runOnFunction(Function &F, bool HasDbgInfo);
  5691. AllocaInst *lowerAlloca(AllocaInst *A);
  5692. GlobalVariable *lowerInternalGlobal(GlobalVariable *GV);
  5693. protected:
  5694. virtual bool needToLower(Value *V) = 0;
  5695. virtual void lowerUseWithNewValue(Value *V, Value *NewV) = 0;
  5696. virtual Type *lowerType(Type *Ty) = 0;
  5697. virtual Constant *lowerInitVal(Constant *InitVal, Type *NewTy) = 0;
  5698. virtual StringRef getGlobalPrefix() = 0;
  5699. virtual void initialize(Module &M) {};
  5700. };
  5701. AllocaInst *LowerTypePass::lowerAlloca(AllocaInst *A) {
  5702. IRBuilder<> Builder(A);
  5703. Type *NewTy = lowerType(A->getAllocatedType());
  5704. return Builder.CreateAlloca(NewTy);
  5705. }
  5706. GlobalVariable *LowerTypePass::lowerInternalGlobal(GlobalVariable *GV) {
  5707. Type *NewTy = lowerType(GV->getType()->getPointerElementType());
  5708. // So set init val to undef.
  5709. Constant *InitVal = UndefValue::get(NewTy);
  5710. if (GV->hasInitializer()) {
  5711. Constant *OldInitVal = GV->getInitializer();
  5712. if (isa<ConstantAggregateZero>(OldInitVal))
  5713. InitVal = ConstantAggregateZero::get(NewTy);
  5714. else if (!isa<UndefValue>(OldInitVal)) {
  5715. InitVal = lowerInitVal(OldInitVal, NewTy);
  5716. }
  5717. }
  5718. bool isConst = GV->isConstant();
  5719. GlobalVariable::ThreadLocalMode TLMode = GV->getThreadLocalMode();
  5720. unsigned AddressSpace = GV->getType()->getAddressSpace();
  5721. GlobalValue::LinkageTypes linkage = GV->getLinkage();
  5722. Module *M = GV->getParent();
  5723. GlobalVariable *NewGV = new llvm::GlobalVariable(
  5724. *M, NewTy, /*IsConstant*/ isConst, linkage,
  5725. /*InitVal*/ InitVal, GV->getName() + getGlobalPrefix(),
  5726. /*InsertBefore*/ nullptr, TLMode, AddressSpace);
  5727. return NewGV;
  5728. }
  5729. bool LowerTypePass::runOnFunction(Function &F, bool HasDbgInfo) {
  5730. std::vector<AllocaInst *> workList;
  5731. // Scan the entry basic block, adding allocas to the worklist.
  5732. BasicBlock &BB = F.getEntryBlock();
  5733. for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
  5734. if (!isa<AllocaInst>(I))
  5735. continue;
  5736. AllocaInst *A = cast<AllocaInst>(I);
  5737. if (needToLower(A))
  5738. workList.emplace_back(A);
  5739. }
  5740. LLVMContext &Context = F.getContext();
  5741. for (AllocaInst *A : workList) {
  5742. AllocaInst *NewA = lowerAlloca(A);
  5743. if (HasDbgInfo) {
  5744. // Add debug info.
  5745. DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(A);
  5746. if (DDI) {
  5747. Value *DDIVar = MetadataAsValue::get(Context, DDI->getRawVariable());
  5748. Value *DDIExp = MetadataAsValue::get(Context, DDI->getRawExpression());
  5749. Value *VMD = MetadataAsValue::get(Context, ValueAsMetadata::get(NewA));
  5750. IRBuilder<> debugBuilder(DDI);
  5751. debugBuilder.CreateCall(DDI->getCalledFunction(),
  5752. {VMD, DDIVar, DDIExp});
  5753. }
  5754. }
  5755. // Replace users.
  5756. lowerUseWithNewValue(A, NewA);
  5757. // Remove alloca.
  5758. A->eraseFromParent();
  5759. }
  5760. return true;
  5761. }
  5762. bool LowerTypePass::runOnModule(Module &M) {
  5763. initialize(M);
  5764. // Load up debug information, to cross-reference values and the instructions
  5765. // used to load them.
  5766. bool HasDbgInfo = getDebugMetadataVersionFromModule(M) != 0;
  5767. llvm::DebugInfoFinder Finder;
  5768. if (HasDbgInfo) {
  5769. Finder.processModule(M);
  5770. }
  5771. std::vector<AllocaInst*> multiDimAllocas;
  5772. for (Function &F : M.functions()) {
  5773. if (F.isDeclaration())
  5774. continue;
  5775. runOnFunction(F, HasDbgInfo);
  5776. }
  5777. // Work on internal global.
  5778. std::vector<GlobalVariable *> vecGVs;
  5779. for (GlobalVariable &GV : M.globals()) {
  5780. if (dxilutil::IsStaticGlobal(&GV) || dxilutil::IsSharedMemoryGlobal(&GV)) {
  5781. if (needToLower(&GV) && !GV.user_empty())
  5782. vecGVs.emplace_back(&GV);
  5783. }
  5784. }
  5785. for (GlobalVariable *GV : vecGVs) {
  5786. GlobalVariable *NewGV = lowerInternalGlobal(GV);
  5787. // Add debug info.
  5788. if (HasDbgInfo) {
  5789. HLModule::UpdateGlobalVariableDebugInfo(GV, Finder, NewGV);
  5790. }
  5791. // Replace users.
  5792. lowerUseWithNewValue(GV, NewGV);
  5793. // Remove GV.
  5794. GV->removeDeadConstantUsers();
  5795. GV->eraseFromParent();
  5796. }
  5797. return true;
  5798. }
  5799. }
  5800. //===----------------------------------------------------------------------===//
  5801. // DynamicIndexingVector to Array.
  5802. //===----------------------------------------------------------------------===//
  5803. namespace {
  5804. class DynamicIndexingVectorToArray : public LowerTypePass {
  5805. bool ReplaceAllVectors;
  5806. public:
  5807. explicit DynamicIndexingVectorToArray(bool ReplaceAll = false)
  5808. : LowerTypePass(ID), ReplaceAllVectors(ReplaceAll) {}
  5809. static char ID; // Pass identification, replacement for typeid
  5810. void applyOptions(PassOptions O) override;
  5811. void dumpConfig(raw_ostream &OS) override;
  5812. protected:
  5813. bool needToLower(Value *V) override;
  5814. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  5815. Type *lowerType(Type *Ty) override;
  5816. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  5817. StringRef getGlobalPrefix() override { return ".v"; }
  5818. private:
  5819. bool HasVectorDynamicIndexing(Value *V);
  5820. void ReplaceVecGEP(Value *GEP, ArrayRef<Value *> idxList, Value *A,
  5821. IRBuilder<> &Builder);
  5822. void ReplaceVecArrayGEP(Value *GEP, ArrayRef<Value *> idxList, Value *A,
  5823. IRBuilder<> &Builder);
  5824. void ReplaceVectorWithArray(Value *Vec, Value *Array);
  5825. void ReplaceVectorArrayWithArray(Value *VecArray, Value *Array);
  5826. void ReplaceStaticIndexingOnVector(Value *V);
  5827. };
  5828. void DynamicIndexingVectorToArray::applyOptions(PassOptions O) {
  5829. GetPassOptionBool(O, "ReplaceAllVectors", &ReplaceAllVectors,
  5830. ReplaceAllVectors);
  5831. }
  5832. void DynamicIndexingVectorToArray::dumpConfig(raw_ostream &OS) {
  5833. ModulePass::dumpConfig(OS);
  5834. OS << ",ReplaceAllVectors=" << ReplaceAllVectors;
  5835. }
  5836. void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
  5837. for (auto U = V->user_begin(), E = V->user_end(); U != E;) {
  5838. Value *User = *(U++);
  5839. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5840. // Only work on element access for vector.
  5841. if (GEP->getNumOperands() == 3) {
  5842. auto Idx = GEP->idx_begin();
  5843. // Skip the pointer idx.
  5844. Idx++;
  5845. ConstantInt *constIdx = cast<ConstantInt>(Idx);
  5846. for (auto GEPU = GEP->user_begin(), GEPE = GEP->user_end();
  5847. GEPU != GEPE;) {
  5848. Instruction *GEPUser = cast<Instruction>(*(GEPU++));
  5849. IRBuilder<> Builder(GEPUser);
  5850. if (LoadInst *ldInst = dyn_cast<LoadInst>(GEPUser)) {
  5851. // Change
  5852. // ld a->x
  5853. // into
  5854. // b = ld a
  5855. // b.x
  5856. Value *ldVal = Builder.CreateLoad(V);
  5857. Value *Elt = Builder.CreateExtractElement(ldVal, constIdx);
  5858. ldInst->replaceAllUsesWith(Elt);
  5859. ldInst->eraseFromParent();
  5860. } else {
  5861. // Change
  5862. // st val, a->x
  5863. // into
  5864. // tmp = ld a
  5865. // tmp.x = val
  5866. // st tmp, a
  5867. // Must be store inst here.
  5868. StoreInst *stInst = cast<StoreInst>(GEPUser);
  5869. Value *val = stInst->getValueOperand();
  5870. Value *ldVal = Builder.CreateLoad(V);
  5871. ldVal = Builder.CreateInsertElement(ldVal, val, constIdx);
  5872. Builder.CreateStore(ldVal, V);
  5873. stInst->eraseFromParent();
  5874. }
  5875. }
  5876. GEP->eraseFromParent();
  5877. } else if (GEP->getNumIndices() == 1) {
  5878. Value *Idx = *GEP->idx_begin();
  5879. if (ConstantInt *C = dyn_cast<ConstantInt>(Idx)) {
  5880. if (C->getLimitedValue() == 0) {
  5881. GEP->replaceAllUsesWith(V);
  5882. GEP->eraseFromParent();
  5883. }
  5884. }
  5885. }
  5886. }
  5887. }
  5888. }
  5889. bool DynamicIndexingVectorToArray::needToLower(Value *V) {
  5890. Type *Ty = V->getType()->getPointerElementType();
  5891. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  5892. if (isa<GlobalVariable>(V) || ReplaceAllVectors) {
  5893. return true;
  5894. }
  5895. // Don't lower local vector which only static indexing.
  5896. if (HasVectorDynamicIndexing(V)) {
  5897. return true;
  5898. } else {
  5899. // Change vector indexing with ld st.
  5900. ReplaceStaticIndexingOnVector(V);
  5901. return false;
  5902. }
  5903. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  5904. // Array must be replaced even without dynamic indexing to remove vector
  5905. // type in dxil.
  5906. // TODO: optimize static array index in later pass.
  5907. Type *EltTy = dxilutil::GetArrayEltTy(AT);
  5908. return isa<VectorType>(EltTy);
  5909. }
  5910. return false;
  5911. }
  5912. void DynamicIndexingVectorToArray::ReplaceVecGEP(Value *GEP, ArrayRef<Value *> idxList,
  5913. Value *A, IRBuilder<> &Builder) {
  5914. Value *newGEP = Builder.CreateGEP(A, idxList);
  5915. if (GEP->getType()->getPointerElementType()->isVectorTy()) {
  5916. ReplaceVectorWithArray(GEP, newGEP);
  5917. } else {
  5918. GEP->replaceAllUsesWith(newGEP);
  5919. }
  5920. }
  5921. void DynamicIndexingVectorToArray::ReplaceVectorWithArray(Value *Vec, Value *A) {
  5922. unsigned size = Vec->getType()->getPointerElementType()->getVectorNumElements();
  5923. for (auto U = Vec->user_begin(); U != Vec->user_end();) {
  5924. User *User = (*U++);
  5925. // GlobalVariable user.
  5926. if (isa<ConstantExpr>(User)) {
  5927. if (User->user_empty())
  5928. continue;
  5929. if (GEPOperator *GEP = dyn_cast<GEPOperator>(User)) {
  5930. IRBuilder<> Builder(Vec->getContext());
  5931. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5932. ReplaceVecGEP(GEP, idxList, A, Builder);
  5933. continue;
  5934. }
  5935. }
  5936. // Instrution user.
  5937. Instruction *UserInst = cast<Instruction>(User);
  5938. IRBuilder<> Builder(UserInst);
  5939. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5940. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5941. ReplaceVecGEP(cast<GEPOperator>(GEP), idxList, A, Builder);
  5942. GEP->eraseFromParent();
  5943. } else if (LoadInst *ldInst = dyn_cast<LoadInst>(User)) {
  5944. // If ld whole struct, need to split the load.
  5945. Value *newLd = UndefValue::get(ldInst->getType());
  5946. Value *zero = Builder.getInt32(0);
  5947. for (unsigned i = 0; i < size; i++) {
  5948. Value *idx = Builder.getInt32(i);
  5949. Value *GEP = Builder.CreateInBoundsGEP(A, {zero, idx});
  5950. Value *Elt = Builder.CreateLoad(GEP);
  5951. newLd = Builder.CreateInsertElement(newLd, Elt, i);
  5952. }
  5953. ldInst->replaceAllUsesWith(newLd);
  5954. ldInst->eraseFromParent();
  5955. } else if (StoreInst *stInst = dyn_cast<StoreInst>(User)) {
  5956. Value *val = stInst->getValueOperand();
  5957. Value *zero = Builder.getInt32(0);
  5958. for (unsigned i = 0; i < size; i++) {
  5959. Value *Elt = Builder.CreateExtractElement(val, i);
  5960. Value *idx = Builder.getInt32(i);
  5961. Value *GEP = Builder.CreateInBoundsGEP(A, {zero, idx});
  5962. Builder.CreateStore(Elt, GEP);
  5963. }
  5964. stInst->eraseFromParent();
  5965. } else {
  5966. // Vector parameter should be lowered.
  5967. // No function call should use vector.
  5968. DXASSERT(0, "not implement yet");
  5969. }
  5970. }
  5971. }
  5972. void DynamicIndexingVectorToArray::ReplaceVecArrayGEP(Value *GEP,
  5973. ArrayRef<Value *> idxList, Value *A,
  5974. IRBuilder<> &Builder) {
  5975. Value *newGEP = Builder.CreateGEP(A, idxList);
  5976. Type *Ty = GEP->getType()->getPointerElementType();
  5977. if (Ty->isVectorTy()) {
  5978. ReplaceVectorWithArray(GEP, newGEP);
  5979. } else if (Ty->isArrayTy()) {
  5980. ReplaceVectorArrayWithArray(GEP, newGEP);
  5981. } else {
  5982. DXASSERT(Ty->isSingleValueType(), "must be vector subscript here");
  5983. GEP->replaceAllUsesWith(newGEP);
  5984. }
  5985. }
  5986. void DynamicIndexingVectorToArray::ReplaceVectorArrayWithArray(Value *VA, Value *A) {
  5987. for (auto U = VA->user_begin(); U != VA->user_end();) {
  5988. User *User = *(U++);
  5989. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  5990. IRBuilder<> Builder(GEP);
  5991. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  5992. ReplaceVecArrayGEP(GEP, idxList, A, Builder);
  5993. GEP->eraseFromParent();
  5994. } else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(User)) {
  5995. IRBuilder<> Builder(GEPOp->getContext());
  5996. SmallVector<Value *, 4> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  5997. ReplaceVecArrayGEP(GEPOp, idxList, A, Builder);
  5998. } else {
  5999. DXASSERT(0, "Array pointer should only used by GEP");
  6000. }
  6001. }
  6002. }
  6003. void DynamicIndexingVectorToArray::lowerUseWithNewValue(Value *V, Value *NewV) {
  6004. Type *Ty = V->getType()->getPointerElementType();
  6005. // Replace V with NewV.
  6006. if (Ty->isVectorTy()) {
  6007. ReplaceVectorWithArray(V, NewV);
  6008. } else {
  6009. ReplaceVectorArrayWithArray(V, NewV);
  6010. }
  6011. }
  6012. Type *DynamicIndexingVectorToArray::lowerType(Type *Ty) {
  6013. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  6014. return ArrayType::get(VT->getElementType(), VT->getNumElements());
  6015. } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
  6016. SmallVector<ArrayType *, 4> nestArrayTys;
  6017. nestArrayTys.emplace_back(AT);
  6018. Type *EltTy = AT->getElementType();
  6019. // support multi level of array
  6020. while (EltTy->isArrayTy()) {
  6021. ArrayType *ElAT = cast<ArrayType>(EltTy);
  6022. nestArrayTys.emplace_back(ElAT);
  6023. EltTy = ElAT->getElementType();
  6024. }
  6025. if (EltTy->isVectorTy()) {
  6026. Type *vecAT = ArrayType::get(EltTy->getVectorElementType(),
  6027. EltTy->getVectorNumElements());
  6028. return CreateNestArrayTy(vecAT, nestArrayTys);
  6029. }
  6030. return nullptr;
  6031. }
  6032. return nullptr;
  6033. }
  6034. Constant *DynamicIndexingVectorToArray::lowerInitVal(Constant *InitVal, Type *NewTy) {
  6035. Type *VecTy = InitVal->getType();
  6036. ArrayType *ArrayTy = cast<ArrayType>(NewTy);
  6037. if (VecTy->isVectorTy()) {
  6038. SmallVector<Constant *, 4> Elts;
  6039. for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
  6040. Elts.emplace_back(InitVal->getAggregateElement(i));
  6041. }
  6042. return ConstantArray::get(ArrayTy, Elts);
  6043. } else {
  6044. ArrayType *AT = cast<ArrayType>(VecTy);
  6045. ArrayType *EltArrayTy = cast<ArrayType>(ArrayTy->getElementType());
  6046. SmallVector<Constant *, 4> Elts;
  6047. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  6048. Constant *Elt = lowerInitVal(InitVal->getAggregateElement(i), EltArrayTy);
  6049. Elts.emplace_back(Elt);
  6050. }
  6051. return ConstantArray::get(ArrayTy, Elts);
  6052. }
  6053. }
  6054. bool DynamicIndexingVectorToArray::HasVectorDynamicIndexing(Value *V) {
  6055. return dxilutil::HasDynamicIndexing(V);
  6056. }
  6057. }
  6058. char DynamicIndexingVectorToArray::ID = 0;
  6059. INITIALIZE_PASS(DynamicIndexingVectorToArray, "dynamic-vector-to-array",
  6060. "Replace dynamic indexing vector with array", false,
  6061. false)
  6062. // Public interface to the DynamicIndexingVectorToArray pass
  6063. ModulePass *llvm::createDynamicIndexingVectorToArrayPass(bool ReplaceAllVector) {
  6064. return new DynamicIndexingVectorToArray(ReplaceAllVector);
  6065. }
  6066. //===----------------------------------------------------------------------===//
  6067. // Flatten multi dim array into 1 dim.
  6068. //===----------------------------------------------------------------------===//
  6069. namespace {
  6070. class MultiDimArrayToOneDimArray : public LowerTypePass {
  6071. public:
  6072. explicit MultiDimArrayToOneDimArray() : LowerTypePass(ID) {}
  6073. static char ID; // Pass identification, replacement for typeid
  6074. protected:
  6075. bool needToLower(Value *V) override;
  6076. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  6077. Type *lowerType(Type *Ty) override;
  6078. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  6079. StringRef getGlobalPrefix() override { return ".1dim"; }
  6080. };
  6081. bool MultiDimArrayToOneDimArray::needToLower(Value *V) {
  6082. Type *Ty = V->getType()->getPointerElementType();
  6083. ArrayType *AT = dyn_cast<ArrayType>(Ty);
  6084. if (!AT)
  6085. return false;
  6086. if (!isa<ArrayType>(AT->getElementType())) {
  6087. return false;
  6088. } else {
  6089. // Merge all GEP.
  6090. HLModule::MergeGepUse(V);
  6091. return true;
  6092. }
  6093. }
  6094. void ReplaceMultiDimGEP(User *GEP, Value *OneDim, IRBuilder<> &Builder) {
  6095. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  6096. Value *PtrOffset = GEPIt.getOperand();
  6097. ++GEPIt;
  6098. Value *ArrayIdx = GEPIt.getOperand();
  6099. ++GEPIt;
  6100. Value *VecIdx = nullptr;
  6101. for (; GEPIt != E; ++GEPIt) {
  6102. if (GEPIt->isArrayTy()) {
  6103. unsigned arraySize = GEPIt->getArrayNumElements();
  6104. Value *V = GEPIt.getOperand();
  6105. ArrayIdx = Builder.CreateMul(ArrayIdx, Builder.getInt32(arraySize));
  6106. ArrayIdx = Builder.CreateAdd(V, ArrayIdx);
  6107. } else {
  6108. DXASSERT_NOMSG(isa<VectorType>(*GEPIt));
  6109. VecIdx = GEPIt.getOperand();
  6110. }
  6111. }
  6112. Value *NewGEP = nullptr;
  6113. if (!VecIdx)
  6114. NewGEP = Builder.CreateGEP(OneDim, {PtrOffset, ArrayIdx});
  6115. else
  6116. NewGEP = Builder.CreateGEP(OneDim, {PtrOffset, ArrayIdx, VecIdx});
  6117. GEP->replaceAllUsesWith(NewGEP);
  6118. }
  6119. void MultiDimArrayToOneDimArray::lowerUseWithNewValue(Value *MultiDim, Value *OneDim) {
  6120. LLVMContext &Context = MultiDim->getContext();
  6121. // All users should be element type.
  6122. // Replace users of AI.
  6123. for (auto it = MultiDim->user_begin(); it != MultiDim->user_end();) {
  6124. User *U = *(it++);
  6125. if (U->user_empty())
  6126. continue;
  6127. // Must be GEP.
  6128. GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U);
  6129. if (!GEP) {
  6130. DXASSERT_NOMSG(isa<GEPOperator>(U));
  6131. // NewGEP must be GEPOperator too.
  6132. // No instruction will be build.
  6133. IRBuilder<> Builder(Context);
  6134. ReplaceMultiDimGEP(U, OneDim, Builder);
  6135. } else {
  6136. IRBuilder<> Builder(GEP);
  6137. ReplaceMultiDimGEP(U, OneDim, Builder);
  6138. }
  6139. if (GEP)
  6140. GEP->eraseFromParent();
  6141. }
  6142. }
  6143. Type *MultiDimArrayToOneDimArray::lowerType(Type *Ty) {
  6144. ArrayType *AT = cast<ArrayType>(Ty);
  6145. unsigned arraySize = AT->getNumElements();
  6146. Type *EltTy = AT->getElementType();
  6147. // support multi level of array
  6148. while (EltTy->isArrayTy()) {
  6149. ArrayType *ElAT = cast<ArrayType>(EltTy);
  6150. arraySize *= ElAT->getNumElements();
  6151. EltTy = ElAT->getElementType();
  6152. }
  6153. return ArrayType::get(EltTy, arraySize);
  6154. }
  6155. void FlattenMultiDimConstArray(Constant *V, std::vector<Constant *> &Elts) {
  6156. if (!V->getType()->isArrayTy()) {
  6157. Elts.emplace_back(V);
  6158. } else {
  6159. ArrayType *AT = cast<ArrayType>(V->getType());
  6160. for (unsigned i = 0; i < AT->getNumElements(); i++) {
  6161. FlattenMultiDimConstArray(V->getAggregateElement(i), Elts);
  6162. }
  6163. }
  6164. }
  6165. Constant *MultiDimArrayToOneDimArray::lowerInitVal(Constant *InitVal, Type *NewTy) {
  6166. if (InitVal) {
  6167. // MultiDim array init should be done by store.
  6168. if (isa<ConstantAggregateZero>(InitVal))
  6169. InitVal = ConstantAggregateZero::get(NewTy);
  6170. else if (isa<UndefValue>(InitVal))
  6171. InitVal = UndefValue::get(NewTy);
  6172. else {
  6173. std::vector<Constant *> Elts;
  6174. FlattenMultiDimConstArray(InitVal, Elts);
  6175. InitVal = ConstantArray::get(cast<ArrayType>(NewTy), Elts);
  6176. }
  6177. } else {
  6178. InitVal = UndefValue::get(NewTy);
  6179. }
  6180. return InitVal;
  6181. }
  6182. }
  6183. char MultiDimArrayToOneDimArray::ID = 0;
  6184. INITIALIZE_PASS(MultiDimArrayToOneDimArray, "multi-dim-one-dim",
  6185. "Flatten multi-dim array into one-dim array", false,
  6186. false)
  6187. // Public interface to the SROA_Parameter_HLSL pass
  6188. ModulePass *llvm::createMultiDimArrayToOneDimArrayPass() {
  6189. return new MultiDimArrayToOneDimArray();
  6190. }
  6191. //===----------------------------------------------------------------------===//
  6192. // Lower resource into handle.
  6193. //===----------------------------------------------------------------------===//
  6194. namespace {
  6195. class ResourceToHandle : public LowerTypePass {
  6196. public:
  6197. explicit ResourceToHandle() : LowerTypePass(ID) {}
  6198. static char ID; // Pass identification, replacement for typeid
  6199. protected:
  6200. bool needToLower(Value *V) override;
  6201. void lowerUseWithNewValue(Value *V, Value *NewV) override;
  6202. Type *lowerType(Type *Ty) override;
  6203. Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override;
  6204. StringRef getGlobalPrefix() override { return ".res"; }
  6205. void initialize(Module &M) override;
  6206. private:
  6207. void ReplaceResourceWithHandle(Value *ResPtr, Value *HandlePtr);
  6208. void ReplaceResourceGEPWithHandleGEP(Value *GEP, ArrayRef<Value *> idxList,
  6209. Value *A, IRBuilder<> &Builder);
  6210. void ReplaceResourceArrayWithHandleArray(Value *VA, Value *A);
  6211. Type *m_HandleTy;
  6212. HLModule *m_pHLM;
  6213. };
  6214. void ResourceToHandle::initialize(Module &M) {
  6215. DXASSERT(M.HasHLModule(), "require HLModule");
  6216. m_pHLM = &M.GetHLModule();
  6217. m_HandleTy = m_pHLM->GetOP()->GetHandleType();
  6218. }
  6219. bool ResourceToHandle::needToLower(Value *V) {
  6220. Type *Ty = V->getType()->getPointerElementType();
  6221. Ty = dxilutil::GetArrayEltTy(Ty);
  6222. return (HLModule::IsHLSLObjectType(Ty) && !HLModule::IsStreamOutputType(Ty));
  6223. }
  6224. Type *ResourceToHandle::lowerType(Type *Ty) {
  6225. if ((HLModule::IsHLSLObjectType(Ty) && !HLModule::IsStreamOutputType(Ty))) {
  6226. return m_HandleTy;
  6227. }
  6228. ArrayType *AT = cast<ArrayType>(Ty);
  6229. SmallVector<ArrayType *, 4> nestArrayTys;
  6230. nestArrayTys.emplace_back(AT);
  6231. Type *EltTy = AT->getElementType();
  6232. // support multi level of array
  6233. while (EltTy->isArrayTy()) {
  6234. ArrayType *ElAT = cast<ArrayType>(EltTy);
  6235. nestArrayTys.emplace_back(ElAT);
  6236. EltTy = ElAT->getElementType();
  6237. }
  6238. return CreateNestArrayTy(m_HandleTy, nestArrayTys);
  6239. }
  6240. Constant *ResourceToHandle::lowerInitVal(Constant *InitVal, Type *NewTy) {
  6241. DXASSERT(isa<UndefValue>(InitVal), "resource cannot have real init val");
  6242. return UndefValue::get(NewTy);
  6243. }
  6244. void ResourceToHandle::ReplaceResourceWithHandle(Value *ResPtr,
  6245. Value *HandlePtr) {
  6246. for (auto it = ResPtr->user_begin(); it != ResPtr->user_end();) {
  6247. User *U = *(it++);
  6248. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  6249. IRBuilder<> Builder(LI);
  6250. Value *Handle = Builder.CreateLoad(HandlePtr);
  6251. Type *ResTy = LI->getType();
  6252. // Used by createHandle or Store.
  6253. for (auto ldIt = LI->user_begin(); ldIt != LI->user_end();) {
  6254. User *ldU = *(ldIt++);
  6255. if (StoreInst *SI = dyn_cast<StoreInst>(ldU)) {
  6256. Value *TmpRes = HLModule::EmitHLOperationCall(
  6257. Builder, HLOpcodeGroup::HLCast,
  6258. (unsigned)HLCastOpcode::HandleToResCast, ResTy, {Handle},
  6259. *m_pHLM->GetModule());
  6260. SI->replaceUsesOfWith(LI, TmpRes);
  6261. } else {
  6262. CallInst *CI = cast<CallInst>(ldU);
  6263. DXASSERT(hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction()) == HLOpcodeGroup::HLCreateHandle,
  6264. "must be createHandle");
  6265. CI->replaceAllUsesWith(Handle);
  6266. CI->eraseFromParent();
  6267. }
  6268. }
  6269. LI->eraseFromParent();
  6270. } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  6271. Value *Res = SI->getValueOperand();
  6272. IRBuilder<> Builder(SI);
  6273. // CreateHandle from Res.
  6274. Value *Handle = HLModule::EmitHLOperationCall(
  6275. Builder, HLOpcodeGroup::HLCreateHandle,
  6276. /*opcode*/ 0, m_HandleTy, {Res}, *m_pHLM->GetModule());
  6277. // Store Handle to HandlePtr.
  6278. Builder.CreateStore(Handle, HandlePtr);
  6279. // Remove resource Store.
  6280. SI->eraseFromParent();
  6281. } else {
  6282. DXASSERT(0, "invalid operation on resource");
  6283. }
  6284. }
  6285. }
  6286. void ResourceToHandle::ReplaceResourceGEPWithHandleGEP(
  6287. Value *GEP, ArrayRef<Value *> idxList, Value *A, IRBuilder<> &Builder) {
  6288. Value *newGEP = Builder.CreateGEP(A, idxList);
  6289. Type *Ty = GEP->getType()->getPointerElementType();
  6290. if (Ty->isArrayTy()) {
  6291. ReplaceResourceArrayWithHandleArray(GEP, newGEP);
  6292. } else {
  6293. DXASSERT(HLModule::IsHLSLObjectType(Ty), "must be resource type here");
  6294. ReplaceResourceWithHandle(GEP, newGEP);
  6295. }
  6296. }
  6297. void ResourceToHandle::ReplaceResourceArrayWithHandleArray(Value *VA,
  6298. Value *A) {
  6299. for (auto U = VA->user_begin(); U != VA->user_end();) {
  6300. User *User = *(U++);
  6301. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(User)) {
  6302. IRBuilder<> Builder(GEP);
  6303. SmallVector<Value *, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  6304. ReplaceResourceGEPWithHandleGEP(GEP, idxList, A, Builder);
  6305. GEP->eraseFromParent();
  6306. } else if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(User)) {
  6307. IRBuilder<> Builder(GEPOp->getContext());
  6308. SmallVector<Value *, 4> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  6309. ReplaceResourceGEPWithHandleGEP(GEPOp, idxList, A, Builder);
  6310. } else {
  6311. DXASSERT(0, "Array pointer should only used by GEP");
  6312. }
  6313. }
  6314. }
  6315. void ResourceToHandle::lowerUseWithNewValue(Value *V, Value *NewV) {
  6316. Type *Ty = V->getType()->getPointerElementType();
  6317. // Replace V with NewV.
  6318. if (Ty->isArrayTy()) {
  6319. ReplaceResourceArrayWithHandleArray(V, NewV);
  6320. } else {
  6321. ReplaceResourceWithHandle(V, NewV);
  6322. }
  6323. }
  6324. }
  6325. char ResourceToHandle::ID = 0;
  6326. INITIALIZE_PASS(ResourceToHandle, "resource-handle",
  6327. "Lower resource into handle", false,
  6328. false)
  6329. // Public interface to the ResourceToHandle pass
  6330. ModulePass *llvm::createResourceToHandlePass() {
  6331. return new ResourceToHandle();
  6332. }