CGHLSLMS.cpp 292 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751
  1. //===----- CGHLSLMS.cpp - Interface to HLSL Runtime ----------------===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // CGHLSLMS.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This provides a class for HLSL code generation. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "CGHLSLRuntime.h"
  13. #include "CodeGenFunction.h"
  14. #include "CodeGenModule.h"
  15. #include "CGRecordLayout.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/HLSL/HLMatrixType.h"
  18. #include "dxc/HLSL/HLModule.h"
  19. #include "dxc/DXIL/DxilUtil.h"
  20. #include "dxc/HLSL/HLOperations.h"
  21. #include "dxc/DXIL/DxilOperations.h"
  22. #include "dxc/DXIL/DxilTypeSystem.h"
  23. #include "clang/AST/DeclTemplate.h"
  24. #include "clang/AST/HlslTypes.h"
  25. #include "clang/Frontend/CodeGenOptions.h"
  26. #include "clang/Lex/HLSLMacroExpander.h"
  27. #include "llvm/ADT/STLExtras.h"
  28. #include "llvm/ADT/StringSwitch.h"
  29. #include "llvm/ADT/SmallPtrSet.h"
  30. #include "llvm/ADT/StringSet.h"
  31. #include "llvm/IR/Constants.h"
  32. #include "llvm/IR/IRBuilder.h"
  33. #include "llvm/IR/GetElementPtrTypeIterator.h"
  34. #include "llvm/Transforms/Utils/Cloning.h"
  35. #include "llvm/IR/InstIterator.h"
  36. #include <memory>
  37. #include <unordered_map>
  38. #include <unordered_set>
  39. #include <set>
  40. #include "dxc/DxilRootSignature/DxilRootSignature.h"
  41. #include "dxc/DXIL/DxilCBuffer.h"
  42. #include "clang/Parse/ParseHLSL.h" // root sig would be in Parser if part of lang
  43. #include "dxc/Support/WinIncludes.h" // stream support
  44. #include "dxc/dxcapi.h" // stream support
  45. #include "dxc/HLSL/HLSLExtensionsCodegenHelper.h"
  46. #include "dxc/HLSL/DxilGenerationPass.h" // support pause/resume passes
  47. #include "dxc/HLSL/DxilExportMap.h"
  48. using namespace clang;
  49. using namespace CodeGen;
  50. using namespace hlsl;
  51. using namespace llvm;
  52. using std::unique_ptr;
  53. static const bool KeepUndefinedTrue = true; // Keep interpolation mode undefined if not set explicitly.
  54. namespace {
  55. /// Use this class to represent HLSL cbuffer in high-level DXIL.
  56. class HLCBuffer : public DxilCBuffer {
  57. public:
  58. HLCBuffer() = default;
  59. virtual ~HLCBuffer() = default;
  60. void AddConst(std::unique_ptr<DxilResourceBase> &pItem);
  61. std::vector<std::unique_ptr<DxilResourceBase>> &GetConstants();
  62. private:
  63. std::vector<std::unique_ptr<DxilResourceBase>> constants; // constants inside const buffer
  64. };
  65. //------------------------------------------------------------------------------
  66. //
  67. // HLCBuffer methods.
  68. //
  69. void HLCBuffer::AddConst(std::unique_ptr<DxilResourceBase> &pItem) {
  70. pItem->SetID(constants.size());
  71. constants.push_back(std::move(pItem));
  72. }
  73. std::vector<std::unique_ptr<DxilResourceBase>> &HLCBuffer::GetConstants() {
  74. return constants;
  75. }
  76. class CGMSHLSLRuntime : public CGHLSLRuntime {
  77. private:
  78. /// Convenience reference to LLVM Context
  79. llvm::LLVMContext &Context;
  80. /// Convenience reference to the current module
  81. llvm::Module &TheModule;
  82. HLModule *m_pHLModule;
  83. llvm::Type *CBufferType;
  84. uint32_t globalCBIndex;
  85. // TODO: make sure how minprec works
  86. llvm::DataLayout dataLayout;
  87. // decl map to constant id for program
  88. llvm::DenseMap<HLSLBufferDecl *, uint32_t> constantBufMap;
  89. // Map for resource type to resource metadata value.
  90. std::unordered_map<llvm::Type *, MDNode*> resMetadataMap;
  91. // Map from Constant to register bindings.
  92. llvm::DenseMap<llvm::Constant *,
  93. llvm::SmallVector<std::pair<DXIL::ResourceClass, unsigned>, 1>>
  94. constantRegBindingMap;
  95. bool m_bDebugInfo;
  96. bool m_bIsLib;
  97. // For library, m_ExportMap maps from internal name to zero or more renames
  98. dxilutil::ExportMap m_ExportMap;
  99. HLCBuffer &GetGlobalCBuffer() {
  100. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(globalCBIndex)));
  101. }
  102. void AddConstant(VarDecl *constDecl, HLCBuffer &CB);
  103. uint32_t AddSampler(VarDecl *samplerDecl);
  104. uint32_t AddUAVSRV(VarDecl *decl, hlsl::DxilResourceBase::Class resClass);
  105. bool SetUAVSRV(SourceLocation loc, hlsl::DxilResourceBase::Class resClass,
  106. DxilResource *hlslRes, QualType QualTy);
  107. uint32_t AddCBuffer(HLSLBufferDecl *D);
  108. hlsl::DxilResourceBase::Class TypeToClass(clang::QualType Ty);
  109. void CreateSubobject(DXIL::SubobjectKind kind, const StringRef name, clang::Expr **args,
  110. unsigned int argCount, DXIL::HitGroupType hgType = (DXIL::HitGroupType)(-1));
  111. bool GetAsConstantString(clang::Expr *expr, StringRef *value, bool failWhenEmpty = false);
  112. bool GetAsConstantUInt32(clang::Expr *expr, uint32_t *value);
  113. std::vector<StringRef> ParseSubobjectExportsAssociations(StringRef exports);
  114. // Save the entryFunc so don't need to find it with original name.
  115. struct EntryFunctionInfo {
  116. clang::SourceLocation SL = clang::SourceLocation();
  117. llvm::Function *Func = nullptr;
  118. };
  119. EntryFunctionInfo Entry;
  120. // Map to save patch constant functions
  121. struct PatchConstantInfo {
  122. clang::SourceLocation SL = clang::SourceLocation();
  123. llvm::Function *Func = nullptr;
  124. std::uint32_t NumOverloads = 0;
  125. };
  126. StringMap<PatchConstantInfo> patchConstantFunctionMap;
  127. std::unordered_map<Function *, std::unique_ptr<DxilFunctionProps>>
  128. patchConstantFunctionPropsMap;
  129. bool IsPatchConstantFunction(const Function *F);
  130. std::unordered_map<Function *, const clang::HLSLPatchConstantFuncAttr *>
  131. HSEntryPatchConstantFuncAttr;
  132. // Map to save entry functions.
  133. StringMap<EntryFunctionInfo> entryFunctionMap;
  134. // Map to save static global init exp.
  135. std::unordered_map<Expr *, GlobalVariable *> staticConstGlobalInitMap;
  136. std::unordered_map<GlobalVariable *, std::vector<Constant *>>
  137. staticConstGlobalInitListMap;
  138. std::unordered_map<GlobalVariable *, Function *> staticConstGlobalCtorMap;
  139. // List for functions with clip plane.
  140. std::vector<Function *> clipPlaneFuncList;
  141. std::unordered_map<Value *, DebugLoc> debugInfoMap;
  142. DxilRootSignatureVersion rootSigVer;
  143. Value *EmitHLSLMatrixLoad(CGBuilderTy &Builder, Value *Ptr, QualType Ty);
  144. void EmitHLSLMatrixStore(CGBuilderTy &Builder, Value *Val, Value *DestPtr,
  145. QualType Ty);
  146. // Flatten the val into scalar val and push into elts and eltTys.
  147. void FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Value *, 4> &elts,
  148. SmallVector<QualType, 4> &eltTys, QualType Ty,
  149. Value *val);
  150. // Push every value on InitListExpr into EltValList and EltTyList.
  151. void ScanInitList(CodeGenFunction &CGF, InitListExpr *E,
  152. SmallVector<Value *, 4> &EltValList,
  153. SmallVector<QualType, 4> &EltTyList);
  154. void FlattenAggregatePtrToGepList(CodeGenFunction &CGF, Value *Ptr,
  155. SmallVector<Value *, 4> &idxList,
  156. clang::QualType Type, llvm::Type *Ty,
  157. SmallVector<Value *, 4> &GepList,
  158. SmallVector<QualType, 4> &EltTyList);
  159. void LoadElements(CodeGenFunction &CGF,
  160. ArrayRef<Value *> Ptrs, ArrayRef<QualType> QualTys,
  161. SmallVector<Value *, 4> &Vals);
  162. void ConvertAndStoreElements(CodeGenFunction &CGF,
  163. ArrayRef<Value *> SrcVals, ArrayRef<QualType> SrcQualTys,
  164. ArrayRef<Value *> DstPtrs, ArrayRef<QualType> DstQualTys);
  165. void EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  166. llvm::Value *DestPtr,
  167. SmallVector<Value *, 4> &idxList,
  168. clang::QualType SrcType,
  169. clang::QualType DestType,
  170. llvm::Type *Ty);
  171. void EmitHLSLSplat(CodeGenFunction &CGF, Value *SrcVal,
  172. llvm::Value *DestPtr,
  173. SmallVector<Value *, 4> &idxList,
  174. QualType Type, QualType SrcType,
  175. llvm::Type *Ty);
  176. void EmitHLSLRootSignature(CodeGenFunction &CGF, HLSLRootSignatureAttr *RSA,
  177. llvm::Function *Fn) override;
  178. void CheckParameterAnnotation(SourceLocation SLoc,
  179. const DxilParameterAnnotation &paramInfo,
  180. bool isPatchConstantFunction);
  181. void CheckParameterAnnotation(SourceLocation SLoc,
  182. DxilParamInputQual paramQual,
  183. llvm::StringRef semFullName,
  184. bool isPatchConstantFunction);
  185. void RemapObsoleteSemantic(DxilParameterAnnotation &paramInfo,
  186. bool isPatchConstantFunction);
  187. void SetEntryFunction();
  188. SourceLocation SetSemantic(const NamedDecl *decl,
  189. DxilParameterAnnotation &paramInfo);
  190. hlsl::InterpolationMode GetInterpMode(const Decl *decl, CompType compType,
  191. bool bKeepUndefined);
  192. hlsl::CompType GetCompType(const BuiltinType *BT);
  193. // save intrinsic opcode
  194. std::vector<std::pair<Function *, unsigned>> m_IntrinsicMap;
  195. void AddHLSLIntrinsicOpcodeToFunction(Function *, unsigned opcode);
  196. // Type annotation related.
  197. unsigned ConstructStructAnnotation(DxilStructAnnotation *annotation,
  198. const RecordDecl *RD,
  199. DxilTypeSystem &dxilTypeSys);
  200. unsigned AddTypeAnnotation(QualType Ty, DxilTypeSystem &dxilTypeSys,
  201. unsigned &arrayEltSize);
  202. MDNode *GetOrAddResTypeMD(QualType resTy);
  203. void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotation,
  204. QualType fieldTy,
  205. bool bDefaultRowMajor);
  206. std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
  207. public:
  208. CGMSHLSLRuntime(CodeGenModule &CGM);
  209. /// Add resouce to the program
  210. void addResource(Decl *D) override;
  211. void SetPatchConstantFunction(const EntryFunctionInfo &EntryFunc);
  212. void SetPatchConstantFunctionWithAttr(
  213. const EntryFunctionInfo &EntryFunc,
  214. const clang::HLSLPatchConstantFuncAttr *PatchConstantFuncAttr);
  215. void addSubobject(Decl *D) override;
  216. void FinishCodeGen() override;
  217. bool IsTrivalInitListExpr(CodeGenFunction &CGF, InitListExpr *E) override;
  218. Value *EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E, Value *DestPtr) override;
  219. Constant *EmitHLSLConstInitListExpr(CodeGenModule &CGM, InitListExpr *E) override;
  220. RValue EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF, const FunctionDecl *FD,
  221. const CallExpr *E,
  222. ReturnValueSlot ReturnValue) override;
  223. void EmitHLSLOutParamConversionInit(
  224. CodeGenFunction &CGF, const FunctionDecl *FD, const CallExpr *E,
  225. llvm::SmallVector<LValue, 8> &castArgList,
  226. llvm::SmallVector<const Stmt *, 8> &argList,
  227. const std::function<void(const VarDecl *, llvm::Value *)> &TmpArgMap)
  228. override;
  229. void EmitHLSLOutParamConversionCopyBack(
  230. CodeGenFunction &CGF, llvm::SmallVector<LValue, 8> &castArgList) override;
  231. Value *EmitHLSLMatrixOperationCall(CodeGenFunction &CGF, const clang::Expr *E,
  232. llvm::Type *RetType,
  233. ArrayRef<Value *> paramList) override;
  234. void EmitHLSLDiscard(CodeGenFunction &CGF) override;
  235. Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF, llvm::Type *RetType,
  236. Value *Ptr, Value *Idx, QualType Ty) override;
  237. Value *EmitHLSLMatrixElement(CodeGenFunction &CGF, llvm::Type *RetType,
  238. ArrayRef<Value *> paramList,
  239. QualType Ty) override;
  240. Value *EmitHLSLMatrixLoad(CodeGenFunction &CGF, Value *Ptr,
  241. QualType Ty) override;
  242. void EmitHLSLMatrixStore(CodeGenFunction &CGF, Value *Val, Value *DestPtr,
  243. QualType Ty) override;
  244. void EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  245. llvm::Value *DestPtr,
  246. clang::QualType Ty) override;
  247. void EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *Val,
  248. llvm::Value *DestPtr,
  249. clang::QualType Ty) override;
  250. void EmitHLSLFlatConversion(CodeGenFunction &CGF, Value *Val,
  251. Value *DestPtr,
  252. QualType Ty,
  253. QualType SrcTy) override;
  254. Value *EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src, QualType SrcType,
  255. QualType DstType) override;
  256. void EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  257. clang::QualType SrcTy,
  258. llvm::Value *DestPtr,
  259. clang::QualType DestTy) override;
  260. void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
  261. void EmitHLSLFunctionProlog(llvm::Function *, const FunctionDecl *FD) override;
  262. void AddControlFlowHint(CodeGenFunction &CGF, const Stmt &S,
  263. llvm::TerminatorInst *TI,
  264. ArrayRef<const Attr *> Attrs) override;
  265. void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) override;
  266. /// Get or add constant to the program
  267. HLCBuffer &GetOrCreateCBuffer(HLSLBufferDecl *D);
  268. };
  269. }
  270. void clang::CompileRootSignature(
  271. StringRef rootSigStr, DiagnosticsEngine &Diags, SourceLocation SLoc,
  272. hlsl::DxilRootSignatureVersion rootSigVer,
  273. hlsl::DxilRootSignatureCompilationFlags flags,
  274. hlsl::RootSignatureHandle *pRootSigHandle) {
  275. std::string OSStr;
  276. llvm::raw_string_ostream OS(OSStr);
  277. hlsl::DxilVersionedRootSignatureDesc *D = nullptr;
  278. if (ParseHLSLRootSignature(rootSigStr.data(), rootSigStr.size(), rootSigVer,
  279. flags, &D, SLoc, Diags)) {
  280. CComPtr<IDxcBlob> pSignature;
  281. CComPtr<IDxcBlobEncoding> pErrors;
  282. hlsl::SerializeRootSignature(D, &pSignature, &pErrors, false);
  283. if (pSignature == nullptr) {
  284. assert(pErrors != nullptr && "else serialize failed with no msg");
  285. ReportHLSLRootSigError(Diags, SLoc, (char *)pErrors->GetBufferPointer(),
  286. pErrors->GetBufferSize());
  287. hlsl::DeleteRootSignature(D);
  288. } else {
  289. pRootSigHandle->Assign(D, pSignature);
  290. }
  291. }
  292. }
  293. //------------------------------------------------------------------------------
  294. //
  295. // CGMSHLSLRuntime methods.
  296. //
  297. CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
  298. : CGHLSLRuntime(CGM), Context(CGM.getLLVMContext()),
  299. TheModule(CGM.getModule()),
  300. CBufferType(
  301. llvm::StructType::create(TheModule.getContext(), "ConstantBuffer")),
  302. dataLayout(CGM.getLangOpts().UseMinPrecision
  303. ? hlsl::DXIL::kLegacyLayoutString
  304. : hlsl::DXIL::kNewLayoutString), Entry() {
  305. const hlsl::ShaderModel *SM =
  306. hlsl::ShaderModel::GetByName(CGM.getCodeGenOpts().HLSLProfile.c_str());
  307. // Only accept valid, 6.0 shader model.
  308. if (!SM->IsValid() || SM->GetMajor() != 6) {
  309. DiagnosticsEngine &Diags = CGM.getDiags();
  310. unsigned DiagID =
  311. Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid profile %0");
  312. Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLProfile;
  313. return;
  314. }
  315. if (CGM.getCodeGenOpts().HLSLValidatorMajorVer != 0) {
  316. // Check validator version against minimum for target profile:
  317. unsigned MinMajor, MinMinor;
  318. SM->GetMinValidatorVersion(MinMajor, MinMinor);
  319. if (DXIL::CompareVersions(CGM.getCodeGenOpts().HLSLValidatorMajorVer,
  320. CGM.getCodeGenOpts().HLSLValidatorMinorVer,
  321. MinMajor, MinMinor) < 0) {
  322. DiagnosticsEngine &Diags = CGM.getDiags();
  323. unsigned DiagID =
  324. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  325. "validator version %0,%1 does not support target profile.");
  326. Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLValidatorMajorVer
  327. << CGM.getCodeGenOpts().HLSLValidatorMinorVer;
  328. return;
  329. }
  330. }
  331. m_bIsLib = SM->IsLib();
  332. // TODO: add AllResourceBound.
  333. if (CGM.getCodeGenOpts().HLSLAvoidControlFlow && !CGM.getCodeGenOpts().HLSLAllResourcesBound) {
  334. if (SM->IsSM51Plus()) {
  335. DiagnosticsEngine &Diags = CGM.getDiags();
  336. unsigned DiagID =
  337. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  338. "Gfa option cannot be used in SM_5_1+ unless "
  339. "all_resources_bound flag is specified");
  340. Diags.Report(DiagID);
  341. }
  342. }
  343. // Create HLModule.
  344. const bool skipInit = true;
  345. m_pHLModule = &TheModule.GetOrCreateHLModule(skipInit);
  346. // Set Option.
  347. HLOptions opts;
  348. opts.bIEEEStrict = CGM.getCodeGenOpts().UnsafeFPMath;
  349. opts.bDefaultRowMajor = CGM.getCodeGenOpts().HLSLDefaultRowMajor;
  350. opts.bDisableOptimizations = CGM.getCodeGenOpts().DisableLLVMOpts;
  351. opts.bLegacyCBufferLoad = !CGM.getCodeGenOpts().HLSLNotUseLegacyCBufLoad;
  352. opts.bAllResourcesBound = CGM.getCodeGenOpts().HLSLAllResourcesBound;
  353. opts.PackingStrategy = CGM.getCodeGenOpts().HLSLSignaturePackingStrategy;
  354. opts.bLegacyResourceReservation = CGM.getCodeGenOpts().HLSLLegacyResourceReservation;
  355. opts.bUseMinPrecision = CGM.getLangOpts().UseMinPrecision;
  356. opts.bDX9CompatMode = CGM.getLangOpts().EnableDX9CompatMode;
  357. opts.bFXCCompatMode = CGM.getLangOpts().EnableFXCCompatMode;
  358. m_pHLModule->SetHLOptions(opts);
  359. m_pHLModule->GetOP()->SetMinPrecision(opts.bUseMinPrecision);
  360. m_pHLModule->GetTypeSystem().SetMinPrecision(opts.bUseMinPrecision);
  361. m_pHLModule->SetAutoBindingSpace(CGM.getCodeGenOpts().HLSLDefaultSpace);
  362. m_pHLModule->SetValidatorVersion(CGM.getCodeGenOpts().HLSLValidatorMajorVer, CGM.getCodeGenOpts().HLSLValidatorMinorVer);
  363. m_bDebugInfo = CGM.getCodeGenOpts().getDebugInfo() == CodeGenOptions::FullDebugInfo;
  364. // set profile
  365. m_pHLModule->SetShaderModel(SM);
  366. // set entry name
  367. if (!SM->IsLib())
  368. m_pHLModule->SetEntryFunctionName(CGM.getCodeGenOpts().HLSLEntryFunction);
  369. // set root signature version.
  370. if (CGM.getLangOpts().RootSigMinor == 0) {
  371. rootSigVer = hlsl::DxilRootSignatureVersion::Version_1_0;
  372. }
  373. else {
  374. DXASSERT(CGM.getLangOpts().RootSigMinor == 1,
  375. "else CGMSHLSLRuntime Constructor needs to be updated");
  376. rootSigVer = hlsl::DxilRootSignatureVersion::Version_1_1;
  377. }
  378. DXASSERT(CGM.getLangOpts().RootSigMajor == 1,
  379. "else CGMSHLSLRuntime Constructor needs to be updated");
  380. // add globalCB
  381. unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>();
  382. std::string globalCBName = "$Globals";
  383. CB->SetGlobalSymbol(nullptr);
  384. CB->SetGlobalName(globalCBName);
  385. globalCBIndex = m_pHLModule->GetCBuffers().size();
  386. CB->SetID(globalCBIndex);
  387. CB->SetRangeSize(1);
  388. CB->SetLowerBound(UINT_MAX);
  389. DXVERIFY_NOMSG(globalCBIndex == m_pHLModule->AddCBuffer(std::move(CB)));
  390. // set Float Denorm Mode
  391. m_pHLModule->SetFloat32DenormMode(CGM.getCodeGenOpts().HLSLFloat32DenormMode);
  392. // set DefaultLinkage
  393. m_pHLModule->SetDefaultLinkage(CGM.getCodeGenOpts().DefaultLinkage);
  394. // Fill in m_ExportMap, which maps from internal name to zero or more renames
  395. m_ExportMap.clear();
  396. std::string errors;
  397. llvm::raw_string_ostream os(errors);
  398. if (!m_ExportMap.ParseExports(CGM.getCodeGenOpts().HLSLLibraryExports, os)) {
  399. DiagnosticsEngine &Diags = CGM.getDiags();
  400. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "Error parsing -exports options: %0");
  401. Diags.Report(DiagID) << os.str();
  402. }
  403. }
  404. void CGMSHLSLRuntime::AddHLSLIntrinsicOpcodeToFunction(Function *F,
  405. unsigned opcode) {
  406. m_IntrinsicMap.emplace_back(F,opcode);
  407. }
  408. void CGMSHLSLRuntime::CheckParameterAnnotation(
  409. SourceLocation SLoc, const DxilParameterAnnotation &paramInfo,
  410. bool isPatchConstantFunction) {
  411. if (!paramInfo.HasSemanticString()) {
  412. return;
  413. }
  414. llvm::StringRef semFullName = paramInfo.GetSemanticStringRef();
  415. DxilParamInputQual paramQual = paramInfo.GetParamInputQual();
  416. if (paramQual == DxilParamInputQual::Inout) {
  417. CheckParameterAnnotation(SLoc, DxilParamInputQual::In, semFullName, isPatchConstantFunction);
  418. CheckParameterAnnotation(SLoc, DxilParamInputQual::Out, semFullName, isPatchConstantFunction);
  419. return;
  420. }
  421. CheckParameterAnnotation(SLoc, paramQual, semFullName, isPatchConstantFunction);
  422. }
  423. void CGMSHLSLRuntime::CheckParameterAnnotation(
  424. SourceLocation SLoc, DxilParamInputQual paramQual, llvm::StringRef semFullName,
  425. bool isPatchConstantFunction) {
  426. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  427. DXIL::SigPointKind sigPoint = SigPointFromInputQual(
  428. paramQual, SM->GetKind(), isPatchConstantFunction);
  429. llvm::StringRef semName;
  430. unsigned semIndex;
  431. Semantic::DecomposeNameAndIndex(semFullName, &semName, &semIndex);
  432. const Semantic *pSemantic =
  433. Semantic::GetByName(semName, sigPoint, SM->GetMajor(), SM->GetMinor());
  434. if (pSemantic->IsInvalid()) {
  435. DiagnosticsEngine &Diags = CGM.getDiags();
  436. unsigned DiagID =
  437. Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid semantic '%0' for %1 %2.%3");
  438. Diags.Report(SLoc, DiagID) << semName << SM->GetKindName() << SM->GetMajor() << SM->GetMinor();
  439. }
  440. }
  441. SourceLocation
  442. CGMSHLSLRuntime::SetSemantic(const NamedDecl *decl,
  443. DxilParameterAnnotation &paramInfo) {
  444. for (const hlsl::UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  445. if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
  446. const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
  447. paramInfo.SetSemanticString(sd->SemanticName);
  448. return it->Loc;
  449. }
  450. }
  451. return SourceLocation();
  452. }
  453. static DXIL::TessellatorDomain StringToDomain(StringRef domain) {
  454. if (domain == "isoline")
  455. return DXIL::TessellatorDomain::IsoLine;
  456. if (domain == "tri")
  457. return DXIL::TessellatorDomain::Tri;
  458. if (domain == "quad")
  459. return DXIL::TessellatorDomain::Quad;
  460. return DXIL::TessellatorDomain::Undefined;
  461. }
  462. static DXIL::TessellatorPartitioning StringToPartitioning(StringRef partition) {
  463. if (partition == "integer")
  464. return DXIL::TessellatorPartitioning::Integer;
  465. if (partition == "pow2")
  466. return DXIL::TessellatorPartitioning::Pow2;
  467. if (partition == "fractional_even")
  468. return DXIL::TessellatorPartitioning::FractionalEven;
  469. if (partition == "fractional_odd")
  470. return DXIL::TessellatorPartitioning::FractionalOdd;
  471. return DXIL::TessellatorPartitioning::Undefined;
  472. }
  473. static DXIL::TessellatorOutputPrimitive
  474. StringToTessOutputPrimitive(StringRef primitive) {
  475. if (primitive == "point")
  476. return DXIL::TessellatorOutputPrimitive::Point;
  477. if (primitive == "line")
  478. return DXIL::TessellatorOutputPrimitive::Line;
  479. if (primitive == "triangle_cw")
  480. return DXIL::TessellatorOutputPrimitive::TriangleCW;
  481. if (primitive == "triangle_ccw")
  482. return DXIL::TessellatorOutputPrimitive::TriangleCCW;
  483. return DXIL::TessellatorOutputPrimitive::Undefined;
  484. }
  485. static DXIL::MeshOutputTopology
  486. StringToMeshOutputTopology(StringRef topology) {
  487. if (topology == "line")
  488. return DXIL::MeshOutputTopology::Line;
  489. if (topology == "triangle")
  490. return DXIL::MeshOutputTopology::Triangle;
  491. return DXIL::MeshOutputTopology::Undefined;
  492. }
  493. static unsigned RoundToAlign(unsigned num, unsigned mod) {
  494. // round num to next highest mod
  495. if (mod != 0)
  496. return mod * ((num + mod - 1) / mod);
  497. return num;
  498. }
  499. // Align cbuffer offset in legacy mode (16 bytes per row).
  500. static unsigned AlignBufferOffsetInLegacy(unsigned offset, unsigned size,
  501. unsigned scalarSizeInBytes,
  502. bool bNeedNewRow) {
  503. if (unsigned remainder = (offset & 0xf)) {
  504. // Start from new row
  505. if (remainder + size > 16 || bNeedNewRow) {
  506. return offset + 16 - remainder;
  507. }
  508. // If not, naturally align data
  509. return RoundToAlign(offset, scalarSizeInBytes);
  510. }
  511. return offset;
  512. }
  513. static unsigned AlignBaseOffset(unsigned baseOffset, unsigned size,
  514. QualType Ty, bool bDefaultRowMajor) {
  515. bool needNewAlign = Ty->isArrayType();
  516. if (IsHLSLMatType(Ty)) {
  517. bool bRowMajor = false;
  518. if (!hlsl::HasHLSLMatOrientation(Ty, &bRowMajor))
  519. bRowMajor = bDefaultRowMajor;
  520. unsigned row, col;
  521. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  522. needNewAlign |= !bRowMajor && col > 1;
  523. needNewAlign |= bRowMajor && row > 1;
  524. } else if (Ty->isStructureOrClassType() && ! hlsl::IsHLSLVecType(Ty)) {
  525. needNewAlign = true;
  526. }
  527. unsigned scalarSizeInBytes = 4;
  528. const clang::BuiltinType *BT = Ty->getAs<clang::BuiltinType>();
  529. if (hlsl::IsHLSLVecMatType(Ty)) {
  530. BT = hlsl::GetElementTypeOrType(Ty)->getAs<clang::BuiltinType>();
  531. }
  532. if (BT) {
  533. if (BT->getKind() == clang::BuiltinType::Kind::Double ||
  534. BT->getKind() == clang::BuiltinType::Kind::LongLong)
  535. scalarSizeInBytes = 8;
  536. else if (BT->getKind() == clang::BuiltinType::Kind::Half ||
  537. BT->getKind() == clang::BuiltinType::Kind::Short ||
  538. BT->getKind() == clang::BuiltinType::Kind::UShort)
  539. scalarSizeInBytes = 2;
  540. }
  541. return AlignBufferOffsetInLegacy(baseOffset, size, scalarSizeInBytes, needNewAlign);
  542. }
  543. static unsigned AlignBaseOffset(QualType Ty, unsigned baseOffset,
  544. bool bDefaultRowMajor,
  545. CodeGen::CodeGenModule &CGM,
  546. llvm::DataLayout &layout) {
  547. QualType paramTy = Ty.getCanonicalType();
  548. if (const ReferenceType *RefType = dyn_cast<ReferenceType>(paramTy))
  549. paramTy = RefType->getPointeeType();
  550. // Get size.
  551. llvm::Type *Type = CGM.getTypes().ConvertType(paramTy);
  552. unsigned size = layout.getTypeAllocSize(Type);
  553. return AlignBaseOffset(baseOffset, size, paramTy, bDefaultRowMajor);
  554. }
  555. static unsigned GetMatrixSizeInCB(QualType Ty, bool defaultRowMajor,
  556. bool b64Bit) {
  557. bool bRowMajor;
  558. if (!hlsl::HasHLSLMatOrientation(Ty, &bRowMajor))
  559. bRowMajor = defaultRowMajor;
  560. unsigned row, col;
  561. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  562. unsigned EltSize = b64Bit ? 8 : 4;
  563. // Align to 4 * 4bytes.
  564. unsigned alignment = 4 * 4;
  565. if (bRowMajor) {
  566. unsigned rowSize = EltSize * col;
  567. // 3x64bit or 4x64bit align to 32 bytes.
  568. if (rowSize > alignment)
  569. alignment <<= 1;
  570. return alignment * (row - 1) + col * EltSize;
  571. } else {
  572. unsigned rowSize = EltSize * row;
  573. // 3x64bit or 4x64bit align to 32 bytes.
  574. if (rowSize > alignment)
  575. alignment <<= 1;
  576. return alignment * (col - 1) + row * EltSize;
  577. }
  578. }
  579. static CompType::Kind BuiltinTyToCompTy(const BuiltinType *BTy, bool bSNorm,
  580. bool bUNorm) {
  581. CompType::Kind kind = CompType::Kind::Invalid;
  582. switch (BTy->getKind()) {
  583. case BuiltinType::UInt:
  584. kind = CompType::Kind::U32;
  585. break;
  586. case BuiltinType::Min16UInt: // HLSL Change
  587. case BuiltinType::UShort:
  588. kind = CompType::Kind::U16;
  589. break;
  590. case BuiltinType::ULongLong:
  591. kind = CompType::Kind::U64;
  592. break;
  593. case BuiltinType::Int:
  594. kind = CompType::Kind::I32;
  595. break;
  596. // HLSL Changes begin
  597. case BuiltinType::Min12Int:
  598. case BuiltinType::Min16Int:
  599. // HLSL Changes end
  600. case BuiltinType::Short:
  601. kind = CompType::Kind::I16;
  602. break;
  603. case BuiltinType::LongLong:
  604. kind = CompType::Kind::I64;
  605. break;
  606. // HLSL Changes begin
  607. case BuiltinType::Min10Float:
  608. case BuiltinType::Min16Float:
  609. // HLSL Changes end
  610. case BuiltinType::Half:
  611. if (bSNorm)
  612. kind = CompType::Kind::SNormF16;
  613. else if (bUNorm)
  614. kind = CompType::Kind::UNormF16;
  615. else
  616. kind = CompType::Kind::F16;
  617. break;
  618. case BuiltinType::HalfFloat: // HLSL Change
  619. case BuiltinType::Float:
  620. if (bSNorm)
  621. kind = CompType::Kind::SNormF32;
  622. else if (bUNorm)
  623. kind = CompType::Kind::UNormF32;
  624. else
  625. kind = CompType::Kind::F32;
  626. break;
  627. case BuiltinType::Double:
  628. if (bSNorm)
  629. kind = CompType::Kind::SNormF64;
  630. else if (bUNorm)
  631. kind = CompType::Kind::UNormF64;
  632. else
  633. kind = CompType::Kind::F64;
  634. break;
  635. case BuiltinType::Bool:
  636. kind = CompType::Kind::I1;
  637. break;
  638. default:
  639. // Other types not used by HLSL.
  640. break;
  641. }
  642. return kind;
  643. }
  644. static DxilSampler::SamplerKind KeywordToSamplerKind(llvm::StringRef keyword) {
  645. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  646. // compare)
  647. return llvm::StringSwitch<DxilSampler::SamplerKind>(keyword)
  648. .Case("SamplerState", DxilSampler::SamplerKind::Default)
  649. .Case("SamplerComparisonState", DxilSampler::SamplerKind::Comparison)
  650. .Default(DxilSampler::SamplerKind::Invalid);
  651. }
  652. MDNode *CGMSHLSLRuntime::GetOrAddResTypeMD(QualType resTy) {
  653. const RecordType *RT = resTy->getAs<RecordType>();
  654. if (!RT)
  655. return nullptr;
  656. RecordDecl *RD = RT->getDecl();
  657. SourceLocation loc = RD->getLocation();
  658. hlsl::DxilResourceBase::Class resClass = TypeToClass(resTy);
  659. llvm::Type *Ty = CGM.getTypes().ConvertType(resTy);
  660. auto it = resMetadataMap.find(Ty);
  661. if (it != resMetadataMap.end())
  662. return it->second;
  663. // Save resource type metadata.
  664. switch (resClass) {
  665. case DXIL::ResourceClass::UAV: {
  666. DxilResource UAV;
  667. // TODO: save globalcoherent to variable in EmitHLSLBuiltinCallExpr.
  668. SetUAVSRV(loc, resClass, &UAV, resTy);
  669. // Set global symbol to save type.
  670. UAV.SetGlobalSymbol(UndefValue::get(Ty));
  671. MDNode *MD = m_pHLModule->DxilUAVToMDNode(UAV);
  672. resMetadataMap[Ty] = MD;
  673. return MD;
  674. } break;
  675. case DXIL::ResourceClass::SRV: {
  676. DxilResource SRV;
  677. SetUAVSRV(loc, resClass, &SRV, resTy);
  678. // Set global symbol to save type.
  679. SRV.SetGlobalSymbol(UndefValue::get(Ty));
  680. MDNode *MD = m_pHLModule->DxilSRVToMDNode(SRV);
  681. resMetadataMap[Ty] = MD;
  682. return MD;
  683. } break;
  684. case DXIL::ResourceClass::Sampler: {
  685. DxilSampler S;
  686. DxilSampler::SamplerKind kind = KeywordToSamplerKind(RD->getName());
  687. S.SetSamplerKind(kind);
  688. // Set global symbol to save type.
  689. S.SetGlobalSymbol(UndefValue::get(Ty));
  690. MDNode *MD = m_pHLModule->DxilSamplerToMDNode(S);
  691. resMetadataMap[Ty] = MD;
  692. return MD;
  693. }
  694. default:
  695. // Skip OutputStream for GS.
  696. return nullptr;
  697. }
  698. }
  699. namespace {
  700. MatrixOrientation GetMatrixMajor(QualType Ty, bool bDefaultRowMajor) {
  701. DXASSERT(hlsl::IsHLSLMatType(Ty), "");
  702. bool bIsRowMajor = bDefaultRowMajor;
  703. HasHLSLMatOrientation(Ty, &bIsRowMajor);
  704. return bIsRowMajor ? MatrixOrientation::RowMajor
  705. : MatrixOrientation::ColumnMajor;
  706. }
  707. QualType GetArrayEltType(ASTContext& Context, QualType Ty) {
  708. while (const clang::ArrayType *ArrayTy = Context.getAsArrayType(Ty))
  709. Ty = ArrayTy->getElementType();
  710. return Ty;
  711. }
  712. } // namespace
  713. void CGMSHLSLRuntime::ConstructFieldAttributedAnnotation(
  714. DxilFieldAnnotation &fieldAnnotation, QualType fieldTy,
  715. bool bDefaultRowMajor) {
  716. QualType Ty = fieldTy;
  717. if (Ty->isReferenceType())
  718. Ty = Ty.getNonReferenceType();
  719. // Get element type.
  720. Ty = GetArrayEltType(CGM.getContext(), Ty);
  721. QualType EltTy = Ty;
  722. if (hlsl::IsHLSLMatType(Ty)) {
  723. DxilMatrixAnnotation Matrix;
  724. Matrix.Orientation = GetMatrixMajor(Ty, bDefaultRowMajor);
  725. hlsl::GetHLSLMatRowColCount(Ty, Matrix.Rows, Matrix.Cols);
  726. fieldAnnotation.SetMatrixAnnotation(Matrix);
  727. EltTy = hlsl::GetHLSLMatElementType(Ty);
  728. }
  729. if (hlsl::IsHLSLVecType(Ty))
  730. EltTy = hlsl::GetHLSLVecElementType(Ty);
  731. if (IsHLSLResourceType(Ty)) {
  732. MDNode *MD = GetOrAddResTypeMD(Ty);
  733. fieldAnnotation.SetResourceAttribute(MD);
  734. }
  735. bool bSNorm = false;
  736. bool bUNorm = false;
  737. if (HasHLSLUNormSNorm(Ty, &bSNorm) && !bSNorm)
  738. bUNorm = true;
  739. if (EltTy->isBuiltinType()) {
  740. const BuiltinType *BTy = EltTy->getAs<BuiltinType>();
  741. CompType::Kind kind = BuiltinTyToCompTy(BTy, bSNorm, bUNorm);
  742. fieldAnnotation.SetCompType(kind);
  743. } else if (EltTy->isEnumeralType()) {
  744. const EnumType *ETy = EltTy->getAs<EnumType>();
  745. QualType type = ETy->getDecl()->getIntegerType();
  746. if (const BuiltinType *BTy =
  747. dyn_cast<BuiltinType>(type->getCanonicalTypeInternal()))
  748. fieldAnnotation.SetCompType(BuiltinTyToCompTy(BTy, bSNorm, bUNorm));
  749. } else {
  750. DXASSERT(!bSNorm && !bUNorm,
  751. "snorm/unorm on invalid type, validate at handleHLSLTypeAttr");
  752. }
  753. }
  754. static void ConstructFieldInterpolation(DxilFieldAnnotation &fieldAnnotation,
  755. FieldDecl *fieldDecl) {
  756. // Keep undefined for interpMode here.
  757. InterpolationMode InterpMode = {fieldDecl->hasAttr<HLSLNoInterpolationAttr>(),
  758. fieldDecl->hasAttr<HLSLLinearAttr>(),
  759. fieldDecl->hasAttr<HLSLNoPerspectiveAttr>(),
  760. fieldDecl->hasAttr<HLSLCentroidAttr>(),
  761. fieldDecl->hasAttr<HLSLSampleAttr>()};
  762. if (InterpMode.GetKind() != InterpolationMode::Kind::Undefined)
  763. fieldAnnotation.SetInterpolationMode(InterpMode);
  764. }
  765. unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annotation,
  766. const RecordDecl *RD,
  767. DxilTypeSystem &dxilTypeSys) {
  768. unsigned fieldIdx = 0;
  769. unsigned offset = 0;
  770. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  771. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  772. // If template, save template args
  773. if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
  774. dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
  775. const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
  776. for (unsigned i = 0; i < args.size(); ++i) {
  777. DxilTemplateArgAnnotation &argAnnotation = annotation->GetTemplateArgAnnotation(i);
  778. const clang::TemplateArgument &arg = args[i];
  779. switch (arg.getKind()) {
  780. case clang::TemplateArgument::ArgKind::Type:
  781. argAnnotation.SetType(CGM.getTypes().ConvertType(arg.getAsType()));
  782. break;
  783. case clang::TemplateArgument::ArgKind::Integral:
  784. argAnnotation.SetIntegral(arg.getAsIntegral().getExtValue());
  785. break;
  786. default:
  787. break;
  788. }
  789. }
  790. }
  791. if (CXXRD->getNumBases()) {
  792. // Add base as field.
  793. for (const auto &I : CXXRD->bases()) {
  794. const CXXRecordDecl *BaseDecl =
  795. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  796. std::string fieldSemName = "";
  797. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  798. // Align offset.
  799. offset = AlignBaseOffset(parentTy, offset, bDefaultRowMajor, CGM,
  800. dataLayout);
  801. unsigned CBufferOffset = offset;
  802. unsigned arrayEltSize = 0;
  803. // Process field to make sure the size of field is ready.
  804. unsigned size =
  805. AddTypeAnnotation(parentTy, dxilTypeSys, arrayEltSize);
  806. // Update offset.
  807. offset += size;
  808. if (size > 0) {
  809. DxilFieldAnnotation &fieldAnnotation =
  810. annotation->GetFieldAnnotation(fieldIdx++);
  811. fieldAnnotation.SetCBufferOffset(CBufferOffset);
  812. fieldAnnotation.SetFieldName(BaseDecl->getNameAsString());
  813. }
  814. }
  815. }
  816. }
  817. for (auto fieldDecl : RD->fields()) {
  818. std::string fieldSemName = "";
  819. QualType fieldTy = fieldDecl->getType();
  820. DXASSERT(!fieldDecl->isBitField(), "We should have already ensured we have no bitfields.");
  821. // Align offset.
  822. offset = AlignBaseOffset(fieldTy, offset, bDefaultRowMajor, CGM, dataLayout);
  823. unsigned CBufferOffset = offset;
  824. // Try to get info from fieldDecl.
  825. for (const hlsl::UnusualAnnotation *it :
  826. fieldDecl->getUnusualAnnotations()) {
  827. switch (it->getKind()) {
  828. case hlsl::UnusualAnnotation::UA_SemanticDecl: {
  829. const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
  830. fieldSemName = sd->SemanticName;
  831. } break;
  832. case hlsl::UnusualAnnotation::UA_ConstantPacking: {
  833. const hlsl::ConstantPacking *cp = cast<hlsl::ConstantPacking>(it);
  834. CBufferOffset = cp->Subcomponent << 2;
  835. CBufferOffset += cp->ComponentOffset;
  836. // Change to byte.
  837. CBufferOffset <<= 2;
  838. } break;
  839. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  840. // register assignment only works on global constant.
  841. DiagnosticsEngine &Diags = CGM.getDiags();
  842. unsigned DiagID = Diags.getCustomDiagID(
  843. DiagnosticsEngine::Error,
  844. "location semantics cannot be specified on members.");
  845. Diags.Report(it->Loc, DiagID);
  846. return 0;
  847. } break;
  848. default:
  849. llvm_unreachable("only semantic for input/output");
  850. break;
  851. }
  852. }
  853. unsigned arrayEltSize = 0;
  854. // Process field to make sure the size of field is ready.
  855. unsigned size = AddTypeAnnotation(fieldDecl->getType(), dxilTypeSys, arrayEltSize);
  856. // Update offset.
  857. offset += size;
  858. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(fieldIdx++);
  859. ConstructFieldAttributedAnnotation(fieldAnnotation, fieldTy, bDefaultRowMajor);
  860. ConstructFieldInterpolation(fieldAnnotation, fieldDecl);
  861. if (fieldDecl->hasAttr<HLSLPreciseAttr>())
  862. fieldAnnotation.SetPrecise();
  863. fieldAnnotation.SetCBufferOffset(CBufferOffset);
  864. fieldAnnotation.SetFieldName(fieldDecl->getName());
  865. if (!fieldSemName.empty())
  866. fieldAnnotation.SetSemanticString(fieldSemName);
  867. }
  868. annotation->SetCBufferSize(offset);
  869. if (offset == 0) {
  870. annotation->MarkEmptyStruct();
  871. }
  872. return offset;
  873. }
  874. static bool IsElementInputOutputType(QualType Ty) {
  875. return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty) || Ty->isEnumeralType();
  876. }
  877. static unsigned GetNumTemplateArgsForRecordDecl(const RecordDecl *RD) {
  878. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  879. if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
  880. dyn_cast<ClassTemplateSpecializationDecl>(CXXRD)) {
  881. const clang::TemplateArgumentList &args = templateSpecializationDecl->getTemplateInstantiationArgs();
  882. return args.size();
  883. }
  884. }
  885. return 0;
  886. }
  887. // Return the size for constant buffer of each decl.
  888. unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
  889. DxilTypeSystem &dxilTypeSys,
  890. unsigned &arrayEltSize) {
  891. QualType paramTy = Ty.getCanonicalType();
  892. if (const ReferenceType *RefType = dyn_cast<ReferenceType>(paramTy))
  893. paramTy = RefType->getPointeeType();
  894. // Get size.
  895. llvm::Type *Type = CGM.getTypes().ConvertType(paramTy);
  896. unsigned size = dataLayout.getTypeAllocSize(Type);
  897. if (IsHLSLMatType(Ty)) {
  898. llvm::Type *EltTy = HLMatrixType::cast(Type).getElementTypeForReg();
  899. bool b64Bit = dataLayout.getTypeAllocSize(EltTy) == 8;
  900. size = GetMatrixSizeInCB(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor,
  901. b64Bit);
  902. }
  903. // Skip element types.
  904. if (IsElementInputOutputType(paramTy))
  905. return size;
  906. else if (IsHLSLStreamOutputType(Ty)) {
  907. return AddTypeAnnotation(GetHLSLOutputPatchElementType(Ty), dxilTypeSys,
  908. arrayEltSize);
  909. } else if (IsHLSLInputPatchType(Ty))
  910. return AddTypeAnnotation(GetHLSLInputPatchElementType(Ty), dxilTypeSys,
  911. arrayEltSize);
  912. else if (IsHLSLOutputPatchType(Ty))
  913. return AddTypeAnnotation(GetHLSLOutputPatchElementType(Ty), dxilTypeSys,
  914. arrayEltSize);
  915. else if (const RecordType *RT = paramTy->getAsStructureType()) {
  916. RecordDecl *RD = RT->getDecl();
  917. llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
  918. // Skip if already created.
  919. if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
  920. unsigned structSize = annotation->GetCBufferSize();
  921. return structSize;
  922. }
  923. DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
  924. GetNumTemplateArgsForRecordDecl(RT->getDecl()));
  925. return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
  926. } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
  927. // For this pointer.
  928. RecordDecl *RD = RT->getDecl();
  929. llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
  930. // Skip if already created.
  931. if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
  932. unsigned structSize = annotation->GetCBufferSize();
  933. return structSize;
  934. }
  935. DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST,
  936. GetNumTemplateArgsForRecordDecl(RT->getDecl()));
  937. return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
  938. } else if (IsHLSLResourceType(Ty)) {
  939. // Save result type info.
  940. AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
  941. // Resource don't count for cbuffer size.
  942. return 0;
  943. } else if (IsStringType(Ty)) {
  944. // string won't be included in cbuffer
  945. return 0;
  946. } else {
  947. unsigned arraySize = 0;
  948. QualType arrayElementTy = Ty;
  949. if (Ty->isConstantArrayType()) {
  950. const ConstantArrayType *arrayTy =
  951. CGM.getContext().getAsConstantArrayType(Ty);
  952. DXASSERT(arrayTy != nullptr, "Must array type here");
  953. arraySize = arrayTy->getSize().getLimitedValue();
  954. arrayElementTy = arrayTy->getElementType();
  955. }
  956. else if (Ty->isIncompleteArrayType()) {
  957. const IncompleteArrayType *arrayTy = CGM.getContext().getAsIncompleteArrayType(Ty);
  958. arrayElementTy = arrayTy->getElementType();
  959. } else {
  960. DXASSERT(0, "Must array type here");
  961. }
  962. unsigned elementSize = AddTypeAnnotation(arrayElementTy, dxilTypeSys, arrayEltSize);
  963. // Only set arrayEltSize once.
  964. if (arrayEltSize == 0)
  965. arrayEltSize = elementSize;
  966. // Align to 4 * 4bytes.
  967. unsigned alignedSize = (elementSize + 15) & 0xfffffff0;
  968. return alignedSize * (arraySize - 1) + elementSize;
  969. }
  970. }
  971. static DxilResource::Kind KeywordToKind(StringRef keyword) {
  972. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  973. // compare)
  974. if (keyword == "Texture1D" || keyword == "RWTexture1D" || keyword == "RasterizerOrderedTexture1D")
  975. return DxilResource::Kind::Texture1D;
  976. if (keyword == "Texture2D" || keyword == "RWTexture2D" || keyword == "RasterizerOrderedTexture2D")
  977. return DxilResource::Kind::Texture2D;
  978. if (keyword == "Texture2DMS" || keyword == "RWTexture2DMS")
  979. return DxilResource::Kind::Texture2DMS;
  980. if (keyword == "FeedbackTexture2D")
  981. return DxilResource::Kind::FeedbackTexture2D;
  982. if (keyword == "Texture3D" || keyword == "RWTexture3D" || keyword == "RasterizerOrderedTexture3D")
  983. return DxilResource::Kind::Texture3D;
  984. if (keyword == "TextureCube" || keyword == "RWTextureCube")
  985. return DxilResource::Kind::TextureCube;
  986. if (keyword == "Texture1DArray" || keyword == "RWTexture1DArray" || keyword == "RasterizerOrderedTexture1DArray")
  987. return DxilResource::Kind::Texture1DArray;
  988. if (keyword == "Texture2DArray" || keyword == "RWTexture2DArray" || keyword == "RasterizerOrderedTexture2DArray")
  989. return DxilResource::Kind::Texture2DArray;
  990. if (keyword == "FeedbackTexture2DArray")
  991. return DxilResource::Kind::FeedbackTexture2DArray;
  992. if (keyword == "Texture2DMSArray" || keyword == "RWTexture2DMSArray")
  993. return DxilResource::Kind::Texture2DMSArray;
  994. if (keyword == "TextureCubeArray" || keyword == "RWTextureCubeArray")
  995. return DxilResource::Kind::TextureCubeArray;
  996. if (keyword == "ByteAddressBuffer" || keyword == "RWByteAddressBuffer" || keyword == "RasterizerOrderedByteAddressBuffer")
  997. return DxilResource::Kind::RawBuffer;
  998. if (keyword == "StructuredBuffer" || keyword == "RWStructuredBuffer" || keyword == "RasterizerOrderedStructuredBuffer")
  999. return DxilResource::Kind::StructuredBuffer;
  1000. if (keyword == "AppendStructuredBuffer" || keyword == "ConsumeStructuredBuffer")
  1001. return DxilResource::Kind::StructuredBuffer;
  1002. // TODO: this is not efficient.
  1003. bool isBuffer = keyword == "Buffer";
  1004. isBuffer |= keyword == "RWBuffer";
  1005. isBuffer |= keyword == "RasterizerOrderedBuffer";
  1006. if (isBuffer)
  1007. return DxilResource::Kind::TypedBuffer;
  1008. if (keyword == "RaytracingAccelerationStructure")
  1009. return DxilResource::Kind::RTAccelerationStructure;
  1010. return DxilResource::Kind::Invalid;
  1011. }
  1012. void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
  1013. // Add hlsl intrinsic attr
  1014. unsigned intrinsicOpcode;
  1015. StringRef intrinsicGroup;
  1016. llvm::FunctionType *FT = F->getFunctionType();
  1017. auto AddResourceMetadata = [&](QualType qTy, llvm::Type *Ty) {
  1018. hlsl::DxilResourceBase::Class resClass = TypeToClass(qTy);
  1019. if (resClass != hlsl::DxilResourceBase::Class::Invalid) {
  1020. if (!resMetadataMap.count(Ty)) {
  1021. MDNode *Meta = GetOrAddResTypeMD(qTy);
  1022. DXASSERT(Meta, "else invalid resource type");
  1023. resMetadataMap[Ty] = Meta;
  1024. }
  1025. }
  1026. };
  1027. if (hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup)) {
  1028. AddHLSLIntrinsicOpcodeToFunction(F, intrinsicOpcode);
  1029. F->addFnAttr(hlsl::HLPrefix, intrinsicGroup);
  1030. unsigned iParamOffset = 0; // skip this on llvm function
  1031. // Save resource type annotation.
  1032. if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD)) {
  1033. iParamOffset = 1;
  1034. const CXXRecordDecl *RD = MD->getParent();
  1035. // For nested case like sample_slice_type.
  1036. if (const CXXRecordDecl *PRD =
  1037. dyn_cast<CXXRecordDecl>(RD->getDeclContext())) {
  1038. RD = PRD;
  1039. }
  1040. QualType recordTy = MD->getASTContext().getRecordType(RD);
  1041. llvm::Type *Ty = CGM.getTypes().ConvertType(recordTy);
  1042. AddResourceMetadata(recordTy, Ty);
  1043. }
  1044. // Add metadata for any resources found in parameters
  1045. for (unsigned iParam = 0; iParam < FD->getNumParams(); iParam++) {
  1046. llvm::Type *Ty = FT->getParamType(iParam + iParamOffset);
  1047. if (!Ty->isPointerTy())
  1048. continue; // not a resource
  1049. Ty = Ty->getPointerElementType();
  1050. QualType paramTy = FD->getParamDecl(iParam)->getType();
  1051. AddResourceMetadata(paramTy, Ty);
  1052. }
  1053. StringRef lower;
  1054. if (hlsl::GetIntrinsicLowering(FD, lower))
  1055. hlsl::SetHLLowerStrategy(F, lower);
  1056. // Don't need to add FunctionQual for intrinsic function.
  1057. return;
  1058. }
  1059. if (m_pHLModule->GetFloat32DenormMode() == DXIL::Float32DenormMode::FTZ) {
  1060. F->addFnAttr(DXIL::kFP32DenormKindString, DXIL::kFP32DenormValueFtzString);
  1061. }
  1062. else if (m_pHLModule->GetFloat32DenormMode() == DXIL::Float32DenormMode::Preserve) {
  1063. F->addFnAttr(DXIL::kFP32DenormKindString, DXIL::kFP32DenormValuePreserveString);
  1064. }
  1065. else if (m_pHLModule->GetFloat32DenormMode() == DXIL::Float32DenormMode::Any) {
  1066. F->addFnAttr(DXIL::kFP32DenormKindString, DXIL::kFP32DenormValueAnyString);
  1067. }
  1068. // Set entry function
  1069. const std::string &entryName = m_pHLModule->GetEntryFunctionName();
  1070. bool isEntry = FD->getNameAsString() == entryName;
  1071. if (isEntry) {
  1072. Entry.Func = F;
  1073. Entry.SL = FD->getLocation();
  1074. }
  1075. DiagnosticsEngine &Diags = CGM.getDiags();
  1076. std::unique_ptr<DxilFunctionProps> funcProps =
  1077. llvm::make_unique<DxilFunctionProps>();
  1078. funcProps->shaderKind = DXIL::ShaderKind::Invalid;
  1079. bool isCS = false;
  1080. bool isGS = false;
  1081. bool isHS = false;
  1082. bool isDS = false;
  1083. bool isVS = false;
  1084. bool isPS = false;
  1085. bool isRay = false;
  1086. bool isMS = false;
  1087. bool isAS = false;
  1088. if (const HLSLShaderAttr *Attr = FD->getAttr<HLSLShaderAttr>()) {
  1089. // Stage is already validate in HandleDeclAttributeForHLSL.
  1090. // Here just check first letter (or two).
  1091. switch (Attr->getStage()[0]) {
  1092. case 'c':
  1093. switch (Attr->getStage()[1]) {
  1094. case 'o':
  1095. isCS = true;
  1096. funcProps->shaderKind = DXIL::ShaderKind::Compute;
  1097. break;
  1098. case 'l':
  1099. isRay = true;
  1100. funcProps->shaderKind = DXIL::ShaderKind::ClosestHit;
  1101. break;
  1102. case 'a':
  1103. isRay = true;
  1104. funcProps->shaderKind = DXIL::ShaderKind::Callable;
  1105. break;
  1106. default:
  1107. break;
  1108. }
  1109. break;
  1110. case 'v':
  1111. isVS = true;
  1112. funcProps->shaderKind = DXIL::ShaderKind::Vertex;
  1113. break;
  1114. case 'h':
  1115. isHS = true;
  1116. funcProps->shaderKind = DXIL::ShaderKind::Hull;
  1117. break;
  1118. case 'd':
  1119. isDS = true;
  1120. funcProps->shaderKind = DXIL::ShaderKind::Domain;
  1121. break;
  1122. case 'g':
  1123. isGS = true;
  1124. funcProps->shaderKind = DXIL::ShaderKind::Geometry;
  1125. break;
  1126. case 'p':
  1127. isPS = true;
  1128. funcProps->shaderKind = DXIL::ShaderKind::Pixel;
  1129. break;
  1130. case 'r':
  1131. isRay = true;
  1132. funcProps->shaderKind = DXIL::ShaderKind::RayGeneration;
  1133. break;
  1134. case 'i':
  1135. isRay = true;
  1136. funcProps->shaderKind = DXIL::ShaderKind::Intersection;
  1137. break;
  1138. case 'a':
  1139. switch (Attr->getStage()[1]) {
  1140. case 'm':
  1141. isAS = true;
  1142. funcProps->shaderKind = DXIL::ShaderKind::Amplification;
  1143. break;
  1144. case 'n':
  1145. isRay = true;
  1146. funcProps->shaderKind = DXIL::ShaderKind::AnyHit;
  1147. break;
  1148. default:
  1149. break;
  1150. }
  1151. break;
  1152. case 'm':
  1153. switch (Attr->getStage()[1]) {
  1154. case 'e':
  1155. isMS = true;
  1156. funcProps->shaderKind = DXIL::ShaderKind::Mesh;
  1157. break;
  1158. case 'i':
  1159. isRay = true;
  1160. funcProps->shaderKind = DXIL::ShaderKind::Miss;
  1161. break;
  1162. default:
  1163. break;
  1164. }
  1165. break;
  1166. default:
  1167. break;
  1168. }
  1169. if (funcProps->shaderKind == DXIL::ShaderKind::Invalid) {
  1170. unsigned DiagID = Diags.getCustomDiagID(
  1171. DiagnosticsEngine::Error, "Invalid profile for shader attribute");
  1172. Diags.Report(Attr->getLocation(), DiagID);
  1173. }
  1174. if (isEntry && isRay) {
  1175. unsigned DiagID = Diags.getCustomDiagID(
  1176. DiagnosticsEngine::Error, "Ray function cannot be used as a global entry point");
  1177. Diags.Report(Attr->getLocation(), DiagID);
  1178. }
  1179. }
  1180. // Save patch constant function to patchConstantFunctionMap.
  1181. bool isPatchConstantFunction = false;
  1182. if (!isEntry && CGM.getContext().IsPatchConstantFunctionDecl(FD)) {
  1183. isPatchConstantFunction = true;
  1184. auto &PCI = patchConstantFunctionMap[FD->getName()];
  1185. PCI.SL = FD->getLocation();
  1186. PCI.Func = F;
  1187. ++PCI.NumOverloads;
  1188. for (ParmVarDecl *parmDecl : FD->parameters()) {
  1189. QualType Ty = parmDecl->getType();
  1190. if (IsHLSLOutputPatchType(Ty)) {
  1191. funcProps->ShaderProps.HS.outputControlPoints =
  1192. GetHLSLOutputPatchCount(parmDecl->getType());
  1193. } else if (IsHLSLInputPatchType(Ty)) {
  1194. funcProps->ShaderProps.HS.inputControlPoints =
  1195. GetHLSLInputPatchCount(parmDecl->getType());
  1196. }
  1197. }
  1198. // Mark patch constant functions that cannot be linked as exports
  1199. // InternalLinkage. Patch constant functions that are actually used
  1200. // will be set back to ExternalLinkage in FinishCodeGen.
  1201. if (funcProps->ShaderProps.HS.outputControlPoints ||
  1202. funcProps->ShaderProps.HS.inputControlPoints) {
  1203. PCI.Func->setLinkage(GlobalValue::InternalLinkage);
  1204. }
  1205. funcProps->shaderKind = DXIL::ShaderKind::Hull;
  1206. }
  1207. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  1208. if (isEntry) {
  1209. funcProps->shaderKind = SM->GetKind();
  1210. if (funcProps->shaderKind == DXIL::ShaderKind::Mesh) {
  1211. isMS = true;
  1212. }
  1213. else if (funcProps->shaderKind == DXIL::ShaderKind::Amplification) {
  1214. isAS = true;
  1215. }
  1216. }
  1217. // Geometry shader.
  1218. if (const HLSLMaxVertexCountAttr *Attr =
  1219. FD->getAttr<HLSLMaxVertexCountAttr>()) {
  1220. isGS = true;
  1221. funcProps->shaderKind = DXIL::ShaderKind::Geometry;
  1222. funcProps->ShaderProps.GS.maxVertexCount = Attr->getCount();
  1223. funcProps->ShaderProps.GS.inputPrimitive = DXIL::InputPrimitive::Undefined;
  1224. if (isEntry && !SM->IsGS()) {
  1225. unsigned DiagID =
  1226. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1227. "attribute maxvertexcount only valid for GS.");
  1228. Diags.Report(Attr->getLocation(), DiagID);
  1229. return;
  1230. }
  1231. }
  1232. if (const HLSLInstanceAttr *Attr = FD->getAttr<HLSLInstanceAttr>()) {
  1233. unsigned instanceCount = Attr->getCount();
  1234. funcProps->ShaderProps.GS.instanceCount = instanceCount;
  1235. if (isEntry && !SM->IsGS()) {
  1236. unsigned DiagID =
  1237. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1238. "attribute maxvertexcount only valid for GS.");
  1239. Diags.Report(Attr->getLocation(), DiagID);
  1240. return;
  1241. }
  1242. } else {
  1243. // Set default instance count.
  1244. if (isGS)
  1245. funcProps->ShaderProps.GS.instanceCount = 1;
  1246. }
  1247. // Compute shader
  1248. if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
  1249. if (isMS) {
  1250. funcProps->ShaderProps.MS.numThreads[0] = Attr->getX();
  1251. funcProps->ShaderProps.MS.numThreads[1] = Attr->getY();
  1252. funcProps->ShaderProps.MS.numThreads[2] = Attr->getZ();
  1253. } else if (isAS) {
  1254. funcProps->ShaderProps.AS.numThreads[0] = Attr->getX();
  1255. funcProps->ShaderProps.AS.numThreads[1] = Attr->getY();
  1256. funcProps->ShaderProps.AS.numThreads[2] = Attr->getZ();
  1257. } else {
  1258. isCS = true;
  1259. funcProps->shaderKind = DXIL::ShaderKind::Compute;
  1260. funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
  1261. funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
  1262. funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
  1263. }
  1264. if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
  1265. unsigned DiagID = Diags.getCustomDiagID(
  1266. DiagnosticsEngine::Error, "attribute numthreads only valid for CS/MS/AS.");
  1267. Diags.Report(Attr->getLocation(), DiagID);
  1268. return;
  1269. }
  1270. }
  1271. // Hull shader.
  1272. if (const HLSLPatchConstantFuncAttr *Attr =
  1273. FD->getAttr<HLSLPatchConstantFuncAttr>()) {
  1274. if (isEntry && !SM->IsHS()) {
  1275. unsigned DiagID = Diags.getCustomDiagID(
  1276. DiagnosticsEngine::Error,
  1277. "attribute patchconstantfunc only valid for HS.");
  1278. Diags.Report(Attr->getLocation(), DiagID);
  1279. return;
  1280. }
  1281. isHS = true;
  1282. funcProps->shaderKind = DXIL::ShaderKind::Hull;
  1283. HSEntryPatchConstantFuncAttr[F] = Attr;
  1284. } else {
  1285. // TODO: This is a duplicate check. We also have this check in
  1286. // hlsl::DiagnoseTranslationUnit(clang::Sema*).
  1287. if (isEntry && SM->IsHS()) {
  1288. unsigned DiagID = Diags.getCustomDiagID(
  1289. DiagnosticsEngine::Error,
  1290. "HS entry point must have the patchconstantfunc attribute");
  1291. Diags.Report(FD->getLocation(), DiagID);
  1292. return;
  1293. }
  1294. }
  1295. if (const HLSLOutputControlPointsAttr *Attr =
  1296. FD->getAttr<HLSLOutputControlPointsAttr>()) {
  1297. if (isHS) {
  1298. funcProps->ShaderProps.HS.outputControlPoints = Attr->getCount();
  1299. } else if (isEntry && !SM->IsHS()) {
  1300. unsigned DiagID = Diags.getCustomDiagID(
  1301. DiagnosticsEngine::Error,
  1302. "attribute outputcontrolpoints only valid for HS.");
  1303. Diags.Report(Attr->getLocation(), DiagID);
  1304. return;
  1305. }
  1306. }
  1307. if (const HLSLPartitioningAttr *Attr = FD->getAttr<HLSLPartitioningAttr>()) {
  1308. if (isHS) {
  1309. DXIL::TessellatorPartitioning partition =
  1310. StringToPartitioning(Attr->getScheme());
  1311. funcProps->ShaderProps.HS.partition = partition;
  1312. } else if (isEntry && !SM->IsHS()) {
  1313. unsigned DiagID =
  1314. Diags.getCustomDiagID(DiagnosticsEngine::Warning,
  1315. "attribute partitioning only valid for HS.");
  1316. Diags.Report(Attr->getLocation(), DiagID);
  1317. }
  1318. }
  1319. if (const HLSLOutputTopologyAttr *Attr =
  1320. FD->getAttr<HLSLOutputTopologyAttr>()) {
  1321. if (isHS) {
  1322. DXIL::TessellatorOutputPrimitive primitive =
  1323. StringToTessOutputPrimitive(Attr->getTopology());
  1324. funcProps->ShaderProps.HS.outputPrimitive = primitive;
  1325. }
  1326. else if (isMS) {
  1327. DXIL::MeshOutputTopology topology =
  1328. StringToMeshOutputTopology(Attr->getTopology());
  1329. funcProps->ShaderProps.MS.outputTopology = topology;
  1330. }
  1331. else if (isEntry && !SM->IsHS() && !SM->IsMS()) {
  1332. unsigned DiagID =
  1333. Diags.getCustomDiagID(DiagnosticsEngine::Warning,
  1334. "attribute outputtopology only valid for HS and MS.");
  1335. Diags.Report(Attr->getLocation(), DiagID);
  1336. }
  1337. }
  1338. if (isHS) {
  1339. funcProps->ShaderProps.HS.maxTessFactor = DXIL::kHSMaxTessFactorUpperBound;
  1340. funcProps->ShaderProps.HS.inputControlPoints = DXIL::kHSDefaultInputControlPointCount;
  1341. }
  1342. if (const HLSLMaxTessFactorAttr *Attr =
  1343. FD->getAttr<HLSLMaxTessFactorAttr>()) {
  1344. if (isHS) {
  1345. // TODO: change getFactor to return float.
  1346. llvm::APInt intV(32, Attr->getFactor());
  1347. funcProps->ShaderProps.HS.maxTessFactor = intV.bitsToFloat();
  1348. } else if (isEntry && !SM->IsHS()) {
  1349. unsigned DiagID =
  1350. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1351. "attribute maxtessfactor only valid for HS.");
  1352. Diags.Report(Attr->getLocation(), DiagID);
  1353. return;
  1354. }
  1355. }
  1356. // Hull or domain shader.
  1357. if (const HLSLDomainAttr *Attr = FD->getAttr<HLSLDomainAttr>()) {
  1358. if (isEntry && !SM->IsHS() && !SM->IsDS()) {
  1359. unsigned DiagID =
  1360. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1361. "attribute domain only valid for HS or DS.");
  1362. Diags.Report(Attr->getLocation(), DiagID);
  1363. return;
  1364. }
  1365. isDS = !isHS;
  1366. if (isDS)
  1367. funcProps->shaderKind = DXIL::ShaderKind::Domain;
  1368. DXIL::TessellatorDomain domain = StringToDomain(Attr->getDomainType());
  1369. if (isHS)
  1370. funcProps->ShaderProps.HS.domain = domain;
  1371. else
  1372. funcProps->ShaderProps.DS.domain = domain;
  1373. }
  1374. // Vertex shader.
  1375. if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
  1376. if (isEntry && !SM->IsVS()) {
  1377. unsigned DiagID = Diags.getCustomDiagID(
  1378. DiagnosticsEngine::Error, "attribute clipplane only valid for VS.");
  1379. Diags.Report(Attr->getLocation(), DiagID);
  1380. return;
  1381. }
  1382. isVS = true;
  1383. // The real job is done at EmitHLSLFunctionProlog where debug info is
  1384. // available. Only set shader kind here.
  1385. funcProps->shaderKind = DXIL::ShaderKind::Vertex;
  1386. }
  1387. // Pixel shader.
  1388. if (const HLSLEarlyDepthStencilAttr *Attr =
  1389. FD->getAttr<HLSLEarlyDepthStencilAttr>()) {
  1390. if (isEntry && !SM->IsPS()) {
  1391. unsigned DiagID = Diags.getCustomDiagID(
  1392. DiagnosticsEngine::Error,
  1393. "attribute earlydepthstencil only valid for PS.");
  1394. Diags.Report(Attr->getLocation(), DiagID);
  1395. return;
  1396. }
  1397. isPS = true;
  1398. funcProps->ShaderProps.PS.EarlyDepthStencil = true;
  1399. funcProps->shaderKind = DXIL::ShaderKind::Pixel;
  1400. }
  1401. const unsigned profileAttributes = isCS + isHS + isDS + isGS + isVS + isPS + isRay + isMS + isAS;
  1402. // TODO: check this in front-end and report error.
  1403. DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive");
  1404. if (isEntry) {
  1405. switch (funcProps->shaderKind) {
  1406. case ShaderModel::Kind::Compute:
  1407. case ShaderModel::Kind::Hull:
  1408. case ShaderModel::Kind::Domain:
  1409. case ShaderModel::Kind::Geometry:
  1410. case ShaderModel::Kind::Vertex:
  1411. case ShaderModel::Kind::Pixel:
  1412. case ShaderModel::Kind::Mesh:
  1413. case ShaderModel::Kind::Amplification:
  1414. DXASSERT(funcProps->shaderKind == SM->GetKind(),
  1415. "attribute profile not match entry function profile");
  1416. break;
  1417. case ShaderModel::Kind::Library:
  1418. case ShaderModel::Kind::Invalid:
  1419. // Non-shader stage shadermodels don't have entry points.
  1420. break;
  1421. }
  1422. }
  1423. DxilFunctionAnnotation *FuncAnnotation =
  1424. m_pHLModule->AddFunctionAnnotation(F);
  1425. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  1426. // Param Info
  1427. unsigned streamIndex = 0;
  1428. unsigned inputPatchCount = 0;
  1429. unsigned outputPatchCount = 0;
  1430. unsigned ArgNo = 0;
  1431. unsigned ParmIdx = 0;
  1432. if (const CXXMethodDecl *MethodDecl = dyn_cast<CXXMethodDecl>(FD)) {
  1433. if (MethodDecl->isInstance()) {
  1434. QualType ThisTy = MethodDecl->getThisType(FD->getASTContext());
  1435. DxilParameterAnnotation &paramAnnotation =
  1436. FuncAnnotation->GetParameterAnnotation(ArgNo++);
  1437. // Construct annoation for this pointer.
  1438. ConstructFieldAttributedAnnotation(paramAnnotation, ThisTy,
  1439. bDefaultRowMajor);
  1440. }
  1441. }
  1442. // Ret Info
  1443. QualType retTy = FD->getReturnType();
  1444. DxilParameterAnnotation *pRetTyAnnotation = nullptr;
  1445. if (F->getReturnType()->isVoidTy() && !retTy->isVoidType()) {
  1446. // SRet.
  1447. pRetTyAnnotation = &FuncAnnotation->GetParameterAnnotation(ArgNo++);
  1448. } else {
  1449. pRetTyAnnotation = &FuncAnnotation->GetRetTypeAnnotation();
  1450. }
  1451. DxilParameterAnnotation &retTyAnnotation = *pRetTyAnnotation;
  1452. // keep Undefined here, we cannot decide for struct
  1453. retTyAnnotation.SetInterpolationMode(
  1454. GetInterpMode(FD, CompType::Kind::Invalid, /*bKeepUndefined*/ true)
  1455. .GetKind());
  1456. SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation);
  1457. retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out);
  1458. if (isEntry) {
  1459. if (CGM.getLangOpts().EnableDX9CompatMode && retTyAnnotation.HasSemanticString()) {
  1460. RemapObsoleteSemantic(retTyAnnotation, /*isPatchConstantFunction*/ false);
  1461. }
  1462. CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation,
  1463. /*isPatchConstantFunction*/ false);
  1464. }
  1465. if (isRay && !retTy->isVoidType()) {
  1466. Diags.Report(FD->getLocation(), Diags.getCustomDiagID(
  1467. DiagnosticsEngine::Error, "return type for ray tracing shaders must be void"));
  1468. }
  1469. ConstructFieldAttributedAnnotation(retTyAnnotation, retTy, bDefaultRowMajor);
  1470. if (FD->hasAttr<HLSLPreciseAttr>())
  1471. retTyAnnotation.SetPrecise();
  1472. if (isRay) {
  1473. funcProps->ShaderProps.Ray.payloadSizeInBytes = 0;
  1474. funcProps->ShaderProps.Ray.attributeSizeInBytes = 0;
  1475. }
  1476. bool hasOutIndices = false;
  1477. bool hasOutVertices = false;
  1478. bool hasOutPrimitives = false;
  1479. bool hasInPayload = false;
  1480. for (; ArgNo < F->arg_size(); ++ArgNo, ++ParmIdx) {
  1481. DxilParameterAnnotation &paramAnnotation =
  1482. FuncAnnotation->GetParameterAnnotation(ArgNo);
  1483. const ParmVarDecl *parmDecl = FD->getParamDecl(ParmIdx);
  1484. QualType fieldTy = parmDecl->getType();
  1485. // if parameter type is a typedef, try to desugar it first.
  1486. if (isa<TypedefType>(fieldTy.getTypePtr()))
  1487. fieldTy = fieldTy.getDesugaredType(FD->getASTContext());
  1488. ConstructFieldAttributedAnnotation(paramAnnotation, fieldTy,
  1489. bDefaultRowMajor);
  1490. if (parmDecl->hasAttr<HLSLPreciseAttr>())
  1491. paramAnnotation.SetPrecise();
  1492. // keep Undefined here, we cannot decide for struct
  1493. InterpolationMode paramIM =
  1494. GetInterpMode(parmDecl, CompType::Kind::Invalid, KeepUndefinedTrue);
  1495. paramAnnotation.SetInterpolationMode(paramIM);
  1496. SourceLocation paramSemanticLoc = SetSemantic(parmDecl, paramAnnotation);
  1497. DxilParamInputQual dxilInputQ = DxilParamInputQual::In;
  1498. if (parmDecl->hasAttr<HLSLInOutAttr>())
  1499. dxilInputQ = DxilParamInputQual::Inout;
  1500. else if (parmDecl->hasAttr<HLSLOutAttr>())
  1501. dxilInputQ = DxilParamInputQual::Out;
  1502. if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLInAttr>())
  1503. dxilInputQ = DxilParamInputQual::Inout;
  1504. if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLIndicesAttr>()) {
  1505. if (hasOutIndices) {
  1506. unsigned DiagID = Diags.getCustomDiagID(
  1507. DiagnosticsEngine::Error,
  1508. "multiple out indices parameters not allowed");
  1509. Diags.Report(parmDecl->getLocation(), DiagID);
  1510. continue;
  1511. }
  1512. const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
  1513. if (CAT == nullptr) {
  1514. unsigned DiagID = Diags.getCustomDiagID(
  1515. DiagnosticsEngine::Error,
  1516. "indices output is not an constant-length array");
  1517. Diags.Report(parmDecl->getLocation(), DiagID);
  1518. continue;
  1519. }
  1520. unsigned count = CAT->getSize().getZExtValue();
  1521. if (count > DXIL::kMaxMSOutputPrimitiveCount) {
  1522. unsigned DiagID = Diags.getCustomDiagID(
  1523. DiagnosticsEngine::Error,
  1524. "max primitive count should not exceed %0");
  1525. Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
  1526. continue;
  1527. }
  1528. if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
  1529. funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
  1530. unsigned DiagID = Diags.getCustomDiagID(
  1531. DiagnosticsEngine::Error,
  1532. "max primitive count mismatch");
  1533. Diags.Report(parmDecl->getLocation(), DiagID);
  1534. continue;
  1535. }
  1536. // Get element type.
  1537. QualType arrayEleTy = CAT->getElementType();
  1538. if (hlsl::IsHLSLVecType(arrayEleTy)) {
  1539. QualType vecEltTy = hlsl::GetHLSLVecElementType(arrayEleTy);
  1540. if (!vecEltTy->isUnsignedIntegerType() || CGM.getContext().getTypeSize(vecEltTy) != 32) {
  1541. unsigned DiagID = Diags.getCustomDiagID(
  1542. DiagnosticsEngine::Error,
  1543. "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
  1544. Diags.Report(parmDecl->getLocation(), DiagID);
  1545. continue;
  1546. }
  1547. unsigned vecEltCount = hlsl::GetHLSLVecSize(arrayEleTy);
  1548. if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Line && vecEltCount != 2) {
  1549. unsigned DiagID = Diags.getCustomDiagID(
  1550. DiagnosticsEngine::Error,
  1551. "the element of out_indices array in a mesh shader whose output topology is line must be uint2");
  1552. Diags.Report(parmDecl->getLocation(), DiagID);
  1553. continue;
  1554. }
  1555. if (funcProps->ShaderProps.MS.outputTopology == DXIL::MeshOutputTopology::Triangle && vecEltCount != 3) {
  1556. unsigned DiagID = Diags.getCustomDiagID(
  1557. DiagnosticsEngine::Error,
  1558. "the element of out_indices array in a mesh shader whose output topology is triangle must be uint3");
  1559. Diags.Report(parmDecl->getLocation(), DiagID);
  1560. continue;
  1561. }
  1562. } else {
  1563. unsigned DiagID = Diags.getCustomDiagID(
  1564. DiagnosticsEngine::Error,
  1565. "the element of out_indices array must be uint2 for line output or uint3 for triangle output");
  1566. Diags.Report(parmDecl->getLocation(), DiagID);
  1567. continue;
  1568. }
  1569. dxilInputQ = DxilParamInputQual::OutIndices;
  1570. funcProps->ShaderProps.MS.maxPrimitiveCount = count;
  1571. hasOutIndices = true;
  1572. }
  1573. if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLVerticesAttr>()) {
  1574. if (hasOutVertices) {
  1575. unsigned DiagID = Diags.getCustomDiagID(
  1576. DiagnosticsEngine::Error,
  1577. "multiple out vertices parameters not allowed");
  1578. Diags.Report(parmDecl->getLocation(), DiagID);
  1579. continue;
  1580. }
  1581. const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
  1582. if (CAT == nullptr) {
  1583. unsigned DiagID = Diags.getCustomDiagID(
  1584. DiagnosticsEngine::Error,
  1585. "vertices output is not an constant-length array");
  1586. Diags.Report(parmDecl->getLocation(), DiagID);
  1587. continue;
  1588. }
  1589. unsigned count = CAT->getSize().getZExtValue();
  1590. if (count > DXIL::kMaxMSOutputVertexCount) {
  1591. unsigned DiagID = Diags.getCustomDiagID(
  1592. DiagnosticsEngine::Error,
  1593. "max vertex count should not exceed %0");
  1594. Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputVertexCount;
  1595. continue;
  1596. }
  1597. dxilInputQ = DxilParamInputQual::OutVertices;
  1598. funcProps->ShaderProps.MS.maxVertexCount = count;
  1599. hasOutVertices = true;
  1600. }
  1601. if (parmDecl->hasAttr<HLSLOutAttr>() && parmDecl->hasAttr<HLSLPrimitivesAttr>()) {
  1602. if (hasOutPrimitives) {
  1603. unsigned DiagID = Diags.getCustomDiagID(
  1604. DiagnosticsEngine::Error,
  1605. "multiple out primitives parameters not allowed");
  1606. Diags.Report(parmDecl->getLocation(), DiagID);
  1607. continue;
  1608. }
  1609. const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(fieldTy.getCanonicalType());
  1610. if (CAT == nullptr) {
  1611. unsigned DiagID = Diags.getCustomDiagID(
  1612. DiagnosticsEngine::Error,
  1613. "primitives output is not an constant-length array");
  1614. Diags.Report(parmDecl->getLocation(), DiagID);
  1615. continue;
  1616. }
  1617. unsigned count = CAT->getSize().getZExtValue();
  1618. if (count > DXIL::kMaxMSOutputPrimitiveCount) {
  1619. unsigned DiagID = Diags.getCustomDiagID(
  1620. DiagnosticsEngine::Error,
  1621. "max primitive count should not exceed %0");
  1622. Diags.Report(parmDecl->getLocation(), DiagID) << DXIL::kMaxMSOutputPrimitiveCount;
  1623. continue;
  1624. }
  1625. if (funcProps->ShaderProps.MS.maxPrimitiveCount != 0 &&
  1626. funcProps->ShaderProps.MS.maxPrimitiveCount != count) {
  1627. unsigned DiagID = Diags.getCustomDiagID(
  1628. DiagnosticsEngine::Error,
  1629. "max primitive count mismatch");
  1630. Diags.Report(parmDecl->getLocation(), DiagID);
  1631. continue;
  1632. }
  1633. dxilInputQ = DxilParamInputQual::OutPrimitives;
  1634. funcProps->ShaderProps.MS.maxPrimitiveCount = count;
  1635. hasOutPrimitives = true;
  1636. }
  1637. if (parmDecl->hasAttr<HLSLInAttr>() && parmDecl->hasAttr<HLSLPayloadAttr>()) {
  1638. if (hasInPayload) {
  1639. unsigned DiagID = Diags.getCustomDiagID(
  1640. DiagnosticsEngine::Error,
  1641. "multiple in payload parameters not allowed");
  1642. Diags.Report(parmDecl->getLocation(), DiagID);
  1643. continue;
  1644. }
  1645. dxilInputQ = DxilParamInputQual::InPayload;
  1646. DataLayout DL(&this->TheModule);
  1647. funcProps->ShaderProps.MS.payloadSizeInBytes = DL.getTypeAllocSize(
  1648. F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
  1649. hasInPayload = true;
  1650. }
  1651. DXIL::InputPrimitive inputPrimitive = DXIL::InputPrimitive::Undefined;
  1652. if (IsHLSLOutputPatchType(parmDecl->getType())) {
  1653. outputPatchCount++;
  1654. if (dxilInputQ != DxilParamInputQual::In) {
  1655. unsigned DiagID = Diags.getCustomDiagID(
  1656. DiagnosticsEngine::Error,
  1657. "OutputPatch should not be out/inout parameter");
  1658. Diags.Report(parmDecl->getLocation(), DiagID);
  1659. continue;
  1660. }
  1661. dxilInputQ = DxilParamInputQual::OutputPatch;
  1662. if (isDS)
  1663. funcProps->ShaderProps.DS.inputControlPoints =
  1664. GetHLSLOutputPatchCount(parmDecl->getType());
  1665. } else if (IsHLSLInputPatchType(parmDecl->getType())) {
  1666. inputPatchCount++;
  1667. if (dxilInputQ != DxilParamInputQual::In) {
  1668. unsigned DiagID = Diags.getCustomDiagID(
  1669. DiagnosticsEngine::Error,
  1670. "InputPatch should not be out/inout parameter");
  1671. Diags.Report(parmDecl->getLocation(), DiagID);
  1672. continue;
  1673. }
  1674. dxilInputQ = DxilParamInputQual::InputPatch;
  1675. if (isHS) {
  1676. funcProps->ShaderProps.HS.inputControlPoints =
  1677. GetHLSLInputPatchCount(parmDecl->getType());
  1678. } else if (isGS) {
  1679. inputPrimitive = (DXIL::InputPrimitive)(
  1680. (unsigned)DXIL::InputPrimitive::ControlPointPatch1 +
  1681. GetHLSLInputPatchCount(parmDecl->getType()) - 1);
  1682. }
  1683. } else if (IsHLSLStreamOutputType(parmDecl->getType())) {
  1684. // TODO: validation this at ASTContext::getFunctionType in
  1685. // AST/ASTContext.cpp
  1686. DXASSERT(dxilInputQ == DxilParamInputQual::Inout,
  1687. "stream output parameter must be inout");
  1688. switch (streamIndex) {
  1689. case 0:
  1690. dxilInputQ = DxilParamInputQual::OutStream0;
  1691. break;
  1692. case 1:
  1693. dxilInputQ = DxilParamInputQual::OutStream1;
  1694. break;
  1695. case 2:
  1696. dxilInputQ = DxilParamInputQual::OutStream2;
  1697. break;
  1698. case 3:
  1699. default:
  1700. // TODO: validation this at ASTContext::getFunctionType in
  1701. // AST/ASTContext.cpp
  1702. DXASSERT(streamIndex == 3, "stream number out of bound");
  1703. dxilInputQ = DxilParamInputQual::OutStream3;
  1704. break;
  1705. }
  1706. DXIL::PrimitiveTopology &streamTopology =
  1707. funcProps->ShaderProps.GS.streamPrimitiveTopologies[streamIndex];
  1708. if (IsHLSLPointStreamType(parmDecl->getType()))
  1709. streamTopology = DXIL::PrimitiveTopology::PointList;
  1710. else if (IsHLSLLineStreamType(parmDecl->getType()))
  1711. streamTopology = DXIL::PrimitiveTopology::LineStrip;
  1712. else {
  1713. DXASSERT(IsHLSLTriangleStreamType(parmDecl->getType()),
  1714. "invalid StreamType");
  1715. streamTopology = DXIL::PrimitiveTopology::TriangleStrip;
  1716. }
  1717. if (streamIndex > 0) {
  1718. bool bAllPoint =
  1719. streamTopology == DXIL::PrimitiveTopology::PointList &&
  1720. funcProps->ShaderProps.GS.streamPrimitiveTopologies[0] ==
  1721. DXIL::PrimitiveTopology::PointList;
  1722. if (!bAllPoint) {
  1723. unsigned DiagID = Diags.getCustomDiagID(
  1724. DiagnosticsEngine::Error, "when multiple GS output streams are "
  1725. "used they must be pointlists.");
  1726. Diags.Report(FD->getLocation(), DiagID);
  1727. }
  1728. }
  1729. streamIndex++;
  1730. }
  1731. unsigned GsInputArrayDim = 0;
  1732. if (parmDecl->hasAttr<HLSLTriangleAttr>()) {
  1733. inputPrimitive = DXIL::InputPrimitive::Triangle;
  1734. GsInputArrayDim = 3;
  1735. } else if (parmDecl->hasAttr<HLSLTriangleAdjAttr>()) {
  1736. inputPrimitive = DXIL::InputPrimitive::TriangleWithAdjacency;
  1737. GsInputArrayDim = 6;
  1738. } else if (parmDecl->hasAttr<HLSLPointAttr>()) {
  1739. inputPrimitive = DXIL::InputPrimitive::Point;
  1740. GsInputArrayDim = 1;
  1741. } else if (parmDecl->hasAttr<HLSLLineAdjAttr>()) {
  1742. inputPrimitive = DXIL::InputPrimitive::LineWithAdjacency;
  1743. GsInputArrayDim = 4;
  1744. } else if (parmDecl->hasAttr<HLSLLineAttr>()) {
  1745. inputPrimitive = DXIL::InputPrimitive::Line;
  1746. GsInputArrayDim = 2;
  1747. }
  1748. if (inputPrimitive != DXIL::InputPrimitive::Undefined) {
  1749. // Set to InputPrimitive for GS.
  1750. dxilInputQ = DxilParamInputQual::InputPrimitive;
  1751. if (funcProps->ShaderProps.GS.inputPrimitive ==
  1752. DXIL::InputPrimitive::Undefined) {
  1753. funcProps->ShaderProps.GS.inputPrimitive = inputPrimitive;
  1754. } else if (funcProps->ShaderProps.GS.inputPrimitive != inputPrimitive) {
  1755. unsigned DiagID = Diags.getCustomDiagID(
  1756. DiagnosticsEngine::Error, "input parameter conflicts with geometry "
  1757. "specifier of previous input parameters");
  1758. Diags.Report(parmDecl->getLocation(), DiagID);
  1759. }
  1760. }
  1761. if (GsInputArrayDim != 0) {
  1762. QualType Ty = parmDecl->getType();
  1763. if (!Ty->isConstantArrayType()) {
  1764. unsigned DiagID = Diags.getCustomDiagID(
  1765. DiagnosticsEngine::Error,
  1766. "input types for geometry shader must be constant size arrays");
  1767. Diags.Report(parmDecl->getLocation(), DiagID);
  1768. } else {
  1769. const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty);
  1770. if (CAT->getSize().getLimitedValue() != GsInputArrayDim) {
  1771. StringRef primtiveNames[] = {
  1772. "invalid", // 0
  1773. "point", // 1
  1774. "line", // 2
  1775. "triangle", // 3
  1776. "lineadj", // 4
  1777. "invalid", // 5
  1778. "triangleadj", // 6
  1779. };
  1780. DXASSERT(GsInputArrayDim < llvm::array_lengthof(primtiveNames),
  1781. "Invalid array dim");
  1782. unsigned DiagID = Diags.getCustomDiagID(
  1783. DiagnosticsEngine::Error, "array dimension for %0 must be %1");
  1784. Diags.Report(parmDecl->getLocation(), DiagID)
  1785. << primtiveNames[GsInputArrayDim] << GsInputArrayDim;
  1786. }
  1787. }
  1788. }
  1789. // Validate Ray Tracing function parameter (some validation may be pushed into front end)
  1790. if (isRay) {
  1791. switch (funcProps->shaderKind) {
  1792. case DXIL::ShaderKind::RayGeneration:
  1793. case DXIL::ShaderKind::Intersection:
  1794. // RayGeneration and Intersection shaders are not allowed to have any input parameters
  1795. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1796. DiagnosticsEngine::Error, "parameters are not allowed for %0 shader"))
  1797. << (funcProps->shaderKind == DXIL::ShaderKind::RayGeneration ?
  1798. "raygeneration" : "intersection");
  1799. break;
  1800. case DXIL::ShaderKind::AnyHit:
  1801. case DXIL::ShaderKind::ClosestHit:
  1802. if (0 == ArgNo && dxilInputQ != DxilParamInputQual::Inout) {
  1803. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1804. DiagnosticsEngine::Error,
  1805. "ray payload parameter must be inout"));
  1806. } else if (1 == ArgNo && dxilInputQ != DxilParamInputQual::In) {
  1807. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1808. DiagnosticsEngine::Error,
  1809. "intersection attributes parameter must be in"));
  1810. } else if (ArgNo > 1) {
  1811. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1812. DiagnosticsEngine::Error,
  1813. "too many parameters, expected payload and attributes parameters only."));
  1814. }
  1815. if (ArgNo < 2) {
  1816. if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
  1817. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1818. DiagnosticsEngine::Error,
  1819. "payload and attribute structures must be user defined types with only numeric contents."));
  1820. } else {
  1821. DataLayout DL(&this->TheModule);
  1822. unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
  1823. if (0 == ArgNo)
  1824. funcProps->ShaderProps.Ray.payloadSizeInBytes = size;
  1825. else
  1826. funcProps->ShaderProps.Ray.attributeSizeInBytes = size;
  1827. }
  1828. }
  1829. break;
  1830. case DXIL::ShaderKind::Miss:
  1831. if (ArgNo > 0) {
  1832. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1833. DiagnosticsEngine::Error,
  1834. "only one parameter (ray payload) allowed for miss shader"));
  1835. } else if (dxilInputQ != DxilParamInputQual::Inout) {
  1836. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1837. DiagnosticsEngine::Error,
  1838. "ray payload parameter must be declared inout"));
  1839. }
  1840. if (ArgNo < 1) {
  1841. if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
  1842. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1843. DiagnosticsEngine::Error,
  1844. "ray payload parameter must be a user defined type with only numeric contents."));
  1845. } else {
  1846. DataLayout DL(&this->TheModule);
  1847. unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
  1848. funcProps->ShaderProps.Ray.payloadSizeInBytes = size;
  1849. }
  1850. }
  1851. break;
  1852. case DXIL::ShaderKind::Callable:
  1853. if (ArgNo > 0) {
  1854. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1855. DiagnosticsEngine::Error,
  1856. "only one parameter allowed for callable shader"));
  1857. } else if (dxilInputQ != DxilParamInputQual::Inout) {
  1858. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1859. DiagnosticsEngine::Error,
  1860. "callable parameter must be declared inout"));
  1861. }
  1862. if (ArgNo < 1) {
  1863. if (!IsHLSLNumericUserDefinedType(parmDecl->getType())) {
  1864. Diags.Report(parmDecl->getLocation(), Diags.getCustomDiagID(
  1865. DiagnosticsEngine::Error,
  1866. "callable parameter must be a user defined type with only numeric contents."));
  1867. } else {
  1868. DataLayout DL(&this->TheModule);
  1869. unsigned size = DL.getTypeAllocSize(F->getFunctionType()->getFunctionParamType(ArgNo)->getPointerElementType());
  1870. funcProps->ShaderProps.Ray.paramSizeInBytes = size;
  1871. }
  1872. }
  1873. break;
  1874. }
  1875. }
  1876. paramAnnotation.SetParamInputQual(dxilInputQ);
  1877. if (isEntry) {
  1878. if (CGM.getLangOpts().EnableDX9CompatMode && paramAnnotation.HasSemanticString()) {
  1879. RemapObsoleteSemantic(paramAnnotation, /*isPatchConstantFunction*/ false);
  1880. }
  1881. CheckParameterAnnotation(paramSemanticLoc, paramAnnotation,
  1882. /*isPatchConstantFunction*/ false);
  1883. }
  1884. }
  1885. if (inputPatchCount > 1) {
  1886. unsigned DiagID = Diags.getCustomDiagID(
  1887. DiagnosticsEngine::Error, "may only have one InputPatch parameter");
  1888. Diags.Report(FD->getLocation(), DiagID);
  1889. }
  1890. if (outputPatchCount > 1) {
  1891. unsigned DiagID = Diags.getCustomDiagID(
  1892. DiagnosticsEngine::Error, "may only have one OutputPatch parameter");
  1893. Diags.Report(FD->getLocation(), DiagID);
  1894. }
  1895. // If Shader is a ray shader that requires parameters, make sure size is non-zero
  1896. if (isRay) {
  1897. bool bNeedsAttributes = false;
  1898. bool bNeedsPayload = false;
  1899. switch (funcProps->shaderKind) {
  1900. case DXIL::ShaderKind::AnyHit:
  1901. case DXIL::ShaderKind::ClosestHit:
  1902. bNeedsAttributes = true;
  1903. case DXIL::ShaderKind::Miss:
  1904. bNeedsPayload = true;
  1905. case DXIL::ShaderKind::Callable:
  1906. if (0 == funcProps->ShaderProps.Ray.payloadSizeInBytes) {
  1907. unsigned DiagID = bNeedsPayload ?
  1908. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1909. "shader must include inout payload structure parameter.") :
  1910. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1911. "shader must include inout parameter structure.");
  1912. Diags.Report(FD->getLocation(), DiagID);
  1913. }
  1914. }
  1915. if (bNeedsAttributes &&
  1916. 0 == funcProps->ShaderProps.Ray.attributeSizeInBytes) {
  1917. Diags.Report(FD->getLocation(), Diags.getCustomDiagID(
  1918. DiagnosticsEngine::Error,
  1919. "shader must include attributes structure parameter."));
  1920. }
  1921. }
  1922. // Type annotation for parameters and return type.
  1923. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  1924. unsigned arrayEltSize = 0;
  1925. AddTypeAnnotation(FD->getReturnType(), dxilTypeSys, arrayEltSize);
  1926. // Type annotation for this pointer.
  1927. if (const CXXMethodDecl *MFD = dyn_cast<CXXMethodDecl>(FD)) {
  1928. const CXXRecordDecl *RD = MFD->getParent();
  1929. QualType Ty = CGM.getContext().getTypeDeclType(RD);
  1930. AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  1931. }
  1932. for (const ValueDecl *param : FD->params()) {
  1933. QualType Ty = param->getType();
  1934. AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  1935. }
  1936. // clear isExportedEntry if not exporting entry
  1937. bool isExportedEntry = profileAttributes != 0;
  1938. if (isExportedEntry) {
  1939. // use unmangled or mangled name depending on which is used for final entry function
  1940. StringRef name = isRay ? F->getName() : FD->getName();
  1941. if (!m_ExportMap.IsExported(name)) {
  1942. isExportedEntry = false;
  1943. }
  1944. }
  1945. // Only add functionProps when exist.
  1946. if (isExportedEntry || isEntry)
  1947. m_pHLModule->AddDxilFunctionProps(F, funcProps);
  1948. if (isPatchConstantFunction)
  1949. patchConstantFunctionPropsMap[F] = std::move(funcProps);
  1950. // Save F to entry map.
  1951. if (isExportedEntry) {
  1952. if (entryFunctionMap.count(FD->getName())) {
  1953. DiagnosticsEngine &Diags = CGM.getDiags();
  1954. unsigned DiagID = Diags.getCustomDiagID(
  1955. DiagnosticsEngine::Error,
  1956. "redefinition of %0");
  1957. Diags.Report(FD->getLocStart(), DiagID) << FD->getName();
  1958. }
  1959. auto &Entry = entryFunctionMap[FD->getNameAsString()];
  1960. Entry.SL = FD->getLocation();
  1961. Entry.Func= F;
  1962. }
  1963. // Add target-dependent experimental function attributes
  1964. for (const auto &Attr : FD->specific_attrs<HLSLExperimentalAttr>()) {
  1965. F->addFnAttr(Twine("exp-", Attr->getName()).str(), Attr->getValue());
  1966. }
  1967. }
  1968. void CGMSHLSLRuntime::RemapObsoleteSemantic(DxilParameterAnnotation &paramInfo, bool isPatchConstantFunction) {
  1969. DXASSERT(CGM.getLangOpts().EnableDX9CompatMode, "should be used only in back-compat mode");
  1970. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  1971. DXIL::SigPointKind sigPointKind = SigPointFromInputQual(paramInfo.GetParamInputQual(), SM->GetKind(), isPatchConstantFunction);
  1972. hlsl::RemapObsoleteSemantic(paramInfo, sigPointKind, CGM.getLLVMContext());
  1973. }
  1974. void CGMSHLSLRuntime::EmitHLSLFunctionProlog(Function *F, const FunctionDecl *FD) {
  1975. // Support clip plane need debug info which not available when create function attribute.
  1976. if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
  1977. DxilFunctionProps &funcProps = m_pHLModule->GetDxilFunctionProps(F);
  1978. // Initialize to null.
  1979. memset(funcProps.ShaderProps.VS.clipPlanes, 0, sizeof(funcProps.ShaderProps.VS.clipPlanes));
  1980. // Create global for each clip plane, and use the clip plane val as init val.
  1981. auto AddClipPlane = [&](Expr *clipPlane, unsigned idx) {
  1982. if (DeclRefExpr *decl = dyn_cast<DeclRefExpr>(clipPlane)) {
  1983. const VarDecl *VD = cast<VarDecl>(decl->getDecl());
  1984. Constant *clipPlaneVal = CGM.GetAddrOfGlobalVar(VD);
  1985. funcProps.ShaderProps.VS.clipPlanes[idx] = clipPlaneVal;
  1986. if (m_bDebugInfo) {
  1987. CodeGenFunction CGF(CGM);
  1988. ApplyDebugLocation applyDebugLoc(CGF, clipPlane);
  1989. debugInfoMap[clipPlaneVal] = CGF.Builder.getCurrentDebugLocation();
  1990. }
  1991. } else {
  1992. // Must be a MemberExpr.
  1993. const MemberExpr *ME = cast<MemberExpr>(clipPlane);
  1994. CodeGenFunction CGF(CGM);
  1995. CodeGen::LValue LV = CGF.EmitMemberExpr(ME);
  1996. Value *addr = LV.getAddress();
  1997. funcProps.ShaderProps.VS.clipPlanes[idx] = cast<Constant>(addr);
  1998. if (m_bDebugInfo) {
  1999. CodeGenFunction CGF(CGM);
  2000. ApplyDebugLocation applyDebugLoc(CGF, clipPlane);
  2001. debugInfoMap[addr] = CGF.Builder.getCurrentDebugLocation();
  2002. }
  2003. }
  2004. };
  2005. if (Expr *clipPlane = Attr->getClipPlane1())
  2006. AddClipPlane(clipPlane, 0);
  2007. if (Expr *clipPlane = Attr->getClipPlane2())
  2008. AddClipPlane(clipPlane, 1);
  2009. if (Expr *clipPlane = Attr->getClipPlane3())
  2010. AddClipPlane(clipPlane, 2);
  2011. if (Expr *clipPlane = Attr->getClipPlane4())
  2012. AddClipPlane(clipPlane, 3);
  2013. if (Expr *clipPlane = Attr->getClipPlane5())
  2014. AddClipPlane(clipPlane, 4);
  2015. if (Expr *clipPlane = Attr->getClipPlane6())
  2016. AddClipPlane(clipPlane, 5);
  2017. clipPlaneFuncList.emplace_back(F);
  2018. }
  2019. // Update function linkage based on DefaultLinkage
  2020. // We will take care of patch constant functions later, once identified for certain.
  2021. if (!m_pHLModule->HasDxilFunctionProps(F)) {
  2022. if (F->getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage) {
  2023. if (!FD->hasAttr<HLSLExportAttr>()) {
  2024. switch (CGM.getCodeGenOpts().DefaultLinkage) {
  2025. case DXIL::DefaultLinkage::Default:
  2026. if (m_pHLModule->GetShaderModel()->GetMinor() != ShaderModel::kOfflineMinor)
  2027. F->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  2028. break;
  2029. case DXIL::DefaultLinkage::Internal:
  2030. F->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  2031. break;
  2032. }
  2033. }
  2034. }
  2035. }
  2036. }
  2037. void CGMSHLSLRuntime::AddControlFlowHint(CodeGenFunction &CGF, const Stmt &S,
  2038. llvm::TerminatorInst *TI,
  2039. ArrayRef<const Attr *> Attrs) {
  2040. // Build hints.
  2041. bool bNoBranchFlatten = true;
  2042. bool bBranch = false;
  2043. bool bFlatten = false;
  2044. std::vector<DXIL::ControlFlowHint> hints;
  2045. for (const auto *Attr : Attrs) {
  2046. if (isa<HLSLBranchAttr>(Attr)) {
  2047. hints.emplace_back(DXIL::ControlFlowHint::Branch);
  2048. bNoBranchFlatten = false;
  2049. bBranch = true;
  2050. }
  2051. else if (isa<HLSLFlattenAttr>(Attr)) {
  2052. hints.emplace_back(DXIL::ControlFlowHint::Flatten);
  2053. bNoBranchFlatten = false;
  2054. bFlatten = true;
  2055. } else if (isa<HLSLForceCaseAttr>(Attr)) {
  2056. if (isa<SwitchStmt>(&S)) {
  2057. hints.emplace_back(DXIL::ControlFlowHint::ForceCase);
  2058. }
  2059. }
  2060. // Ignore fastopt, allow_uav_condition and call for now.
  2061. }
  2062. if (bNoBranchFlatten) {
  2063. // CHECK control flow option.
  2064. if (CGF.CGM.getCodeGenOpts().HLSLPreferControlFlow)
  2065. hints.emplace_back(DXIL::ControlFlowHint::Branch);
  2066. else if (CGF.CGM.getCodeGenOpts().HLSLAvoidControlFlow)
  2067. hints.emplace_back(DXIL::ControlFlowHint::Flatten);
  2068. }
  2069. if (bFlatten && bBranch) {
  2070. DiagnosticsEngine &Diags = CGM.getDiags();
  2071. unsigned DiagID = Diags.getCustomDiagID(
  2072. DiagnosticsEngine::Error,
  2073. "can't use branch and flatten attributes together");
  2074. Diags.Report(S.getLocStart(), DiagID);
  2075. }
  2076. if (hints.size()) {
  2077. // Add meta data to the instruction.
  2078. MDNode *hintsNode = DxilMDHelper::EmitControlFlowHints(Context, hints);
  2079. TI->setMetadata(DxilMDHelper::kDxilControlFlowHintMDName, hintsNode);
  2080. }
  2081. }
  2082. void CGMSHLSLRuntime::FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) {
  2083. if (D.hasAttr<HLSLPreciseAttr>()) {
  2084. AllocaInst *AI = cast<AllocaInst>(V);
  2085. HLModule::MarkPreciseAttributeWithMetadata(AI);
  2086. }
  2087. // Add type annotation for local variable.
  2088. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  2089. unsigned arrayEltSize = 0;
  2090. AddTypeAnnotation(D.getType(), typeSys, arrayEltSize);
  2091. }
  2092. hlsl::InterpolationMode CGMSHLSLRuntime::GetInterpMode(const Decl *decl,
  2093. CompType compType,
  2094. bool bKeepUndefined) {
  2095. InterpolationMode Interp(
  2096. decl->hasAttr<HLSLNoInterpolationAttr>(), decl->hasAttr<HLSLLinearAttr>(),
  2097. decl->hasAttr<HLSLNoPerspectiveAttr>(), decl->hasAttr<HLSLCentroidAttr>(),
  2098. decl->hasAttr<HLSLSampleAttr>());
  2099. DXASSERT(Interp.IsValid(), "otherwise front-end missing validation");
  2100. if (Interp.IsUndefined() && !bKeepUndefined) {
  2101. // Type-based default: linear for floats, constant for others.
  2102. if (compType.IsFloatTy())
  2103. Interp = InterpolationMode::Kind::Linear;
  2104. else
  2105. Interp = InterpolationMode::Kind::Constant;
  2106. }
  2107. return Interp;
  2108. }
  2109. hlsl::CompType CGMSHLSLRuntime::GetCompType(const BuiltinType *BT) {
  2110. hlsl::CompType ElementType = hlsl::CompType::getInvalid();
  2111. switch (BT->getKind()) {
  2112. case BuiltinType::Bool:
  2113. ElementType = hlsl::CompType::getI1();
  2114. break;
  2115. case BuiltinType::Double:
  2116. ElementType = hlsl::CompType::getF64();
  2117. break;
  2118. case BuiltinType::HalfFloat: // HLSL Change
  2119. case BuiltinType::Float:
  2120. ElementType = hlsl::CompType::getF32();
  2121. break;
  2122. // HLSL Changes begin
  2123. case BuiltinType::Min10Float:
  2124. case BuiltinType::Min16Float:
  2125. // HLSL Changes end
  2126. case BuiltinType::Half:
  2127. ElementType = hlsl::CompType::getF16();
  2128. break;
  2129. case BuiltinType::Int:
  2130. ElementType = hlsl::CompType::getI32();
  2131. break;
  2132. case BuiltinType::LongLong:
  2133. ElementType = hlsl::CompType::getI64();
  2134. break;
  2135. // HLSL Changes begin
  2136. case BuiltinType::Min12Int:
  2137. case BuiltinType::Min16Int:
  2138. // HLSL Changes end
  2139. case BuiltinType::Short:
  2140. ElementType = hlsl::CompType::getI16();
  2141. break;
  2142. case BuiltinType::UInt:
  2143. ElementType = hlsl::CompType::getU32();
  2144. break;
  2145. case BuiltinType::ULongLong:
  2146. ElementType = hlsl::CompType::getU64();
  2147. break;
  2148. case BuiltinType::Min16UInt: // HLSL Change
  2149. case BuiltinType::UShort:
  2150. ElementType = hlsl::CompType::getU16();
  2151. break;
  2152. default:
  2153. llvm_unreachable("unsupported type");
  2154. break;
  2155. }
  2156. return ElementType;
  2157. }
  2158. /// Add resource to the program
  2159. void CGMSHLSLRuntime::addResource(Decl *D) {
  2160. if (HLSLBufferDecl *BD = dyn_cast<HLSLBufferDecl>(D))
  2161. GetOrCreateCBuffer(BD);
  2162. else if (VarDecl *VD = dyn_cast<VarDecl>(D)) {
  2163. hlsl::DxilResourceBase::Class resClass = TypeToClass(VD->getType());
  2164. // skip decl has init which is resource.
  2165. if (VD->hasInit() && resClass != DXIL::ResourceClass::Invalid)
  2166. return;
  2167. // skip static global.
  2168. if (!VD->hasExternalFormalLinkage()) {
  2169. if (VD->hasInit() && VD->getType().isConstQualified()) {
  2170. Expr* InitExp = VD->getInit();
  2171. GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD));
  2172. // Only save const static global of struct type.
  2173. if (GV->getType()->getElementType()->isStructTy()) {
  2174. staticConstGlobalInitMap[InitExp] = GV;
  2175. }
  2176. }
  2177. // Add type annotation for static global variable.
  2178. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  2179. unsigned arrayEltSize = 0;
  2180. AddTypeAnnotation(VD->getType(), typeSys, arrayEltSize);
  2181. return;
  2182. }
  2183. if (D->hasAttr<HLSLGroupSharedAttr>()) {
  2184. GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD));
  2185. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  2186. unsigned arraySize = 0;
  2187. AddTypeAnnotation(VD->getType(), dxilTypeSys, arraySize);
  2188. m_pHLModule->AddGroupSharedVariable(GV);
  2189. return;
  2190. }
  2191. switch (resClass) {
  2192. case hlsl::DxilResourceBase::Class::Sampler:
  2193. AddSampler(VD);
  2194. break;
  2195. case hlsl::DxilResourceBase::Class::UAV:
  2196. case hlsl::DxilResourceBase::Class::SRV:
  2197. AddUAVSRV(VD, resClass);
  2198. break;
  2199. case hlsl::DxilResourceBase::Class::Invalid: {
  2200. // normal global constant, add to global CB
  2201. HLCBuffer &globalCB = GetGlobalCBuffer();
  2202. AddConstant(VD, globalCB);
  2203. break;
  2204. }
  2205. case DXIL::ResourceClass::CBuffer:
  2206. DXASSERT(0, "cbuffer should not be here");
  2207. break;
  2208. }
  2209. }
  2210. }
  2211. /// Add subobject to the module
  2212. void CGMSHLSLRuntime::addSubobject(Decl *D) {
  2213. VarDecl *VD = dyn_cast<VarDecl>(D);
  2214. DXASSERT(VD != nullptr, "must be a global variable");
  2215. DXIL::SubobjectKind subobjKind;
  2216. DXIL::HitGroupType hgType;
  2217. if (!hlsl::GetHLSLSubobjectKind(VD->getType(), subobjKind, hgType)) {
  2218. DXASSERT(false, "not a valid subobject declaration");
  2219. return;
  2220. }
  2221. Expr *initExpr = const_cast<Expr*>(VD->getAnyInitializer());
  2222. if (!initExpr) {
  2223. DiagnosticsEngine &Diags = CGM.getDiags();
  2224. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "subobject needs to be initialized");
  2225. Diags.Report(D->getLocStart(), DiagID);
  2226. return;
  2227. }
  2228. if (InitListExpr *initListExpr = dyn_cast<InitListExpr>(initExpr)) {
  2229. try {
  2230. CreateSubobject(subobjKind, VD->getName(), initListExpr->getInits(), initListExpr->getNumInits(), hgType);
  2231. } catch (hlsl::Exception&) {
  2232. DiagnosticsEngine &Diags = CGM.getDiags();
  2233. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "internal error creating subobject");
  2234. Diags.Report(initExpr->getLocStart(), DiagID);
  2235. return;
  2236. }
  2237. }
  2238. else {
  2239. DiagnosticsEngine &Diags = CGM.getDiags();
  2240. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "expected initialization list");
  2241. Diags.Report(initExpr->getLocStart(), DiagID);
  2242. return;
  2243. }
  2244. }
  2245. // TODO: collect such helper utility functions in one place.
  2246. static DxilResourceBase::Class KeywordToClass(const std::string &keyword) {
  2247. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  2248. // compare)
  2249. if (keyword == "SamplerState")
  2250. return DxilResourceBase::Class::Sampler;
  2251. if (keyword == "SamplerComparisonState")
  2252. return DxilResourceBase::Class::Sampler;
  2253. if (keyword == "ConstantBuffer")
  2254. return DxilResourceBase::Class::CBuffer;
  2255. if (keyword == "TextureBuffer")
  2256. return DxilResourceBase::Class::SRV;
  2257. bool isSRV = keyword == "Buffer";
  2258. isSRV |= keyword == "ByteAddressBuffer";
  2259. isSRV |= keyword == "RaytracingAccelerationStructure";
  2260. isSRV |= keyword == "StructuredBuffer";
  2261. isSRV |= keyword == "Texture1D";
  2262. isSRV |= keyword == "Texture1DArray";
  2263. isSRV |= keyword == "Texture2D";
  2264. isSRV |= keyword == "Texture2DArray";
  2265. isSRV |= keyword == "Texture3D";
  2266. isSRV |= keyword == "TextureCube";
  2267. isSRV |= keyword == "TextureCubeArray";
  2268. isSRV |= keyword == "Texture2DMS";
  2269. isSRV |= keyword == "Texture2DMSArray";
  2270. if (isSRV)
  2271. return DxilResourceBase::Class::SRV;
  2272. bool isUAV = keyword == "RWBuffer";
  2273. isUAV |= keyword == "RWByteAddressBuffer";
  2274. isUAV |= keyword == "RWStructuredBuffer";
  2275. isUAV |= keyword == "RWTexture1D";
  2276. isUAV |= keyword == "RWTexture1DArray";
  2277. isUAV |= keyword == "RWTexture2D";
  2278. isUAV |= keyword == "RWTexture2DArray";
  2279. isUAV |= keyword == "RWTexture3D";
  2280. isUAV |= keyword == "RWTextureCube";
  2281. isUAV |= keyword == "RWTextureCubeArray";
  2282. isUAV |= keyword == "RWTexture2DMS";
  2283. isUAV |= keyword == "RWTexture2DMSArray";
  2284. isUAV |= keyword == "AppendStructuredBuffer";
  2285. isUAV |= keyword == "ConsumeStructuredBuffer";
  2286. isUAV |= keyword == "RasterizerOrderedBuffer";
  2287. isUAV |= keyword == "RasterizerOrderedByteAddressBuffer";
  2288. isUAV |= keyword == "RasterizerOrderedStructuredBuffer";
  2289. isUAV |= keyword == "RasterizerOrderedTexture1D";
  2290. isUAV |= keyword == "RasterizerOrderedTexture1DArray";
  2291. isUAV |= keyword == "RasterizerOrderedTexture2D";
  2292. isUAV |= keyword == "RasterizerOrderedTexture2DArray";
  2293. isUAV |= keyword == "RasterizerOrderedTexture3D";
  2294. isUAV |= keyword == "FeedbackTexture2D";
  2295. isUAV |= keyword == "FeedbackTexture2DArray";
  2296. if (isUAV)
  2297. return DxilResourceBase::Class::UAV;
  2298. return DxilResourceBase::Class::Invalid;
  2299. }
  2300. // This should probably be refactored to ASTContextHLSL, and follow types
  2301. // rather than do string comparisons.
  2302. DXIL::ResourceClass
  2303. hlsl::GetResourceClassForType(const clang::ASTContext &context,
  2304. clang::QualType Ty) {
  2305. Ty = Ty.getCanonicalType();
  2306. if (const clang::ArrayType *arrayType = context.getAsArrayType(Ty)) {
  2307. return GetResourceClassForType(context, arrayType->getElementType());
  2308. } else if (const RecordType *RT = Ty->getAsStructureType()) {
  2309. return KeywordToClass(RT->getDecl()->getName());
  2310. } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
  2311. if (const ClassTemplateSpecializationDecl *templateDecl =
  2312. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl())) {
  2313. return KeywordToClass(templateDecl->getName());
  2314. }
  2315. }
  2316. return hlsl::DxilResourceBase::Class::Invalid;
  2317. }
  2318. hlsl::DxilResourceBase::Class CGMSHLSLRuntime::TypeToClass(clang::QualType Ty) {
  2319. return hlsl::GetResourceClassForType(CGM.getContext(), Ty);
  2320. }
  2321. namespace {
  2322. void GetResourceDeclElemTypeAndRangeSize(CodeGenModule &CGM, HLModule &HL, VarDecl &VD,
  2323. QualType &ElemType, unsigned& rangeSize) {
  2324. // We can't canonicalize nor desugar the type without losing the 'snorm' in Buffer<snorm float>
  2325. ElemType = VD.getType();
  2326. rangeSize = 1;
  2327. while (const clang::ArrayType *arrayType = CGM.getContext().getAsArrayType(ElemType)) {
  2328. if (rangeSize != UINT_MAX) {
  2329. if (arrayType->isConstantArrayType()) {
  2330. rangeSize *= cast<ConstantArrayType>(arrayType)->getSize().getLimitedValue();
  2331. }
  2332. else {
  2333. if (HL.GetHLOptions().bLegacyResourceReservation) {
  2334. DiagnosticsEngine &Diags = CGM.getDiags();
  2335. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  2336. "unbounded resources are not supported with -flegacy-resource-reservation");
  2337. Diags.Report(VD.getLocation(), DiagID);
  2338. }
  2339. rangeSize = UINT_MAX;
  2340. }
  2341. }
  2342. ElemType = arrayType->getElementType();
  2343. }
  2344. }
  2345. }
  2346. static void InitFromUnusualAnnotations(DxilResourceBase &Resource, NamedDecl &Decl) {
  2347. for (hlsl::UnusualAnnotation* It : Decl.getUnusualAnnotations()) {
  2348. switch (It->getKind()) {
  2349. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  2350. hlsl::RegisterAssignment* RegAssign = cast<hlsl::RegisterAssignment>(It);
  2351. if (RegAssign->RegisterType) {
  2352. Resource.SetLowerBound(RegAssign->RegisterNumber);
  2353. // For backcompat, don't auto-assign the register space if there's an
  2354. // explicit register type.
  2355. Resource.SetSpaceID(RegAssign->RegisterSpace.getValueOr(0));
  2356. }
  2357. else {
  2358. Resource.SetSpaceID(RegAssign->RegisterSpace.getValueOr(UINT_MAX));
  2359. }
  2360. break;
  2361. }
  2362. case hlsl::UnusualAnnotation::UA_SemanticDecl:
  2363. // Ignore Semantics
  2364. break;
  2365. case hlsl::UnusualAnnotation::UA_ConstantPacking:
  2366. // Should be handled by front-end
  2367. llvm_unreachable("packoffset on resource");
  2368. break;
  2369. default:
  2370. llvm_unreachable("unknown UnusualAnnotation on resource");
  2371. break;
  2372. }
  2373. }
  2374. }
  2375. uint32_t CGMSHLSLRuntime::AddSampler(VarDecl *samplerDecl) {
  2376. llvm::GlobalVariable *val =
  2377. cast<llvm::GlobalVariable>(CGM.GetAddrOfGlobalVar(samplerDecl));
  2378. unique_ptr<DxilSampler> hlslRes(new DxilSampler);
  2379. hlslRes->SetLowerBound(UINT_MAX);
  2380. hlslRes->SetSpaceID(UINT_MAX);
  2381. hlslRes->SetGlobalSymbol(val);
  2382. hlslRes->SetGlobalName(samplerDecl->getName());
  2383. QualType VarTy;
  2384. unsigned rangeSize;
  2385. GetResourceDeclElemTypeAndRangeSize(CGM, *m_pHLModule, *samplerDecl,
  2386. VarTy, rangeSize);
  2387. hlslRes->SetRangeSize(rangeSize);
  2388. const RecordType *RT = VarTy->getAs<RecordType>();
  2389. DxilSampler::SamplerKind kind = KeywordToSamplerKind(RT->getDecl()->getName());
  2390. hlslRes->SetSamplerKind(kind);
  2391. InitFromUnusualAnnotations(*hlslRes, *samplerDecl);
  2392. hlslRes->SetID(m_pHLModule->GetSamplers().size());
  2393. return m_pHLModule->AddSampler(std::move(hlslRes));
  2394. }
  2395. bool CGMSHLSLRuntime::GetAsConstantUInt32(clang::Expr *expr, uint32_t *value) {
  2396. APSInt result;
  2397. if (!expr->EvaluateAsInt(result, CGM.getContext())) {
  2398. DiagnosticsEngine &Diags = CGM.getDiags();
  2399. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  2400. "cannot convert to constant unsigned int");
  2401. Diags.Report(expr->getLocStart(), DiagID);
  2402. return false;
  2403. }
  2404. *value = result.getLimitedValue(UINT32_MAX);
  2405. return true;
  2406. }
  2407. bool CGMSHLSLRuntime::GetAsConstantString(clang::Expr *expr, StringRef *value, bool failWhenEmpty /*=false*/) {
  2408. Expr::EvalResult result;
  2409. DiagnosticsEngine &Diags = CGM.getDiags();
  2410. unsigned DiagID = 0;
  2411. if (expr->EvaluateAsRValue(result, CGM.getContext())) {
  2412. if (result.Val.isLValue()) {
  2413. DXASSERT_NOMSG(result.Val.getLValueOffset().isZero());
  2414. DXASSERT_NOMSG(result.Val.getLValueCallIndex() == 0);
  2415. const Expr *evExpr = result.Val.getLValueBase().get<const Expr *>();
  2416. if (const StringLiteral *strLit = dyn_cast<const StringLiteral>(evExpr)) {
  2417. *value = strLit->getBytes();
  2418. if (!failWhenEmpty || !(*value).empty()) {
  2419. return true;
  2420. }
  2421. DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "empty string not expected here");
  2422. }
  2423. }
  2424. }
  2425. if (!DiagID)
  2426. DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "cannot convert to constant string");
  2427. Diags.Report(expr->getLocStart(), DiagID);
  2428. return false;
  2429. }
  2430. std::vector<StringRef> CGMSHLSLRuntime::ParseSubobjectExportsAssociations(StringRef exports) {
  2431. std::vector<StringRef> parsedExports;
  2432. const char *pData = exports.data();
  2433. const char *pEnd = pData + exports.size();
  2434. const char *pLast = pData;
  2435. while (pData < pEnd) {
  2436. if (*pData == ';') {
  2437. if (pLast < pData) {
  2438. parsedExports.emplace_back(StringRef(pLast, pData - pLast));
  2439. }
  2440. pLast = pData + 1;
  2441. }
  2442. pData++;
  2443. }
  2444. if (pLast < pData) {
  2445. parsedExports.emplace_back(StringRef(pLast, pData - pLast));
  2446. }
  2447. return std::move(parsedExports);
  2448. }
  2449. void CGMSHLSLRuntime::CreateSubobject(DXIL::SubobjectKind kind, const StringRef name,
  2450. clang::Expr **args, unsigned int argCount,
  2451. DXIL::HitGroupType hgType /*= (DXIL::HitGroupType)(-1)*/) {
  2452. DxilSubobjects *subobjects = m_pHLModule->GetSubobjects();
  2453. if (!subobjects) {
  2454. subobjects = new DxilSubobjects();
  2455. m_pHLModule->ResetSubobjects(subobjects);
  2456. }
  2457. DxilRootSignatureCompilationFlags flags = DxilRootSignatureCompilationFlags::GlobalRootSignature;
  2458. switch (kind) {
  2459. case DXIL::SubobjectKind::StateObjectConfig: {
  2460. uint32_t flags;
  2461. DXASSERT_NOMSG(argCount == 1);
  2462. if (GetAsConstantUInt32(args[0], &flags)) {
  2463. subobjects->CreateStateObjectConfig(name, flags);
  2464. }
  2465. break;
  2466. }
  2467. case DXIL::SubobjectKind::LocalRootSignature:
  2468. flags = DxilRootSignatureCompilationFlags::LocalRootSignature;
  2469. __fallthrough;
  2470. case DXIL::SubobjectKind::GlobalRootSignature: {
  2471. DXASSERT_NOMSG(argCount == 1);
  2472. StringRef signature;
  2473. if (!GetAsConstantString(args[0], &signature, true))
  2474. return;
  2475. RootSignatureHandle RootSigHandle;
  2476. CompileRootSignature(signature, CGM.getDiags(), args[0]->getLocStart(), rootSigVer, flags, &RootSigHandle);
  2477. if (!RootSigHandle.IsEmpty()) {
  2478. RootSigHandle.EnsureSerializedAvailable();
  2479. subobjects->CreateRootSignature(name, kind == DXIL::SubobjectKind::LocalRootSignature,
  2480. RootSigHandle.GetSerializedBytes(), RootSigHandle.GetSerializedSize(), &signature);
  2481. }
  2482. break;
  2483. }
  2484. case DXIL::SubobjectKind::SubobjectToExportsAssociation: {
  2485. DXASSERT_NOMSG(argCount == 2);
  2486. StringRef subObjName, exports;
  2487. if (!GetAsConstantString(args[0], &subObjName, true) ||
  2488. !GetAsConstantString(args[1], &exports, false))
  2489. return;
  2490. std::vector<StringRef> exportList = ParseSubobjectExportsAssociations(exports);
  2491. subobjects->CreateSubobjectToExportsAssociation(name, subObjName, exportList.data(), exportList.size());
  2492. break;
  2493. }
  2494. case DXIL::SubobjectKind::RaytracingShaderConfig: {
  2495. DXASSERT_NOMSG(argCount == 2);
  2496. uint32_t maxPayloadSize;
  2497. uint32_t MaxAttributeSize;
  2498. if (!GetAsConstantUInt32(args[0], &maxPayloadSize) ||
  2499. !GetAsConstantUInt32(args[1], &MaxAttributeSize))
  2500. return;
  2501. subobjects->CreateRaytracingShaderConfig(name, maxPayloadSize, MaxAttributeSize);
  2502. break;
  2503. }
  2504. case DXIL::SubobjectKind::RaytracingPipelineConfig: {
  2505. DXASSERT_NOMSG(argCount == 1);
  2506. uint32_t maxTraceRecursionDepth;
  2507. if (!GetAsConstantUInt32(args[0], &maxTraceRecursionDepth))
  2508. return;
  2509. subobjects->CreateRaytracingPipelineConfig(name, maxTraceRecursionDepth);
  2510. break;
  2511. }
  2512. case DXIL::SubobjectKind::HitGroup: {
  2513. switch (hgType) {
  2514. case DXIL::HitGroupType::Triangle: {
  2515. DXASSERT_NOMSG(argCount == 2);
  2516. StringRef anyhit, closesthit;
  2517. if (!GetAsConstantString(args[0], &anyhit) ||
  2518. !GetAsConstantString(args[1], &closesthit))
  2519. return;
  2520. subobjects->CreateHitGroup(name, DXIL::HitGroupType::Triangle, anyhit, closesthit, llvm::StringRef(""));
  2521. break;
  2522. }
  2523. case DXIL::HitGroupType::ProceduralPrimitive: {
  2524. DXASSERT_NOMSG(argCount == 3);
  2525. StringRef anyhit, closesthit, intersection;
  2526. if (!GetAsConstantString(args[0], &anyhit) ||
  2527. !GetAsConstantString(args[1], &closesthit) ||
  2528. !GetAsConstantString(args[2], &intersection, true))
  2529. return;
  2530. subobjects->CreateHitGroup(name, DXIL::HitGroupType::ProceduralPrimitive, anyhit, closesthit, intersection);
  2531. break;
  2532. }
  2533. default:
  2534. llvm_unreachable("unknown HitGroupType");
  2535. }
  2536. break;
  2537. }
  2538. case DXIL::SubobjectKind::RaytracingPipelineConfig1: {
  2539. DXASSERT_NOMSG(argCount == 2);
  2540. uint32_t maxTraceRecursionDepth;
  2541. uint32_t raytracingPipelineFlags;
  2542. if (!GetAsConstantUInt32(args[0], &maxTraceRecursionDepth))
  2543. return;
  2544. if (!GetAsConstantUInt32(args[1], &raytracingPipelineFlags))
  2545. return;
  2546. subobjects->CreateRaytracingPipelineConfig1(name, maxTraceRecursionDepth, raytracingPipelineFlags);
  2547. break;
  2548. }
  2549. default:
  2550. llvm_unreachable("unknown SubobjectKind");
  2551. break;
  2552. }
  2553. }
  2554. static void CollectScalarTypes(std::vector<QualType> &ScalarTys, QualType Ty) {
  2555. if (Ty->isRecordType()) {
  2556. if (hlsl::IsHLSLMatType(Ty)) {
  2557. QualType EltTy = hlsl::GetHLSLMatElementType(Ty);
  2558. unsigned row = 0;
  2559. unsigned col = 0;
  2560. hlsl::GetRowsAndCols(Ty, row, col);
  2561. unsigned size = col*row;
  2562. for (unsigned i = 0; i < size; i++) {
  2563. CollectScalarTypes(ScalarTys, EltTy);
  2564. }
  2565. } else if (hlsl::IsHLSLVecType(Ty)) {
  2566. QualType EltTy = hlsl::GetHLSLVecElementType(Ty);
  2567. unsigned row = 0;
  2568. unsigned col = 0;
  2569. hlsl::GetRowsAndColsForAny(Ty, row, col);
  2570. unsigned size = col;
  2571. for (unsigned i = 0; i < size; i++) {
  2572. CollectScalarTypes(ScalarTys, EltTy);
  2573. }
  2574. } else {
  2575. const RecordType *RT = Ty->getAsStructureType();
  2576. // For CXXRecord.
  2577. if (!RT)
  2578. RT = Ty->getAs<RecordType>();
  2579. RecordDecl *RD = RT->getDecl();
  2580. for (FieldDecl *field : RD->fields())
  2581. CollectScalarTypes(ScalarTys, field->getType());
  2582. }
  2583. } else if (Ty->isArrayType()) {
  2584. const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
  2585. QualType EltTy = AT->getElementType();
  2586. // Set it to 5 for unsized array.
  2587. unsigned size = 5;
  2588. if (AT->isConstantArrayType()) {
  2589. size = cast<ConstantArrayType>(AT)->getSize().getLimitedValue();
  2590. }
  2591. for (unsigned i=0;i<size;i++) {
  2592. CollectScalarTypes(ScalarTys, EltTy);
  2593. }
  2594. } else {
  2595. ScalarTys.emplace_back(Ty);
  2596. }
  2597. }
  2598. bool CGMSHLSLRuntime::SetUAVSRV(SourceLocation loc,
  2599. hlsl::DxilResourceBase::Class resClass,
  2600. DxilResource *hlslRes, QualType QualTy) {
  2601. RecordDecl *RD = QualTy->getAs<RecordType>()->getDecl();
  2602. hlsl::DxilResource::Kind kind = KeywordToKind(RD->getName());
  2603. DXASSERT_NOMSG(kind != hlsl::DxilResource::Kind::Invalid);
  2604. hlslRes->SetKind(kind);
  2605. // Type annotation for result type of resource.
  2606. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  2607. unsigned arrayEltSize = 0;
  2608. AddTypeAnnotation(QualType(RD->getTypeForDecl(),0), dxilTypeSys, arrayEltSize);
  2609. if (kind == hlsl::DxilResource::Kind::Texture2DMS ||
  2610. kind == hlsl::DxilResource::Kind::Texture2DMSArray) {
  2611. const ClassTemplateSpecializationDecl *templateDecl =
  2612. cast<ClassTemplateSpecializationDecl>(RD);
  2613. const clang::TemplateArgument &sampleCountArg =
  2614. templateDecl->getTemplateArgs()[1];
  2615. uint32_t sampleCount = sampleCountArg.getAsIntegral().getLimitedValue();
  2616. hlslRes->SetSampleCount(sampleCount);
  2617. }
  2618. if (hlsl::DxilResource::IsAnyTexture(kind)) {
  2619. const ClassTemplateSpecializationDecl *templateDecl = cast<ClassTemplateSpecializationDecl>(RD);
  2620. const clang::TemplateArgument &texelTyArg = templateDecl->getTemplateArgs()[0];
  2621. llvm::Type *texelTy = CGM.getTypes().ConvertType(texelTyArg.getAsType());
  2622. if (!texelTy->isFloatingPointTy() && !texelTy->isIntegerTy()
  2623. && !hlsl::IsHLSLVecType(texelTyArg.getAsType())) {
  2624. DiagnosticsEngine &Diags = CGM.getDiags();
  2625. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  2626. "texture resource texel type must be scalar or vector");
  2627. Diags.Report(loc, DiagID);
  2628. return false;
  2629. }
  2630. }
  2631. QualType resultTy = hlsl::GetHLSLResourceResultType(QualTy);
  2632. if (kind != hlsl::DxilResource::Kind::StructuredBuffer && !resultTy.isNull()) {
  2633. QualType Ty = resultTy;
  2634. QualType EltTy = Ty;
  2635. if (hlsl::IsHLSLVecType(Ty)) {
  2636. EltTy = hlsl::GetHLSLVecElementType(Ty);
  2637. } else if (hlsl::IsHLSLMatType(Ty)) {
  2638. EltTy = hlsl::GetHLSLMatElementType(Ty);
  2639. } else if (hlsl::IsHLSLAggregateType(resultTy)) {
  2640. // Struct or array in a none-struct resource.
  2641. std::vector<QualType> ScalarTys;
  2642. CollectScalarTypes(ScalarTys, resultTy);
  2643. unsigned size = ScalarTys.size();
  2644. if (size == 0) {
  2645. DiagnosticsEngine &Diags = CGM.getDiags();
  2646. unsigned DiagID = Diags.getCustomDiagID(
  2647. DiagnosticsEngine::Error,
  2648. "object's templated type must have at least one element");
  2649. Diags.Report(loc, DiagID);
  2650. return false;
  2651. }
  2652. if (size > 4) {
  2653. DiagnosticsEngine &Diags = CGM.getDiags();
  2654. unsigned DiagID = Diags.getCustomDiagID(
  2655. DiagnosticsEngine::Error, "elements of typed buffers and textures "
  2656. "must fit in four 32-bit quantities");
  2657. Diags.Report(loc, DiagID);
  2658. return false;
  2659. }
  2660. EltTy = ScalarTys[0];
  2661. for (QualType ScalarTy : ScalarTys) {
  2662. if (ScalarTy != EltTy) {
  2663. DiagnosticsEngine &Diags = CGM.getDiags();
  2664. unsigned DiagID = Diags.getCustomDiagID(
  2665. DiagnosticsEngine::Error,
  2666. "all template type components must have the same type");
  2667. Diags.Report(loc, DiagID);
  2668. return false;
  2669. }
  2670. }
  2671. }
  2672. bool bSNorm = false;
  2673. bool bHasNormAttribute = hlsl::HasHLSLUNormSNorm(Ty, &bSNorm);
  2674. if (const BuiltinType *BTy = EltTy->getAs<BuiltinType>()) {
  2675. CompType::Kind kind = BuiltinTyToCompTy(BTy, bHasNormAttribute && bSNorm, bHasNormAttribute && !bSNorm);
  2676. // 64bits types are implemented with u32.
  2677. if (kind == CompType::Kind::U64 || kind == CompType::Kind::I64 ||
  2678. kind == CompType::Kind::SNormF64 ||
  2679. kind == CompType::Kind::UNormF64 || kind == CompType::Kind::F64) {
  2680. kind = CompType::Kind::U32;
  2681. }
  2682. hlslRes->SetCompType(kind);
  2683. } else {
  2684. DXASSERT(!bHasNormAttribute, "snorm/unorm on invalid type");
  2685. }
  2686. }
  2687. if (hlslRes->IsFeedbackTexture()) {
  2688. hlslRes->SetSamplerFeedbackType(
  2689. static_cast<DXIL::SamplerFeedbackType>(hlsl::GetHLSLResourceTemplateUInt(QualTy)));
  2690. }
  2691. hlslRes->SetROV(RD->getName().startswith("RasterizerOrdered"));
  2692. if (kind == hlsl::DxilResource::Kind::TypedBuffer ||
  2693. kind == hlsl::DxilResource::Kind::StructuredBuffer) {
  2694. const ClassTemplateSpecializationDecl *templateDecl =
  2695. cast<ClassTemplateSpecializationDecl>(RD);
  2696. const clang::TemplateArgument &retTyArg =
  2697. templateDecl->getTemplateArgs()[0];
  2698. llvm::Type *retTy = CGM.getTypes().ConvertType(retTyArg.getAsType());
  2699. uint32_t strideInBytes = dataLayout.getTypeAllocSize(retTy);
  2700. hlslRes->SetElementStride(strideInBytes);
  2701. }
  2702. if (resClass == hlsl::DxilResourceBase::Class::SRV) {
  2703. if (hlslRes->IsGloballyCoherent()) {
  2704. DiagnosticsEngine &Diags = CGM.getDiags();
  2705. unsigned DiagID = Diags.getCustomDiagID(
  2706. DiagnosticsEngine::Error, "globallycoherent can only be used with "
  2707. "Unordered Access View buffers.");
  2708. Diags.Report(loc, DiagID);
  2709. return false;
  2710. }
  2711. hlslRes->SetRW(false);
  2712. hlslRes->SetID(m_pHLModule->GetSRVs().size());
  2713. } else {
  2714. hlslRes->SetRW(true);
  2715. hlslRes->SetID(m_pHLModule->GetUAVs().size());
  2716. }
  2717. return true;
  2718. }
  2719. uint32_t CGMSHLSLRuntime::AddUAVSRV(VarDecl *decl,
  2720. hlsl::DxilResourceBase::Class resClass) {
  2721. llvm::GlobalVariable *val =
  2722. cast<llvm::GlobalVariable>(CGM.GetAddrOfGlobalVar(decl));
  2723. unique_ptr<HLResource> hlslRes(new HLResource);
  2724. hlslRes->SetLowerBound(UINT_MAX);
  2725. hlslRes->SetSpaceID(UINT_MAX);
  2726. hlslRes->SetGlobalSymbol(val);
  2727. hlslRes->SetGlobalName(decl->getName());
  2728. QualType VarTy;
  2729. unsigned rangeSize;
  2730. GetResourceDeclElemTypeAndRangeSize(CGM, *m_pHLModule, *decl,
  2731. VarTy, rangeSize);
  2732. hlslRes->SetRangeSize(rangeSize);
  2733. InitFromUnusualAnnotations(*hlslRes, *decl);
  2734. if (decl->hasAttr<HLSLGloballyCoherentAttr>()) {
  2735. hlslRes->SetGloballyCoherent(true);
  2736. }
  2737. if (!SetUAVSRV(decl->getLocation(), resClass, hlslRes.get(), VarTy))
  2738. return 0;
  2739. if (resClass == hlsl::DxilResourceBase::Class::SRV) {
  2740. return m_pHLModule->AddSRV(std::move(hlslRes));
  2741. } else {
  2742. return m_pHLModule->AddUAV(std::move(hlslRes));
  2743. }
  2744. }
  2745. static bool IsResourceInType(const clang::ASTContext &context,
  2746. clang::QualType Ty) {
  2747. Ty = Ty.getCanonicalType();
  2748. if (const clang::ArrayType *arrayType = context.getAsArrayType(Ty)) {
  2749. return IsResourceInType(context, arrayType->getElementType());
  2750. } else if (const RecordType *RT = Ty->getAsStructureType()) {
  2751. if (KeywordToClass(RT->getDecl()->getName()) != DxilResourceBase::Class::Invalid)
  2752. return true;
  2753. const CXXRecordDecl* typeRecordDecl = RT->getAsCXXRecordDecl();
  2754. if (typeRecordDecl && !typeRecordDecl->isImplicit()) {
  2755. for (auto field : typeRecordDecl->fields()) {
  2756. if (IsResourceInType(context, field->getType()))
  2757. return true;
  2758. }
  2759. }
  2760. } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
  2761. if (const ClassTemplateSpecializationDecl *templateDecl =
  2762. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl())) {
  2763. if (KeywordToClass(templateDecl->getName()) != DxilResourceBase::Class::Invalid)
  2764. return true;
  2765. }
  2766. }
  2767. return false; // no resources found
  2768. }
  2769. void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
  2770. if (constDecl->getStorageClass() == SC_Static) {
  2771. // For static inside cbuffer, take as global static.
  2772. // Don't add to cbuffer.
  2773. CGM.EmitGlobal(constDecl);
  2774. // Add type annotation for static global types.
  2775. // May need it when cast from cbuf.
  2776. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  2777. unsigned arraySize = 0;
  2778. AddTypeAnnotation(constDecl->getType(), dxilTypeSys, arraySize);
  2779. return;
  2780. }
  2781. // Search defined structure for resource objects and fail
  2782. if (CB.GetRangeSize() > 1 &&
  2783. IsResourceInType(CGM.getContext(), constDecl->getType())) {
  2784. DiagnosticsEngine &Diags = CGM.getDiags();
  2785. unsigned DiagID = Diags.getCustomDiagID(
  2786. DiagnosticsEngine::Error,
  2787. "object types not supported in cbuffer/tbuffer view arrays.");
  2788. Diags.Report(constDecl->getLocation(), DiagID);
  2789. return;
  2790. }
  2791. llvm::Constant *constVal = CGM.GetAddrOfGlobalVar(constDecl);
  2792. auto &regBindings = constantRegBindingMap[constVal];
  2793. bool isGlobalCB = CB.GetID() == globalCBIndex;
  2794. uint32_t offset = 0;
  2795. bool userOffset = false;
  2796. for (hlsl::UnusualAnnotation *it : constDecl->getUnusualAnnotations()) {
  2797. switch (it->getKind()) {
  2798. case hlsl::UnusualAnnotation::UA_ConstantPacking: {
  2799. if (!isGlobalCB) {
  2800. // TODO: check cannot mix packoffset elements with nonpackoffset
  2801. // elements in a cbuffer.
  2802. hlsl::ConstantPacking *cp = cast<hlsl::ConstantPacking>(it);
  2803. offset = cp->Subcomponent << 2;
  2804. offset += cp->ComponentOffset;
  2805. // Change to byte.
  2806. offset <<= 2;
  2807. userOffset = true;
  2808. } else {
  2809. DiagnosticsEngine &Diags = CGM.getDiags();
  2810. unsigned DiagID = Diags.getCustomDiagID(
  2811. DiagnosticsEngine::Error,
  2812. "packoffset is only allowed in a constant buffer.");
  2813. Diags.Report(it->Loc, DiagID);
  2814. }
  2815. break;
  2816. }
  2817. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  2818. RegisterAssignment *ra = cast<RegisterAssignment>(it);
  2819. if (isGlobalCB) {
  2820. if (ra->RegisterSpace.hasValue()) {
  2821. DiagnosticsEngine& Diags = CGM.getDiags();
  2822. unsigned DiagID = Diags.getCustomDiagID(
  2823. DiagnosticsEngine::Error,
  2824. "register space cannot be specified on global constants.");
  2825. Diags.Report(it->Loc, DiagID);
  2826. }
  2827. offset = ra->RegisterNumber << 2;
  2828. // Change to byte.
  2829. offset <<= 2;
  2830. userOffset = true;
  2831. }
  2832. switch (ra->RegisterType) {
  2833. default:
  2834. break;
  2835. case 't':
  2836. regBindings.emplace_back(
  2837. std::make_pair(DXIL::ResourceClass::SRV, ra->RegisterNumber));
  2838. break;
  2839. case 'u':
  2840. regBindings.emplace_back(
  2841. std::make_pair(DXIL::ResourceClass::UAV, ra->RegisterNumber));
  2842. break;
  2843. case 's':
  2844. regBindings.emplace_back(
  2845. std::make_pair(DXIL::ResourceClass::Sampler, ra->RegisterNumber));
  2846. break;
  2847. }
  2848. break;
  2849. }
  2850. case hlsl::UnusualAnnotation::UA_SemanticDecl:
  2851. // skip semantic on constant
  2852. break;
  2853. }
  2854. }
  2855. std::unique_ptr<DxilResourceBase> pHlslConst = llvm::make_unique<DxilResourceBase>(DXIL::ResourceClass::Invalid);
  2856. pHlslConst->SetLowerBound(UINT_MAX);
  2857. pHlslConst->SetSpaceID(0);
  2858. pHlslConst->SetGlobalSymbol(cast<llvm::GlobalVariable>(constVal));
  2859. pHlslConst->SetGlobalName(constDecl->getName());
  2860. if (userOffset) {
  2861. pHlslConst->SetLowerBound(offset);
  2862. }
  2863. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  2864. // Just add type annotation here.
  2865. // Offset will be allocated later.
  2866. QualType Ty = constDecl->getType();
  2867. if (CB.GetRangeSize() != 1) {
  2868. while (Ty->isArrayType()) {
  2869. Ty = Ty->getAsArrayTypeUnsafe()->getElementType();
  2870. }
  2871. }
  2872. unsigned arrayEltSize = 0;
  2873. unsigned size = AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  2874. pHlslConst->SetRangeSize(size);
  2875. CB.AddConst(pHlslConst);
  2876. // Save fieldAnnotation for the const var.
  2877. DxilFieldAnnotation fieldAnnotation;
  2878. if (userOffset)
  2879. fieldAnnotation.SetCBufferOffset(offset);
  2880. // Get the nested element type.
  2881. if (Ty->isArrayType()) {
  2882. while (const ConstantArrayType *arrayTy =
  2883. CGM.getContext().getAsConstantArrayType(Ty)) {
  2884. Ty = arrayTy->getElementType();
  2885. }
  2886. }
  2887. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  2888. ConstructFieldAttributedAnnotation(fieldAnnotation, Ty, bDefaultRowMajor);
  2889. m_ConstVarAnnotationMap[constVal] = fieldAnnotation;
  2890. }
  2891. uint32_t CGMSHLSLRuntime::AddCBuffer(HLSLBufferDecl *D) {
  2892. unique_ptr<HLCBuffer> CB = llvm::make_unique<HLCBuffer>();
  2893. // setup the CB
  2894. CB->SetGlobalSymbol(nullptr);
  2895. CB->SetGlobalName(D->getNameAsString());
  2896. CB->SetSpaceID(UINT_MAX);
  2897. CB->SetLowerBound(UINT_MAX);
  2898. if (!D->isCBuffer()) {
  2899. CB->SetKind(DXIL::ResourceKind::TBuffer);
  2900. }
  2901. // the global variable will only used once by the createHandle?
  2902. // SetHandle(llvm::Value *pHandle);
  2903. InitFromUnusualAnnotations(*CB, *D);
  2904. // Add constant
  2905. if (D->isConstantBufferView()) {
  2906. VarDecl *constDecl = cast<VarDecl>(*D->decls_begin());
  2907. CB->SetRangeSize(1);
  2908. QualType Ty = constDecl->getType();
  2909. if (Ty->isArrayType()) {
  2910. if (!Ty->isIncompleteArrayType()) {
  2911. unsigned arraySize = 1;
  2912. while (Ty->isArrayType()) {
  2913. Ty = Ty->getCanonicalTypeUnqualified();
  2914. const ConstantArrayType *AT = cast<ConstantArrayType>(Ty);
  2915. arraySize *= AT->getSize().getLimitedValue();
  2916. Ty = AT->getElementType();
  2917. }
  2918. CB->SetRangeSize(arraySize);
  2919. } else {
  2920. CB->SetRangeSize(UINT_MAX);
  2921. }
  2922. }
  2923. AddConstant(constDecl, *CB.get());
  2924. } else {
  2925. auto declsEnds = D->decls_end();
  2926. CB->SetRangeSize(1);
  2927. for (auto it = D->decls_begin(); it != declsEnds; it++) {
  2928. if (VarDecl *constDecl = dyn_cast<VarDecl>(*it)) {
  2929. AddConstant(constDecl, *CB.get());
  2930. } else if (isa<EmptyDecl>(*it)) {
  2931. // Nothing to do for this declaration.
  2932. } else if (isa<CXXRecordDecl>(*it)) {
  2933. // Nothing to do for this declaration.
  2934. } else if (isa<FunctionDecl>(*it)) {
  2935. // A function within an cbuffer is effectively a top-level function,
  2936. // as it only refers to globally scoped declarations.
  2937. this->CGM.EmitTopLevelDecl(*it);
  2938. } else {
  2939. HLSLBufferDecl *inner = cast<HLSLBufferDecl>(*it);
  2940. GetOrCreateCBuffer(inner);
  2941. }
  2942. }
  2943. }
  2944. CB->SetID(m_pHLModule->GetCBuffers().size());
  2945. return m_pHLModule->AddCBuffer(std::move(CB));
  2946. }
  2947. HLCBuffer &CGMSHLSLRuntime::GetOrCreateCBuffer(HLSLBufferDecl *D) {
  2948. if (constantBufMap.count(D) != 0) {
  2949. uint32_t cbIndex = constantBufMap[D];
  2950. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(cbIndex)));
  2951. }
  2952. uint32_t cbID = AddCBuffer(D);
  2953. constantBufMap[D] = cbID;
  2954. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(cbID)));
  2955. }
  2956. bool CGMSHLSLRuntime::IsPatchConstantFunction(const Function *F) {
  2957. DXASSERT_NOMSG(F != nullptr);
  2958. for (auto && p : patchConstantFunctionMap) {
  2959. if (p.second.Func == F) return true;
  2960. }
  2961. return false;
  2962. }
  2963. void CGMSHLSLRuntime::SetEntryFunction() {
  2964. if (Entry.Func == nullptr) {
  2965. DiagnosticsEngine &Diags = CGM.getDiags();
  2966. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  2967. "cannot find entry function %0");
  2968. Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLEntryFunction;
  2969. return;
  2970. }
  2971. m_pHLModule->SetEntryFunction(Entry.Func);
  2972. }
  2973. // Here the size is CB size.
  2974. // Offset still needs to be aligned based on type since this
  2975. // is the legacy cbuffer global path.
  2976. static unsigned AlignCBufferOffset(unsigned offset, unsigned size, llvm::Type *Ty, bool bRowMajor) {
  2977. DXASSERT(!(offset & 1), "otherwise we have an invalid offset.");
  2978. bool bNeedNewRow = Ty->isArrayTy();
  2979. if (!bNeedNewRow && Ty->isStructTy()) {
  2980. if (HLMatrixType mat = HLMatrixType::dyn_cast(Ty)) {
  2981. bNeedNewRow |= !bRowMajor && mat.getNumColumns() > 1;
  2982. bNeedNewRow |= bRowMajor && mat.getNumRows() > 1;
  2983. } else {
  2984. bNeedNewRow = true;
  2985. }
  2986. }
  2987. unsigned scalarSizeInBytes = Ty->getScalarSizeInBits() / 8;
  2988. return AlignBufferOffsetInLegacy(offset, size, scalarSizeInBytes, bNeedNewRow);
  2989. }
  2990. static unsigned AllocateDxilConstantBuffer(HLCBuffer &CB,
  2991. std::unordered_map<Constant*, DxilFieldAnnotation> &constVarAnnotationMap) {
  2992. unsigned offset = 0;
  2993. // Scan user allocated constants first.
  2994. // Update offset.
  2995. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2996. if (C->GetLowerBound() == UINT_MAX)
  2997. continue;
  2998. unsigned size = C->GetRangeSize();
  2999. unsigned nextOffset = size + C->GetLowerBound();
  3000. if (offset < nextOffset)
  3001. offset = nextOffset;
  3002. }
  3003. // Alloc after user allocated constants.
  3004. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  3005. if (C->GetLowerBound() != UINT_MAX)
  3006. continue;
  3007. unsigned size = C->GetRangeSize();
  3008. llvm::Type *Ty = C->GetGlobalSymbol()->getType()->getPointerElementType();
  3009. auto fieldAnnotation = constVarAnnotationMap.at(C->GetGlobalSymbol());
  3010. bool bRowMajor = HLMatrixType::isa(Ty)
  3011. ? fieldAnnotation.GetMatrixAnnotation().Orientation == MatrixOrientation::RowMajor
  3012. : false;
  3013. // Align offset.
  3014. offset = AlignCBufferOffset(offset, size, Ty, bRowMajor);
  3015. if (C->GetLowerBound() == UINT_MAX) {
  3016. C->SetLowerBound(offset);
  3017. }
  3018. offset += size;
  3019. }
  3020. return offset;
  3021. }
  3022. static void AddRegBindingsForResourceInConstantBuffer(
  3023. HLModule *pHLModule,
  3024. llvm::DenseMap<llvm::Constant *,
  3025. llvm::SmallVector<std::pair<DXIL::ResourceClass, unsigned>,
  3026. 1>> &constantRegBindingMap) {
  3027. for (unsigned i = 0; i < pHLModule->GetCBuffers().size(); i++) {
  3028. HLCBuffer &CB = *static_cast<HLCBuffer *>(&(pHLModule->GetCBuffer(i)));
  3029. auto &Constants = CB.GetConstants();
  3030. for (unsigned j = 0; j < Constants.size(); j++) {
  3031. const std::unique_ptr<DxilResourceBase> &C = Constants[j];
  3032. Constant *CGV = C->GetGlobalSymbol();
  3033. auto &regBindings = constantRegBindingMap[CGV];
  3034. if (regBindings.empty())
  3035. continue;
  3036. unsigned Srv = UINT_MAX;
  3037. unsigned Uav = UINT_MAX;
  3038. unsigned Sampler = UINT_MAX;
  3039. for (auto it : regBindings) {
  3040. unsigned RegNum = it.second;
  3041. switch (it.first) {
  3042. case DXIL::ResourceClass::SRV:
  3043. Srv = RegNum;
  3044. break;
  3045. case DXIL::ResourceClass::UAV:
  3046. Uav = RegNum;
  3047. break;
  3048. case DXIL::ResourceClass::Sampler:
  3049. Sampler = RegNum;
  3050. break;
  3051. default:
  3052. DXASSERT(0, "invalid resource class");
  3053. break;
  3054. }
  3055. }
  3056. pHLModule->AddRegBinding(CB.GetID(), j, Srv, Uav, Sampler);
  3057. }
  3058. }
  3059. }
  3060. static void AllocateDxilConstantBuffers(HLModule *pHLModule,
  3061. std::unordered_map<Constant*, DxilFieldAnnotation> &constVarAnnotationMap) {
  3062. for (unsigned i = 0; i < pHLModule->GetCBuffers().size(); i++) {
  3063. HLCBuffer &CB = *static_cast<HLCBuffer*>(&(pHLModule->GetCBuffer(i)));
  3064. unsigned size = AllocateDxilConstantBuffer(CB, constVarAnnotationMap);
  3065. CB.SetSize(size);
  3066. }
  3067. }
  3068. static void ReplaceUseInFunction(Value *V, Value *NewV, Function *F,
  3069. IRBuilder<> &Builder) {
  3070. for (auto U = V->user_begin(); U != V->user_end(); ) {
  3071. User *user = *(U++);
  3072. if (Instruction *I = dyn_cast<Instruction>(user)) {
  3073. if (I->getParent()->getParent() == F) {
  3074. // replace use with GEP if in F
  3075. for (unsigned i = 0; i < I->getNumOperands(); i++) {
  3076. if (I->getOperand(i) == V)
  3077. I->setOperand(i, NewV);
  3078. }
  3079. }
  3080. } else {
  3081. // For constant operator, create local clone which use GEP.
  3082. // Only support GEP and bitcast.
  3083. if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(user)) {
  3084. std::vector<Value *> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  3085. Value *NewGEP = Builder.CreateInBoundsGEP(NewV, idxList);
  3086. ReplaceUseInFunction(GEPOp, NewGEP, F, Builder);
  3087. } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(user)) {
  3088. // Change the init val into NewV with Store.
  3089. GV->setInitializer(nullptr);
  3090. Builder.CreateStore(NewV, GV);
  3091. } else {
  3092. // Must be bitcast here.
  3093. BitCastOperator *BC = cast<BitCastOperator>(user);
  3094. Value *NewBC = Builder.CreateBitCast(NewV, BC->getType());
  3095. ReplaceUseInFunction(BC, NewBC, F, Builder);
  3096. }
  3097. }
  3098. }
  3099. }
  3100. void MarkUsedFunctionForConst(Value *V, std::unordered_set<Function*> &usedFunc) {
  3101. for (auto U = V->user_begin(); U != V->user_end();) {
  3102. User *user = *(U++);
  3103. if (Instruction *I = dyn_cast<Instruction>(user)) {
  3104. Function *F = I->getParent()->getParent();
  3105. usedFunc.insert(F);
  3106. } else {
  3107. // For constant operator, create local clone which use GEP.
  3108. // Only support GEP and bitcast.
  3109. if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(user)) {
  3110. MarkUsedFunctionForConst(GEPOp, usedFunc);
  3111. } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(user)) {
  3112. MarkUsedFunctionForConst(GV, usedFunc);
  3113. } else {
  3114. // Must be bitcast here.
  3115. BitCastOperator *BC = cast<BitCastOperator>(user);
  3116. MarkUsedFunctionForConst(BC, usedFunc);
  3117. }
  3118. }
  3119. }
  3120. }
  3121. static Function * GetOrCreateHLCreateHandle(HLModule &HLM, llvm::Type *HandleTy,
  3122. ArrayRef<Value*> paramList, MDNode *MD) {
  3123. SmallVector<llvm::Type *, 4> paramTyList;
  3124. for (Value *param : paramList) {
  3125. paramTyList.emplace_back(param->getType());
  3126. }
  3127. llvm::FunctionType *funcTy =
  3128. llvm::FunctionType::get(HandleTy, paramTyList, false);
  3129. llvm::Module &M = *HLM.GetModule();
  3130. Function *CreateHandle = GetOrCreateHLFunctionWithBody(M, funcTy, HLOpcodeGroup::HLCreateHandle,
  3131. /*opcode*/ 0, "");
  3132. if (CreateHandle->empty()) {
  3133. // Add body.
  3134. BasicBlock *BB =
  3135. BasicBlock::Create(CreateHandle->getContext(), "Entry", CreateHandle);
  3136. IRBuilder<> Builder(BB);
  3137. // Just return undef to make a body.
  3138. Builder.CreateRet(UndefValue::get(HandleTy));
  3139. // Mark resource attribute.
  3140. HLM.MarkDxilResourceAttrib(CreateHandle, MD);
  3141. }
  3142. return CreateHandle;
  3143. }
  3144. static bool CreateCBufferVariable(HLCBuffer &CB,
  3145. HLModule &HLM, llvm::Type *HandleTy) {
  3146. bool bUsed = false;
  3147. // Build Struct for CBuffer.
  3148. SmallVector<llvm::Type*, 4> Elements;
  3149. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  3150. Value *GV = C->GetGlobalSymbol();
  3151. if (GV->hasNUsesOrMore(1))
  3152. bUsed = true;
  3153. // Global variable must be pointer type.
  3154. llvm::Type *Ty = GV->getType()->getPointerElementType();
  3155. Elements.emplace_back(Ty);
  3156. }
  3157. // Don't create CBuffer variable for unused cbuffer.
  3158. if (!bUsed)
  3159. return false;
  3160. llvm::Module &M = *HLM.GetModule();
  3161. bool isCBArray = CB.GetRangeSize() != 1;
  3162. llvm::GlobalVariable *cbGV = nullptr;
  3163. llvm::Type *cbTy = nullptr;
  3164. unsigned cbIndexDepth = 0;
  3165. if (!isCBArray) {
  3166. llvm::StructType *CBStructTy =
  3167. llvm::StructType::create(Elements, CB.GetGlobalName());
  3168. cbGV = new llvm::GlobalVariable(M, CBStructTy, /*IsConstant*/ true,
  3169. llvm::GlobalValue::ExternalLinkage,
  3170. /*InitVal*/ nullptr, CB.GetGlobalName());
  3171. cbTy = cbGV->getType();
  3172. } else {
  3173. // For array of ConstantBuffer, create array of struct instead of struct of
  3174. // array.
  3175. DXASSERT(CB.GetConstants().size() == 1,
  3176. "ConstantBuffer should have 1 constant");
  3177. Value *GV = CB.GetConstants()[0]->GetGlobalSymbol();
  3178. llvm::Type *CBEltTy =
  3179. GV->getType()->getPointerElementType()->getArrayElementType();
  3180. cbIndexDepth = 1;
  3181. while (CBEltTy->isArrayTy()) {
  3182. CBEltTy = CBEltTy->getArrayElementType();
  3183. cbIndexDepth++;
  3184. }
  3185. // Add one level struct type to match normal case.
  3186. llvm::StructType *CBStructTy =
  3187. llvm::StructType::create({CBEltTy}, CB.GetGlobalName());
  3188. llvm::ArrayType *CBArrayTy =
  3189. llvm::ArrayType::get(CBStructTy, CB.GetRangeSize());
  3190. cbGV = new llvm::GlobalVariable(M, CBArrayTy, /*IsConstant*/ true,
  3191. llvm::GlobalValue::ExternalLinkage,
  3192. /*InitVal*/ nullptr, CB.GetGlobalName());
  3193. cbTy = llvm::PointerType::get(CBStructTy,
  3194. cbGV->getType()->getPointerAddressSpace());
  3195. }
  3196. CB.SetGlobalSymbol(cbGV);
  3197. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  3198. llvm::Type *idxTy = opcodeTy;
  3199. Constant *zeroIdx = ConstantInt::get(opcodeTy, 0);
  3200. MDNode *MD = HLM.DxilCBufferToMDNode(CB);
  3201. Value *HandleArgs[] = { zeroIdx, cbGV, zeroIdx };
  3202. Function *CreateHandleFunc = GetOrCreateHLCreateHandle(HLM, HandleTy, HandleArgs, MD);
  3203. llvm::FunctionType *SubscriptFuncTy =
  3204. llvm::FunctionType::get(cbTy, { opcodeTy, HandleTy, idxTy}, false);
  3205. Function *subscriptFunc =
  3206. GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
  3207. (unsigned)HLSubscriptOpcode::CBufferSubscript);
  3208. Constant *opArg = ConstantInt::get(opcodeTy, (unsigned)HLSubscriptOpcode::CBufferSubscript);
  3209. Value *args[] = { opArg, nullptr, zeroIdx };
  3210. llvm::LLVMContext &Context = M.getContext();
  3211. llvm::Type *i32Ty = llvm::Type::getInt32Ty(Context);
  3212. Value *zero = ConstantInt::get(i32Ty, (uint64_t)0);
  3213. std::vector<Value *> indexArray(CB.GetConstants().size());
  3214. std::vector<std::unordered_set<Function*>> constUsedFuncList(CB.GetConstants().size());
  3215. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  3216. Value *idx = ConstantInt::get(i32Ty, C->GetID());
  3217. indexArray[C->GetID()] = idx;
  3218. Value *GV = C->GetGlobalSymbol();
  3219. MarkUsedFunctionForConst(GV, constUsedFuncList[C->GetID()]);
  3220. }
  3221. for (Function &F : M.functions()) {
  3222. if (F.isDeclaration())
  3223. continue;
  3224. if (GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
  3225. continue;
  3226. IRBuilder<> Builder(F.getEntryBlock().getFirstInsertionPt());
  3227. // create HL subscript to make all the use of cbuffer start from it.
  3228. HandleArgs[HLOperandIndex::kCreateHandleResourceOpIdx] = cbGV;
  3229. CallInst *Handle = Builder.CreateCall(CreateHandleFunc, HandleArgs);
  3230. args[HLOperandIndex::kSubscriptObjectOpIdx] = Handle;
  3231. Instruction *cbSubscript =
  3232. cast<Instruction>(Builder.CreateCall(subscriptFunc, {args}));
  3233. // Replace constant var with GEP pGV
  3234. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  3235. Value *GV = C->GetGlobalSymbol();
  3236. if (constUsedFuncList[C->GetID()].count(&F) == 0)
  3237. continue;
  3238. Value *idx = indexArray[C->GetID()];
  3239. if (!isCBArray) {
  3240. Instruction *GEP = cast<Instruction>(
  3241. Builder.CreateInBoundsGEP(cbSubscript, {zero, idx}));
  3242. // TODO: make sure the debug info is synced to GEP.
  3243. // GEP->setDebugLoc(GV);
  3244. ReplaceUseInFunction(GV, GEP, &F, Builder);
  3245. // Delete if no use in F.
  3246. if (GEP->user_empty())
  3247. GEP->eraseFromParent();
  3248. } else {
  3249. for (auto U = GV->user_begin(); U != GV->user_end();) {
  3250. User *user = *(U++);
  3251. if (user->user_empty())
  3252. continue;
  3253. Instruction *I = dyn_cast<Instruction>(user);
  3254. if (I && I->getParent()->getParent() != &F)
  3255. continue;
  3256. IRBuilder<> *instBuilder = &Builder;
  3257. unique_ptr<IRBuilder<>> B;
  3258. if (I) {
  3259. B = llvm::make_unique<IRBuilder<>>(I);
  3260. instBuilder = B.get();
  3261. }
  3262. GEPOperator *GEPOp = cast<GEPOperator>(user);
  3263. std::vector<Value *> idxList;
  3264. DXASSERT(GEPOp->getNumIndices() >= 1 + cbIndexDepth,
  3265. "must indexing ConstantBuffer array");
  3266. idxList.reserve(GEPOp->getNumIndices() - (cbIndexDepth - 1));
  3267. gep_type_iterator GI = gep_type_begin(*GEPOp),
  3268. E = gep_type_end(*GEPOp);
  3269. idxList.push_back(GI.getOperand());
  3270. // change array index with 0 for struct index.
  3271. idxList.push_back(zero);
  3272. GI++;
  3273. Value *arrayIdx = GI.getOperand();
  3274. GI++;
  3275. for (unsigned curIndex = 1; GI != E && curIndex < cbIndexDepth;
  3276. ++GI, ++curIndex) {
  3277. arrayIdx = instBuilder->CreateMul(
  3278. arrayIdx, Builder.getInt32(GI->getArrayNumElements()));
  3279. arrayIdx = instBuilder->CreateAdd(arrayIdx, GI.getOperand());
  3280. }
  3281. for (; GI != E; ++GI) {
  3282. idxList.push_back(GI.getOperand());
  3283. }
  3284. HandleArgs[HLOperandIndex::kCreateHandleIndexOpIdx] = arrayIdx;
  3285. CallInst *Handle =
  3286. instBuilder->CreateCall(CreateHandleFunc, HandleArgs);
  3287. args[HLOperandIndex::kSubscriptObjectOpIdx] = Handle;
  3288. args[HLOperandIndex::kSubscriptIndexOpIdx] = arrayIdx;
  3289. Instruction *cbSubscript =
  3290. cast<Instruction>(instBuilder->CreateCall(subscriptFunc, {args}));
  3291. Instruction *NewGEP = cast<Instruction>(
  3292. instBuilder->CreateInBoundsGEP(cbSubscript, idxList));
  3293. ReplaceUseInFunction(GEPOp, NewGEP, &F, *instBuilder);
  3294. }
  3295. }
  3296. }
  3297. // Delete if no use in F.
  3298. if (cbSubscript->user_empty()) {
  3299. cbSubscript->eraseFromParent();
  3300. Handle->eraseFromParent();
  3301. } else {
  3302. // merge GEP use for cbSubscript.
  3303. HLModule::MergeGepUse(cbSubscript);
  3304. }
  3305. }
  3306. return true;
  3307. }
  3308. static void ConstructCBufferAnnotation(
  3309. HLCBuffer &CB, DxilTypeSystem &dxilTypeSys,
  3310. std::unordered_map<Constant *, DxilFieldAnnotation> &AnnotationMap) {
  3311. Value *GV = CB.GetGlobalSymbol();
  3312. llvm::StructType *CBStructTy =
  3313. dyn_cast<llvm::StructType>(GV->getType()->getPointerElementType());
  3314. if (!CBStructTy) {
  3315. // For Array of ConstantBuffer.
  3316. llvm::ArrayType *CBArrayTy =
  3317. cast<llvm::ArrayType>(GV->getType()->getPointerElementType());
  3318. CBStructTy = cast<llvm::StructType>(CBArrayTy->getArrayElementType());
  3319. }
  3320. DxilStructAnnotation *CBAnnotation =
  3321. dxilTypeSys.AddStructAnnotation(CBStructTy);
  3322. CBAnnotation->SetCBufferSize(CB.GetSize());
  3323. // Set fieldAnnotation for each constant var.
  3324. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  3325. Constant *GV = C->GetGlobalSymbol();
  3326. DxilFieldAnnotation &fieldAnnotation =
  3327. CBAnnotation->GetFieldAnnotation(C->GetID());
  3328. fieldAnnotation = AnnotationMap[GV];
  3329. // This is after CBuffer allocation.
  3330. fieldAnnotation.SetCBufferOffset(C->GetLowerBound());
  3331. fieldAnnotation.SetFieldName(C->GetGlobalName());
  3332. }
  3333. }
  3334. static void ConstructCBuffer(
  3335. HLModule *pHLModule,
  3336. llvm::Type *CBufferType,
  3337. std::unordered_map<Constant *, DxilFieldAnnotation> &AnnotationMap) {
  3338. DxilTypeSystem &dxilTypeSys = pHLModule->GetTypeSystem();
  3339. llvm::Type *HandleTy = pHLModule->GetOP()->GetHandleType();
  3340. for (unsigned i = 0; i < pHLModule->GetCBuffers().size(); i++) {
  3341. HLCBuffer &CB = *static_cast<HLCBuffer*>(&(pHLModule->GetCBuffer(i)));
  3342. if (CB.GetConstants().size() == 0) {
  3343. // Create Fake variable for cbuffer which is empty.
  3344. llvm::GlobalVariable *pGV = new llvm::GlobalVariable(
  3345. *pHLModule->GetModule(), CBufferType, true,
  3346. llvm::GlobalValue::ExternalLinkage, nullptr, CB.GetGlobalName());
  3347. CB.SetGlobalSymbol(pGV);
  3348. } else {
  3349. bool bCreated =
  3350. CreateCBufferVariable(CB, *pHLModule, HandleTy);
  3351. if (bCreated)
  3352. ConstructCBufferAnnotation(CB, dxilTypeSys, AnnotationMap);
  3353. else {
  3354. // Create Fake variable for cbuffer which is unused.
  3355. llvm::GlobalVariable *pGV = new llvm::GlobalVariable(
  3356. *pHLModule->GetModule(), CBufferType, true,
  3357. llvm::GlobalValue::ExternalLinkage, nullptr, CB.GetGlobalName());
  3358. CB.SetGlobalSymbol(pGV);
  3359. }
  3360. }
  3361. // Clear the constants which useless now.
  3362. CB.GetConstants().clear();
  3363. }
  3364. }
  3365. static void ReplaceBoolVectorSubscript(CallInst *CI) {
  3366. Value *Ptr = CI->getArgOperand(0);
  3367. Value *Idx = CI->getArgOperand(1);
  3368. Value *IdxList[] = {ConstantInt::get(Idx->getType(), 0), Idx};
  3369. for (auto It = CI->user_begin(), E = CI->user_end(); It != E;) {
  3370. Instruction *user = cast<Instruction>(*(It++));
  3371. IRBuilder<> Builder(user);
  3372. Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  3373. if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
  3374. Value *NewLd = Builder.CreateLoad(GEP);
  3375. Value *cast = Builder.CreateZExt(NewLd, LI->getType());
  3376. LI->replaceAllUsesWith(cast);
  3377. LI->eraseFromParent();
  3378. } else {
  3379. // Must be a store inst here.
  3380. StoreInst *SI = cast<StoreInst>(user);
  3381. Value *V = SI->getValueOperand();
  3382. Value *cast =
  3383. Builder.CreateICmpNE(V, llvm::ConstantInt::get(V->getType(), 0));
  3384. Builder.CreateStore(cast, GEP);
  3385. SI->eraseFromParent();
  3386. }
  3387. }
  3388. CI->eraseFromParent();
  3389. }
  3390. static void ReplaceBoolVectorSubscript(Function *F) {
  3391. for (auto It = F->user_begin(), E = F->user_end(); It != E; ) {
  3392. User *user = *(It++);
  3393. CallInst *CI = cast<CallInst>(user);
  3394. ReplaceBoolVectorSubscript(CI);
  3395. }
  3396. }
  3397. // Add function body for intrinsic if possible.
  3398. static Function *CreateOpFunction(llvm::Module &M, Function *F,
  3399. llvm::FunctionType *funcTy,
  3400. HLOpcodeGroup group, unsigned opcode) {
  3401. Function *opFunc = nullptr;
  3402. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  3403. if (group == HLOpcodeGroup::HLIntrinsic) {
  3404. IntrinsicOp intriOp = static_cast<IntrinsicOp>(opcode);
  3405. switch (intriOp) {
  3406. case IntrinsicOp::MOP_Append:
  3407. case IntrinsicOp::MOP_Consume: {
  3408. bool bAppend = intriOp == IntrinsicOp::MOP_Append;
  3409. llvm::Type *handleTy = funcTy->getParamType(HLOperandIndex::kHandleOpIdx);
  3410. // Don't generate body for OutputStream::Append.
  3411. if (bAppend && HLModule::IsStreamOutputPtrType(handleTy)) {
  3412. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  3413. break;
  3414. }
  3415. opFunc = GetOrCreateHLFunctionWithBody(M, funcTy, group, opcode,
  3416. bAppend ? "append" : "consume");
  3417. llvm::Type *counterTy = llvm::Type::getInt32Ty(M.getContext());
  3418. llvm::FunctionType *IncCounterFuncTy =
  3419. llvm::FunctionType::get(counterTy, {opcodeTy, handleTy}, false);
  3420. unsigned counterOpcode = bAppend ? (unsigned)IntrinsicOp::MOP_IncrementCounter:
  3421. (unsigned)IntrinsicOp::MOP_DecrementCounter;
  3422. Function *incCounterFunc =
  3423. GetOrCreateHLFunction(M, IncCounterFuncTy, group,
  3424. counterOpcode);
  3425. llvm::Type *idxTy = counterTy;
  3426. llvm::Type *valTy = bAppend ?
  3427. funcTy->getParamType(HLOperandIndex::kAppendValOpIndex):funcTy->getReturnType();
  3428. // Return type for subscript should be pointer type, hence in memory representation
  3429. llvm::Type *subscriptTy = valTy;
  3430. bool isBoolScalarOrVector = false;
  3431. if (!subscriptTy->isPointerTy()) {
  3432. if (subscriptTy->getScalarType()->isIntegerTy(1)) {
  3433. isBoolScalarOrVector = true;
  3434. llvm::Type *memReprType = llvm::IntegerType::get(subscriptTy->getContext(), 32);
  3435. subscriptTy = subscriptTy->isVectorTy()
  3436. ? llvm::VectorType::get(memReprType, subscriptTy->getVectorNumElements())
  3437. : memReprType;
  3438. }
  3439. subscriptTy = llvm::PointerType::get(subscriptTy, 0);
  3440. }
  3441. llvm::FunctionType *SubscriptFuncTy =
  3442. llvm::FunctionType::get(subscriptTy, {opcodeTy, handleTy, idxTy}, false);
  3443. Function *subscriptFunc =
  3444. GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
  3445. (unsigned)HLSubscriptOpcode::DefaultSubscript);
  3446. BasicBlock *BB = BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
  3447. IRBuilder<> Builder(BB);
  3448. auto argIter = opFunc->args().begin();
  3449. // Skip the opcode arg.
  3450. argIter++;
  3451. Argument *thisArg = argIter++;
  3452. // int counter = IncrementCounter/DecrementCounter(Buf);
  3453. Value *incCounterOpArg =
  3454. ConstantInt::get(idxTy, counterOpcode);
  3455. Value *counter =
  3456. Builder.CreateCall(incCounterFunc, {incCounterOpArg, thisArg});
  3457. // Buf[counter];
  3458. Value *subscriptOpArg = ConstantInt::get(
  3459. idxTy, (unsigned)HLSubscriptOpcode::DefaultSubscript);
  3460. Value *subscript =
  3461. Builder.CreateCall(subscriptFunc, {subscriptOpArg, thisArg, counter});
  3462. if (bAppend) {
  3463. Argument *valArg = argIter;
  3464. // Buf[counter] = val;
  3465. if (valTy->isPointerTy()) {
  3466. unsigned size = M.getDataLayout().getTypeAllocSize(subscript->getType()->getPointerElementType());
  3467. Builder.CreateMemCpy(subscript, valArg, size, 1);
  3468. }
  3469. else {
  3470. Value *storedVal = valArg;
  3471. // Convert to memory representation
  3472. if (isBoolScalarOrVector)
  3473. storedVal = Builder.CreateZExt(storedVal, subscriptTy->getPointerElementType(), "frombool");
  3474. Builder.CreateStore(storedVal, subscript);
  3475. }
  3476. Builder.CreateRetVoid();
  3477. } else {
  3478. // return Buf[counter];
  3479. if (valTy->isPointerTy())
  3480. Builder.CreateRet(subscript);
  3481. else {
  3482. Value *retVal = Builder.CreateLoad(subscript);
  3483. // Convert to register representation
  3484. if (isBoolScalarOrVector)
  3485. retVal = Builder.CreateICmpNE(retVal, Constant::getNullValue(retVal->getType()), "tobool");
  3486. Builder.CreateRet(retVal);
  3487. }
  3488. }
  3489. } break;
  3490. case IntrinsicOp::IOP_sincos: {
  3491. opFunc = GetOrCreateHLFunctionWithBody(M, funcTy, group, opcode, "sincos");
  3492. llvm::Type *valTy = funcTy->getParamType(HLOperandIndex::kTrinaryOpSrc0Idx);
  3493. llvm::FunctionType *sinFuncTy =
  3494. llvm::FunctionType::get(valTy, {opcodeTy, valTy}, false);
  3495. unsigned sinOp = static_cast<unsigned>(IntrinsicOp::IOP_sin);
  3496. unsigned cosOp = static_cast<unsigned>(IntrinsicOp::IOP_cos);
  3497. Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp);
  3498. Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp);
  3499. BasicBlock *BB = BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
  3500. IRBuilder<> Builder(BB);
  3501. auto argIter = opFunc->args().begin();
  3502. // Skip the opcode arg.
  3503. argIter++;
  3504. Argument *valArg = argIter++;
  3505. Argument *sinPtrArg = argIter++;
  3506. Argument *cosPtrArg = argIter++;
  3507. Value *sinOpArg =
  3508. ConstantInt::get(opcodeTy, sinOp);
  3509. Value *sinVal = Builder.CreateCall(sinFunc, {sinOpArg, valArg});
  3510. Builder.CreateStore(sinVal, sinPtrArg);
  3511. Value *cosOpArg =
  3512. ConstantInt::get(opcodeTy, cosOp);
  3513. Value *cosVal = Builder.CreateCall(cosFunc, {cosOpArg, valArg});
  3514. Builder.CreateStore(cosVal, cosPtrArg);
  3515. // Ret.
  3516. Builder.CreateRetVoid();
  3517. } break;
  3518. default:
  3519. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  3520. break;
  3521. }
  3522. }
  3523. else if (group == HLOpcodeGroup::HLExtIntrinsic) {
  3524. llvm::StringRef fnName = F->getName();
  3525. llvm::StringRef groupName = GetHLOpcodeGroupNameByAttr(F);
  3526. opFunc = GetOrCreateHLFunction(M, funcTy, group, &groupName, &fnName, opcode);
  3527. }
  3528. else {
  3529. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  3530. }
  3531. // Add attribute
  3532. if (F->hasFnAttribute(Attribute::ReadNone))
  3533. opFunc->addFnAttr(Attribute::ReadNone);
  3534. if (F->hasFnAttribute(Attribute::ReadOnly))
  3535. opFunc->addFnAttr(Attribute::ReadOnly);
  3536. return opFunc;
  3537. }
  3538. static Value *CreateHandleFromResPtr(
  3539. Value *ResPtr, HLModule &HLM, llvm::Type *HandleTy,
  3540. std::unordered_map<llvm::Type *, MDNode *> &resMetaMap,
  3541. IRBuilder<> &Builder) {
  3542. llvm::Type *objTy = ResPtr->getType()->getPointerElementType();
  3543. DXASSERT(resMetaMap.count(objTy), "cannot find resource type");
  3544. MDNode *MD = resMetaMap[objTy];
  3545. // Load to make sure resource only have Ld/St use so mem2reg could remove
  3546. // temp resource.
  3547. Value *ldObj = Builder.CreateLoad(ResPtr);
  3548. Value *opcode = Builder.getInt32(0);
  3549. Value *args[] = {opcode, ldObj};
  3550. Function *CreateHandle = GetOrCreateHLCreateHandle(HLM, HandleTy, args, MD);
  3551. CallInst *Handle = Builder.CreateCall(CreateHandle, args);
  3552. return Handle;
  3553. }
  3554. static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
  3555. unsigned opcode, llvm::Type *HandleTy,
  3556. std::unordered_map<llvm::Type *, MDNode*> &resMetaMap) {
  3557. llvm::Module &M = *HLM.GetModule();
  3558. llvm::FunctionType *oldFuncTy = F->getFunctionType();
  3559. SmallVector<llvm::Type *, 4> paramTyList;
  3560. // Add the opcode param
  3561. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  3562. paramTyList.emplace_back(opcodeTy);
  3563. paramTyList.append(oldFuncTy->param_begin(), oldFuncTy->param_end());
  3564. for (unsigned i = 1; i < paramTyList.size(); i++) {
  3565. llvm::Type *Ty = paramTyList[i];
  3566. if (Ty->isPointerTy()) {
  3567. Ty = Ty->getPointerElementType();
  3568. if (dxilutil::IsHLSLResourceType(Ty)) {
  3569. // Use handle type for resource type.
  3570. // This will make sure temp object variable only used by createHandle.
  3571. paramTyList[i] = HandleTy;
  3572. }
  3573. }
  3574. }
  3575. HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
  3576. if (group == HLOpcodeGroup::HLSubscript &&
  3577. opcode == static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)) {
  3578. llvm::FunctionType *FT = F->getFunctionType();
  3579. llvm::Type *VecArgTy = FT->getParamType(0);
  3580. llvm::VectorType *VType =
  3581. cast<llvm::VectorType>(VecArgTy->getPointerElementType());
  3582. llvm::Type *Ty = VType->getElementType();
  3583. DXASSERT(Ty->isIntegerTy(), "Only bool could use VectorSubscript");
  3584. llvm::IntegerType *ITy = cast<IntegerType>(Ty);
  3585. DXASSERT_LOCALVAR(ITy, ITy->getBitWidth() == 1, "Only bool could use VectorSubscript");
  3586. // The return type is i8*.
  3587. // Replace all uses with i1*.
  3588. ReplaceBoolVectorSubscript(F);
  3589. return;
  3590. }
  3591. bool isDoubleSubscriptFunc = group == HLOpcodeGroup::HLSubscript &&
  3592. opcode == static_cast<unsigned>(HLSubscriptOpcode::DoubleSubscript);
  3593. llvm::Type *RetTy = oldFuncTy->getReturnType();
  3594. if (isDoubleSubscriptFunc) {
  3595. CallInst *doubleSub = cast<CallInst>(*F->user_begin());
  3596. // Change currentIdx type into coord type.
  3597. auto U = doubleSub->user_begin();
  3598. Value *user = *U;
  3599. CallInst *secSub = cast<CallInst>(user);
  3600. unsigned coordIdx = HLOperandIndex::kSubscriptIndexOpIdx;
  3601. // opcode operand not add yet, so the index need -1.
  3602. if (GetHLOpcodeGroupByName(secSub->getCalledFunction()) == HLOpcodeGroup::NotHL)
  3603. coordIdx -= 1;
  3604. Value *coord = secSub->getArgOperand(coordIdx);
  3605. llvm::Type *coordTy = coord->getType();
  3606. paramTyList[HLOperandIndex::kSubscriptIndexOpIdx] = coordTy;
  3607. // Add the sampleIdx or mipLevel parameter to the end.
  3608. paramTyList.emplace_back(opcodeTy);
  3609. // Change return type to be resource ret type.
  3610. // opcode operand not add yet, so the index need -1.
  3611. Value *objPtr = doubleSub->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx-1);
  3612. // Must be a GEP
  3613. GEPOperator *objGEP = cast<GEPOperator>(objPtr);
  3614. gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
  3615. llvm::Type *resTy = nullptr;
  3616. while (GEPIt != E) {
  3617. if (dxilutil::IsHLSLResourceType(*GEPIt)) {
  3618. resTy = *GEPIt;
  3619. break;
  3620. }
  3621. GEPIt++;
  3622. }
  3623. DXASSERT(resTy, "must find the resource type");
  3624. // Change object type to handle type.
  3625. paramTyList[HLOperandIndex::kSubscriptObjectOpIdx] = HandleTy;
  3626. // Change RetTy into pointer of resource reture type.
  3627. RetTy = cast<StructType>(resTy)->getElementType(0)->getPointerTo();
  3628. llvm::Type *sliceTy = objGEP->getType()->getPointerElementType();
  3629. DXIL::ResourceClass RC = HLM.GetResourceClass(sliceTy);
  3630. DXIL::ResourceKind RK = HLM.GetResourceKind(sliceTy);
  3631. HLM.AddResourceTypeAnnotation(resTy, RC, RK);
  3632. }
  3633. llvm::FunctionType *funcTy =
  3634. llvm::FunctionType::get(RetTy, paramTyList, false);
  3635. Function *opFunc = CreateOpFunction(M, F, funcTy, group, opcode);
  3636. StringRef lower = hlsl::GetHLLowerStrategy(F);
  3637. if (!lower.empty())
  3638. hlsl::SetHLLowerStrategy(opFunc, lower);
  3639. for (auto user = F->user_begin(); user != F->user_end();) {
  3640. // User must be a call.
  3641. CallInst *oldCI = cast<CallInst>(*(user++));
  3642. SmallVector<Value *, 4> opcodeParamList;
  3643. Value *opcodeConst = Constant::getIntegerValue(opcodeTy, APInt(32, opcode));
  3644. opcodeParamList.emplace_back(opcodeConst);
  3645. opcodeParamList.append(oldCI->arg_operands().begin(),
  3646. oldCI->arg_operands().end());
  3647. IRBuilder<> Builder(oldCI);
  3648. if (isDoubleSubscriptFunc) {
  3649. // Change obj to the resource pointer.
  3650. Value *objVal = opcodeParamList[HLOperandIndex::kSubscriptObjectOpIdx];
  3651. GEPOperator *objGEP = cast<GEPOperator>(objVal);
  3652. SmallVector<Value *, 8> IndexList;
  3653. IndexList.append(objGEP->idx_begin(), objGEP->idx_end());
  3654. Value *lastIndex = IndexList.back();
  3655. ConstantInt *constIndex = cast<ConstantInt>(lastIndex);
  3656. DXASSERT_LOCALVAR(constIndex, constIndex->getLimitedValue() == 1, "last index must 1");
  3657. // Remove the last index.
  3658. IndexList.pop_back();
  3659. objVal = objGEP->getPointerOperand();
  3660. if (IndexList.size() > 1)
  3661. objVal = Builder.CreateInBoundsGEP(objVal, IndexList);
  3662. Value *Handle =
  3663. CreateHandleFromResPtr(objVal, HLM, HandleTy, resMetaMap, Builder);
  3664. // Change obj to the resource pointer.
  3665. opcodeParamList[HLOperandIndex::kSubscriptObjectOpIdx] = Handle;
  3666. // Set idx and mipIdx.
  3667. Value *mipIdx = opcodeParamList[HLOperandIndex::kSubscriptIndexOpIdx];
  3668. auto U = oldCI->user_begin();
  3669. Value *user = *U;
  3670. CallInst *secSub = cast<CallInst>(user);
  3671. unsigned idxOpIndex = HLOperandIndex::kSubscriptIndexOpIdx;
  3672. if (GetHLOpcodeGroupByName(secSub->getCalledFunction()) == HLOpcodeGroup::NotHL)
  3673. idxOpIndex--;
  3674. Value *idx = secSub->getArgOperand(idxOpIndex);
  3675. DXASSERT(secSub->hasOneUse(), "subscript should only has one use");
  3676. // Add the sampleIdx or mipLevel parameter to the end.
  3677. opcodeParamList[HLOperandIndex::kSubscriptIndexOpIdx] = idx;
  3678. opcodeParamList.emplace_back(mipIdx);
  3679. // Insert new call before secSub to make sure idx is ready to use.
  3680. Builder.SetInsertPoint(secSub);
  3681. }
  3682. for (unsigned i = 1; i < opcodeParamList.size(); i++) {
  3683. Value *arg = opcodeParamList[i];
  3684. llvm::Type *Ty = arg->getType();
  3685. if (Ty->isPointerTy()) {
  3686. Ty = Ty->getPointerElementType();
  3687. if (dxilutil::IsHLSLResourceType(Ty)) {
  3688. // Use object type directly, not by pointer.
  3689. // This will make sure temp object variable only used by ld/st.
  3690. if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
  3691. std::vector<Value*> idxList(argGEP->idx_begin(), argGEP->idx_end());
  3692. // Create instruction to avoid GEPOperator.
  3693. GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(argGEP->getPointerOperand(),
  3694. idxList);
  3695. Builder.Insert(GEP);
  3696. arg = GEP;
  3697. }
  3698. Value *Handle = CreateHandleFromResPtr(arg, HLM, HandleTy,
  3699. resMetaMap, Builder);
  3700. opcodeParamList[i] = Handle;
  3701. }
  3702. }
  3703. }
  3704. Value *CI = Builder.CreateCall(opFunc, opcodeParamList);
  3705. if (!isDoubleSubscriptFunc) {
  3706. // replace new call and delete the old call
  3707. oldCI->replaceAllUsesWith(CI);
  3708. oldCI->eraseFromParent();
  3709. } else {
  3710. // For double script.
  3711. // Replace single users use with new CI.
  3712. auto U = oldCI->user_begin();
  3713. Value *user = *U;
  3714. CallInst *secSub = cast<CallInst>(user);
  3715. secSub->replaceAllUsesWith(CI);
  3716. secSub->eraseFromParent();
  3717. oldCI->eraseFromParent();
  3718. }
  3719. }
  3720. // delete the function
  3721. F->eraseFromParent();
  3722. }
  3723. static void AddOpcodeParamForIntrinsics(HLModule &HLM
  3724. , std::vector<std::pair<Function *, unsigned>> &intrinsicMap,
  3725. std::unordered_map<llvm::Type *, MDNode*> &resMetaMap) {
  3726. llvm::Type *HandleTy = HLM.GetOP()->GetHandleType();
  3727. for (auto mapIter : intrinsicMap) {
  3728. Function *F = mapIter.first;
  3729. if (F->user_empty()) {
  3730. // delete the function
  3731. F->eraseFromParent();
  3732. continue;
  3733. }
  3734. unsigned opcode = mapIter.second;
  3735. AddOpcodeParamForIntrinsic(HLM, F, opcode, HandleTy, resMetaMap);
  3736. }
  3737. }
  3738. static Value *CastLdValue(Value *Ptr, llvm::Type *FromTy, llvm::Type *ToTy, IRBuilder<> &Builder) {
  3739. if (ToTy->isVectorTy()) {
  3740. unsigned vecSize = ToTy->getVectorNumElements();
  3741. if (vecSize == 1 && ToTy->getVectorElementType() == FromTy) {
  3742. Value *V = Builder.CreateLoad(Ptr);
  3743. // ScalarToVec1Splat
  3744. // Change scalar into vec1.
  3745. Value *Vec1 = UndefValue::get(ToTy);
  3746. return Builder.CreateInsertElement(Vec1, V, (uint64_t)0);
  3747. } else if (vecSize == 1 && FromTy->isIntegerTy()
  3748. && ToTy->getVectorElementType()->isIntegerTy(1)) {
  3749. // load(bitcast i32* to <1 x i1>*)
  3750. // Rewrite to
  3751. // insertelement(icmp ne (load i32*), 0)
  3752. Value *IntV = Builder.CreateLoad(Ptr);
  3753. Value *BoolV = Builder.CreateICmpNE(IntV, ConstantInt::get(IntV->getType(), 0), "tobool");
  3754. Value *Vec1 = UndefValue::get(ToTy);
  3755. return Builder.CreateInsertElement(Vec1, BoolV, (uint64_t)0);
  3756. } else if (FromTy->isVectorTy() && vecSize == 1) {
  3757. Value *V = Builder.CreateLoad(Ptr);
  3758. // VectorTrunc
  3759. // Change vector into vec1.
  3760. int mask[] = {0};
  3761. return Builder.CreateShuffleVector(V, V, mask);
  3762. } else if (FromTy->isArrayTy()) {
  3763. llvm::Type *FromEltTy = FromTy->getArrayElementType();
  3764. llvm::Type *ToEltTy = ToTy->getVectorElementType();
  3765. if (FromTy->getArrayNumElements() == vecSize && FromEltTy == ToEltTy) {
  3766. // ArrayToVector.
  3767. Value *NewLd = UndefValue::get(ToTy);
  3768. Value *zeroIdx = Builder.getInt32(0);
  3769. for (unsigned i = 0; i < vecSize; i++) {
  3770. Value *GEP = Builder.CreateInBoundsGEP(
  3771. Ptr, {zeroIdx, Builder.getInt32(i)});
  3772. Value *Elt = Builder.CreateLoad(GEP);
  3773. NewLd = Builder.CreateInsertElement(NewLd, Elt, i);
  3774. }
  3775. return NewLd;
  3776. }
  3777. }
  3778. } else if (FromTy == Builder.getInt1Ty()) {
  3779. Value *V = Builder.CreateLoad(Ptr);
  3780. // BoolCast
  3781. DXASSERT_NOMSG(ToTy->isIntegerTy());
  3782. return Builder.CreateZExt(V, ToTy);
  3783. }
  3784. return nullptr;
  3785. }
  3786. static Value *CastStValue(Value *Ptr, Value *V, llvm::Type *FromTy, llvm::Type *ToTy, IRBuilder<> &Builder) {
  3787. if (ToTy->isVectorTy()) {
  3788. unsigned vecSize = ToTy->getVectorNumElements();
  3789. if (vecSize == 1 && ToTy->getVectorElementType() == FromTy) {
  3790. // ScalarToVec1Splat
  3791. // Change vec1 back to scalar.
  3792. Value *Elt = Builder.CreateExtractElement(V, (uint64_t)0);
  3793. return Elt;
  3794. } else if (FromTy->isVectorTy() && vecSize == 1) {
  3795. // VectorTrunc
  3796. // Change vec1 into vector.
  3797. // Should not happen.
  3798. // Reported error at Sema::ImpCastExprToType.
  3799. DXASSERT_NOMSG(0);
  3800. } else if (FromTy->isArrayTy()) {
  3801. llvm::Type *FromEltTy = FromTy->getArrayElementType();
  3802. llvm::Type *ToEltTy = ToTy->getVectorElementType();
  3803. if (FromTy->getArrayNumElements() == vecSize && FromEltTy == ToEltTy) {
  3804. // ArrayToVector.
  3805. Value *zeroIdx = Builder.getInt32(0);
  3806. for (unsigned i = 0; i < vecSize; i++) {
  3807. Value *Elt = Builder.CreateExtractElement(V, i);
  3808. Value *GEP = Builder.CreateInBoundsGEP(
  3809. Ptr, {zeroIdx, Builder.getInt32(i)});
  3810. Builder.CreateStore(Elt, GEP);
  3811. }
  3812. // The store already done.
  3813. // Return null to ignore use of the return value.
  3814. return nullptr;
  3815. }
  3816. }
  3817. } else if (FromTy == Builder.getInt1Ty()) {
  3818. // BoolCast
  3819. // Change i1 to ToTy.
  3820. DXASSERT_NOMSG(ToTy->isIntegerTy());
  3821. Value *CastV = Builder.CreateICmpNE(V, ConstantInt::get(V->getType(), 0));
  3822. return CastV;
  3823. }
  3824. return nullptr;
  3825. }
  3826. static bool SimplifyBitCastLoad(LoadInst *LI, llvm::Type *FromTy, llvm::Type *ToTy, Value *Ptr) {
  3827. IRBuilder<> Builder(LI);
  3828. // Cast FromLd to ToTy.
  3829. Value *CastV = CastLdValue(Ptr, FromTy, ToTy, Builder);
  3830. if (CastV) {
  3831. LI->replaceAllUsesWith(CastV);
  3832. return true;
  3833. } else {
  3834. return false;
  3835. }
  3836. }
  3837. static bool SimplifyBitCastStore(StoreInst *SI, llvm::Type *FromTy, llvm::Type *ToTy, Value *Ptr) {
  3838. IRBuilder<> Builder(SI);
  3839. Value *V = SI->getValueOperand();
  3840. // Cast Val to FromTy.
  3841. Value *CastV = CastStValue(Ptr, V, FromTy, ToTy, Builder);
  3842. if (CastV) {
  3843. Builder.CreateStore(CastV, Ptr);
  3844. return true;
  3845. } else {
  3846. return false;
  3847. }
  3848. }
  3849. static bool SimplifyBitCastGEP(GEPOperator *GEP, llvm::Type *FromTy, llvm::Type *ToTy, Value *Ptr) {
  3850. if (ToTy->isVectorTy()) {
  3851. unsigned vecSize = ToTy->getVectorNumElements();
  3852. if (vecSize == 1 && ToTy->getVectorElementType() == FromTy) {
  3853. // ScalarToVec1Splat
  3854. GEP->replaceAllUsesWith(Ptr);
  3855. return true;
  3856. } else if (FromTy->isVectorTy() && vecSize == 1) {
  3857. // VectorTrunc
  3858. DXASSERT_NOMSG(
  3859. !isa<llvm::VectorType>(GEP->getType()->getPointerElementType()));
  3860. IRBuilder<> Builder(FromTy->getContext());
  3861. if (Instruction *I = dyn_cast<Instruction>(GEP))
  3862. Builder.SetInsertPoint(I);
  3863. std::vector<Value *> idxList(GEP->idx_begin(), GEP->idx_end());
  3864. Value *NewGEP = Builder.CreateInBoundsGEP(Ptr, idxList);
  3865. GEP->replaceAllUsesWith(NewGEP);
  3866. return true;
  3867. } else if (FromTy->isArrayTy()) {
  3868. llvm::Type *FromEltTy = FromTy->getArrayElementType();
  3869. llvm::Type *ToEltTy = ToTy->getVectorElementType();
  3870. if (FromTy->getArrayNumElements() == vecSize && FromEltTy == ToEltTy) {
  3871. // ArrayToVector.
  3872. }
  3873. }
  3874. } else if (FromTy == llvm::Type::getInt1Ty(FromTy->getContext())) {
  3875. // BoolCast
  3876. }
  3877. return false;
  3878. }
  3879. typedef SmallPtrSet<Instruction *, 4> SmallInstSet;
  3880. static void SimplifyBitCast(BitCastOperator *BC, SmallInstSet &deadInsts) {
  3881. Value *Ptr = BC->getOperand(0);
  3882. llvm::Type *FromTy = Ptr->getType();
  3883. llvm::Type *ToTy = BC->getType();
  3884. if (!FromTy->isPointerTy() || !ToTy->isPointerTy())
  3885. return;
  3886. FromTy = FromTy->getPointerElementType();
  3887. ToTy = ToTy->getPointerElementType();
  3888. // Take care case like %2 = bitcast %struct.T* %1 to <1 x float>*.
  3889. bool GEPCreated = false;
  3890. if (FromTy->isStructTy()) {
  3891. IRBuilder<> Builder(FromTy->getContext());
  3892. if (Instruction *I = dyn_cast<Instruction>(BC))
  3893. Builder.SetInsertPoint(I);
  3894. Value *zeroIdx = Builder.getInt32(0);
  3895. unsigned nestLevel = 1;
  3896. while (llvm::StructType *ST = dyn_cast<llvm::StructType>(FromTy)) {
  3897. if (ST->getNumElements() == 0) break;
  3898. FromTy = ST->getElementType(0);
  3899. nestLevel++;
  3900. }
  3901. std::vector<Value *> idxList(nestLevel, zeroIdx);
  3902. Ptr = Builder.CreateGEP(Ptr, idxList);
  3903. GEPCreated = true;
  3904. }
  3905. for (User *U : BC->users()) {
  3906. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  3907. if (SimplifyBitCastLoad(LI, FromTy, ToTy, Ptr)) {
  3908. LI->dropAllReferences();
  3909. deadInsts.insert(LI);
  3910. }
  3911. } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  3912. if (SimplifyBitCastStore(SI, FromTy, ToTy, Ptr)) {
  3913. SI->dropAllReferences();
  3914. deadInsts.insert(SI);
  3915. }
  3916. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  3917. if (SimplifyBitCastGEP(GEP, FromTy, ToTy, Ptr))
  3918. if (Instruction *I = dyn_cast<Instruction>(GEP)) {
  3919. I->dropAllReferences();
  3920. deadInsts.insert(I);
  3921. }
  3922. } else if (dyn_cast<CallInst>(U)) {
  3923. // Skip function call.
  3924. } else if (dyn_cast<BitCastInst>(U)) {
  3925. // Skip bitcast.
  3926. } else if (dyn_cast<AddrSpaceCastInst>(U)) {
  3927. // Skip addrspacecast.
  3928. } else {
  3929. DXASSERT(0, "not support yet");
  3930. }
  3931. }
  3932. // We created a GEP instruction but didn't end up consuming it, so delete it.
  3933. if (GEPCreated && Ptr->use_empty()) {
  3934. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
  3935. GEP->eraseFromParent();
  3936. else
  3937. cast<Constant>(Ptr)->destroyConstant();
  3938. }
  3939. }
  3940. typedef float(__cdecl *FloatUnaryEvalFuncType)(float);
  3941. typedef double(__cdecl *DoubleUnaryEvalFuncType)(double);
  3942. typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt&, const APInt&);
  3943. typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
  3944. typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);
  3945. static Value * EvalUnaryIntrinsic(ConstantFP *fpV,
  3946. FloatUnaryEvalFuncType floatEvalFunc,
  3947. DoubleUnaryEvalFuncType doubleEvalFunc) {
  3948. llvm::Type *Ty = fpV->getType();
  3949. Value *Result = nullptr;
  3950. if (Ty->isDoubleTy()) {
  3951. double dV = fpV->getValueAPF().convertToDouble();
  3952. Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV));
  3953. Result = dResult;
  3954. } else {
  3955. DXASSERT_NOMSG(Ty->isFloatTy());
  3956. float fV = fpV->getValueAPF().convertToFloat();
  3957. Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV));
  3958. Result = dResult;
  3959. }
  3960. return Result;
  3961. }
  3962. static Value * EvalBinaryIntrinsic(Constant *cV0, Constant *cV1,
  3963. FloatBinaryEvalFuncType floatEvalFunc,
  3964. DoubleBinaryEvalFuncType doubleEvalFunc,
  3965. IntBinaryEvalFuncType intEvalFunc) {
  3966. llvm::Type *Ty = cV0->getType();
  3967. Value *Result = nullptr;
  3968. if (Ty->isDoubleTy()) {
  3969. ConstantFP *fpV0 = cast<ConstantFP>(cV0);
  3970. ConstantFP *fpV1 = cast<ConstantFP>(cV1);
  3971. double dV0 = fpV0->getValueAPF().convertToDouble();
  3972. double dV1 = fpV1->getValueAPF().convertToDouble();
  3973. Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1));
  3974. Result = dResult;
  3975. } else if (Ty->isFloatTy()) {
  3976. ConstantFP *fpV0 = cast<ConstantFP>(cV0);
  3977. ConstantFP *fpV1 = cast<ConstantFP>(cV1);
  3978. float fV0 = fpV0->getValueAPF().convertToFloat();
  3979. float fV1 = fpV1->getValueAPF().convertToFloat();
  3980. Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1));
  3981. Result = dResult;
  3982. } else {
  3983. DXASSERT_NOMSG(Ty->isIntegerTy());
  3984. DXASSERT_NOMSG(intEvalFunc);
  3985. ConstantInt *ciV0 = cast<ConstantInt>(cV0);
  3986. ConstantInt *ciV1 = cast<ConstantInt>(cV1);
  3987. const APInt& iV0 = ciV0->getValue();
  3988. const APInt& iV1 = ciV1->getValue();
  3989. Value *dResult = ConstantInt::get(Ty, intEvalFunc(iV0, iV1));
  3990. Result = dResult;
  3991. }
  3992. return Result;
  3993. }
  3994. static Value * EvalUnaryIntrinsic(CallInst *CI,
  3995. FloatUnaryEvalFuncType floatEvalFunc,
  3996. DoubleUnaryEvalFuncType doubleEvalFunc) {
  3997. Value *V = CI->getArgOperand(0);
  3998. llvm::Type *Ty = CI->getType();
  3999. Value *Result = nullptr;
  4000. if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
  4001. Result = UndefValue::get(Ty);
  4002. Constant *CV = cast<Constant>(V);
  4003. IRBuilder<> Builder(CI);
  4004. for (unsigned i=0;i<VT->getNumElements();i++) {
  4005. ConstantFP *fpV = cast<ConstantFP>(CV->getAggregateElement(i));
  4006. Value *EltResult = EvalUnaryIntrinsic(fpV, floatEvalFunc, doubleEvalFunc);
  4007. Result = Builder.CreateInsertElement(Result, EltResult, i);
  4008. }
  4009. } else {
  4010. ConstantFP *fpV = cast<ConstantFP>(V);
  4011. Result = EvalUnaryIntrinsic(fpV, floatEvalFunc, doubleEvalFunc);
  4012. }
  4013. CI->replaceAllUsesWith(Result);
  4014. CI->eraseFromParent();
  4015. return Result;
  4016. }
  4017. static Value * EvalBinaryIntrinsic(CallInst *CI,
  4018. FloatBinaryEvalFuncType floatEvalFunc,
  4019. DoubleBinaryEvalFuncType doubleEvalFunc,
  4020. IntBinaryEvalFuncType intEvalFunc = nullptr) {
  4021. Value *V0 = CI->getArgOperand(0);
  4022. Value *V1 = CI->getArgOperand(1);
  4023. llvm::Type *Ty = CI->getType();
  4024. Value *Result = nullptr;
  4025. if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
  4026. Result = UndefValue::get(Ty);
  4027. Constant *CV0 = cast<Constant>(V0);
  4028. Constant *CV1 = cast<Constant>(V1);
  4029. IRBuilder<> Builder(CI);
  4030. for (unsigned i=0;i<VT->getNumElements();i++) {
  4031. Constant *cV0 = cast<Constant>(CV0->getAggregateElement(i));
  4032. Constant *cV1 = cast<Constant>(CV1->getAggregateElement(i));
  4033. Value *EltResult = EvalBinaryIntrinsic(cV0, cV1, floatEvalFunc, doubleEvalFunc, intEvalFunc);
  4034. Result = Builder.CreateInsertElement(Result, EltResult, i);
  4035. }
  4036. } else {
  4037. Constant *cV0 = cast<Constant>(V0);
  4038. Constant *cV1 = cast<Constant>(V1);
  4039. Result = EvalBinaryIntrinsic(cV0, cV1, floatEvalFunc, doubleEvalFunc, intEvalFunc);
  4040. }
  4041. CI->replaceAllUsesWith(Result);
  4042. CI->eraseFromParent();
  4043. return Result;
  4044. CI->eraseFromParent();
  4045. return Result;
  4046. }
  4047. static Value * TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp) {
  4048. switch (intriOp) {
  4049. case IntrinsicOp::IOP_tan: {
  4050. return EvalUnaryIntrinsic(CI, tanf, tan);
  4051. } break;
  4052. case IntrinsicOp::IOP_tanh: {
  4053. return EvalUnaryIntrinsic(CI, tanhf, tanh);
  4054. } break;
  4055. case IntrinsicOp::IOP_sin: {
  4056. return EvalUnaryIntrinsic(CI, sinf, sin);
  4057. } break;
  4058. case IntrinsicOp::IOP_sinh: {
  4059. return EvalUnaryIntrinsic(CI, sinhf, sinh);
  4060. } break;
  4061. case IntrinsicOp::IOP_cos: {
  4062. return EvalUnaryIntrinsic(CI, cosf, cos);
  4063. } break;
  4064. case IntrinsicOp::IOP_cosh: {
  4065. return EvalUnaryIntrinsic(CI, coshf, cosh);
  4066. } break;
  4067. case IntrinsicOp::IOP_asin: {
  4068. return EvalUnaryIntrinsic(CI, asinf, asin);
  4069. } break;
  4070. case IntrinsicOp::IOP_acos: {
  4071. return EvalUnaryIntrinsic(CI, acosf, acos);
  4072. } break;
  4073. case IntrinsicOp::IOP_atan: {
  4074. return EvalUnaryIntrinsic(CI, atanf, atan);
  4075. } break;
  4076. case IntrinsicOp::IOP_atan2: {
  4077. Value *V0 = CI->getArgOperand(0);
  4078. ConstantFP *fpV0 = cast<ConstantFP>(V0);
  4079. Value *V1 = CI->getArgOperand(1);
  4080. ConstantFP *fpV1 = cast<ConstantFP>(V1);
  4081. llvm::Type *Ty = CI->getType();
  4082. Value *Result = nullptr;
  4083. if (Ty->isDoubleTy()) {
  4084. double dV0 = fpV0->getValueAPF().convertToDouble();
  4085. double dV1 = fpV1->getValueAPF().convertToDouble();
  4086. Value *atanV = ConstantFP::get(CI->getType(), atan2(dV0, dV1));
  4087. CI->replaceAllUsesWith(atanV);
  4088. Result = atanV;
  4089. } else {
  4090. DXASSERT_NOMSG(Ty->isFloatTy());
  4091. float fV0 = fpV0->getValueAPF().convertToFloat();
  4092. float fV1 = fpV1->getValueAPF().convertToFloat();
  4093. Value *atanV = ConstantFP::get(CI->getType(), atan2f(fV0, fV1));
  4094. CI->replaceAllUsesWith(atanV);
  4095. Result = atanV;
  4096. }
  4097. CI->eraseFromParent();
  4098. return Result;
  4099. } break;
  4100. case IntrinsicOp::IOP_sqrt: {
  4101. return EvalUnaryIntrinsic(CI, sqrtf, sqrt);
  4102. } break;
  4103. case IntrinsicOp::IOP_rsqrt: {
  4104. auto rsqrtF = [](float v) -> float { return 1.0 / sqrtf(v); };
  4105. auto rsqrtD = [](double v) -> double { return 1.0 / sqrt(v); };
  4106. return EvalUnaryIntrinsic(CI, rsqrtF, rsqrtD);
  4107. } break;
  4108. case IntrinsicOp::IOP_exp: {
  4109. return EvalUnaryIntrinsic(CI, expf, exp);
  4110. } break;
  4111. case IntrinsicOp::IOP_exp2: {
  4112. return EvalUnaryIntrinsic(CI, exp2f, exp2);
  4113. } break;
  4114. case IntrinsicOp::IOP_log: {
  4115. return EvalUnaryIntrinsic(CI, logf, log);
  4116. } break;
  4117. case IntrinsicOp::IOP_log10: {
  4118. return EvalUnaryIntrinsic(CI, log10f, log10);
  4119. } break;
  4120. case IntrinsicOp::IOP_log2: {
  4121. return EvalUnaryIntrinsic(CI, log2f, log2);
  4122. } break;
  4123. case IntrinsicOp::IOP_pow: {
  4124. return EvalBinaryIntrinsic(CI, powf, pow);
  4125. } break;
  4126. case IntrinsicOp::IOP_max: {
  4127. auto maxF = [](float a, float b) -> float { return a > b ? a:b; };
  4128. auto maxD = [](double a, double b) -> double { return a > b ? a:b; };
  4129. auto imaxI = [](const APInt &a, const APInt &b) -> APInt { return a.sgt(b) ? a : b; };
  4130. return EvalBinaryIntrinsic(CI, maxF, maxD, imaxI);
  4131. } break;
  4132. case IntrinsicOp::IOP_min: {
  4133. auto minF = [](float a, float b) -> float { return a < b ? a:b; };
  4134. auto minD = [](double a, double b) -> double { return a < b ? a:b; };
  4135. auto iminI = [](const APInt &a, const APInt &b) -> APInt { return a.slt(b) ? a : b; };
  4136. return EvalBinaryIntrinsic(CI, minF, minD, iminI);
  4137. } break;
  4138. case IntrinsicOp::IOP_umax: {
  4139. DXASSERT_NOMSG(CI->getArgOperand(0)->getType()->getScalarType()->isIntegerTy());
  4140. auto umaxI = [](const APInt &a, const APInt &b) -> APInt { return a.ugt(b) ? a : b; };
  4141. return EvalBinaryIntrinsic(CI, nullptr, nullptr, umaxI);
  4142. } break;
  4143. case IntrinsicOp::IOP_umin: {
  4144. DXASSERT_NOMSG(CI->getArgOperand(0)->getType()->getScalarType()->isIntegerTy());
  4145. auto uminI = [](const APInt &a, const APInt &b) -> APInt { return a.ult(b) ? a : b; };
  4146. return EvalBinaryIntrinsic(CI, nullptr, nullptr, uminI);
  4147. } break;
  4148. case IntrinsicOp::IOP_rcp: {
  4149. auto rcpF = [](float v) -> float { return 1.0 / v; };
  4150. auto rcpD = [](double v) -> double { return 1.0 / v; };
  4151. return EvalUnaryIntrinsic(CI, rcpF, rcpD);
  4152. } break;
  4153. case IntrinsicOp::IOP_ceil: {
  4154. return EvalUnaryIntrinsic(CI, ceilf, ceil);
  4155. } break;
  4156. case IntrinsicOp::IOP_floor: {
  4157. return EvalUnaryIntrinsic(CI, floorf, floor);
  4158. } break;
  4159. case IntrinsicOp::IOP_round: {
  4160. return EvalUnaryIntrinsic(CI, roundf, round);
  4161. } break;
  4162. case IntrinsicOp::IOP_trunc: {
  4163. return EvalUnaryIntrinsic(CI, truncf, trunc);
  4164. } break;
  4165. case IntrinsicOp::IOP_frac: {
  4166. auto fracF = [](float v) -> float {
  4167. return v - floor(v);
  4168. };
  4169. auto fracD = [](double v) -> double {
  4170. return v - floor(v);
  4171. };
  4172. return EvalUnaryIntrinsic(CI, fracF, fracD);
  4173. } break;
  4174. case IntrinsicOp::IOP_isnan: {
  4175. Value *V = CI->getArgOperand(0);
  4176. ConstantFP *fV = cast<ConstantFP>(V);
  4177. bool isNan = fV->getValueAPF().isNaN();
  4178. Constant *cNan = ConstantInt::get(CI->getType(), isNan ? 1 : 0);
  4179. CI->replaceAllUsesWith(cNan);
  4180. CI->eraseFromParent();
  4181. return cNan;
  4182. } break;
  4183. default:
  4184. return nullptr;
  4185. }
  4186. }
  4187. static void SimpleTransformForHLDXIR(Instruction *I,
  4188. SmallInstSet &deadInsts) {
  4189. unsigned opcode = I->getOpcode();
  4190. switch (opcode) {
  4191. case Instruction::BitCast: {
  4192. BitCastOperator *BCI = cast<BitCastOperator>(I);
  4193. SimplifyBitCast(BCI, deadInsts);
  4194. } break;
  4195. case Instruction::Load: {
  4196. LoadInst *ldInst = cast<LoadInst>(I);
  4197. DXASSERT(!HLMatrixType::isa(ldInst->getType()),
  4198. "matrix load should use HL LdStMatrix");
  4199. Value *Ptr = ldInst->getPointerOperand();
  4200. if (ConstantExpr *CE = dyn_cast_or_null<ConstantExpr>(Ptr)) {
  4201. if (BitCastOperator *BCO = dyn_cast<BitCastOperator>(CE)) {
  4202. SimplifyBitCast(BCO, deadInsts);
  4203. }
  4204. }
  4205. } break;
  4206. case Instruction::Store: {
  4207. StoreInst *stInst = cast<StoreInst>(I);
  4208. Value *V = stInst->getValueOperand();
  4209. DXASSERT_LOCALVAR(V, !HLMatrixType::isa(V->getType()),
  4210. "matrix store should use HL LdStMatrix");
  4211. Value *Ptr = stInst->getPointerOperand();
  4212. if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) {
  4213. if (BitCastOperator *BCO = dyn_cast<BitCastOperator>(CE)) {
  4214. SimplifyBitCast(BCO, deadInsts);
  4215. }
  4216. }
  4217. } break;
  4218. case Instruction::LShr:
  4219. case Instruction::AShr:
  4220. case Instruction::Shl: {
  4221. llvm::BinaryOperator *BO = cast<llvm::BinaryOperator>(I);
  4222. Value *op2 = BO->getOperand(1);
  4223. IntegerType *Ty = cast<IntegerType>(BO->getType()->getScalarType());
  4224. unsigned bitWidth = Ty->getBitWidth();
  4225. // Clamp op2 to 0 ~ bitWidth-1
  4226. if (ConstantInt *cOp2 = dyn_cast<ConstantInt>(op2)) {
  4227. unsigned iOp2 = cOp2->getLimitedValue();
  4228. unsigned clampedOp2 = iOp2 & (bitWidth - 1);
  4229. if (iOp2 != clampedOp2) {
  4230. BO->setOperand(1, ConstantInt::get(op2->getType(), clampedOp2));
  4231. }
  4232. } else {
  4233. Value *mask = ConstantInt::get(op2->getType(), bitWidth - 1);
  4234. IRBuilder<> Builder(I);
  4235. op2 = Builder.CreateAnd(op2, mask);
  4236. BO->setOperand(1, op2);
  4237. }
  4238. } break;
  4239. }
  4240. }
  4241. // Do simple transform to make later lower pass easier.
  4242. static void SimpleTransformForHLDXIR(llvm::Module *pM) {
  4243. SmallInstSet deadInsts;
  4244. for (Function &F : pM->functions()) {
  4245. for (BasicBlock &BB : F.getBasicBlockList()) {
  4246. for (BasicBlock::iterator Iter = BB.begin(); Iter != BB.end(); ) {
  4247. Instruction *I = (Iter++);
  4248. if (deadInsts.count(I))
  4249. continue; // Skip dead instructions
  4250. SimpleTransformForHLDXIR(I, deadInsts);
  4251. }
  4252. }
  4253. }
  4254. for (Instruction * I : deadInsts)
  4255. I->dropAllReferences();
  4256. for (Instruction * I : deadInsts)
  4257. I->eraseFromParent();
  4258. deadInsts.clear();
  4259. for (GlobalVariable &GV : pM->globals()) {
  4260. if (dxilutil::IsStaticGlobal(&GV)) {
  4261. for (User *U : GV.users()) {
  4262. if (BitCastOperator *BCO = dyn_cast<BitCastOperator>(U)) {
  4263. SimplifyBitCast(BCO, deadInsts);
  4264. }
  4265. }
  4266. }
  4267. }
  4268. for (Instruction * I : deadInsts)
  4269. I->dropAllReferences();
  4270. for (Instruction * I : deadInsts)
  4271. I->eraseFromParent();
  4272. }
  4273. static Function *CloneFunction(Function *Orig,
  4274. const llvm::Twine &Name,
  4275. llvm::Module *llvmModule,
  4276. hlsl::DxilTypeSystem &TypeSys,
  4277. hlsl::DxilTypeSystem &SrcTypeSys) {
  4278. Function *F = Function::Create(Orig->getFunctionType(),
  4279. GlobalValue::LinkageTypes::ExternalLinkage,
  4280. Name, llvmModule);
  4281. SmallVector<ReturnInst *, 2> Returns;
  4282. ValueToValueMapTy vmap;
  4283. // Map params.
  4284. auto entryParamIt = F->arg_begin();
  4285. for (Argument &param : Orig->args()) {
  4286. vmap[&param] = (entryParamIt++);
  4287. }
  4288. llvm::CloneFunctionInto(F, Orig, vmap, /*ModuleLevelChagnes*/ false, Returns);
  4289. TypeSys.CopyFunctionAnnotation(F, Orig, SrcTypeSys);
  4290. return F;
  4291. }
  4292. // Clone shader entry function to be called by other functions.
  4293. // The original function will be used as shader entry.
  4294. static void CloneShaderEntry(Function *ShaderF, StringRef EntryName,
  4295. HLModule &HLM) {
  4296. Function *F = CloneFunction(ShaderF, "", HLM.GetModule(),
  4297. HLM.GetTypeSystem(), HLM.GetTypeSystem());
  4298. F->takeName(ShaderF);
  4299. F->setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  4300. // Set to name before mangled.
  4301. ShaderF->setName(EntryName);
  4302. DxilFunctionAnnotation *annot = HLM.GetFunctionAnnotation(F);
  4303. DxilParameterAnnotation &cloneRetAnnot = annot->GetRetTypeAnnotation();
  4304. // Clear semantic for cloned one.
  4305. cloneRetAnnot.SetSemanticString("");
  4306. cloneRetAnnot.SetSemanticIndexVec({});
  4307. for (unsigned i = 0; i < annot->GetNumParameters(); i++) {
  4308. DxilParameterAnnotation &cloneParamAnnot = annot->GetParameterAnnotation(i);
  4309. // Clear semantic for cloned one.
  4310. cloneParamAnnot.SetSemanticString("");
  4311. cloneParamAnnot.SetSemanticIndexVec({});
  4312. }
  4313. }
  4314. // For case like:
  4315. //cbuffer A {
  4316. // float a;
  4317. // int b;
  4318. //}
  4319. //
  4320. //const static struct {
  4321. // float a;
  4322. // int b;
  4323. //} ST = { a, b };
  4324. // Replace user of ST with a and b.
  4325. static bool ReplaceConstStaticGlobalUser(GEPOperator *GEP,
  4326. std::vector<Constant *> &InitList,
  4327. IRBuilder<> &Builder) {
  4328. if (GEP->getNumIndices() < 2) {
  4329. // Don't use sub element.
  4330. return false;
  4331. }
  4332. SmallVector<Value *, 4> idxList;
  4333. auto iter = GEP->idx_begin();
  4334. idxList.emplace_back(*(iter++));
  4335. ConstantInt *subIdx = dyn_cast<ConstantInt>(*(iter++));
  4336. DXASSERT(subIdx, "else dynamic indexing on struct field");
  4337. unsigned subIdxImm = subIdx->getLimitedValue();
  4338. DXASSERT(subIdxImm < InitList.size(), "else struct index out of bound");
  4339. Constant *subPtr = InitList[subIdxImm];
  4340. // Move every idx to idxList except idx for InitList.
  4341. while (iter != GEP->idx_end()) {
  4342. idxList.emplace_back(*(iter++));
  4343. }
  4344. Value *NewGEP = Builder.CreateGEP(subPtr, idxList);
  4345. GEP->replaceAllUsesWith(NewGEP);
  4346. return true;
  4347. }
  4348. static void ReplaceConstStaticGlobals(
  4349. std::unordered_map<GlobalVariable *, std::vector<Constant *>>
  4350. &staticConstGlobalInitListMap,
  4351. std::unordered_map<GlobalVariable *, Function *>
  4352. &staticConstGlobalCtorMap) {
  4353. for (auto &iter : staticConstGlobalInitListMap) {
  4354. GlobalVariable *GV = iter.first;
  4355. std::vector<Constant *> &InitList = iter.second;
  4356. LLVMContext &Ctx = GV->getContext();
  4357. // Do the replace.
  4358. bool bPass = true;
  4359. for (User *U : GV->users()) {
  4360. IRBuilder<> Builder(Ctx);
  4361. if (GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
  4362. Builder.SetInsertPoint(GEPInst);
  4363. bPass &= ReplaceConstStaticGlobalUser(cast<GEPOperator>(GEPInst), InitList, Builder);
  4364. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
  4365. bPass &= ReplaceConstStaticGlobalUser(GEP, InitList, Builder);
  4366. } else {
  4367. DXASSERT(false, "invalid user of const static global");
  4368. }
  4369. }
  4370. // Clear the Ctor which is useless now.
  4371. if (bPass) {
  4372. Function *Ctor = staticConstGlobalCtorMap[GV];
  4373. Ctor->getBasicBlockList().clear();
  4374. BasicBlock *Entry = BasicBlock::Create(Ctx, "", Ctor);
  4375. IRBuilder<> Builder(Entry);
  4376. Builder.CreateRetVoid();
  4377. }
  4378. }
  4379. }
  4380. bool BuildImmInit(Function *Ctor) {
  4381. GlobalVariable *GV = nullptr;
  4382. SmallVector<Constant *, 4> ImmList;
  4383. bool allConst = true;
  4384. for (inst_iterator I = inst_begin(Ctor), E = inst_end(Ctor); I != E; ++I) {
  4385. if (StoreInst *SI = dyn_cast<StoreInst>(&(*I))) {
  4386. Value *V = SI->getValueOperand();
  4387. if (!isa<Constant>(V) || V->getType()->isPointerTy()) {
  4388. allConst = false;
  4389. break;
  4390. }
  4391. ImmList.emplace_back(cast<Constant>(V));
  4392. Value *Ptr = SI->getPointerOperand();
  4393. if (GEPOperator *GepOp = dyn_cast<GEPOperator>(Ptr)) {
  4394. Ptr = GepOp->getPointerOperand();
  4395. if (GlobalVariable *pGV = dyn_cast<GlobalVariable>(Ptr)) {
  4396. if (GV == nullptr)
  4397. GV = pGV;
  4398. else {
  4399. DXASSERT(GV == pGV, "else pointer mismatch");
  4400. }
  4401. }
  4402. }
  4403. } else {
  4404. if (!isa<ReturnInst>(*I)) {
  4405. allConst = false;
  4406. break;
  4407. }
  4408. }
  4409. }
  4410. if (!allConst)
  4411. return false;
  4412. if (!GV)
  4413. return false;
  4414. llvm::Type *Ty = GV->getType()->getElementType();
  4415. llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty);
  4416. // TODO: support other types.
  4417. if (!AT)
  4418. return false;
  4419. if (ImmList.size() != AT->getNumElements())
  4420. return false;
  4421. Constant *Init = llvm::ConstantArray::get(AT, ImmList);
  4422. GV->setInitializer(Init);
  4423. return true;
  4424. }
  4425. void ProcessCtorFunctions(llvm::Module &M, StringRef globalName,
  4426. Instruction *InsertPt) {
  4427. // add global call to entry func
  4428. GlobalVariable *GV = M.getGlobalVariable(globalName);
  4429. if (GV) {
  4430. if (ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer())) {
  4431. IRBuilder<> Builder(InsertPt);
  4432. for (User::op_iterator i = CA->op_begin(), e = CA->op_end(); i != e;
  4433. ++i) {
  4434. if (isa<ConstantAggregateZero>(*i))
  4435. continue;
  4436. ConstantStruct *CS = cast<ConstantStruct>(*i);
  4437. if (isa<ConstantPointerNull>(CS->getOperand(1)))
  4438. continue;
  4439. // Must have a function or null ptr.
  4440. if (!isa<Function>(CS->getOperand(1)))
  4441. continue;
  4442. Function *Ctor = cast<Function>(CS->getOperand(1));
  4443. DXASSERT(Ctor->getReturnType()->isVoidTy() && Ctor->arg_size() == 0,
  4444. "function type must be void (void)");
  4445. for (inst_iterator I = inst_begin(Ctor), E = inst_end(Ctor); I != E;
  4446. ++I) {
  4447. if (CallInst *CI = dyn_cast<CallInst>(&(*I))) {
  4448. Function *F = CI->getCalledFunction();
  4449. // Try to build imm initilizer.
  4450. // If not work, add global call to entry func.
  4451. if (BuildImmInit(F) == false) {
  4452. Builder.CreateCall(F);
  4453. }
  4454. } else {
  4455. DXASSERT(isa<ReturnInst>(&(*I)),
  4456. "else invalid Global constructor function");
  4457. }
  4458. }
  4459. }
  4460. // remove the GV
  4461. GV->eraseFromParent();
  4462. }
  4463. }
  4464. }
  4465. void CGMSHLSLRuntime::SetPatchConstantFunction(const EntryFunctionInfo &EntryFunc) {
  4466. auto AttrsIter = HSEntryPatchConstantFuncAttr.find(EntryFunc.Func);
  4467. DXASSERT(AttrsIter != HSEntryPatchConstantFuncAttr.end(),
  4468. "we have checked this in AddHLSLFunctionInfo()");
  4469. SetPatchConstantFunctionWithAttr(Entry, AttrsIter->second);
  4470. }
  4471. void CGMSHLSLRuntime::SetPatchConstantFunctionWithAttr(
  4472. const EntryFunctionInfo &EntryFunc,
  4473. const clang::HLSLPatchConstantFuncAttr *PatchConstantFuncAttr) {
  4474. StringRef funcName = PatchConstantFuncAttr->getFunctionName();
  4475. auto Entry = patchConstantFunctionMap.find(funcName);
  4476. if (Entry == patchConstantFunctionMap.end()) {
  4477. DiagnosticsEngine &Diags = CGM.getDiags();
  4478. unsigned DiagID =
  4479. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  4480. "Cannot find patchconstantfunc %0.");
  4481. Diags.Report(PatchConstantFuncAttr->getLocation(), DiagID)
  4482. << funcName;
  4483. return;
  4484. }
  4485. if (Entry->second.NumOverloads != 1) {
  4486. DiagnosticsEngine &Diags = CGM.getDiags();
  4487. unsigned DiagID =
  4488. Diags.getCustomDiagID(DiagnosticsEngine::Warning,
  4489. "Multiple overloads of patchconstantfunc %0.");
  4490. unsigned NoteID =
  4491. Diags.getCustomDiagID(DiagnosticsEngine::Note,
  4492. "This overload was selected.");
  4493. Diags.Report(PatchConstantFuncAttr->getLocation(), DiagID)
  4494. << funcName;
  4495. Diags.Report(Entry->second.SL, NoteID);
  4496. }
  4497. Function *patchConstFunc = Entry->second.Func;
  4498. DXASSERT(m_pHLModule->HasDxilFunctionProps(EntryFunc.Func),
  4499. " else AddHLSLFunctionInfo did not save the dxil function props for the "
  4500. "HS entry.");
  4501. DxilFunctionProps *HSProps = &m_pHLModule->GetDxilFunctionProps(EntryFunc.Func);
  4502. m_pHLModule->SetPatchConstantFunctionForHS(EntryFunc.Func, patchConstFunc);
  4503. DXASSERT_NOMSG(patchConstantFunctionPropsMap.count(patchConstFunc));
  4504. // Check no inout parameter for patch constant function.
  4505. DxilFunctionAnnotation *patchConstFuncAnnotation =
  4506. m_pHLModule->GetFunctionAnnotation(patchConstFunc);
  4507. for (unsigned i = 0; i < patchConstFuncAnnotation->GetNumParameters(); i++) {
  4508. if (patchConstFuncAnnotation->GetParameterAnnotation(i)
  4509. .GetParamInputQual() == DxilParamInputQual::Inout) {
  4510. DiagnosticsEngine &Diags = CGM.getDiags();
  4511. unsigned DiagID = Diags.getCustomDiagID(
  4512. DiagnosticsEngine::Error,
  4513. "Patch Constant function %0 should not have inout param.");
  4514. Diags.Report(Entry->second.SL, DiagID) << funcName;
  4515. }
  4516. }
  4517. // Input/Output control point validation.
  4518. if (patchConstantFunctionPropsMap.count(patchConstFunc)) {
  4519. const DxilFunctionProps &patchProps =
  4520. *patchConstantFunctionPropsMap[patchConstFunc];
  4521. if (patchProps.ShaderProps.HS.inputControlPoints != 0 &&
  4522. patchProps.ShaderProps.HS.inputControlPoints !=
  4523. HSProps->ShaderProps.HS.inputControlPoints) {
  4524. DiagnosticsEngine &Diags = CGM.getDiags();
  4525. unsigned DiagID =
  4526. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  4527. "Patch constant function's input patch input "
  4528. "should have %0 elements, but has %1.");
  4529. Diags.Report(Entry->second.SL, DiagID)
  4530. << HSProps->ShaderProps.HS.inputControlPoints
  4531. << patchProps.ShaderProps.HS.inputControlPoints;
  4532. }
  4533. if (patchProps.ShaderProps.HS.outputControlPoints != 0 &&
  4534. patchProps.ShaderProps.HS.outputControlPoints !=
  4535. HSProps->ShaderProps.HS.outputControlPoints) {
  4536. DiagnosticsEngine &Diags = CGM.getDiags();
  4537. unsigned DiagID = Diags.getCustomDiagID(
  4538. DiagnosticsEngine::Error,
  4539. "Patch constant function's output patch input "
  4540. "should have %0 elements, but has %1.");
  4541. Diags.Report(Entry->second.SL, DiagID)
  4542. << HSProps->ShaderProps.HS.outputControlPoints
  4543. << patchProps.ShaderProps.HS.outputControlPoints;
  4544. }
  4545. }
  4546. }
  4547. static void ReportDisallowedTypeInExportParam(CodeGenModule &CGM, StringRef name) {
  4548. DiagnosticsEngine &Diags = CGM.getDiags();
  4549. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  4550. "Exported function %0 must not contain a resource in parameter or return type.");
  4551. std::string escaped;
  4552. llvm::raw_string_ostream os(escaped);
  4553. dxilutil::PrintEscapedString(name, os);
  4554. Diags.Report(DiagID) << os.str();
  4555. }
  4556. // Returns true a global value is being updated
  4557. static bool GlobalHasStoreUserRec(Value *V, std::set<Value *> &visited) {
  4558. bool isWriteEnabled = false;
  4559. if (V && visited.find(V) == visited.end()) {
  4560. visited.insert(V);
  4561. for (User *U : V->users()) {
  4562. if (isa<StoreInst>(U)) {
  4563. return true;
  4564. } else if (CallInst* CI = dyn_cast<CallInst>(U)) {
  4565. Function *F = CI->getCalledFunction();
  4566. if (!F->isIntrinsic()) {
  4567. HLOpcodeGroup hlGroup = GetHLOpcodeGroup(F);
  4568. switch (hlGroup) {
  4569. case HLOpcodeGroup::NotHL:
  4570. return true;
  4571. case HLOpcodeGroup::HLMatLoadStore:
  4572. {
  4573. HLMatLoadStoreOpcode opCode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  4574. if (opCode == HLMatLoadStoreOpcode::ColMatStore || opCode == HLMatLoadStoreOpcode::RowMatStore)
  4575. return true;
  4576. break;
  4577. }
  4578. case HLOpcodeGroup::HLCast:
  4579. case HLOpcodeGroup::HLSubscript:
  4580. if (GlobalHasStoreUserRec(U, visited))
  4581. return true;
  4582. break;
  4583. default:
  4584. break;
  4585. }
  4586. }
  4587. } else if (isa<GEPOperator>(U) || isa<PHINode>(U) || isa<SelectInst>(U)) {
  4588. if (GlobalHasStoreUserRec(U, visited))
  4589. return true;
  4590. }
  4591. }
  4592. }
  4593. return isWriteEnabled;
  4594. }
  4595. // Returns true if any of the direct user of a global is a store inst
  4596. // otherwise recurse through the remaining users and check if any GEP
  4597. // exists and which in turn has a store inst as user.
  4598. static bool GlobalHasStoreUser(GlobalVariable *GV) {
  4599. std::set<Value *> visited;
  4600. Value *V = cast<Value>(GV);
  4601. return GlobalHasStoreUserRec(V, visited);
  4602. }
  4603. static GlobalVariable *CreateStaticGlobal(llvm::Module *M, GlobalVariable *GV) {
  4604. Constant *GC = M->getOrInsertGlobal(GV->getName().str() + ".static.copy",
  4605. GV->getType()->getPointerElementType());
  4606. GlobalVariable *NGV = cast<GlobalVariable>(GC);
  4607. if (GV->hasInitializer()) {
  4608. NGV->setInitializer(GV->getInitializer());
  4609. } else {
  4610. // The copy being static, it should be initialized per llvm rules
  4611. NGV->setInitializer(Constant::getNullValue(GV->getType()->getPointerElementType()));
  4612. }
  4613. // static global should have internal linkage
  4614. NGV->setLinkage(GlobalValue::InternalLinkage);
  4615. return NGV;
  4616. }
  4617. static void CreateWriteEnabledStaticGlobals(llvm::Module *M,
  4618. llvm::Function *EF) {
  4619. std::vector<GlobalVariable *> worklist;
  4620. for (GlobalVariable &GV : M->globals()) {
  4621. if (!GV.isConstant() && GV.getLinkage() != GlobalValue::InternalLinkage &&
  4622. // skip globals which are HLSL objects or group shared
  4623. !dxilutil::IsHLSLObjectType(GV.getType()->getPointerElementType()) &&
  4624. !dxilutil::IsSharedMemoryGlobal(&GV)) {
  4625. if (GlobalHasStoreUser(&GV))
  4626. worklist.emplace_back(&GV);
  4627. // TODO: Ensure that constant globals aren't using initializer
  4628. GV.setConstant(true);
  4629. }
  4630. }
  4631. IRBuilder<> Builder(
  4632. dxilutil::FirstNonAllocaInsertionPt(&EF->getEntryBlock()));
  4633. for (GlobalVariable *GV : worklist) {
  4634. GlobalVariable *NGV = CreateStaticGlobal(M, GV);
  4635. GV->replaceAllUsesWith(NGV);
  4636. // insert memcpy in all entryblocks
  4637. uint64_t size = M->getDataLayout().getTypeAllocSize(
  4638. GV->getType()->getPointerElementType());
  4639. Builder.CreateMemCpy(NGV, GV, size, 1);
  4640. }
  4641. }
  4642. // Translate RayQuery constructor. From:
  4643. // %call = call %"RayQuery<flags>" @<constructor>(%"RayQuery<flags>" %ptr)
  4644. // To:
  4645. // i32 %handle = AllocateRayQuery(i32 <IntrinsicOp::IOP_AllocateRayQuery>, i32 %flags)
  4646. // %gep = GEP %"RayQuery<flags>" %ptr, 0, 0
  4647. // store i32* %gep, i32 %handle
  4648. // ; and replace uses of %call with %ptr
  4649. void TranslateRayQueryConstructor(llvm::Module &M) {
  4650. SmallVector<Function*, 4> Constructors;
  4651. for (auto &F : M.functions()) {
  4652. // Match templated RayQuery constructor instantiation by prefix and signature.
  4653. // It should be impossible to achieve the same signature from HLSL.
  4654. if (!F.getName().startswith("\01??0?$RayQuery@$"))
  4655. continue;
  4656. llvm::Type *Ty = F.getReturnType();
  4657. if (!Ty->isPointerTy() || !dxilutil::IsHLSLRayQueryType(Ty->getPointerElementType()))
  4658. continue;
  4659. if (F.arg_size() != 1 || Ty != F.arg_begin()->getType())
  4660. continue;
  4661. Constructors.emplace_back(&F);
  4662. }
  4663. for (auto pConstructorFunc : Constructors) {
  4664. llvm::IntegerType *i32Ty = llvm::Type::getInt32Ty(M.getContext());
  4665. llvm::ConstantInt *i32Zero = llvm::ConstantInt::get(i32Ty, (uint64_t)0, false);
  4666. llvm::FunctionType *funcTy = llvm::FunctionType::get(i32Ty, {i32Ty, i32Ty}, false);
  4667. unsigned opcode = (unsigned)IntrinsicOp::IOP_AllocateRayQuery;
  4668. llvm::ConstantInt *opVal = llvm::ConstantInt::get(i32Ty, opcode, false);
  4669. Function *opFunc = GetOrCreateHLFunction(M, funcTy, HLOpcodeGroup::HLIntrinsic, opcode);
  4670. while (!pConstructorFunc->user_empty()) {
  4671. Value *V = *pConstructorFunc->user_begin();
  4672. llvm::CallInst *CI = cast<CallInst>(V); // Must be call
  4673. llvm::Value *pThis = CI->getArgOperand(0);
  4674. llvm::StructType *pRQType = cast<llvm::StructType>(pThis->getType()->getPointerElementType());
  4675. DxilStructAnnotation *SA = M.GetHLModule().GetTypeSystem().GetStructAnnotation(pRQType);
  4676. DXASSERT(SA, "otherwise, could not find type annoation for RayQuery specialization");
  4677. DXASSERT(SA->GetNumTemplateArgs() == 1 && SA->GetTemplateArgAnnotation(0).IsIntegral(),
  4678. "otherwise, RayQuery has changed, or lacks template args");
  4679. llvm::IRBuilder<> Builder(CI);
  4680. llvm::Value *rayFlags = Builder.getInt32(SA->GetTemplateArgAnnotation(0).GetIntegral());
  4681. llvm::Value *Call = Builder.CreateCall(opFunc, {opVal, rayFlags}, pThis->getName());
  4682. llvm::Value *GEP = Builder.CreateInBoundsGEP(pThis, {i32Zero, i32Zero});
  4683. Builder.CreateStore(Call, GEP);
  4684. CI->replaceAllUsesWith(pThis);
  4685. CI->eraseFromParent();
  4686. }
  4687. pConstructorFunc->eraseFromParent();
  4688. }
  4689. }
  4690. void CGMSHLSLRuntime::FinishCodeGen() {
  4691. // Library don't have entry.
  4692. if (!m_bIsLib) {
  4693. SetEntryFunction();
  4694. // If at this point we haven't determined the entry function it's an error.
  4695. if (m_pHLModule->GetEntryFunction() == nullptr) {
  4696. assert(CGM.getDiags().hasErrorOccurred() &&
  4697. "else SetEntryFunction should have reported this condition");
  4698. return;
  4699. }
  4700. // In back-compat mode (with /Gec flag) create a static global for each const global
  4701. // to allow writing to it.
  4702. // TODO: Verfiy the behavior of static globals in hull shader
  4703. if(CGM.getLangOpts().EnableDX9CompatMode && CGM.getLangOpts().HLSLVersion <= 2016)
  4704. CreateWriteEnabledStaticGlobals(m_pHLModule->GetModule(), m_pHLModule->GetEntryFunction());
  4705. if (m_pHLModule->GetShaderModel()->IsHS()) {
  4706. SetPatchConstantFunction(Entry);
  4707. }
  4708. } else {
  4709. for (auto &it : entryFunctionMap) {
  4710. // skip clone if RT entry
  4711. if (m_pHLModule->GetDxilFunctionProps(it.second.Func).IsRay())
  4712. continue;
  4713. // TODO: change flattened function names to dx.entry.<name>:
  4714. //std::string entryName = (Twine(dxilutil::EntryPrefix) + it.getKey()).str();
  4715. CloneShaderEntry(it.second.Func, it.getKey(), *m_pHLModule);
  4716. auto AttrIter = HSEntryPatchConstantFuncAttr.find(it.second.Func);
  4717. if (AttrIter != HSEntryPatchConstantFuncAttr.end()) {
  4718. SetPatchConstantFunctionWithAttr(it.second, AttrIter->second);
  4719. }
  4720. }
  4721. }
  4722. ReplaceConstStaticGlobals(staticConstGlobalInitListMap,
  4723. staticConstGlobalCtorMap);
  4724. // Create copy for clip plane.
  4725. for (Function *F : clipPlaneFuncList) {
  4726. DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(F);
  4727. IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
  4728. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  4729. Value *clipPlane = props.ShaderProps.VS.clipPlanes[i];
  4730. if (!clipPlane)
  4731. continue;
  4732. if (m_bDebugInfo) {
  4733. Builder.SetCurrentDebugLocation(debugInfoMap[clipPlane]);
  4734. }
  4735. llvm::Type *Ty = clipPlane->getType()->getPointerElementType();
  4736. // Constant *zeroInit = ConstantFP::get(Ty, 0);
  4737. GlobalVariable *GV = new llvm::GlobalVariable(
  4738. TheModule, Ty, /*IsConstant*/ false, // constant false to store.
  4739. llvm::GlobalValue::ExternalLinkage,
  4740. /*InitVal*/ nullptr, Twine("SV_ClipPlane") + Twine(i));
  4741. Value *initVal = Builder.CreateLoad(clipPlane);
  4742. Builder.CreateStore(initVal, GV);
  4743. props.ShaderProps.VS.clipPlanes[i] = GV;
  4744. }
  4745. }
  4746. // Add Reg bindings for resource in cb.
  4747. AddRegBindingsForResourceInConstantBuffer(m_pHLModule, constantRegBindingMap);
  4748. // Allocate constant buffers.
  4749. AllocateDxilConstantBuffers(m_pHLModule, m_ConstVarAnnotationMap);
  4750. // TODO: create temp variable for constant which has store use.
  4751. // Create Global variable and type annotation for each CBuffer.
  4752. ConstructCBuffer(m_pHLModule, CBufferType, m_ConstVarAnnotationMap);
  4753. // Translate calls to RayQuery constructor into hl Allocate calls
  4754. TranslateRayQueryConstructor(*m_pHLModule->GetModule());
  4755. if (!m_bIsLib) {
  4756. // need this for "llvm.global_dtors"?
  4757. ProcessCtorFunctions(TheModule ,"llvm.global_ctors",
  4758. Entry.Func->getEntryBlock().getFirstInsertionPt());
  4759. }
  4760. // translate opcode into parameter for intrinsic functions
  4761. AddOpcodeParamForIntrinsics(*m_pHLModule, m_IntrinsicMap, resMetadataMap);
  4762. // Register patch constant functions referenced by exported Hull Shaders
  4763. if (m_bIsLib && !m_ExportMap.empty()) {
  4764. for (auto &it : entryFunctionMap) {
  4765. if (m_pHLModule->HasDxilFunctionProps(it.second.Func)) {
  4766. const DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(it.second.Func);
  4767. if (props.IsHS())
  4768. m_ExportMap.RegisterExportedFunction(props.ShaderProps.HS.patchConstantFunc);
  4769. }
  4770. }
  4771. }
  4772. // Pin entry point and constant buffers, mark everything else internal.
  4773. for (Function &f : m_pHLModule->GetModule()->functions()) {
  4774. if (!m_bIsLib) {
  4775. if (&f == m_pHLModule->GetEntryFunction() ||
  4776. IsPatchConstantFunction(&f) || f.isDeclaration()) {
  4777. if (f.isDeclaration() && !f.isIntrinsic() &&
  4778. GetHLOpcodeGroup(&f) == HLOpcodeGroup::NotHL) {
  4779. DiagnosticsEngine &Diags = CGM.getDiags();
  4780. unsigned DiagID = Diags.getCustomDiagID(
  4781. DiagnosticsEngine::Error,
  4782. "External function used in non-library profile: %0");
  4783. std::string escaped;
  4784. llvm::raw_string_ostream os(escaped);
  4785. dxilutil::PrintEscapedString(f.getName(), os);
  4786. Diags.Report(DiagID) << os.str();
  4787. return;
  4788. }
  4789. f.setLinkage(GlobalValue::LinkageTypes::ExternalLinkage);
  4790. } else {
  4791. f.setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  4792. }
  4793. }
  4794. // Skip no inline functions.
  4795. if (f.hasFnAttribute(llvm::Attribute::NoInline))
  4796. continue;
  4797. // Always inline for used functions.
  4798. if (!f.user_empty() && !f.isDeclaration())
  4799. f.addFnAttr(llvm::Attribute::AlwaysInline);
  4800. }
  4801. if (m_bIsLib && !m_ExportMap.empty()) {
  4802. m_ExportMap.BeginProcessing();
  4803. for (Function &f : m_pHLModule->GetModule()->functions()) {
  4804. if (f.isDeclaration() || f.isIntrinsic() ||
  4805. GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL)
  4806. continue;
  4807. m_ExportMap.ProcessFunction(&f, true);
  4808. }
  4809. // TODO: add subobject export names here.
  4810. if (!m_ExportMap.EndProcessing()) {
  4811. for (auto &name : m_ExportMap.GetNameCollisions()) {
  4812. DiagnosticsEngine &Diags = CGM.getDiags();
  4813. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  4814. "Export name collides with another export: %0");
  4815. std::string escaped;
  4816. llvm::raw_string_ostream os(escaped);
  4817. dxilutil::PrintEscapedString(name, os);
  4818. Diags.Report(DiagID) << os.str();
  4819. }
  4820. for (auto &name : m_ExportMap.GetUnusedExports()) {
  4821. DiagnosticsEngine &Diags = CGM.getDiags();
  4822. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  4823. "Could not find target for export: %0");
  4824. std::string escaped;
  4825. llvm::raw_string_ostream os(escaped);
  4826. dxilutil::PrintEscapedString(name, os);
  4827. Diags.Report(DiagID) << os.str();
  4828. }
  4829. }
  4830. }
  4831. for (auto &it : m_ExportMap.GetFunctionRenames()) {
  4832. Function *F = it.first;
  4833. auto &renames = it.second;
  4834. if (renames.empty())
  4835. continue;
  4836. // Rename the original, if necessary, then clone the rest
  4837. if (renames.find(F->getName()) == renames.end())
  4838. F->setName(*renames.begin());
  4839. for (auto &itName : renames) {
  4840. if (F->getName() != itName) {
  4841. Function *pClone = CloneFunction(F, itName, m_pHLModule->GetModule(),
  4842. m_pHLModule->GetTypeSystem(), m_pHLModule->GetTypeSystem());
  4843. // add DxilFunctionProps if entry
  4844. if (m_pHLModule->HasDxilFunctionProps(F)) {
  4845. DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(F);
  4846. auto newProps = llvm::make_unique<DxilFunctionProps>(props);
  4847. m_pHLModule->AddDxilFunctionProps(pClone, newProps);
  4848. }
  4849. }
  4850. }
  4851. }
  4852. if (CGM.getCodeGenOpts().ExportShadersOnly) {
  4853. for (Function &f : m_pHLModule->GetModule()->functions()) {
  4854. // Skip declarations, intrinsics, shaders, and non-external linkage
  4855. if (f.isDeclaration() || f.isIntrinsic() ||
  4856. GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL ||
  4857. m_pHLModule->HasDxilFunctionProps(&f) ||
  4858. m_pHLModule->IsPatchConstantShader(&f) ||
  4859. f.getLinkage() != GlobalValue::LinkageTypes::ExternalLinkage)
  4860. continue;
  4861. // Mark non-shader user functions as InternalLinkage
  4862. f.setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  4863. }
  4864. }
  4865. // Now iterate hull shaders and make sure their corresponding patch constant
  4866. // functions are marked ExternalLinkage:
  4867. for (Function &f : m_pHLModule->GetModule()->functions()) {
  4868. if (f.isDeclaration() || f.isIntrinsic() ||
  4869. GetHLOpcodeGroup(&f) != HLOpcodeGroup::NotHL ||
  4870. f.getLinkage() != GlobalValue::LinkageTypes::ExternalLinkage ||
  4871. !m_pHLModule->HasDxilFunctionProps(&f))
  4872. continue;
  4873. DxilFunctionProps &props = m_pHLModule->GetDxilFunctionProps(&f);
  4874. if (!props.IsHS())
  4875. continue;
  4876. Function *PCFunc = props.ShaderProps.HS.patchConstantFunc;
  4877. if (PCFunc->getLinkage() != GlobalValue::LinkageTypes::ExternalLinkage)
  4878. PCFunc->setLinkage(GlobalValue::LinkageTypes::ExternalLinkage);
  4879. }
  4880. // Disallow resource arguments in (non-entry) function exports
  4881. // unless offline linking target.
  4882. if (m_bIsLib && m_pHLModule->GetShaderModel()->GetMinor() != ShaderModel::kOfflineMinor) {
  4883. for (Function &f : m_pHLModule->GetModule()->functions()) {
  4884. // Skip llvm intrinsics, non-external linkage, entry/patch constant func, and HL intrinsics
  4885. if (!f.isIntrinsic() &&
  4886. f.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage &&
  4887. !m_pHLModule->HasDxilFunctionProps(&f) &&
  4888. !m_pHLModule->IsPatchConstantShader(&f) &&
  4889. GetHLOpcodeGroup(&f) == HLOpcodeGroup::NotHL) {
  4890. // Verify no resources in param/return types
  4891. if (dxilutil::ContainsHLSLObjectType(f.getReturnType())) {
  4892. ReportDisallowedTypeInExportParam(CGM, f.getName());
  4893. continue;
  4894. }
  4895. for (auto &Arg : f.args()) {
  4896. if (dxilutil::ContainsHLSLObjectType(Arg.getType())) {
  4897. ReportDisallowedTypeInExportParam(CGM, f.getName());
  4898. break;
  4899. }
  4900. }
  4901. }
  4902. }
  4903. }
  4904. // Do simple transform to make later lower pass easier.
  4905. SimpleTransformForHLDXIR(m_pHLModule->GetModule());
  4906. // Handle lang extensions if provided.
  4907. if (CGM.getCodeGenOpts().HLSLExtensionsCodegen) {
  4908. // Add semantic defines for extensions if any are available.
  4909. HLSLExtensionsCodegenHelper::SemanticDefineErrorList errors =
  4910. CGM.getCodeGenOpts().HLSLExtensionsCodegen->WriteSemanticDefines(m_pHLModule->GetModule());
  4911. DiagnosticsEngine &Diags = CGM.getDiags();
  4912. for (const HLSLExtensionsCodegenHelper::SemanticDefineError& error : errors) {
  4913. DiagnosticsEngine::Level level = DiagnosticsEngine::Error;
  4914. if (error.IsWarning())
  4915. level = DiagnosticsEngine::Warning;
  4916. unsigned DiagID = Diags.getCustomDiagID(level, "%0");
  4917. Diags.Report(SourceLocation::getFromRawEncoding(error.Location()), DiagID) << error.Message();
  4918. }
  4919. // Add root signature from a #define. Overrides root signature in function attribute.
  4920. {
  4921. using Status = HLSLExtensionsCodegenHelper::CustomRootSignature::Status;
  4922. HLSLExtensionsCodegenHelper::CustomRootSignature customRootSig;
  4923. Status status = CGM.getCodeGenOpts().HLSLExtensionsCodegen->GetCustomRootSignature(&customRootSig);
  4924. if (status == Status::FOUND) {
  4925. RootSignatureHandle RootSigHandle;
  4926. CompileRootSignature(customRootSig.RootSignature, Diags,
  4927. SourceLocation::getFromRawEncoding(customRootSig.EncodedSourceLocation),
  4928. rootSigVer, DxilRootSignatureCompilationFlags::GlobalRootSignature, &RootSigHandle);
  4929. if (!RootSigHandle.IsEmpty()) {
  4930. RootSigHandle.EnsureSerializedAvailable();
  4931. m_pHLModule->SetSerializedRootSignature(
  4932. RootSigHandle.GetSerializedBytes(),
  4933. RootSigHandle.GetSerializedSize());
  4934. }
  4935. }
  4936. }
  4937. }
  4938. // At this point, we have a high-level DXIL module - record this.
  4939. SetPauseResumePasses(*m_pHLModule->GetModule(), "hlsl-hlemit", "hlsl-hlensure");
  4940. }
  4941. RValue CGMSHLSLRuntime::EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF,
  4942. const FunctionDecl *FD,
  4943. const CallExpr *E,
  4944. ReturnValueSlot ReturnValue) {
  4945. const Decl *TargetDecl = E->getCalleeDecl();
  4946. llvm::Value *Callee = CGF.EmitScalarExpr(E->getCallee());
  4947. RValue RV = CGF.EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
  4948. TargetDecl);
  4949. if (RV.isScalar() && RV.getScalarVal() != nullptr) {
  4950. if (CallInst *CI = dyn_cast<CallInst>(RV.getScalarVal())) {
  4951. Function *F = CI->getCalledFunction();
  4952. HLOpcodeGroup group = hlsl::GetHLOpcodeGroup(F);
  4953. if (group == HLOpcodeGroup::HLIntrinsic) {
  4954. bool allOperandImm = true;
  4955. for (auto &operand : CI->arg_operands()) {
  4956. bool isImm = isa<ConstantInt>(operand) || isa<ConstantFP>(operand) ||
  4957. isa<ConstantAggregateZero>(operand) || isa<ConstantDataVector>(operand);
  4958. if (!isImm) {
  4959. allOperandImm = false;
  4960. break;
  4961. } else if (operand->getType()->isHalfTy()) {
  4962. // Not support half Eval yet.
  4963. allOperandImm = false;
  4964. break;
  4965. }
  4966. }
  4967. if (allOperandImm) {
  4968. unsigned intrinsicOpcode;
  4969. StringRef intrinsicGroup;
  4970. hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup);
  4971. IntrinsicOp opcode = static_cast<IntrinsicOp>(intrinsicOpcode);
  4972. if (Value *Result = TryEvalIntrinsic(CI, opcode)) {
  4973. RV = RValue::get(Result);
  4974. }
  4975. }
  4976. }
  4977. }
  4978. }
  4979. return RV;
  4980. }
  4981. static HLOpcodeGroup GetHLOpcodeGroup(const clang::Stmt::StmtClass stmtClass) {
  4982. switch (stmtClass) {
  4983. case Stmt::CStyleCastExprClass:
  4984. case Stmt::ImplicitCastExprClass:
  4985. case Stmt::CXXFunctionalCastExprClass:
  4986. return HLOpcodeGroup::HLCast;
  4987. case Stmt::InitListExprClass:
  4988. return HLOpcodeGroup::HLInit;
  4989. case Stmt::BinaryOperatorClass:
  4990. case Stmt::CompoundAssignOperatorClass:
  4991. return HLOpcodeGroup::HLBinOp;
  4992. case Stmt::UnaryOperatorClass:
  4993. return HLOpcodeGroup::HLUnOp;
  4994. case Stmt::ExtMatrixElementExprClass:
  4995. return HLOpcodeGroup::HLSubscript;
  4996. case Stmt::CallExprClass:
  4997. return HLOpcodeGroup::HLIntrinsic;
  4998. case Stmt::ConditionalOperatorClass:
  4999. return HLOpcodeGroup::HLSelect;
  5000. default:
  5001. llvm_unreachable("not support operation");
  5002. }
  5003. }
  5004. // NOTE: This table must match BinaryOperator::Opcode
  5005. static const HLBinaryOpcode BinaryOperatorKindMap[] = {
  5006. HLBinaryOpcode::Invalid, // PtrMemD
  5007. HLBinaryOpcode::Invalid, // PtrMemI
  5008. HLBinaryOpcode::Mul, HLBinaryOpcode::Div, HLBinaryOpcode::Rem,
  5009. HLBinaryOpcode::Add, HLBinaryOpcode::Sub, HLBinaryOpcode::Shl,
  5010. HLBinaryOpcode::Shr, HLBinaryOpcode::LT, HLBinaryOpcode::GT,
  5011. HLBinaryOpcode::LE, HLBinaryOpcode::GE, HLBinaryOpcode::EQ,
  5012. HLBinaryOpcode::NE, HLBinaryOpcode::And, HLBinaryOpcode::Xor,
  5013. HLBinaryOpcode::Or, HLBinaryOpcode::LAnd, HLBinaryOpcode::LOr,
  5014. HLBinaryOpcode::Invalid, // Assign,
  5015. // The assign part is done by matrix store
  5016. HLBinaryOpcode::Mul, // MulAssign
  5017. HLBinaryOpcode::Div, // DivAssign
  5018. HLBinaryOpcode::Rem, // RemAssign
  5019. HLBinaryOpcode::Add, // AddAssign
  5020. HLBinaryOpcode::Sub, // SubAssign
  5021. HLBinaryOpcode::Shl, // ShlAssign
  5022. HLBinaryOpcode::Shr, // ShrAssign
  5023. HLBinaryOpcode::And, // AndAssign
  5024. HLBinaryOpcode::Xor, // XorAssign
  5025. HLBinaryOpcode::Or, // OrAssign
  5026. HLBinaryOpcode::Invalid, // Comma
  5027. };
  5028. // NOTE: This table must match UnaryOperator::Opcode
  5029. static const HLUnaryOpcode UnaryOperatorKindMap[] = {
  5030. HLUnaryOpcode::PostInc, HLUnaryOpcode::PostDec,
  5031. HLUnaryOpcode::PreInc, HLUnaryOpcode::PreDec,
  5032. HLUnaryOpcode::Invalid, // AddrOf,
  5033. HLUnaryOpcode::Invalid, // Deref,
  5034. HLUnaryOpcode::Plus, HLUnaryOpcode::Minus,
  5035. HLUnaryOpcode::Not, HLUnaryOpcode::LNot,
  5036. HLUnaryOpcode::Invalid, // Real,
  5037. HLUnaryOpcode::Invalid, // Imag,
  5038. HLUnaryOpcode::Invalid, // Extension
  5039. };
  5040. static unsigned GetHLOpcode(const Expr *E) {
  5041. switch (E->getStmtClass()) {
  5042. case Stmt::CompoundAssignOperatorClass:
  5043. case Stmt::BinaryOperatorClass: {
  5044. const clang::BinaryOperator *binOp = cast<clang::BinaryOperator>(E);
  5045. HLBinaryOpcode binOpcode = BinaryOperatorKindMap[binOp->getOpcode()];
  5046. if (HasUnsignedOpcode(binOpcode)) {
  5047. if (hlsl::IsHLSLUnsigned(binOp->getLHS()->getType())) {
  5048. binOpcode = GetUnsignedOpcode(binOpcode);
  5049. }
  5050. }
  5051. return static_cast<unsigned>(binOpcode);
  5052. }
  5053. case Stmt::UnaryOperatorClass: {
  5054. const UnaryOperator *unOp = cast<clang::UnaryOperator>(E);
  5055. HLUnaryOpcode unOpcode = UnaryOperatorKindMap[unOp->getOpcode()];
  5056. return static_cast<unsigned>(unOpcode);
  5057. }
  5058. case Stmt::ImplicitCastExprClass:
  5059. case Stmt::CStyleCastExprClass: {
  5060. const CastExpr *CE = cast<CastExpr>(E);
  5061. bool toUnsigned = hlsl::IsHLSLUnsigned(E->getType());
  5062. bool fromUnsigned = hlsl::IsHLSLUnsigned(CE->getSubExpr()->getType());
  5063. if (toUnsigned && fromUnsigned)
  5064. return static_cast<unsigned>(HLCastOpcode::UnsignedUnsignedCast);
  5065. else if (toUnsigned)
  5066. return static_cast<unsigned>(HLCastOpcode::ToUnsignedCast);
  5067. else if (fromUnsigned)
  5068. return static_cast<unsigned>(HLCastOpcode::FromUnsignedCast);
  5069. else
  5070. return static_cast<unsigned>(HLCastOpcode::DefaultCast);
  5071. }
  5072. default:
  5073. return 0;
  5074. }
  5075. }
  5076. static Value *
  5077. EmitHLSLMatrixOperationCallImp(CGBuilderTy &Builder, HLOpcodeGroup group,
  5078. unsigned opcode, llvm::Type *RetType,
  5079. ArrayRef<Value *> paramList, llvm::Module &M) {
  5080. SmallVector<llvm::Type *, 4> paramTyList;
  5081. // Add the opcode param
  5082. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  5083. paramTyList.emplace_back(opcodeTy);
  5084. for (Value *param : paramList) {
  5085. paramTyList.emplace_back(param->getType());
  5086. }
  5087. llvm::FunctionType *funcTy =
  5088. llvm::FunctionType::get(RetType, paramTyList, false);
  5089. Function *opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  5090. SmallVector<Value *, 4> opcodeParamList;
  5091. Value *opcodeConst = Constant::getIntegerValue(opcodeTy, APInt(32, opcode));
  5092. opcodeParamList.emplace_back(opcodeConst);
  5093. opcodeParamList.append(paramList.begin(), paramList.end());
  5094. return Builder.CreateCall(opFunc, opcodeParamList);
  5095. }
  5096. static Value *EmitHLSLArrayInit(CGBuilderTy &Builder, HLOpcodeGroup group,
  5097. unsigned opcode, llvm::Type *RetType,
  5098. ArrayRef<Value *> paramList, llvm::Module &M) {
  5099. // It's a matrix init.
  5100. if (!RetType->isVoidTy())
  5101. return EmitHLSLMatrixOperationCallImp(Builder, group, opcode, RetType,
  5102. paramList, M);
  5103. Value *arrayPtr = paramList[0];
  5104. llvm::ArrayType *AT =
  5105. cast<llvm::ArrayType>(arrayPtr->getType()->getPointerElementType());
  5106. // Avoid the arrayPtr.
  5107. unsigned paramSize = paramList.size() - 1;
  5108. // Support simple case here.
  5109. if (paramSize == AT->getArrayNumElements()) {
  5110. bool typeMatch = true;
  5111. llvm::Type *EltTy = AT->getArrayElementType();
  5112. if (EltTy->isAggregateType()) {
  5113. // Aggregate Type use pointer in initList.
  5114. EltTy = llvm::PointerType::get(EltTy, 0);
  5115. }
  5116. for (unsigned i = 1; i < paramList.size(); i++) {
  5117. if (paramList[i]->getType() != EltTy) {
  5118. typeMatch = false;
  5119. break;
  5120. }
  5121. }
  5122. // Both size and type match.
  5123. if (typeMatch) {
  5124. bool isPtr = EltTy->isPointerTy();
  5125. llvm::Type *i32Ty = llvm::Type::getInt32Ty(EltTy->getContext());
  5126. Constant *zero = ConstantInt::get(i32Ty, 0);
  5127. for (unsigned i = 1; i < paramList.size(); i++) {
  5128. Constant *idx = ConstantInt::get(i32Ty, i - 1);
  5129. Value *GEP = Builder.CreateInBoundsGEP(arrayPtr, {zero, idx});
  5130. Value *Elt = paramList[i];
  5131. if (isPtr) {
  5132. Elt = Builder.CreateLoad(Elt);
  5133. }
  5134. Builder.CreateStore(Elt, GEP);
  5135. }
  5136. // The return value will not be used.
  5137. return nullptr;
  5138. }
  5139. }
  5140. // Other case will be lowered in later pass.
  5141. return EmitHLSLMatrixOperationCallImp(Builder, group, opcode, RetType,
  5142. paramList, M);
  5143. }
  5144. void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Value *, 4> &elts,
  5145. SmallVector<QualType, 4> &eltTys,
  5146. QualType Ty, Value *val) {
  5147. CGBuilderTy &Builder = CGF.Builder;
  5148. llvm::Type *valTy = val->getType();
  5149. if (valTy->isPointerTy()) {
  5150. llvm::Type *valEltTy = valTy->getPointerElementType();
  5151. if (valEltTy->isVectorTy() ||
  5152. valEltTy->isSingleValueType()) {
  5153. Value *ldVal = Builder.CreateLoad(val);
  5154. FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
  5155. } else if (HLMatrixType::isa(valEltTy)) {
  5156. Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
  5157. FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
  5158. } else {
  5159. llvm::Type *i32Ty = llvm::Type::getInt32Ty(valTy->getContext());
  5160. Value *zero = ConstantInt::get(i32Ty, 0);
  5161. if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(valEltTy)) {
  5162. QualType EltTy = Ty->getAsArrayTypeUnsafe()->getElementType();
  5163. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  5164. Value *gepIdx = ConstantInt::get(i32Ty, i);
  5165. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  5166. FlattenValToInitList(CGF, elts, eltTys, EltTy,EltPtr);
  5167. }
  5168. } else {
  5169. // Struct.
  5170. StructType *ST = cast<StructType>(valEltTy);
  5171. if (dxilutil::IsHLSLObjectType(ST)) {
  5172. // Save object directly like basic type.
  5173. elts.emplace_back(Builder.CreateLoad(val));
  5174. eltTys.emplace_back(Ty);
  5175. } else {
  5176. RecordDecl *RD = Ty->getAsStructureType()->getDecl();
  5177. const CGRecordLayout& RL = CGF.getTypes().getCGRecordLayout(RD);
  5178. // Take care base.
  5179. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  5180. if (CXXRD->getNumBases()) {
  5181. for (const auto &I : CXXRD->bases()) {
  5182. const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
  5183. I.getType()->castAs<RecordType>()->getDecl());
  5184. if (BaseDecl->field_empty())
  5185. continue;
  5186. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  5187. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  5188. Value *gepIdx = ConstantInt::get(i32Ty, i);
  5189. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  5190. FlattenValToInitList(CGF, elts, eltTys, parentTy, EltPtr);
  5191. }
  5192. }
  5193. }
  5194. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  5195. fieldIter != fieldEnd; ++fieldIter) {
  5196. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  5197. Value *gepIdx = ConstantInt::get(i32Ty, i);
  5198. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  5199. FlattenValToInitList(CGF, elts, eltTys, fieldIter->getType(), EltPtr);
  5200. }
  5201. }
  5202. }
  5203. }
  5204. } else {
  5205. if (HLMatrixType MatTy = HLMatrixType::dyn_cast(valTy)) {
  5206. llvm::Type *EltTy = MatTy.getElementTypeForReg();
  5207. // All matrix Value should be row major.
  5208. // Init list is row major in scalar.
  5209. // So the order is match here, just cast to vector.
  5210. unsigned matSize = MatTy.getNumElements();
  5211. bool isRowMajor = hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  5212. HLCastOpcode opcode = isRowMajor ? HLCastOpcode::RowMatrixToVecCast
  5213. : HLCastOpcode::ColMatrixToVecCast;
  5214. // Cast to vector.
  5215. val = EmitHLSLMatrixOperationCallImp(
  5216. Builder, HLOpcodeGroup::HLCast,
  5217. static_cast<unsigned>(opcode),
  5218. llvm::VectorType::get(EltTy, matSize), {val}, TheModule);
  5219. valTy = val->getType();
  5220. }
  5221. if (valTy->isVectorTy()) {
  5222. QualType EltTy = hlsl::GetElementTypeOrType(Ty);
  5223. unsigned vecSize = valTy->getVectorNumElements();
  5224. for (unsigned i = 0; i < vecSize; i++) {
  5225. Value *Elt = Builder.CreateExtractElement(val, i);
  5226. elts.emplace_back(Elt);
  5227. eltTys.emplace_back(EltTy);
  5228. }
  5229. } else {
  5230. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  5231. elts.emplace_back(val);
  5232. eltTys.emplace_back(Ty);
  5233. }
  5234. }
  5235. }
  5236. static Value* ConvertScalarOrVector(CGBuilderTy& Builder, CodeGenTypes &Types,
  5237. Value *Val, QualType SrcQualTy, QualType DstQualTy) {
  5238. llvm::Type *SrcTy = Val->getType();
  5239. llvm::Type *DstTy = Types.ConvertType(DstQualTy);
  5240. DXASSERT(Val->getType() == Types.ConvertType(SrcQualTy), "QualType/Value mismatch!");
  5241. DXASSERT((SrcTy->isIntOrIntVectorTy() || SrcTy->isFPOrFPVectorTy())
  5242. && (DstTy->isIntOrIntVectorTy() || DstTy->isFPOrFPVectorTy()),
  5243. "EmitNumericConversion can only be used with int/float scalars/vectors.");
  5244. if (SrcTy == DstTy) return Val; // Valid no-op, including uint to int / int to uint
  5245. DXASSERT(SrcTy->isVectorTy()
  5246. ? (DstTy->isVectorTy() && SrcTy->getVectorNumElements() == DstTy->getVectorNumElements())
  5247. : !DstTy->isVectorTy(),
  5248. "EmitNumericConversion can only cast between scalars or vectors of matching sizes");
  5249. // Conversions to bools are comparisons
  5250. if (DstTy->getScalarSizeInBits() == 1) {
  5251. // fcmp une is what regular clang uses in C++ for (bool)f;
  5252. return SrcTy->isIntOrIntVectorTy()
  5253. ? Builder.CreateICmpNE(Val, llvm::Constant::getNullValue(SrcTy), "tobool")
  5254. : Builder.CreateFCmpUNE(Val, llvm::Constant::getNullValue(SrcTy), "tobool");
  5255. }
  5256. // Cast necessary
  5257. auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
  5258. SrcTy, hlsl::IsHLSLUnsigned(SrcQualTy), DstTy, hlsl::IsHLSLUnsigned(DstQualTy)));
  5259. return Builder.CreateCast(CastOp, Val, DstTy);
  5260. }
  5261. static Value* ConvertScalarOrVector(CodeGenFunction &CGF,
  5262. Value *Val, QualType SrcQualTy, QualType DstQualTy) {
  5263. return ConvertScalarOrVector(CGF.Builder, CGF.getTypes(), Val, SrcQualTy, DstQualTy);
  5264. }
  5265. // Cast elements in initlist if not match the target type.
  5266. // idx is current element index in initlist, Ty is target type.
  5267. // TODO: Stop handling missing cast here. Handle the casting of non-scalar values
  5268. // to their destination type in init list expressions at AST level.
  5269. static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVector<QualType, 4> &eltTys, unsigned &idx, QualType Ty, CodeGenFunction &CGF) {
  5270. if (Ty->isArrayType()) {
  5271. const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
  5272. // Must be ConstantArrayType here.
  5273. unsigned arraySize = cast<ConstantArrayType>(AT)->getSize().getLimitedValue();
  5274. QualType EltTy = AT->getElementType();
  5275. for (unsigned i = 0; i < arraySize; i++)
  5276. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  5277. } else if (IsHLSLVecType(Ty)) {
  5278. QualType EltTy = GetHLSLVecElementType(Ty);
  5279. unsigned vecSize = GetHLSLVecSize(Ty);
  5280. for (unsigned i=0;i< vecSize;i++)
  5281. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  5282. } else if (IsHLSLMatType(Ty)) {
  5283. QualType EltTy = GetHLSLMatElementType(Ty);
  5284. unsigned row, col;
  5285. GetHLSLMatRowColCount(Ty, row, col);
  5286. unsigned matSize = row*col;
  5287. for (unsigned i = 0; i < matSize; i++)
  5288. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  5289. } else if (Ty->isRecordType()) {
  5290. if (dxilutil::IsHLSLObjectType(CGF.ConvertType(Ty))) {
  5291. // Skip hlsl object.
  5292. idx++;
  5293. } else {
  5294. const RecordType *RT = Ty->getAsStructureType();
  5295. // For CXXRecord.
  5296. if (!RT)
  5297. RT = Ty->getAs<RecordType>();
  5298. RecordDecl *RD = RT->getDecl();
  5299. // Take care base.
  5300. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  5301. if (CXXRD->getNumBases()) {
  5302. for (const auto &I : CXXRD->bases()) {
  5303. const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
  5304. I.getType()->castAs<RecordType>()->getDecl());
  5305. if (BaseDecl->field_empty())
  5306. continue;
  5307. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  5308. AddMissingCastOpsInInitList(elts, eltTys, idx, parentTy, CGF);
  5309. }
  5310. }
  5311. }
  5312. for (FieldDecl *field : RD->fields())
  5313. AddMissingCastOpsInInitList(elts, eltTys, idx, field->getType(), CGF);
  5314. }
  5315. }
  5316. else {
  5317. // Basic type.
  5318. elts[idx] = ConvertScalarOrVector(CGF, elts[idx], eltTys[idx], Ty);
  5319. idx++;
  5320. }
  5321. }
  5322. static void StoreInitListToDestPtr(Value *DestPtr,
  5323. SmallVector<Value *, 4> &elts, unsigned &idx,
  5324. QualType Type, bool bDefaultRowMajor,
  5325. CodeGenFunction &CGF, llvm::Module &M) {
  5326. CodeGenTypes &Types = CGF.getTypes();
  5327. CGBuilderTy &Builder = CGF.Builder;
  5328. llvm::Type *Ty = DestPtr->getType()->getPointerElementType();
  5329. if (Ty->isVectorTy()) {
  5330. llvm::Type *RegTy = CGF.ConvertType(Type);
  5331. Value *Result = UndefValue::get(RegTy);
  5332. for (unsigned i = 0; i < RegTy->getVectorNumElements(); i++)
  5333. Result = Builder.CreateInsertElement(Result, elts[idx + i], i);
  5334. Result = CGF.EmitToMemory(Result, Type);
  5335. Builder.CreateStore(Result, DestPtr);
  5336. idx += Ty->getVectorNumElements();
  5337. } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
  5338. bool isRowMajor = hlsl::IsHLSLMatRowMajor(Type, bDefaultRowMajor);
  5339. std::vector<Value *> matInitList(MatTy.getNumElements());
  5340. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  5341. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  5342. unsigned matIdx = c * MatTy.getNumRows() + r;
  5343. matInitList[matIdx] = elts[idx + matIdx];
  5344. }
  5345. }
  5346. idx += MatTy.getNumElements();
  5347. Value *matVal =
  5348. EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
  5349. /*opcode*/ 0, Ty, matInitList, M);
  5350. // matVal return from HLInit is row major.
  5351. // If DestPtr is row major, just store it directly.
  5352. if (!isRowMajor) {
  5353. // ColMatStore need a col major value.
  5354. // Cast row major matrix into col major.
  5355. // Then store it.
  5356. Value *colMatVal = EmitHLSLMatrixOperationCallImp(
  5357. Builder, HLOpcodeGroup::HLCast,
  5358. static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix), Ty,
  5359. {matVal}, M);
  5360. EmitHLSLMatrixOperationCallImp(
  5361. Builder, HLOpcodeGroup::HLMatLoadStore,
  5362. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
  5363. {DestPtr, colMatVal}, M);
  5364. } else {
  5365. EmitHLSLMatrixOperationCallImp(
  5366. Builder, HLOpcodeGroup::HLMatLoadStore,
  5367. static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
  5368. {DestPtr, matVal}, M);
  5369. }
  5370. } else if (Ty->isStructTy()) {
  5371. if (dxilutil::IsHLSLObjectType(Ty)) {
  5372. Builder.CreateStore(elts[idx], DestPtr);
  5373. idx++;
  5374. } else {
  5375. Constant *zero = Builder.getInt32(0);
  5376. const RecordType *RT = Type->getAsStructureType();
  5377. // For CXXRecord.
  5378. if (!RT)
  5379. RT = Type->getAs<RecordType>();
  5380. RecordDecl *RD = RT->getDecl();
  5381. const CGRecordLayout &RL = Types.getCGRecordLayout(RD);
  5382. // Take care base.
  5383. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  5384. if (CXXRD->getNumBases()) {
  5385. for (const auto &I : CXXRD->bases()) {
  5386. const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
  5387. I.getType()->castAs<RecordType>()->getDecl());
  5388. if (BaseDecl->field_empty())
  5389. continue;
  5390. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  5391. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  5392. Constant *gepIdx = Builder.getInt32(i);
  5393. Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
  5394. StoreInitListToDestPtr(GEP, elts, idx, parentTy,
  5395. bDefaultRowMajor, CGF, M);
  5396. }
  5397. }
  5398. }
  5399. for (FieldDecl *field : RD->fields()) {
  5400. unsigned i = RL.getLLVMFieldNo(field);
  5401. Constant *gepIdx = Builder.getInt32(i);
  5402. Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
  5403. StoreInitListToDestPtr(GEP, elts, idx, field->getType(),
  5404. bDefaultRowMajor, CGF, M);
  5405. }
  5406. }
  5407. } else if (Ty->isArrayTy()) {
  5408. Constant *zero = Builder.getInt32(0);
  5409. QualType EltType = Type->getAsArrayTypeUnsafe()->getElementType();
  5410. for (unsigned i = 0; i < Ty->getArrayNumElements(); i++) {
  5411. Constant *gepIdx = Builder.getInt32(i);
  5412. Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
  5413. StoreInitListToDestPtr(GEP, elts, idx, EltType, bDefaultRowMajor,
  5414. CGF, M);
  5415. }
  5416. } else {
  5417. DXASSERT(Ty->isSingleValueType(), "invalid type");
  5418. llvm::Type *i1Ty = Builder.getInt1Ty();
  5419. Value *V = elts[idx];
  5420. if (V->getType() == i1Ty &&
  5421. DestPtr->getType()->getPointerElementType() != i1Ty) {
  5422. V = Builder.CreateZExt(V, DestPtr->getType()->getPointerElementType());
  5423. }
  5424. Builder.CreateStore(V, DestPtr);
  5425. idx++;
  5426. }
  5427. }
  5428. void CGMSHLSLRuntime::ScanInitList(CodeGenFunction &CGF, InitListExpr *E,
  5429. SmallVector<Value *, 4> &EltValList,
  5430. SmallVector<QualType, 4> &EltTyList) {
  5431. unsigned NumInitElements = E->getNumInits();
  5432. for (unsigned i = 0; i != NumInitElements; ++i) {
  5433. Expr *init = E->getInit(i);
  5434. QualType iType = init->getType();
  5435. if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
  5436. ScanInitList(CGF, initList, EltValList, EltTyList);
  5437. } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
  5438. llvm::Value *initVal = CGF.EmitScalarExpr(init);
  5439. FlattenValToInitList(CGF, EltValList, EltTyList, iType, initVal);
  5440. } else {
  5441. AggValueSlot Slot =
  5442. CGF.CreateAggTemp(init->getType(), "Agg.InitList.tmp");
  5443. CGF.EmitAggExpr(init, Slot);
  5444. llvm::Value *aggPtr = Slot.getAddr();
  5445. FlattenValToInitList(CGF, EltValList, EltTyList, iType, aggPtr);
  5446. }
  5447. }
  5448. }
  5449. // Is Type of E match Ty.
  5450. static bool ExpTypeMatch(Expr *E, QualType Ty, ASTContext &Ctx, CodeGenTypes &Types) {
  5451. if (InitListExpr *initList = dyn_cast<InitListExpr>(E)) {
  5452. unsigned NumInitElements = initList->getNumInits();
  5453. // Skip vector and matrix type.
  5454. if (Ty->isVectorType())
  5455. return false;
  5456. if (hlsl::IsHLSLVecMatType(Ty))
  5457. return false;
  5458. if (Ty->isStructureOrClassType()) {
  5459. RecordDecl *record = Ty->castAs<RecordType>()->getDecl();
  5460. bool bMatch = true;
  5461. unsigned i = 0;
  5462. for (auto it = record->field_begin(), end = record->field_end();
  5463. it != end; it++) {
  5464. if (i == NumInitElements) {
  5465. bMatch = false;
  5466. break;
  5467. }
  5468. Expr *init = initList->getInit(i++);
  5469. QualType EltTy = it->getType();
  5470. bMatch &= ExpTypeMatch(init, EltTy, Ctx, Types);
  5471. if (!bMatch)
  5472. break;
  5473. }
  5474. bMatch &= i == NumInitElements;
  5475. if (bMatch && initList->getType()->isVoidType()) {
  5476. initList->setType(Ty);
  5477. }
  5478. return bMatch;
  5479. } else if (Ty->isArrayType() && !Ty->isIncompleteArrayType()) {
  5480. const ConstantArrayType *AT = Ctx.getAsConstantArrayType(Ty);
  5481. QualType EltTy = AT->getElementType();
  5482. unsigned size = AT->getSize().getZExtValue();
  5483. if (size != NumInitElements)
  5484. return false;
  5485. bool bMatch = true;
  5486. for (unsigned i = 0; i != NumInitElements; ++i) {
  5487. Expr *init = initList->getInit(i);
  5488. bMatch &= ExpTypeMatch(init, EltTy, Ctx, Types);
  5489. if (!bMatch)
  5490. break;
  5491. }
  5492. if (bMatch && initList->getType()->isVoidType()) {
  5493. initList->setType(Ty);
  5494. }
  5495. return bMatch;
  5496. } else {
  5497. return false;
  5498. }
  5499. } else {
  5500. llvm::Type *ExpTy = Types.ConvertType(E->getType());
  5501. llvm::Type *TargetTy = Types.ConvertType(Ty);
  5502. return ExpTy == TargetTy;
  5503. }
  5504. }
  5505. bool CGMSHLSLRuntime::IsTrivalInitListExpr(CodeGenFunction &CGF,
  5506. InitListExpr *E) {
  5507. QualType Ty = E->getType();
  5508. bool result = ExpTypeMatch(E, Ty, CGF.getContext(), CGF.getTypes());
  5509. if (result) {
  5510. auto iter = staticConstGlobalInitMap.find(E);
  5511. if (iter != staticConstGlobalInitMap.end()) {
  5512. GlobalVariable * GV = iter->second;
  5513. auto &InitConstants = staticConstGlobalInitListMap[GV];
  5514. // Add Constant to InitList.
  5515. for (unsigned i=0;i<E->getNumInits();i++) {
  5516. Expr *Expr = E->getInit(i);
  5517. if (ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(Expr)) {
  5518. if (Cast->getCastKind() == CK_LValueToRValue) {
  5519. Expr = Cast->getSubExpr();
  5520. }
  5521. }
  5522. // Only do this on lvalue, if not lvalue, it will not be constant
  5523. // anyway.
  5524. if (Expr->isLValue()) {
  5525. LValue LV = CGF.EmitLValue(Expr);
  5526. if (LV.isSimple()) {
  5527. Constant *SrcPtr = dyn_cast<Constant>(LV.getAddress());
  5528. if (SrcPtr && !isa<UndefValue>(SrcPtr)) {
  5529. InitConstants.emplace_back(SrcPtr);
  5530. continue;
  5531. }
  5532. }
  5533. }
  5534. // Only support simple LV and Constant Ptr case.
  5535. // Other case just go normal path.
  5536. InitConstants.clear();
  5537. break;
  5538. }
  5539. if (InitConstants.empty())
  5540. staticConstGlobalInitListMap.erase(GV);
  5541. else
  5542. staticConstGlobalCtorMap[GV] = CGF.CurFn;
  5543. }
  5544. }
  5545. return result;
  5546. }
  5547. Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E,
  5548. // The destPtr when emiting aggregate init, for normal case, it will be null.
  5549. Value *DestPtr) {
  5550. if (DestPtr && E->getNumInits() == 1) {
  5551. llvm::Type *ExpTy = CGF.ConvertType(E->getType());
  5552. llvm::Type *TargetTy = CGF.ConvertType(E->getInit(0)->getType());
  5553. if (ExpTy == TargetTy) {
  5554. Expr *Expr = E->getInit(0);
  5555. LValue LV = CGF.EmitLValue(Expr);
  5556. if (LV.isSimple()) {
  5557. Value *SrcPtr = LV.getAddress();
  5558. SmallVector<Value *, 4> idxList;
  5559. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, Expr->getType(),
  5560. E->getType(), SrcPtr->getType());
  5561. return nullptr;
  5562. }
  5563. }
  5564. }
  5565. SmallVector<Value *, 4> EltValList;
  5566. SmallVector<QualType, 4> EltTyList;
  5567. ScanInitList(CGF, E, EltValList, EltTyList);
  5568. QualType ResultTy = E->getType();
  5569. unsigned idx = 0;
  5570. // Create cast if need.
  5571. AddMissingCastOpsInInitList(EltValList, EltTyList, idx, ResultTy, CGF);
  5572. DXASSERT(idx == EltValList.size(), "size must match");
  5573. llvm::Type *RetTy = CGF.ConvertType(ResultTy);
  5574. if (DestPtr) {
  5575. SmallVector<Value *, 4> ParamList;
  5576. DXASSERT_NOMSG(RetTy->isAggregateType());
  5577. ParamList.emplace_back(DestPtr);
  5578. ParamList.append(EltValList.begin(), EltValList.end());
  5579. idx = 0;
  5580. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  5581. StoreInitListToDestPtr(DestPtr, EltValList, idx, ResultTy,
  5582. bDefaultRowMajor, CGF, TheModule);
  5583. return nullptr;
  5584. }
  5585. if (IsHLSLVecType(ResultTy)) {
  5586. Value *Result = UndefValue::get(RetTy);
  5587. for (unsigned i = 0; i < RetTy->getVectorNumElements(); i++)
  5588. Result = CGF.Builder.CreateInsertElement(Result, EltValList[i], i);
  5589. return Result;
  5590. } else {
  5591. // Must be matrix here.
  5592. DXASSERT(IsHLSLMatType(ResultTy), "must be matrix type here.");
  5593. return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLInit,
  5594. /*opcode*/ 0, RetTy, EltValList,
  5595. TheModule);
  5596. }
  5597. }
  5598. static void FlatConstToList(CodeGenTypes &Types, bool bDefaultRowMajor,
  5599. Constant *C, QualType QualTy,
  5600. SmallVectorImpl<Constant *> &EltVals, SmallVectorImpl<QualType> &EltQualTys) {
  5601. llvm::Type *Ty = C->getType();
  5602. DXASSERT(Types.ConvertTypeForMem(QualTy) == Ty, "QualType/Type mismatch!");
  5603. if (llvm::VectorType *VecTy = dyn_cast<llvm::VectorType>(Ty)) {
  5604. DXASSERT(hlsl::IsHLSLVecType(QualTy), "QualType/Type mismatch!");
  5605. QualType VecElemQualTy = hlsl::GetHLSLVecElementType(QualTy);
  5606. for (unsigned i = 0; i < VecTy->getNumElements(); i++) {
  5607. EltVals.emplace_back(C->getAggregateElement(i));
  5608. EltQualTys.emplace_back(VecElemQualTy);
  5609. }
  5610. } else if (HLMatrixType::isa(Ty)) {
  5611. DXASSERT(hlsl::IsHLSLMatType(QualTy), "QualType/Type mismatch!");
  5612. // matrix type is struct { [rowcount x <colcount x T>] };
  5613. // Strip the struct level here.
  5614. Constant *RowArrayVal = C->getAggregateElement((unsigned)0);
  5615. QualType MatEltQualTy = hlsl::GetHLSLMatElementType(QualTy);
  5616. unsigned RowCount, ColCount;
  5617. hlsl::GetHLSLMatRowColCount(QualTy, RowCount, ColCount);
  5618. // Get all the elements from the array of row vectors.
  5619. // Matrices are never in memory representation so convert as needed.
  5620. SmallVector<Constant *, 16> MatElts;
  5621. for (unsigned r = 0; r < RowCount; ++r) {
  5622. Constant *RowVec = RowArrayVal->getAggregateElement(r);
  5623. for (unsigned c = 0; c < ColCount; ++c) {
  5624. Constant *MatElt = RowVec->getAggregateElement(c);
  5625. if (MatEltQualTy->isBooleanType()) {
  5626. DXASSERT(MatElt->getType()->isIntegerTy(1),
  5627. "Matrix elements should be in their register representation.");
  5628. MatElt = llvm::ConstantExpr::getZExt(MatElt, Types.ConvertTypeForMem(MatEltQualTy));
  5629. }
  5630. MatElts.emplace_back(MatElt);
  5631. }
  5632. }
  5633. // Return the elements in the order respecting the orientation.
  5634. // Constant initializers are used as the initial value for static variables,
  5635. // which live in memory. This is why they have to respect memory packing order.
  5636. bool IsRowMajor = hlsl::IsHLSLMatRowMajor(QualTy, bDefaultRowMajor);
  5637. for (unsigned r = 0; r < RowCount; ++r) {
  5638. for (unsigned c = 0; c < ColCount; ++c) {
  5639. unsigned Idx = IsRowMajor ? (r * ColCount + c) : (c * RowCount + r);
  5640. EltVals.emplace_back(MatElts[Idx]);
  5641. EltQualTys.emplace_back(MatEltQualTy);
  5642. }
  5643. }
  5644. }
  5645. else if (const clang::ConstantArrayType *ClangArrayTy = Types.getContext().getAsConstantArrayType(QualTy)) {
  5646. QualType ArrayEltQualTy = ClangArrayTy->getElementType();
  5647. uint64_t ArraySize = ClangArrayTy->getSize().getLimitedValue();
  5648. DXASSERT(cast<llvm::ArrayType>(Ty)->getArrayNumElements() == ArraySize, "QualType/Type mismatch!");
  5649. for (unsigned i = 0; i < ArraySize; i++) {
  5650. FlatConstToList(Types, bDefaultRowMajor, C->getAggregateElement(i), ArrayEltQualTy,
  5651. EltVals, EltQualTys);
  5652. }
  5653. }
  5654. else if (const clang::RecordType* RecordTy = QualTy->getAs<clang::RecordType>()) {
  5655. DXASSERT(dyn_cast<llvm::StructType>(Ty) != nullptr, "QualType/Type mismatch!");
  5656. RecordDecl *RecordDecl = RecordTy->getDecl();
  5657. const CGRecordLayout &RL = Types.getCGRecordLayout(RecordDecl);
  5658. // Take care base.
  5659. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RecordDecl)) {
  5660. if (CXXRD->getNumBases()) {
  5661. for (const auto &I : CXXRD->bases()) {
  5662. const CXXRecordDecl *BaseDecl =
  5663. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  5664. if (BaseDecl->field_empty())
  5665. continue;
  5666. QualType BaseQualTy = QualType(BaseDecl->getTypeForDecl(), 0);
  5667. unsigned BaseFieldIdx = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  5668. FlatConstToList(Types, bDefaultRowMajor,
  5669. C->getAggregateElement(BaseFieldIdx), BaseQualTy, EltVals, EltQualTys);
  5670. }
  5671. }
  5672. }
  5673. for (auto FieldIt = RecordDecl->field_begin(), fieldEnd = RecordDecl->field_end();
  5674. FieldIt != fieldEnd; ++FieldIt) {
  5675. unsigned FieldIndex = RL.getLLVMFieldNo(*FieldIt);
  5676. FlatConstToList(Types, bDefaultRowMajor,
  5677. C->getAggregateElement(FieldIndex), FieldIt->getType(), EltVals, EltQualTys);
  5678. }
  5679. }
  5680. else {
  5681. // At this point, we should have scalars in their memory representation
  5682. DXASSERT_NOMSG(QualTy->isBuiltinType());
  5683. EltVals.emplace_back(C);
  5684. EltQualTys.emplace_back(QualTy);
  5685. }
  5686. }
  5687. static bool ScanConstInitList(CodeGenModule &CGM, bool bDefaultRowMajor,
  5688. InitListExpr *InitList,
  5689. SmallVectorImpl<Constant *> &EltVals,
  5690. SmallVectorImpl<QualType> &EltQualTys) {
  5691. unsigned NumInitElements = InitList->getNumInits();
  5692. for (unsigned i = 0; i != NumInitElements; ++i) {
  5693. Expr *InitExpr = InitList->getInit(i);
  5694. QualType InitQualTy = InitExpr->getType();
  5695. if (InitListExpr *SubInitList = dyn_cast<InitListExpr>(InitExpr)) {
  5696. if (!ScanConstInitList(CGM, bDefaultRowMajor, SubInitList, EltVals, EltQualTys))
  5697. return false;
  5698. } else if (DeclRefExpr *DeclRef = dyn_cast<DeclRefExpr>(InitExpr)) {
  5699. if (VarDecl *Var = dyn_cast<VarDecl>(DeclRef->getDecl())) {
  5700. if (!Var->hasInit())
  5701. return false;
  5702. if (Constant *InitVal = CGM.EmitConstantInit(*Var)) {
  5703. FlatConstToList(CGM.getTypes(), bDefaultRowMajor,
  5704. InitVal, InitQualTy, EltVals, EltQualTys);
  5705. } else {
  5706. return false;
  5707. }
  5708. } else {
  5709. return false;
  5710. }
  5711. } else if (hlsl::IsHLSLMatType(InitQualTy)) {
  5712. return false;
  5713. } else if (CodeGenFunction::hasScalarEvaluationKind(InitQualTy)) {
  5714. if (Constant *InitVal = CGM.EmitConstantExpr(InitExpr, InitQualTy)) {
  5715. FlatConstToList(CGM.getTypes(), bDefaultRowMajor, InitVal, InitQualTy, EltVals, EltQualTys);
  5716. } else {
  5717. return false;
  5718. }
  5719. } else {
  5720. return false;
  5721. }
  5722. }
  5723. return true;
  5724. }
  5725. static Constant *BuildConstInitializer(CodeGenTypes &Types, bool bDefaultRowMajor,
  5726. QualType QualTy, bool MemRepr,
  5727. SmallVectorImpl<Constant *> &EltVals, SmallVectorImpl<QualType> &EltQualTys, unsigned &EltIdx);
  5728. static Constant *BuildConstMatrix(CodeGenTypes &Types, bool bDefaultRowMajor, QualType QualTy,
  5729. SmallVectorImpl<Constant *> &EltVals, SmallVectorImpl<QualType> &EltQualTys, unsigned &EltIdx) {
  5730. QualType MatEltTy = hlsl::GetHLSLMatElementType(QualTy);
  5731. unsigned RowCount, ColCount;
  5732. hlsl::GetHLSLMatRowColCount(QualTy, RowCount, ColCount);
  5733. bool IsRowMajor = hlsl::IsHLSLMatRowMajor(QualTy, bDefaultRowMajor);
  5734. // Save initializer elements first.
  5735. // Matrix initializer is row major.
  5736. SmallVector<Constant *, 16> RowMajorMatElts;
  5737. for (unsigned i = 0; i < RowCount * ColCount; i++) {
  5738. // Matrix elements are never in their memory representation,
  5739. // to preserve type information for later lowering.
  5740. bool MemRepr = false;
  5741. RowMajorMatElts.emplace_back(BuildConstInitializer(
  5742. Types, bDefaultRowMajor, MatEltTy, MemRepr,
  5743. EltVals, EltQualTys, EltIdx));
  5744. }
  5745. SmallVector<Constant *, 16> FinalMatElts;
  5746. if (IsRowMajor) {
  5747. FinalMatElts = RowMajorMatElts;
  5748. }
  5749. else {
  5750. // Cast row major to col major.
  5751. for (unsigned c = 0; c < ColCount; c++) {
  5752. for (unsigned r = 0; r < RowCount; r++) {
  5753. FinalMatElts.emplace_back(RowMajorMatElts[r * ColCount + c]);
  5754. }
  5755. }
  5756. }
  5757. // The type is vector<element, col>[row].
  5758. SmallVector<Constant *, 4> Rows;
  5759. unsigned idx = 0;
  5760. for (unsigned r = 0; r < RowCount; r++) {
  5761. SmallVector<Constant *, 4> RowElts;
  5762. for (unsigned c = 0; c < ColCount; c++) {
  5763. RowElts.emplace_back(FinalMatElts[idx++]);
  5764. }
  5765. Rows.emplace_back(llvm::ConstantVector::get(RowElts));
  5766. }
  5767. Constant *RowArray = llvm::ConstantArray::get(
  5768. llvm::ArrayType::get(Rows[0]->getType(), Rows.size()), Rows);
  5769. return llvm::ConstantStruct::get(cast<llvm::StructType>(Types.ConvertType(QualTy)), RowArray);
  5770. }
  5771. static Constant *BuildConstStruct(CodeGenTypes &Types, bool bDefaultRowMajor, QualType QualTy,
  5772. SmallVectorImpl<Constant *> &EltVals, SmallVectorImpl<QualType> &EltQualTys, unsigned &EltIdx) {
  5773. const RecordDecl *Record = QualTy->castAs<RecordType>()->getDecl();
  5774. bool MemRepr = true; // Structs are always in their memory representation
  5775. SmallVector<Constant *, 4> FieldVals;
  5776. if (const CXXRecordDecl *CXXRecord = dyn_cast<CXXRecordDecl>(Record)) {
  5777. if (CXXRecord->getNumBases()) {
  5778. // Add base as field.
  5779. for (const auto &BaseSpec : CXXRecord->bases()) {
  5780. const CXXRecordDecl *BaseDecl =
  5781. cast<CXXRecordDecl>(BaseSpec.getType()->castAs<RecordType>()->getDecl());
  5782. // Skip empty struct.
  5783. if (BaseDecl->field_empty())
  5784. continue;
  5785. // Add base as a whole constant. Not as element.
  5786. FieldVals.emplace_back(BuildConstInitializer(Types, bDefaultRowMajor,
  5787. BaseSpec.getType(), MemRepr, EltVals, EltQualTys, EltIdx));
  5788. }
  5789. }
  5790. }
  5791. for (auto FieldIt = Record->field_begin(), FieldEnd = Record->field_end();
  5792. FieldIt != FieldEnd; ++FieldIt) {
  5793. FieldVals.emplace_back(BuildConstInitializer(Types, bDefaultRowMajor,
  5794. FieldIt->getType(), MemRepr, EltVals, EltQualTys, EltIdx));
  5795. }
  5796. return llvm::ConstantStruct::get(cast<llvm::StructType>(Types.ConvertTypeForMem(QualTy)), FieldVals);
  5797. }
  5798. static Constant *BuildConstInitializer(CodeGenTypes &Types, bool bDefaultRowMajor,
  5799. QualType QualTy, bool MemRepr,
  5800. SmallVectorImpl<Constant *> &EltVals, SmallVectorImpl<QualType> &EltQualTys, unsigned &EltIdx) {
  5801. if (hlsl::IsHLSLVecType(QualTy)) {
  5802. QualType VecEltQualTy = hlsl::GetHLSLVecElementType(QualTy);
  5803. unsigned VecSize = hlsl::GetHLSLVecSize(QualTy);
  5804. SmallVector<Constant *, 4> VecElts;
  5805. for (unsigned i = 0; i < VecSize; i++) {
  5806. VecElts.emplace_back(BuildConstInitializer(Types, bDefaultRowMajor,
  5807. VecEltQualTy, MemRepr,
  5808. EltVals, EltQualTys, EltIdx));
  5809. }
  5810. return llvm::ConstantVector::get(VecElts);
  5811. }
  5812. else if (const clang::ConstantArrayType *ArrayTy = Types.getContext().getAsConstantArrayType(QualTy)) {
  5813. QualType ArrayEltQualTy = QualType(ArrayTy->getArrayElementTypeNoTypeQual(), 0);
  5814. uint64_t ArraySize = ArrayTy->getSize().getLimitedValue();
  5815. SmallVector<Constant *, 4> ArrayElts;
  5816. for (unsigned i = 0; i < ArraySize; i++) {
  5817. ArrayElts.emplace_back(BuildConstInitializer(Types, bDefaultRowMajor,
  5818. ArrayEltQualTy, true, // Array elements must be in their memory representation
  5819. EltVals, EltQualTys, EltIdx));
  5820. }
  5821. return llvm::ConstantArray::get(
  5822. cast<llvm::ArrayType>(Types.ConvertTypeForMem(QualTy)), ArrayElts);
  5823. }
  5824. else if (hlsl::IsHLSLMatType(QualTy)) {
  5825. return BuildConstMatrix(Types, bDefaultRowMajor, QualTy,
  5826. EltVals, EltQualTys, EltIdx);
  5827. }
  5828. else if (QualTy->getAs<clang::RecordType>() != nullptr) {
  5829. return BuildConstStruct(Types, bDefaultRowMajor, QualTy,
  5830. EltVals, EltQualTys, EltIdx);
  5831. } else {
  5832. DXASSERT_NOMSG(QualTy->isBuiltinType());
  5833. Constant *EltVal = EltVals[EltIdx];
  5834. QualType EltQualTy = EltQualTys[EltIdx];
  5835. EltIdx++;
  5836. // Initializer constants are in their memory representation.
  5837. if (EltQualTy == QualTy && MemRepr) return EltVal;
  5838. CGBuilderTy Builder(EltVal->getContext());
  5839. if (EltQualTy->isBooleanType()) {
  5840. // Convert to register representation
  5841. // We don't have access to CodeGenFunction::EmitFromMemory here
  5842. DXASSERT_NOMSG(!EltVal->getType()->isIntegerTy(1));
  5843. EltVal = cast<Constant>(Builder.CreateICmpNE(EltVal, Constant::getNullValue(EltVal->getType())));
  5844. }
  5845. Constant *Result = cast<Constant>(ConvertScalarOrVector(Builder, Types, EltVal, EltQualTy, QualTy));
  5846. if (QualTy->isBooleanType() && MemRepr) {
  5847. // Convert back to the memory representation
  5848. // We don't have access to CodeGenFunction::EmitToMemory here
  5849. DXASSERT_NOMSG(Result->getType()->isIntegerTy(1));
  5850. Result = cast<Constant>(Builder.CreateZExt(Result, Types.ConvertTypeForMem(QualTy)));
  5851. }
  5852. return Result;
  5853. }
  5854. }
  5855. Constant *CGMSHLSLRuntime::EmitHLSLConstInitListExpr(CodeGenModule &CGM,
  5856. InitListExpr *E) {
  5857. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  5858. SmallVector<Constant *, 4> EltVals;
  5859. SmallVector<QualType, 4> EltQualTys;
  5860. if (!ScanConstInitList(CGM, bDefaultRowMajor, E, EltVals, EltQualTys))
  5861. return nullptr;
  5862. QualType QualTy = E->getType();
  5863. unsigned EltIdx = 0;
  5864. bool MemRepr = true;
  5865. return BuildConstInitializer(CGM.getTypes(), bDefaultRowMajor,
  5866. QualTy, MemRepr, EltVals, EltQualTys, EltIdx);
  5867. }
  5868. Value *CGMSHLSLRuntime::EmitHLSLMatrixOperationCall(
  5869. CodeGenFunction &CGF, const clang::Expr *E, llvm::Type *RetType,
  5870. ArrayRef<Value *> paramList) {
  5871. HLOpcodeGroup group = GetHLOpcodeGroup(E->getStmtClass());
  5872. unsigned opcode = GetHLOpcode(E);
  5873. if (group == HLOpcodeGroup::HLInit)
  5874. return EmitHLSLArrayInit(CGF.Builder, group, opcode, RetType, paramList,
  5875. TheModule);
  5876. else
  5877. return EmitHLSLMatrixOperationCallImp(CGF.Builder, group, opcode, RetType,
  5878. paramList, TheModule);
  5879. }
  5880. void CGMSHLSLRuntime::EmitHLSLDiscard(CodeGenFunction &CGF) {
  5881. EmitHLSLMatrixOperationCallImp(
  5882. CGF.Builder, HLOpcodeGroup::HLIntrinsic,
  5883. static_cast<unsigned>(IntrinsicOp::IOP_clip),
  5884. llvm::Type::getVoidTy(CGF.getLLVMContext()),
  5885. {ConstantFP::get(llvm::Type::getFloatTy(CGF.getLLVMContext()), -1.0f)},
  5886. TheModule);
  5887. }
  5888. static llvm::Type *MergeIntType(llvm::IntegerType *T0, llvm::IntegerType *T1) {
  5889. if (T0->getBitWidth() > T1->getBitWidth())
  5890. return T0;
  5891. else
  5892. return T1;
  5893. }
  5894. static Value *CreateExt(CGBuilderTy &Builder, Value *Src, llvm::Type *DstTy,
  5895. bool bSigned) {
  5896. if (bSigned)
  5897. return Builder.CreateSExt(Src, DstTy);
  5898. else
  5899. return Builder.CreateZExt(Src, DstTy);
  5900. }
  5901. // For integer literal, try to get lowest precision.
  5902. static Value *CalcHLSLLiteralToLowestPrecision(CGBuilderTy &Builder, Value *Src,
  5903. bool bSigned) {
  5904. if (ConstantInt *CI = dyn_cast<ConstantInt>(Src)) {
  5905. APInt v = CI->getValue();
  5906. switch (v.getActiveWords()) {
  5907. case 4:
  5908. return Builder.getInt32(v.getLimitedValue());
  5909. case 8:
  5910. return Builder.getInt64(v.getLimitedValue());
  5911. case 2:
  5912. // TODO: use low precision type when support it in dxil.
  5913. // return Builder.getInt16(v.getLimitedValue());
  5914. return Builder.getInt32(v.getLimitedValue());
  5915. case 1:
  5916. // TODO: use precision type when support it in dxil.
  5917. // return Builder.getInt8(v.getLimitedValue());
  5918. return Builder.getInt32(v.getLimitedValue());
  5919. default:
  5920. return nullptr;
  5921. }
  5922. } else if (SelectInst *SI = dyn_cast<SelectInst>(Src)) {
  5923. if (SI->getType()->isIntegerTy()) {
  5924. Value *T = SI->getTrueValue();
  5925. Value *F = SI->getFalseValue();
  5926. Value *lowT = CalcHLSLLiteralToLowestPrecision(Builder, T, bSigned);
  5927. Value *lowF = CalcHLSLLiteralToLowestPrecision(Builder, F, bSigned);
  5928. if (lowT && lowF && lowT != T && lowF != F) {
  5929. llvm::IntegerType *TTy = cast<llvm::IntegerType>(lowT->getType());
  5930. llvm::IntegerType *FTy = cast<llvm::IntegerType>(lowF->getType());
  5931. llvm::Type *Ty = MergeIntType(TTy, FTy);
  5932. if (TTy != Ty) {
  5933. lowT = CreateExt(Builder, lowT, Ty, bSigned);
  5934. }
  5935. if (FTy != Ty) {
  5936. lowF = CreateExt(Builder, lowF, Ty, bSigned);
  5937. }
  5938. Value *Cond = SI->getCondition();
  5939. return Builder.CreateSelect(Cond, lowT, lowF);
  5940. }
  5941. }
  5942. } else if (llvm::BinaryOperator *BO = dyn_cast<llvm::BinaryOperator>(Src)) {
  5943. Value *Src0 = BO->getOperand(0);
  5944. Value *Src1 = BO->getOperand(1);
  5945. Value *CastSrc0 = CalcHLSLLiteralToLowestPrecision(Builder, Src0, bSigned);
  5946. Value *CastSrc1 = CalcHLSLLiteralToLowestPrecision(Builder, Src1, bSigned);
  5947. if (Src0 != CastSrc0 && Src1 != CastSrc1 && CastSrc0 && CastSrc1 &&
  5948. CastSrc0->getType() == CastSrc1->getType()) {
  5949. llvm::IntegerType *Ty0 = cast<llvm::IntegerType>(CastSrc0->getType());
  5950. llvm::IntegerType *Ty1 = cast<llvm::IntegerType>(CastSrc0->getType());
  5951. llvm::Type *Ty = MergeIntType(Ty0, Ty1);
  5952. if (Ty0 != Ty) {
  5953. CastSrc0 = CreateExt(Builder, CastSrc0, Ty, bSigned);
  5954. }
  5955. if (Ty1 != Ty) {
  5956. CastSrc1 = CreateExt(Builder, CastSrc1, Ty, bSigned);
  5957. }
  5958. return Builder.CreateBinOp(BO->getOpcode(), CastSrc0, CastSrc1);
  5959. }
  5960. }
  5961. return nullptr;
  5962. }
  5963. Value *CGMSHLSLRuntime::EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src,
  5964. QualType SrcType,
  5965. QualType DstType) {
  5966. auto &Builder = CGF.Builder;
  5967. llvm::Type *DstTy = CGF.ConvertType(DstType);
  5968. bool bSrcSigned = SrcType->isSignedIntegerType();
  5969. if (ConstantInt *CI = dyn_cast<ConstantInt>(Src)) {
  5970. APInt v = CI->getValue();
  5971. if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(DstTy)) {
  5972. v = v.trunc(IT->getBitWidth());
  5973. switch (IT->getBitWidth()) {
  5974. case 32:
  5975. return Builder.getInt32(v.getLimitedValue());
  5976. case 64:
  5977. return Builder.getInt64(v.getLimitedValue());
  5978. case 16:
  5979. return Builder.getInt16(v.getLimitedValue());
  5980. case 8:
  5981. return Builder.getInt8(v.getLimitedValue());
  5982. default:
  5983. return nullptr;
  5984. }
  5985. } else {
  5986. DXASSERT_NOMSG(DstTy->isFloatingPointTy());
  5987. int64_t val = v.getLimitedValue();
  5988. if (v.isNegative())
  5989. val = 0-v.abs().getLimitedValue();
  5990. if (DstTy->isDoubleTy())
  5991. return ConstantFP::get(DstTy, (double)val);
  5992. else if (DstTy->isFloatTy())
  5993. return ConstantFP::get(DstTy, (float)val);
  5994. else {
  5995. if (bSrcSigned)
  5996. return Builder.CreateSIToFP(Src, DstTy);
  5997. else
  5998. return Builder.CreateUIToFP(Src, DstTy);
  5999. }
  6000. }
  6001. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(Src)) {
  6002. APFloat v = CF->getValueAPF();
  6003. double dv = v.convertToDouble();
  6004. if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(DstTy)) {
  6005. switch (IT->getBitWidth()) {
  6006. case 32:
  6007. return Builder.getInt32(dv);
  6008. case 64:
  6009. return Builder.getInt64(dv);
  6010. case 16:
  6011. return Builder.getInt16(dv);
  6012. case 8:
  6013. return Builder.getInt8(dv);
  6014. default:
  6015. return nullptr;
  6016. }
  6017. } else {
  6018. if (DstTy->isFloatTy()) {
  6019. float fv = dv;
  6020. return ConstantFP::get(DstTy->getContext(), APFloat(fv));
  6021. } else {
  6022. return Builder.CreateFPTrunc(Src, DstTy);
  6023. }
  6024. }
  6025. } else if (dyn_cast<UndefValue>(Src)) {
  6026. return UndefValue::get(DstTy);
  6027. } else {
  6028. Instruction *I = cast<Instruction>(Src);
  6029. if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
  6030. Value *T = SI->getTrueValue();
  6031. Value *F = SI->getFalseValue();
  6032. Value *Cond = SI->getCondition();
  6033. if (isa<llvm::ConstantInt>(T) && isa<llvm::ConstantInt>(F)) {
  6034. llvm::APInt lhs = cast<llvm::ConstantInt>(T)->getValue();
  6035. llvm::APInt rhs = cast<llvm::ConstantInt>(F)->getValue();
  6036. if (DstTy == Builder.getInt32Ty()) {
  6037. T = Builder.getInt32(lhs.getLimitedValue());
  6038. F = Builder.getInt32(rhs.getLimitedValue());
  6039. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  6040. return Sel;
  6041. } else if (DstTy->isFloatingPointTy()) {
  6042. T = ConstantFP::get(DstTy, int64_t(lhs.getLimitedValue()));
  6043. F = ConstantFP::get(DstTy, int64_t(rhs.getLimitedValue()));
  6044. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  6045. return Sel;
  6046. }
  6047. } else if (isa<llvm::ConstantFP>(T) && isa<llvm::ConstantFP>(F)) {
  6048. llvm::APFloat lhs = cast<llvm::ConstantFP>(T)->getValueAPF();
  6049. llvm::APFloat rhs = cast<llvm::ConstantFP>(F)->getValueAPF();
  6050. double ld = lhs.convertToDouble();
  6051. double rd = rhs.convertToDouble();
  6052. if (DstTy->isFloatTy()) {
  6053. float lf = ld;
  6054. float rf = rd;
  6055. T = ConstantFP::get(DstTy->getContext(), APFloat(lf));
  6056. F = ConstantFP::get(DstTy->getContext(), APFloat(rf));
  6057. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  6058. return Sel;
  6059. } else if (DstTy == Builder.getInt32Ty()) {
  6060. T = Builder.getInt32(ld);
  6061. F = Builder.getInt32(rd);
  6062. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  6063. return Sel;
  6064. } else if (DstTy == Builder.getInt64Ty()) {
  6065. T = Builder.getInt64(ld);
  6066. F = Builder.getInt64(rd);
  6067. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  6068. return Sel;
  6069. }
  6070. }
  6071. } else if (llvm::BinaryOperator *BO = dyn_cast<llvm::BinaryOperator>(I)) {
  6072. // For integer binary operator, do the calc on lowest precision, then cast
  6073. // to dstTy.
  6074. if (I->getType()->isIntegerTy()) {
  6075. bool bSigned = DstType->isSignedIntegerType();
  6076. Value *CastResult =
  6077. CalcHLSLLiteralToLowestPrecision(Builder, BO, bSigned);
  6078. if (!CastResult)
  6079. return nullptr;
  6080. if (dyn_cast<llvm::IntegerType>(DstTy)) {
  6081. if (DstTy == CastResult->getType()) {
  6082. return CastResult;
  6083. } else {
  6084. if (bSigned)
  6085. return Builder.CreateSExtOrTrunc(CastResult, DstTy);
  6086. else
  6087. return Builder.CreateZExtOrTrunc(CastResult, DstTy);
  6088. }
  6089. } else {
  6090. if (bSrcSigned)
  6091. return Builder.CreateSIToFP(CastResult, DstTy);
  6092. else
  6093. return Builder.CreateUIToFP(CastResult, DstTy);
  6094. }
  6095. }
  6096. }
  6097. // TODO: support other opcode if need.
  6098. return nullptr;
  6099. }
  6100. }
  6101. // For case like ((float3xfloat3)mat4x4).m21 or ((float3xfloat3)mat4x4)[1], just
  6102. // treat it like mat4x4.m21 or mat4x4[1].
  6103. static Value *GetOriginMatrixOperandAndUpdateMatSize(Value *Ptr, unsigned &row,
  6104. unsigned &col) {
  6105. if (CallInst *Mat = dyn_cast<CallInst>(Ptr)) {
  6106. HLOpcodeGroup OpcodeGroup =
  6107. GetHLOpcodeGroupByName(Mat->getCalledFunction());
  6108. if (OpcodeGroup == HLOpcodeGroup::HLCast) {
  6109. HLCastOpcode castOpcode = static_cast<HLCastOpcode>(GetHLOpcode(Mat));
  6110. if (castOpcode == HLCastOpcode::DefaultCast) {
  6111. Ptr = Mat->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  6112. // Remove the cast which is useless now.
  6113. Mat->eraseFromParent();
  6114. // Update row and col.
  6115. HLMatrixType matTy =
  6116. HLMatrixType::cast(Ptr->getType()->getPointerElementType());
  6117. row = matTy.getNumRows();
  6118. col = matTy.getNumColumns();
  6119. // Don't update RetTy and DxilGeneration pass will do the right thing.
  6120. return Ptr;
  6121. }
  6122. }
  6123. }
  6124. return nullptr;
  6125. }
  6126. Value *CGMSHLSLRuntime::EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
  6127. llvm::Type *RetType,
  6128. llvm::Value *Ptr,
  6129. llvm::Value *Idx,
  6130. clang::QualType Ty) {
  6131. bool isRowMajor =
  6132. hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  6133. unsigned opcode =
  6134. isRowMajor ? static_cast<unsigned>(HLSubscriptOpcode::RowMatSubscript)
  6135. : static_cast<unsigned>(HLSubscriptOpcode::ColMatSubscript);
  6136. Value *matBase = Ptr;
  6137. DXASSERT(matBase->getType()->isPointerTy(),
  6138. "matrix subscript should return pointer");
  6139. RetType =
  6140. llvm::PointerType::get(RetType->getPointerElementType(),
  6141. matBase->getType()->getPointerAddressSpace());
  6142. unsigned row, col;
  6143. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  6144. unsigned resultCol = col;
  6145. if (Value *OriginPtr = GetOriginMatrixOperandAndUpdateMatSize(Ptr, row, col)) {
  6146. Ptr = OriginPtr;
  6147. // Update col to result col to get correct result size.
  6148. col = resultCol;
  6149. }
  6150. // Lower mat[Idx] into real idx.
  6151. SmallVector<Value *, 8> args;
  6152. args.emplace_back(Ptr);
  6153. if (isRowMajor) {
  6154. Value *cCol = ConstantInt::get(Idx->getType(), col);
  6155. Value *Base = CGF.Builder.CreateMul(cCol, Idx);
  6156. for (unsigned i = 0; i < col; i++) {
  6157. Value *c = ConstantInt::get(Idx->getType(), i);
  6158. // r * col + c
  6159. Value *matIdx = CGF.Builder.CreateAdd(Base, c);
  6160. args.emplace_back(matIdx);
  6161. }
  6162. } else {
  6163. for (unsigned i = 0; i < col; i++) {
  6164. Value *cMulRow = ConstantInt::get(Idx->getType(), i * row);
  6165. // c * row + r
  6166. Value *matIdx = CGF.Builder.CreateAdd(cMulRow, Idx);
  6167. args.emplace_back(matIdx);
  6168. }
  6169. }
  6170. Value *matSub =
  6171. EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
  6172. opcode, RetType, args, TheModule);
  6173. return matSub;
  6174. }
  6175. Value *CGMSHLSLRuntime::EmitHLSLMatrixElement(CodeGenFunction &CGF,
  6176. llvm::Type *RetType,
  6177. ArrayRef<Value *> paramList,
  6178. QualType Ty) {
  6179. bool isRowMajor =
  6180. hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  6181. unsigned opcode =
  6182. isRowMajor ? static_cast<unsigned>(HLSubscriptOpcode::RowMatElement)
  6183. : static_cast<unsigned>(HLSubscriptOpcode::ColMatElement);
  6184. Value *matBase = paramList[0];
  6185. DXASSERT(matBase->getType()->isPointerTy(),
  6186. "matrix element should return pointer");
  6187. RetType =
  6188. llvm::PointerType::get(RetType->getPointerElementType(),
  6189. matBase->getType()->getPointerAddressSpace());
  6190. Value *idx = paramList[HLOperandIndex::kMatSubscriptSubOpIdx-1];
  6191. // Lower _m00 into real idx.
  6192. // -1 to avoid opcode param which is added in EmitHLSLMatrixOperationCallImp.
  6193. Value *args[] = {paramList[HLOperandIndex::kMatSubscriptMatOpIdx - 1],
  6194. paramList[HLOperandIndex::kMatSubscriptSubOpIdx - 1]};
  6195. unsigned row, col;
  6196. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  6197. Value *Ptr = paramList[0];
  6198. if (Value *OriginPtr = GetOriginMatrixOperandAndUpdateMatSize(Ptr, row, col)) {
  6199. args[0] = OriginPtr;
  6200. }
  6201. // For all zero idx. Still all zero idx.
  6202. if (ConstantAggregateZero *zeros = dyn_cast<ConstantAggregateZero>(idx)) {
  6203. Constant *zero = zeros->getAggregateElement((unsigned)0);
  6204. std::vector<Constant *> elts(zeros->getNumElements() >> 1, zero);
  6205. args[HLOperandIndex::kMatSubscriptSubOpIdx - 1] = ConstantVector::get(elts);
  6206. } else {
  6207. ConstantDataSequential *elts = cast<ConstantDataSequential>(idx);
  6208. unsigned count = elts->getNumElements();
  6209. std::vector<Constant *> idxs(count >> 1);
  6210. for (unsigned i = 0; i < count; i += 2) {
  6211. unsigned rowIdx = elts->getElementAsInteger(i);
  6212. unsigned colIdx = elts->getElementAsInteger(i + 1);
  6213. unsigned matIdx = 0;
  6214. if (isRowMajor) {
  6215. matIdx = rowIdx * col + colIdx;
  6216. } else {
  6217. matIdx = colIdx * row + rowIdx;
  6218. }
  6219. idxs[i >> 1] = CGF.Builder.getInt32(matIdx);
  6220. }
  6221. args[HLOperandIndex::kMatSubscriptSubOpIdx - 1] = ConstantVector::get(idxs);
  6222. }
  6223. return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
  6224. opcode, RetType, args, TheModule);
  6225. }
  6226. Value *CGMSHLSLRuntime::EmitHLSLMatrixLoad(CGBuilderTy &Builder, Value *Ptr,
  6227. QualType Ty) {
  6228. bool isRowMajor =
  6229. hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  6230. unsigned opcode =
  6231. isRowMajor
  6232. ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad)
  6233. : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad);
  6234. Value *matVal = EmitHLSLMatrixOperationCallImp(
  6235. Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
  6236. Ptr->getType()->getPointerElementType(), {Ptr}, TheModule);
  6237. if (!isRowMajor) {
  6238. // ColMatLoad will return a col major matrix.
  6239. // All matrix Value should be row major.
  6240. // Cast it to row major.
  6241. matVal = EmitHLSLMatrixOperationCallImp(
  6242. Builder, HLOpcodeGroup::HLCast,
  6243. static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix),
  6244. matVal->getType(), {matVal}, TheModule);
  6245. }
  6246. return matVal;
  6247. }
  6248. void CGMSHLSLRuntime::EmitHLSLMatrixStore(CGBuilderTy &Builder, Value *Val,
  6249. Value *DestPtr, QualType Ty) {
  6250. bool isRowMajor =
  6251. hlsl::IsHLSLMatRowMajor(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  6252. unsigned opcode =
  6253. isRowMajor
  6254. ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore)
  6255. : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore);
  6256. if (!isRowMajor) {
  6257. Value *ColVal = nullptr;
  6258. // If Val is casted from col major. Just use the original col major val.
  6259. if (CallInst *CI = dyn_cast<CallInst>(Val)) {
  6260. hlsl::HLOpcodeGroup group =
  6261. hlsl::GetHLOpcodeGroupByName(CI->getCalledFunction());
  6262. if (group == HLOpcodeGroup::HLCast) {
  6263. HLCastOpcode castOp = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
  6264. if (castOp == HLCastOpcode::ColMatrixToRowMatrix) {
  6265. ColVal = CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx);
  6266. }
  6267. }
  6268. }
  6269. if (ColVal) {
  6270. Val = ColVal;
  6271. } else {
  6272. // All matrix Value should be row major.
  6273. // ColMatStore need a col major value.
  6274. // Cast it to row major.
  6275. Val = EmitHLSLMatrixOperationCallImp(
  6276. Builder, HLOpcodeGroup::HLCast,
  6277. static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
  6278. Val->getType(), {Val}, TheModule);
  6279. }
  6280. }
  6281. EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
  6282. Val->getType(), {DestPtr, Val}, TheModule);
  6283. }
  6284. Value *CGMSHLSLRuntime::EmitHLSLMatrixLoad(CodeGenFunction &CGF, Value *Ptr,
  6285. QualType Ty) {
  6286. return EmitHLSLMatrixLoad(CGF.Builder, Ptr, Ty);
  6287. }
  6288. void CGMSHLSLRuntime::EmitHLSLMatrixStore(CodeGenFunction &CGF, Value *Val,
  6289. Value *DestPtr, QualType Ty) {
  6290. EmitHLSLMatrixStore(CGF.Builder, Val, DestPtr, Ty);
  6291. }
  6292. // Copy data from srcPtr to destPtr.
  6293. static void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
  6294. ArrayRef<Value *> idxList, CGBuilderTy &Builder) {
  6295. if (idxList.size() > 1) {
  6296. DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
  6297. SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
  6298. }
  6299. llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
  6300. Builder.CreateStore(ld, DestPtr);
  6301. }
  6302. // Get Element val from SrvVal with extract value.
  6303. static Value *GetEltVal(Value *SrcVal, ArrayRef<Value*> idxList,
  6304. CGBuilderTy &Builder) {
  6305. Value *Val = SrcVal;
  6306. // Skip beginning pointer type.
  6307. for (unsigned i = 1; i < idxList.size(); i++) {
  6308. ConstantInt *idx = cast<ConstantInt>(idxList[i]);
  6309. llvm::Type *Ty = Val->getType();
  6310. if (Ty->isAggregateType()) {
  6311. Val = Builder.CreateExtractValue(Val, idx->getLimitedValue());
  6312. }
  6313. }
  6314. return Val;
  6315. }
  6316. // Copy srcVal to destPtr.
  6317. static void SimpleValCopy(Value *DestPtr, Value *SrcVal,
  6318. ArrayRef<Value*> idxList,
  6319. CGBuilderTy &Builder) {
  6320. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  6321. Value *Val = GetEltVal(SrcVal, idxList, Builder);
  6322. Builder.CreateStore(Val, DestGEP);
  6323. }
  6324. static void SimpleCopy(Value *Dest, Value *Src,
  6325. ArrayRef<Value *> idxList,
  6326. CGBuilderTy &Builder) {
  6327. if (Src->getType()->isPointerTy())
  6328. SimplePtrCopy(Dest, Src, idxList, Builder);
  6329. else
  6330. SimpleValCopy(Dest, Src, idxList, Builder);
  6331. }
  6332. void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
  6333. CodeGenFunction &CGF, Value *Ptr, SmallVector<Value *, 4> &idxList,
  6334. clang::QualType Type, llvm::Type *Ty, SmallVector<Value *, 4> &GepList,
  6335. SmallVector<QualType, 4> &EltTyList) {
  6336. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  6337. Constant *idx = Constant::getIntegerValue(
  6338. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  6339. idxList.emplace_back(idx);
  6340. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, Type, PT->getElementType(),
  6341. GepList, EltTyList);
  6342. idxList.pop_back();
  6343. } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
  6344. // Use matLd/St for matrix.
  6345. llvm::Type *EltTy = MatTy.getElementTypeForReg();
  6346. llvm::PointerType *EltPtrTy =
  6347. llvm::PointerType::get(EltTy, Ptr->getType()->getPointerAddressSpace());
  6348. QualType EltQualTy = hlsl::GetHLSLMatElementType(Type);
  6349. Value *matPtr = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  6350. // Flatten matrix to elements.
  6351. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  6352. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  6353. ConstantInt *cRow = CGF.Builder.getInt32(r);
  6354. ConstantInt *cCol = CGF.Builder.getInt32(c);
  6355. Constant *CV = llvm::ConstantVector::get({cRow, cCol});
  6356. GepList.push_back(
  6357. EmitHLSLMatrixElement(CGF, EltPtrTy, {matPtr, CV}, Type));
  6358. EltTyList.push_back(EltQualTy);
  6359. }
  6360. }
  6361. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  6362. if (dxilutil::IsHLSLObjectType(ST)) {
  6363. // Avoid split HLSL object.
  6364. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  6365. GepList.push_back(GEP);
  6366. EltTyList.push_back(Type);
  6367. return;
  6368. }
  6369. const clang::RecordType *RT = Type->getAsStructureType();
  6370. RecordDecl *RD = RT->getDecl();
  6371. const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
  6372. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  6373. if (CXXRD->getNumBases()) {
  6374. // Add base as field.
  6375. for (const auto &I : CXXRD->bases()) {
  6376. const CXXRecordDecl *BaseDecl =
  6377. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  6378. // Skip empty struct.
  6379. if (BaseDecl->field_empty())
  6380. continue;
  6381. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  6382. llvm::Type *parentType = CGF.ConvertType(parentTy);
  6383. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  6384. Constant *idx = llvm::Constant::getIntegerValue(
  6385. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6386. idxList.emplace_back(idx);
  6387. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, parentTy, parentType,
  6388. GepList, EltTyList);
  6389. idxList.pop_back();
  6390. }
  6391. }
  6392. }
  6393. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  6394. fieldIter != fieldEnd; ++fieldIter) {
  6395. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  6396. llvm::Type *ET = ST->getElementType(i);
  6397. Constant *idx = llvm::Constant::getIntegerValue(
  6398. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6399. idxList.emplace_back(idx);
  6400. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, fieldIter->getType(), ET,
  6401. GepList, EltTyList);
  6402. idxList.pop_back();
  6403. }
  6404. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  6405. llvm::Type *ET = AT->getElementType();
  6406. QualType EltType = CGF.getContext().getBaseElementType(Type);
  6407. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  6408. Constant *idx = Constant::getIntegerValue(
  6409. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6410. idxList.emplace_back(idx);
  6411. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, EltType, ET, GepList,
  6412. EltTyList);
  6413. idxList.pop_back();
  6414. }
  6415. } else if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
  6416. // Flatten vector too.
  6417. QualType EltTy = hlsl::GetHLSLVecElementType(Type);
  6418. for (uint32_t i = 0; i < VT->getNumElements(); i++) {
  6419. Constant *idx = CGF.Builder.getInt32(i);
  6420. idxList.emplace_back(idx);
  6421. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  6422. GepList.push_back(GEP);
  6423. EltTyList.push_back(EltTy);
  6424. idxList.pop_back();
  6425. }
  6426. } else {
  6427. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  6428. GepList.push_back(GEP);
  6429. EltTyList.push_back(Type);
  6430. }
  6431. }
  6432. void CGMSHLSLRuntime::LoadElements(CodeGenFunction &CGF,
  6433. ArrayRef<Value *> Ptrs, ArrayRef<QualType> QualTys,
  6434. SmallVector<Value *, 4> &Vals) {
  6435. for (size_t i = 0, e = Ptrs.size(); i < e; i++) {
  6436. Value *Ptr = Ptrs[i];
  6437. llvm::Type *Ty = Ptr->getType()->getPointerElementType();
  6438. DXASSERT_LOCALVAR(Ty, Ty->isIntegerTy() || Ty->isFloatingPointTy(), "Expected only element types.");
  6439. Value *Val = CGF.Builder.CreateLoad(Ptr);
  6440. Val = CGF.EmitFromMemory(Val, QualTys[i]);
  6441. Vals.push_back(Val);
  6442. }
  6443. }
  6444. void CGMSHLSLRuntime::ConvertAndStoreElements(CodeGenFunction &CGF,
  6445. ArrayRef<Value *> SrcVals, ArrayRef<QualType> SrcQualTys,
  6446. ArrayRef<Value *> DstPtrs, ArrayRef<QualType> DstQualTys) {
  6447. for (size_t i = 0, e = DstPtrs.size(); i < e; i++) {
  6448. Value *DstPtr = DstPtrs[i];
  6449. QualType DstQualTy = DstQualTys[i];
  6450. Value *SrcVal = SrcVals[i];
  6451. QualType SrcQualTy = SrcQualTys[i];
  6452. DXASSERT(SrcVal->getType()->isIntegerTy() || SrcVal->getType()->isFloatingPointTy(),
  6453. "Expected only element types.");
  6454. llvm::Value *Result = ConvertScalarOrVector(CGF, SrcVal, SrcQualTy, DstQualTy);
  6455. Result = CGF.EmitToMemory(Result, DstQualTy);
  6456. CGF.Builder.CreateStore(Result, DstPtr);
  6457. }
  6458. }
  6459. static bool AreMatrixArrayOrientationMatching(ASTContext& Context,
  6460. HLModule &Module, QualType LhsTy, QualType RhsTy) {
  6461. while (const clang::ArrayType *LhsArrayTy = Context.getAsArrayType(LhsTy)) {
  6462. LhsTy = LhsArrayTy->getElementType();
  6463. RhsTy = Context.getAsArrayType(RhsTy)->getElementType();
  6464. }
  6465. bool LhsRowMajor, RhsRowMajor;
  6466. LhsRowMajor = RhsRowMajor = Module.GetHLOptions().bDefaultRowMajor;
  6467. HasHLSLMatOrientation(LhsTy, &LhsRowMajor);
  6468. HasHLSLMatOrientation(RhsTy, &RhsRowMajor);
  6469. return LhsRowMajor == RhsRowMajor;
  6470. }
  6471. static llvm::Value *CreateInBoundsGEPIfNeeded(llvm::Value *Ptr, ArrayRef<Value*> IdxList, CGBuilderTy &Builder) {
  6472. DXASSERT(IdxList.size() > 0, "Invalid empty GEP index list");
  6473. // If the GEP list is a single zero, it's a no-op, so save us the trouble.
  6474. if (IdxList.size() == 1) {
  6475. if (ConstantInt *FirstIdx = dyn_cast<ConstantInt>(IdxList[0])) {
  6476. if (FirstIdx->isZero()) return Ptr;
  6477. }
  6478. }
  6479. return Builder.CreateInBoundsGEP(Ptr, IdxList);
  6480. }
  6481. // Copy data from SrcPtr to DestPtr.
  6482. // For matrix, use MatLoad/MatStore.
  6483. // For matrix array, EmitHLSLAggregateCopy on each element.
  6484. // For struct or array, use memcpy.
  6485. // Other just load/store.
  6486. void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
  6487. CodeGenFunction &CGF, llvm::Value *SrcPtr, llvm::Value *DestPtr,
  6488. SmallVector<Value *, 4> &idxList, clang::QualType SrcType,
  6489. clang::QualType DestType, llvm::Type *Ty) {
  6490. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  6491. Constant *idx = Constant::getIntegerValue(
  6492. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  6493. idxList.emplace_back(idx);
  6494. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, SrcType, DestType,
  6495. PT->getElementType());
  6496. idxList.pop_back();
  6497. } else if (HLMatrixType::isa(Ty)) {
  6498. // Use matLd/St for matrix.
  6499. Value *SrcMatPtr = CreateInBoundsGEPIfNeeded(SrcPtr, idxList, CGF.Builder);
  6500. Value *DestMatPtr = CreateInBoundsGEPIfNeeded(DestPtr, idxList, CGF.Builder);
  6501. Value *ldMat = EmitHLSLMatrixLoad(CGF, SrcMatPtr, SrcType);
  6502. EmitHLSLMatrixStore(CGF, ldMat, DestMatPtr, DestType);
  6503. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  6504. if (dxilutil::IsHLSLObjectType(ST)) {
  6505. // Avoid split HLSL object.
  6506. SimpleCopy(DestPtr, SrcPtr, idxList, CGF.Builder);
  6507. return;
  6508. }
  6509. Value *SrcStructPtr = CreateInBoundsGEPIfNeeded(SrcPtr, idxList, CGF.Builder);
  6510. Value *DestStructPtr = CreateInBoundsGEPIfNeeded(DestPtr, idxList, CGF.Builder);
  6511. unsigned size = this->TheModule.getDataLayout().getTypeAllocSize(ST);
  6512. // Memcpy struct.
  6513. CGF.Builder.CreateMemCpy(DestStructPtr, SrcStructPtr, size, 1);
  6514. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  6515. if (!HLMatrixType::isMatrixArray(Ty)
  6516. || AreMatrixArrayOrientationMatching(CGF.getContext(), *m_pHLModule, SrcType, DestType)) {
  6517. Value *SrcArrayPtr = CreateInBoundsGEPIfNeeded(SrcPtr, idxList, CGF.Builder);
  6518. Value *DestArrayPtr = CreateInBoundsGEPIfNeeded(DestPtr, idxList, CGF.Builder);
  6519. unsigned size = this->TheModule.getDataLayout().getTypeAllocSize(AT);
  6520. // Memcpy non-matrix array.
  6521. CGF.Builder.CreateMemCpy(DestArrayPtr, SrcArrayPtr, size, 1);
  6522. } else {
  6523. // Copy matrix arrays elementwise if orientation changes are needed.
  6524. llvm::Type *ET = AT->getElementType();
  6525. QualType EltDestType = CGF.getContext().getBaseElementType(DestType);
  6526. QualType EltSrcType = CGF.getContext().getBaseElementType(SrcType);
  6527. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  6528. Constant *idx = Constant::getIntegerValue(
  6529. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6530. idxList.emplace_back(idx);
  6531. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, EltSrcType,
  6532. EltDestType, ET);
  6533. idxList.pop_back();
  6534. }
  6535. }
  6536. } else {
  6537. SimpleCopy(DestPtr, SrcPtr, idxList, CGF.Builder);
  6538. }
  6539. }
  6540. void CGMSHLSLRuntime::EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  6541. llvm::Value *DestPtr,
  6542. clang::QualType Ty) {
  6543. SmallVector<Value *, 4> idxList;
  6544. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, Ty, Ty, SrcPtr->getType());
  6545. }
  6546. // Make sure all element type of struct is same type.
  6547. static bool IsStructWithSameElementType(llvm::StructType *ST, llvm::Type *Ty) {
  6548. for (llvm::Type *EltTy : ST->elements()) {
  6549. if (StructType *EltSt = dyn_cast<StructType>(EltTy)) {
  6550. if (!IsStructWithSameElementType(EltSt, Ty))
  6551. return false;
  6552. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(EltTy)) {
  6553. llvm::Type *ArrayEltTy = dxilutil::GetArrayEltTy(AT);
  6554. if (ArrayEltTy == Ty) {
  6555. continue;
  6556. } else if (StructType *EltSt = dyn_cast<StructType>(EltTy)) {
  6557. if (!IsStructWithSameElementType(EltSt, Ty))
  6558. return false;
  6559. } else {
  6560. return false;
  6561. }
  6562. } else if (EltTy != Ty)
  6563. return false;
  6564. }
  6565. return true;
  6566. }
  6567. // To memcpy, need element type match.
  6568. // For struct type, the layout should match in cbuffer layout.
  6569. // struct { float2 x; float3 y; } will not match struct { float3 x; float2 y; }.
  6570. // struct { float2 x; float3 y; } will not match array of float.
  6571. static bool IsTypeMatchForMemcpy(llvm::Type *SrcTy, llvm::Type *DestTy) {
  6572. llvm::Type *SrcEltTy = dxilutil::GetArrayEltTy(SrcTy);
  6573. llvm::Type *DestEltTy = dxilutil::GetArrayEltTy(DestTy);
  6574. if (SrcEltTy == DestEltTy)
  6575. return true;
  6576. llvm::StructType *SrcST = dyn_cast<llvm::StructType>(SrcEltTy);
  6577. llvm::StructType *DestST = dyn_cast<llvm::StructType>(DestEltTy);
  6578. if (SrcST && DestST) {
  6579. // Only allow identical struct.
  6580. return SrcST->isLayoutIdentical(DestST);
  6581. } else if (!SrcST && !DestST) {
  6582. // For basic type, if one is array, one is not array, layout is different.
  6583. // If both array, type mismatch. If both basic, copy should be fine.
  6584. // So all return false.
  6585. return false;
  6586. } else {
  6587. // One struct, one basic type.
  6588. // Make sure all struct element match the basic type and basic type is
  6589. // vector4.
  6590. llvm::StructType *ST = SrcST ? SrcST : DestST;
  6591. llvm::Type *Ty = SrcST ? DestEltTy : SrcEltTy;
  6592. if (!Ty->isVectorTy())
  6593. return false;
  6594. if (Ty->getVectorNumElements() != 4)
  6595. return false;
  6596. return IsStructWithSameElementType(ST, Ty);
  6597. }
  6598. }
  6599. void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  6600. clang::QualType SrcTy,
  6601. llvm::Value *DestPtr,
  6602. clang::QualType DestTy) {
  6603. llvm::Type *SrcPtrTy = SrcPtr->getType()->getPointerElementType();
  6604. llvm::Type *DestPtrTy = DestPtr->getType()->getPointerElementType();
  6605. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  6606. if (SrcPtrTy == DestPtrTy) {
  6607. bool bMatArrayRotate = false;
  6608. if (HLMatrixType::isMatrixArrayPtr(SrcPtr->getType())) {
  6609. QualType SrcEltTy = GetArrayEltType(CGM.getContext(), SrcTy);
  6610. QualType DestEltTy = GetArrayEltType(CGM.getContext(), DestTy);
  6611. if (GetMatrixMajor(SrcEltTy, bDefaultRowMajor) !=
  6612. GetMatrixMajor(DestEltTy, bDefaultRowMajor)) {
  6613. bMatArrayRotate = true;
  6614. }
  6615. }
  6616. if (!bMatArrayRotate) {
  6617. // Memcpy if type is match.
  6618. unsigned size = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
  6619. CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, size, 1);
  6620. return;
  6621. }
  6622. } else if (dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(SrcPtrTy)) &&
  6623. dxilutil::IsHLSLObjectType(dxilutil::GetArrayEltTy(DestPtrTy))) {
  6624. unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
  6625. unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
  6626. CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::max(sizeSrc, sizeDest), 1);
  6627. return;
  6628. } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(DestPtr)) {
  6629. if (GV->isInternalLinkage(GV->getLinkage()) &&
  6630. IsTypeMatchForMemcpy(SrcPtrTy, DestPtrTy)) {
  6631. unsigned sizeSrc = TheModule.getDataLayout().getTypeAllocSize(SrcPtrTy);
  6632. unsigned sizeDest = TheModule.getDataLayout().getTypeAllocSize(DestPtrTy);
  6633. CGF.Builder.CreateMemCpy(DestPtr, SrcPtr, std::min(sizeSrc, sizeDest), 1);
  6634. return;
  6635. }
  6636. }
  6637. // It is possible to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore
  6638. // the same way. But split value to scalar will generate many instruction when
  6639. // src type is same as dest type.
  6640. SmallVector<Value *, 4> GEPIdxStack;
  6641. SmallVector<Value *, 4> SrcPtrs;
  6642. SmallVector<QualType, 4> SrcQualTys;
  6643. FlattenAggregatePtrToGepList(CGF, SrcPtr, GEPIdxStack, SrcTy, SrcPtr->getType(),
  6644. SrcPtrs, SrcQualTys);
  6645. SmallVector<Value *, 4> SrcVals;
  6646. LoadElements(CGF, SrcPtrs, SrcQualTys, SrcVals);
  6647. GEPIdxStack.clear();
  6648. SmallVector<Value *, 4> DstPtrs;
  6649. SmallVector<QualType, 4> DstQualTys;
  6650. FlattenAggregatePtrToGepList(CGF, DestPtr, GEPIdxStack, DestTy,
  6651. DestPtr->getType(), DstPtrs, DstQualTys);
  6652. ConvertAndStoreElements(CGF, SrcVals, SrcQualTys, DstPtrs, DstQualTys);
  6653. }
  6654. void CGMSHLSLRuntime::EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *SrcVal,
  6655. llvm::Value *DestPtr,
  6656. clang::QualType Ty) {
  6657. DXASSERT(0, "aggregate return type will use SRet, no aggregate store should exist");
  6658. }
  6659. // Either copies a scalar to a scalar, a scalar to a vector, or splats a scalar to a vector
  6660. static void SimpleFlatValCopy(CodeGenFunction &CGF,
  6661. Value *SrcVal, QualType SrcQualTy, Value *DstPtr, QualType DstQualTy) {
  6662. DXASSERT(SrcVal->getType() == CGF.ConvertType(SrcQualTy), "QualType/Type mismatch!");
  6663. llvm::Type *DstTy = DstPtr->getType()->getPointerElementType();
  6664. DXASSERT(DstTy == CGF.ConvertTypeForMem(DstQualTy), "QualType/Type mismatch!");
  6665. llvm::VectorType *DstVecTy = dyn_cast<llvm::VectorType>(DstTy);
  6666. QualType DstScalarQualTy = DstQualTy;
  6667. if (DstVecTy) {
  6668. DstScalarQualTy = hlsl::GetHLSLVecElementType(DstQualTy);
  6669. }
  6670. Value *ResultScalar = ConvertScalarOrVector(CGF, SrcVal, SrcQualTy, DstScalarQualTy);
  6671. ResultScalar = CGF.EmitToMemory(ResultScalar, DstScalarQualTy);
  6672. if (DstVecTy) {
  6673. llvm::VectorType *DstScalarVecTy = llvm::VectorType::get(ResultScalar->getType(), 1);
  6674. Value *ResultScalarVec = CGF.Builder.CreateInsertElement(
  6675. UndefValue::get(DstScalarVecTy), ResultScalar, (uint64_t)0);
  6676. std::vector<int> ShufIdx(DstVecTy->getNumElements(), 0);
  6677. Value *ResultVec = CGF.Builder.CreateShuffleVector(ResultScalarVec, ResultScalarVec, ShufIdx);
  6678. CGF.Builder.CreateStore(ResultVec, DstPtr);
  6679. } else
  6680. CGF.Builder.CreateStore(ResultScalar, DstPtr);
  6681. }
  6682. void CGMSHLSLRuntime::EmitHLSLSplat(
  6683. CodeGenFunction &CGF, Value *SrcVal, llvm::Value *DestPtr,
  6684. SmallVector<Value *, 4> &idxList, QualType Type, QualType SrcType,
  6685. llvm::Type *Ty) {
  6686. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  6687. idxList.emplace_back(CGF.Builder.getInt32(0));
  6688. EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, Type,
  6689. SrcType, PT->getElementType());
  6690. idxList.pop_back();
  6691. } else if (HLMatrixType MatTy = HLMatrixType::dyn_cast(Ty)) {
  6692. // Use matLd/St for matrix.
  6693. Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
  6694. llvm::Type *EltTy = MatTy.getElementTypeForReg();
  6695. llvm::VectorType *VT1 = llvm::VectorType::get(EltTy, 1);
  6696. SrcVal = ConvertScalarOrVector(CGF, SrcVal, SrcType, hlsl::GetHLSLMatElementType(Type));
  6697. // Splat the value
  6698. Value *V1 = CGF.Builder.CreateInsertElement(UndefValue::get(VT1), SrcVal,
  6699. (uint64_t)0);
  6700. std::vector<int> shufIdx(MatTy.getNumElements(), 0);
  6701. Value *VecMat = CGF.Builder.CreateShuffleVector(V1, V1, shufIdx);
  6702. Value *MatInit = EmitHLSLMatrixOperationCallImp(
  6703. CGF.Builder, HLOpcodeGroup::HLInit, 0, Ty, {VecMat}, TheModule);
  6704. EmitHLSLMatrixStore(CGF, MatInit, dstGEP, Type);
  6705. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  6706. DXASSERT(!dxilutil::IsHLSLObjectType(ST), "cannot cast to hlsl object, Sema should reject");
  6707. const clang::RecordType *RT = Type->getAsStructureType();
  6708. RecordDecl *RD = RT->getDecl();
  6709. const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
  6710. // Take care base.
  6711. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  6712. if (CXXRD->getNumBases()) {
  6713. for (const auto &I : CXXRD->bases()) {
  6714. const CXXRecordDecl *BaseDecl =
  6715. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  6716. if (BaseDecl->field_empty())
  6717. continue;
  6718. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  6719. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  6720. llvm::Type *ET = ST->getElementType(i);
  6721. Constant *idx = llvm::Constant::getIntegerValue(
  6722. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6723. idxList.emplace_back(idx);
  6724. EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, parentTy, SrcType, ET);
  6725. idxList.pop_back();
  6726. }
  6727. }
  6728. }
  6729. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  6730. fieldIter != fieldEnd; ++fieldIter) {
  6731. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  6732. llvm::Type *ET = ST->getElementType(i);
  6733. Constant *idx = llvm::Constant::getIntegerValue(
  6734. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6735. idxList.emplace_back(idx);
  6736. EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, fieldIter->getType(), SrcType, ET);
  6737. idxList.pop_back();
  6738. }
  6739. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  6740. llvm::Type *ET = AT->getElementType();
  6741. QualType EltType = CGF.getContext().getBaseElementType(Type);
  6742. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  6743. Constant *idx = Constant::getIntegerValue(
  6744. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  6745. idxList.emplace_back(idx);
  6746. EmitHLSLSplat(CGF, SrcVal, DestPtr, idxList, EltType, SrcType, ET);
  6747. idxList.pop_back();
  6748. }
  6749. } else {
  6750. DestPtr = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
  6751. SimpleFlatValCopy(CGF, SrcVal, SrcType, DestPtr, Type);
  6752. }
  6753. }
  6754. void CGMSHLSLRuntime::EmitHLSLFlatConversion(CodeGenFunction &CGF,
  6755. Value *Val,
  6756. Value *DestPtr,
  6757. QualType Ty,
  6758. QualType SrcTy) {
  6759. SmallVector<Value *, 4> SrcVals;
  6760. SmallVector<QualType, 4> SrcQualTys;
  6761. FlattenValToInitList(CGF, SrcVals, SrcQualTys, SrcTy, Val);
  6762. if (SrcVals.size() == 1) {
  6763. // Perform a splat
  6764. SmallVector<Value *, 4> GEPIdxStack;
  6765. GEPIdxStack.emplace_back(CGF.Builder.getInt32(0)); // Add first 0 for DestPtr.
  6766. EmitHLSLSplat(
  6767. CGF, SrcVals[0], DestPtr, GEPIdxStack, Ty, SrcQualTys[0],
  6768. DestPtr->getType()->getPointerElementType());
  6769. }
  6770. else {
  6771. SmallVector<Value *, 4> GEPIdxStack;
  6772. SmallVector<Value *, 4> DstPtrs;
  6773. SmallVector<QualType, 4> DstQualTys;
  6774. FlattenAggregatePtrToGepList(CGF, DestPtr, GEPIdxStack, Ty, DestPtr->getType(), DstPtrs, DstQualTys);
  6775. ConvertAndStoreElements(CGF, SrcVals, SrcQualTys, DstPtrs, DstQualTys);
  6776. }
  6777. }
  6778. void CGMSHLSLRuntime::EmitHLSLRootSignature(CodeGenFunction &CGF,
  6779. HLSLRootSignatureAttr *RSA,
  6780. Function *Fn) {
  6781. // Only parse root signature for entry function.
  6782. if (Fn != Entry.Func)
  6783. return;
  6784. StringRef StrRef = RSA->getSignatureName();
  6785. DiagnosticsEngine &Diags = CGF.getContext().getDiagnostics();
  6786. SourceLocation SLoc = RSA->getLocation();
  6787. RootSignatureHandle RootSigHandle;
  6788. clang::CompileRootSignature(StrRef, Diags, SLoc, rootSigVer, DxilRootSignatureCompilationFlags::GlobalRootSignature, &RootSigHandle);
  6789. if (!RootSigHandle.IsEmpty()) {
  6790. RootSigHandle.EnsureSerializedAvailable();
  6791. m_pHLModule->SetSerializedRootSignature(RootSigHandle.GetSerializedBytes(),
  6792. RootSigHandle.GetSerializedSize());
  6793. }
  6794. }
  6795. void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
  6796. CodeGenFunction &CGF, const FunctionDecl *FD, const CallExpr *E,
  6797. llvm::SmallVector<LValue, 8> &castArgList,
  6798. llvm::SmallVector<const Stmt *, 8> &argList,
  6799. const std::function<void(const VarDecl *, llvm::Value *)> &TmpArgMap) {
  6800. // Special case: skip first argument of CXXOperatorCall (it is "this").
  6801. unsigned ArgsToSkip = isa<CXXOperatorCallExpr>(E) ? 1 : 0;
  6802. for (uint32_t i = 0; i < FD->getNumParams(); i++) {
  6803. const ParmVarDecl *Param = FD->getParamDecl(i);
  6804. const Expr *Arg = E->getArg(i+ArgsToSkip);
  6805. QualType ParamTy = Param->getType().getNonReferenceType();
  6806. bool isObject = dxilutil::IsHLSLObjectType(CGF.ConvertTypeForMem(ParamTy));
  6807. bool isAggregateType = !isObject &&
  6808. (ParamTy->isArrayType() || ParamTy->isRecordType()) &&
  6809. !hlsl::IsHLSLVecMatType(ParamTy);
  6810. bool EmitRValueAgg = false;
  6811. bool RValOnRef = false;
  6812. if (!Param->isModifierOut()) {
  6813. if (!isAggregateType && !isObject) {
  6814. if (Arg->isRValue() && Param->getType()->isReferenceType()) {
  6815. // RValue on a reference type.
  6816. if (const CStyleCastExpr *cCast = dyn_cast<CStyleCastExpr>(Arg)) {
  6817. // TODO: Evolving this to warn then fail in future language versions.
  6818. // Allow special case like cast uint to uint for back-compat.
  6819. if (cCast->getCastKind() == CastKind::CK_NoOp) {
  6820. if (const ImplicitCastExpr *cast =
  6821. dyn_cast<ImplicitCastExpr>(cCast->getSubExpr())) {
  6822. if (cast->getCastKind() == CastKind::CK_LValueToRValue) {
  6823. // update the arg
  6824. argList[i] = cast->getSubExpr();
  6825. continue;
  6826. }
  6827. }
  6828. }
  6829. }
  6830. // EmitLValue will report error.
  6831. // Mark RValOnRef to create tmpArg for it.
  6832. RValOnRef = true;
  6833. } else {
  6834. continue;
  6835. }
  6836. } else if (isAggregateType) {
  6837. // aggregate in-only - emit RValue, unless LValueToRValue cast
  6838. EmitRValueAgg = true;
  6839. if (const ImplicitCastExpr *cast =
  6840. dyn_cast<ImplicitCastExpr>(Arg)) {
  6841. if (cast->getCastKind() == CastKind::CK_LValueToRValue) {
  6842. EmitRValueAgg = false;
  6843. }
  6844. }
  6845. } else {
  6846. // Must be object
  6847. DXASSERT(isObject, "otherwise, flow condition changed, breaking assumption");
  6848. // in-only objects should be skipped to preserve previous behavior.
  6849. continue;
  6850. }
  6851. }
  6852. // Skip unbounded array, since we cannot preserve copy-in copy-out
  6853. // semantics for these.
  6854. if (ParamTy->isIncompleteArrayType()) {
  6855. continue;
  6856. }
  6857. if (!Param->isModifierOut() && !RValOnRef) {
  6858. // No need to copy arg to in-only param for hlsl intrinsic.
  6859. if (const FunctionDecl *Callee = E->getDirectCallee()) {
  6860. if (Callee->hasAttr<HLSLIntrinsicAttr>())
  6861. continue;
  6862. }
  6863. }
  6864. // get original arg
  6865. // FIXME: This will not emit in correct argument order with the other
  6866. // arguments. This should be integrated into
  6867. // CodeGenFunction::EmitCallArg if possible.
  6868. RValue argRV; // emit this if aggregate arg on in-only param
  6869. LValue argLV; // otherwise, we may emit this
  6870. llvm::Value *argAddr = nullptr;
  6871. QualType argType = Arg->getType();
  6872. CharUnits argAlignment;
  6873. if (EmitRValueAgg) {
  6874. argRV = CGF.EmitAnyExprToTemp(Arg);
  6875. argAddr = argRV.getAggregateAddr(); // must be alloca
  6876. argAlignment = CharUnits::fromQuantity(cast<AllocaInst>(argAddr)->getAlignment());
  6877. argLV = LValue::MakeAddr(argAddr, ParamTy, argAlignment, CGF.getContext());
  6878. } else {
  6879. argLV = CGF.EmitLValue(Arg);
  6880. if (argLV.isSimple())
  6881. argAddr = argLV.getAddress();
  6882. argType = argLV.getType(); // TBD: Can this be different than Arg->getType()?
  6883. argAlignment = argLV.getAlignment();
  6884. }
  6885. // After emit Arg, we must update the argList[i],
  6886. // otherwise we get double emit of the expression.
  6887. // create temp Var
  6888. VarDecl *tmpArg =
  6889. VarDecl::Create(CGF.getContext(), const_cast<FunctionDecl *>(FD),
  6890. SourceLocation(), SourceLocation(),
  6891. /*IdentifierInfo*/ nullptr, ParamTy,
  6892. CGF.getContext().getTrivialTypeSourceInfo(ParamTy),
  6893. StorageClass::SC_Auto);
  6894. // Aggregate type will be indirect param convert to pointer type.
  6895. // So don't update to ReferenceType, use RValue for it.
  6896. const DeclRefExpr *tmpRef = DeclRefExpr::Create(
  6897. CGF.getContext(), NestedNameSpecifierLoc(), SourceLocation(), tmpArg,
  6898. /*enclosing*/ false, tmpArg->getLocation(), ParamTy,
  6899. (isAggregateType || isObject) ? VK_RValue : VK_LValue);
  6900. // must update the arg, since we did emit Arg, else we get double emit.
  6901. argList[i] = tmpRef;
  6902. // create alloc for the tmp arg
  6903. Value *tmpArgAddr = nullptr;
  6904. BasicBlock *InsertBlock = CGF.Builder.GetInsertBlock();
  6905. Function *F = InsertBlock->getParent();
  6906. // Make sure the alloca is in entry block to stop inline create stacksave.
  6907. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(F));
  6908. tmpArgAddr = AllocaBuilder.CreateAlloca(CGF.ConvertTypeForMem(ParamTy));
  6909. // add it to local decl map
  6910. TmpArgMap(tmpArg, tmpArgAddr);
  6911. LValue tmpLV = LValue::MakeAddr(tmpArgAddr, ParamTy, argAlignment,
  6912. CGF.getContext());
  6913. // save for cast after call
  6914. if (Param->isModifierOut()) {
  6915. castArgList.emplace_back(tmpLV);
  6916. castArgList.emplace_back(argLV);
  6917. }
  6918. // cast before the call
  6919. if (Param->isModifierIn() &&
  6920. // Don't copy object
  6921. !isObject) {
  6922. QualType ArgTy = Arg->getType();
  6923. Value *outVal = nullptr;
  6924. if (!isAggregateType) {
  6925. if (!IsHLSLMatType(ParamTy)) {
  6926. RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation());
  6927. outVal = outRVal.getScalarVal();
  6928. } else {
  6929. DXASSERT(argAddr, "should be RV or simple LV");
  6930. outVal = EmitHLSLMatrixLoad(CGF, argAddr, ArgTy);
  6931. }
  6932. llvm::Type *ToTy = tmpArgAddr->getType()->getPointerElementType();
  6933. if (HLMatrixType::isa(ToTy)) {
  6934. Value *castVal = CGF.Builder.CreateBitCast(outVal, ToTy);
  6935. EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
  6936. }
  6937. else {
  6938. Value *castVal = ConvertScalarOrVector(CGF, outVal, argType, ParamTy);
  6939. castVal = CGF.EmitToMemory(castVal, ParamTy);
  6940. CGF.Builder.CreateStore(castVal, tmpArgAddr);
  6941. }
  6942. } else {
  6943. DXASSERT(argAddr, "should be RV or simple LV");
  6944. SmallVector<Value *, 4> idxList;
  6945. EmitHLSLAggregateCopy(CGF, argAddr, tmpArgAddr,
  6946. idxList, ArgTy, ParamTy,
  6947. argAddr->getType());
  6948. }
  6949. }
  6950. }
  6951. }
  6952. void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
  6953. CodeGenFunction &CGF, llvm::SmallVector<LValue, 8> &castArgList) {
  6954. for (uint32_t i = 0; i < castArgList.size(); i += 2) {
  6955. // cast after the call
  6956. LValue tmpLV = castArgList[i];
  6957. LValue argLV = castArgList[i + 1];
  6958. QualType ArgTy = argLV.getType().getNonReferenceType();
  6959. QualType ParamTy = tmpLV.getType().getNonReferenceType();
  6960. Value *tmpArgAddr = tmpLV.getAddress();
  6961. Value *outVal = nullptr;
  6962. bool isAggregateTy = hlsl::IsHLSLAggregateType(ArgTy);
  6963. bool isObject = dxilutil::IsHLSLObjectType(
  6964. tmpArgAddr->getType()->getPointerElementType());
  6965. if (!isObject) {
  6966. if (!isAggregateTy) {
  6967. if (!IsHLSLMatType(ParamTy))
  6968. outVal = CGF.Builder.CreateLoad(tmpArgAddr);
  6969. else
  6970. outVal = EmitHLSLMatrixLoad(CGF, tmpArgAddr, ParamTy);
  6971. outVal = CGF.EmitFromMemory(outVal, ParamTy);
  6972. llvm::Type *ToTy = CGF.ConvertType(ArgTy);
  6973. llvm::Type *FromTy = outVal->getType();
  6974. Value *castVal = outVal;
  6975. if (ToTy == FromTy) {
  6976. // Don't need cast.
  6977. } else if (ToTy->getScalarType() == FromTy->getScalarType()) {
  6978. if (ToTy->getScalarType() == ToTy) {
  6979. DXASSERT(FromTy->isVectorTy() &&
  6980. FromTy->getVectorNumElements() == 1,
  6981. "must be vector of 1 element");
  6982. castVal = CGF.Builder.CreateExtractElement(outVal, (uint64_t)0);
  6983. } else {
  6984. DXASSERT(!FromTy->isVectorTy(), "must be scalar type");
  6985. DXASSERT(ToTy->isVectorTy() && ToTy->getVectorNumElements() == 1,
  6986. "must be vector of 1 element");
  6987. castVal = UndefValue::get(ToTy);
  6988. castVal =
  6989. CGF.Builder.CreateInsertElement(castVal, outVal, (uint64_t)0);
  6990. }
  6991. } else {
  6992. castVal = ConvertScalarOrVector(CGF,
  6993. outVal, tmpLV.getType(), argLV.getType());
  6994. }
  6995. if (!HLMatrixType::isa(ToTy))
  6996. CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
  6997. else {
  6998. Value *destPtr = argLV.getAddress();
  6999. EmitHLSLMatrixStore(CGF, castVal, destPtr, ArgTy);
  7000. }
  7001. } else {
  7002. SmallVector<Value *, 4> idxList;
  7003. EmitHLSLAggregateCopy(CGF, tmpLV.getAddress(), argLV.getAddress(),
  7004. idxList, ParamTy, ArgTy,
  7005. argLV.getAddress()->getType());
  7006. }
  7007. } else
  7008. tmpArgAddr->replaceAllUsesWith(argLV.getAddress());
  7009. }
  7010. }
  7011. CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) {
  7012. return new CGMSHLSLRuntime(CGM);
  7013. }