SpirvEmitter.cpp 479 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966996799689969997099719972997399749975997699779978997999809981998299839984998599869987998899899990999199929993999499959996999799989999100001000110002100031000410005100061000710008100091001010011100121001310014100151001610017100181001910020100211002210023100241002510026100271002810029100301003110032100331003410035100361003710038100391004010041100421004310044100451004610047100481004910050100511005210053100541005510056100571005810059100601006110062100631006410065100661006710068100691007010071100721007310074100751007610077100781007910080100811008210083100841008510086100871008810089100901009110092100931009410095100961009710098100991010010101101021010310104101051010610107101081010910110101111011210113101141011510116101171011810119101201012110122101231012410125101261012710128101291013010131101321013310134101351013610137101381013910140101411014210143101441014510146101471014810149101501015110152101531015410155101561015710158101591016010161101621016310164101651016610167101681016910170101711017210173101741017510176101771017810179101801018110182101831018410185101861018710188101891019010191101921019310194101951019610197101981019910200102011020210203102041020510206102071020810209102101021110212102131021410215102161021710218102191022010221102221022310224102251022610227102281022910230102311023210233102341023510236102371023810239102401024110242102431024410245102461024710248102491025010251102521025310254102551025610257102581025910260102611026210263102641026510266102671026810269102701027110272102731027410275102761027710278102791028010281102821028310284102851028610287102881028910290102911029210293102941029510296102971029810299103001030110302103031030410305103061030710308103091031010311103121031310314103151031610317103181031910320103211032210323103241032510326103271032810329103301033110332103331033410335103361033710338103391034010341103421034310344103451034610347103481034910350103511035210353103541035510356103571035810359103601036110362103631036410365103661036710368103691037010371103721037310374103751037610377103781037910380103811038210383103841038510386103871038810389103901039110392103931039410395103961039710398103991040010401104021040310404104051040610407104081040910410104111041210413104141041510416104171041810419104201042110422104231042410425104261042710428104291043010431104321043310434104351043610437104381043910440104411044210443104441044510446104471044810449104501045110452104531045410455104561045710458104591046010461104621046310464104651046610467104681046910470104711047210473104741047510476104771047810479104801048110482104831048410485104861048710488104891049010491104921049310494104951049610497104981049910500105011050210503105041050510506105071050810509105101051110512105131051410515105161051710518105191052010521105221052310524105251052610527105281052910530105311053210533105341053510536105371053810539105401054110542105431054410545105461054710548105491055010551105521055310554105551055610557105581055910560105611056210563105641056510566105671056810569105701057110572105731057410575105761057710578105791058010581105821058310584105851058610587105881058910590105911059210593105941059510596105971059810599106001060110602106031060410605106061060710608106091061010611106121061310614106151061610617106181061910620106211062210623106241062510626106271062810629106301063110632106331063410635106361063710638106391064010641106421064310644106451064610647106481064910650106511065210653106541065510656106571065810659106601066110662106631066410665106661066710668106691067010671106721067310674106751067610677106781067910680106811068210683106841068510686106871068810689106901069110692106931069410695106961069710698106991070010701107021070310704107051070610707107081070910710107111071210713107141071510716107171071810719107201072110722107231072410725107261072710728107291073010731107321073310734107351073610737107381073910740107411074210743107441074510746107471074810749107501075110752107531075410755107561075710758107591076010761107621076310764107651076610767107681076910770107711077210773107741077510776107771077810779107801078110782107831078410785107861078710788107891079010791107921079310794107951079610797107981079910800108011080210803108041080510806108071080810809108101081110812108131081410815108161081710818108191082010821108221082310824108251082610827108281082910830108311083210833108341083510836108371083810839108401084110842108431084410845108461084710848108491085010851108521085310854108551085610857108581085910860108611086210863108641086510866108671086810869108701087110872108731087410875108761087710878108791088010881108821088310884108851088610887108881088910890108911089210893108941089510896108971089810899109001090110902109031090410905109061090710908109091091010911109121091310914109151091610917109181091910920109211092210923109241092510926109271092810929109301093110932109331093410935109361093710938109391094010941109421094310944109451094610947109481094910950109511095210953109541095510956109571095810959109601096110962109631096410965109661096710968109691097010971109721097310974109751097610977109781097910980109811098210983109841098510986109871098810989109901099110992109931099410995109961099710998109991100011001110021100311004110051100611007110081100911010110111101211013110141101511016110171101811019110201102111022110231102411025110261102711028110291103011031110321103311034110351103611037110381103911040110411104211043110441104511046110471104811049110501105111052110531105411055110561105711058110591106011061110621106311064110651106611067110681106911070110711107211073110741107511076110771107811079110801108111082110831108411085110861108711088110891109011091110921109311094110951109611097110981109911100111011110211103111041110511106111071110811109111101111111112111131111411115111161111711118111191112011121111221112311124111251112611127111281112911130111311113211133111341113511136111371113811139111401114111142111431114411145111461114711148111491115011151111521115311154111551115611157111581115911160111611116211163111641116511166111671116811169111701117111172111731117411175111761117711178111791118011181111821118311184111851118611187111881118911190111911119211193111941119511196111971119811199112001120111202112031120411205112061120711208112091121011211112121121311214112151121611217112181121911220112211122211223112241122511226112271122811229112301123111232112331123411235112361123711238112391124011241112421124311244112451124611247112481124911250112511125211253112541125511256112571125811259112601126111262112631126411265112661126711268112691127011271112721127311274112751127611277112781127911280112811128211283112841128511286112871128811289112901129111292112931129411295112961129711298112991130011301113021130311304113051130611307113081130911310113111131211313113141131511316113171131811319113201132111322113231132411325113261132711328113291133011331113321133311334113351133611337113381133911340113411134211343113441134511346113471134811349113501135111352113531135411355113561135711358113591136011361113621136311364113651136611367113681136911370113711137211373113741137511376113771137811379113801138111382113831138411385113861138711388113891139011391113921139311394113951139611397113981139911400114011140211403114041140511406114071140811409114101141111412114131141411415114161141711418114191142011421114221142311424114251142611427114281142911430114311143211433114341143511436114371143811439114401144111442114431144411445114461144711448114491145011451114521145311454114551145611457114581145911460114611146211463114641146511466114671146811469114701147111472114731147411475114761147711478114791148011481114821148311484114851148611487114881148911490114911149211493114941149511496114971149811499115001150111502115031150411505115061150711508115091151011511115121151311514115151151611517115181151911520115211152211523115241152511526115271152811529115301153111532115331153411535115361153711538115391154011541115421154311544115451154611547115481154911550115511155211553115541155511556115571155811559115601156111562115631156411565115661156711568115691157011571115721157311574115751157611577115781157911580115811158211583115841158511586115871158811589115901159111592115931159411595115961159711598115991160011601116021160311604116051160611607116081160911610116111161211613116141161511616116171161811619116201162111622116231162411625116261162711628116291163011631116321163311634116351163611637116381163911640116411164211643116441164511646116471164811649116501165111652116531165411655116561165711658116591166011661116621166311664116651166611667116681166911670116711167211673116741167511676116771167811679116801168111682116831168411685116861168711688116891169011691116921169311694116951169611697116981169911700117011170211703117041170511706117071170811709117101171111712117131171411715117161171711718117191172011721117221172311724117251172611727117281172911730117311173211733117341173511736117371173811739117401174111742117431174411745117461174711748117491175011751117521175311754117551175611757117581175911760117611176211763117641176511766117671176811769117701177111772117731177411775117761177711778117791178011781117821178311784117851178611787117881178911790117911179211793117941179511796117971179811799118001180111802118031180411805118061180711808118091181011811118121181311814118151181611817118181181911820118211182211823118241182511826118271182811829118301183111832118331183411835118361183711838118391184011841118421184311844118451184611847118481184911850118511185211853118541185511856118571185811859118601186111862118631186411865118661186711868118691187011871118721187311874118751187611877118781187911880118811188211883118841188511886118871188811889118901189111892118931189411895118961189711898118991190011901119021190311904119051190611907119081190911910119111191211913119141191511916119171191811919119201192111922119231192411925119261192711928119291193011931119321193311934119351193611937119381193911940119411194211943119441194511946119471194811949119501195111952119531195411955119561195711958119591196011961119621196311964119651196611967119681196911970119711197211973119741197511976119771197811979119801198111982119831198411985119861198711988119891199011991119921199311994119951199611997119981199912000
  1. //===------- SpirvEmitter.cpp - SPIR-V Binary Code Emitter ------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a SPIR-V emitter class that takes in HLSL AST and emits
  10. // SPIR-V binary words.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "SpirvEmitter.h"
  14. #include "AlignmentSizeCalculator.h"
  15. #include "RawBufferMethods.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "spirv-tools/optimizer.hpp"
  18. #include "clang/SPIRV/AstTypeProbe.h"
  19. #include "clang/Sema/Sema.h"
  20. #include "llvm/ADT/StringExtras.h"
  21. #include "InitListHandler.h"
  22. #include "dxc/DXIL/DxilConstants.h"
  23. #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
  24. #include "clang/Basic/Version.h"
  25. #else
  26. namespace clang {
  27. uint32_t getGitCommitCount() { return 0; }
  28. const char *getGitCommitHash() { return "<unknown-hash>"; }
  29. } // namespace clang
  30. #endif // SUPPORT_QUERY_GIT_COMMIT_INFO
  31. namespace clang {
  32. namespace spirv {
  33. namespace {
  34. // Returns true if the given decl has the given semantic.
  35. bool hasSemantic(const DeclaratorDecl *decl,
  36. hlsl::DXIL::SemanticKind semanticKind) {
  37. using namespace hlsl;
  38. for (auto *annotation : decl->getUnusualAnnotations()) {
  39. if (auto *semanticDecl = dyn_cast<SemanticDecl>(annotation)) {
  40. llvm::StringRef semanticName;
  41. uint32_t semanticIndex = 0;
  42. Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName,
  43. &semanticIndex);
  44. const auto *semantic = Semantic::GetByName(semanticName);
  45. if (semantic->GetKind() == semanticKind)
  46. return true;
  47. }
  48. }
  49. return false;
  50. }
  51. const ParmVarDecl *patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
  52. for (const auto *param : pcf->parameters())
  53. if (hlsl::IsHLSLOutputPatchType(param->getType()))
  54. return param;
  55. return nullptr;
  56. }
  57. inline bool isSpirvMatrixOp(spv::Op opcode) {
  58. return opcode == spv::Op::OpMatrixTimesMatrix ||
  59. opcode == spv::Op::OpMatrixTimesVector ||
  60. opcode == spv::Op::OpMatrixTimesScalar;
  61. }
  62. /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes
  63. /// index. Otherwiser, returns false.
  64. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a
  65. // duplicate of doCXXMemberCallExpr.
  66. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
  67. using namespace hlsl;
  68. if (const auto *indexing = dyn_cast<CXXMemberCallExpr>(expr)) {
  69. const auto *callee = indexing->getDirectCallee();
  70. uint32_t opcode = static_cast<uint32_t>(IntrinsicOp::Num_Intrinsics);
  71. llvm::StringRef group;
  72. if (GetIntrinsicOp(callee, opcode, group)) {
  73. if (static_cast<IntrinsicOp>(opcode) == IntrinsicOp::MOP_Load) {
  74. const auto *object = indexing->getImplicitObjectArgument();
  75. if (isStructuredBuffer(object->getType())) {
  76. *index = indexing->getArg(0);
  77. return indexing->getImplicitObjectArgument();
  78. }
  79. }
  80. }
  81. }
  82. return nullptr;
  83. }
  84. /// Returns true if the given VarDecl will be translated into a SPIR-V variable
  85. /// not in the Private or Function storage class.
  86. inline bool isExternalVar(const VarDecl *var) {
  87. // Class static variables should be put in the Private storage class.
  88. // groupshared variables are allowed to be declared as "static". But we still
  89. // need to put them in the Workgroup storage class. That is, when seeing
  90. // "static groupshared", ignore "static".
  91. return var->hasExternalFormalLinkage()
  92. ? !var->isStaticDataMember()
  93. : (var->getAttr<HLSLGroupSharedAttr>() != nullptr);
  94. }
  95. /// Returns the referenced variable's DeclContext if the given expr is
  96. /// a DeclRefExpr referencing a ConstantBuffer/TextureBuffer. Otherwise,
  97. /// returns nullptr.
  98. const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) {
  99. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
  100. if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
  101. if (isConstantTextureBuffer(varDecl->getType()))
  102. return hlsl::GetHLSLResourceResultType(varDecl->getType())
  103. ->getAs<RecordType>()
  104. ->getDecl();
  105. return nullptr;
  106. }
  107. /// Returns true if
  108. /// * the given expr is an DeclRefExpr referencing a kind of structured or byte
  109. /// buffer and it is non-alias one, or
  110. /// * the given expr is an CallExpr returning a kind of structured or byte
  111. /// buffer.
  112. /// * the given expr is an ArraySubscriptExpr referencing a kind of structured
  113. /// or byte buffer.
  114. ///
  115. /// Note: legalization specific code
  116. bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
  117. expr = expr->IgnoreParenCasts();
  118. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  119. if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
  120. if (isAKindOfStructuredOrByteBuffer(varDecl->getType()))
  121. return isExternalVar(varDecl);
  122. } else if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
  123. if (isAKindOfStructuredOrByteBuffer(callExpr->getType()))
  124. return true;
  125. } else if (const auto *arrSubExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  126. return isReferencingNonAliasStructuredOrByteBuffer(arrSubExpr->getBase());
  127. }
  128. return false;
  129. }
  130. /// Translates atomic HLSL opcodes into the equivalent SPIR-V opcode.
  131. spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) {
  132. using namespace hlsl;
  133. using namespace spv;
  134. switch (opcode) {
  135. case IntrinsicOp::IOP_InterlockedAdd:
  136. case IntrinsicOp::MOP_InterlockedAdd:
  137. return Op::OpAtomicIAdd;
  138. case IntrinsicOp::IOP_InterlockedAnd:
  139. case IntrinsicOp::MOP_InterlockedAnd:
  140. return Op::OpAtomicAnd;
  141. case IntrinsicOp::IOP_InterlockedOr:
  142. case IntrinsicOp::MOP_InterlockedOr:
  143. return Op::OpAtomicOr;
  144. case IntrinsicOp::IOP_InterlockedXor:
  145. case IntrinsicOp::MOP_InterlockedXor:
  146. return Op::OpAtomicXor;
  147. case IntrinsicOp::IOP_InterlockedUMax:
  148. case IntrinsicOp::MOP_InterlockedUMax:
  149. return Op::OpAtomicUMax;
  150. case IntrinsicOp::IOP_InterlockedUMin:
  151. case IntrinsicOp::MOP_InterlockedUMin:
  152. return Op::OpAtomicUMin;
  153. case IntrinsicOp::IOP_InterlockedMax:
  154. case IntrinsicOp::MOP_InterlockedMax:
  155. return Op::OpAtomicSMax;
  156. case IntrinsicOp::IOP_InterlockedMin:
  157. case IntrinsicOp::MOP_InterlockedMin:
  158. return Op::OpAtomicSMin;
  159. case IntrinsicOp::IOP_InterlockedExchange:
  160. case IntrinsicOp::MOP_InterlockedExchange:
  161. return Op::OpAtomicExchange;
  162. default:
  163. // Only atomic opcodes are relevant.
  164. break;
  165. }
  166. assert(false && "unimplemented hlsl intrinsic opcode");
  167. return Op::Max;
  168. }
  169. // Returns true if the given opcode is an accepted binary opcode in
  170. // OpSpecConstantOp.
  171. bool isAcceptedSpecConstantBinaryOp(spv::Op op) {
  172. switch (op) {
  173. case spv::Op::OpIAdd:
  174. case spv::Op::OpISub:
  175. case spv::Op::OpIMul:
  176. case spv::Op::OpUDiv:
  177. case spv::Op::OpSDiv:
  178. case spv::Op::OpUMod:
  179. case spv::Op::OpSRem:
  180. case spv::Op::OpSMod:
  181. case spv::Op::OpShiftRightLogical:
  182. case spv::Op::OpShiftRightArithmetic:
  183. case spv::Op::OpShiftLeftLogical:
  184. case spv::Op::OpBitwiseOr:
  185. case spv::Op::OpBitwiseXor:
  186. case spv::Op::OpBitwiseAnd:
  187. case spv::Op::OpVectorShuffle:
  188. case spv::Op::OpCompositeExtract:
  189. case spv::Op::OpCompositeInsert:
  190. case spv::Op::OpLogicalOr:
  191. case spv::Op::OpLogicalAnd:
  192. case spv::Op::OpLogicalNot:
  193. case spv::Op::OpLogicalEqual:
  194. case spv::Op::OpLogicalNotEqual:
  195. case spv::Op::OpIEqual:
  196. case spv::Op::OpINotEqual:
  197. case spv::Op::OpULessThan:
  198. case spv::Op::OpSLessThan:
  199. case spv::Op::OpUGreaterThan:
  200. case spv::Op::OpSGreaterThan:
  201. case spv::Op::OpULessThanEqual:
  202. case spv::Op::OpSLessThanEqual:
  203. case spv::Op::OpUGreaterThanEqual:
  204. case spv::Op::OpSGreaterThanEqual:
  205. return true;
  206. default:
  207. // Accepted binary opcodes return true. Anything else is false.
  208. return false;
  209. }
  210. return false;
  211. }
  212. /// Returns true if the given expression is an accepted initializer for a spec
  213. /// constant.
  214. bool isAcceptedSpecConstantInit(const Expr *init) {
  215. // Allow numeric casts
  216. init = init->IgnoreParenCasts();
  217. if (isa<CXXBoolLiteralExpr>(init) || isa<IntegerLiteral>(init) ||
  218. isa<FloatingLiteral>(init))
  219. return true;
  220. // Allow the minus operator which is used to specify negative values
  221. if (const auto *unaryOp = dyn_cast<UnaryOperator>(init))
  222. return unaryOp->getOpcode() == UO_Minus &&
  223. isAcceptedSpecConstantInit(unaryOp->getSubExpr());
  224. return false;
  225. }
  226. /// Returns true if the given function parameter can act as shader stage
  227. /// input parameter.
  228. inline bool canActAsInParmVar(const ParmVarDecl *param) {
  229. // If the parameter has no in/out/inout attribute, it is defaulted to
  230. // an in parameter.
  231. return !param->hasAttr<HLSLOutAttr>() &&
  232. // GS output streams are marked as inout, but it should not be
  233. // used as in parameter.
  234. !hlsl::IsHLSLStreamOutputType(param->getType());
  235. }
  236. /// Returns true if the given function parameter can act as shader stage
  237. /// output parameter.
  238. inline bool canActAsOutParmVar(const ParmVarDecl *param) {
  239. return param->hasAttr<HLSLOutAttr>() || param->hasAttr<HLSLInOutAttr>() ||
  240. hlsl::IsHLSLRayQueryType(param->getType());
  241. }
  242. /// Returns true if the given expression is of builtin type and can be evaluated
  243. /// to a constant zero. Returns false otherwise.
  244. inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) {
  245. const auto type = expr->getType();
  246. if (!type->isBuiltinType())
  247. return false;
  248. Expr::EvalResult evalResult;
  249. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  250. !evalResult.HasSideEffects) {
  251. const auto &val = evalResult.Val;
  252. return ((type->isBooleanType() && !val.getInt().getBoolValue()) ||
  253. (type->isIntegerType() && !val.getInt().getBoolValue()) ||
  254. (type->isFloatingType() && val.getFloat().isZero()));
  255. }
  256. return false;
  257. }
  258. /// Returns the real definition of the callee of the given CallExpr.
  259. ///
  260. /// If we are calling a forward-declared function, callee will be the
  261. /// FunctionDecl for the foward-declared function, not the actual
  262. /// definition. The foward-delcaration and defintion are two completely
  263. /// different AST nodes.
  264. inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) {
  265. const auto *callee = expr->getDirectCallee();
  266. if (callee->isThisDeclarationADefinition())
  267. return callee;
  268. // We need to update callee to the actual definition here
  269. if (!callee->isDefined(callee))
  270. return nullptr;
  271. return callee;
  272. }
  273. /// Returns the referenced definition. The given expr is expected to be a
  274. /// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise.
  275. const DeclaratorDecl *getReferencedDef(const Expr *expr) {
  276. if (!expr)
  277. return nullptr;
  278. expr = expr->IgnoreParenCasts();
  279. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  280. return dyn_cast_or_null<DeclaratorDecl>(declRefExpr->getDecl());
  281. }
  282. if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
  283. return getCalleeDefinition(callExpr);
  284. }
  285. return nullptr;
  286. }
  287. /// Returns the number of base classes if this type is a derived class/struct.
  288. /// Returns zero otherwise.
  289. inline uint32_t getNumBaseClasses(QualType type) {
  290. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  291. return cxxDecl->getNumBases();
  292. return 0;
  293. }
  294. /// Gets the index sequence of casting a derived object to a base object by
  295. /// following the cast chain.
  296. void getBaseClassIndices(const CastExpr *expr,
  297. llvm::SmallVectorImpl<uint32_t> *indices) {
  298. assert(expr->getCastKind() == CK_UncheckedDerivedToBase ||
  299. expr->getCastKind() == CK_HLSLDerivedToBase);
  300. indices->clear();
  301. QualType derivedType = expr->getSubExpr()->getType();
  302. const auto *derivedDecl = derivedType->getAsCXXRecordDecl();
  303. // Go through the base cast chain: for each of the derived to base cast, find
  304. // the index of the base in question in the derived's bases.
  305. for (auto pathIt = expr->path_begin(), pathIe = expr->path_end();
  306. pathIt != pathIe; ++pathIt) {
  307. // The type of the base in question
  308. const auto baseType = (*pathIt)->getType();
  309. uint32_t index = 0;
  310. for (auto baseIt = derivedDecl->bases_begin(),
  311. baseIe = derivedDecl->bases_end();
  312. baseIt != baseIe; ++baseIt, ++index)
  313. if (baseIt->getType() == baseType) {
  314. indices->push_back(index);
  315. break;
  316. }
  317. assert(index < derivedDecl->getNumBases());
  318. // Continue to proceed the next base in the chain
  319. derivedType = baseType;
  320. derivedDecl = derivedType->getAsCXXRecordDecl();
  321. }
  322. }
  323. std::string getNamespacePrefix(const Decl *decl) {
  324. std::string nsPrefix = "";
  325. const DeclContext *dc = decl->getDeclContext();
  326. while (dc && !dc->isTranslationUnit()) {
  327. if (const NamespaceDecl *ns = dyn_cast<NamespaceDecl>(dc)) {
  328. if (!ns->isAnonymousNamespace()) {
  329. nsPrefix = ns->getName().str() + "::" + nsPrefix;
  330. }
  331. }
  332. dc = dc->getParent();
  333. }
  334. return nsPrefix;
  335. }
  336. std::string getFnName(const FunctionDecl *fn) {
  337. // Prefix the function name with the struct name if necessary
  338. std::string classOrStructName = "";
  339. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(fn))
  340. if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
  341. classOrStructName = st->getName().str() + ".";
  342. return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
  343. }
  344. } // namespace
  345. SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
  346. : theCompilerInstance(ci), astContext(ci.getASTContext()),
  347. diags(ci.getDiagnostics()),
  348. spirvOptions(ci.getCodeGenOpts().SpirvOptions),
  349. entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(),
  350. featureManager(diags, spirvOptions),
  351. spvBuilder(astContext, spvContext, spirvOptions),
  352. declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager,
  353. spirvOptions),
  354. entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
  355. seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
  356. beforeHlslLegalization(false), mainSourceFile(nullptr) {
  357. // Get ShaderModel from command line hlsl profile option.
  358. const hlsl::ShaderModel *shaderModel =
  359. hlsl::ShaderModel::GetByName(ci.getCodeGenOpts().HLSLProfile.c_str());
  360. if (shaderModel->GetKind() == hlsl::ShaderModel::Kind::Invalid)
  361. emitError("unknown shader module: %0", {}) << shaderModel->GetName();
  362. if (spirvOptions.invertY && !shaderModel->IsVS() && !shaderModel->IsDS() &&
  363. !shaderModel->IsGS())
  364. emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
  365. if (spirvOptions.useGlLayout && spirvOptions.useDxLayout)
  366. emitError("cannot specify both -fvk-use-dx-layout and -fvk-use-gl-layout",
  367. {});
  368. // Set shader model kind and hlsl major/minor version.
  369. spvContext.setCurrentShaderModelKind(shaderModel->GetKind());
  370. spvContext.setMajorVersion(shaderModel->GetMajor());
  371. spvContext.setMinorVersion(shaderModel->GetMinor());
  372. if (spirvOptions.useDxLayout) {
  373. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
  374. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
  375. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
  376. spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::FxcSBuffer;
  377. } else if (spirvOptions.useGlLayout) {
  378. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
  379. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
  380. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
  381. spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::GLSLStd430;
  382. } else if (spirvOptions.useScalarLayout) {
  383. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar;
  384. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar;
  385. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar;
  386. spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::Scalar;
  387. } else {
  388. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
  389. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
  390. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
  391. spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
  392. }
  393. // Set shader module version, source file name, and source file content (if
  394. // needed).
  395. llvm::StringRef source;
  396. std::vector<llvm::StringRef> fileNames;
  397. const auto &inputFiles = ci.getFrontendOpts().Inputs;
  398. // File name
  399. if (spirvOptions.debugInfoFile && !inputFiles.empty()) {
  400. for (const auto &inputFile : inputFiles) {
  401. fileNames.push_back(inputFile.getFile());
  402. }
  403. }
  404. // Source code
  405. if (spirvOptions.debugInfoSource) {
  406. const auto &sm = ci.getSourceManager();
  407. const llvm::MemoryBuffer *mainFile =
  408. sm.getBuffer(sm.getMainFileID(), SourceLocation());
  409. source = StringRef(mainFile->getBufferStart(), mainFile->getBufferSize());
  410. }
  411. mainSourceFile = spvBuilder.setDebugSource(spvContext.getMajorVersion(),
  412. spvContext.getMinorVersion(),
  413. fileNames, source);
  414. // OpenCL.DebugInfo.100 DebugSource
  415. if (spirvOptions.debugInfoRich) {
  416. auto *dbgSrc = spvBuilder.createDebugSource(mainSourceFile->getString());
  417. // spvContext.getDebugInfo().insert() inserts {string key, RichDebugInfo}
  418. // pair and returns {{string key, RichDebugInfo}, true /*Success*/}.
  419. // spvContext.getDebugInfo().insert().first->second is a RichDebugInfo.
  420. auto *richDebugInfo =
  421. &spvContext.getDebugInfo()
  422. .insert(
  423. {mainSourceFile->getString(),
  424. RichDebugInfo(dbgSrc,
  425. spvBuilder.createDebugCompilationUnit(dbgSrc))})
  426. .first->second;
  427. spvContext.pushDebugLexicalScope(richDebugInfo,
  428. richDebugInfo->scopeStack.back());
  429. }
  430. if (spirvOptions.debugInfoTool &&
  431. spirvOptions.targetEnv.compare("vulkan1.1") >= 0) {
  432. // Emit OpModuleProcessed to indicate the commit information.
  433. std::string commitHash =
  434. std::string("dxc-commit-hash: ") + clang::getGitCommitHash();
  435. spvBuilder.addModuleProcessed(commitHash);
  436. // Emit OpModuleProcessed to indicate the command line options that were
  437. // used to generate this module.
  438. if (!spirvOptions.clOptions.empty()) {
  439. // Using this format: "dxc-cl-option: XXXXXX"
  440. std::string clOptionStr = "dxc-cl-option:" + spirvOptions.clOptions;
  441. spvBuilder.addModuleProcessed(clOptionStr);
  442. }
  443. }
  444. }
  445. void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
  446. // Stop translating if there are errors in previous compilation stages.
  447. if (context.getDiagnostics().hasErrorOccurred())
  448. return;
  449. TranslationUnitDecl *tu = context.getTranslationUnitDecl();
  450. uint32_t numEntryPoints = 0;
  451. // The entry function is the seed of the queue.
  452. for (auto *decl : tu->decls()) {
  453. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  454. if (spvContext.isLib()) {
  455. if (const auto *shaderAttr = funcDecl->getAttr<HLSLShaderAttr>()) {
  456. // If we are compiling as a library then add everything that has a
  457. // ShaderAttr.
  458. addFunctionToWorkQueue(getShaderModelKind(shaderAttr->getStage()),
  459. funcDecl, /*isEntryFunction*/ true);
  460. numEntryPoints++;
  461. }
  462. } else {
  463. if (funcDecl->getName() == entryFunctionName) {
  464. addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(),
  465. funcDecl, /*isEntryFunction*/ true);
  466. numEntryPoints++;
  467. }
  468. }
  469. } else {
  470. doDecl(decl);
  471. }
  472. if (context.getDiagnostics().hasErrorOccurred())
  473. return;
  474. }
  475. // Translate all functions reachable from the entry function.
  476. // The queue can grow in the meanwhile; so need to keep evaluating
  477. // workQueue.size().
  478. for (uint32_t i = 0; i < workQueue.size(); ++i) {
  479. const FunctionInfo *curEntryOrCallee = workQueue[i];
  480. spvContext.setCurrentShaderModelKind(curEntryOrCallee->shaderModelKind);
  481. doDecl(curEntryOrCallee->funcDecl);
  482. if (context.getDiagnostics().hasErrorOccurred())
  483. return;
  484. }
  485. const spv_target_env targetEnv = featureManager.getTargetEnv();
  486. // Addressing and memory model are required in a valid SPIR-V module.
  487. spvBuilder.setMemoryModel(spv::AddressingModel::Logical,
  488. spv::MemoryModel::GLSL450);
  489. // Even though the 'workQueue' grows due to the above loop, the first
  490. // 'numEntryPoints' entries in the 'workQueue' are the ones with the HLSL
  491. // 'shader' attribute, and must therefore be entry functions.
  492. assert(numEntryPoints <= workQueue.size());
  493. for (uint32_t i = 0; i < numEntryPoints; ++i) {
  494. // TODO: assign specific StageVars w.r.t. to entry point
  495. const FunctionInfo *entryInfo = workQueue[i];
  496. assert(entryInfo->isEntryFunction);
  497. spvBuilder.addEntryPoint(
  498. getSpirvShaderStage(entryInfo->shaderModelKind),
  499. entryInfo->entryFunction, entryInfo->funcDecl->getName(),
  500. targetEnv == SPV_ENV_VULKAN_1_2
  501. ? spvBuilder.getModule()->getVariables()
  502. : llvm::ArrayRef<SpirvVariable *>(declIdMapper.collectStageVars()));
  503. }
  504. // Add Location decorations to stage input/output variables.
  505. if (!declIdMapper.decorateStageIOLocations())
  506. return;
  507. // Add descriptor set and binding decorations to resource variables.
  508. if (!declIdMapper.decorateResourceBindings())
  509. return;
  510. // Output the constructed module.
  511. std::vector<uint32_t> m = spvBuilder.takeModule();
  512. if (!spirvOptions.codeGenHighLevel) {
  513. // In order to flatten composite resources, we must also unroll loops.
  514. // Therefore we should run legalization before optimization.
  515. needsLegalization = needsLegalization ||
  516. declIdMapper.requiresLegalization() ||
  517. spirvOptions.flattenResourceArrays ||
  518. declIdMapper.requiresFlatteningCompositeResources() ||
  519. spirvOptions.debugInfoRich;
  520. // Run legalization passes
  521. if (needsLegalization) {
  522. std::string messages;
  523. if (!spirvToolsLegalize(&m, &messages)) {
  524. emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
  525. emitNote("please file a bug report on "
  526. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  527. "with source code if possible",
  528. {});
  529. return;
  530. } else if (!messages.empty()) {
  531. emitWarning("SPIR-V legalization: %0", {}) << messages;
  532. }
  533. }
  534. // Run optimization passes
  535. if (!spirvOptions.debugInfoRich &&
  536. theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) {
  537. std::string messages;
  538. if (!spirvToolsOptimize(&m, &messages)) {
  539. emitFatalError("failed to optimize SPIR-V: %0", {}) << messages;
  540. emitNote("please file a bug report on "
  541. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  542. "with source code if possible",
  543. {});
  544. return;
  545. }
  546. }
  547. }
  548. // Validate the generated SPIR-V code
  549. if (!spirvOptions.disableValidation) {
  550. std::string messages;
  551. if (!spirvToolsValidate(&m, &messages)) {
  552. emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
  553. emitNote("please file a bug report on "
  554. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  555. "with source code if possible",
  556. {});
  557. return;
  558. }
  559. }
  560. theCompilerInstance.getOutStream()->write(
  561. reinterpret_cast<const char *>(m.data()), m.size() * 4);
  562. }
  563. void SpirvEmitter::doDecl(const Decl *decl) {
  564. if (isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl))
  565. return;
  566. // Implicit decls are lazily created when needed.
  567. if (decl->isImplicit()) {
  568. return;
  569. }
  570. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  571. doVarDecl(varDecl);
  572. } else if (const auto *namespaceDecl = dyn_cast<NamespaceDecl>(decl)) {
  573. for (auto *subDecl : namespaceDecl->decls())
  574. // Note: We only emit functions as they are discovered through the call
  575. // graph starting from the entry-point. We should not emit unused
  576. // functions inside namespaces.
  577. if (!isa<FunctionDecl>(subDecl))
  578. doDecl(subDecl);
  579. } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  580. doFunctionDecl(funcDecl);
  581. } else if (const auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
  582. doHLSLBufferDecl(bufferDecl);
  583. } else if (const auto *recordDecl = dyn_cast<RecordDecl>(decl)) {
  584. doRecordDecl(recordDecl);
  585. } else if (const auto *enumDecl = dyn_cast<EnumDecl>(decl)) {
  586. doEnumDecl(enumDecl);
  587. } else {
  588. emitError("decl type %0 unimplemented", decl->getLocation())
  589. << decl->getDeclKindName();
  590. }
  591. }
  592. RichDebugInfo *
  593. SpirvEmitter::getOrCreateRichDebugInfo(const SourceLocation &loc) {
  594. const StringRef file =
  595. astContext.getSourceManager().getPresumedLoc(loc).getFilename();
  596. auto &debugInfo = spvContext.getDebugInfo();
  597. auto it = debugInfo.find(file);
  598. if (it != debugInfo.end())
  599. return &it->second;
  600. auto *dbgSrc = spvBuilder.createDebugSource(file);
  601. // debugInfo.insert() inserts {string key, RichDebugInfo} pair and
  602. // returns {{string key, RichDebugInfo}, true /*Success*/}.
  603. // debugInfo.insert().first->second is a RichDebugInfo.
  604. return &debugInfo
  605. .insert({file, RichDebugInfo(
  606. dbgSrc, spvBuilder.createDebugCompilationUnit(
  607. dbgSrc))})
  608. .first->second;
  609. }
  610. void SpirvEmitter::doStmt(const Stmt *stmt,
  611. llvm::ArrayRef<const Attr *> attrs) {
  612. if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
  613. if (spirvOptions.debugInfoRich) {
  614. // Any opening of curly braces ('{') starts a CompoundStmt in the AST
  615. // tree. It also means we have a new lexical block!
  616. const auto loc = stmt->getLocStart();
  617. const auto &sm = astContext.getSourceManager();
  618. const uint32_t line = sm.getPresumedLineNumber(loc);
  619. const uint32_t column = sm.getPresumedColumnNumber(loc);
  620. RichDebugInfo *info = getOrCreateRichDebugInfo(loc);
  621. auto *debugLexicalBlock = spvBuilder.createDebugLexicalBlock(
  622. info->source, line, column, info->scopeStack.back());
  623. // Add this lexical block to the stack of lexical scopes.
  624. spvContext.pushDebugLexicalScope(info, debugLexicalBlock);
  625. // Update or add DebugScope.
  626. if (spvBuilder.getInsertPoint()->empty()) {
  627. spvBuilder.getInsertPoint()->updateDebugScope(
  628. new (spvContext) SpirvDebugScope(debugLexicalBlock));
  629. } else if (!spvBuilder.isCurrentBasicBlockTerminated()) {
  630. spvBuilder.createDebugScope(debugLexicalBlock);
  631. }
  632. // Iterate over sub-statements
  633. for (auto *st : compoundStmt->body())
  634. doStmt(st);
  635. // We are done with processing this compound statement. Remove its lexical
  636. // block from the stack of lexical scopes.
  637. spvContext.popDebugLexicalScope(info);
  638. if (!spvBuilder.isCurrentBasicBlockTerminated()) {
  639. spvBuilder.createDebugScope(spvContext.getCurrentLexicalScope());
  640. }
  641. } else {
  642. // Iterate over sub-statements
  643. for (auto *st : compoundStmt->body())
  644. doStmt(st);
  645. }
  646. } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
  647. doReturnStmt(retStmt);
  648. } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
  649. doDeclStmt(declStmt);
  650. } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
  651. doIfStmt(ifStmt, attrs);
  652. } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
  653. doSwitchStmt(switchStmt, attrs);
  654. } else if (dyn_cast<CaseStmt>(stmt)) {
  655. processCaseStmtOrDefaultStmt(stmt);
  656. } else if (dyn_cast<DefaultStmt>(stmt)) {
  657. processCaseStmtOrDefaultStmt(stmt);
  658. } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
  659. doBreakStmt(breakStmt);
  660. } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
  661. doDoStmt(theDoStmt, attrs);
  662. } else if (const auto *discardStmt = dyn_cast<DiscardStmt>(stmt)) {
  663. doDiscardStmt(discardStmt);
  664. } else if (const auto *continueStmt = dyn_cast<ContinueStmt>(stmt)) {
  665. doContinueStmt(continueStmt);
  666. } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
  667. doWhileStmt(whileStmt, attrs);
  668. } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
  669. doForStmt(forStmt, attrs);
  670. } else if (dyn_cast<NullStmt>(stmt)) {
  671. // For the null statement ";". We don't need to do anything.
  672. } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
  673. // All cases for expressions used as statements
  674. doExpr(expr);
  675. } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
  676. doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
  677. } else {
  678. emitError("statement class '%0' unimplemented", stmt->getLocStart())
  679. << stmt->getStmtClassName() << stmt->getSourceRange();
  680. }
  681. }
  682. SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr) {
  683. SpirvInstruction *result = nullptr;
  684. expr = expr->IgnoreParens();
  685. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  686. result = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(),
  687. expr->getLocStart());
  688. } else if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
  689. result = doMemberExpr(memberExpr);
  690. } else if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
  691. result = doCastExpr(castExpr);
  692. } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
  693. result = doInitListExpr(initListExpr);
  694. } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
  695. result =
  696. spvBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode);
  697. result->setRValue();
  698. } else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  699. result = translateAPInt(intLiteral->getValue(), expr->getType());
  700. result->setRValue();
  701. } else if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  702. result = translateAPFloat(floatLiteral->getValue(), expr->getType());
  703. result->setRValue();
  704. } else if (const auto *stringLiteral = dyn_cast<StringLiteral>(expr)) {
  705. result = spvBuilder.getString(stringLiteral->getString());
  706. } else if (const auto *compoundAssignOp =
  707. dyn_cast<CompoundAssignOperator>(expr)) {
  708. // CompoundAssignOperator is a subclass of BinaryOperator. It should be
  709. // checked before BinaryOperator.
  710. result = doCompoundAssignOperator(compoundAssignOp);
  711. } else if (const auto *binOp = dyn_cast<BinaryOperator>(expr)) {
  712. result = doBinaryOperator(binOp);
  713. } else if (const auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
  714. result = doUnaryOperator(unaryOp);
  715. } else if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  716. result = doHLSLVectorElementExpr(vecElemExpr);
  717. } else if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
  718. result = doExtMatrixElementExpr(matElemExpr);
  719. } else if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
  720. result = doCallExpr(funcCall);
  721. } else if (const auto *subscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  722. result = doArraySubscriptExpr(subscriptExpr);
  723. } else if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
  724. result = doConditionalOperator(condExpr);
  725. } else if (const auto *defaultArgExpr = dyn_cast<CXXDefaultArgExpr>(expr)) {
  726. result = doExpr(defaultArgExpr->getParam()->getDefaultArg());
  727. } else if (isa<CXXThisExpr>(expr)) {
  728. assert(curThis);
  729. result = curThis;
  730. } else if (isa<CXXConstructExpr>(expr)) {
  731. result = curThis;
  732. } else if (const auto *unaryExpr = dyn_cast<UnaryExprOrTypeTraitExpr>(expr)) {
  733. result = doUnaryExprOrTypeTraitExpr(unaryExpr);
  734. } else {
  735. emitError("expression class '%0' unimplemented", expr->getExprLoc())
  736. << expr->getStmtClassName() << expr->getSourceRange();
  737. }
  738. return result;
  739. }
  740. SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr) {
  741. // We are trying to load the value here, which is what an LValueToRValue
  742. // implicit cast is intended to do. We can ignore the cast if exists.
  743. expr = expr->IgnoreParenLValueCasts();
  744. return loadIfGLValue(expr, doExpr(expr));
  745. }
  746. SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr,
  747. SpirvInstruction *info) {
  748. const auto exprType = expr->getType();
  749. // Do nothing if this is already rvalue
  750. if (!info || info->isRValue())
  751. return info;
  752. // Check whether we are trying to load an array of opaque objects as a whole.
  753. // If true, we are likely to copy it as a whole. To assist per-element
  754. // copying, avoid the load here and return the pointer directly.
  755. // TODO: consider moving this hack into SPIRV-Tools as a transformation.
  756. if (isOpaqueArrayType(exprType))
  757. return info;
  758. // Check whether we are trying to load an externally visible structured/byte
  759. // buffer as a whole. If true, it means we are creating alias for it. Avoid
  760. // the load and write the pointer directly to the alias variable then.
  761. //
  762. // Also for the case of alias function returns. If we are trying to load an
  763. // alias function return as a whole, it means we are assigning it to another
  764. // alias variable. Avoid the load and write the pointer directly.
  765. //
  766. // Note: legalization specific code
  767. if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
  768. return info;
  769. }
  770. if (loadIfAliasVarRef(expr, &info)) {
  771. // We are loading an alias variable as a whole here. This is likely for
  772. // wholesale assignments or function returns. Need to load the pointer.
  773. //
  774. // Note: legalization specific code
  775. return info;
  776. }
  777. SpirvInstruction *loadedInstr = nullptr;
  778. // TODO: Ouch. Very hacky. We need special path to get the value type if
  779. // we are loading a whole ConstantBuffer/TextureBuffer since the normal
  780. // type translation path won't work.
  781. if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
  782. loadedInstr = spvBuilder.createLoad(
  783. declIdMapper.getCTBufferPushConstantType(declContext), info,
  784. expr->getExprLoc());
  785. } else {
  786. loadedInstr = spvBuilder.createLoad(exprType, info, expr->getExprLoc());
  787. }
  788. assert(loadedInstr);
  789. // Special-case: According to the SPIR-V Spec: There is no physical size or
  790. // bit pattern defined for boolean type. Therefore an unsigned integer is used
  791. // to represent booleans when layout is required. In such cases, after loading
  792. // the uint, we should perform a comparison.
  793. {
  794. uint32_t vecSize = 1, numRows = 0, numCols = 0;
  795. if (info->getLayoutRule() != SpirvLayoutRule::Void &&
  796. isBoolOrVecMatOfBoolType(exprType)) {
  797. QualType uintType = astContext.UnsignedIntTy;
  798. if (isScalarType(exprType) || isVectorType(exprType, nullptr, &vecSize)) {
  799. const auto fromType =
  800. vecSize == 1 ? uintType
  801. : astContext.getExtVectorType(uintType, vecSize);
  802. loadedInstr =
  803. castToBool(loadedInstr, fromType, exprType, expr->getLocStart());
  804. } else {
  805. const bool isMat = isMxNMatrix(exprType, nullptr, &numRows, &numCols);
  806. assert(isMat);
  807. (void)isMat;
  808. const clang::Type *type = exprType.getCanonicalType().getTypePtr();
  809. const RecordType *RT = cast<RecordType>(type);
  810. const ClassTemplateSpecializationDecl *templateSpecDecl =
  811. cast<ClassTemplateSpecializationDecl>(RT->getDecl());
  812. ClassTemplateDecl *templateDecl =
  813. templateSpecDecl->getSpecializedTemplate();
  814. const auto fromType = getHLSLMatrixType(
  815. astContext, theCompilerInstance.getSema(), templateDecl,
  816. astContext.UnsignedIntTy, numRows, numCols);
  817. loadedInstr =
  818. castToBool(loadedInstr, fromType, exprType, expr->getLocStart());
  819. }
  820. // Now that it is converted to Bool, it has no layout rule.
  821. // This result-id should be evaluated as bool from here on out.
  822. loadedInstr->setLayoutRule(SpirvLayoutRule::Void);
  823. }
  824. }
  825. loadedInstr->setRValue();
  826. return loadedInstr;
  827. }
  828. SpirvInstruction *SpirvEmitter::loadIfAliasVarRef(const Expr *expr) {
  829. auto *instr = doExpr(expr);
  830. loadIfAliasVarRef(expr, &instr);
  831. return instr;
  832. }
  833. bool SpirvEmitter::loadIfAliasVarRef(const Expr *varExpr,
  834. SpirvInstruction **instr) {
  835. assert(instr);
  836. if ((*instr) && (*instr)->containsAliasComponent() &&
  837. isAKindOfStructuredOrByteBuffer(varExpr->getType())) {
  838. // Load the pointer of the aliased-to-variable if the expression has a
  839. // pointer to pointer type.
  840. if (varExpr->isGLValue()) {
  841. *instr = spvBuilder.createLoad(varExpr->getType(), *instr,
  842. varExpr->getExprLoc());
  843. }
  844. return true;
  845. }
  846. return false;
  847. }
  848. SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value,
  849. QualType fromType, QualType toType,
  850. SourceLocation srcLoc) {
  851. if (isFloatOrVecMatOfFloatType(toType))
  852. return castToFloat(value, fromType, toType, srcLoc);
  853. // Order matters here. Bool (vector) values will also be considered as uint
  854. // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
  855. // will also return true. We need to check bool before uint. The opposite is
  856. // not true.
  857. if (isBoolOrVecMatOfBoolType(toType))
  858. return castToBool(value, fromType, toType, srcLoc);
  859. if (isSintOrVecMatOfSintType(toType) || isUintOrVecMatOfUintType(toType))
  860. return castToInt(value, fromType, toType, srcLoc);
  861. emitError("casting to type %0 unimplemented", {}) << toType;
  862. return nullptr;
  863. }
  864. void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
  865. assert(decl->isThisDeclarationADefinition());
  866. // A RAII class for maintaining the current function under traversal.
  867. class FnEnvRAII {
  868. public:
  869. // Creates a new instance which sets fnEnv to the newFn on creation,
  870. // and resets fnEnv to its original value on destruction.
  871. FnEnvRAII(const FunctionDecl **fnEnv, const FunctionDecl *newFn)
  872. : oldFn(*fnEnv), fnSlot(fnEnv) {
  873. *fnEnv = newFn;
  874. }
  875. ~FnEnvRAII() { *fnSlot = oldFn; }
  876. private:
  877. const FunctionDecl *oldFn;
  878. const FunctionDecl **fnSlot;
  879. };
  880. FnEnvRAII fnEnvRAII(&curFunction, decl);
  881. // We are about to start translation for a new function. Clear the break stack
  882. // and the continue stack.
  883. breakStack = std::stack<SpirvBasicBlock *>();
  884. continueStack = std::stack<SpirvBasicBlock *>();
  885. // This will allow the entry-point name to be something like
  886. // myNamespace::myEntrypointFunc.
  887. std::string funcName = getFnName(decl);
  888. std::string debugFuncName = funcName;
  889. SpirvFunction *func = declIdMapper.getOrRegisterFn(decl);
  890. const auto iter = functionInfoMap.find(decl);
  891. if (iter != functionInfoMap.end()) {
  892. const auto &entryInfo = iter->second;
  893. if (entryInfo->isEntryFunction) {
  894. funcName = "src." + funcName;
  895. // Create wrapper for the entry function
  896. if (!emitEntryFunctionWrapper(decl, func))
  897. return;
  898. }
  899. }
  900. const QualType retType =
  901. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
  902. spvBuilder.beginFunction(retType, decl->getLocStart(), funcName,
  903. decl->hasAttr<HLSLPreciseAttr>(), func);
  904. auto loc = decl->getLocStart();
  905. RichDebugInfo *info = nullptr;
  906. const auto &sm = astContext.getSourceManager();
  907. if (spirvOptions.debugInfoRich && decl->hasBody()) {
  908. const uint32_t line = sm.getPresumedLineNumber(loc);
  909. const uint32_t column = sm.getPresumedColumnNumber(loc);
  910. info = getOrCreateRichDebugInfo(loc);
  911. auto *source = info->source;
  912. // Note that info->scopeStack.back() is a lexical scope of the function
  913. // caller.
  914. auto *parentScope = info->compilationUnit;
  915. // TODO: figure out the proper flag based on the function decl.
  916. // using FlagIsPublic for now.
  917. uint32_t flags = 3u;
  918. // The line number in the source program at which the function scope begins.
  919. auto scopeLine = sm.getPresumedLineNumber(decl->getBody()->getLocStart());
  920. SpirvDebugFunction *debugFunction = spvBuilder.createDebugFunction(
  921. decl, debugFuncName, source, line, column, parentScope, "", flags,
  922. scopeLine, func);
  923. func->setDebugScope(new (spvContext) SpirvDebugScope(debugFunction));
  924. spvContext.pushDebugLexicalScope(info, debugFunction);
  925. }
  926. bool isNonStaticMemberFn = false;
  927. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
  928. if (!memberFn->isStatic()) {
  929. // For non-static member function, the first parameter should be the
  930. // object on which we are invoking this method.
  931. QualType valueType = memberFn->getThisType(astContext)->getPointeeType();
  932. // Remember the parameter for the 'this' object so later we can handle
  933. // CXXThisExpr correctly.
  934. curThis = spvBuilder.addFnParam(valueType, /*isPrecise*/ false,
  935. decl->getLocStart(), "param.this");
  936. if (isOrContainsAKindOfStructuredOrByteBuffer(valueType)) {
  937. curThis->setContainsAliasComponent(true);
  938. needsLegalization = true;
  939. }
  940. if (spirvOptions.debugInfoRich) {
  941. // Add DebugLocalVariable information
  942. const auto &sm = astContext.getSourceManager();
  943. const uint32_t line = sm.getPresumedLineNumber(loc);
  944. const uint32_t column = sm.getPresumedColumnNumber(loc);
  945. if (!info)
  946. info = getOrCreateRichDebugInfo(loc);
  947. // TODO: replace this with FlagArtificial|FlagObjectPointer.
  948. uint32_t flags = (1 << 5) | (1 << 8);
  949. auto *debugLocalVar = spvBuilder.createDebugLocalVariable(
  950. valueType, "this", info->source, line, column,
  951. info->scopeStack.back(), flags, 1);
  952. spvBuilder.createDebugDeclare(debugLocalVar, curThis);
  953. }
  954. isNonStaticMemberFn = true;
  955. }
  956. }
  957. // Create all parameters.
  958. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  959. const ParmVarDecl *paramDecl = decl->getParamDecl(i);
  960. if (spvContext.isHS() && decl == patchConstFunc &&
  961. hlsl::IsHLSLOutputPatchType(paramDecl->getType())) {
  962. // Since the output patch used in hull shaders is translated to
  963. // a variable with Workgroup storage class, there is no need
  964. // to pass the variable as function parameter in SPIR-V.
  965. continue;
  966. }
  967. (void)declIdMapper.createFnParam(paramDecl, i + 1 + isNonStaticMemberFn);
  968. }
  969. if (decl->hasBody()) {
  970. // The entry basic block.
  971. auto *entryLabel = spvBuilder.createBasicBlock("bb.entry");
  972. spvBuilder.setInsertPoint(entryLabel);
  973. // Process all statments in the body.
  974. doStmt(decl->getBody());
  975. // We have processed all Stmts in this function and now in the last
  976. // basic block. Make sure we have a termination instruction.
  977. if (!spvBuilder.isCurrentBasicBlockTerminated()) {
  978. const auto retType = decl->getReturnType();
  979. const auto returnLoc = decl->getBody()->getLocEnd();
  980. if (retType->isVoidType()) {
  981. spvBuilder.createReturn(returnLoc);
  982. } else {
  983. // If the source code does not provide a proper return value for some
  984. // control flow path, it's undefined behavior. We just return null
  985. // value here.
  986. spvBuilder.createReturnValue(spvBuilder.getConstantNull(retType),
  987. returnLoc);
  988. }
  989. }
  990. }
  991. spvBuilder.endFunction();
  992. if (spirvOptions.debugInfoRich) {
  993. spvContext.popDebugLexicalScope(info);
  994. }
  995. }
  996. bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) {
  997. bool success = true;
  998. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  999. const auto varType = varDecl->getType();
  1000. if ((isSubpassInput(varType) || isSubpassInputMS(varType)) &&
  1001. !varDecl->hasAttr<VKInputAttachmentIndexAttr>()) {
  1002. emitError("missing vk::input_attachment_index attribute",
  1003. varDecl->getLocation());
  1004. success = false;
  1005. }
  1006. }
  1007. if (decl->getAttr<VKInputAttachmentIndexAttr>()) {
  1008. if (!spvContext.isPS()) {
  1009. emitError("SubpassInput(MS) only allowed in pixel shader",
  1010. decl->getLocation());
  1011. success = false;
  1012. }
  1013. if (!decl->isExternallyVisible()) {
  1014. emitError("SubpassInput(MS) must be externally visible",
  1015. decl->getLocation());
  1016. success = false;
  1017. }
  1018. // We only allow VKInputAttachmentIndexAttr to be attached to global
  1019. // variables. So it should be fine to cast here.
  1020. const auto elementType =
  1021. hlsl::GetHLSLResourceResultType(cast<VarDecl>(decl)->getType());
  1022. if (!isScalarType(elementType) && !isVectorType(elementType)) {
  1023. emitError(
  1024. "only scalar/vector types allowed as SubpassInput(MS) parameter type",
  1025. decl->getLocation());
  1026. // Return directly to avoid further type processing, which will hit
  1027. // asserts when lowering the type.
  1028. return false;
  1029. }
  1030. }
  1031. // The frontend will make sure that
  1032. // * vk::push_constant applies to global variables of struct type
  1033. // * vk::binding applies to global variables or cbuffers/tbuffers
  1034. // * vk::counter_binding applies to global variables of RW/Append/Consume
  1035. // StructuredBuffer
  1036. // * vk::location applies to function parameters/returns and struct fields
  1037. // So the only case we need to check co-existence is vk::push_constant and
  1038. // vk::binding.
  1039. if (const auto *pcAttr = decl->getAttr<VKPushConstantAttr>()) {
  1040. const auto loc = pcAttr->getLocation();
  1041. if (seenPushConstantAt.isInvalid()) {
  1042. seenPushConstantAt = loc;
  1043. } else {
  1044. // TODO: Actually this is slightly incorrect. The Vulkan spec says:
  1045. // There must be no more than one push constant block statically used
  1046. // per shader entry point.
  1047. // But we are checking whether there are more than one push constant
  1048. // blocks defined. Tracking usage requires more work.
  1049. emitError("cannot have more than one push constant block", loc);
  1050. emitNote("push constant block previously defined here",
  1051. seenPushConstantAt);
  1052. success = false;
  1053. }
  1054. if (decl->hasAttr<VKBindingAttr>()) {
  1055. emitError("vk::push_constant attribute cannot be used together with "
  1056. "vk::binding attribute",
  1057. loc);
  1058. success = false;
  1059. }
  1060. }
  1061. // vk::shader_record_nv is supported only on cbuffer/ConstantBuffer
  1062. if (const auto *srbAttr = decl->getAttr<VKShaderRecordNVAttr>()) {
  1063. const auto loc = srbAttr->getLocation();
  1064. const HLSLBufferDecl *bufDecl = nullptr;
  1065. bool isValidType = false;
  1066. if ((bufDecl = dyn_cast<HLSLBufferDecl>(decl)))
  1067. isValidType = bufDecl->isCBuffer();
  1068. else if ((bufDecl = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())))
  1069. isValidType = bufDecl->isCBuffer();
  1070. else if(isa<VarDecl>(decl))
  1071. isValidType = isConstantBuffer(dyn_cast<VarDecl>(decl)->getType());
  1072. if (!isValidType) {
  1073. emitError(
  1074. "vk::shader_record_nv can be applied only to cbuffer/ConstantBuffer",
  1075. loc);
  1076. success = false;
  1077. }
  1078. if (decl->hasAttr<VKBindingAttr>()) {
  1079. emitError("vk::shader_record_nv attribute cannot be used together with "
  1080. "vk::binding attribute",
  1081. loc);
  1082. success = false;
  1083. }
  1084. }
  1085. return success;
  1086. }
  1087. void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
  1088. // This is a cbuffer/tbuffer decl.
  1089. // Check and emit warnings for member intializers which are not
  1090. // supported in Vulkan
  1091. for (const auto *member : bufferDecl->decls()) {
  1092. if (const auto *varMember = dyn_cast<VarDecl>(member)) {
  1093. if (!spirvOptions.noWarnIgnoredFeatures) {
  1094. if (const auto *init = varMember->getInit())
  1095. emitWarning("%select{tbuffer|cbuffer}0 member initializer "
  1096. "ignored since no Vulkan equivalent",
  1097. init->getExprLoc())
  1098. << bufferDecl->isCBuffer() << init->getSourceRange();
  1099. }
  1100. // We cannot handle external initialization of column-major matrices now.
  1101. if (isOrContainsNonFpColMajorMatrix(astContext, spirvOptions,
  1102. varMember->getType(), varMember)) {
  1103. emitError("externally initialized non-floating-point column-major "
  1104. "matrices not supported yet",
  1105. varMember->getLocation());
  1106. }
  1107. }
  1108. }
  1109. if (!validateVKAttributes(bufferDecl))
  1110. return;
  1111. if (bufferDecl->hasAttr<VKShaderRecordNVAttr>()) {
  1112. (void)declIdMapper.createShaderRecordBufferNV(bufferDecl);
  1113. } else {
  1114. (void)declIdMapper.createCTBuffer(bufferDecl);
  1115. }
  1116. }
  1117. void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
  1118. // Ignore implict records
  1119. // Somehow we'll have implicit records with:
  1120. // static const int Length = count;
  1121. // that can mess up with the normal CodeGen.
  1122. if (recordDecl->isImplicit())
  1123. return;
  1124. // Handle each static member with inline initializer.
  1125. // Each static member has a corresponding VarDecl inside the
  1126. // RecordDecl. For those defined in the translation unit,
  1127. // their VarDecls do not have initializer.
  1128. for (auto *subDecl : recordDecl->decls())
  1129. if (auto *varDecl = dyn_cast<VarDecl>(subDecl))
  1130. if (varDecl->isStaticDataMember() && varDecl->hasInit())
  1131. doVarDecl(varDecl);
  1132. }
  1133. void SpirvEmitter::doEnumDecl(const EnumDecl *decl) {
  1134. for (auto it = decl->enumerator_begin(); it != decl->enumerator_end(); ++it)
  1135. declIdMapper.createEnumConstant(*it);
  1136. }
  1137. void SpirvEmitter::doVarDecl(const VarDecl *decl) {
  1138. if (!validateVKAttributes(decl))
  1139. return;
  1140. const auto loc = decl->getLocation();
  1141. // HLSL has the 'string' type which can be used for rare purposes such as
  1142. // printf (SPIR-V's DebugPrintf). SPIR-V does not have a 'char' or 'string'
  1143. // type, and therefore any variable of such type should not be created.
  1144. // DeclResultIdMapper maps such decl to an OpString instruction that
  1145. // represents the variable's initializer literal.
  1146. if (isStringType(decl->getType())) {
  1147. declIdMapper.createOrUpdateStringVar(decl);
  1148. return;
  1149. }
  1150. // We cannot handle external initialization of column-major matrices now.
  1151. if (isExternalVar(decl) &&
  1152. isOrContainsNonFpColMajorMatrix(astContext, spirvOptions, decl->getType(),
  1153. decl)) {
  1154. emitError("externally initialized non-floating-point column-major "
  1155. "matrices not supported yet",
  1156. loc);
  1157. }
  1158. // Reject arrays of RW/append/consume structured buffers. They have assoicated
  1159. // counters, which are quite nasty to handle.
  1160. if (decl->getType()->isArrayType()) {
  1161. auto type = decl->getType();
  1162. do {
  1163. type = type->getAsArrayTypeUnsafe()->getElementType();
  1164. } while (type->isArrayType());
  1165. if (isRWAppendConsumeSBuffer(type)) {
  1166. emitError("arrays of RW/append/consume structured buffers unsupported",
  1167. loc);
  1168. return;
  1169. }
  1170. }
  1171. if (decl->hasAttr<VKConstantIdAttr>()) {
  1172. // This is a VarDecl for specialization constant.
  1173. createSpecConstant(decl);
  1174. return;
  1175. }
  1176. if (decl->hasAttr<VKPushConstantAttr>()) {
  1177. // This is a VarDecl for PushConstant block.
  1178. (void)declIdMapper.createPushConstant(decl);
  1179. return;
  1180. }
  1181. if (decl->hasAttr<VKShaderRecordNVAttr>()) {
  1182. (void)declIdMapper.createShaderRecordBufferNV(decl);
  1183. return;
  1184. }
  1185. // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
  1186. // to emit their cbuffer/tbuffer as a whole and access each individual one
  1187. // using access chains.
  1188. // cbuffers and tbuffers are HLSLBufferDecls
  1189. // ConstantBuffers and TextureBuffers are not HLSLBufferDecls.
  1190. if (const auto *bufferDecl =
  1191. dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
  1192. // This is a VarDecl of cbuffer/tbuffer type.
  1193. doHLSLBufferDecl(bufferDecl);
  1194. return;
  1195. }
  1196. if (isConstantTextureBuffer(decl->getType())) {
  1197. // This is a VarDecl of ConstantBuffer/TextureBuffer type.
  1198. (void)declIdMapper.createCTBuffer(decl);
  1199. return;
  1200. }
  1201. SpirvVariable *var = nullptr;
  1202. // The contents in externally visible variables can be updated via the
  1203. // pipeline. They should be handled differently from file and function scope
  1204. // variables.
  1205. // File scope variables (static "global" and "local" variables) belongs to
  1206. // the Private storage class, while function scope variables (normal "local"
  1207. // variables) belongs to the Function storage class.
  1208. if (isExternalVar(decl)) {
  1209. var = declIdMapper.createExternVar(decl);
  1210. } else {
  1211. // We already know the variable is not externally visible here. If it does
  1212. // not have local storage, it should be file scope variable.
  1213. const bool isFileScopeVar = !decl->hasLocalStorage();
  1214. if (isFileScopeVar)
  1215. var = declIdMapper.createFileVar(decl, llvm::None);
  1216. else
  1217. var = declIdMapper.createFnVar(decl, llvm::None);
  1218. // Emit OpStore to initialize the variable
  1219. // TODO: revert back to use OpVariable initializer
  1220. // We should only evaluate the initializer once for a static variable.
  1221. if (isFileScopeVar) {
  1222. if (decl->isStaticLocal()) {
  1223. initOnce(decl->getType(), decl->getName(), var, decl->getInit());
  1224. } else {
  1225. // Defer to initialize these global variables at the beginning of the
  1226. // entry function.
  1227. toInitGloalVars.push_back(decl);
  1228. }
  1229. }
  1230. // Function local variables. Just emit OpStore at the current insert point.
  1231. else if (const Expr *init = decl->getInit()) {
  1232. if (auto *constInit = tryToEvaluateAsConst(init)) {
  1233. spvBuilder.createStore(var, constInit, loc);
  1234. } else {
  1235. storeValue(var, loadIfGLValue(init), decl->getType(), loc);
  1236. }
  1237. // Update counter variable associated with local variables
  1238. tryToAssignCounterVar(decl, init);
  1239. }
  1240. if (!isFileScopeVar && spirvOptions.debugInfoRich) {
  1241. // Add DebugLocalVariable information
  1242. const auto &sm = astContext.getSourceManager();
  1243. const uint32_t line = sm.getPresumedLineNumber(loc);
  1244. const uint32_t column = sm.getPresumedColumnNumber(loc);
  1245. const auto *info = getOrCreateRichDebugInfo(loc);
  1246. // TODO: replace this with FlagIsLocal enum.
  1247. uint32_t flags = 1 << 2;
  1248. auto *debugLocalVar = spvBuilder.createDebugLocalVariable(
  1249. decl->getType(), decl->getName(), info->source, line, column,
  1250. info->scopeStack.back(), flags);
  1251. spvBuilder.createDebugDeclare(debugLocalVar, var);
  1252. }
  1253. // Variables that are not externally visible and of opaque types should
  1254. // request legalization.
  1255. if (!needsLegalization && isOpaqueType(decl->getType()))
  1256. needsLegalization = true;
  1257. }
  1258. // All variables that are of opaque struct types should request legalization.
  1259. if (!needsLegalization && isOpaqueStructType(decl->getType()))
  1260. needsLegalization = true;
  1261. }
  1262. spv::LoopControlMask SpirvEmitter::translateLoopAttribute(const Stmt *stmt,
  1263. const Attr &attr) {
  1264. switch (attr.getKind()) {
  1265. case attr::HLSLLoop:
  1266. case attr::HLSLFastOpt:
  1267. return spv::LoopControlMask::DontUnroll;
  1268. case attr::HLSLUnroll:
  1269. return spv::LoopControlMask::Unroll;
  1270. case attr::HLSLAllowUAVCondition:
  1271. if (!spirvOptions.noWarnIgnoredFeatures) {
  1272. emitWarning("unsupported allow_uav_condition attribute ignored",
  1273. stmt->getLocStart());
  1274. }
  1275. break;
  1276. default:
  1277. llvm_unreachable("found unknown loop attribute");
  1278. }
  1279. return spv::LoopControlMask::MaskNone;
  1280. }
  1281. void SpirvEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
  1282. assert(!spvBuilder.isCurrentBasicBlockTerminated());
  1283. // The discard statement can only be called from a pixel shader
  1284. if (!spvContext.isPS()) {
  1285. emitError("discard statement may only be used in pixel shaders",
  1286. discardStmt->getLoc());
  1287. return;
  1288. }
  1289. if (featureManager.isExtensionEnabled(
  1290. Extension::EXT_demote_to_helper_invocation)) {
  1291. // SPV_EXT_demote_to_helper_invocation SPIR-V extension provides a new
  1292. // instruction OpDemoteToHelperInvocationEXT allowing shaders to "demote" a
  1293. // fragment shader invocation to behave like a helper invocation for its
  1294. // duration. The demoted invocation will have no further side effects and
  1295. // will not output to the framebuffer, but remains active and can
  1296. // participate in computing derivatives and in subgroup operations. This is
  1297. // a better match for the "discard" instruction in HLSL.
  1298. spvBuilder.createDemoteToHelperInvocationEXT(discardStmt->getLoc());
  1299. } else {
  1300. // Note: if/when the demote behavior becomes part of the core Vulkan spec,
  1301. // we should no longer generate OpKill for 'discard', and always generate
  1302. // the demote behavior.
  1303. spvBuilder.createKill(discardStmt->getLoc());
  1304. // Some statements that alter the control flow (break, continue, return, and
  1305. // discard), require creation of a new basic block to hold any statement
  1306. // that may follow them.
  1307. auto *newBB = spvBuilder.createBasicBlock();
  1308. spvBuilder.setInsertPoint(newBB);
  1309. }
  1310. }
  1311. void SpirvEmitter::doDoStmt(const DoStmt *theDoStmt,
  1312. llvm::ArrayRef<const Attr *> attrs) {
  1313. // do-while loops are composed of:
  1314. //
  1315. // do {
  1316. // <body>
  1317. // } while(<check>);
  1318. //
  1319. // SPIR-V requires loops to have a merge basic block as well as a continue
  1320. // basic block. Even though do-while loops do not have an explicit continue
  1321. // block as in for-loops, we still do need to create a continue block.
  1322. //
  1323. // Since SPIR-V requires structured control flow, we need two more basic
  1324. // blocks, <header> and <merge>. <header> is the block before control flow
  1325. // diverges, and <merge> is the block where control flow subsequently
  1326. // converges. The <check> can be performed in the <continue> basic block.
  1327. // The final CFG should normally be like the following. Exceptions
  1328. // will occur with non-local exits like loop breaks or early returns.
  1329. //
  1330. // +----------+
  1331. // | header | <-----------------------------------+
  1332. // +----------+ |
  1333. // | | (true)
  1334. // v |
  1335. // +------+ +--------------------+ |
  1336. // | body | ----> | continue (<check>) |-----------+
  1337. // +------+ +--------------------+
  1338. // |
  1339. // | (false)
  1340. // +-------+ |
  1341. // | merge | <-------------+
  1342. // +-------+
  1343. //
  1344. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1345. const spv::LoopControlMask loopControl =
  1346. attrs.empty() ? spv::LoopControlMask::MaskNone
  1347. : translateLoopAttribute(theDoStmt, *attrs.front());
  1348. // Create basic blocks
  1349. auto *headerBB = spvBuilder.createBasicBlock("do_while.header");
  1350. auto *bodyBB = spvBuilder.createBasicBlock("do_while.body");
  1351. auto *continueBB = spvBuilder.createBasicBlock("do_while.continue");
  1352. auto *mergeBB = spvBuilder.createBasicBlock("do_while.merge");
  1353. // Make sure any continue statements branch to the continue block, and any
  1354. // break statements branch to the merge block.
  1355. continueStack.push(continueBB);
  1356. breakStack.push(mergeBB);
  1357. // Branch from the current insert point to the header block.
  1358. spvBuilder.createBranch(headerBB, theDoStmt->getLocStart());
  1359. spvBuilder.addSuccessor(headerBB);
  1360. // Process the <header> block
  1361. // The header block must always branch to the body.
  1362. spvBuilder.setInsertPoint(headerBB);
  1363. const Stmt *body = theDoStmt->getBody();
  1364. spvBuilder.createBranch(bodyBB,
  1365. body ? body->getLocStart() : theDoStmt->getLocStart(),
  1366. mergeBB, continueBB, loopControl);
  1367. spvBuilder.addSuccessor(bodyBB);
  1368. // The current basic block has OpLoopMerge instruction. We need to set its
  1369. // continue and merge target.
  1370. spvBuilder.setContinueTarget(continueBB);
  1371. spvBuilder.setMergeTarget(mergeBB);
  1372. // Process the <body> block
  1373. spvBuilder.setInsertPoint(bodyBB);
  1374. if (body) {
  1375. doStmt(body);
  1376. }
  1377. if (!spvBuilder.isCurrentBasicBlockTerminated()) {
  1378. spvBuilder.createBranch(continueBB, body ? body->getLocEnd()
  1379. : theDoStmt->getLocStart());
  1380. }
  1381. spvBuilder.addSuccessor(continueBB);
  1382. // Process the <continue> block. The check for whether the loop should
  1383. // continue lies in the continue block.
  1384. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in
  1385. // a continue block of a loop, there should be no OpSelectionMerge. Only an
  1386. // OpBranchConditional must be specified.
  1387. spvBuilder.setInsertPoint(continueBB);
  1388. SpirvInstruction *condition = nullptr;
  1389. if (const Expr *check = theDoStmt->getCond()) {
  1390. condition = doExpr(check);
  1391. } else {
  1392. condition = spvBuilder.getConstantBool(true);
  1393. }
  1394. spvBuilder.createConditionalBranch(condition, headerBB, mergeBB,
  1395. theDoStmt->getLocEnd());
  1396. spvBuilder.addSuccessor(headerBB);
  1397. spvBuilder.addSuccessor(mergeBB);
  1398. // Set insertion point to the <merge> block for subsequent statements
  1399. spvBuilder.setInsertPoint(mergeBB);
  1400. // Done with the current scope's continue block and merge block.
  1401. continueStack.pop();
  1402. breakStack.pop();
  1403. }
  1404. void SpirvEmitter::doContinueStmt(const ContinueStmt *continueStmt) {
  1405. assert(!spvBuilder.isCurrentBasicBlockTerminated());
  1406. auto *continueTargetBB = continueStack.top();
  1407. spvBuilder.createBranch(continueTargetBB, continueStmt->getLocStart());
  1408. spvBuilder.addSuccessor(continueTargetBB);
  1409. // Some statements that alter the control flow (break, continue, return, and
  1410. // discard), require creation of a new basic block to hold any statement that
  1411. // may follow them. For example: StmtB and StmtC below are put inside a new
  1412. // basic block which is unreachable.
  1413. //
  1414. // while (true) {
  1415. // StmtA;
  1416. // continue;
  1417. // StmtB;
  1418. // StmtC;
  1419. // }
  1420. auto *newBB = spvBuilder.createBasicBlock();
  1421. spvBuilder.setInsertPoint(newBB);
  1422. }
  1423. void SpirvEmitter::doWhileStmt(const WhileStmt *whileStmt,
  1424. llvm::ArrayRef<const Attr *> attrs) {
  1425. // While loops are composed of:
  1426. // while (<check>) { <body> }
  1427. //
  1428. // SPIR-V requires loops to have a merge basic block as well as a continue
  1429. // basic block. Even though while loops do not have an explicit continue
  1430. // block as in for-loops, we still do need to create a continue block.
  1431. //
  1432. // Since SPIR-V requires structured control flow, we need two more basic
  1433. // blocks, <header> and <merge>. <header> is the block before control flow
  1434. // diverges, and <merge> is the block where control flow subsequently
  1435. // converges. The <check> block can take the responsibility of the <header>
  1436. // block. The final CFG should normally be like the following. Exceptions
  1437. // will occur with non-local exits like loop breaks or early returns.
  1438. //
  1439. // +----------+
  1440. // | header | <------------------+
  1441. // | (check) | |
  1442. // +----------+ |
  1443. // | |
  1444. // +-------+-------+ |
  1445. // | false | true |
  1446. // | v |
  1447. // | +------+ +------------------+
  1448. // | | body | --> | continue (no-op) |
  1449. // v +------+ +------------------+
  1450. // +-------+
  1451. // | merge |
  1452. // +-------+
  1453. //
  1454. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1455. const spv::LoopControlMask loopControl =
  1456. attrs.empty() ? spv::LoopControlMask::MaskNone
  1457. : translateLoopAttribute(whileStmt, *attrs.front());
  1458. // Create basic blocks
  1459. auto *checkBB = spvBuilder.createBasicBlock("while.check");
  1460. auto *bodyBB = spvBuilder.createBasicBlock("while.body");
  1461. auto *continueBB = spvBuilder.createBasicBlock("while.continue");
  1462. auto *mergeBB = spvBuilder.createBasicBlock("while.merge");
  1463. // Make sure any continue statements branch to the continue block, and any
  1464. // break statements branch to the merge block.
  1465. continueStack.push(continueBB);
  1466. breakStack.push(mergeBB);
  1467. // Process the <check> block
  1468. spvBuilder.createBranch(checkBB, whileStmt->getLocStart());
  1469. spvBuilder.addSuccessor(checkBB);
  1470. spvBuilder.setInsertPoint(checkBB);
  1471. // If we have:
  1472. // while (int a = foo()) {...}
  1473. // we should evaluate 'a' by calling 'foo()' every single time the check has
  1474. // to occur.
  1475. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
  1476. doStmt(condVarDecl);
  1477. SpirvInstruction *condition = nullptr;
  1478. const Expr *check = whileStmt->getCond();
  1479. if (check) {
  1480. condition = doExpr(check);
  1481. } else {
  1482. condition = spvBuilder.getConstantBool(true);
  1483. }
  1484. spvBuilder.createConditionalBranch(
  1485. condition, bodyBB,
  1486. /*false branch*/ mergeBB, whileStmt->getLocStart(),
  1487. /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone,
  1488. loopControl);
  1489. spvBuilder.addSuccessor(bodyBB);
  1490. spvBuilder.addSuccessor(mergeBB);
  1491. // The current basic block has OpLoopMerge instruction. We need to set its
  1492. // continue and merge target.
  1493. spvBuilder.setContinueTarget(continueBB);
  1494. spvBuilder.setMergeTarget(mergeBB);
  1495. // Process the <body> block
  1496. spvBuilder.setInsertPoint(bodyBB);
  1497. const Stmt *body = whileStmt->getBody();
  1498. if (body) {
  1499. doStmt(body);
  1500. }
  1501. if (!spvBuilder.isCurrentBasicBlockTerminated())
  1502. spvBuilder.createBranch(continueBB, whileStmt->getLocEnd());
  1503. spvBuilder.addSuccessor(continueBB);
  1504. // Process the <continue> block. While loops do not have an explicit
  1505. // continue block. The continue block just branches to the <check> block.
  1506. spvBuilder.setInsertPoint(continueBB);
  1507. spvBuilder.createBranch(checkBB, whileStmt->getLocEnd());
  1508. spvBuilder.addSuccessor(checkBB);
  1509. // Set insertion point to the <merge> block for subsequent statements
  1510. spvBuilder.setInsertPoint(mergeBB);
  1511. // Done with the current scope's continue and merge blocks.
  1512. continueStack.pop();
  1513. breakStack.pop();
  1514. }
  1515. void SpirvEmitter::doForStmt(const ForStmt *forStmt,
  1516. llvm::ArrayRef<const Attr *> attrs) {
  1517. // for loops are composed of:
  1518. // for (<init>; <check>; <continue>) <body>
  1519. //
  1520. // To translate a for loop, we'll need to emit all <init> statements
  1521. // in the current basic block, and then have separate basic blocks for
  1522. // <check>, <continue>, and <body>. Besides, since SPIR-V requires
  1523. // structured control flow, we need two more basic blocks, <header>
  1524. // and <merge>. <header> is the block before control flow diverges,
  1525. // while <merge> is the block where control flow subsequently converges.
  1526. // The <check> block can take the responsibility of the <header> block.
  1527. // The final CFG should normally be like the following. Exceptions will
  1528. // occur with non-local exits like loop breaks or early returns.
  1529. // +--------+
  1530. // | init |
  1531. // +--------+
  1532. // |
  1533. // v
  1534. // +----------+
  1535. // | header | <---------------+
  1536. // | (check) | |
  1537. // +----------+ |
  1538. // | |
  1539. // +-------+-------+ |
  1540. // | false | true |
  1541. // | v |
  1542. // | +------+ +----------+
  1543. // | | body | --> | continue |
  1544. // v +------+ +----------+
  1545. // +-------+
  1546. // | merge |
  1547. // +-------+
  1548. //
  1549. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1550. const spv::LoopControlMask loopControl =
  1551. attrs.empty() ? spv::LoopControlMask::MaskNone
  1552. : translateLoopAttribute(forStmt, *attrs.front());
  1553. // Create basic blocks
  1554. auto *checkBB = spvBuilder.createBasicBlock("for.check");
  1555. auto *bodyBB = spvBuilder.createBasicBlock("for.body");
  1556. auto *continueBB = spvBuilder.createBasicBlock("for.continue");
  1557. auto *mergeBB = spvBuilder.createBasicBlock("for.merge");
  1558. // Make sure any continue statements branch to the continue block, and any
  1559. // break statements branch to the merge block.
  1560. continueStack.push(continueBB);
  1561. breakStack.push(mergeBB);
  1562. // Process the <init> block
  1563. if (const Stmt *initStmt = forStmt->getInit()) {
  1564. doStmt(initStmt);
  1565. }
  1566. const Expr *check = forStmt->getCond();
  1567. spvBuilder.createBranch(checkBB, check ? check->getLocStart()
  1568. : forStmt->getLocStart());
  1569. spvBuilder.addSuccessor(checkBB);
  1570. // Process the <check> block
  1571. spvBuilder.setInsertPoint(checkBB);
  1572. SpirvInstruction *condition = nullptr;
  1573. if (check) {
  1574. condition = doExpr(check);
  1575. } else {
  1576. condition = spvBuilder.getConstantBool(true);
  1577. }
  1578. const Stmt *body = forStmt->getBody();
  1579. spvBuilder.createConditionalBranch(
  1580. condition, bodyBB,
  1581. /*false branch*/ mergeBB,
  1582. check ? check->getLocEnd()
  1583. : (body ? body->getLocStart() : forStmt->getLocStart()),
  1584. /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone,
  1585. loopControl);
  1586. spvBuilder.addSuccessor(bodyBB);
  1587. spvBuilder.addSuccessor(mergeBB);
  1588. // The current basic block has OpLoopMerge instruction. We need to set its
  1589. // continue and merge target.
  1590. spvBuilder.setContinueTarget(continueBB);
  1591. spvBuilder.setMergeTarget(mergeBB);
  1592. // Process the <body> block
  1593. spvBuilder.setInsertPoint(bodyBB);
  1594. if (body) {
  1595. doStmt(body);
  1596. }
  1597. if (!spvBuilder.isCurrentBasicBlockTerminated())
  1598. spvBuilder.createBranch(continueBB, forStmt->getLocEnd());
  1599. spvBuilder.addSuccessor(continueBB);
  1600. // Process the <continue> block
  1601. spvBuilder.setInsertPoint(continueBB);
  1602. if (const Expr *cont = forStmt->getInc()) {
  1603. doExpr(cont);
  1604. }
  1605. // <continue> should jump back to header
  1606. spvBuilder.createBranch(checkBB, forStmt->getLocEnd());
  1607. spvBuilder.addSuccessor(checkBB);
  1608. // Set insertion point to the <merge> block for subsequent statements
  1609. spvBuilder.setInsertPoint(mergeBB);
  1610. // Done with the current scope's continue block and merge block.
  1611. continueStack.pop();
  1612. breakStack.pop();
  1613. }
  1614. void SpirvEmitter::doIfStmt(const IfStmt *ifStmt,
  1615. llvm::ArrayRef<const Attr *> attrs) {
  1616. // if statements are composed of:
  1617. // if (<check>) { <then> } else { <else> }
  1618. //
  1619. // To translate if statements, we'll need to emit the <check> expressions
  1620. // in the current basic block, and then create separate basic blocks for
  1621. // <then> and <else>. Additionally, we'll need a <merge> block as per
  1622. // SPIR-V's structured control flow requirements. Depending whether there
  1623. // exists the else branch, the final CFG should normally be like the
  1624. // following. Exceptions will occur with non-local exits like loop breaks
  1625. // or early returns.
  1626. // +-------+ +-------+
  1627. // | check | | check |
  1628. // +-------+ +-------+
  1629. // | |
  1630. // +-------+-------+ +-----+-----+
  1631. // | true | false | true | false
  1632. // v v or v |
  1633. // +------+ +------+ +------+ |
  1634. // | then | | else | | then | |
  1635. // +------+ +------+ +------+ |
  1636. // | | | v
  1637. // | +-------+ | | +-------+
  1638. // +-> | merge | <-+ +---> | merge |
  1639. // +-------+ +-------+
  1640. { // Try to see if we can const-eval the condition
  1641. bool condition = false;
  1642. if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) {
  1643. if (condition) {
  1644. doStmt(ifStmt->getThen());
  1645. } else if (ifStmt->getElse()) {
  1646. doStmt(ifStmt->getElse());
  1647. }
  1648. return;
  1649. }
  1650. }
  1651. auto selectionControl = spv::SelectionControlMask::MaskNone;
  1652. if (!attrs.empty()) {
  1653. const Attr *attribute = attrs.front();
  1654. switch (attribute->getKind()) {
  1655. case attr::HLSLBranch:
  1656. selectionControl = spv::SelectionControlMask::DontFlatten;
  1657. break;
  1658. case attr::HLSLFlatten:
  1659. selectionControl = spv::SelectionControlMask::Flatten;
  1660. break;
  1661. default:
  1662. // warning emitted in hlsl::ProcessStmtAttributeForHLSL
  1663. break;
  1664. }
  1665. }
  1666. if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt())
  1667. doDeclStmt(declStmt);
  1668. // First emit the instruction for evaluating the condition.
  1669. auto *condition = doExpr(ifStmt->getCond());
  1670. // Then we need to emit the instruction for the conditional branch.
  1671. // We'll need the <label-id> for the then/else/merge block to do so.
  1672. const bool hasElse = ifStmt->getElse() != nullptr;
  1673. auto *thenBB = spvBuilder.createBasicBlock("if.true");
  1674. auto *mergeBB = spvBuilder.createBasicBlock("if.merge");
  1675. auto *elseBB = hasElse ? spvBuilder.createBasicBlock("if.false") : mergeBB;
  1676. // Create the branch instruction. This will end the current basic block.
  1677. const auto *then = ifStmt->getThen();
  1678. spvBuilder.createConditionalBranch(condition, thenBB, elseBB,
  1679. then->getLocStart(), mergeBB,
  1680. /*continue*/ 0, selectionControl);
  1681. spvBuilder.addSuccessor(thenBB);
  1682. spvBuilder.addSuccessor(elseBB);
  1683. // The current basic block has the OpSelectionMerge instruction. We need
  1684. // to record its merge target.
  1685. spvBuilder.setMergeTarget(mergeBB);
  1686. // Handle the then branch
  1687. spvBuilder.setInsertPoint(thenBB);
  1688. doStmt(then);
  1689. if (!spvBuilder.isCurrentBasicBlockTerminated())
  1690. spvBuilder.createBranch(mergeBB, ifStmt->getLocEnd());
  1691. spvBuilder.addSuccessor(mergeBB);
  1692. // Handle the else branch (if exists)
  1693. if (hasElse) {
  1694. spvBuilder.setInsertPoint(elseBB);
  1695. const auto *elseStmt = ifStmt->getElse();
  1696. doStmt(elseStmt);
  1697. if (!spvBuilder.isCurrentBasicBlockTerminated())
  1698. spvBuilder.createBranch(mergeBB, elseStmt->getLocEnd());
  1699. spvBuilder.addSuccessor(mergeBB);
  1700. }
  1701. // From now on, we'll emit instructions into the merge block.
  1702. spvBuilder.setInsertPoint(mergeBB);
  1703. }
  1704. void SpirvEmitter::doReturnStmt(const ReturnStmt *stmt) {
  1705. if (const auto *retVal = stmt->getRetValue()) {
  1706. // Update counter variable associated with function returns
  1707. tryToAssignCounterVar(curFunction, retVal);
  1708. auto *retInfo = loadIfGLValue(retVal);
  1709. if (!retInfo)
  1710. return;
  1711. auto retType = retVal->getType();
  1712. if (retInfo->getStorageClass() != spv::StorageClass::Function &&
  1713. retType->isStructureType()) {
  1714. // We are returning some value from a non-Function storage class. Need to
  1715. // create a temporary variable to "convert" the value to Function storage
  1716. // class and then return.
  1717. auto *tempVar =
  1718. spvBuilder.addFnVar(retType, retVal->getLocEnd(), "temp.var.ret");
  1719. storeValue(tempVar, retInfo, retType, retVal->getLocEnd());
  1720. spvBuilder.createReturnValue(
  1721. spvBuilder.createLoad(retType, tempVar, retVal->getLocEnd()),
  1722. stmt->getReturnLoc());
  1723. } else {
  1724. spvBuilder.createReturnValue(retInfo, stmt->getReturnLoc());
  1725. }
  1726. } else {
  1727. spvBuilder.createReturn(stmt->getReturnLoc());
  1728. }
  1729. // We are translating a ReturnStmt, we should be in some function's body.
  1730. assert(curFunction->hasBody());
  1731. // If this return statement is the last statement in the function, then
  1732. // whe have no more work to do.
  1733. if (cast<CompoundStmt>(curFunction->getBody())->body_back() == stmt)
  1734. return;
  1735. // Some statements that alter the control flow (break, continue, return, and
  1736. // discard), require creation of a new basic block to hold any statement that
  1737. // may follow them. In this case, the newly created basic block will contain
  1738. // any statement that may come after an early return.
  1739. auto *newBB = spvBuilder.createBasicBlock();
  1740. spvBuilder.setInsertPoint(newBB);
  1741. }
  1742. void SpirvEmitter::doBreakStmt(const BreakStmt *breakStmt) {
  1743. assert(!spvBuilder.isCurrentBasicBlockTerminated());
  1744. auto *breakTargetBB = breakStack.top();
  1745. spvBuilder.addSuccessor(breakTargetBB);
  1746. spvBuilder.createBranch(breakTargetBB, breakStmt->getLocStart());
  1747. // Some statements that alter the control flow (break, continue, return, and
  1748. // discard), require creation of a new basic block to hold any statement that
  1749. // may follow them. For example: StmtB and StmtC below are put inside a new
  1750. // basic block which is unreachable.
  1751. //
  1752. // while (true) {
  1753. // StmtA;
  1754. // break;
  1755. // StmtB;
  1756. // StmtC;
  1757. // }
  1758. auto *newBB = spvBuilder.createBasicBlock();
  1759. spvBuilder.setInsertPoint(newBB);
  1760. }
  1761. void SpirvEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
  1762. llvm::ArrayRef<const Attr *> attrs) {
  1763. // Switch statements are composed of:
  1764. // switch (<condition variable>) {
  1765. // <CaseStmt>
  1766. // <CaseStmt>
  1767. // <CaseStmt>
  1768. // <DefaultStmt> (optional)
  1769. // }
  1770. //
  1771. // +-------+
  1772. // | check |
  1773. // +-------+
  1774. // |
  1775. // +-------+-------+----------------+---------------+
  1776. // | 1 | 2 | 3 | (others)
  1777. // v v v v
  1778. // +-------+ +-------------+ +-------+ +------------+
  1779. // | case1 | | case2 | | case3 | ... | default |
  1780. // | | |(fallthrough)|---->| | | (optional) |
  1781. // +-------+ |+------------+ +-------+ +------------+
  1782. // | | |
  1783. // | | |
  1784. // | +-------+ | |
  1785. // | | | <--------------------+ |
  1786. // +-> | merge | |
  1787. // | | <-------------------------------------+
  1788. // +-------+
  1789. // If no attributes are given, or if "forcecase" attribute was provided,
  1790. // we'll do our best to use OpSwitch if possible.
  1791. // If any of the cases compares to a variable (rather than an integer
  1792. // literal), we cannot use OpSwitch because OpSwitch expects literal
  1793. // numbers as parameters.
  1794. const bool isAttrForceCase =
  1795. !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
  1796. const bool canUseSpirvOpSwitch =
  1797. (attrs.empty() || isAttrForceCase) &&
  1798. allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
  1799. if (isAttrForceCase && !canUseSpirvOpSwitch &&
  1800. !spirvOptions.noWarnIgnoredFeatures) {
  1801. emitWarning("ignored 'forcecase' attribute for the switch statement "
  1802. "since one or more case values are not integer literals",
  1803. switchStmt->getLocStart());
  1804. }
  1805. if (canUseSpirvOpSwitch)
  1806. processSwitchStmtUsingSpirvOpSwitch(switchStmt);
  1807. else
  1808. processSwitchStmtUsingIfStmts(switchStmt);
  1809. }
  1810. SpirvInstruction *
  1811. SpirvEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
  1812. llvm::SmallVector<SpirvInstruction *, 4> indices;
  1813. const auto *base = collectArrayStructIndices(
  1814. expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
  1815. auto *info = loadIfAliasVarRef(base);
  1816. if (!indices.empty()) {
  1817. info = turnIntoElementPtr(base->getType(), info, expr->getType(), indices,
  1818. base->getExprLoc());
  1819. }
  1820. return info;
  1821. }
  1822. SpirvInstruction *SpirvEmitter::doBinaryOperator(const BinaryOperator *expr) {
  1823. const auto opcode = expr->getOpcode();
  1824. // Handle assignment first since we need to evaluate rhs before lhs.
  1825. // For other binary operations, we need to evaluate lhs before rhs.
  1826. if (opcode == BO_Assign) {
  1827. // Update counter variable associated with lhs of assignments
  1828. tryToAssignCounterVar(expr->getLHS(), expr->getRHS());
  1829. return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
  1830. /*isCompoundAssignment=*/false);
  1831. }
  1832. // Try to optimize floatMxN * float and floatN * float case
  1833. if (opcode == BO_Mul) {
  1834. if (auto *result = tryToGenFloatMatrixScale(expr))
  1835. return result;
  1836. if (auto *result = tryToGenFloatVectorScale(expr))
  1837. return result;
  1838. }
  1839. return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode,
  1840. expr->getLHS()->getType(), expr->getType(),
  1841. expr->getSourceRange(), expr->getOperatorLoc());
  1842. }
  1843. SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) {
  1844. if (const auto *operatorCall = dyn_cast<CXXOperatorCallExpr>(callExpr))
  1845. return doCXXOperatorCallExpr(operatorCall);
  1846. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr))
  1847. return doCXXMemberCallExpr(memberCall);
  1848. // Intrinsic functions such as 'dot' or 'mul'
  1849. if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) {
  1850. return processIntrinsicCallExpr(callExpr);
  1851. }
  1852. // Normal standalone functions
  1853. return processCall(callExpr);
  1854. }
  1855. SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
  1856. const FunctionDecl *callee = getCalleeDefinition(callExpr);
  1857. // Note that we always want the defintion because Stmts/Exprs in the
  1858. // function body references the parameters in the definition.
  1859. if (!callee) {
  1860. emitError("found undefined function", callExpr->getExprLoc());
  1861. return nullptr;
  1862. }
  1863. const auto paramTypeMatchesArgType = [](QualType paramType,
  1864. QualType argType) {
  1865. if (argType == paramType)
  1866. return true;
  1867. if (const auto *refType = paramType->getAs<ReferenceType>())
  1868. paramType = refType->getPointeeType();
  1869. auto argUnqualifiedType = argType->getUnqualifiedDesugaredType();
  1870. auto paramUnqualifiedType = paramType->getUnqualifiedDesugaredType();
  1871. if (argUnqualifiedType == paramUnqualifiedType)
  1872. return true;
  1873. return false;
  1874. };
  1875. const auto numParams = callee->getNumParams();
  1876. bool isNonStaticMemberCall = false;
  1877. QualType objectType = {}; // Type of the object (if exists)
  1878. SpirvInstruction *objInstr = nullptr; // EvalInfo for the object (if exists)
  1879. llvm::SmallVector<SpirvInstruction *, 4> vars; // Variables for function call
  1880. llvm::SmallVector<bool, 4> isTempVar; // Temporary variable or not
  1881. llvm::SmallVector<SpirvInstruction *, 4> args; // Evaluated arguments
  1882. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
  1883. const auto *memberFn = cast<CXXMethodDecl>(memberCall->getCalleeDecl());
  1884. isNonStaticMemberCall = !memberFn->isStatic();
  1885. if (isNonStaticMemberCall) {
  1886. // For non-static member calls, evaluate the object and pass it as the
  1887. // first argument.
  1888. const auto *object = memberCall->getImplicitObjectArgument();
  1889. object = object->IgnoreParenNoopCasts(astContext);
  1890. // Update counter variable associated with the implicit object
  1891. tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object);
  1892. objectType = object->getType();
  1893. objInstr = doExpr(object);
  1894. // If not already a variable, we need to create a temporary variable and
  1895. // pass the object pointer to the function. Example:
  1896. // getObject().objectMethod();
  1897. // Also, any parameter passed to the member function must be of Function
  1898. // storage class.
  1899. if (objInstr->isRValue()) {
  1900. args.push_back(createTemporaryVar(
  1901. objectType, getAstTypeName(objectType),
  1902. // May need to load to use as initializer
  1903. loadIfGLValue(object, objInstr), object->getLocStart()));
  1904. } else {
  1905. // Based on SPIR-V spec, function parameter must always be in Function
  1906. // scope. If we pass a non-function scope argument, we need
  1907. // the legalization.
  1908. if (objInstr->getStorageClass() != spv::StorageClass::Function)
  1909. beforeHlslLegalization = true;
  1910. args.push_back(objInstr);
  1911. }
  1912. // We do not need to create a new temporary variable for the this
  1913. // object. Use the evaluated argument.
  1914. vars.push_back(args.back());
  1915. isTempVar.push_back(false);
  1916. }
  1917. }
  1918. // Evaluate parameters
  1919. for (uint32_t i = 0; i < numParams; ++i) {
  1920. // We want the argument variable here so that we can write back to it
  1921. // later. We will do the OpLoad of this argument manually. So ingore
  1922. // the LValueToRValue implicit cast here.
  1923. auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts();
  1924. const auto *param = callee->getParamDecl(i);
  1925. // Get the evaluation info if this argument is referencing some variable
  1926. // *as a whole*, in which case we can avoid creating the temporary variable
  1927. // for it if it can act as out parameter.
  1928. SpirvInstruction *argInfo = nullptr;
  1929. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
  1930. argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(),
  1931. arg->getLocStart());
  1932. }
  1933. auto *argInst = doExpr(arg);
  1934. // If argInfo is nullptr and argInst is a rvalue, we do not have a proper
  1935. // pointer to pass to the function. we need a temporary variable in that
  1936. // case.
  1937. //
  1938. // If we have an 'out/inout' resource as function argument, we need to
  1939. // create a temporary variable for it because the function definition
  1940. // expects are point-to-pointer argument for resources, which will be
  1941. // resolved by legalization.
  1942. if ((argInfo || (argInst && !argInst->isRValue())) &&
  1943. canActAsOutParmVar(param) && !isResourceType(param) &&
  1944. paramTypeMatchesArgType(param->getType(), arg->getType())) {
  1945. // Based on SPIR-V spec, function parameter must be always Function
  1946. // scope. In addition, we must pass memory object declaration argument
  1947. // to function. If we pass an argument that is not function scope
  1948. // or not memory object declaration, we need the legalization.
  1949. if (!argInfo || argInfo->getStorageClass() != spv::StorageClass::Function)
  1950. beforeHlslLegalization = true;
  1951. isTempVar.push_back(false);
  1952. args.push_back(argInst);
  1953. vars.push_back(argInfo ? argInfo : argInst);
  1954. } else {
  1955. // We need to create variables for holding the values to be used as
  1956. // arguments. The variables themselves are of pointer types.
  1957. const QualType varType =
  1958. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
  1959. const std::string varName = "param.var." + param->getNameAsString();
  1960. // Temporary "param.var.*" variables are used for OpFunctionCall purposes.
  1961. // 'precise' attribute on function parameters only affect computations
  1962. // inside the function, not the variables at the call sites. Therefore, we
  1963. // do not need to mark the "param.var.*" variables as precise.
  1964. const bool isPrecise = false;
  1965. auto *tempVar =
  1966. spvBuilder.addFnVar(varType, arg->getLocStart(), varName, isPrecise);
  1967. vars.push_back(tempVar);
  1968. isTempVar.push_back(true);
  1969. args.push_back(argInst);
  1970. // Update counter variable associated with function parameters
  1971. tryToAssignCounterVar(param, arg);
  1972. // Manually load the argument here
  1973. auto *rhsVal = loadIfGLValue(arg, args.back());
  1974. // The AST does not include cast nodes to and from the function parameter
  1975. // type for 'out' and 'inout' cases. Example:
  1976. //
  1977. // void foo(out half3 param) {...}
  1978. // void main() { float3 arg; foo(arg); }
  1979. //
  1980. // In such cases, we first do a manual cast before passing the argument to
  1981. // the function. And we will cast back the results once the function call
  1982. // has returned.
  1983. if (canActAsOutParmVar(param) &&
  1984. !paramTypeMatchesArgType(param->getType(), arg->getType())) {
  1985. auto paramType = param->getType();
  1986. if (const auto *refType = paramType->getAs<ReferenceType>())
  1987. paramType = refType->getPointeeType();
  1988. rhsVal =
  1989. castToType(rhsVal, arg->getType(), paramType, arg->getLocStart());
  1990. }
  1991. // Initialize the temporary variables using the contents of the arguments
  1992. storeValue(tempVar, rhsVal, param->getType(), arg->getLocStart());
  1993. }
  1994. }
  1995. if (beforeHlslLegalization)
  1996. needsLegalization = true;
  1997. assert(vars.size() == isTempVar.size());
  1998. assert(vars.size() == args.size());
  1999. // Push the callee into the work queue if it is not there.
  2000. addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), callee,
  2001. /*isEntryFunction*/ false);
  2002. const QualType retType =
  2003. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee);
  2004. // Get or forward declare the function <result-id>
  2005. SpirvFunction *func = declIdMapper.getOrRegisterFn(callee);
  2006. auto *retVal = spvBuilder.createFunctionCall(
  2007. retType, func, vars, callExpr->getCallee()->getExprLoc());
  2008. // Go through all parameters and write those marked as out/inout
  2009. for (uint32_t i = 0; i < numParams; ++i) {
  2010. const auto *param = callee->getParamDecl(i);
  2011. // If it calls a non-static member function, the object itself is argument
  2012. // 0, and therefore all other argument positions are shifted by 1.
  2013. const uint32_t index = i + isNonStaticMemberCall;
  2014. // Using a resouce as a function parameter is never passed-by-copy. As a
  2015. // result, even if the function parameter is marked as 'out' or 'inout',
  2016. // there is no reason to copy back the results after the function call into
  2017. // the resource.
  2018. if (isTempVar[index] && canActAsOutParmVar(param) &&
  2019. !isResourceType(param)) {
  2020. const auto *arg = callExpr->getArg(i);
  2021. SpirvInstruction *value = spvBuilder.createLoad(
  2022. param->getType(), vars[index], arg->getLocStart());
  2023. // Now we want to assign 'value' to arg. But first, in rare cases when
  2024. // using 'out' or 'inout' where the parameter and argument have a type
  2025. // mismatch, we need to first cast 'value' to the type of 'arg' because
  2026. // the AST will not include a cast node.
  2027. if (!paramTypeMatchesArgType(param->getType(), arg->getType())) {
  2028. auto paramType = param->getType();
  2029. if (const auto *refType = paramType->getAs<ReferenceType>())
  2030. paramType = refType->getPointeeType();
  2031. value =
  2032. castToType(value, paramType, arg->getType(), arg->getLocStart());
  2033. }
  2034. processAssignment(arg, value, false, args[index]);
  2035. }
  2036. }
  2037. return retVal;
  2038. }
  2039. SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) {
  2040. const Expr *subExpr = expr->getSubExpr();
  2041. const QualType subExprType = subExpr->getType();
  2042. const QualType toType = expr->getType();
  2043. const auto srcLoc = expr->getExprLoc();
  2044. switch (expr->getCastKind()) {
  2045. case CastKind::CK_LValueToRValue:
  2046. return loadIfGLValue(subExpr);
  2047. case CastKind::CK_NoOp:
  2048. return doExpr(subExpr);
  2049. case CastKind::CK_IntegralCast:
  2050. case CastKind::CK_FloatingToIntegral:
  2051. case CastKind::CK_HLSLCC_IntegralCast:
  2052. case CastKind::CK_HLSLCC_FloatingToIntegral: {
  2053. // Integer literals in the AST are represented using 64bit APInt
  2054. // themselves and then implicitly casted into the expected bitwidth.
  2055. // We need special treatment of integer literals here because generating
  2056. // a 64bit constant and then explicit casting in SPIR-V requires Int64
  2057. // capability. We should avoid introducing unnecessary capabilities to
  2058. // our best.
  2059. if (auto *value = tryToEvaluateAsConst(expr)) {
  2060. value->setRValue();
  2061. return value;
  2062. }
  2063. auto *value = castToInt(loadIfGLValue(subExpr), subExprType, toType,
  2064. subExpr->getLocStart());
  2065. value->setRValue();
  2066. return value;
  2067. }
  2068. case CastKind::CK_FloatingCast:
  2069. case CastKind::CK_IntegralToFloating:
  2070. case CastKind::CK_HLSLCC_FloatingCast:
  2071. case CastKind::CK_HLSLCC_IntegralToFloating: {
  2072. // First try to see if we can do constant folding for floating point
  2073. // numbers like what we are doing for integers in the above.
  2074. if (auto *value = tryToEvaluateAsConst(expr)) {
  2075. value->setRValue();
  2076. return value;
  2077. }
  2078. auto *value = castToFloat(loadIfGLValue(subExpr), subExprType, toType,
  2079. subExpr->getLocStart());
  2080. value->setRValue();
  2081. return value;
  2082. }
  2083. case CastKind::CK_IntegralToBoolean:
  2084. case CastKind::CK_FloatingToBoolean:
  2085. case CastKind::CK_HLSLCC_IntegralToBoolean:
  2086. case CastKind::CK_HLSLCC_FloatingToBoolean: {
  2087. // First try to see if we can do constant folding.
  2088. if (auto *value = tryToEvaluateAsConst(expr)) {
  2089. value->setRValue();
  2090. return value;
  2091. }
  2092. auto *value = castToBool(loadIfGLValue(subExpr), subExprType, toType,
  2093. subExpr->getLocStart());
  2094. value->setRValue();
  2095. return value;
  2096. }
  2097. case CastKind::CK_HLSLVectorSplat: {
  2098. const size_t size = hlsl::GetHLSLVecSize(expr->getType());
  2099. return createVectorSplat(subExpr, size);
  2100. }
  2101. case CastKind::CK_HLSLVectorTruncationCast: {
  2102. const QualType toVecType = toType;
  2103. const QualType elemType = hlsl::GetHLSLVecElementType(toType);
  2104. const auto toSize = hlsl::GetHLSLVecSize(toType);
  2105. auto *composite = doExpr(subExpr);
  2106. llvm::SmallVector<SpirvInstruction *, 4> elements;
  2107. for (uint32_t i = 0; i < toSize; ++i) {
  2108. elements.push_back(spvBuilder.createCompositeExtract(
  2109. elemType, composite, {i}, expr->getExprLoc()));
  2110. }
  2111. auto *value = elements.front();
  2112. if (toSize > 1) {
  2113. value = spvBuilder.createCompositeConstruct(toVecType, elements,
  2114. expr->getExprLoc());
  2115. }
  2116. value->setRValue();
  2117. return value;
  2118. }
  2119. case CastKind::CK_HLSLVectorToScalarCast: {
  2120. // The underlying should already be a vector of size 1.
  2121. assert(hlsl::GetHLSLVecSize(subExprType) == 1);
  2122. return doExpr(subExpr);
  2123. }
  2124. case CastKind::CK_HLSLVectorToMatrixCast: {
  2125. // If target type is already an 1xN matrix type, we just return the
  2126. // underlying vector.
  2127. if (is1xNMatrix(toType))
  2128. return doExpr(subExpr);
  2129. // A vector can have no more than 4 elements. The only remaining case
  2130. // is casting from size-4 vector to size-2-by-2 matrix.
  2131. auto *vec = loadIfGLValue(subExpr);
  2132. QualType elemType = {};
  2133. uint32_t rowCount = 0, colCount = 0;
  2134. const bool isMat = isMxNMatrix(toType, &elemType, &rowCount, &colCount);
  2135. assert(isMat && rowCount == 2 && colCount == 2);
  2136. (void)isMat;
  2137. QualType vec2Type = astContext.getExtVectorType(elemType, 2);
  2138. auto *subVec1 = spvBuilder.createVectorShuffle(vec2Type, vec, vec, {0, 1},
  2139. expr->getLocStart());
  2140. auto *subVec2 = spvBuilder.createVectorShuffle(vec2Type, vec, vec, {2, 3},
  2141. expr->getLocStart());
  2142. auto *mat = spvBuilder.createCompositeConstruct(toType, {subVec1, subVec2},
  2143. expr->getLocStart());
  2144. mat->setRValue();
  2145. return mat;
  2146. }
  2147. case CastKind::CK_HLSLMatrixSplat: {
  2148. // From scalar to matrix
  2149. uint32_t rowCount = 0, colCount = 0;
  2150. hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount);
  2151. // Handle degenerated cases first
  2152. if (rowCount == 1 && colCount == 1)
  2153. return doExpr(subExpr);
  2154. if (colCount == 1)
  2155. return createVectorSplat(subExpr, rowCount);
  2156. const auto vecSplat = createVectorSplat(subExpr, colCount);
  2157. if (rowCount == 1)
  2158. return vecSplat;
  2159. if (isa<SpirvConstant>(vecSplat)) {
  2160. llvm::SmallVector<SpirvConstant *, 4> vectors(
  2161. size_t(rowCount), cast<SpirvConstant>(vecSplat));
  2162. auto *value = spvBuilder.getConstantComposite(toType, vectors);
  2163. value->setRValue();
  2164. return value;
  2165. } else {
  2166. llvm::SmallVector<SpirvInstruction *, 4> vectors(size_t(rowCount),
  2167. vecSplat);
  2168. auto *value = spvBuilder.createCompositeConstruct(toType, vectors,
  2169. expr->getLocEnd());
  2170. value->setRValue();
  2171. return value;
  2172. }
  2173. }
  2174. case CastKind::CK_HLSLMatrixTruncationCast: {
  2175. const QualType srcType = subExprType;
  2176. auto *src = doExpr(subExpr);
  2177. const QualType elemType = hlsl::GetHLSLMatElementType(srcType);
  2178. llvm::SmallVector<uint32_t, 4> indexes;
  2179. // It is possible that the source matrix is in fact a vector.
  2180. // Example 1: Truncate float1x3 --> float1x2.
  2181. // Example 2: Truncate float1x3 --> float1x1.
  2182. // The front-end disallows float1x3 --> float2x1.
  2183. {
  2184. uint32_t srcVecSize = 0, dstVecSize = 0;
  2185. if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) {
  2186. auto *val = spvBuilder.createCompositeExtract(toType, src, {0},
  2187. expr->getLocStart());
  2188. val->setRValue();
  2189. return val;
  2190. }
  2191. if (isVectorType(srcType, nullptr, &srcVecSize) &&
  2192. isVectorType(toType, nullptr, &dstVecSize)) {
  2193. for (uint32_t i = 0; i < dstVecSize; ++i)
  2194. indexes.push_back(i);
  2195. auto *val = spvBuilder.createVectorShuffle(toType, src, src, indexes,
  2196. expr->getLocStart());
  2197. val->setRValue();
  2198. return val;
  2199. }
  2200. }
  2201. uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0;
  2202. hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols);
  2203. hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols);
  2204. const QualType srcRowType = astContext.getExtVectorType(elemType, srcCols);
  2205. const QualType dstRowType = astContext.getExtVectorType(elemType, dstCols);
  2206. // Indexes to pass to OpVectorShuffle
  2207. for (uint32_t i = 0; i < dstCols; ++i)
  2208. indexes.push_back(i);
  2209. llvm::SmallVector<SpirvInstruction *, 4> extractedVecs;
  2210. for (uint32_t row = 0; row < dstRows; ++row) {
  2211. // Extract a row
  2212. SpirvInstruction *rowInstr = spvBuilder.createCompositeExtract(
  2213. srcRowType, src, {row}, expr->getExprLoc());
  2214. // Extract the necessary columns from that row.
  2215. // The front-end ensures dstCols <= srcCols.
  2216. // If dstCols equals srcCols, we can use the whole row directly.
  2217. if (dstCols == 1) {
  2218. rowInstr = spvBuilder.createCompositeExtract(elemType, rowInstr, {0},
  2219. expr->getLocStart());
  2220. } else if (dstCols < srcCols) {
  2221. rowInstr = spvBuilder.createVectorShuffle(
  2222. dstRowType, rowInstr, rowInstr, indexes, expr->getLocStart());
  2223. }
  2224. extractedVecs.push_back(rowInstr);
  2225. }
  2226. auto *val = extractedVecs.front();
  2227. if (extractedVecs.size() > 1) {
  2228. val = spvBuilder.createCompositeConstruct(toType, extractedVecs,
  2229. expr->getExprLoc());
  2230. }
  2231. val->setRValue();
  2232. return val;
  2233. }
  2234. case CastKind::CK_HLSLMatrixToScalarCast: {
  2235. // The underlying should already be a matrix of 1x1.
  2236. assert(is1x1Matrix(subExprType));
  2237. return doExpr(subExpr);
  2238. }
  2239. case CastKind::CK_HLSLMatrixToVectorCast: {
  2240. // If the underlying matrix is Mx1 or 1xM for M in {1, 2,3,4}, we can return
  2241. // the underlying matrix because it'll be evaluated as a vector by default.
  2242. if (is1x1Matrix(subExprType) || is1xNMatrix(subExprType) ||
  2243. isMx1Matrix(subExprType))
  2244. return doExpr(subExpr);
  2245. // A vector can have no more than 4 elements. The only remaining case
  2246. // is casting from a 2x2 matrix to a vector of size 4.
  2247. auto *mat = loadIfGLValue(subExpr);
  2248. QualType elemType = {};
  2249. uint32_t rowCount = 0, colCount = 0, elemCount = 0;
  2250. const bool isMat =
  2251. isMxNMatrix(subExprType, &elemType, &rowCount, &colCount);
  2252. const bool isVec = isVectorType(toType, nullptr, &elemCount);
  2253. assert(isMat && rowCount == 2 && colCount == 2);
  2254. assert(isVec && elemCount == 4);
  2255. (void)isMat;
  2256. (void)isVec;
  2257. QualType vec2Type = astContext.getExtVectorType(elemType, 2);
  2258. auto *row0 = spvBuilder.createCompositeExtract(vec2Type, mat, {0}, srcLoc);
  2259. auto *row1 = spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc);
  2260. auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3},
  2261. srcLoc);
  2262. vec->setRValue();
  2263. return vec;
  2264. }
  2265. case CastKind::CK_FunctionToPointerDecay:
  2266. // Just need to return the function id
  2267. return doExpr(subExpr);
  2268. case CastKind::CK_FlatConversion: {
  2269. SpirvInstruction *subExprInstr = nullptr;
  2270. QualType evalType = subExprType;
  2271. // Optimization: we can use OpConstantNull for cases where we want to
  2272. // initialize an entire data structure to zeros.
  2273. if (evaluatesToConstZero(subExpr, astContext)) {
  2274. subExprInstr = spvBuilder.getConstantNull(toType);
  2275. subExprInstr->setRValue();
  2276. return subExprInstr;
  2277. }
  2278. // Try to evaluate float literals as float rather than double.
  2279. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(subExpr)) {
  2280. subExprInstr = tryToEvaluateAsFloat32(floatLiteral->getValue());
  2281. if (subExprInstr)
  2282. evalType = astContext.FloatTy;
  2283. }
  2284. // Evaluate 'literal float' initializer type as float rather than double.
  2285. // TODO: This could result in rounding error if the initializer is a
  2286. // non-literal expression that requires larger than 32 bits and has the
  2287. // 'literal float' type.
  2288. else if (subExprType->isSpecificBuiltinType(BuiltinType::LitFloat)) {
  2289. evalType = astContext.FloatTy;
  2290. }
  2291. // Try to evaluate integer literals as 32-bit int rather than 64-bit int.
  2292. else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(subExpr)) {
  2293. const bool isSigned = subExprType->isSignedIntegerType();
  2294. subExprInstr = tryToEvaluateAsInt32(intLiteral->getValue(), isSigned);
  2295. if (subExprInstr)
  2296. evalType = isSigned ? astContext.IntTy : astContext.UnsignedIntTy;
  2297. }
  2298. // For assigning one array instance to another one with the same array type
  2299. // (regardless of constness and literalness), the rhs will be wrapped in a
  2300. // FlatConversion. Similarly for assigning a struct to another struct with
  2301. // identical members.
  2302. // |- <lhs>
  2303. // `- ImplicitCastExpr <FlatConversion>
  2304. // `- ImplicitCastExpr <LValueToRValue>
  2305. // `- <rhs>
  2306. else if (isSameType(astContext, toType, evalType) ||
  2307. // We can have casts changing the shape but without affecting
  2308. // memory order, e.g., `float4 a[2]; float b[8] = (float[8])a;`.
  2309. // This is also represented as FlatConversion. For such cases, we
  2310. // can rely on the InitListHandler, which can decompse
  2311. // vectors/matrices.
  2312. subExprType->isArrayType()) {
  2313. auto *valInstr =
  2314. InitListHandler(astContext, *this).processCast(toType, subExpr);
  2315. if (valInstr)
  2316. valInstr->setRValue();
  2317. return valInstr;
  2318. }
  2319. // We can have casts changing the shape but without affecting memory order,
  2320. // e.g., `float4 a[2]; float b[8] = (float[8])a;`. This is also represented
  2321. // as FlatConversion. For such cases, we can rely on the InitListHandler,
  2322. // which can decompse vectors/matrices.
  2323. else if (subExprType->isArrayType()) {
  2324. auto *valInstr = InitListHandler(astContext, *this)
  2325. .processCast(expr->getType(), subExpr);
  2326. if (valInstr)
  2327. valInstr->setRValue();
  2328. return valInstr;
  2329. }
  2330. if (!subExprInstr)
  2331. subExprInstr = doExpr(subExpr);
  2332. auto *val = processFlatConversion(toType, evalType, subExprInstr,
  2333. expr->getExprLoc());
  2334. val->setRValue();
  2335. return val;
  2336. }
  2337. case CastKind::CK_UncheckedDerivedToBase:
  2338. case CastKind::CK_HLSLDerivedToBase: {
  2339. // Find the index sequence of the base to which we are casting
  2340. llvm::SmallVector<uint32_t, 4> baseIndices;
  2341. getBaseClassIndices(expr, &baseIndices);
  2342. // Turn them in to SPIR-V constants
  2343. llvm::SmallVector<SpirvInstruction *, 4> baseIndexInstructions(
  2344. baseIndices.size(), nullptr);
  2345. for (uint32_t i = 0; i < baseIndices.size(); ++i)
  2346. baseIndexInstructions[i] = spvBuilder.getConstantInt(
  2347. astContext.UnsignedIntTy, llvm::APInt(32, baseIndices[i]));
  2348. auto *derivedInfo = doExpr(subExpr);
  2349. return turnIntoElementPtr(subExpr->getType(), derivedInfo, expr->getType(),
  2350. baseIndexInstructions, subExpr->getExprLoc());
  2351. }
  2352. case CastKind::CK_ArrayToPointerDecay: {
  2353. // Literal string to const string conversion falls under this category.
  2354. if (hlsl::IsStringLiteralType(subExprType) && hlsl::IsStringType(toType)) {
  2355. return doExpr(subExpr);
  2356. } else {
  2357. emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
  2358. << expr->getCastKindName() << expr->getSourceRange();
  2359. expr->dump();
  2360. return 0;
  2361. }
  2362. }
  2363. default:
  2364. emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
  2365. << expr->getCastKindName() << expr->getSourceRange();
  2366. expr->dump();
  2367. return 0;
  2368. }
  2369. }
  2370. SpirvInstruction *SpirvEmitter::processFlatConversion(
  2371. const QualType type, const QualType initType, SpirvInstruction *initInstr,
  2372. SourceLocation srcLoc) {
  2373. // When translating ConstantBuffer<T> or TextureBuffer<T> types, we consider
  2374. // the underlying type (T), and therefore we should bypass the FlatConversion
  2375. // node when accessing these types:
  2376. // `-MemberExpr
  2377. // `-ImplicitCastExpr 'const T' lvalue <FlatConversion>
  2378. // `-ArraySubscriptExpr 'ConstantBuffer<T>':'ConstantBuffer<T>' lvalue
  2379. if(isConstantTextureBuffer(initType)) {
  2380. return initInstr;
  2381. }
  2382. // Try to translate the canonical type first
  2383. const auto canonicalType = type.getCanonicalType();
  2384. if (canonicalType != type)
  2385. return processFlatConversion(canonicalType, initType, initInstr, srcLoc);
  2386. // Primitive types
  2387. {
  2388. QualType ty = {};
  2389. if (isScalarType(type, &ty)) {
  2390. if (const auto *builtinType = ty->getAs<BuiltinType>()) {
  2391. switch (builtinType->getKind()) {
  2392. case BuiltinType::Void: {
  2393. emitError("cannot create a constant of void type", srcLoc);
  2394. return 0;
  2395. }
  2396. case BuiltinType::Bool:
  2397. return castToBool(initInstr, initType, ty, srcLoc);
  2398. // Target type is an integer variant.
  2399. case BuiltinType::Int:
  2400. case BuiltinType::Short:
  2401. case BuiltinType::Min12Int:
  2402. case BuiltinType::Min16Int:
  2403. case BuiltinType::Min16UInt:
  2404. case BuiltinType::UShort:
  2405. case BuiltinType::UInt:
  2406. case BuiltinType::Long:
  2407. case BuiltinType::LongLong:
  2408. case BuiltinType::ULong:
  2409. case BuiltinType::ULongLong:
  2410. return castToInt(initInstr, initType, ty, srcLoc);
  2411. // Target type is a float variant.
  2412. case BuiltinType::Double:
  2413. case BuiltinType::Float:
  2414. case BuiltinType::Half:
  2415. case BuiltinType::HalfFloat:
  2416. case BuiltinType::Min10Float:
  2417. case BuiltinType::Min16Float:
  2418. return castToFloat(initInstr, initType, ty, srcLoc);
  2419. default:
  2420. emitError("flat conversion of type %0 unimplemented", srcLoc)
  2421. << builtinType->getTypeClassName();
  2422. return 0;
  2423. }
  2424. }
  2425. }
  2426. }
  2427. // Vector types
  2428. {
  2429. QualType elemType = {};
  2430. uint32_t elemCount = {};
  2431. if (isVectorType(type, &elemType, &elemCount)) {
  2432. auto *elem = processFlatConversion(elemType, initType, initInstr, srcLoc);
  2433. llvm::SmallVector<SpirvInstruction *, 4> constituents(size_t(elemCount),
  2434. elem);
  2435. return spvBuilder.createCompositeConstruct(type, constituents, srcLoc);
  2436. }
  2437. }
  2438. // Matrix types
  2439. {
  2440. QualType elemType = {};
  2441. uint32_t rowCount = 0, colCount = 0;
  2442. if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
  2443. // By default HLSL matrices are row major, while SPIR-V matrices are
  2444. // column major. We are mapping what HLSL semantically mean a row into a
  2445. // column here.
  2446. const QualType vecType = astContext.getExtVectorType(elemType, colCount);
  2447. auto *elem = processFlatConversion(elemType, initType, initInstr, srcLoc);
  2448. const llvm::SmallVector<SpirvInstruction *, 4> constituents(
  2449. size_t(colCount), elem);
  2450. auto *col =
  2451. spvBuilder.createCompositeConstruct(vecType, constituents, srcLoc);
  2452. const llvm::SmallVector<SpirvInstruction *, 4> rows(size_t(rowCount),
  2453. col);
  2454. return spvBuilder.createCompositeConstruct(type, rows, srcLoc);
  2455. }
  2456. }
  2457. // Struct type
  2458. if (const auto *structType = type->getAs<RecordType>()) {
  2459. const auto *decl = structType->getDecl();
  2460. llvm::SmallVector<SpirvInstruction *, 4> fields;
  2461. for (const auto *field : decl->fields()) {
  2462. // There is a special case for FlatConversion. If T is a struct with only
  2463. // one member, S, then (T)<an-instance-of-S> is allowed, which essentially
  2464. // constructs a new T instance using the instance of S as its only member.
  2465. // Check whether we are handling that case here first.
  2466. if (field->getType().getCanonicalType() == initType.getCanonicalType()) {
  2467. fields.push_back(initInstr);
  2468. } else {
  2469. fields.push_back(processFlatConversion(field->getType(), initType,
  2470. initInstr, srcLoc));
  2471. }
  2472. }
  2473. return spvBuilder.createCompositeConstruct(type, fields, srcLoc);
  2474. }
  2475. // Array type
  2476. if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
  2477. const auto size =
  2478. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  2479. auto *elem = processFlatConversion(arrayType->getElementType(), initType,
  2480. initInstr, srcLoc);
  2481. llvm::SmallVector<SpirvInstruction *, 4> constituents(size_t(size), elem);
  2482. return spvBuilder.createCompositeConstruct(type, constituents, srcLoc);
  2483. }
  2484. emitError("flat conversion of type %0 unimplemented", {})
  2485. << type->getTypeClassName();
  2486. type->dump();
  2487. return 0;
  2488. }
  2489. SpirvInstruction *
  2490. SpirvEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
  2491. const auto opcode = expr->getOpcode();
  2492. // Try to optimize floatMxN *= float and floatN *= float case
  2493. if (opcode == BO_MulAssign) {
  2494. if (auto *result = tryToGenFloatMatrixScale(expr))
  2495. return result;
  2496. if (auto *result = tryToGenFloatVectorScale(expr))
  2497. return result;
  2498. }
  2499. const auto *rhs = expr->getRHS();
  2500. const auto *lhs = expr->getLHS();
  2501. SpirvInstruction *lhsPtr = nullptr;
  2502. auto *result = processBinaryOp(
  2503. lhs, rhs, opcode, expr->getComputationLHSType(), expr->getType(),
  2504. expr->getSourceRange(), expr->getOperatorLoc(), &lhsPtr);
  2505. return processAssignment(lhs, result, true, lhsPtr);
  2506. }
  2507. SpirvInstruction *
  2508. SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) {
  2509. const auto type = expr->getType();
  2510. const SourceLocation loc = expr->getExprLoc();
  2511. const Expr *cond = expr->getCond();
  2512. const Expr *falseExpr = expr->getFalseExpr();
  2513. const Expr *trueExpr = expr->getTrueExpr();
  2514. // According to HLSL doc, all sides of the ?: expression are always evaluated.
  2515. // Corner-case: In HLSL, the condition of the ternary operator can be a
  2516. // matrix of booleans which results in selecting between components of two
  2517. // matrices. However, a matrix of booleans is not a valid type in SPIR-V.
  2518. // If the AST has inserted a splat of a scalar/vector to a matrix, we can just
  2519. // use that scalar/vector as an if-clause condition.
  2520. if (auto *cast = dyn_cast<ImplicitCastExpr>(cond))
  2521. if (cast->getCastKind() == CK_HLSLMatrixSplat)
  2522. cond = cast->getSubExpr();
  2523. // If we are selecting between two SampleState objects, none of the three
  2524. // operands has a LValueToRValue implicit cast.
  2525. auto *condition = loadIfGLValue(cond);
  2526. auto *trueBranch = loadIfGLValue(trueExpr);
  2527. auto *falseBranch = loadIfGLValue(falseExpr);
  2528. // Corner-case: In HLSL, the condition of the ternary operator can be a
  2529. // matrix of booleans which results in selecting between components of two
  2530. // matrices. However, a matrix of booleans is not a valid type in SPIR-V.
  2531. // Therefore, we need to perform OpSelect for each row of the matrix.
  2532. {
  2533. QualType condElemType = {}, elemType = {};
  2534. uint32_t rowCount = 0, colCount = 0;
  2535. if (isMxNMatrix(type, &elemType, &rowCount, &colCount) &&
  2536. isMxNMatrix(cond->getType(), &condElemType) &&
  2537. condElemType->isBooleanType()) {
  2538. const auto rowType = astContext.getExtVectorType(elemType, colCount);
  2539. const auto condRowType =
  2540. astContext.getExtVectorType(condElemType, colCount);
  2541. llvm::SmallVector<SpirvInstruction *, 4> rows;
  2542. for (uint32_t i = 0; i < rowCount; ++i) {
  2543. auto *condRow =
  2544. spvBuilder.createCompositeExtract(condRowType, condition, {i}, loc);
  2545. auto *trueRow =
  2546. spvBuilder.createCompositeExtract(rowType, trueBranch, {i}, loc);
  2547. auto *falseRow =
  2548. spvBuilder.createCompositeExtract(rowType, falseBranch, {i}, loc);
  2549. rows.push_back(
  2550. spvBuilder.createSelect(rowType, condRow, trueRow, falseRow, loc));
  2551. }
  2552. auto *result = spvBuilder.createCompositeConstruct(type, rows, loc);
  2553. result->setRValue();
  2554. return result;
  2555. }
  2556. }
  2557. // For cases where the return type is a scalar or a vector, we can use
  2558. // OpSelect to choose between the two. OpSelect's return type must be either
  2559. // scalar or vector.
  2560. if (isScalarType(type) || isVectorType(type)) {
  2561. // The SPIR-V OpSelect instruction must have a selection argument that is
  2562. // the same size as the return type. If the return type is a vector, the
  2563. // selection must be a vector of booleans (one per output component).
  2564. uint32_t count = 0;
  2565. if (isVectorType(expr->getType(), nullptr, &count) &&
  2566. !isVectorType(expr->getCond()->getType())) {
  2567. const llvm::SmallVector<SpirvInstruction *, 4> components(size_t(count),
  2568. condition);
  2569. condition = spvBuilder.createCompositeConstruct(
  2570. astContext.getExtVectorType(astContext.BoolTy, count), components,
  2571. expr->getCond()->getLocEnd());
  2572. }
  2573. auto *value =
  2574. spvBuilder.createSelect(type, condition, trueBranch, falseBranch, loc);
  2575. value->setRValue();
  2576. return value;
  2577. }
  2578. // If we can't use OpSelect, we need to create if-else control flow.
  2579. auto *tempVar = spvBuilder.addFnVar(type, loc, "temp.var.ternary");
  2580. auto *thenBB = spvBuilder.createBasicBlock("if.true");
  2581. auto *mergeBB = spvBuilder.createBasicBlock("if.merge");
  2582. auto *elseBB = spvBuilder.createBasicBlock("if.false");
  2583. // Create the branch instruction. This will end the current basic block.
  2584. spvBuilder.createConditionalBranch(condition, thenBB, elseBB,
  2585. expr->getCond()->getLocEnd(), mergeBB);
  2586. spvBuilder.addSuccessor(thenBB);
  2587. spvBuilder.addSuccessor(elseBB);
  2588. spvBuilder.setMergeTarget(mergeBB);
  2589. // Handle the then branch
  2590. spvBuilder.setInsertPoint(thenBB);
  2591. spvBuilder.createStore(tempVar, trueBranch,
  2592. expr->getTrueExpr()->getLocStart());
  2593. spvBuilder.createBranch(mergeBB, expr->getTrueExpr()->getLocEnd());
  2594. spvBuilder.addSuccessor(mergeBB);
  2595. // Handle the else branch
  2596. spvBuilder.setInsertPoint(elseBB);
  2597. spvBuilder.createStore(tempVar, falseBranch,
  2598. expr->getFalseExpr()->getLocStart());
  2599. spvBuilder.createBranch(mergeBB, expr->getFalseExpr()->getLocEnd());
  2600. spvBuilder.addSuccessor(mergeBB);
  2601. // From now on, emit instructions into the merge block.
  2602. spvBuilder.setInsertPoint(mergeBB);
  2603. auto *result = spvBuilder.createLoad(type, tempVar, expr->getLocEnd());
  2604. result->setRValue();
  2605. return result;
  2606. }
  2607. SpirvInstruction *
  2608. SpirvEmitter::processByteAddressBufferStructuredBufferGetDimensions(
  2609. const CXXMemberCallExpr *expr) {
  2610. const auto *object = expr->getImplicitObjectArgument();
  2611. auto *objectInstr = loadIfAliasVarRef(object);
  2612. const auto type = object->getType();
  2613. const bool isBABuf = isByteAddressBuffer(type) || isRWByteAddressBuffer(type);
  2614. const bool isStructuredBuf = isStructuredBuffer(type) ||
  2615. isAppendStructuredBuffer(type) ||
  2616. isConsumeStructuredBuffer(type);
  2617. assert(isBABuf || isStructuredBuf);
  2618. // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
  2619. // with only one member that is a runtime array. We need to perform
  2620. // OpArrayLength on member 0.
  2621. SpirvInstruction *length = spvBuilder.createArrayLength(
  2622. astContext.UnsignedIntTy, expr->getExprLoc(), objectInstr, 0);
  2623. // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
  2624. // in bytes, but OpArrayLength returns the number of uints in the runtime
  2625. // array. Therefore we must multiply the results by 4.
  2626. if (isBABuf) {
  2627. length = spvBuilder.createBinaryOp(
  2628. spv::Op::OpIMul, astContext.UnsignedIntTy, length,
  2629. // TODO(jaebaek): What line info we should emit for constants?
  2630. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  2631. llvm::APInt(32, 4u)),
  2632. expr->getExprLoc());
  2633. }
  2634. spvBuilder.createStore(doExpr(expr->getArg(0)), length,
  2635. expr->getArg(0)->getLocStart());
  2636. if (isStructuredBuf) {
  2637. // For (RW)StructuredBuffer, the stride of the runtime array (which is the
  2638. // size of the struct) must also be written to the second argument.
  2639. AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
  2640. uint32_t size = 0, stride = 0;
  2641. std::tie(std::ignore, size) =
  2642. alignmentCalc.getAlignmentAndSize(type, spirvOptions.sBufferLayoutRule,
  2643. /*isRowMajor*/ llvm::None, &stride);
  2644. auto *sizeInstr = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  2645. llvm::APInt(32, size));
  2646. spvBuilder.createStore(doExpr(expr->getArg(1)), sizeInstr,
  2647. expr->getArg(1)->getLocStart());
  2648. }
  2649. return nullptr;
  2650. }
  2651. SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods(
  2652. hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) {
  2653. // The signature of RWByteAddressBuffer atomic methods are largely:
  2654. // void Interlocked*(in UINT dest, in UINT value);
  2655. // void Interlocked*(in UINT dest, in UINT value, out UINT original_value);
  2656. const auto *object = expr->getImplicitObjectArgument();
  2657. auto *objectInfo = loadIfAliasVarRef(object);
  2658. auto *zero =
  2659. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  2660. auto *offset = doExpr(expr->getArg(0));
  2661. // Right shift by 2 to convert the byte offset to uint32_t offset
  2662. auto *address = spvBuilder.createBinaryOp(
  2663. spv::Op::OpShiftRightLogical, astContext.UnsignedIntTy, offset,
  2664. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)),
  2665. expr->getExprLoc());
  2666. auto *ptr =
  2667. spvBuilder.createAccessChain(astContext.UnsignedIntTy, objectInfo,
  2668. {zero, address}, object->getLocStart());
  2669. const bool isCompareExchange =
  2670. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange;
  2671. const bool isCompareStore =
  2672. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore;
  2673. if (isCompareExchange || isCompareStore) {
  2674. auto *comparator = doExpr(expr->getArg(1));
  2675. auto *originalVal = spvBuilder.createAtomicCompareExchange(
  2676. astContext.UnsignedIntTy, ptr, spv::Scope::Device,
  2677. spv::MemorySemanticsMask::MaskNone, spv::MemorySemanticsMask::MaskNone,
  2678. doExpr(expr->getArg(2)), comparator, expr->getCallee()->getExprLoc());
  2679. if (isCompareExchange)
  2680. spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal,
  2681. expr->getArg(3)->getLocStart());
  2682. } else {
  2683. auto *value = doExpr(expr->getArg(1));
  2684. SpirvInstruction *originalVal = spvBuilder.createAtomicOp(
  2685. translateAtomicHlslOpcodeToSpirvOpcode(opcode),
  2686. astContext.UnsignedIntTy, ptr, spv::Scope::Device,
  2687. spv::MemorySemanticsMask::MaskNone, value,
  2688. expr->getCallee()->getExprLoc());
  2689. if (expr->getNumArgs() > 2) {
  2690. originalVal = castToType(originalVal, astContext.UnsignedIntTy,
  2691. expr->getArg(2)->getType(),
  2692. expr->getArg(2)->getLocStart());
  2693. spvBuilder.createStore(doExpr(expr->getArg(2)), originalVal,
  2694. expr->getArg(2)->getLocStart());
  2695. }
  2696. }
  2697. return nullptr;
  2698. }
  2699. SpirvInstruction *
  2700. SpirvEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) {
  2701. const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
  2702. auto *sampleCount = spvBuilder.createImageQuery(
  2703. spv::Op::OpImageQuerySamples, astContext.UnsignedIntTy,
  2704. expr->getExprLoc(), loadIfGLValue(object));
  2705. if (!spirvOptions.noWarnEmulatedFeatures)
  2706. emitWarning("GetSamplePosition is emulated using many SPIR-V instructions "
  2707. "due to lack of direct SPIR-V equivalent, so it only supports "
  2708. "standard sample settings with 1, 2, 4, 8, or 16 samples and "
  2709. "will return float2(0, 0) for other cases",
  2710. expr->getCallee()->getExprLoc());
  2711. return emitGetSamplePosition(sampleCount, doExpr(expr->getArg(0)),
  2712. expr->getCallee()->getExprLoc());
  2713. }
  2714. SpirvInstruction *
  2715. SpirvEmitter::processSubpassLoad(const CXXMemberCallExpr *expr) {
  2716. const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
  2717. SpirvInstruction *sample =
  2718. expr->getNumArgs() == 1 ? doExpr(expr->getArg(0)) : nullptr;
  2719. auto *zero = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
  2720. auto *location = spvBuilder.getConstantComposite(
  2721. astContext.getExtVectorType(astContext.IntTy, 2), {zero, zero});
  2722. return processBufferTextureLoad(object, location, /*constOffset*/ 0,
  2723. /*varOffset*/ 0, /*lod*/ sample,
  2724. /*residencyCode*/ 0, expr->getExprLoc());
  2725. }
  2726. SpirvInstruction *
  2727. SpirvEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
  2728. const auto *object = expr->getImplicitObjectArgument();
  2729. auto *objectInstr = loadIfGLValue(object);
  2730. const auto type = object->getType();
  2731. const auto *recType = type->getAs<RecordType>();
  2732. assert(recType);
  2733. const auto typeName = recType->getDecl()->getName();
  2734. const auto numArgs = expr->getNumArgs();
  2735. const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr;
  2736. assert(isTexture(type) || isRWTexture(type) || isBuffer(type) ||
  2737. isRWBuffer(type));
  2738. // For Texture1D, arguments are either:
  2739. // a) width
  2740. // b) MipLevel, width, NumLevels
  2741. // For Texture1DArray, arguments are either:
  2742. // a) width, elements
  2743. // b) MipLevel, width, elements, NumLevels
  2744. // For Texture2D, arguments are either:
  2745. // a) width, height
  2746. // b) MipLevel, width, height, NumLevels
  2747. // For Texture2DArray, arguments are either:
  2748. // a) width, height, elements
  2749. // b) MipLevel, width, height, elements, NumLevels
  2750. // For Texture3D, arguments are either:
  2751. // a) width, height, depth
  2752. // b) MipLevel, width, height, depth, NumLevels
  2753. // For Texture2DMS, arguments are: width, height, NumSamples
  2754. // For Texture2DMSArray, arguments are: width, height, elements, NumSamples
  2755. // For TextureCube, arguments are either:
  2756. // a) width, height
  2757. // b) MipLevel, width, height, NumLevels
  2758. // For TextureCubeArray, arguments are either:
  2759. // a) width, height, elements
  2760. // b) MipLevel, width, height, elements, NumLevels
  2761. // Note: SPIR-V Spec requires return type of OpImageQuerySize(Lod) to be a
  2762. // scalar/vector of integers. SPIR-V Spec also requires return type of
  2763. // OpImageQueryLevels and OpImageQuerySamples to be scalar integers.
  2764. // The HLSL methods, however, have overloaded functions which have float
  2765. // output arguments. Since the AST naturally won't have casting AST nodes for
  2766. // such cases, we'll have to perform the cast ourselves.
  2767. const auto storeToOutputArg = [this](const Expr *outputArg,
  2768. SpirvInstruction *id, QualType type) {
  2769. id = castToType(id, type, outputArg->getType(), outputArg->getExprLoc());
  2770. spvBuilder.createStore(doExpr(outputArg), id, outputArg->getLocStart());
  2771. };
  2772. if ((typeName == "Texture1D" && numArgs > 1) ||
  2773. (typeName == "Texture2D" && numArgs > 2) ||
  2774. (typeName == "TextureCube" && numArgs > 2) ||
  2775. (typeName == "Texture3D" && numArgs > 3) ||
  2776. (typeName == "Texture1DArray" && numArgs > 2) ||
  2777. (typeName == "TextureCubeArray" && numArgs > 3) ||
  2778. (typeName == "Texture2DArray" && numArgs > 3)) {
  2779. mipLevel = expr->getArg(0);
  2780. numLevels = expr->getArg(numArgs - 1);
  2781. }
  2782. if (isTextureMS(type)) {
  2783. numSamples = expr->getArg(numArgs - 1);
  2784. }
  2785. uint32_t querySize = numArgs;
  2786. // If numLevels arg is present, mipLevel must also be present. These are not
  2787. // queried via ImageQuerySizeLod.
  2788. if (numLevels)
  2789. querySize -= 2;
  2790. // If numLevels arg is present, mipLevel must also be present.
  2791. else if (numSamples)
  2792. querySize -= 1;
  2793. const QualType resultQualType =
  2794. querySize == 1
  2795. ? astContext.UnsignedIntTy
  2796. : astContext.getExtVectorType(astContext.UnsignedIntTy, querySize);
  2797. // Only Texture types use ImageQuerySizeLod.
  2798. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize.
  2799. SpirvInstruction *lod = nullptr;
  2800. if (isTexture(type) && !numSamples) {
  2801. if (mipLevel) {
  2802. // For Texture types when mipLevel argument is present.
  2803. lod = doExpr(mipLevel);
  2804. } else {
  2805. // For Texture types when mipLevel argument is omitted.
  2806. lod = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
  2807. }
  2808. }
  2809. SpirvInstruction *query =
  2810. lod ? cast<SpirvInstruction>(spvBuilder.createImageQuery(
  2811. spv::Op::OpImageQuerySizeLod, resultQualType,
  2812. expr->getCallee()->getExprLoc(), objectInstr, lod))
  2813. : cast<SpirvInstruction>(spvBuilder.createImageQuery(
  2814. spv::Op::OpImageQuerySize, resultQualType,
  2815. expr->getCallee()->getExprLoc(), objectInstr));
  2816. if (querySize == 1) {
  2817. const uint32_t argIndex = mipLevel ? 1 : 0;
  2818. storeToOutputArg(expr->getArg(argIndex), query, resultQualType);
  2819. } else {
  2820. for (uint32_t i = 0; i < querySize; ++i) {
  2821. const uint32_t argIndex = mipLevel ? i + 1 : i;
  2822. auto *component = spvBuilder.createCompositeExtract(
  2823. astContext.UnsignedIntTy, query, {i},
  2824. expr->getCallee()->getExprLoc());
  2825. // If the first arg is the mipmap level, we must write the results
  2826. // starting from Arg(i+1), not Arg(i).
  2827. storeToOutputArg(expr->getArg(argIndex), component,
  2828. astContext.UnsignedIntTy);
  2829. }
  2830. }
  2831. if (numLevels || numSamples) {
  2832. const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
  2833. const spv::Op opcode =
  2834. numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
  2835. auto *numLevelsSamplesQuery = spvBuilder.createImageQuery(
  2836. opcode, astContext.UnsignedIntTy, expr->getCallee()->getExprLoc(),
  2837. objectInstr);
  2838. storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery,
  2839. astContext.UnsignedIntTy);
  2840. }
  2841. return nullptr;
  2842. }
  2843. SpirvInstruction *
  2844. SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
  2845. bool unclamped) {
  2846. // Possible signatures are as follows:
  2847. // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x);
  2848. // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy);
  2849. // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz);
  2850. // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz);
  2851. // Return type is always a single float (LOD).
  2852. assert(expr->getNumArgs() == 2u);
  2853. const auto *object = expr->getImplicitObjectArgument();
  2854. auto *objectInfo = loadIfGLValue(object);
  2855. auto *samplerState = doExpr(expr->getArg(0));
  2856. auto *coordinate = doExpr(expr->getArg(1));
  2857. auto *sampledImage = spvBuilder.createSampledImage(
  2858. object->getType(), objectInfo, samplerState, expr->getExprLoc());
  2859. // The result type of OpImageQueryLod must be a float2.
  2860. const QualType queryResultType =
  2861. astContext.getExtVectorType(astContext.FloatTy, 2u);
  2862. auto *query =
  2863. spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType,
  2864. expr->getExprLoc(), sampledImage, coordinate);
  2865. // The first component of the float2 contains the mipmap array layer.
  2866. // The second component of the float2 represents the unclamped lod.
  2867. return spvBuilder.createCompositeExtract(astContext.FloatTy, query,
  2868. unclamped ? 1 : 0,
  2869. expr->getCallee()->getExprLoc());
  2870. }
  2871. SpirvInstruction *SpirvEmitter::processTextureGatherRGBACmpRGBA(
  2872. const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) {
  2873. // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following
  2874. // two sets:
  2875. // * SamplerState s, float2 location, int2 offset
  2876. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  2877. // int offset2, int2 offset3
  2878. //
  2879. // An additional 'out uint status' parameter can appear in both of the above.
  2880. //
  2881. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following
  2882. // two sets:
  2883. // * SamplerState s, float2 location, float compare_value, int2 offset
  2884. // * SamplerState s, float2 location, float compare_value, int2 offset1,
  2885. // int2 offset2, int2 offset3, int2 offset4
  2886. //
  2887. // An additional 'out uint status' parameter can appear in both of the above.
  2888. //
  2889. // TextureCube's signature is somewhat different from the rest.
  2890. // Parameters for .Gather{Red|Green|Blue|Alpha}() for TextureCube are:
  2891. // * SamplerState s, float2 location, out uint status
  2892. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() for TextureCube are:
  2893. // * SamplerState s, float2 location, float compare_value, out uint status
  2894. //
  2895. // Return type is always a 4-component vector.
  2896. const FunctionDecl *callee = expr->getDirectCallee();
  2897. const auto numArgs = expr->getNumArgs();
  2898. const auto *imageExpr = expr->getImplicitObjectArgument();
  2899. const auto loc = expr->getCallee()->getExprLoc();
  2900. const QualType imageType = imageExpr->getType();
  2901. const QualType retType = callee->getReturnType();
  2902. // If the last arg is an unsigned integer, it must be the status.
  2903. const bool hasStatusArg =
  2904. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  2905. // Subtract 1 for status arg (if it exists), subtract 1 for compare_value (if
  2906. // it exists), and subtract 2 for SamplerState and location.
  2907. const auto numOffsetArgs = numArgs - hasStatusArg - isCmp - 2;
  2908. // No offset args for TextureCube, 1 or 4 offset args for the rest.
  2909. assert(numOffsetArgs == 0 || numOffsetArgs == 1 || numOffsetArgs == 4);
  2910. auto *image = loadIfGLValue(imageExpr);
  2911. auto *sampler = doExpr(expr->getArg(0));
  2912. auto *coordinate = doExpr(expr->getArg(1));
  2913. auto *compareVal = isCmp ? doExpr(expr->getArg(2)) : nullptr;
  2914. // Handle offsets (if any).
  2915. bool needsEmulation = false;
  2916. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr,
  2917. *constOffsets = nullptr;
  2918. if (numOffsetArgs == 1) {
  2919. // The offset arg is not optional.
  2920. handleOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset);
  2921. } else if (numOffsetArgs == 4) {
  2922. auto *offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp));
  2923. auto *offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp));
  2924. auto *offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp));
  2925. auto *offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp));
  2926. // If any of the offsets is not constant, we then need to emulate the call
  2927. // using 4 OpImageGather instructions. Otherwise, we can leverage the
  2928. // ConstOffsets image operand.
  2929. if (offset0 && offset1 && offset2 && offset3) {
  2930. const QualType v2i32 = astContext.getExtVectorType(astContext.IntTy, 2);
  2931. const auto offsetType = astContext.getConstantArrayType(
  2932. v2i32, llvm::APInt(32, 4), clang::ArrayType::Normal, 0);
  2933. constOffsets = spvBuilder.getConstantComposite(
  2934. offsetType, {offset0, offset1, offset2, offset3});
  2935. } else {
  2936. needsEmulation = true;
  2937. }
  2938. }
  2939. auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  2940. if (needsEmulation) {
  2941. const auto elemType = hlsl::GetHLSLVecElementType(callee->getReturnType());
  2942. SpirvInstruction *texels[4];
  2943. for (uint32_t i = 0; i < 4; ++i) {
  2944. varOffset = doExpr(expr->getArg(2 + isCmp + i));
  2945. auto *gatherRet = spvBuilder.createImageGather(
  2946. retType, imageType, image, sampler, coordinate,
  2947. spvBuilder.getConstantInt(astContext.IntTy,
  2948. llvm::APInt(32, component, true)),
  2949. compareVal,
  2950. /*constOffset*/ nullptr, varOffset, /*constOffsets*/ nullptr,
  2951. /*sampleNumber*/ nullptr, status, loc);
  2952. texels[i] =
  2953. spvBuilder.createCompositeExtract(elemType, gatherRet, {i}, loc);
  2954. }
  2955. return spvBuilder.createCompositeConstruct(
  2956. retType, {texels[0], texels[1], texels[2], texels[3]}, loc);
  2957. }
  2958. return spvBuilder.createImageGather(
  2959. retType, imageType, image, sampler, coordinate,
  2960. spvBuilder.getConstantInt(astContext.IntTy,
  2961. llvm::APInt(32, component, true)),
  2962. compareVal, constOffset, varOffset, constOffsets,
  2963. /*sampleNumber*/ nullptr, status, loc);
  2964. }
  2965. SpirvInstruction *
  2966. SpirvEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) {
  2967. // Signature for Texture2D/Texture2DArray:
  2968. //
  2969. // float4 GatherCmp(
  2970. // in SamplerComparisonState s,
  2971. // in float2 location,
  2972. // in float compare_value
  2973. // [,in int2 offset]
  2974. // [,out uint Status]
  2975. // );
  2976. //
  2977. // Signature for TextureCube/TextureCubeArray:
  2978. //
  2979. // float4 GatherCmp(
  2980. // in SamplerComparisonState s,
  2981. // in float2 location,
  2982. // in float compare_value,
  2983. // out uint Status
  2984. // );
  2985. //
  2986. // Other Texture types do not have the GatherCmp method.
  2987. const FunctionDecl *callee = expr->getDirectCallee();
  2988. const auto numArgs = expr->getNumArgs();
  2989. const auto loc = expr->getExprLoc();
  2990. const bool hasStatusArg =
  2991. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  2992. const bool hasOffsetArg = (numArgs == 5) || (numArgs == 4 && !hasStatusArg);
  2993. const auto *imageExpr = expr->getImplicitObjectArgument();
  2994. auto *image = loadIfGLValue(imageExpr);
  2995. auto *sampler = doExpr(expr->getArg(0));
  2996. auto *coordinate = doExpr(expr->getArg(1));
  2997. auto *comparator = doExpr(expr->getArg(2));
  2998. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  2999. if (hasOffsetArg)
  3000. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  3001. const auto retType = callee->getReturnType();
  3002. const auto imageType = imageExpr->getType();
  3003. const auto status =
  3004. hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  3005. return spvBuilder.createImageGather(
  3006. retType, imageType, image, sampler, coordinate,
  3007. /*component*/ nullptr, comparator, constOffset, varOffset,
  3008. /*constOffsets*/ nullptr,
  3009. /*sampleNumber*/ nullptr, status, loc);
  3010. }
  3011. SpirvInstruction *SpirvEmitter::processBufferTextureLoad(
  3012. const Expr *object, SpirvInstruction *location,
  3013. SpirvInstruction *constOffset, SpirvInstruction *varOffset,
  3014. SpirvInstruction *lod, SpirvInstruction *residencyCode,
  3015. SourceLocation loc) {
  3016. // Loading for Buffer and RWBuffer translates to an OpImageFetch.
  3017. // The result type of an OpImageFetch must be a vec4 of float or int.
  3018. const auto type = object->getType();
  3019. assert(isBuffer(type) || isRWBuffer(type) || isTexture(type) ||
  3020. isRWTexture(type) || isSubpassInput(type) || isSubpassInputMS(type));
  3021. const bool doFetch = isBuffer(type) || isTexture(type);
  3022. auto *objectInfo = loadIfGLValue(object);
  3023. // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod.
  3024. SpirvInstruction *sampleNumber = nullptr;
  3025. if (isTextureMS(type) || isSubpassInputMS(type)) {
  3026. sampleNumber = lod;
  3027. lod = nullptr;
  3028. }
  3029. const auto sampledType = hlsl::GetHLSLResourceResultType(type);
  3030. QualType elemType = sampledType;
  3031. uint32_t elemCount = 1;
  3032. bool isTemplateOverStruct = false;
  3033. // Check whether the template type is a vector type or struct type.
  3034. if (!isVectorType(sampledType, &elemType, &elemCount)) {
  3035. if (sampledType->getAsStructureType()) {
  3036. isTemplateOverStruct = true;
  3037. // For struct type, we need to make sure it can fit into a 4-component
  3038. // vector. Detailed failing reasons will be emitted by the function so
  3039. // we don't need to emit errors here.
  3040. if (!canFitIntoOneRegister(astContext, sampledType, &elemType,
  3041. &elemCount))
  3042. return nullptr;
  3043. }
  3044. }
  3045. {
  3046. // Treat a vector of size 1 the same as a scalar.
  3047. if (hlsl::IsHLSLVecType(elemType) && hlsl::GetHLSLVecSize(elemType) == 1)
  3048. elemType = hlsl::GetHLSLVecElementType(elemType);
  3049. if (!elemType->isFloatingType() && !elemType->isIntegerType()) {
  3050. emitError("loading %0 value unsupported", object->getExprLoc()) << type;
  3051. return nullptr;
  3052. }
  3053. }
  3054. // If residencyCode is nullptr, we are dealing with a Load method with 2
  3055. // arguments which does not return the operation status.
  3056. if (residencyCode && residencyCode->isRValue()) {
  3057. emitError(
  3058. "an lvalue argument should be used for returning the operation status",
  3059. loc);
  3060. return nullptr;
  3061. }
  3062. // OpImageFetch and OpImageRead can only fetch a vector of 4 elements.
  3063. const QualType texelType = astContext.getExtVectorType(elemType, 4u);
  3064. auto *texel = spvBuilder.createImageFetchOrRead(
  3065. doFetch, texelType, type, objectInfo, location, lod, constOffset,
  3066. varOffset, /*constOffsets*/ nullptr, sampleNumber, residencyCode, loc);
  3067. // If the result type is a vec1, vec2, or vec3, some extra processing
  3068. // (extraction) is required.
  3069. auto *retVal = extractVecFromVec4(texel, elemCount, elemType, loc);
  3070. if (isTemplateOverStruct) {
  3071. // Convert to the struct so that we are consistent with types in the AST.
  3072. retVal = convertVectorToStruct(sampledType, elemType, retVal, loc);
  3073. }
  3074. retVal->setRValue();
  3075. return retVal;
  3076. }
  3077. SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore(
  3078. const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
  3079. SpirvInstruction *result = nullptr;
  3080. const auto object = expr->getImplicitObjectArgument();
  3081. auto *objectInfo = loadIfAliasVarRef(object);
  3082. assert(numWords >= 1 && numWords <= 4);
  3083. if (doStore) {
  3084. assert(isRWByteAddressBuffer(object->getType()));
  3085. assert(expr->getNumArgs() == 2);
  3086. } else {
  3087. assert(isRWByteAddressBuffer(object->getType()) ||
  3088. isByteAddressBuffer(object->getType()));
  3089. if (expr->getNumArgs() == 2) {
  3090. emitError(
  3091. "(RW)ByteAddressBuffer::Load(in address, out status) not supported",
  3092. expr->getExprLoc());
  3093. return 0;
  3094. }
  3095. }
  3096. const Expr *addressExpr = expr->getArg(0);
  3097. auto *byteAddress = doExpr(addressExpr);
  3098. const QualType addressType = addressExpr->getType();
  3099. // The front-end prevents usage of templated Load2, Load3, Load4, Store2,
  3100. // Store3, Store4 intrinsic functions.
  3101. const bool isTemplatedLoadOrStore =
  3102. (numWords == 1) &&
  3103. (doStore ? !expr->getArg(1)->getType()->isSpecificBuiltinType(
  3104. BuiltinType::UInt)
  3105. : !expr->getType()->isSpecificBuiltinType(BuiltinType::UInt));
  3106. // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory
  3107. // access). The AST always casts the address to unsinged integer, so shift
  3108. // by unsinged integer 2.
  3109. auto *constUint2 =
  3110. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2));
  3111. SpirvInstruction *address =
  3112. spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, addressType,
  3113. byteAddress, constUint2, expr->getExprLoc());
  3114. if (isTemplatedLoadOrStore) {
  3115. // Templated load. Need to (potentially) perform more
  3116. // loads/casts/composite-constructs.
  3117. uint32_t bitOffset = 0;
  3118. if (doStore) {
  3119. auto *values = doExpr(expr->getArg(1));
  3120. RawBufferHandler(*this).processTemplatedStoreToBuffer(
  3121. values, objectInfo, address, expr->getArg(1)->getType(), bitOffset);
  3122. return nullptr;
  3123. } else {
  3124. RawBufferHandler rawBufferHandler(*this);
  3125. return rawBufferHandler.processTemplatedLoadFromBuffer(
  3126. objectInfo, address, expr->getType(), bitOffset);
  3127. }
  3128. }
  3129. // Perform access chain into the RWByteAddressBuffer.
  3130. // First index must be zero (member 0 of the struct is a
  3131. // runtimeArray). The second index passed to OpAccessChain should be
  3132. // the address.
  3133. auto *constUint0 =
  3134. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  3135. if (doStore) {
  3136. auto *values = doExpr(expr->getArg(1));
  3137. auto *curStoreAddress = address;
  3138. for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) {
  3139. // Extract a 32-bit word from the input.
  3140. auto *curValue = numWords == 1
  3141. ? values
  3142. : spvBuilder.createCompositeExtract(
  3143. astContext.UnsignedIntTy, values,
  3144. {wordCounter}, expr->getArg(1)->getExprLoc());
  3145. // Update the output address if necessary.
  3146. if (wordCounter > 0) {
  3147. auto *offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  3148. llvm::APInt(32, wordCounter));
  3149. curStoreAddress =
  3150. spvBuilder.createBinaryOp(spv::Op::OpIAdd, addressType, address,
  3151. offset, expr->getCallee()->getExprLoc());
  3152. }
  3153. // Store the word to the right address at the output.
  3154. auto *storePtr = spvBuilder.createAccessChain(
  3155. astContext.UnsignedIntTy, objectInfo, {constUint0, curStoreAddress},
  3156. object->getLocStart());
  3157. spvBuilder.createStore(storePtr, curValue,
  3158. expr->getCallee()->getExprLoc());
  3159. }
  3160. } else {
  3161. auto *loadPtr = spvBuilder.createAccessChain(
  3162. astContext.UnsignedIntTy, objectInfo, {constUint0, address},
  3163. object->getLocStart());
  3164. result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr,
  3165. expr->getCallee()->getExprLoc());
  3166. if (numWords > 1) {
  3167. // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to
  3168. // return a vector result.
  3169. llvm::SmallVector<SpirvInstruction *, 4> values;
  3170. values.push_back(result);
  3171. for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) {
  3172. auto *offset = spvBuilder.getConstantInt(
  3173. astContext.UnsignedIntTy, llvm::APInt(32, wordCounter - 1));
  3174. auto *newAddress =
  3175. spvBuilder.createBinaryOp(spv::Op::OpIAdd, addressType, address,
  3176. offset, expr->getCallee()->getExprLoc());
  3177. loadPtr = spvBuilder.createAccessChain(
  3178. astContext.UnsignedIntTy, objectInfo, {constUint0, newAddress},
  3179. object->getLocStart());
  3180. values.push_back(
  3181. spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr,
  3182. expr->getCallee()->getExprLoc()));
  3183. }
  3184. const QualType resultType =
  3185. astContext.getExtVectorType(addressType, numWords);
  3186. result = spvBuilder.createCompositeConstruct(resultType, values,
  3187. expr->getLocStart());
  3188. result->setRValue();
  3189. }
  3190. }
  3191. return result;
  3192. }
  3193. SpirvInstruction *
  3194. SpirvEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
  3195. if (expr->getNumArgs() == 2) {
  3196. emitError(
  3197. "(RW)StructuredBuffer::Load(in location, out status) not supported",
  3198. expr->getExprLoc());
  3199. return 0;
  3200. }
  3201. const auto *buffer = expr->getImplicitObjectArgument();
  3202. auto *info = loadIfAliasVarRef(buffer);
  3203. const QualType structType =
  3204. hlsl::GetHLSLResourceResultType(buffer->getType());
  3205. auto *zero = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
  3206. auto *index = doExpr(expr->getArg(0));
  3207. return turnIntoElementPtr(buffer->getType(), info, structType, {zero, index},
  3208. buffer->getExprLoc());
  3209. }
  3210. SpirvInstruction *
  3211. SpirvEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
  3212. bool isInc, bool loadObject) {
  3213. auto *zero =
  3214. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  3215. auto *sOne =
  3216. spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 1, true));
  3217. const auto srcLoc = expr->getCallee()->getExprLoc();
  3218. const auto *object =
  3219. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  3220. if (loadObject) {
  3221. // We don't need the object's <result-id> here since counter variable is a
  3222. // separate variable. But we still need the side effects of evaluating the
  3223. // object, e.g., if the source code is foo(...).IncrementCounter(), we still
  3224. // want to emit the code for foo(...).
  3225. (void)doExpr(object);
  3226. }
  3227. const auto *counterPair = getFinalACSBufferCounter(object);
  3228. if (!counterPair) {
  3229. emitFatalError("cannot find the associated counter variable",
  3230. object->getExprLoc());
  3231. return nullptr;
  3232. }
  3233. auto *counterPtr = spvBuilder.createAccessChain(
  3234. astContext.IntTy, counterPair->get(spvBuilder, spvContext), {zero},
  3235. srcLoc);
  3236. SpirvInstruction *index = nullptr;
  3237. if (isInc) {
  3238. index = spvBuilder.createAtomicOp(
  3239. spv::Op::OpAtomicIAdd, astContext.IntTy, counterPtr, spv::Scope::Device,
  3240. spv::MemorySemanticsMask::MaskNone, sOne, srcLoc);
  3241. } else {
  3242. // Note that OpAtomicISub returns the value before the subtraction;
  3243. // so we need to do substraction again with OpAtomicISub's return value.
  3244. auto *prev = spvBuilder.createAtomicOp(
  3245. spv::Op::OpAtomicISub, astContext.IntTy, counterPtr, spv::Scope::Device,
  3246. spv::MemorySemanticsMask::MaskNone, sOne, srcLoc);
  3247. index = spvBuilder.createBinaryOp(spv::Op::OpISub, astContext.IntTy, prev,
  3248. sOne, srcLoc);
  3249. }
  3250. return index;
  3251. }
  3252. bool SpirvEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
  3253. const Expr *srcExpr) {
  3254. // We are handling associated counters here. Casts should not alter which
  3255. // associated counter to manipulate.
  3256. srcExpr = srcExpr->IgnoreParenCasts();
  3257. // For parameters of forward-declared functions. We must make sure the
  3258. // associated counter variable is created. But for forward-declared functions,
  3259. // the translation of the real definition may not be started yet.
  3260. if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
  3261. declIdMapper.createFnParamCounterVar(param);
  3262. // For implicit objects of methods. Similar to the above.
  3263. else if (const auto *thisObject = dyn_cast<ImplicitParamDecl>(dstDecl))
  3264. declIdMapper.createFnParamCounterVar(thisObject);
  3265. // Handle AssocCounter#1 (see CounterVarFields comment)
  3266. if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
  3267. const auto *srcPair = getFinalACSBufferCounter(srcExpr);
  3268. if (!srcPair) {
  3269. emitFatalError("cannot find the associated counter variable",
  3270. srcExpr->getExprLoc());
  3271. return false;
  3272. }
  3273. dstPair->assign(*srcPair, spvBuilder, spvContext);
  3274. return true;
  3275. }
  3276. // Handle AssocCounter#3
  3277. llvm::SmallVector<uint32_t, 4> srcIndices;
  3278. const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
  3279. const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
  3280. if (dstFields && srcFields) {
  3281. // The destination is a struct whose fields are directly alias resources.
  3282. // But that's not necessarily true for the source, which can be deep
  3283. // nested structs. That means they will have different index "prefixes"
  3284. // for all their fields; while the "prefix" for destination is effectively
  3285. // an empty list (since it is not nested in other structs). We need to
  3286. // strip the index prefix from the source.
  3287. return dstFields->assign(*srcFields, /*dstIndices=*/{}, srcIndices,
  3288. spvBuilder, spvContext);
  3289. }
  3290. // AssocCounter#2 and AssocCounter#4 for the lhs cannot happen since the lhs
  3291. // is a stand-alone decl in this method.
  3292. return false;
  3293. }
  3294. bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr,
  3295. const Expr *srcExpr) {
  3296. dstExpr = dstExpr->IgnoreParenCasts();
  3297. srcExpr = srcExpr->IgnoreParenCasts();
  3298. const auto *dstPair = getFinalACSBufferCounter(dstExpr);
  3299. const auto *srcPair = getFinalACSBufferCounter(srcExpr);
  3300. if ((dstPair == nullptr) != (srcPair == nullptr)) {
  3301. emitFatalError("cannot handle associated counter variable assignment",
  3302. srcExpr->getExprLoc());
  3303. return false;
  3304. }
  3305. // Handle AssocCounter#1 & AssocCounter#2
  3306. if (dstPair && srcPair) {
  3307. dstPair->assign(*srcPair, spvBuilder, spvContext);
  3308. return true;
  3309. }
  3310. // Handle AssocCounter#3 & AssocCounter#4
  3311. llvm::SmallVector<uint32_t, 4> dstIndices;
  3312. llvm::SmallVector<uint32_t, 4> srcIndices;
  3313. const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
  3314. const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices);
  3315. if (dstFields && srcFields) {
  3316. return dstFields->assign(*srcFields, dstIndices, srcIndices, spvBuilder,
  3317. spvContext);
  3318. }
  3319. return false;
  3320. }
  3321. const CounterIdAliasPair *
  3322. SpirvEmitter::getFinalACSBufferCounter(const Expr *expr) {
  3323. // AssocCounter#1: referencing some stand-alone variable
  3324. if (const auto *decl = getReferencedDef(expr))
  3325. return declIdMapper.getCounterIdAliasPair(decl);
  3326. // AssocCounter#2: referencing some non-struct field
  3327. llvm::SmallVector<uint32_t, 4> rawIndices;
  3328. const auto *base = collectArrayStructIndices(
  3329. expr, /*rawIndex=*/true, &rawIndices, /*indices*/ nullptr);
  3330. const auto *decl =
  3331. (base && isa<CXXThisExpr>(base))
  3332. ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
  3333. : getReferencedDef(base);
  3334. return declIdMapper.getCounterIdAliasPair(decl, &rawIndices);
  3335. }
  3336. const CounterVarFields *SpirvEmitter::getIntermediateACSBufferCounter(
  3337. const Expr *expr, llvm::SmallVector<uint32_t, 4> *rawIndices) {
  3338. const auto *base = collectArrayStructIndices(expr, /*rawIndex=*/true,
  3339. rawIndices, /*indices*/ nullptr);
  3340. const auto *decl =
  3341. (base && isa<CXXThisExpr>(base))
  3342. // Use the decl we created to represent the implicit object
  3343. ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
  3344. // Find the referenced decl from the original source code
  3345. : getReferencedDef(base);
  3346. return declIdMapper.getCounterVarFields(decl);
  3347. }
  3348. const ImplicitParamDecl *
  3349. SpirvEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) {
  3350. const auto found = thisDecls.find(method);
  3351. if (found != thisDecls.end())
  3352. return found->second;
  3353. const std::string name = method->getName().str() + ".this";
  3354. // Create a new identifier to convey the name
  3355. auto &identifier = astContext.Idents.get(name);
  3356. return thisDecls[method] = ImplicitParamDecl::Create(
  3357. astContext, /*DC=*/nullptr, SourceLocation(), &identifier,
  3358. method->getThisType(astContext)->getPointeeType());
  3359. }
  3360. SpirvInstruction *
  3361. SpirvEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
  3362. const bool isAppend = expr->getNumArgs() == 1;
  3363. auto *zero =
  3364. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  3365. const auto *object =
  3366. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  3367. auto *bufferInfo = loadIfAliasVarRef(object);
  3368. auto *index = incDecRWACSBufferCounter(
  3369. expr, isAppend,
  3370. // We have already translated the object in the above. Avoid duplication.
  3371. /*loadObject=*/false);
  3372. auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
  3373. // If this is a variable to communicate with host e.g., ACSBuffer
  3374. // and its type is bool or vector of bool, its effective type used
  3375. // for SPIRV must be uint not bool. We must convert it to uint here.
  3376. bool needCast = false;
  3377. if (bufferInfo->getLayoutRule() != SpirvLayoutRule::Void &&
  3378. isBoolOrVecOfBoolType(bufferElemTy)) {
  3379. uint32_t vecSize = 1;
  3380. const bool isVec = isVectorType(bufferElemTy, nullptr, &vecSize);
  3381. bufferElemTy =
  3382. isVec ? astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize)
  3383. : astContext.UnsignedIntTy;
  3384. needCast = true;
  3385. }
  3386. bufferInfo = turnIntoElementPtr(object->getType(), bufferInfo, bufferElemTy,
  3387. {zero, index}, object->getExprLoc());
  3388. if (isAppend) {
  3389. // Write out the value
  3390. auto *arg0 = doExpr(expr->getArg(0));
  3391. if (!arg0)
  3392. return nullptr;
  3393. if (!arg0->isRValue()) {
  3394. arg0 = spvBuilder.createLoad(bufferElemTy, arg0,
  3395. expr->getArg(0)->getExprLoc());
  3396. }
  3397. if (needCast &&
  3398. !isSameType(astContext, bufferElemTy, arg0->getAstResultType())) {
  3399. arg0 = castToType(arg0, arg0->getAstResultType(), bufferElemTy,
  3400. expr->getArg(0)->getExprLoc());
  3401. }
  3402. storeValue(bufferInfo, arg0, bufferElemTy, expr->getCallee()->getExprLoc());
  3403. return 0;
  3404. } else {
  3405. // Note that we are returning a pointer (lvalue) here inorder to further
  3406. // acess the fields in this element, e.g., buffer.Consume().a.b. So we
  3407. // cannot forcefully set all normal function calls as returning rvalue.
  3408. return bufferInfo;
  3409. }
  3410. }
  3411. SpirvInstruction *
  3412. SpirvEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
  3413. // TODO: handle multiple stream-output objects
  3414. const auto *object =
  3415. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  3416. const auto *stream = cast<DeclRefExpr>(object)->getDecl();
  3417. auto *value = doExpr(expr->getArg(0));
  3418. declIdMapper.writeBackOutputStream(stream, stream->getType(), value);
  3419. spvBuilder.createEmitVertex(expr->getExprLoc());
  3420. return nullptr;
  3421. }
  3422. SpirvInstruction *
  3423. SpirvEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) {
  3424. // TODO: handle multiple stream-output objects
  3425. spvBuilder.createEndPrimitive(expr->getExprLoc());
  3426. return 0;
  3427. }
  3428. SpirvInstruction *
  3429. SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount,
  3430. SpirvInstruction *sampleIndex,
  3431. SourceLocation loc) {
  3432. struct Float2 {
  3433. float x;
  3434. float y;
  3435. };
  3436. static const Float2 pos2[] = {
  3437. {4.0 / 16.0, 4.0 / 16.0},
  3438. {-4.0 / 16.0, -4.0 / 16.0},
  3439. };
  3440. static const Float2 pos4[] = {
  3441. {-2.0 / 16.0, -6.0 / 16.0},
  3442. {6.0 / 16.0, -2.0 / 16.0},
  3443. {-6.0 / 16.0, 2.0 / 16.0},
  3444. {2.0 / 16.0, 6.0 / 16.0},
  3445. };
  3446. static const Float2 pos8[] = {
  3447. {1.0 / 16.0, -3.0 / 16.0}, {-1.0 / 16.0, 3.0 / 16.0},
  3448. {5.0 / 16.0, 1.0 / 16.0}, {-3.0 / 16.0, -5.0 / 16.0},
  3449. {-5.0 / 16.0, 5.0 / 16.0}, {-7.0 / 16.0, -1.0 / 16.0},
  3450. {3.0 / 16.0, 7.0 / 16.0}, {7.0 / 16.0, -7.0 / 16.0},
  3451. };
  3452. static const Float2 pos16[] = {
  3453. {1.0 / 16.0, 1.0 / 16.0}, {-1.0 / 16.0, -3.0 / 16.0},
  3454. {-3.0 / 16.0, 2.0 / 16.0}, {4.0 / 16.0, -1.0 / 16.0},
  3455. {-5.0 / 16.0, -2.0 / 16.0}, {2.0 / 16.0, 5.0 / 16.0},
  3456. {5.0 / 16.0, 3.0 / 16.0}, {3.0 / 16.0, -5.0 / 16.0},
  3457. {-2.0 / 16.0, 6.0 / 16.0}, {0.0 / 16.0, -7.0 / 16.0},
  3458. {-4.0 / 16.0, -6.0 / 16.0}, {-6.0 / 16.0, 4.0 / 16.0},
  3459. {-8.0 / 16.0, 0.0 / 16.0}, {7.0 / 16.0, -4.0 / 16.0},
  3460. {6.0 / 16.0, 7.0 / 16.0}, {-7.0 / 16.0, -8.0 / 16.0},
  3461. };
  3462. // We are emitting the SPIR-V for the following HLSL source code:
  3463. //
  3464. // float2 position;
  3465. //
  3466. // if (count == 2) {
  3467. // position = pos2[index];
  3468. // }
  3469. // else if (count == 4) {
  3470. // position = pos4[index];
  3471. // }
  3472. // else if (count == 8) {
  3473. // position = pos8[index];
  3474. // }
  3475. // else if (count == 16) {
  3476. // position = pos16[index];
  3477. // }
  3478. // else {
  3479. // position = float2(0.0f, 0.0f);
  3480. // }
  3481. const auto v2f32Type = astContext.getExtVectorType(astContext.FloatTy, 2);
  3482. // Creates a SPIR-V function scope variable of type float2[len].
  3483. const auto createArray = [this, v2f32Type, loc](const Float2 *ptr,
  3484. uint32_t len) {
  3485. llvm::SmallVector<SpirvConstant *, 16> components;
  3486. for (uint32_t i = 0; i < len; ++i) {
  3487. auto *x = spvBuilder.getConstantFloat(astContext.FloatTy,
  3488. llvm::APFloat(ptr[i].x));
  3489. auto *y = spvBuilder.getConstantFloat(astContext.FloatTy,
  3490. llvm::APFloat(ptr[i].y));
  3491. components.push_back(spvBuilder.getConstantComposite(v2f32Type, {x, y}));
  3492. }
  3493. const auto arrType = astContext.getConstantArrayType(
  3494. v2f32Type, llvm::APInt(32, len), clang::ArrayType::Normal, 0);
  3495. auto *val = spvBuilder.getConstantComposite(arrType, components);
  3496. const std::string varName =
  3497. "var.GetSamplePosition.data." + std::to_string(len);
  3498. auto *var = spvBuilder.addFnVar(arrType, loc, varName);
  3499. spvBuilder.createStore(var, val, loc);
  3500. return var;
  3501. };
  3502. auto *pos2Arr = createArray(pos2, 2);
  3503. auto *pos4Arr = createArray(pos4, 4);
  3504. auto *pos8Arr = createArray(pos8, 8);
  3505. auto *pos16Arr = createArray(pos16, 16);
  3506. auto *resultVar =
  3507. spvBuilder.addFnVar(v2f32Type, loc, "var.GetSamplePosition.result");
  3508. auto *then2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then2");
  3509. auto *then4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then4");
  3510. auto *then8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then8");
  3511. auto *then16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then16");
  3512. auto *else2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else2");
  3513. auto *else4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else4");
  3514. auto *else8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else8");
  3515. auto *else16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else16");
  3516. auto *merge2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge2");
  3517. auto *merge4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge4");
  3518. auto *merge8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge8");
  3519. auto *merge16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge16");
  3520. // if (count == 2) {
  3521. const auto check2 = spvBuilder.createBinaryOp(
  3522. spv::Op::OpIEqual, astContext.BoolTy, sampleCount,
  3523. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)),
  3524. loc);
  3525. spvBuilder.createConditionalBranch(check2, then2BB, else2BB, loc, merge2BB);
  3526. spvBuilder.addSuccessor(then2BB);
  3527. spvBuilder.addSuccessor(else2BB);
  3528. spvBuilder.setMergeTarget(merge2BB);
  3529. // position = pos2[index];
  3530. // }
  3531. spvBuilder.setInsertPoint(then2BB);
  3532. auto *ac =
  3533. spvBuilder.createAccessChain(v2f32Type, pos2Arr, {sampleIndex}, loc);
  3534. spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
  3535. loc);
  3536. spvBuilder.createBranch(merge2BB, loc);
  3537. spvBuilder.addSuccessor(merge2BB);
  3538. // else if (count == 4) {
  3539. spvBuilder.setInsertPoint(else2BB);
  3540. const auto check4 = spvBuilder.createBinaryOp(
  3541. spv::Op::OpIEqual, astContext.BoolTy, sampleCount,
  3542. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 4)),
  3543. loc);
  3544. spvBuilder.createConditionalBranch(check4, then4BB, else4BB, loc, merge4BB);
  3545. spvBuilder.addSuccessor(then4BB);
  3546. spvBuilder.addSuccessor(else4BB);
  3547. spvBuilder.setMergeTarget(merge4BB);
  3548. // position = pos4[index];
  3549. // }
  3550. spvBuilder.setInsertPoint(then4BB);
  3551. ac = spvBuilder.createAccessChain(v2f32Type, pos4Arr, {sampleIndex}, loc);
  3552. spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
  3553. loc);
  3554. spvBuilder.createBranch(merge4BB, loc);
  3555. spvBuilder.addSuccessor(merge4BB);
  3556. // else if (count == 8) {
  3557. spvBuilder.setInsertPoint(else4BB);
  3558. const auto check8 = spvBuilder.createBinaryOp(
  3559. spv::Op::OpIEqual, astContext.BoolTy, sampleCount,
  3560. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 8)),
  3561. loc);
  3562. spvBuilder.createConditionalBranch(check8, then8BB, else8BB, loc, merge8BB);
  3563. spvBuilder.addSuccessor(then8BB);
  3564. spvBuilder.addSuccessor(else8BB);
  3565. spvBuilder.setMergeTarget(merge8BB);
  3566. // position = pos8[index];
  3567. // }
  3568. spvBuilder.setInsertPoint(then8BB);
  3569. ac = spvBuilder.createAccessChain(v2f32Type, pos8Arr, {sampleIndex}, loc);
  3570. spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
  3571. loc);
  3572. spvBuilder.createBranch(merge8BB, loc);
  3573. spvBuilder.addSuccessor(merge8BB);
  3574. // else if (count == 16) {
  3575. spvBuilder.setInsertPoint(else8BB);
  3576. const auto check16 = spvBuilder.createBinaryOp(
  3577. spv::Op::OpIEqual, astContext.BoolTy, sampleCount,
  3578. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16)),
  3579. loc);
  3580. spvBuilder.createConditionalBranch(check16, then16BB, else16BB, loc,
  3581. merge16BB);
  3582. spvBuilder.addSuccessor(then16BB);
  3583. spvBuilder.addSuccessor(else16BB);
  3584. spvBuilder.setMergeTarget(merge16BB);
  3585. // position = pos16[index];
  3586. // }
  3587. spvBuilder.setInsertPoint(then16BB);
  3588. ac = spvBuilder.createAccessChain(v2f32Type, pos16Arr, {sampleIndex}, loc);
  3589. spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc),
  3590. loc);
  3591. spvBuilder.createBranch(merge16BB, loc);
  3592. spvBuilder.addSuccessor(merge16BB);
  3593. // else {
  3594. // position = float2(0.0f, 0.0f);
  3595. // }
  3596. spvBuilder.setInsertPoint(else16BB);
  3597. auto *zero =
  3598. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(0.0f));
  3599. auto *v2f32Zero = spvBuilder.getConstantComposite(v2f32Type, {zero, zero});
  3600. spvBuilder.createStore(resultVar, v2f32Zero, loc);
  3601. spvBuilder.createBranch(merge16BB, loc);
  3602. spvBuilder.addSuccessor(merge16BB);
  3603. spvBuilder.setInsertPoint(merge16BB);
  3604. spvBuilder.createBranch(merge8BB, loc);
  3605. spvBuilder.addSuccessor(merge8BB);
  3606. spvBuilder.setInsertPoint(merge8BB);
  3607. spvBuilder.createBranch(merge4BB, loc);
  3608. spvBuilder.addSuccessor(merge4BB);
  3609. spvBuilder.setInsertPoint(merge4BB);
  3610. spvBuilder.createBranch(merge2BB, loc);
  3611. spvBuilder.addSuccessor(merge2BB);
  3612. spvBuilder.setInsertPoint(merge2BB);
  3613. return spvBuilder.createLoad(v2f32Type, resultVar, loc);
  3614. }
  3615. SpirvInstruction *
  3616. SpirvEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
  3617. const FunctionDecl *callee = expr->getDirectCallee();
  3618. llvm::StringRef group;
  3619. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  3620. if (hlsl::GetIntrinsicOp(callee, opcode, group)) {
  3621. return processIntrinsicMemberCall(expr,
  3622. static_cast<hlsl::IntrinsicOp>(opcode));
  3623. }
  3624. return processCall(expr);
  3625. }
  3626. void SpirvEmitter::handleOffsetInMethodCall(const CXXMemberCallExpr *expr,
  3627. uint32_t index,
  3628. SpirvInstruction **constOffset,
  3629. SpirvInstruction **varOffset) {
  3630. assert(constOffset && varOffset);
  3631. // Ensure the given arg index is not out-of-range.
  3632. assert(index < expr->getNumArgs());
  3633. *constOffset = *varOffset = nullptr; // Initialize both first
  3634. if ((*constOffset = tryToEvaluateAsConst(expr->getArg(index))))
  3635. return; // Constant offset
  3636. else
  3637. *varOffset = doExpr(expr->getArg(index));
  3638. }
  3639. SpirvInstruction *
  3640. SpirvEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
  3641. hlsl::IntrinsicOp opcode) {
  3642. using namespace hlsl;
  3643. SpirvInstruction *retVal = nullptr;
  3644. switch (opcode) {
  3645. case IntrinsicOp::MOP_Sample:
  3646. retVal = processTextureSampleGather(expr, /*isSample=*/true);
  3647. break;
  3648. case IntrinsicOp::MOP_Gather:
  3649. retVal = processTextureSampleGather(expr, /*isSample=*/false);
  3650. break;
  3651. case IntrinsicOp::MOP_SampleBias:
  3652. retVal = processTextureSampleBiasLevel(expr, /*isBias=*/true);
  3653. break;
  3654. case IntrinsicOp::MOP_SampleLevel:
  3655. retVal = processTextureSampleBiasLevel(expr, /*isBias=*/false);
  3656. break;
  3657. case IntrinsicOp::MOP_SampleGrad:
  3658. retVal = processTextureSampleGrad(expr);
  3659. break;
  3660. case IntrinsicOp::MOP_SampleCmp:
  3661. retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true);
  3662. break;
  3663. case IntrinsicOp::MOP_SampleCmpLevelZero:
  3664. retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false);
  3665. break;
  3666. case IntrinsicOp::MOP_GatherRed:
  3667. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0);
  3668. break;
  3669. case IntrinsicOp::MOP_GatherGreen:
  3670. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1);
  3671. break;
  3672. case IntrinsicOp::MOP_GatherBlue:
  3673. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2);
  3674. break;
  3675. case IntrinsicOp::MOP_GatherAlpha:
  3676. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3);
  3677. break;
  3678. case IntrinsicOp::MOP_GatherCmp:
  3679. retVal = processTextureGatherCmp(expr);
  3680. break;
  3681. case IntrinsicOp::MOP_GatherCmpRed:
  3682. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0);
  3683. break;
  3684. case IntrinsicOp::MOP_Load:
  3685. return processBufferTextureLoad(expr);
  3686. case IntrinsicOp::MOP_Load2:
  3687. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false);
  3688. case IntrinsicOp::MOP_Load3:
  3689. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false);
  3690. case IntrinsicOp::MOP_Load4:
  3691. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false);
  3692. case IntrinsicOp::MOP_Store:
  3693. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true);
  3694. case IntrinsicOp::MOP_Store2:
  3695. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true);
  3696. case IntrinsicOp::MOP_Store3:
  3697. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true);
  3698. case IntrinsicOp::MOP_Store4:
  3699. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true);
  3700. case IntrinsicOp::MOP_GetDimensions:
  3701. retVal = processGetDimensions(expr);
  3702. break;
  3703. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  3704. retVal = processTextureLevelOfDetail(expr, /* unclamped */ false);
  3705. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  3706. retVal = processTextureLevelOfDetail(expr, /* unclamped */ true);
  3707. break;
  3708. case IntrinsicOp::MOP_IncrementCounter:
  3709. retVal =
  3710. spvBuilder.createUnaryOp(spv::Op::OpBitcast, astContext.UnsignedIntTy,
  3711. incDecRWACSBufferCounter(expr, /*isInc*/ true),
  3712. expr->getCallee()->getExprLoc());
  3713. break;
  3714. case IntrinsicOp::MOP_DecrementCounter:
  3715. retVal = spvBuilder.createUnaryOp(
  3716. spv::Op::OpBitcast, astContext.UnsignedIntTy,
  3717. incDecRWACSBufferCounter(expr, /*isInc*/ false),
  3718. expr->getCallee()->getExprLoc());
  3719. break;
  3720. case IntrinsicOp::MOP_Append:
  3721. if (hlsl::IsHLSLStreamOutputType(
  3722. expr->getImplicitObjectArgument()->getType()))
  3723. return processStreamOutputAppend(expr);
  3724. else
  3725. return processACSBufferAppendConsume(expr);
  3726. case IntrinsicOp::MOP_Consume:
  3727. return processACSBufferAppendConsume(expr);
  3728. case IntrinsicOp::MOP_RestartStrip:
  3729. retVal = processStreamOutputRestart(expr);
  3730. break;
  3731. case IntrinsicOp::MOP_InterlockedAdd:
  3732. case IntrinsicOp::MOP_InterlockedAnd:
  3733. case IntrinsicOp::MOP_InterlockedOr:
  3734. case IntrinsicOp::MOP_InterlockedXor:
  3735. case IntrinsicOp::MOP_InterlockedUMax:
  3736. case IntrinsicOp::MOP_InterlockedUMin:
  3737. case IntrinsicOp::MOP_InterlockedMax:
  3738. case IntrinsicOp::MOP_InterlockedMin:
  3739. case IntrinsicOp::MOP_InterlockedExchange:
  3740. case IntrinsicOp::MOP_InterlockedCompareExchange:
  3741. case IntrinsicOp::MOP_InterlockedCompareStore:
  3742. retVal = processRWByteAddressBufferAtomicMethods(opcode, expr);
  3743. break;
  3744. case IntrinsicOp::MOP_GetSamplePosition:
  3745. retVal = processGetSamplePosition(expr);
  3746. break;
  3747. case IntrinsicOp::MOP_SubpassLoad:
  3748. retVal = processSubpassLoad(expr);
  3749. break;
  3750. case IntrinsicOp::MOP_GatherCmpGreen:
  3751. case IntrinsicOp::MOP_GatherCmpBlue:
  3752. case IntrinsicOp::MOP_GatherCmpAlpha:
  3753. emitError("no equivalent for %0 intrinsic method in Vulkan",
  3754. expr->getCallee()->getExprLoc())
  3755. << expr->getMethodDecl()->getName();
  3756. return nullptr;
  3757. case IntrinsicOp::MOP_TraceRayInline:
  3758. return processTraceRayInline(expr);
  3759. case IntrinsicOp::MOP_Abort:
  3760. case IntrinsicOp::MOP_CandidateGeometryIndex:
  3761. case IntrinsicOp::MOP_CandidateInstanceContributionToHitGroupIndex:
  3762. case IntrinsicOp::MOP_CandidateInstanceID:
  3763. case IntrinsicOp::MOP_CandidateInstanceIndex:
  3764. case IntrinsicOp::MOP_CandidateObjectRayDirection:
  3765. case IntrinsicOp::MOP_CandidateObjectRayOrigin:
  3766. case IntrinsicOp::MOP_CandidateObjectToWorld3x4:
  3767. case IntrinsicOp::MOP_CandidateObjectToWorld4x3:
  3768. case IntrinsicOp::MOP_CandidatePrimitiveIndex:
  3769. case IntrinsicOp::MOP_CandidateProceduralPrimitiveNonOpaque:
  3770. case IntrinsicOp::MOP_CandidateTriangleBarycentrics:
  3771. case IntrinsicOp::MOP_CandidateTriangleFrontFace:
  3772. case IntrinsicOp::MOP_CandidateTriangleRayT:
  3773. case IntrinsicOp::MOP_CandidateType:
  3774. case IntrinsicOp::MOP_CandidateWorldToObject3x4:
  3775. case IntrinsicOp::MOP_CandidateWorldToObject4x3:
  3776. case IntrinsicOp::MOP_CommitNonOpaqueTriangleHit:
  3777. case IntrinsicOp::MOP_CommitProceduralPrimitiveHit:
  3778. case IntrinsicOp::MOP_CommittedGeometryIndex:
  3779. case IntrinsicOp::MOP_CommittedInstanceContributionToHitGroupIndex:
  3780. case IntrinsicOp::MOP_CommittedInstanceID:
  3781. case IntrinsicOp::MOP_CommittedInstanceIndex:
  3782. case IntrinsicOp::MOP_CommittedObjectRayDirection:
  3783. case IntrinsicOp::MOP_CommittedObjectRayOrigin:
  3784. case IntrinsicOp::MOP_CommittedObjectToWorld3x4:
  3785. case IntrinsicOp::MOP_CommittedObjectToWorld4x3:
  3786. case IntrinsicOp::MOP_CommittedPrimitiveIndex:
  3787. case IntrinsicOp::MOP_CommittedRayT:
  3788. case IntrinsicOp::MOP_CommittedStatus:
  3789. case IntrinsicOp::MOP_CommittedTriangleBarycentrics:
  3790. case IntrinsicOp::MOP_CommittedTriangleFrontFace:
  3791. case IntrinsicOp::MOP_CommittedWorldToObject3x4:
  3792. case IntrinsicOp::MOP_CommittedWorldToObject4x3:
  3793. case IntrinsicOp::MOP_Proceed:
  3794. case IntrinsicOp::MOP_RayFlags:
  3795. case IntrinsicOp::MOP_RayTMin:
  3796. case IntrinsicOp::MOP_WorldRayDirection:
  3797. case IntrinsicOp::MOP_WorldRayOrigin:
  3798. return processRayQueryIntrinsics(expr, opcode);
  3799. default:
  3800. emitError("intrinsic '%0' method unimplemented",
  3801. expr->getCallee()->getExprLoc())
  3802. << expr->getDirectCallee()->getName();
  3803. return nullptr;
  3804. }
  3805. if (retVal)
  3806. retVal->setRValue();
  3807. return retVal;
  3808. }
  3809. SpirvInstruction *SpirvEmitter::createImageSample(
  3810. QualType retType, QualType imageType, SpirvInstruction *image,
  3811. SpirvInstruction *sampler, SpirvInstruction *coordinate,
  3812. SpirvInstruction *compareVal, SpirvInstruction *bias, SpirvInstruction *lod,
  3813. std::pair<SpirvInstruction *, SpirvInstruction *> grad,
  3814. SpirvInstruction *constOffset, SpirvInstruction *varOffset,
  3815. SpirvInstruction *constOffsets, SpirvInstruction *sample,
  3816. SpirvInstruction *minLod, SpirvInstruction *residencyCodeId,
  3817. SourceLocation loc) {
  3818. // SampleDref* instructions in SPIR-V always return a scalar.
  3819. // They also have the correct type in HLSL.
  3820. if (compareVal) {
  3821. return spvBuilder.createImageSample(retType, imageType, image, sampler,
  3822. coordinate, compareVal, bias, lod, grad,
  3823. constOffset, varOffset, constOffsets,
  3824. sample, minLod, residencyCodeId, loc);
  3825. }
  3826. // Non-Dref Sample instructions in SPIR-V must always return a vec4.
  3827. auto texelType = retType;
  3828. QualType elemType = {};
  3829. uint32_t retVecSize = 0;
  3830. if (isVectorType(retType, &elemType, &retVecSize) && retVecSize != 4) {
  3831. texelType = astContext.getExtVectorType(elemType, 4);
  3832. } else if (isScalarType(retType)) {
  3833. retVecSize = 1;
  3834. elemType = retType;
  3835. texelType = astContext.getExtVectorType(retType, 4);
  3836. }
  3837. // The Lod and Grad image operands requires explicit-lod instructions.
  3838. // Otherwise we use implicit-lod instructions.
  3839. const bool isExplicit = lod || (grad.first && grad.second);
  3840. // Implicit-lod instructions are only allowed in pixel shader.
  3841. if (!spvContext.isPS() && !isExplicit)
  3842. emitError("sampling with implicit lod is only allowed in fragment shaders",
  3843. loc);
  3844. auto *retVal = spvBuilder.createImageSample(
  3845. texelType, imageType, image, sampler, coordinate, compareVal, bias, lod,
  3846. grad, constOffset, varOffset, constOffsets, sample, minLod,
  3847. residencyCodeId, loc);
  3848. // Extract smaller vector from the vec4 result if necessary.
  3849. if (texelType != retType) {
  3850. retVal = extractVecFromVec4(retVal, retVecSize, elemType, loc);
  3851. }
  3852. return retVal;
  3853. }
  3854. SpirvInstruction *
  3855. SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
  3856. const bool isSample) {
  3857. // Signatures:
  3858. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D:
  3859. // DXGI_FORMAT Object.Sample(sampler_state S,
  3860. // float Location
  3861. // [, int Offset]
  3862. // [, float Clamp]
  3863. // [, out uint Status]);
  3864. //
  3865. // For TextureCube and TextureCubeArray:
  3866. // DXGI_FORMAT Object.Sample(sampler_state S,
  3867. // float Location
  3868. // [, float Clamp]
  3869. // [, out uint Status]);
  3870. //
  3871. // For Texture2D/Texture2DArray:
  3872. // <Template Type>4 Object.Gather(sampler_state S,
  3873. // float2|3|4 Location,
  3874. // int2 Offset
  3875. // [, uint Status]);
  3876. //
  3877. // For TextureCube/TextureCubeArray:
  3878. // <Template Type>4 Object.Gather(sampler_state S,
  3879. // float2|3|4 Location
  3880. // [, uint Status]);
  3881. //
  3882. // Other Texture types do not have a Gather method.
  3883. const auto numArgs = expr->getNumArgs();
  3884. const auto loc = expr->getExprLoc();
  3885. const bool hasStatusArg =
  3886. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3887. SpirvInstruction *clamp = nullptr;
  3888. if (numArgs > 2 && expr->getArg(2)->getType()->isFloatingType())
  3889. clamp = doExpr(expr->getArg(2));
  3890. else if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  3891. clamp = doExpr(expr->getArg(3));
  3892. const bool hasClampArg = (clamp != 0);
  3893. const auto status =
  3894. hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  3895. // Subtract 1 for status (if it exists), subtract 1 for clamp (if it exists),
  3896. // and subtract 2 for sampler_state and location.
  3897. const bool hasOffsetArg = numArgs - hasStatusArg - hasClampArg - 2 > 0;
  3898. const auto *imageExpr = expr->getImplicitObjectArgument();
  3899. const QualType imageType = imageExpr->getType();
  3900. auto *image = loadIfGLValue(imageExpr);
  3901. auto *sampler = doExpr(expr->getArg(0));
  3902. auto *coordinate = doExpr(expr->getArg(1));
  3903. // .Sample()/.Gather() may have a third optional paramter for offset.
  3904. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  3905. if (hasOffsetArg)
  3906. handleOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  3907. const auto retType = expr->getDirectCallee()->getReturnType();
  3908. if (isSample) {
  3909. return createImageSample(retType, imageType, image, sampler, coordinate,
  3910. /*compareVal*/ nullptr, /*bias*/ nullptr,
  3911. /*lod*/ nullptr, std::make_pair(nullptr, nullptr),
  3912. constOffset, varOffset,
  3913. /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr,
  3914. /*minLod*/ clamp, status,
  3915. expr->getCallee()->getLocStart());
  3916. } else {
  3917. return spvBuilder.createImageGather(
  3918. retType, imageType, image, sampler, coordinate,
  3919. // .Gather() doc says we return four components of red data.
  3920. spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)),
  3921. /*compareVal*/ nullptr, constOffset, varOffset,
  3922. /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr, status, loc);
  3923. }
  3924. }
  3925. SpirvInstruction *
  3926. SpirvEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
  3927. const bool isBias) {
  3928. // Signatures:
  3929. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  3930. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  3931. // float Location,
  3932. // float Bias
  3933. // [, int Offset]
  3934. // [, float clamp]
  3935. // [, out uint Status]);
  3936. //
  3937. // For TextureCube and TextureCubeArray:
  3938. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  3939. // float Location,
  3940. // float Bias
  3941. // [, float clamp]
  3942. // [, out uint Status]);
  3943. //
  3944. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  3945. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  3946. // float Location,
  3947. // float LOD
  3948. // [, int Offset]
  3949. // [, out uint Status]);
  3950. //
  3951. // For TextureCube and TextureCubeArray:
  3952. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  3953. // float Location,
  3954. // float LOD
  3955. // [, out uint Status]);
  3956. const auto numArgs = expr->getNumArgs();
  3957. const bool hasStatusArg =
  3958. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3959. auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  3960. SpirvInstruction *clamp = nullptr;
  3961. // The .SampleLevel() methods do not take the clamp argument.
  3962. if (isBias) {
  3963. if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  3964. clamp = doExpr(expr->getArg(3));
  3965. else if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  3966. clamp = doExpr(expr->getArg(4));
  3967. }
  3968. const bool hasClampArg = clamp != nullptr;
  3969. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  3970. // and 3 for sampler_state, location, and Bias/LOD.
  3971. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 3 > 0;
  3972. const auto *imageExpr = expr->getImplicitObjectArgument();
  3973. const QualType imageType = imageExpr->getType();
  3974. auto *image = loadIfGLValue(imageExpr);
  3975. auto *sampler = doExpr(expr->getArg(0));
  3976. auto *coordinate = doExpr(expr->getArg(1));
  3977. SpirvInstruction *lod = nullptr;
  3978. SpirvInstruction *bias = nullptr;
  3979. if (isBias) {
  3980. bias = doExpr(expr->getArg(2));
  3981. } else {
  3982. lod = doExpr(expr->getArg(2));
  3983. }
  3984. // If offset is present in .Bias()/.SampleLevel(), it is the fourth argument.
  3985. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  3986. if (hasOffsetArg)
  3987. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  3988. const auto retType = expr->getDirectCallee()->getReturnType();
  3989. return createImageSample(
  3990. retType, imageType, image, sampler, coordinate,
  3991. /*compareVal*/ nullptr, bias, lod, std::make_pair(nullptr, nullptr),
  3992. constOffset, varOffset,
  3993. /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr,
  3994. /*minLod*/ clamp, status, expr->getCallee()->getLocStart());
  3995. }
  3996. SpirvInstruction *
  3997. SpirvEmitter::processTextureSampleGrad(const CXXMemberCallExpr *expr) {
  3998. // Signature:
  3999. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  4000. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  4001. // float Location,
  4002. // float DDX,
  4003. // float DDY
  4004. // [, int Offset]
  4005. // [, float Clamp]
  4006. // [, out uint Status]);
  4007. //
  4008. // For TextureCube and TextureCubeArray:
  4009. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  4010. // float Location,
  4011. // float DDX,
  4012. // float DDY
  4013. // [, float Clamp]
  4014. // [, out uint Status]);
  4015. const auto numArgs = expr->getNumArgs();
  4016. const bool hasStatusArg =
  4017. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  4018. auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  4019. SpirvInstruction *clamp = nullptr;
  4020. if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  4021. clamp = doExpr(expr->getArg(4));
  4022. else if (numArgs > 5 && expr->getArg(5)->getType()->isFloatingType())
  4023. clamp = doExpr(expr->getArg(5));
  4024. const bool hasClampArg = clamp != nullptr;
  4025. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  4026. // and 4 for sampler_state, location, DDX, and DDY;
  4027. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 4 > 0;
  4028. const auto *imageExpr = expr->getImplicitObjectArgument();
  4029. const QualType imageType = imageExpr->getType();
  4030. auto *image = loadIfGLValue(imageExpr);
  4031. auto *sampler = doExpr(expr->getArg(0));
  4032. auto *coordinate = doExpr(expr->getArg(1));
  4033. auto *ddx = doExpr(expr->getArg(2));
  4034. auto *ddy = doExpr(expr->getArg(3));
  4035. // If offset is present in .SampleGrad(), it is the fifth argument.
  4036. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  4037. if (hasOffsetArg)
  4038. handleOffsetInMethodCall(expr, 4, &constOffset, &varOffset);
  4039. const auto retType = expr->getDirectCallee()->getReturnType();
  4040. return createImageSample(
  4041. retType, imageType, image, sampler, coordinate,
  4042. /*compareVal*/ nullptr, /*bias*/ nullptr,
  4043. /*lod*/ nullptr, std::make_pair(ddx, ddy), constOffset, varOffset,
  4044. /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr,
  4045. /*minLod*/ clamp, status, expr->getCallee()->getLocStart());
  4046. }
  4047. SpirvInstruction *
  4048. SpirvEmitter::processTextureSampleCmpCmpLevelZero(const CXXMemberCallExpr *expr,
  4049. const bool isCmp) {
  4050. // .SampleCmp() Signature:
  4051. //
  4052. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray:
  4053. // float Object.SampleCmp(
  4054. // SamplerComparisonState S,
  4055. // float Location,
  4056. // float CompareValue
  4057. // [, int Offset]
  4058. // [, float Clamp]
  4059. // [, out uint Status]
  4060. // );
  4061. //
  4062. // For TextureCube and TextureCubeArray:
  4063. // float Object.SampleCmp(
  4064. // SamplerComparisonState S,
  4065. // float Location,
  4066. // float CompareValue
  4067. // [, float Clamp]
  4068. // [, out uint Status]
  4069. // );
  4070. //
  4071. // .SampleCmpLevelZero() is identical to .SampleCmp() on mipmap level 0 only.
  4072. // It never takes a clamp argument, which is good because lod and clamp may
  4073. // not be used together.
  4074. //
  4075. // .SampleCmpLevelZero() Signature:
  4076. //
  4077. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray:
  4078. // float Object.SampleCmpLevelZero(
  4079. // SamplerComparisonState S,
  4080. // float Location,
  4081. // float CompareValue
  4082. // [, int Offset]
  4083. // [, out uint Status]
  4084. // );
  4085. //
  4086. // For TextureCube and TextureCubeArray:
  4087. // float Object.SampleCmpLevelZero(
  4088. // SamplerComparisonState S,
  4089. // float Location,
  4090. // float CompareValue
  4091. // [, out uint Status]
  4092. // );
  4093. const auto numArgs = expr->getNumArgs();
  4094. const bool hasStatusArg =
  4095. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  4096. auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  4097. SpirvInstruction *clamp = nullptr;
  4098. // The .SampleCmpLevelZero() methods do not take the clamp argument.
  4099. if (isCmp) {
  4100. if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  4101. clamp = doExpr(expr->getArg(3));
  4102. else if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  4103. clamp = doExpr(expr->getArg(4));
  4104. }
  4105. const bool hasClampArg = clamp != nullptr;
  4106. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  4107. // and 3 for sampler_state, location, and compare_value.
  4108. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 3 > 0;
  4109. const auto *imageExpr = expr->getImplicitObjectArgument();
  4110. auto *image = loadIfGLValue(imageExpr);
  4111. auto *sampler = doExpr(expr->getArg(0));
  4112. auto *coordinate = doExpr(expr->getArg(1));
  4113. auto *compareVal = doExpr(expr->getArg(2));
  4114. // If offset is present in .SampleCmp(), it will be the fourth argument.
  4115. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  4116. if (hasOffsetArg)
  4117. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  4118. auto *lod = isCmp ? nullptr
  4119. : spvBuilder.getConstantFloat(astContext.FloatTy,
  4120. llvm::APFloat(0.0f));
  4121. const auto retType = expr->getDirectCallee()->getReturnType();
  4122. const auto imageType = imageExpr->getType();
  4123. return createImageSample(
  4124. retType, imageType, image, sampler, coordinate, compareVal,
  4125. /*bias*/ nullptr, lod, std::make_pair(nullptr, nullptr), constOffset,
  4126. varOffset,
  4127. /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr, /*minLod*/ clamp,
  4128. status, expr->getCallee()->getLocStart());
  4129. }
  4130. SpirvInstruction *
  4131. SpirvEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
  4132. // Signature:
  4133. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D:
  4134. // ret Object.Load(int Location
  4135. // [, int Offset]
  4136. // [, uint status]);
  4137. //
  4138. // For Texture2DMS and Texture2DMSArray, there is one additional argument:
  4139. // ret Object.Load(int Location
  4140. // [, int SampleIndex]
  4141. // [, int Offset]
  4142. // [, uint status]);
  4143. //
  4144. // For (RW)Buffer, RWTexture1D, RWTexture1DArray, RWTexture2D,
  4145. // RWTexture2DArray, RWTexture3D:
  4146. // ret Object.Load (int Location
  4147. // [, uint status]);
  4148. //
  4149. // Note: (RW)ByteAddressBuffer and (RW)StructuredBuffer types also have Load
  4150. // methods that take an additional Status argument. However, since these types
  4151. // are not represented as OpTypeImage in SPIR-V, we don't have a way of
  4152. // figuring out the Residency Code for them. Therefore having the Status
  4153. // argument for these types is not supported.
  4154. //
  4155. // For (RW)ByteAddressBuffer:
  4156. // ret Object.{Load,Load2,Load3,Load4} (int Location
  4157. // [, uint status]);
  4158. //
  4159. // For (RW)StructuredBuffer:
  4160. // ret Object.Load (int Location
  4161. // [, uint status]);
  4162. //
  4163. const auto *object = expr->getImplicitObjectArgument();
  4164. const auto objectType = object->getType();
  4165. if (isRWByteAddressBuffer(objectType) || isByteAddressBuffer(objectType))
  4166. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
  4167. if (isStructuredBuffer(objectType))
  4168. return processStructuredBufferLoad(expr);
  4169. const auto numArgs = expr->getNumArgs();
  4170. const auto *locationArg = expr->getArg(0);
  4171. const bool textureMS = isTextureMS(objectType);
  4172. const bool hasStatusArg =
  4173. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  4174. auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr;
  4175. auto loc = expr->getExprLoc();
  4176. if (isBuffer(objectType) || isRWBuffer(objectType) || isRWTexture(objectType))
  4177. return processBufferTextureLoad(object, doExpr(locationArg),
  4178. /*constOffset*/ nullptr,
  4179. /*varOffset*/ nullptr, /*lod*/ nullptr,
  4180. /*residencyCode*/ status, loc);
  4181. // Subtract 1 for status (if it exists), and 1 for sampleIndex (if it exists),
  4182. // and 1 for location.
  4183. const bool hasOffsetArg = numArgs - hasStatusArg - textureMS - 1 > 0;
  4184. if (isTexture(objectType)) {
  4185. // .Load() has a second optional paramter for offset.
  4186. SpirvInstruction *location = doExpr(locationArg);
  4187. SpirvInstruction *constOffset = nullptr, *varOffset = nullptr;
  4188. SpirvInstruction *coordinate = location, *lod = nullptr;
  4189. if (textureMS) {
  4190. // SampleIndex is only available when the Object is of Texture2DMS or
  4191. // Texture2DMSArray types. Under those cases, Offset will be the third
  4192. // parameter (index 2).
  4193. lod = doExpr(expr->getArg(1));
  4194. if (hasOffsetArg)
  4195. handleOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  4196. } else {
  4197. // For Texture Load() functions, the location parameter is a vector
  4198. // that consists of both the coordinate and the mipmap level (via the
  4199. // last vector element). We need to split it here since the
  4200. // OpImageFetch SPIR-V instruction encodes them as separate arguments.
  4201. splitVecLastElement(locationArg->getType(), location, &coordinate, &lod,
  4202. locationArg->getExprLoc());
  4203. // For textures other than Texture2DMS(Array), offset should be the
  4204. // second parameter (index 1).
  4205. if (hasOffsetArg)
  4206. handleOffsetInMethodCall(expr, 1, &constOffset, &varOffset);
  4207. }
  4208. return processBufferTextureLoad(object, coordinate, constOffset, varOffset,
  4209. lod, status, loc);
  4210. }
  4211. emitError("Load() of the given object type unimplemented",
  4212. object->getExprLoc());
  4213. return nullptr;
  4214. }
  4215. SpirvInstruction *
  4216. SpirvEmitter::processGetDimensions(const CXXMemberCallExpr *expr) {
  4217. const auto objectType = expr->getImplicitObjectArgument()->getType();
  4218. if (isTexture(objectType) || isRWTexture(objectType) ||
  4219. isBuffer(objectType) || isRWBuffer(objectType)) {
  4220. return processBufferTextureGetDimensions(expr);
  4221. } else if (isByteAddressBuffer(objectType) ||
  4222. isRWByteAddressBuffer(objectType) ||
  4223. isStructuredBuffer(objectType) ||
  4224. isAppendStructuredBuffer(objectType) ||
  4225. isConsumeStructuredBuffer(objectType)) {
  4226. return processByteAddressBufferStructuredBufferGetDimensions(expr);
  4227. } else {
  4228. emitError("GetDimensions() of the given object type unimplemented",
  4229. expr->getExprLoc());
  4230. return nullptr;
  4231. }
  4232. }
  4233. SpirvInstruction *
  4234. SpirvEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
  4235. { // Handle Buffer/RWBuffer/Texture/RWTexture indexing
  4236. const Expr *baseExpr = nullptr;
  4237. const Expr *indexExpr = nullptr;
  4238. const Expr *lodExpr = nullptr;
  4239. // For Textures, regular indexing (operator[]) uses slice 0.
  4240. if (isBufferTextureIndexing(expr, &baseExpr, &indexExpr)) {
  4241. auto *lod = isTexture(baseExpr->getType())
  4242. ? spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  4243. llvm::APInt(32, 0))
  4244. : nullptr;
  4245. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  4246. /*constOffset*/ nullptr,
  4247. /*varOffset*/ nullptr, lod,
  4248. /*residencyCode*/ nullptr,
  4249. expr->getExprLoc());
  4250. }
  4251. // .mips[][] or .sample[][] must use the correct slice.
  4252. if (isTextureMipsSampleIndexing(expr, &baseExpr, &indexExpr, &lodExpr)) {
  4253. auto *lod = doExpr(lodExpr);
  4254. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  4255. /*constOffset*/ nullptr,
  4256. /*varOffset*/ nullptr, lod,
  4257. /*residencyCode*/ nullptr,
  4258. expr->getExprLoc());
  4259. }
  4260. }
  4261. llvm::SmallVector<SpirvInstruction *, 4> indices;
  4262. const Expr *baseExpr = collectArrayStructIndices(
  4263. expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
  4264. auto base = loadIfAliasVarRef(baseExpr);
  4265. if (indices.empty())
  4266. return base; // For indexing into size-1 vectors and 1xN matrices
  4267. // If we are indexing into a rvalue, to use OpAccessChain, we first need
  4268. // to create a local variable to hold the rvalue.
  4269. //
  4270. // TODO: We can optimize the codegen by emitting OpCompositeExtract if
  4271. // all indices are contant integers.
  4272. if (base->isRValue()) {
  4273. base = createTemporaryVar(baseExpr->getType(), "vector", base,
  4274. baseExpr->getExprLoc());
  4275. }
  4276. return turnIntoElementPtr(baseExpr->getType(), base, expr->getType(), indices,
  4277. baseExpr->getExprLoc());
  4278. }
  4279. SpirvInstruction *
  4280. SpirvEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
  4281. const Expr *baseExpr = expr->getBase();
  4282. auto *baseInfo = doExpr(baseExpr);
  4283. const auto layoutRule = baseInfo->getLayoutRule();
  4284. const auto elemType = hlsl::GetHLSLMatElementType(baseExpr->getType());
  4285. const auto accessor = expr->getEncodedElementAccess();
  4286. uint32_t rowCount = 0, colCount = 0;
  4287. hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
  4288. // Construct a temporary vector out of all elements accessed:
  4289. // 1. Create access chain for each element using OpAccessChain
  4290. // 2. Load each element using OpLoad
  4291. // 3. Create the vector using OpCompositeConstruct
  4292. llvm::SmallVector<SpirvInstruction *, 4> elements;
  4293. for (uint32_t i = 0; i < accessor.Count; ++i) {
  4294. uint32_t row = 0, col = 0;
  4295. SpirvInstruction *elem = nullptr;
  4296. accessor.GetPosition(i, &row, &col);
  4297. llvm::SmallVector<uint32_t, 2> indices;
  4298. // If the matrix only has one row/column, we are indexing into a vector
  4299. // then. Only one index is needed for such cases.
  4300. if (rowCount > 1)
  4301. indices.push_back(row);
  4302. if (colCount > 1)
  4303. indices.push_back(col);
  4304. if (!baseInfo->isRValue()) {
  4305. llvm::SmallVector<SpirvInstruction *, 2> indexInstructions(indices.size(),
  4306. nullptr);
  4307. for (uint32_t i = 0; i < indices.size(); ++i)
  4308. indexInstructions[i] = spvBuilder.getConstantInt(
  4309. astContext.IntTy, llvm::APInt(32, indices[i], true));
  4310. if (!indices.empty()) {
  4311. assert(!baseInfo->isRValue());
  4312. // Load the element via access chain
  4313. elem = spvBuilder.createAccessChain(
  4314. elemType, baseInfo, indexInstructions, baseExpr->getLocStart());
  4315. } else {
  4316. // The matrix is of size 1x1. No need to use access chain, base should
  4317. // be the source pointer.
  4318. elem = baseInfo;
  4319. }
  4320. elem = spvBuilder.createLoad(elemType, elem, baseExpr->getLocStart());
  4321. } else { // e.g., (mat1 + mat2)._m11
  4322. elem = spvBuilder.createCompositeExtract(elemType, baseInfo, indices,
  4323. baseExpr->getLocStart());
  4324. }
  4325. elements.push_back(elem);
  4326. }
  4327. const auto size = elements.size();
  4328. auto *value = elements.front();
  4329. if (size > 1) {
  4330. value = spvBuilder.createCompositeConstruct(
  4331. astContext.getExtVectorType(elemType, size), elements,
  4332. expr->getLocStart());
  4333. }
  4334. // Note: Special-case: Booleans have no physical layout, and therefore when
  4335. // layout is required booleans are represented as unsigned integers.
  4336. // Therefore, after loading the uint we should convert it boolean.
  4337. if (elemType->isBooleanType() && layoutRule != SpirvLayoutRule::Void) {
  4338. const auto fromType =
  4339. size == 1 ? astContext.UnsignedIntTy
  4340. : astContext.getExtVectorType(astContext.UnsignedIntTy, size);
  4341. const auto toType =
  4342. size == 1 ? astContext.BoolTy
  4343. : astContext.getExtVectorType(astContext.BoolTy, size);
  4344. value = castToBool(value, fromType, toType, expr->getLocStart());
  4345. }
  4346. value->setRValue();
  4347. return value;
  4348. }
  4349. SpirvInstruction *
  4350. SpirvEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
  4351. const Expr *baseExpr = nullptr;
  4352. hlsl::VectorMemberAccessPositions accessor;
  4353. condenseVectorElementExpr(expr, &baseExpr, &accessor);
  4354. const QualType baseType = baseExpr->getType();
  4355. assert(hlsl::IsHLSLVecType(baseType));
  4356. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  4357. const auto accessorSize = static_cast<size_t>(accessor.Count);
  4358. // Depending on the number of elements selected, we emit different
  4359. // instructions.
  4360. // For vectors of size greater than 1, if we are only selecting one element,
  4361. // typical access chain or composite extraction should be fine. But if we
  4362. // are selecting more than one elements, we must resolve to vector specific
  4363. // operations.
  4364. // For size-1 vectors, if we are selecting their single elements multiple
  4365. // times, we need composite construct instructions.
  4366. if (accessorSize == 1) {
  4367. auto *baseInfo = doExpr(baseExpr);
  4368. if (!baseInfo || baseSize == 1) {
  4369. // Selecting one element from a size-1 vector. The underlying vector is
  4370. // already treated as a scalar.
  4371. return baseInfo;
  4372. }
  4373. // If the base is an lvalue, we should emit an access chain instruction
  4374. // so that we can load/store the specified element. For rvalue base,
  4375. // we should use composite extraction. We should check the immediate base
  4376. // instead of the original base here since we can have something like
  4377. // v.xyyz to turn a lvalue v into rvalue.
  4378. const auto type = expr->getType();
  4379. if (!baseInfo->isRValue()) { // E.g., v.x;
  4380. auto *index = spvBuilder.getConstantInt(
  4381. astContext.IntTy, llvm::APInt(32, accessor.Swz0, true));
  4382. // We need a lvalue here. Do not try to load.
  4383. return spvBuilder.createAccessChain(type, baseInfo, {index},
  4384. baseExpr->getLocStart());
  4385. } else { // E.g., (v + w).x;
  4386. // The original base vector may not be a rvalue. Need to load it if
  4387. // it is lvalue since ImplicitCastExpr (LValueToRValue) will be missing
  4388. // for that case.
  4389. SpirvInstruction *result = spvBuilder.createCompositeExtract(
  4390. type, baseInfo, {accessor.Swz0}, baseExpr->getLocStart());
  4391. // Special-case: Booleans in SPIR-V do not have a physical layout. Uint is
  4392. // used to represent them when layout is required.
  4393. if (expr->getType()->isBooleanType() &&
  4394. baseInfo->getLayoutRule() != SpirvLayoutRule::Void)
  4395. result = castToBool(result, astContext.UnsignedIntTy, astContext.BoolTy,
  4396. expr->getLocStart());
  4397. return result;
  4398. }
  4399. }
  4400. if (baseSize == 1) {
  4401. // Selecting more than one element from a size-1 vector, for example,
  4402. // <scalar>.xx. Construct the vector.
  4403. auto *info = loadIfGLValue(baseExpr);
  4404. const auto type = expr->getType();
  4405. llvm::SmallVector<SpirvInstruction *, 4> components(accessorSize, info);
  4406. info = spvBuilder.createCompositeConstruct(type, components,
  4407. expr->getLocStart());
  4408. info->setRValue();
  4409. return info;
  4410. }
  4411. llvm::SmallVector<uint32_t, 4> selectors;
  4412. selectors.resize(accessorSize);
  4413. // Whether we are selecting elements in the original order
  4414. bool originalOrder = baseSize == accessorSize;
  4415. for (uint32_t i = 0; i < accessorSize; ++i) {
  4416. accessor.GetPosition(i, &selectors[i]);
  4417. // We can select more elements than the vector provides. This handles
  4418. // that case too.
  4419. originalOrder &= selectors[i] == i;
  4420. }
  4421. if (originalOrder)
  4422. return doExpr(baseExpr);
  4423. auto *info = loadIfGLValue(baseExpr);
  4424. // Use base for both vectors. But we are only selecting values from the
  4425. // first one.
  4426. return spvBuilder.createVectorShuffle(expr->getType(), info, info, selectors,
  4427. expr->getLocStart());
  4428. }
  4429. SpirvInstruction *SpirvEmitter::doInitListExpr(const InitListExpr *expr) {
  4430. if (auto *id = tryToEvaluateAsConst(expr)) {
  4431. id->setRValue();
  4432. return id;
  4433. }
  4434. auto *result = InitListHandler(astContext, *this).processInit(expr);
  4435. result->setRValue();
  4436. return result;
  4437. }
  4438. SpirvInstruction *SpirvEmitter::doMemberExpr(const MemberExpr *expr) {
  4439. llvm::SmallVector<SpirvInstruction *, 4> indices;
  4440. const Expr *base = collectArrayStructIndices(
  4441. expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices);
  4442. auto *instr = loadIfAliasVarRef(base);
  4443. if (instr && !indices.empty()) {
  4444. instr = turnIntoElementPtr(base->getType(), instr, expr->getType(), indices,
  4445. base->getExprLoc());
  4446. }
  4447. return instr;
  4448. }
  4449. SpirvVariable *SpirvEmitter::createTemporaryVar(QualType type,
  4450. llvm::StringRef name,
  4451. SpirvInstruction *init,
  4452. SourceLocation loc) {
  4453. // We are creating a temporary variable in the Function storage class here,
  4454. // which means it has void layout rule.
  4455. const std::string varName = "temp.var." + name.str();
  4456. auto *var = spvBuilder.addFnVar(type, loc, varName);
  4457. storeValue(var, init, type, loc);
  4458. return var;
  4459. }
  4460. SpirvInstruction *SpirvEmitter::doUnaryOperator(const UnaryOperator *expr) {
  4461. const auto opcode = expr->getOpcode();
  4462. const auto *subExpr = expr->getSubExpr();
  4463. const auto subType = subExpr->getType();
  4464. auto *subValue = doExpr(subExpr);
  4465. switch (opcode) {
  4466. case UO_PreInc:
  4467. case UO_PreDec:
  4468. case UO_PostInc:
  4469. case UO_PostDec: {
  4470. const bool isPre = opcode == UO_PreInc || opcode == UO_PreDec;
  4471. const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
  4472. const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
  4473. SpirvInstruction *originValue =
  4474. subValue->isRValue()
  4475. ? subValue
  4476. : spvBuilder.createLoad(subType, subValue, subExpr->getLocStart());
  4477. auto *one = hlsl::IsHLSLMatType(subType) ? getMatElemValueOne(subType)
  4478. : getValueOne(subType);
  4479. SpirvInstruction *incValue = nullptr;
  4480. if (isMxNMatrix(subType)) {
  4481. // For matrices, we can only increment/decrement each vector of it.
  4482. const auto actOnEachVec = [this, spvOp, one,
  4483. expr](uint32_t /*index*/, QualType vecType,
  4484. SpirvInstruction *lhsVec) {
  4485. auto *val = spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, one,
  4486. expr->getOperatorLoc());
  4487. val->setRValue();
  4488. return val;
  4489. };
  4490. incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec,
  4491. expr->getLocStart());
  4492. } else {
  4493. incValue = spvBuilder.createBinaryOp(spvOp, subType, originValue, one,
  4494. expr->getOperatorLoc());
  4495. }
  4496. // If this is a RWBuffer/RWTexture assignment, OpImageWrite will be used.
  4497. // Otherwise, store using OpStore.
  4498. if (tryToAssignToRWBufferRWTexture(subExpr, incValue)) {
  4499. incValue->setRValue();
  4500. subValue = incValue;
  4501. } else {
  4502. spvBuilder.createStore(subValue, incValue, subExpr->getLocStart());
  4503. }
  4504. // Prefix increment/decrement operator returns a lvalue, while postfix
  4505. // increment/decrement returns a rvalue.
  4506. if (isPre) {
  4507. return subValue;
  4508. } else {
  4509. originValue->setRValue();
  4510. return originValue;
  4511. }
  4512. }
  4513. case UO_Not: {
  4514. subValue = spvBuilder.createUnaryOp(spv::Op::OpNot, subType, subValue,
  4515. expr->getOperatorLoc());
  4516. subValue->setRValue();
  4517. return subValue;
  4518. }
  4519. case UO_LNot: {
  4520. // Parsing will do the necessary casting to make sure we are applying the
  4521. // ! operator on boolean values.
  4522. subValue = spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, subType,
  4523. subValue, expr->getOperatorLoc());
  4524. subValue->setRValue();
  4525. return subValue;
  4526. }
  4527. case UO_Plus:
  4528. // No need to do anything for the prefix + operator.
  4529. return subValue;
  4530. case UO_Minus: {
  4531. // SPIR-V have two opcodes for negating values: OpSNegate and OpFNegate.
  4532. const spv::Op spvOp = isFloatOrVecMatOfFloatType(subType)
  4533. ? spv::Op::OpFNegate
  4534. : spv::Op::OpSNegate;
  4535. if (isMxNMatrix(subType)) {
  4536. // For matrices, we can only negate each vector of it.
  4537. const auto actOnEachVec = [this, spvOp, expr](uint32_t /*index*/,
  4538. QualType vecType,
  4539. SpirvInstruction *lhsVec) {
  4540. return spvBuilder.createUnaryOp(spvOp, vecType, lhsVec,
  4541. expr->getOperatorLoc());
  4542. };
  4543. return processEachVectorInMatrix(subExpr, subValue, actOnEachVec,
  4544. expr->getLocStart());
  4545. } else {
  4546. subValue = spvBuilder.createUnaryOp(spvOp, subType, subValue,
  4547. expr->getOperatorLoc());
  4548. subValue->setRValue();
  4549. return subValue;
  4550. }
  4551. }
  4552. default:
  4553. break;
  4554. }
  4555. emitError("unary operator '%0' unimplemented", expr->getExprLoc())
  4556. << expr->getOpcodeStr(opcode);
  4557. expr->dump();
  4558. return 0;
  4559. }
  4560. spv::Op SpirvEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
  4561. const bool isSintType = isSintOrVecMatOfSintType(type);
  4562. const bool isUintType = isUintOrVecMatOfUintType(type);
  4563. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  4564. #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
  4565. \
  4566. case BO_##kind: { \
  4567. if (isSintType || isUintType) { \
  4568. return spv::Op::Op##intBinOp; \
  4569. } \
  4570. if (isFloatType) { \
  4571. return spv::Op::Op##floatBinOp; \
  4572. } \
  4573. } break
  4574. #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp) \
  4575. \
  4576. case BO_##kind: { \
  4577. if (isSintType) { \
  4578. return spv::Op::Op##sintBinOp; \
  4579. } \
  4580. if (isUintType) { \
  4581. return spv::Op::Op##uintBinOp; \
  4582. } \
  4583. if (isFloatType) { \
  4584. return spv::Op::Op##floatBinOp; \
  4585. } \
  4586. } break
  4587. #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp) \
  4588. \
  4589. case BO_##kind: { \
  4590. if (isSintType) { \
  4591. return spv::Op::Op##sintBinOp; \
  4592. } \
  4593. if (isUintType) { \
  4594. return spv::Op::Op##uintBinOp; \
  4595. } \
  4596. } break
  4597. switch (op) {
  4598. case BO_EQ: {
  4599. if (isBoolOrVecMatOfBoolType(type))
  4600. return spv::Op::OpLogicalEqual;
  4601. if (isSintType || isUintType)
  4602. return spv::Op::OpIEqual;
  4603. if (isFloatType)
  4604. return spv::Op::OpFOrdEqual;
  4605. } break;
  4606. case BO_NE: {
  4607. if (isBoolOrVecMatOfBoolType(type))
  4608. return spv::Op::OpLogicalNotEqual;
  4609. if (isSintType || isUintType)
  4610. return spv::Op::OpINotEqual;
  4611. if (isFloatType)
  4612. return spv::Op::OpFOrdNotEqual;
  4613. } break;
  4614. // According to HLSL doc, all sides of the && and || expression are always
  4615. // evaluated.
  4616. case BO_LAnd:
  4617. return spv::Op::OpLogicalAnd;
  4618. case BO_LOr:
  4619. return spv::Op::OpLogicalOr;
  4620. BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
  4621. BIN_OP_CASE_INT_FLOAT(AddAssign, IAdd, FAdd);
  4622. BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
  4623. BIN_OP_CASE_INT_FLOAT(SubAssign, ISub, FSub);
  4624. BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
  4625. BIN_OP_CASE_INT_FLOAT(MulAssign, IMul, FMul);
  4626. BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
  4627. BIN_OP_CASE_SINT_UINT_FLOAT(DivAssign, SDiv, UDiv, FDiv);
  4628. // According to HLSL spec, "the modulus operator returns the remainder of
  4629. // a division." "The % operator is defined only in cases where either both
  4630. // sides are positive or both sides are negative."
  4631. //
  4632. // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
  4633. // the former, the sign of a non-0 result comes from Operand 1, while
  4634. // with the latter, from Operand 2.
  4635. //
  4636. // For operands with different signs, technically we can map % to either
  4637. // Op*Rem or Op*Mod since it's undefined behavior. But it is more
  4638. // consistent with C (HLSL starts as a C derivative) and Clang frontend
  4639. // const expression evaluation if we map % to Op*Rem.
  4640. //
  4641. // Note there is no OpURem in SPIR-V.
  4642. BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
  4643. BIN_OP_CASE_SINT_UINT_FLOAT(RemAssign, SRem, UMod, FRem);
  4644. BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
  4645. BIN_OP_CASE_SINT_UINT_FLOAT(LE, SLessThanEqual, ULessThanEqual,
  4646. FOrdLessThanEqual);
  4647. BIN_OP_CASE_SINT_UINT_FLOAT(GT, SGreaterThan, UGreaterThan,
  4648. FOrdGreaterThan);
  4649. BIN_OP_CASE_SINT_UINT_FLOAT(GE, SGreaterThanEqual, UGreaterThanEqual,
  4650. FOrdGreaterThanEqual);
  4651. BIN_OP_CASE_SINT_UINT(And, BitwiseAnd, BitwiseAnd);
  4652. BIN_OP_CASE_SINT_UINT(AndAssign, BitwiseAnd, BitwiseAnd);
  4653. BIN_OP_CASE_SINT_UINT(Or, BitwiseOr, BitwiseOr);
  4654. BIN_OP_CASE_SINT_UINT(OrAssign, BitwiseOr, BitwiseOr);
  4655. BIN_OP_CASE_SINT_UINT(Xor, BitwiseXor, BitwiseXor);
  4656. BIN_OP_CASE_SINT_UINT(XorAssign, BitwiseXor, BitwiseXor);
  4657. BIN_OP_CASE_SINT_UINT(Shl, ShiftLeftLogical, ShiftLeftLogical);
  4658. BIN_OP_CASE_SINT_UINT(ShlAssign, ShiftLeftLogical, ShiftLeftLogical);
  4659. BIN_OP_CASE_SINT_UINT(Shr, ShiftRightArithmetic, ShiftRightLogical);
  4660. BIN_OP_CASE_SINT_UINT(ShrAssign, ShiftRightArithmetic, ShiftRightLogical);
  4661. default:
  4662. break;
  4663. }
  4664. #undef BIN_OP_CASE_INT_FLOAT
  4665. #undef BIN_OP_CASE_SINT_UINT_FLOAT
  4666. #undef BIN_OP_CASE_SINT_UINT
  4667. emitError("translating binary operator '%0' unimplemented", {})
  4668. << BinaryOperator::getOpcodeStr(op);
  4669. return spv::Op::OpNop;
  4670. }
  4671. SpirvInstruction *
  4672. SpirvEmitter::processAssignment(const Expr *lhs, SpirvInstruction *rhs,
  4673. const bool isCompoundAssignment,
  4674. SpirvInstruction *lhsPtr) {
  4675. lhs = lhs->IgnoreParenNoopCasts(astContext);
  4676. // Assigning to vector swizzling should be handled differently.
  4677. if (SpirvInstruction *result = tryToAssignToVectorElements(lhs, rhs))
  4678. return result;
  4679. // Assigning to matrix swizzling should be handled differently.
  4680. if (SpirvInstruction *result = tryToAssignToMatrixElements(lhs, rhs))
  4681. return result;
  4682. // Assigning to a RWBuffer/RWTexture should be handled differently.
  4683. if (SpirvInstruction *result = tryToAssignToRWBufferRWTexture(lhs, rhs))
  4684. return result;
  4685. // Assigning to a out attribute or indices object in mesh shader should be
  4686. // handled differently.
  4687. if (SpirvInstruction *result = tryToAssignToMSOutAttrsOrIndices(lhs, rhs))
  4688. return result;
  4689. // Assigning to a 'string' variable. SPIR-V doesn't have a string type, and we
  4690. // do not allow creating or modifying string variables. We do allow use of
  4691. // string literals using OpString.
  4692. if (isStringType(lhs->getType())) {
  4693. emitError("string variables are immutable in SPIR-V.", lhs->getExprLoc());
  4694. return nullptr;
  4695. }
  4696. // Normal assignment procedure
  4697. if (!lhsPtr)
  4698. lhsPtr = doExpr(lhs);
  4699. storeValue(lhsPtr, rhs, lhs->getType(), lhs->getLocStart());
  4700. // Plain assignment returns a rvalue, while compound assignment returns
  4701. // lvalue.
  4702. return isCompoundAssignment ? lhsPtr : rhs;
  4703. }
  4704. void SpirvEmitter::storeValue(SpirvInstruction *lhsPtr,
  4705. SpirvInstruction *rhsVal, QualType lhsValType,
  4706. SourceLocation loc) {
  4707. // Defend against nullptr source or destination so errors can bubble up to the
  4708. // user.
  4709. if (!lhsPtr || !rhsVal)
  4710. return;
  4711. if (const auto *refType = lhsValType->getAs<ReferenceType>())
  4712. lhsValType = refType->getPointeeType();
  4713. QualType matElemType = {};
  4714. const bool lhsIsMat = isMxNMatrix(lhsValType, &matElemType);
  4715. const bool lhsIsFloatMat = lhsIsMat && matElemType->isFloatingType();
  4716. const bool lhsIsNonFpMat = lhsIsMat && !matElemType->isFloatingType();
  4717. if (isScalarType(lhsValType) || isVectorType(lhsValType) || lhsIsFloatMat) {
  4718. // Special-case: According to the SPIR-V Spec: There is no physical size
  4719. // or bit pattern defined for boolean type. Therefore an unsigned integer
  4720. // is used to represent booleans when layout is required. In such cases,
  4721. // we should cast the boolean to uint before creating OpStore.
  4722. if (isBoolOrVecOfBoolType(lhsValType) &&
  4723. lhsPtr->getLayoutRule() != SpirvLayoutRule::Void) {
  4724. uint32_t vecSize = 1;
  4725. const bool isVec = isVectorType(lhsValType, nullptr, &vecSize);
  4726. const auto toType =
  4727. isVec ? astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize)
  4728. : astContext.UnsignedIntTy;
  4729. const auto fromType =
  4730. isVec ? astContext.getExtVectorType(astContext.BoolTy, vecSize)
  4731. : astContext.BoolTy;
  4732. rhsVal = castToInt(rhsVal, fromType, toType, {});
  4733. }
  4734. spvBuilder.createStore(lhsPtr, rhsVal, loc);
  4735. } else if (isOpaqueType(lhsValType)) {
  4736. // Resource types are represented using RecordType in the AST.
  4737. // Handle them before the general RecordType.
  4738. //
  4739. // HLSL allows to put resource types that translating into SPIR-V opaque
  4740. // types in structs, or assign to variables of resource types. These can all
  4741. // result in illegal SPIR-V for Vulkan. We just translate here literally and
  4742. // let SPIRV-Tools opt to do the legalization work.
  4743. //
  4744. // Note: legalization specific code
  4745. if (hlsl::IsHLSLRayQueryType(lhsValType)) {
  4746. emitError("store value of type %0 is unsupported", {}) << lhsValType;
  4747. return;
  4748. }
  4749. spvBuilder.createStore(lhsPtr, rhsVal, loc);
  4750. needsLegalization = true;
  4751. } else if (isAKindOfStructuredOrByteBuffer(lhsValType)) {
  4752. // The rhs should be a pointer and the lhs should be a pointer-to-pointer.
  4753. // Directly store the pointer here and let SPIRV-Tools opt to do the clean
  4754. // up.
  4755. //
  4756. // Note: legalization specific code
  4757. spvBuilder.createStore(lhsPtr, rhsVal, loc);
  4758. needsLegalization = true;
  4759. // For ConstantBuffers/TextureBuffers, we decompose and assign each field
  4760. // recursively like normal structs using the following logic.
  4761. //
  4762. // The frontend forbids declaring ConstantBuffer<T> or TextureBuffer<T>
  4763. // variables as function parameters/returns/variables, but happily accepts
  4764. // assignments/returns from ConstantBuffer<T>/TextureBuffer<T> to function
  4765. // parameters/returns/variables of type T. And ConstantBuffer<T> is not
  4766. // represented differently as struct T.
  4767. } else if (isOpaqueArrayType(lhsValType)) {
  4768. // For opaque array types, we cannot perform OpLoad on the whole array and
  4769. // then write out as a whole; instead, we need to OpLoad each element
  4770. // using access chains. This is to influence later SPIR-V transformations
  4771. // to use access chains to access each opaque object; if we do array
  4772. // wholesale handling here, they will be in the final transformed code.
  4773. // Drivers don't like that.
  4774. // TODO: consider moving this hack into SPIRV-Tools as a transformation.
  4775. assert(!rhsVal->isRValue());
  4776. if (!lhsValType->isConstantArrayType()) {
  4777. spvBuilder.createStore(lhsPtr, rhsVal, loc);
  4778. needsLegalization = true;
  4779. return;
  4780. }
  4781. const auto *arrayType = astContext.getAsConstantArrayType(lhsValType);
  4782. const auto elemType = arrayType->getElementType();
  4783. const auto arraySize =
  4784. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  4785. // Do separate load of each element via access chain
  4786. llvm::SmallVector<SpirvInstruction *, 8> elements;
  4787. for (uint32_t i = 0; i < arraySize; ++i) {
  4788. auto *subRhsPtr = spvBuilder.createAccessChain(
  4789. elemType, rhsVal,
  4790. {spvBuilder.getConstantInt(astContext.IntTy,
  4791. llvm::APInt(32, i, true))},
  4792. loc);
  4793. elements.push_back(spvBuilder.createLoad(elemType, subRhsPtr, loc));
  4794. }
  4795. // Create a new composite and write out once
  4796. spvBuilder.createStore(
  4797. lhsPtr,
  4798. spvBuilder.createCompositeConstruct(lhsValType, elements,
  4799. rhsVal->getSourceLocation()),
  4800. loc);
  4801. } else if (lhsPtr->getLayoutRule() == rhsVal->getLayoutRule()) {
  4802. // If lhs and rhs has the same memory layout, we should be safe to load
  4803. // from rhs and directly store into lhs and avoid decomposing rhs.
  4804. // Note: this check should happen after those setting needsLegalization.
  4805. // TODO: is this optimization always correct?
  4806. spvBuilder.createStore(lhsPtr, rhsVal, loc);
  4807. } else if (lhsValType->isRecordType() || lhsValType->isConstantArrayType() ||
  4808. lhsIsNonFpMat) {
  4809. spvBuilder.createStore(
  4810. lhsPtr,
  4811. reconstructValue(rhsVal, lhsValType, lhsPtr->getLayoutRule(), loc),
  4812. loc);
  4813. } else {
  4814. emitError("storing value of type %0 unimplemented", {}) << lhsValType;
  4815. }
  4816. }
  4817. SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,
  4818. const QualType valType,
  4819. SpirvLayoutRule dstLR,
  4820. SourceLocation loc) {
  4821. // Lambda for casting scalar or vector of bool<-->uint in cases where one side
  4822. // of the reconstruction (lhs or rhs) has a layout rule.
  4823. const auto handleBooleanLayout = [this, &srcVal, dstLR,
  4824. loc](SpirvInstruction *val,
  4825. QualType valType) {
  4826. // We only need to cast if we have a scalar or vector of booleans.
  4827. if (!isBoolOrVecOfBoolType(valType))
  4828. return val;
  4829. SpirvLayoutRule srcLR = srcVal->getLayoutRule();
  4830. // Source value has a layout rule, and has therefore been represented
  4831. // as a uint. Cast it to boolean before using.
  4832. bool shouldCastToBool =
  4833. srcLR != SpirvLayoutRule::Void && dstLR == SpirvLayoutRule::Void;
  4834. // Destination has a layout rule, and should therefore be represented
  4835. // as a uint. Cast to uint before using.
  4836. bool shouldCastToUint =
  4837. srcLR == SpirvLayoutRule::Void && dstLR != SpirvLayoutRule::Void;
  4838. // No boolean layout issues to take care of.
  4839. if (!shouldCastToBool && !shouldCastToUint)
  4840. return val;
  4841. uint32_t vecSize = 1;
  4842. isVectorType(valType, nullptr, &vecSize);
  4843. QualType boolType =
  4844. vecSize == 1 ? astContext.BoolTy
  4845. : astContext.getExtVectorType(astContext.BoolTy, vecSize);
  4846. QualType uintType =
  4847. vecSize == 1
  4848. ? astContext.UnsignedIntTy
  4849. : astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize);
  4850. if (shouldCastToBool)
  4851. return castToBool(val, uintType, boolType, loc);
  4852. if (shouldCastToUint)
  4853. return castToInt(val, boolType, uintType, loc);
  4854. return val;
  4855. };
  4856. // Lambda for cases where we want to reconstruct an array
  4857. const auto reconstructArray = [this, &srcVal, valType, dstLR,
  4858. loc](uint32_t arraySize,
  4859. QualType arrayElemType) {
  4860. llvm::SmallVector<SpirvInstruction *, 4> elements;
  4861. for (uint32_t i = 0; i < arraySize; ++i) {
  4862. SpirvInstruction *subSrcVal =
  4863. spvBuilder.createCompositeExtract(arrayElemType, srcVal, {i}, loc);
  4864. subSrcVal->setLayoutRule(srcVal->getLayoutRule());
  4865. elements.push_back(
  4866. reconstructValue(subSrcVal, arrayElemType, dstLR, loc));
  4867. }
  4868. auto *result = spvBuilder.createCompositeConstruct(
  4869. valType, elements, srcVal->getSourceLocation());
  4870. result->setLayoutRule(dstLR);
  4871. return result;
  4872. };
  4873. // Constant arrays
  4874. if (const auto *arrayType = astContext.getAsConstantArrayType(valType)) {
  4875. const auto elemType = arrayType->getElementType();
  4876. const auto size =
  4877. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  4878. return reconstructArray(size, elemType);
  4879. }
  4880. // Non-floating-point matrices
  4881. QualType matElemType = {};
  4882. uint32_t numRows = 0, numCols = 0;
  4883. const bool isNonFpMat =
  4884. isMxNMatrix(valType, &matElemType, &numRows, &numCols) &&
  4885. !matElemType->isFloatingType();
  4886. if (isNonFpMat) {
  4887. // Note: This check should happen before the RecordType check.
  4888. // Non-fp matrices are represented as arrays of vectors in SPIR-V.
  4889. // Each array element is a vector. Get the QualType for the vector.
  4890. const auto elemType = astContext.getExtVectorType(matElemType, numCols);
  4891. return reconstructArray(numRows, elemType);
  4892. }
  4893. // Note: This check should happen before the RecordType check since
  4894. // vector/matrix/resource types are represented as RecordType in the AST.
  4895. if (hlsl::IsHLSLVecMatType(valType) || hlsl::IsHLSLResourceType(valType))
  4896. return handleBooleanLayout(srcVal, valType);
  4897. // Structs
  4898. if (const auto *recordType = valType->getAs<RecordType>()) {
  4899. uint32_t index = 0;
  4900. llvm::SmallVector<SpirvInstruction *, 4> elements;
  4901. for (const auto *field : recordType->getDecl()->fields()) {
  4902. SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
  4903. field->getType(), srcVal, {index}, loc);
  4904. subSrcVal->setLayoutRule(srcVal->getLayoutRule());
  4905. elements.push_back(
  4906. reconstructValue(subSrcVal, field->getType(), dstLR, loc));
  4907. ++index;
  4908. }
  4909. auto *result = spvBuilder.createCompositeConstruct(
  4910. valType, elements, srcVal->getSourceLocation());
  4911. result->setLayoutRule(dstLR);
  4912. return result;
  4913. }
  4914. return handleBooleanLayout(srcVal, valType);
  4915. }
  4916. SpirvInstruction *SpirvEmitter::processBinaryOp(
  4917. const Expr *lhs, const Expr *rhs, const BinaryOperatorKind opcode,
  4918. const QualType computationType, const QualType resultType,
  4919. SourceRange sourceRange, SourceLocation loc, SpirvInstruction **lhsInfo,
  4920. const spv::Op mandateGenOpcode) {
  4921. const QualType lhsType = lhs->getType();
  4922. const QualType rhsType = rhs->getType();
  4923. // If the operands are of matrix type, we need to dispatch the operation
  4924. // onto each element vector iff the operands are not degenerated matrices
  4925. // and we don't have a matrix specific SPIR-V instruction for the operation.
  4926. if (!isSpirvMatrixOp(mandateGenOpcode) && isMxNMatrix(lhsType)) {
  4927. return processMatrixBinaryOp(lhs, rhs, opcode, sourceRange, loc);
  4928. }
  4929. // Comma operator works differently from other binary operations as there is
  4930. // no SPIR-V instruction for it. For each comma, we must evaluate lhs and rhs
  4931. // respectively, and return the results of rhs.
  4932. if (opcode == BO_Comma) {
  4933. (void)doExpr(lhs);
  4934. return doExpr(rhs);
  4935. }
  4936. SpirvInstruction *rhsVal = nullptr, *lhsPtr = nullptr, *lhsVal = nullptr;
  4937. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  4938. // Evalute rhs before lhs
  4939. rhsVal = loadIfGLValue(rhs);
  4940. lhsVal = lhsPtr = doExpr(lhs);
  4941. // This is a compound assignment. We need to load the lhs value if lhs
  4942. // is not already rvalue and does not generate a vector shuffle.
  4943. if (!lhsPtr->isRValue() && !isVectorShuffle(lhs)) {
  4944. lhsVal = loadIfGLValue(lhs, lhsPtr);
  4945. }
  4946. // For a compound assignments, the AST does not have the proper implicit
  4947. // cast if lhs and rhs have different types. So we need to manually cast lhs
  4948. // to the computation type.
  4949. if (computationType != lhsType)
  4950. lhsVal = castToType(lhsVal, lhsType, computationType, lhs->getExprLoc());
  4951. } else {
  4952. // Evalute lhs before rhs
  4953. lhsPtr = doExpr(lhs);
  4954. if (!lhsPtr)
  4955. return nullptr;
  4956. lhsVal = loadIfGLValue(lhs, lhsPtr);
  4957. rhsVal = loadIfGLValue(rhs);
  4958. }
  4959. if (lhsInfo)
  4960. *lhsInfo = lhsPtr;
  4961. const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
  4962. ? translateOp(opcode, computationType)
  4963. : mandateGenOpcode;
  4964. switch (opcode) {
  4965. case BO_Shl:
  4966. case BO_Shr:
  4967. case BO_ShlAssign:
  4968. case BO_ShrAssign:
  4969. // We need to cull the RHS to make sure that we are not shifting by an
  4970. // amount that is larger than the bitwidth of the LHS.
  4971. rhsVal = spvBuilder.createBinaryOp(spv::Op::OpBitwiseAnd, computationType,
  4972. rhsVal, getMaskForBitwidthValue(rhsType),
  4973. loc);
  4974. // Fall through
  4975. case BO_Add:
  4976. case BO_Sub:
  4977. case BO_Mul:
  4978. case BO_Div:
  4979. case BO_Rem:
  4980. case BO_LT:
  4981. case BO_LE:
  4982. case BO_GT:
  4983. case BO_GE:
  4984. case BO_EQ:
  4985. case BO_NE:
  4986. case BO_And:
  4987. case BO_Or:
  4988. case BO_Xor:
  4989. case BO_LAnd:
  4990. case BO_LOr:
  4991. case BO_AddAssign:
  4992. case BO_SubAssign:
  4993. case BO_MulAssign:
  4994. case BO_DivAssign:
  4995. case BO_RemAssign:
  4996. case BO_AndAssign:
  4997. case BO_OrAssign:
  4998. case BO_XorAssign: {
  4999. // To evaluate this expression as an OpSpecConstantOp, we need to make sure
  5000. // both operands are constant and at least one of them is a spec constant.
  5001. if (SpirvConstant *lhsValConstant = dyn_cast<SpirvConstant>(lhsVal)) {
  5002. if (SpirvConstant *rhsValConstant = dyn_cast<SpirvConstant>(rhsVal)) {
  5003. if (isAcceptedSpecConstantBinaryOp(spvOp)) {
  5004. if (lhsValConstant->isSpecConstant() ||
  5005. rhsValConstant->isSpecConstant()) {
  5006. auto *val = spvBuilder.createSpecConstantBinaryOp(
  5007. spvOp, resultType, lhsVal, rhsVal, loc);
  5008. val->setRValue();
  5009. return val;
  5010. }
  5011. }
  5012. }
  5013. }
  5014. // Normal binary operation
  5015. SpirvInstruction *val = nullptr;
  5016. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  5017. val = spvBuilder.createBinaryOp(spvOp, computationType, lhsVal, rhsVal,
  5018. loc);
  5019. // For a compound assignments, the AST does not have the proper implicit
  5020. // cast if lhs and rhs have different types. So we need to manually cast
  5021. // the result back to lhs' type.
  5022. if (computationType != lhsType)
  5023. val = castToType(val, computationType, lhsType, lhs->getExprLoc());
  5024. } else {
  5025. val = spvBuilder.createBinaryOp(spvOp, resultType, lhsVal, rhsVal, loc);
  5026. }
  5027. val->setRValue();
  5028. // Propagate RelaxedPrecision
  5029. if ((lhsVal && lhsVal->isRelaxedPrecision()) ||
  5030. (rhsVal && rhsVal->isRelaxedPrecision()))
  5031. val->setRelaxedPrecision();
  5032. return val;
  5033. }
  5034. case BO_Assign:
  5035. llvm_unreachable("assignment should not be handled here");
  5036. break;
  5037. case BO_PtrMemD:
  5038. case BO_PtrMemI:
  5039. case BO_Comma:
  5040. // Unimplemented
  5041. break;
  5042. }
  5043. emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
  5044. << BinaryOperator::getOpcodeStr(opcode) << sourceRange;
  5045. return nullptr;
  5046. }
  5047. void SpirvEmitter::initOnce(QualType varType, std::string varName,
  5048. SpirvVariable *var, const Expr *varInit) {
  5049. // For uninitialized resource objects, we do nothing since there is no
  5050. // meaningful zero values for them.
  5051. if (!varInit && hlsl::IsHLSLResourceType(varType))
  5052. return;
  5053. varName = "init.done." + varName;
  5054. auto loc = varInit ? varInit->getLocStart() : SourceLocation();
  5055. // Create a file/module visible variable to hold the initialization state.
  5056. SpirvVariable *initDoneVar = spvBuilder.addModuleVar(
  5057. astContext.BoolTy, spv::StorageClass::Private, /*isPrecise*/ false,
  5058. varName, spvBuilder.getConstantBool(false));
  5059. auto *condition = spvBuilder.createLoad(astContext.BoolTy, initDoneVar, loc);
  5060. auto *todoBB = spvBuilder.createBasicBlock("if.init.todo");
  5061. auto *doneBB = spvBuilder.createBasicBlock("if.init.done");
  5062. // If initDoneVar contains true, we jump to the "done" basic block; otherwise,
  5063. // jump to the "todo" basic block.
  5064. spvBuilder.createConditionalBranch(condition, doneBB, todoBB, loc, doneBB);
  5065. spvBuilder.addSuccessor(todoBB);
  5066. spvBuilder.addSuccessor(doneBB);
  5067. spvBuilder.setMergeTarget(doneBB);
  5068. spvBuilder.setInsertPoint(todoBB);
  5069. // Do initialization and mark done
  5070. if (varInit) {
  5071. var->setStorageClass(spv::StorageClass::Private);
  5072. storeValue(
  5073. // Static function variable are of private storage class
  5074. var, loadIfGLValue(varInit), varInit->getType(), varInit->getLocEnd());
  5075. } else {
  5076. spvBuilder.createStore(var, spvBuilder.getConstantNull(varType), loc);
  5077. }
  5078. spvBuilder.createStore(initDoneVar, spvBuilder.getConstantBool(true), loc);
  5079. spvBuilder.createBranch(doneBB, loc);
  5080. spvBuilder.addSuccessor(doneBB);
  5081. spvBuilder.setInsertPoint(doneBB);
  5082. }
  5083. bool SpirvEmitter::isVectorShuffle(const Expr *expr) {
  5084. // TODO: the following check is essentially duplicated from
  5085. // doHLSLVectorElementExpr. Should unify them.
  5086. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  5087. const Expr *base = nullptr;
  5088. hlsl::VectorMemberAccessPositions accessor;
  5089. condenseVectorElementExpr(vecElemExpr, &base, &accessor);
  5090. const auto accessorSize = accessor.Count;
  5091. if (accessorSize == 1) {
  5092. // Selecting only one element. OpAccessChain or OpCompositeExtract for
  5093. // such cases.
  5094. return false;
  5095. }
  5096. const auto baseSize = hlsl::GetHLSLVecSize(base->getType());
  5097. if (accessorSize != baseSize)
  5098. return true;
  5099. for (uint32_t i = 0; i < accessorSize; ++i) {
  5100. uint32_t position;
  5101. accessor.GetPosition(i, &position);
  5102. if (position != i)
  5103. return true;
  5104. }
  5105. // Selecting exactly the original vector. No vector shuffle generated.
  5106. return false;
  5107. }
  5108. return false;
  5109. }
  5110. bool SpirvEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
  5111. const Expr **base,
  5112. const Expr **location,
  5113. const Expr **lod) {
  5114. if (!expr)
  5115. return false;
  5116. // <object>.mips[][] consists of an outer operator[] and an inner operator[]
  5117. const CXXOperatorCallExpr *outerExpr = expr;
  5118. if (outerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  5119. return false;
  5120. const Expr *arg0 = outerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  5121. const CXXOperatorCallExpr *innerExpr = dyn_cast<CXXOperatorCallExpr>(arg0);
  5122. // Must have an inner operator[]
  5123. if (!innerExpr ||
  5124. innerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript) {
  5125. return false;
  5126. }
  5127. const Expr *innerArg0 =
  5128. innerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  5129. const MemberExpr *memberExpr = dyn_cast<MemberExpr>(innerArg0);
  5130. if (!memberExpr)
  5131. return false;
  5132. // Must be accessing the member named "mips" or "sample"
  5133. const auto &memberName =
  5134. memberExpr->getMemberNameInfo().getName().getAsString();
  5135. if (memberName != "mips" && memberName != "sample")
  5136. return false;
  5137. const Expr *object = memberExpr->getBase();
  5138. const auto objectType = object->getType();
  5139. if (!isTexture(objectType))
  5140. return false;
  5141. if (base)
  5142. *base = object;
  5143. if (lod)
  5144. *lod = innerExpr->getArg(1);
  5145. if (location)
  5146. *location = outerExpr->getArg(1);
  5147. return true;
  5148. }
  5149. bool SpirvEmitter::isBufferTextureIndexing(const CXXOperatorCallExpr *indexExpr,
  5150. const Expr **base,
  5151. const Expr **index) {
  5152. if (!indexExpr)
  5153. return false;
  5154. // Must be operator[]
  5155. if (indexExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  5156. return false;
  5157. const Expr *object = indexExpr->getArg(0);
  5158. const auto objectType = object->getType();
  5159. if (isBuffer(objectType) || isRWBuffer(objectType) || isTexture(objectType) ||
  5160. isRWTexture(objectType)) {
  5161. if (base)
  5162. *base = object;
  5163. if (index)
  5164. *index = indexExpr->getArg(1);
  5165. return true;
  5166. }
  5167. return false;
  5168. }
  5169. void SpirvEmitter::condenseVectorElementExpr(
  5170. const HLSLVectorElementExpr *expr, const Expr **basePtr,
  5171. hlsl::VectorMemberAccessPositions *flattenedAccessor) {
  5172. llvm::SmallVector<hlsl::VectorMemberAccessPositions, 2> accessors;
  5173. *basePtr = expr;
  5174. // Recursively descending until we find the true base vector (the base vector
  5175. // that does not have a base vector). In the meanwhile, collecting accessors
  5176. // in the reverse order.
  5177. // Example: for myVector.yxwz.yxz.xx.yx, the true base is 'myVector'.
  5178. while (const auto *vecElemBase = dyn_cast<HLSLVectorElementExpr>(*basePtr)) {
  5179. accessors.push_back(vecElemBase->getEncodedElementAccess());
  5180. *basePtr = vecElemBase->getBase();
  5181. // We need to skip any number of parentheses around swizzling at any level.
  5182. while (const auto *parenExpr = dyn_cast<ParenExpr>(*basePtr))
  5183. *basePtr = parenExpr->getSubExpr();
  5184. }
  5185. *flattenedAccessor = accessors.back();
  5186. for (int32_t i = accessors.size() - 2; i >= 0; --i) {
  5187. const auto &currentAccessor = accessors[i];
  5188. // Apply the current level of accessor to the flattened accessor of all
  5189. // previous levels of ones.
  5190. hlsl::VectorMemberAccessPositions combinedAccessor;
  5191. for (uint32_t j = 0; j < currentAccessor.Count; ++j) {
  5192. uint32_t currentPosition = 0;
  5193. currentAccessor.GetPosition(j, &currentPosition);
  5194. uint32_t previousPosition = 0;
  5195. flattenedAccessor->GetPosition(currentPosition, &previousPosition);
  5196. combinedAccessor.SetPosition(j, previousPosition);
  5197. }
  5198. combinedAccessor.Count = currentAccessor.Count;
  5199. combinedAccessor.IsValid =
  5200. flattenedAccessor->IsValid && currentAccessor.IsValid;
  5201. *flattenedAccessor = combinedAccessor;
  5202. }
  5203. }
  5204. SpirvInstruction *SpirvEmitter::createVectorSplat(const Expr *scalarExpr,
  5205. uint32_t size) {
  5206. SpirvInstruction *scalarVal = nullptr;
  5207. // Try to evaluate the element as constant first. If successful, then we
  5208. // can generate constant instructions for this vector splat.
  5209. if ((scalarVal = tryToEvaluateAsConst(scalarExpr))) {
  5210. scalarVal->setRValue();
  5211. } else {
  5212. scalarVal = loadIfGLValue(scalarExpr);
  5213. }
  5214. if (!scalarVal || size == 1) {
  5215. // Just return the scalar value for vector splat with size 1.
  5216. // Note that can be used as an lvalue, so we need to carry over
  5217. // the lvalueness for non-constant cases.
  5218. return scalarVal;
  5219. }
  5220. const auto vecType = astContext.getExtVectorType(scalarExpr->getType(), size);
  5221. // TODO: we are saying the constant has Function storage class here.
  5222. // Should find a more meaningful one.
  5223. if (auto *constVal = dyn_cast<SpirvConstant>(scalarVal)) {
  5224. llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), constVal);
  5225. auto *value = spvBuilder.getConstantComposite(vecType, elements);
  5226. value->setRValue();
  5227. return value;
  5228. } else {
  5229. llvm::SmallVector<SpirvInstruction *, 4> elements(size_t(size), scalarVal);
  5230. auto *value = spvBuilder.createCompositeConstruct(
  5231. vecType, elements, scalarExpr->getLocStart());
  5232. value->setRValue();
  5233. return value;
  5234. }
  5235. }
  5236. void SpirvEmitter::splitVecLastElement(QualType vecType, SpirvInstruction *vec,
  5237. SpirvInstruction **residual,
  5238. SpirvInstruction **lastElement,
  5239. SourceLocation loc) {
  5240. assert(hlsl::IsHLSLVecType(vecType));
  5241. const uint32_t count = hlsl::GetHLSLVecSize(vecType);
  5242. assert(count > 1);
  5243. const QualType elemType = hlsl::GetHLSLVecElementType(vecType);
  5244. if (count == 2) {
  5245. *residual = spvBuilder.createCompositeExtract(elemType, vec, 0, loc);
  5246. } else {
  5247. llvm::SmallVector<uint32_t, 4> indices;
  5248. for (uint32_t i = 0; i < count - 1; ++i)
  5249. indices.push_back(i);
  5250. const QualType type = astContext.getExtVectorType(elemType, count - 1);
  5251. *residual = spvBuilder.createVectorShuffle(type, vec, vec, indices, loc);
  5252. }
  5253. *lastElement =
  5254. spvBuilder.createCompositeExtract(elemType, vec, {count - 1}, loc);
  5255. }
  5256. SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType structType,
  5257. QualType elemType,
  5258. SpirvInstruction *vector,
  5259. SourceLocation loc) {
  5260. assert(structType->isStructureType());
  5261. const auto *structDecl = structType->getAsStructureType()->getDecl();
  5262. uint32_t vectorIndex = 0;
  5263. uint32_t elemCount = 1;
  5264. llvm::SmallVector<SpirvInstruction *, 4> members;
  5265. for (const auto *field : structDecl->fields()) {
  5266. if (isScalarType(field->getType())) {
  5267. members.push_back(spvBuilder.createCompositeExtract(
  5268. elemType, vector, {vectorIndex++}, loc));
  5269. } else if (isVectorType(field->getType(), nullptr, &elemCount)) {
  5270. llvm::SmallVector<uint32_t, 4> indices;
  5271. for (uint32_t i = 0; i < elemCount; ++i)
  5272. indices.push_back(vectorIndex++);
  5273. members.push_back(spvBuilder.createVectorShuffle(
  5274. astContext.getExtVectorType(elemType, elemCount), vector, vector,
  5275. indices, loc));
  5276. } else {
  5277. assert(false && "unhandled type");
  5278. }
  5279. }
  5280. return spvBuilder.createCompositeConstruct(structType, members,
  5281. vector->getSourceLocation());
  5282. }
  5283. SpirvInstruction *
  5284. SpirvEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
  5285. const QualType type = expr->getType();
  5286. const SourceRange range = expr->getSourceRange();
  5287. QualType elemType = {};
  5288. // We can only translate floatN * float into OpVectorTimesScalar.
  5289. // So the result type must be floatN. Note that float1 is not a valid vector
  5290. // in SPIR-V.
  5291. if (!(isVectorType(type, &elemType) && elemType->isFloatingType()))
  5292. return nullptr;
  5293. const Expr *lhs = expr->getLHS();
  5294. const Expr *rhs = expr->getRHS();
  5295. // Multiplying a float vector with a float scalar will be represented in
  5296. // AST via a binary operation with two float vectors as operands; one of
  5297. // the operand is from an implicit cast with kind CK_HLSLVectorSplat.
  5298. // vector * scalar
  5299. if (hlsl::IsHLSLVecType(lhs->getType())) {
  5300. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  5301. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  5302. const QualType vecType = expr->getType();
  5303. if (isa<CompoundAssignOperator>(expr)) {
  5304. SpirvInstruction *lhsPtr = nullptr;
  5305. auto *result =
  5306. processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  5307. vecType, vecType, range, expr->getOperatorLoc(),
  5308. &lhsPtr, spv::Op::OpVectorTimesScalar);
  5309. return processAssignment(lhs, result, true, lhsPtr);
  5310. } else {
  5311. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  5312. vecType, vecType, range,
  5313. expr->getOperatorLoc(), nullptr,
  5314. spv::Op::OpVectorTimesScalar);
  5315. }
  5316. }
  5317. }
  5318. }
  5319. // scalar * vector
  5320. if (hlsl::IsHLSLVecType(rhs->getType())) {
  5321. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  5322. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  5323. const QualType vecType = expr->getType();
  5324. // We need to switch the positions of lhs and rhs here because
  5325. // OpVectorTimesScalar requires the first operand to be a vector and
  5326. // the second to be a scalar.
  5327. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  5328. vecType, vecType, range, expr->getOperatorLoc(),
  5329. nullptr, spv::Op::OpVectorTimesScalar);
  5330. }
  5331. }
  5332. }
  5333. return nullptr;
  5334. }
  5335. SpirvInstruction *
  5336. SpirvEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
  5337. const QualType type = expr->getType();
  5338. const SourceRange range = expr->getSourceRange();
  5339. // We translate 'floatMxN * float' into OpMatrixTimesScalar.
  5340. // We translate 'floatMx1 * float' and 'float1xN * float' using
  5341. // OpVectorTimesScalar.
  5342. // So the result type can be floatMxN, floatMx1, or float1xN.
  5343. if (!hlsl::IsHLSLMatType(type) ||
  5344. !hlsl::GetHLSLMatElementType(type)->isFloatingType() || is1x1Matrix(type))
  5345. return 0;
  5346. const Expr *lhs = expr->getLHS();
  5347. const Expr *rhs = expr->getRHS();
  5348. const QualType lhsType = lhs->getType();
  5349. const QualType rhsType = rhs->getType();
  5350. const auto selectOpcode = [](const QualType ty) {
  5351. return isMx1Matrix(ty) || is1xNMatrix(ty) ? spv::Op::OpVectorTimesScalar
  5352. : spv::Op::OpMatrixTimesScalar;
  5353. };
  5354. // Multiplying a float matrix with a float scalar will be represented in
  5355. // AST via a binary operation with two float matrices as operands; one of
  5356. // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
  5357. // matrix * scalar
  5358. if (hlsl::IsHLSLMatType(lhsType)) {
  5359. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  5360. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  5361. const QualType matType = expr->getType();
  5362. const spv::Op opcode = selectOpcode(lhsType);
  5363. if (isa<CompoundAssignOperator>(expr)) {
  5364. SpirvInstruction *lhsPtr = nullptr;
  5365. auto *result = processBinaryOp(
  5366. lhs, cast->getSubExpr(), expr->getOpcode(), matType, matType,
  5367. range, expr->getOperatorLoc(), &lhsPtr, opcode);
  5368. return processAssignment(lhs, result, true, lhsPtr);
  5369. } else {
  5370. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  5371. matType, matType, range,
  5372. expr->getOperatorLoc(), nullptr, opcode);
  5373. }
  5374. }
  5375. }
  5376. }
  5377. // scalar * matrix
  5378. if (hlsl::IsHLSLMatType(rhsType)) {
  5379. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  5380. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  5381. const QualType matType = expr->getType();
  5382. const spv::Op opcode = selectOpcode(rhsType);
  5383. // We need to switch the positions of lhs and rhs here because
  5384. // OpMatrixTimesScalar requires the first operand to be a matrix and
  5385. // the second to be a scalar.
  5386. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  5387. matType, matType, range, expr->getOperatorLoc(),
  5388. nullptr, opcode);
  5389. }
  5390. }
  5391. }
  5392. return nullptr;
  5393. }
  5394. SpirvInstruction *
  5395. SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
  5396. SpirvInstruction *rhs) {
  5397. // Assigning to a vector swizzling lhs is tricky if we are neither
  5398. // writing to one element nor all elements in their original order.
  5399. // Under such cases, we need to create a new vector swizzling involving
  5400. // both the lhs and rhs vectors and then write the result of this swizzling
  5401. // into the base vector of lhs.
  5402. // For example, for vec4.yz = vec2, we nee to do the following:
  5403. //
  5404. // %vec4Val = OpLoad %v4float %vec4
  5405. // %vec2Val = OpLoad %v2float %vec2
  5406. // %shuffle = OpVectorShuffle %v4float %vec4Val %vec2Val 0 4 5 3
  5407. // OpStore %vec4 %shuffle
  5408. //
  5409. // When doing the vector shuffle, we use the lhs base vector as the first
  5410. // vector and the rhs vector as the second vector. Therefore, all elements
  5411. // in the second vector will be selected into the shuffle result.
  5412. const auto *lhsExpr = dyn_cast<HLSLVectorElementExpr>(lhs);
  5413. if (!lhsExpr)
  5414. return 0;
  5415. // Special case for <scalar-value>.x, which will have an AST of
  5416. // HLSLVectorElementExpr whose base is an ImplicitCastExpr
  5417. // (CK_HLSLVectorSplat). We just need to assign to <scalar-value>
  5418. // for such case.
  5419. if (const auto *baseCast = dyn_cast<CastExpr>(lhsExpr->getBase()))
  5420. if (baseCast->getCastKind() == CastKind::CK_HLSLVectorSplat &&
  5421. hlsl::GetHLSLVecSize(baseCast->getType()) == 1)
  5422. return processAssignment(baseCast->getSubExpr(), rhs, false);
  5423. const Expr *base = nullptr;
  5424. hlsl::VectorMemberAccessPositions accessor;
  5425. condenseVectorElementExpr(lhsExpr, &base, &accessor);
  5426. const QualType baseType = base->getType();
  5427. assert(hlsl::IsHLSLVecType(baseType));
  5428. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  5429. const auto accessorSize = accessor.Count;
  5430. // Whether selecting the whole original vector
  5431. bool isSelectOrigin = accessorSize == baseSize;
  5432. // Assigning to one component
  5433. if (accessorSize == 1) {
  5434. if (isBufferTextureIndexing(dyn_cast_or_null<CXXOperatorCallExpr>(base))) {
  5435. // Assigning to one component of a RWBuffer/RWTexture element
  5436. // We need to use OpImageWrite here.
  5437. // Compose the new vector value first
  5438. auto *oldVec = doExpr(base);
  5439. auto *newVec = spvBuilder.createCompositeInsert(
  5440. baseType, oldVec, {accessor.Swz0}, rhs, lhs->getLocStart());
  5441. auto *result = tryToAssignToRWBufferRWTexture(base, newVec);
  5442. assert(result); // Definitely RWBuffer/RWTexture assignment
  5443. (void)result;
  5444. return rhs; // TODO: incorrect for compound assignments
  5445. } else {
  5446. // Assigning to one component of mesh out attribute/indices vector object.
  5447. SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
  5448. astContext.UnsignedIntTy, llvm::APInt(32, accessor.Swz0));
  5449. if (tryToAssignToMSOutAttrsOrIndices(base, rhs, vecComponent))
  5450. return rhs;
  5451. // Assigning to one normal vector component. Nothing special, just fall
  5452. // back to the normal CodeGen path.
  5453. return nullptr;
  5454. }
  5455. }
  5456. if (isSelectOrigin) {
  5457. for (uint32_t i = 0; i < accessorSize; ++i) {
  5458. uint32_t position;
  5459. accessor.GetPosition(i, &position);
  5460. if (position != i)
  5461. isSelectOrigin = false;
  5462. }
  5463. }
  5464. // Assigning to the original vector
  5465. if (isSelectOrigin) {
  5466. // Ignore this HLSLVectorElementExpr and dispatch to base
  5467. return processAssignment(base, rhs, false);
  5468. }
  5469. if (tryToAssignToMSOutAttrsOrIndices(base, rhs, /*vecComponent=*/nullptr,
  5470. /*noWriteBack=*/true)) {
  5471. // Assigning to 'n' components of mesh out attribute/indices vector object.
  5472. const QualType elemType =
  5473. hlsl::GetHLSLVecElementType(rhs->getAstResultType());
  5474. uint32_t i = 0;
  5475. for (; i < accessor.Count; ++i) {
  5476. auto *rhsElem = spvBuilder.createCompositeExtract(elemType, rhs, {i},
  5477. lhs->getLocStart());
  5478. uint32_t position;
  5479. accessor.GetPosition(i, &position);
  5480. SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
  5481. astContext.UnsignedIntTy, llvm::APInt(32, position));
  5482. if (!tryToAssignToMSOutAttrsOrIndices(base, rhsElem, vecComponent))
  5483. break;
  5484. }
  5485. assert(i == accessor.Count);
  5486. return rhs;
  5487. }
  5488. llvm::SmallVector<uint32_t, 4> selectors;
  5489. selectors.resize(baseSize);
  5490. // Assume we are selecting all original elements first.
  5491. for (uint32_t i = 0; i < baseSize; ++i) {
  5492. selectors[i] = i;
  5493. }
  5494. // Now fix up the elements that actually got overwritten by the rhs vector.
  5495. // Since we are using the rhs vector as the second vector, their index
  5496. // should be offset'ed by the size of the lhs base vector.
  5497. for (uint32_t i = 0; i < accessor.Count; ++i) {
  5498. uint32_t position;
  5499. accessor.GetPosition(i, &position);
  5500. selectors[position] = baseSize + i;
  5501. }
  5502. auto *vec1 = doExpr(base);
  5503. auto *vec1Val = vec1->isRValue() ? vec1
  5504. : spvBuilder.createLoad(baseType, vec1,
  5505. base->getLocStart());
  5506. auto *shuffle = spvBuilder.createVectorShuffle(baseType, vec1Val, rhs,
  5507. selectors, lhs->getLocStart());
  5508. if (!tryToAssignToRWBufferRWTexture(base, shuffle))
  5509. spvBuilder.createStore(vec1, shuffle, lhs->getLocStart());
  5510. // TODO: OK, this return value is incorrect for compound assignments, for
  5511. // which cases we should return lvalues. Should at least emit errors if
  5512. // this return value is used (can be checked via ASTContext.getParents).
  5513. return rhs;
  5514. }
  5515. SpirvInstruction *
  5516. SpirvEmitter::tryToAssignToRWBufferRWTexture(const Expr *lhs,
  5517. SpirvInstruction *rhs) {
  5518. const Expr *baseExpr = nullptr;
  5519. const Expr *indexExpr = nullptr;
  5520. const auto lhsExpr = dyn_cast<CXXOperatorCallExpr>(lhs);
  5521. if (isBufferTextureIndexing(lhsExpr, &baseExpr, &indexExpr)) {
  5522. auto *loc = doExpr(indexExpr);
  5523. const QualType imageType = baseExpr->getType();
  5524. auto *baseInfo = doExpr(baseExpr);
  5525. auto *image =
  5526. spvBuilder.createLoad(imageType, baseInfo, baseExpr->getExprLoc());
  5527. spvBuilder.createImageWrite(imageType, image, loc, rhs, lhs->getExprLoc());
  5528. return rhs;
  5529. }
  5530. return nullptr;
  5531. }
  5532. SpirvInstruction *
  5533. SpirvEmitter::tryToAssignToMatrixElements(const Expr *lhs,
  5534. SpirvInstruction *rhs) {
  5535. const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
  5536. if (!lhsExpr)
  5537. return nullptr;
  5538. const Expr *baseMat = lhsExpr->getBase();
  5539. auto *base = doExpr(baseMat);
  5540. const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
  5541. uint32_t rowCount = 0, colCount = 0;
  5542. hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
  5543. // For each lhs element written to:
  5544. // 1. Extract the corresponding rhs element using OpCompositeExtract
  5545. // 2. Create access chain for the lhs element using OpAccessChain
  5546. // 3. Write using OpStore
  5547. const auto accessor = lhsExpr->getEncodedElementAccess();
  5548. for (uint32_t i = 0; i < accessor.Count; ++i) {
  5549. uint32_t row = 0, col = 0;
  5550. accessor.GetPosition(i, &row, &col);
  5551. llvm::SmallVector<uint32_t, 2> indices;
  5552. // If the matrix only have one row/column, we are indexing into a vector
  5553. // then. Only one index is needed for such cases.
  5554. if (rowCount > 1)
  5555. indices.push_back(row);
  5556. if (colCount > 1)
  5557. indices.push_back(col);
  5558. llvm::SmallVector<SpirvInstruction *, 2> indexInstructions(indices.size(),
  5559. nullptr);
  5560. for (uint32_t i = 0; i < indices.size(); ++i)
  5561. indexInstructions[i] = spvBuilder.getConstantInt(
  5562. astContext.IntTy, llvm::APInt(32, indices[i], true));
  5563. // If we are writing to only one element, the rhs should already be a
  5564. // scalar value.
  5565. auto *rhsElem = rhs;
  5566. if (accessor.Count > 1) {
  5567. rhsElem = spvBuilder.createCompositeExtract(elemType, rhs, {i},
  5568. rhs->getSourceLocation());
  5569. }
  5570. // If the lhs is actually a matrix of size 1x1, we don't need the access
  5571. // chain. base is already the dest pointer.
  5572. auto *lhsElemPtr = base;
  5573. if (!indexInstructions.empty()) {
  5574. assert(!base->isRValue());
  5575. // Load the element via access chain
  5576. lhsElemPtr = spvBuilder.createAccessChain(
  5577. elemType, lhsElemPtr, indexInstructions, lhs->getLocStart());
  5578. }
  5579. spvBuilder.createStore(lhsElemPtr, rhsElem, lhs->getLocStart());
  5580. }
  5581. // TODO: OK, this return value is incorrect for compound assignments, for
  5582. // which cases we should return lvalues. Should at least emit errors if
  5583. // this return value is used (can be checked via ASTContext.getParents).
  5584. return rhs;
  5585. }
  5586. SpirvInstruction *SpirvEmitter::tryToAssignToMSOutAttrsOrIndices(
  5587. const Expr *lhs, SpirvInstruction *rhs, SpirvInstruction *vecComponent,
  5588. bool noWriteBack) {
  5589. // Early exit for non-mesh shaders.
  5590. if (!spvContext.isMS())
  5591. return nullptr;
  5592. llvm::SmallVector<SpirvInstruction *, 4> indices;
  5593. bool isMSOutAttribute = false;
  5594. bool isMSOutAttributeBlock = false;
  5595. bool isMSOutIndices = false;
  5596. const Expr *base = collectArrayStructIndices(lhs, /*rawIndex*/ false,
  5597. /*rawIndices*/ nullptr, &indices,
  5598. &isMSOutAttribute);
  5599. // Expecting at least one array index - early exit.
  5600. if (!base || indices.empty())
  5601. return nullptr;
  5602. const DeclaratorDecl *varDecl = nullptr;
  5603. if (isMSOutAttribute) {
  5604. const MemberExpr *memberExpr = dyn_cast<MemberExpr>(base);
  5605. assert(memberExpr);
  5606. varDecl = cast<DeclaratorDecl>(memberExpr->getMemberDecl());
  5607. } else {
  5608. if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
  5609. if ((varDecl = dyn_cast<DeclaratorDecl>(arg->getDecl()))) {
  5610. if (varDecl->hasAttr<HLSLIndicesAttr>()) {
  5611. isMSOutIndices = true;
  5612. } else if (varDecl->hasAttr<HLSLVerticesAttr>() ||
  5613. varDecl->hasAttr<HLSLPrimitivesAttr>()) {
  5614. isMSOutAttributeBlock = true;
  5615. }
  5616. }
  5617. }
  5618. }
  5619. // Return if no out attribute or indices object found.
  5620. if (!(isMSOutAttribute || isMSOutAttributeBlock || isMSOutIndices)) {
  5621. return nullptr;
  5622. }
  5623. // For noWriteBack, return without generating write instructions.
  5624. if (noWriteBack) {
  5625. return rhs;
  5626. }
  5627. // Add vecComponent to indices.
  5628. if (vecComponent) {
  5629. indices.push_back(vecComponent);
  5630. }
  5631. if (isMSOutAttribute) {
  5632. assignToMSOutAttribute(varDecl, rhs, indices);
  5633. } else if (isMSOutIndices) {
  5634. assignToMSOutIndices(varDecl, rhs, indices);
  5635. } else {
  5636. assert(isMSOutAttributeBlock);
  5637. QualType type = varDecl->getType();
  5638. assert(isa<ConstantArrayType>(type));
  5639. type = astContext.getAsConstantArrayType(type)->getElementType();
  5640. assert(type->isStructureType());
  5641. // Extract subvalue and assign to its corresponding member attribute.
  5642. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  5643. for (const auto *field : structDecl->fields()) {
  5644. const auto fieldType = field->getType();
  5645. SpirvInstruction *subValue = spvBuilder.createCompositeExtract(
  5646. fieldType, rhs, {getNumBaseClasses(type) + field->getFieldIndex()},
  5647. lhs->getLocStart());
  5648. assignToMSOutAttribute(field, subValue, indices);
  5649. }
  5650. }
  5651. // TODO: OK, this return value is incorrect for compound assignments, for
  5652. // which cases we should return lvalues. Should at least emit errors if
  5653. // this return value is used (can be checked via ASTContext.getParents).
  5654. return rhs;
  5655. }
  5656. void SpirvEmitter::assignToMSOutAttribute(
  5657. const DeclaratorDecl *decl, SpirvInstruction *value,
  5658. const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
  5659. assert(spvContext.isMS() && !indices.empty());
  5660. // Extract attribute index and vecComponent (if any).
  5661. SpirvInstruction *attrIndex = indices.front();
  5662. SpirvInstruction *vecComponent = nullptr;
  5663. if (indices.size() > 1) {
  5664. vecComponent = indices.back();
  5665. }
  5666. auto semanticInfo = declIdMapper.getStageVarSemantic(decl);
  5667. assert(semanticInfo.isValid());
  5668. const auto loc = decl->getLocation();
  5669. // Special handle writes to clip/cull distance attributes.
  5670. if (!declIdMapper.glPerVertex.tryToAccess(
  5671. hlsl::DXIL::SigPointKind::MSOut, semanticInfo.semantic->GetKind(),
  5672. semanticInfo.index, attrIndex, &value, /*noWriteBack=*/false,
  5673. vecComponent, loc)) {
  5674. // All other attribute writes are handled below.
  5675. auto *varInstr = declIdMapper.getStageVarInstruction(decl);
  5676. QualType valueType = value->getAstResultType();
  5677. varInstr = spvBuilder.createAccessChain(valueType, varInstr, indices, loc);
  5678. spvBuilder.createStore(varInstr, value, loc);
  5679. }
  5680. }
  5681. void SpirvEmitter::assignToMSOutIndices(
  5682. const DeclaratorDecl *decl, SpirvInstruction *value,
  5683. const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
  5684. assert(spvContext.isMS() && !indices.empty());
  5685. // Extract vertex index and vecComponent (if any).
  5686. SpirvInstruction *vertIndex = indices.front();
  5687. SpirvInstruction *vecComponent = nullptr;
  5688. if (indices.size() > 1) {
  5689. vecComponent = indices.back();
  5690. }
  5691. auto *var = declIdMapper.getStageVarInstruction(decl);
  5692. const auto *varTypeDecl = astContext.getAsConstantArrayType(decl->getType());
  5693. QualType varType = varTypeDecl->getElementType();
  5694. uint32_t numVertices = 1;
  5695. if (!isVectorType(varType, nullptr, &numVertices)) {
  5696. assert(isScalarType(varType));
  5697. }
  5698. QualType valueType = value->getAstResultType();
  5699. uint32_t numValues = 1;
  5700. if (!isVectorType(valueType, nullptr, &numValues)) {
  5701. assert(isScalarType(valueType));
  5702. }
  5703. const auto loc = decl->getLocation();
  5704. if (numVertices == 1) {
  5705. // for "point" output topology.
  5706. assert(numValues == 1);
  5707. // create accesschain for PrimitiveIndicesNV[vertIndex].
  5708. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
  5709. {vertIndex}, loc);
  5710. // finally create store for PrimitiveIndicesNV[vertIndex] = value.
  5711. spvBuilder.createStore(ptr, value, loc);
  5712. } else {
  5713. // for "line" or "triangle" output topology.
  5714. assert(numVertices == 2 || numVertices == 3);
  5715. // set baseOffset = vertIndex * numVertices.
  5716. auto *baseOffset = spvBuilder.createBinaryOp(
  5717. spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
  5718. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  5719. llvm::APInt(32, numVertices)),
  5720. loc);
  5721. if (vecComponent) {
  5722. // write an individual vector component of uint2 or uint3.
  5723. assert(numValues == 1);
  5724. // set baseOffset = baseOffset + vecComponent.
  5725. baseOffset =
  5726. spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
  5727. baseOffset, vecComponent, loc);
  5728. // create accesschain for PrimitiveIndicesNV[baseOffset].
  5729. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
  5730. {baseOffset}, loc);
  5731. // finally create store for PrimitiveIndicesNV[baseOffset] = value.
  5732. spvBuilder.createStore(ptr, value, loc);
  5733. } else {
  5734. // write all vector components of uint2 or uint3.
  5735. assert(numValues == numVertices);
  5736. auto *curOffset = baseOffset;
  5737. for (uint32_t i = 0; i < numValues; ++i) {
  5738. if (i != 0) {
  5739. // set curOffset = baseOffset + i.
  5740. curOffset = spvBuilder.createBinaryOp(
  5741. spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
  5742. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  5743. llvm::APInt(32, i)),
  5744. loc);
  5745. }
  5746. // create accesschain for PrimitiveIndicesNV[curOffset].
  5747. auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
  5748. {curOffset}, loc);
  5749. // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
  5750. spvBuilder.createStore(ptr,
  5751. spvBuilder.createCompositeExtract(
  5752. astContext.UnsignedIntTy, value, {i}, loc),
  5753. loc);
  5754. }
  5755. }
  5756. }
  5757. }
  5758. SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
  5759. const Expr *matrix, SpirvInstruction *matrixVal,
  5760. llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
  5761. SpirvInstruction *)>
  5762. actOnEachVector,
  5763. SourceLocation loc) {
  5764. const auto matType = matrix->getType();
  5765. assert(isMxNMatrix(matType));
  5766. const QualType vecType = getComponentVectorType(astContext, matType);
  5767. uint32_t rowCount = 0, colCount = 0;
  5768. hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
  5769. llvm::SmallVector<SpirvInstruction *, 4> vectors;
  5770. // Extract each component vector and do operation on it
  5771. for (uint32_t i = 0; i < rowCount; ++i) {
  5772. auto *lhsVec = spvBuilder.createCompositeExtract(vecType, matrixVal, {i},
  5773. matrix->getLocStart());
  5774. vectors.push_back(actOnEachVector(i, vecType, lhsVec));
  5775. }
  5776. // Construct the result matrix
  5777. auto *val = spvBuilder.createCompositeConstruct(matType, vectors, loc);
  5778. val->setRValue();
  5779. return val;
  5780. }
  5781. void SpirvEmitter::createSpecConstant(const VarDecl *varDecl) {
  5782. class SpecConstantEnvRAII {
  5783. public:
  5784. // Creates a new instance which sets mode to true on creation,
  5785. // and resets mode to false on destruction.
  5786. SpecConstantEnvRAII(bool *mode) : modeSlot(mode) { *modeSlot = true; }
  5787. ~SpecConstantEnvRAII() { *modeSlot = false; }
  5788. private:
  5789. bool *modeSlot;
  5790. };
  5791. const QualType varType = varDecl->getType();
  5792. bool hasError = false;
  5793. if (!varDecl->isExternallyVisible()) {
  5794. emitError("specialization constant must be externally visible",
  5795. varDecl->getLocation());
  5796. hasError = true;
  5797. }
  5798. if (const auto *builtinType = varType->getAs<BuiltinType>()) {
  5799. switch (builtinType->getKind()) {
  5800. case BuiltinType::Bool:
  5801. case BuiltinType::Int:
  5802. case BuiltinType::UInt:
  5803. case BuiltinType::Float:
  5804. break;
  5805. default:
  5806. emitError("unsupported specialization constant type",
  5807. varDecl->getLocStart());
  5808. hasError = true;
  5809. }
  5810. }
  5811. const auto *init = varDecl->getInit();
  5812. if (!init) {
  5813. emitError("missing default value for specialization constant",
  5814. varDecl->getLocation());
  5815. hasError = true;
  5816. } else if (!isAcceptedSpecConstantInit(init)) {
  5817. emitError("unsupported specialization constant initializer",
  5818. init->getLocStart())
  5819. << init->getSourceRange();
  5820. hasError = true;
  5821. }
  5822. if (hasError)
  5823. return;
  5824. SpecConstantEnvRAII specConstantEnvRAII(&isSpecConstantMode);
  5825. const auto specConstant = doExpr(init);
  5826. // We are not creating a variable to hold the spec constant, instead, we
  5827. // translate the varDecl directly into the spec constant here.
  5828. spvBuilder.decorateSpecId(
  5829. specConstant, varDecl->getAttr<VKConstantIdAttr>()->getSpecConstId(),
  5830. varDecl->getLocation());
  5831. specConstant->setDebugName(varDecl->getName());
  5832. declIdMapper.registerSpecConstant(varDecl, specConstant);
  5833. }
  5834. SpirvInstruction *
  5835. SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
  5836. const BinaryOperatorKind opcode,
  5837. SourceRange range, SourceLocation loc) {
  5838. // TODO: some code are duplicated from processBinaryOp. Try to unify them.
  5839. const auto lhsType = lhs->getType();
  5840. assert(isMxNMatrix(lhsType));
  5841. const spv::Op spvOp = translateOp(opcode, lhsType);
  5842. SpirvInstruction *rhsVal = nullptr, *lhsPtr = nullptr, *lhsVal = nullptr;
  5843. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  5844. // Evalute rhs before lhs
  5845. rhsVal = doExpr(rhs);
  5846. lhsPtr = doExpr(lhs);
  5847. lhsVal = spvBuilder.createLoad(lhsType, lhsPtr, lhs->getLocStart());
  5848. } else {
  5849. // Evalute lhs before rhs
  5850. lhsVal = lhsPtr = doExpr(lhs);
  5851. rhsVal = doExpr(rhs);
  5852. }
  5853. switch (opcode) {
  5854. case BO_Add:
  5855. case BO_Sub:
  5856. case BO_Mul:
  5857. case BO_Div:
  5858. case BO_Rem:
  5859. case BO_AddAssign:
  5860. case BO_SubAssign:
  5861. case BO_MulAssign:
  5862. case BO_DivAssign:
  5863. case BO_RemAssign: {
  5864. const auto actOnEachVec = [this, spvOp, rhsVal, rhs,
  5865. loc](uint32_t index, QualType vecType,
  5866. SpirvInstruction *lhsVec) {
  5867. // For each vector of lhs, we need to load the corresponding vector of
  5868. // rhs and do the operation on them.
  5869. auto *rhsVec = spvBuilder.createCompositeExtract(vecType, rhsVal, {index},
  5870. rhs->getLocStart());
  5871. auto *val =
  5872. spvBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec, loc);
  5873. val->setRValue();
  5874. return val;
  5875. };
  5876. return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec,
  5877. lhs->getLocStart());
  5878. }
  5879. case BO_Assign:
  5880. llvm_unreachable("assignment should not be handled here");
  5881. default:
  5882. break;
  5883. }
  5884. emitError("binary operator '%0' over matrix type unimplemented",
  5885. lhs->getExprLoc())
  5886. << BinaryOperator::getOpcodeStr(opcode) << range;
  5887. return nullptr;
  5888. }
  5889. const Expr *SpirvEmitter::collectArrayStructIndices(
  5890. const Expr *expr, bool rawIndex,
  5891. llvm::SmallVectorImpl<uint32_t> *rawIndices,
  5892. llvm::SmallVectorImpl<SpirvInstruction *> *indices,
  5893. bool *isMSOutAttribute) {
  5894. assert((rawIndex && rawIndices) || (!rawIndex && indices));
  5895. if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
  5896. // First check whether this is referring to a static member. If it is, we
  5897. // create a DeclRefExpr for it.
  5898. if (auto *varDecl = dyn_cast<VarDecl>(indexing->getMemberDecl()))
  5899. if (varDecl->isStaticDataMember())
  5900. return DeclRefExpr::Create(
  5901. astContext, NestedNameSpecifierLoc(), SourceLocation(), varDecl,
  5902. /*RefersToEnclosingVariableOrCapture=*/false, SourceLocation(),
  5903. varDecl->getType(), VK_LValue);
  5904. const Expr *base = collectArrayStructIndices(
  5905. indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
  5906. rawIndices, indices, isMSOutAttribute);
  5907. if (isMSOutAttribute && base) {
  5908. if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
  5909. if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
  5910. if (varDecl->hasAttr<HLSLVerticesAttr>() ||
  5911. varDecl->hasAttr<HLSLPrimitivesAttr>()) {
  5912. assert(spvContext.isMS());
  5913. *isMSOutAttribute = true;
  5914. return expr;
  5915. }
  5916. }
  5917. }
  5918. }
  5919. // Append the index of the current level
  5920. const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
  5921. assert(fieldDecl);
  5922. // If we are accessing a derived struct, we need to account for the number
  5923. // of base structs, since they are placed as fields at the beginning of the
  5924. // derived struct.
  5925. auto baseType = indexing->getBase()->getType();
  5926. if (baseType->isPointerType()) {
  5927. baseType = baseType->getPointeeType();
  5928. }
  5929. const uint32_t index =
  5930. getNumBaseClasses(baseType) + fieldDecl->getFieldIndex();
  5931. if (rawIndex) {
  5932. rawIndices->push_back(index);
  5933. } else {
  5934. indices->push_back(spvBuilder.getConstantInt(
  5935. astContext.IntTy, llvm::APInt(32, index, true)));
  5936. }
  5937. return base;
  5938. }
  5939. if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
  5940. if (rawIndex)
  5941. return nullptr; // TODO: handle constant array index
  5942. // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
  5943. // cast. We need to ingore it to avoid creating OpLoad.
  5944. const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
  5945. const Expr *base = collectArrayStructIndices(thisBase, rawIndex, rawIndices,
  5946. indices, isMSOutAttribute);
  5947. // The index into an array must be an integer number.
  5948. const auto *idxExpr = indexing->getIdx();
  5949. const auto idxExprType = idxExpr->getType();
  5950. SpirvInstruction *thisIndex = doExpr(idxExpr);
  5951. if (!idxExprType->isIntegerType() || idxExprType->isBooleanType()) {
  5952. thisIndex = castToInt(thisIndex, idxExprType, astContext.UnsignedIntTy,
  5953. idxExpr->getExprLoc());
  5954. }
  5955. indices->push_back(thisIndex);
  5956. return base;
  5957. }
  5958. if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
  5959. if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
  5960. if (rawIndex)
  5961. return nullptr; // TODO: handle constant array index
  5962. // If this is indexing into resources, we need specific OpImage*
  5963. // instructions for accessing. Return directly to avoid further building
  5964. // up the access chain.
  5965. if (isBufferTextureIndexing(indexing))
  5966. return indexing;
  5967. const Expr *thisBase =
  5968. indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
  5969. const auto thisBaseType = thisBase->getType();
  5970. const Expr *base = collectArrayStructIndices(
  5971. thisBase, rawIndex, rawIndices, indices, isMSOutAttribute);
  5972. if (thisBaseType != base->getType() &&
  5973. isAKindOfStructuredOrByteBuffer(thisBaseType)) {
  5974. // The immediate base is a kind of structured or byte buffer. It should
  5975. // be an alias variable. Break the normal index collecting chain.
  5976. // Return the immediate base as the base so that we can apply other
  5977. // hacks for legalization over it.
  5978. //
  5979. // Note: legalization specific code
  5980. indices->clear();
  5981. base = thisBase;
  5982. }
  5983. // If the base is a StructureType, we need to push an addtional index 0
  5984. // here. This is because we created an additional OpTypeRuntimeArray
  5985. // in the structure.
  5986. if (isStructuredBuffer(thisBaseType))
  5987. indices->push_back(
  5988. spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)));
  5989. if ((hlsl::IsHLSLVecType(thisBaseType) &&
  5990. (hlsl::GetHLSLVecSize(thisBaseType) == 1)) ||
  5991. is1x1Matrix(thisBaseType) || is1xNMatrix(thisBaseType)) {
  5992. // If this is a size-1 vector or 1xN matrix, ignore the index.
  5993. } else {
  5994. indices->push_back(doExpr(indexing->getArg(1)));
  5995. }
  5996. return base;
  5997. }
  5998. {
  5999. const Expr *index = nullptr;
  6000. // TODO: the following is duplicating the logic in doCXXMemberCallExpr.
  6001. if (const auto *object = isStructuredBufferLoad(expr, &index)) {
  6002. if (rawIndex)
  6003. return nullptr; // TODO: handle constant array index
  6004. // For object.Load(index), there should be no more indexing into the
  6005. // object.
  6006. indices->push_back(
  6007. spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)));
  6008. indices->push_back(doExpr(index));
  6009. return object;
  6010. }
  6011. }
  6012. {
  6013. // Indexing into ConstantBuffers and TextureBuffers involves an additional
  6014. // FlatConversion node which casts the handle to the underlying structure
  6015. // type. We can look past the FlatConversion to continue to collect indices.
  6016. // For example: MyConstantBufferArray[0].structMember1
  6017. // `-MemberExpr .structMember1
  6018. // `-ImplicitCastExpr 'const T' lvalue <FlatConversion>
  6019. // `-ArraySubscriptExpr 'ConstantBuffer<T>':'ConstantBuffer<T>' lvalue
  6020. if (auto *castExpr = dyn_cast<ImplicitCastExpr>(expr)) {
  6021. if (castExpr->getCastKind() == CK_FlatConversion) {
  6022. const auto *subExpr = castExpr->getSubExpr();
  6023. const QualType subExprType = subExpr->getType();
  6024. if (isConstantTextureBuffer(subExprType)) {
  6025. return collectArrayStructIndices(subExpr, rawIndex, rawIndices,
  6026. indices, isMSOutAttribute);
  6027. }
  6028. }
  6029. }
  6030. }
  6031. // This the deepest we can go. No more array or struct indexing.
  6032. return expr;
  6033. }
  6034. SpirvInstruction *SpirvEmitter::turnIntoElementPtr(
  6035. QualType baseType, SpirvInstruction *base, QualType elemType,
  6036. const llvm::SmallVector<SpirvInstruction *, 4> &indices,
  6037. SourceLocation loc) {
  6038. // If this is a rvalue, we need a temporary object to hold it
  6039. // so that we can get access chain from it.
  6040. const bool needTempVar = base->isRValue();
  6041. SpirvInstruction *accessChainBase = base;
  6042. if (needTempVar) {
  6043. auto varName = getAstTypeName(baseType);
  6044. const auto var = createTemporaryVar(baseType, varName, base, loc);
  6045. var->setLayoutRule(SpirvLayoutRule::Void);
  6046. var->setStorageClass(spv::StorageClass::Function);
  6047. var->setContainsAliasComponent(base->containsAliasComponent());
  6048. accessChainBase = var;
  6049. }
  6050. base = spvBuilder.createAccessChain(elemType, accessChainBase, indices, loc);
  6051. // Okay, this part seems weird, but it is intended:
  6052. // If the base is originally a rvalue, the whole AST involving the base
  6053. // is consistently set up to handle rvalues. By copying the base into
  6054. // a temporary variable and grab an access chain from it, we are breaking
  6055. // the consistency by turning the base from rvalue into lvalue. Keep in
  6056. // mind that there will be no LValueToRValue casts in the AST for us
  6057. // to rely on to load the access chain if a rvalue is expected. Therefore,
  6058. // we must do the load here. Otherwise, it's up to the consumer of this
  6059. // access chain to do the load, and that can be everywhere.
  6060. if (needTempVar) {
  6061. base = spvBuilder.createLoad(elemType, base, loc);
  6062. }
  6063. return base;
  6064. }
  6065. SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
  6066. QualType fromType,
  6067. QualType toBoolType,
  6068. SourceLocation loc) {
  6069. if (isSameType(astContext, fromType, toBoolType))
  6070. return fromVal;
  6071. { // Special case handling for converting to a matrix of booleans.
  6072. QualType elemType = {};
  6073. uint32_t rowCount = 0, colCount = 0;
  6074. if (isMxNMatrix(fromType, &elemType, &rowCount, &colCount)) {
  6075. const auto fromRowQualType =
  6076. astContext.getExtVectorType(elemType, colCount);
  6077. const auto toBoolRowQualType =
  6078. astContext.getExtVectorType(astContext.BoolTy, colCount);
  6079. llvm::SmallVector<SpirvInstruction *, 4> rows;
  6080. for (uint32_t i = 0; i < rowCount; ++i) {
  6081. auto *row = spvBuilder.createCompositeExtract(fromRowQualType, fromVal,
  6082. {i}, loc);
  6083. rows.push_back(
  6084. castToBool(row, fromRowQualType, toBoolRowQualType, loc));
  6085. }
  6086. return spvBuilder.createCompositeConstruct(toBoolType, rows, loc);
  6087. }
  6088. }
  6089. // Converting to bool means comparing with value zero.
  6090. const spv::Op spvOp = translateOp(BO_NE, fromType);
  6091. auto *zeroVal = getValueZero(fromType);
  6092. return spvBuilder.createBinaryOp(spvOp, toBoolType, fromVal, zeroVal, loc);
  6093. }
  6094. SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
  6095. QualType fromType, QualType toIntType,
  6096. SourceLocation srcLoc) {
  6097. if (isEnumType(fromType))
  6098. fromType = astContext.IntTy;
  6099. if (isSameType(astContext, fromType, toIntType))
  6100. return fromVal;
  6101. if (isBoolOrVecOfBoolType(fromType)) {
  6102. auto *one = getValueOne(toIntType);
  6103. auto *zero = getValueZero(toIntType);
  6104. return spvBuilder.createSelect(toIntType, fromVal, one, zero, srcLoc);
  6105. }
  6106. if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
  6107. // First convert the source to the bitwidth of the destination if necessary.
  6108. QualType convertedType = {};
  6109. fromVal =
  6110. convertBitwidth(fromVal, srcLoc, fromType, toIntType, &convertedType);
  6111. // If bitwidth conversion was the only thing we needed to do, we're done.
  6112. if (isSameScalarOrVecType(convertedType, toIntType))
  6113. return fromVal;
  6114. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toIntType, fromVal,
  6115. srcLoc);
  6116. }
  6117. if (isFloatOrVecOfFloatType(fromType)) {
  6118. // First convert the source to the bitwidth of the destination if necessary.
  6119. fromVal = convertBitwidth(fromVal, srcLoc, fromType, toIntType);
  6120. if (isSintOrVecOfSintType(toIntType)) {
  6121. return spvBuilder.createUnaryOp(spv::Op::OpConvertFToS, toIntType,
  6122. fromVal, srcLoc);
  6123. } else if (isUintOrVecOfUintType(toIntType)) {
  6124. return spvBuilder.createUnaryOp(spv::Op::OpConvertFToU, toIntType,
  6125. fromVal, srcLoc);
  6126. } else {
  6127. emitError("casting from floating point to integer unimplemented", srcLoc);
  6128. }
  6129. }
  6130. {
  6131. QualType elemType = {};
  6132. uint32_t numRows = 0, numCols = 0;
  6133. if (isMxNMatrix(fromType, &elemType, &numRows, &numCols)) {
  6134. // The source matrix and the target matrix must have the same dimensions.
  6135. QualType toElemType = {};
  6136. uint32_t toNumRows = 0, toNumCols = 0;
  6137. const bool isMat =
  6138. isMxNMatrix(toIntType, &toElemType, &toNumRows, &toNumCols);
  6139. assert(isMat && numRows == toNumRows && numCols == toNumCols);
  6140. (void)isMat;
  6141. (void)toNumRows;
  6142. (void)toNumCols;
  6143. // Casting to a matrix of integers: Cast each row and construct a
  6144. // composite.
  6145. llvm::SmallVector<SpirvInstruction *, 4> castedRows;
  6146. const QualType vecType = getComponentVectorType(astContext, fromType);
  6147. const auto fromVecQualType =
  6148. astContext.getExtVectorType(elemType, numCols);
  6149. const auto toIntVecQualType =
  6150. astContext.getExtVectorType(toElemType, numCols);
  6151. for (uint32_t row = 0; row < numRows; ++row) {
  6152. auto *rowId =
  6153. spvBuilder.createCompositeExtract(vecType, fromVal, {row}, srcLoc);
  6154. castedRows.push_back(
  6155. castToInt(rowId, fromVecQualType, toIntVecQualType, srcLoc));
  6156. }
  6157. return spvBuilder.createCompositeConstruct(toIntType, castedRows, srcLoc);
  6158. }
  6159. }
  6160. return nullptr;
  6161. }
  6162. SpirvInstruction *SpirvEmitter::convertBitwidth(SpirvInstruction *fromVal,
  6163. SourceLocation loc,
  6164. QualType fromType,
  6165. QualType toType,
  6166. QualType *resultType) {
  6167. // At the moment, we will not make bitwidth conversions to/from literal int
  6168. // and literal float types because they do not represent the intended SPIR-V
  6169. // bitwidth.
  6170. if (isLitTypeOrVecOfLitType(fromType) || isLitTypeOrVecOfLitType(toType))
  6171. return fromVal;
  6172. const auto fromBitwidth = getElementSpirvBitwidth(
  6173. astContext, fromType, spirvOptions.enable16BitTypes);
  6174. const auto toBitwidth = getElementSpirvBitwidth(
  6175. astContext, toType, spirvOptions.enable16BitTypes);
  6176. if (fromBitwidth == toBitwidth) {
  6177. if (resultType)
  6178. *resultType = fromType;
  6179. return fromVal;
  6180. }
  6181. // We want the 'fromType' with the 'toBitwidth'.
  6182. const QualType targetType =
  6183. getTypeWithCustomBitwidth(astContext, fromType, toBitwidth);
  6184. if (resultType)
  6185. *resultType = targetType;
  6186. if (isFloatOrVecOfFloatType(fromType))
  6187. return spvBuilder.createUnaryOp(spv::Op::OpFConvert, targetType, fromVal,
  6188. loc);
  6189. if (isSintOrVecOfSintType(fromType))
  6190. return spvBuilder.createUnaryOp(spv::Op::OpSConvert, targetType, fromVal,
  6191. loc);
  6192. if (isUintOrVecOfUintType(fromType))
  6193. return spvBuilder.createUnaryOp(spv::Op::OpUConvert, targetType, fromVal,
  6194. loc);
  6195. llvm_unreachable("invalid type passed to convertBitwidth");
  6196. }
  6197. SpirvInstruction *SpirvEmitter::castToFloat(SpirvInstruction *fromVal,
  6198. QualType fromType,
  6199. QualType toFloatType,
  6200. SourceLocation srcLoc) {
  6201. if (isSameType(astContext, fromType, toFloatType))
  6202. return fromVal;
  6203. if (isBoolOrVecOfBoolType(fromType)) {
  6204. auto *one = getValueOne(toFloatType);
  6205. auto *zero = getValueZero(toFloatType);
  6206. return spvBuilder.createSelect(toFloatType, fromVal, one, zero, srcLoc);
  6207. }
  6208. if (isSintOrVecOfSintType(fromType)) {
  6209. // First convert the source to the bitwidth of the destination if necessary.
  6210. fromVal = convertBitwidth(fromVal, srcLoc, fromType, toFloatType);
  6211. return spvBuilder.createUnaryOp(spv::Op::OpConvertSToF, toFloatType,
  6212. fromVal, srcLoc);
  6213. }
  6214. if (isUintOrVecOfUintType(fromType)) {
  6215. // First convert the source to the bitwidth of the destination if necessary.
  6216. fromVal = convertBitwidth(fromVal, srcLoc, fromType, toFloatType);
  6217. return spvBuilder.createUnaryOp(spv::Op::OpConvertUToF, toFloatType,
  6218. fromVal, srcLoc);
  6219. }
  6220. if (isFloatOrVecOfFloatType(fromType)) {
  6221. // This is the case of float to float conversion with different bitwidths.
  6222. return convertBitwidth(fromVal, srcLoc, fromType, toFloatType);
  6223. }
  6224. // Casting matrix types
  6225. {
  6226. QualType elemType = {};
  6227. uint32_t numRows = 0, numCols = 0;
  6228. if (isMxNMatrix(fromType, &elemType, &numRows, &numCols)) {
  6229. // The source matrix and the target matrix must have the same dimensions.
  6230. QualType toElemType = {};
  6231. uint32_t toNumRows = 0, toNumCols = 0;
  6232. const auto isMat =
  6233. isMxNMatrix(toFloatType, &toElemType, &toNumRows, &toNumCols);
  6234. assert(isMat && numRows == toNumRows && numCols == toNumCols);
  6235. (void)isMat;
  6236. (void)toNumRows;
  6237. (void)toNumCols;
  6238. // Casting to a matrix of floats: Cast each row and construct a
  6239. // composite.
  6240. llvm::SmallVector<SpirvInstruction *, 4> castedRows;
  6241. const QualType vecType = getComponentVectorType(astContext, fromType);
  6242. const auto fromVecQualType =
  6243. astContext.getExtVectorType(elemType, numCols);
  6244. const auto toIntVecQualType =
  6245. astContext.getExtVectorType(toElemType, numCols);
  6246. for (uint32_t row = 0; row < numRows; ++row) {
  6247. auto *rowId =
  6248. spvBuilder.createCompositeExtract(vecType, fromVal, {row}, srcLoc);
  6249. castedRows.push_back(
  6250. castToFloat(rowId, fromVecQualType, toIntVecQualType, srcLoc));
  6251. }
  6252. return spvBuilder.createCompositeConstruct(toFloatType, castedRows,
  6253. srcLoc);
  6254. }
  6255. }
  6256. emitError("casting to floating point unimplemented", srcLoc);
  6257. return nullptr;
  6258. }
  6259. SpirvInstruction *
  6260. SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
  6261. const FunctionDecl *callee = callExpr->getDirectCallee();
  6262. const SourceLocation srcLoc = callExpr->getExprLoc();
  6263. assert(hlsl::IsIntrinsicOp(callee) &&
  6264. "doIntrinsicCallExpr was called for a non-intrinsic function.");
  6265. const bool isFloatType = isFloatOrVecMatOfFloatType(callExpr->getType());
  6266. const bool isSintType = isSintOrVecMatOfSintType(callExpr->getType());
  6267. // Figure out which intrinsic function to translate.
  6268. llvm::StringRef group;
  6269. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  6270. hlsl::GetIntrinsicOp(callee, opcode, group);
  6271. GLSLstd450 glslOpcode = GLSLstd450Bad;
  6272. SpirvInstruction *retVal = nullptr;
  6273. #define INTRINSIC_SPIRV_OP_CASE(intrinsicOp, spirvOp, doEachVec) \
  6274. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  6275. retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  6276. doEachVec); \
  6277. } break
  6278. #define INTRINSIC_OP_CASE(intrinsicOp, glslOp, doEachVec) \
  6279. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  6280. glslOpcode = GLSLstd450::GLSLstd450##glslOp; \
  6281. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec, \
  6282. srcLoc); \
  6283. } break
  6284. #define INTRINSIC_OP_CASE_INT_FLOAT(intrinsicOp, glslIntOp, glslFloatOp, \
  6285. doEachVec) \
  6286. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  6287. glslOpcode = isFloatType ? GLSLstd450::GLSLstd450##glslFloatOp \
  6288. : GLSLstd450::GLSLstd450##glslIntOp; \
  6289. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec, \
  6290. srcLoc); \
  6291. } break
  6292. #define INTRINSIC_OP_CASE_SINT_UINT(intrinsicOp, glslSintOp, glslUintOp, \
  6293. doEachVec) \
  6294. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  6295. glslOpcode = isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  6296. : GLSLstd450::GLSLstd450##glslUintOp; \
  6297. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec, \
  6298. srcLoc); \
  6299. } break
  6300. #define INTRINSIC_OP_CASE_SINT_UINT_FLOAT(intrinsicOp, glslSintOp, glslUintOp, \
  6301. glslFloatOp, doEachVec) \
  6302. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  6303. glslOpcode = isFloatType \
  6304. ? GLSLstd450::GLSLstd450##glslFloatOp \
  6305. : isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  6306. : GLSLstd450::GLSLstd450##glslUintOp; \
  6307. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec, \
  6308. srcLoc); \
  6309. } break
  6310. switch (const auto hlslOpcode = static_cast<hlsl::IntrinsicOp>(opcode)) {
  6311. case hlsl::IntrinsicOp::IOP_InterlockedAdd:
  6312. case hlsl::IntrinsicOp::IOP_InterlockedAnd:
  6313. case hlsl::IntrinsicOp::IOP_InterlockedMax:
  6314. case hlsl::IntrinsicOp::IOP_InterlockedUMax:
  6315. case hlsl::IntrinsicOp::IOP_InterlockedMin:
  6316. case hlsl::IntrinsicOp::IOP_InterlockedUMin:
  6317. case hlsl::IntrinsicOp::IOP_InterlockedOr:
  6318. case hlsl::IntrinsicOp::IOP_InterlockedXor:
  6319. case hlsl::IntrinsicOp::IOP_InterlockedExchange:
  6320. case hlsl::IntrinsicOp::IOP_InterlockedCompareStore:
  6321. case hlsl::IntrinsicOp::IOP_InterlockedCompareExchange:
  6322. retVal = processIntrinsicInterlockedMethod(callExpr, hlslOpcode);
  6323. break;
  6324. case hlsl::IntrinsicOp::IOP_NonUniformResourceIndex:
  6325. retVal = processIntrinsicNonUniformResourceIndex(callExpr);
  6326. break;
  6327. case hlsl::IntrinsicOp::IOP_tex1D:
  6328. case hlsl::IntrinsicOp::IOP_tex1Dbias:
  6329. case hlsl::IntrinsicOp::IOP_tex1Dgrad:
  6330. case hlsl::IntrinsicOp::IOP_tex1Dlod:
  6331. case hlsl::IntrinsicOp::IOP_tex1Dproj:
  6332. case hlsl::IntrinsicOp::IOP_tex2D:
  6333. case hlsl::IntrinsicOp::IOP_tex2Dbias:
  6334. case hlsl::IntrinsicOp::IOP_tex2Dgrad:
  6335. case hlsl::IntrinsicOp::IOP_tex2Dlod:
  6336. case hlsl::IntrinsicOp::IOP_tex2Dproj:
  6337. case hlsl::IntrinsicOp::IOP_tex3D:
  6338. case hlsl::IntrinsicOp::IOP_tex3Dbias:
  6339. case hlsl::IntrinsicOp::IOP_tex3Dgrad:
  6340. case hlsl::IntrinsicOp::IOP_tex3Dlod:
  6341. case hlsl::IntrinsicOp::IOP_tex3Dproj:
  6342. case hlsl::IntrinsicOp::IOP_texCUBE:
  6343. case hlsl::IntrinsicOp::IOP_texCUBEbias:
  6344. case hlsl::IntrinsicOp::IOP_texCUBEgrad:
  6345. case hlsl::IntrinsicOp::IOP_texCUBElod:
  6346. case hlsl::IntrinsicOp::IOP_texCUBEproj: {
  6347. emitError("deprecated %0 intrinsic function will not be supported", srcLoc)
  6348. << callee->getName();
  6349. return nullptr;
  6350. }
  6351. case hlsl::IntrinsicOp::IOP_dot:
  6352. retVal = processIntrinsicDot(callExpr);
  6353. break;
  6354. case hlsl::IntrinsicOp::IOP_GroupMemoryBarrier:
  6355. retVal = processIntrinsicMemoryBarrier(callExpr,
  6356. /*isDevice*/ false,
  6357. /*groupSync*/ false,
  6358. /*isAllBarrier*/ false);
  6359. break;
  6360. case hlsl::IntrinsicOp::IOP_GroupMemoryBarrierWithGroupSync:
  6361. retVal = processIntrinsicMemoryBarrier(callExpr,
  6362. /*isDevice*/ false,
  6363. /*groupSync*/ true,
  6364. /*isAllBarrier*/ false);
  6365. break;
  6366. case hlsl::IntrinsicOp::IOP_DeviceMemoryBarrier:
  6367. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  6368. /*groupSync*/ false,
  6369. /*isAllBarrier*/ false);
  6370. break;
  6371. case hlsl::IntrinsicOp::IOP_DeviceMemoryBarrierWithGroupSync:
  6372. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  6373. /*groupSync*/ true,
  6374. /*isAllBarrier*/ false);
  6375. break;
  6376. case hlsl::IntrinsicOp::IOP_AllMemoryBarrier:
  6377. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  6378. /*groupSync*/ false,
  6379. /*isAllBarrier*/ true);
  6380. break;
  6381. case hlsl::IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync:
  6382. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  6383. /*groupSync*/ true,
  6384. /*isAllBarrier*/ true);
  6385. break;
  6386. case hlsl::IntrinsicOp::IOP_CheckAccessFullyMapped:
  6387. retVal = spvBuilder.createImageSparseTexelsResident(
  6388. doExpr(callExpr->getArg(0)), srcLoc);
  6389. break;
  6390. case hlsl::IntrinsicOp::IOP_mul:
  6391. case hlsl::IntrinsicOp::IOP_umul:
  6392. retVal = processIntrinsicMul(callExpr);
  6393. break;
  6394. case hlsl::IntrinsicOp::IOP_all:
  6395. retVal = processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
  6396. break;
  6397. case hlsl::IntrinsicOp::IOP_any:
  6398. retVal = processIntrinsicAllOrAny(callExpr, spv::Op::OpAny);
  6399. break;
  6400. case hlsl::IntrinsicOp::IOP_asdouble:
  6401. case hlsl::IntrinsicOp::IOP_asfloat:
  6402. case hlsl::IntrinsicOp::IOP_asint:
  6403. case hlsl::IntrinsicOp::IOP_asuint:
  6404. retVal = processIntrinsicAsType(callExpr);
  6405. break;
  6406. case hlsl::IntrinsicOp::IOP_clip:
  6407. retVal = processIntrinsicClip(callExpr);
  6408. break;
  6409. case hlsl::IntrinsicOp::IOP_dst:
  6410. retVal = processIntrinsicDst(callExpr);
  6411. break;
  6412. case hlsl::IntrinsicOp::IOP_clamp:
  6413. case hlsl::IntrinsicOp::IOP_uclamp:
  6414. retVal = processIntrinsicClamp(callExpr);
  6415. break;
  6416. case hlsl::IntrinsicOp::IOP_frexp:
  6417. retVal = processIntrinsicFrexp(callExpr);
  6418. break;
  6419. case hlsl::IntrinsicOp::IOP_ldexp:
  6420. retVal = processIntrinsicLdexp(callExpr);
  6421. break;
  6422. case hlsl::IntrinsicOp::IOP_lit:
  6423. retVal = processIntrinsicLit(callExpr);
  6424. break;
  6425. case hlsl::IntrinsicOp::IOP_mad:
  6426. case hlsl::IntrinsicOp::IOP_umad:
  6427. retVal = processIntrinsicMad(callExpr);
  6428. break;
  6429. case hlsl::IntrinsicOp::IOP_modf:
  6430. retVal = processIntrinsicModf(callExpr);
  6431. break;
  6432. case hlsl::IntrinsicOp::IOP_msad4:
  6433. retVal = processIntrinsicMsad4(callExpr);
  6434. break;
  6435. case hlsl::IntrinsicOp::IOP_printf:
  6436. retVal = processIntrinsicPrintf(callExpr);
  6437. break;
  6438. case hlsl::IntrinsicOp::IOP_sign: {
  6439. if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
  6440. retVal = processIntrinsicFloatSign(callExpr);
  6441. else
  6442. retVal =
  6443. processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450SSign,
  6444. /*actPerRowForMatrices*/ true, srcLoc);
  6445. } break;
  6446. case hlsl::IntrinsicOp::IOP_D3DCOLORtoUBYTE4:
  6447. retVal = processD3DCOLORtoUBYTE4(callExpr);
  6448. break;
  6449. case hlsl::IntrinsicOp::IOP_isfinite:
  6450. retVal = processIntrinsicIsFinite(callExpr);
  6451. break;
  6452. case hlsl::IntrinsicOp::IOP_sincos:
  6453. retVal = processIntrinsicSinCos(callExpr);
  6454. break;
  6455. case hlsl::IntrinsicOp::IOP_rcp:
  6456. retVal = processIntrinsicRcp(callExpr);
  6457. break;
  6458. case hlsl::IntrinsicOp::IOP_saturate:
  6459. retVal = processIntrinsicSaturate(callExpr);
  6460. break;
  6461. case hlsl::IntrinsicOp::IOP_log10:
  6462. retVal = processIntrinsicLog10(callExpr);
  6463. break;
  6464. case hlsl::IntrinsicOp::IOP_f16tof32:
  6465. retVal = processIntrinsicF16ToF32(callExpr);
  6466. break;
  6467. case hlsl::IntrinsicOp::IOP_f32tof16:
  6468. retVal = processIntrinsicF32ToF16(callExpr);
  6469. break;
  6470. case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
  6471. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneCount",
  6472. srcLoc);
  6473. const QualType retType = callExpr->getCallReturnType(astContext);
  6474. auto *var =
  6475. declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize, retType, srcLoc);
  6476. retVal = spvBuilder.createLoad(retType, var, srcLoc);
  6477. } break;
  6478. case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
  6479. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneIndex",
  6480. srcLoc);
  6481. const QualType retType = callExpr->getCallReturnType(astContext);
  6482. auto *var = declIdMapper.getBuiltinVar(
  6483. spv::BuiltIn::SubgroupLocalInvocationId, retType, srcLoc);
  6484. retVal = spvBuilder.createLoad(retType, var, srcLoc);
  6485. } break;
  6486. case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
  6487. retVal = processWaveQuery(callExpr, spv::Op::OpGroupNonUniformElect);
  6488. break;
  6489. case hlsl::IntrinsicOp::IOP_WaveActiveAllTrue:
  6490. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAll);
  6491. break;
  6492. case hlsl::IntrinsicOp::IOP_WaveActiveAnyTrue:
  6493. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAny);
  6494. break;
  6495. case hlsl::IntrinsicOp::IOP_WaveActiveBallot:
  6496. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformBallot);
  6497. break;
  6498. case hlsl::IntrinsicOp::IOP_WaveActiveAllEqual:
  6499. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
  6500. break;
  6501. case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
  6502. retVal = processWaveCountBits(callExpr, spv::GroupOperation::Reduce);
  6503. break;
  6504. case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
  6505. case hlsl::IntrinsicOp::IOP_WaveActiveSum:
  6506. case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
  6507. case hlsl::IntrinsicOp::IOP_WaveActiveProduct:
  6508. case hlsl::IntrinsicOp::IOP_WaveActiveUMax:
  6509. case hlsl::IntrinsicOp::IOP_WaveActiveMax:
  6510. case hlsl::IntrinsicOp::IOP_WaveActiveUMin:
  6511. case hlsl::IntrinsicOp::IOP_WaveActiveMin:
  6512. case hlsl::IntrinsicOp::IOP_WaveActiveBitAnd:
  6513. case hlsl::IntrinsicOp::IOP_WaveActiveBitOr:
  6514. case hlsl::IntrinsicOp::IOP_WaveActiveBitXor: {
  6515. const auto retType = callExpr->getCallReturnType(astContext);
  6516. retVal = processWaveReductionOrPrefix(
  6517. callExpr, translateWaveOp(hlslOpcode, retType, srcLoc),
  6518. spv::GroupOperation::Reduce);
  6519. } break;
  6520. case hlsl::IntrinsicOp::IOP_WavePrefixUSum:
  6521. case hlsl::IntrinsicOp::IOP_WavePrefixSum:
  6522. case hlsl::IntrinsicOp::IOP_WavePrefixUProduct:
  6523. case hlsl::IntrinsicOp::IOP_WavePrefixProduct: {
  6524. const auto retType = callExpr->getCallReturnType(astContext);
  6525. retVal = processWaveReductionOrPrefix(
  6526. callExpr, translateWaveOp(hlslOpcode, retType, srcLoc),
  6527. spv::GroupOperation::ExclusiveScan);
  6528. } break;
  6529. case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
  6530. retVal = processWaveCountBits(callExpr, spv::GroupOperation::ExclusiveScan);
  6531. break;
  6532. case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
  6533. case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
  6534. retVal = processWaveBroadcast(callExpr);
  6535. break;
  6536. case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
  6537. case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
  6538. case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
  6539. case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
  6540. retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
  6541. break;
  6542. case hlsl::IntrinsicOp::IOP_abort:
  6543. case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
  6544. case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
  6545. emitError("no equivalent for %0 intrinsic function in Vulkan", srcLoc)
  6546. << callee->getName();
  6547. return 0;
  6548. }
  6549. case hlsl::IntrinsicOp::IOP_transpose: {
  6550. const Expr *mat = callExpr->getArg(0);
  6551. const QualType matType = mat->getType();
  6552. if (isVectorType(matType) || isScalarType(matType)) {
  6553. // A 1xN or Nx1 or 1x1 matrix is a SPIR-V vector/scalar, and its transpose
  6554. // is the vector/scalar itself.
  6555. retVal = doExpr(mat);
  6556. } else {
  6557. if (hlsl::GetHLSLMatElementType(matType)->isFloatingType())
  6558. retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpTranspose,
  6559. false);
  6560. else
  6561. retVal = processNonFpMatrixTranspose(matType, doExpr(mat), srcLoc);
  6562. }
  6563. break;
  6564. }
  6565. // DXR raytracing intrinsics
  6566. case hlsl::IntrinsicOp::IOP_DispatchRaysDimensions:
  6567. case hlsl::IntrinsicOp::IOP_DispatchRaysIndex:
  6568. case hlsl::IntrinsicOp::IOP_GeometryIndex:
  6569. case hlsl::IntrinsicOp::IOP_HitKind:
  6570. case hlsl::IntrinsicOp::IOP_InstanceIndex:
  6571. case hlsl::IntrinsicOp::IOP_InstanceID:
  6572. case hlsl::IntrinsicOp::IOP_ObjectRayDirection:
  6573. case hlsl::IntrinsicOp::IOP_ObjectRayOrigin:
  6574. case hlsl::IntrinsicOp::IOP_ObjectToWorld3x4:
  6575. case hlsl::IntrinsicOp::IOP_ObjectToWorld4x3:
  6576. case hlsl::IntrinsicOp::IOP_PrimitiveIndex:
  6577. case hlsl::IntrinsicOp::IOP_RayFlags:
  6578. case hlsl::IntrinsicOp::IOP_RayTCurrent:
  6579. case hlsl::IntrinsicOp::IOP_RayTMin:
  6580. case hlsl::IntrinsicOp::IOP_WorldRayDirection:
  6581. case hlsl::IntrinsicOp::IOP_WorldRayOrigin:
  6582. case hlsl::IntrinsicOp::IOP_WorldToObject3x4:
  6583. case hlsl::IntrinsicOp::IOP_WorldToObject4x3: {
  6584. retVal = processRayBuiltins(callExpr, hlslOpcode);
  6585. break;
  6586. }
  6587. case hlsl::IntrinsicOp::IOP_AcceptHitAndEndSearch:
  6588. case hlsl::IntrinsicOp::IOP_IgnoreHit: {
  6589. // Any modifications made to the ray payload in an any hit shader are
  6590. // preserved before calling AcceptHit/IgnoreHit. Write out the results to
  6591. // the payload which is visible only in entry functions
  6592. const auto iter = functionInfoMap.find(curFunction);
  6593. if (iter != functionInfoMap.end()) {
  6594. const auto &entryInfo = iter->second;
  6595. if (entryInfo->isEntryFunction) {
  6596. const auto payloadArg = curFunction->getParamDecl(0);
  6597. const auto payloadArgInst =
  6598. declIdMapper.getDeclEvalInfo(payloadArg, payloadArg->getLocStart());
  6599. auto tempLoad = spvBuilder.createLoad(
  6600. payloadArg->getType(), payloadArgInst, payloadArg->getLocStart());
  6601. spvBuilder.createStore(currentRayPayload, tempLoad,
  6602. callExpr->getExprLoc());
  6603. }
  6604. }
  6605. spvBuilder.createRayTracingOpsNV(
  6606. hlslOpcode == hlsl::IntrinsicOp ::IOP_AcceptHitAndEndSearch
  6607. ? spv::Op::OpTerminateRayNV
  6608. : spv::Op::OpIgnoreIntersectionNV,
  6609. QualType(), {}, srcLoc);
  6610. break;
  6611. }
  6612. case hlsl::IntrinsicOp::IOP_ReportHit: {
  6613. retVal = processReportHit(callExpr);
  6614. break;
  6615. }
  6616. case hlsl::IntrinsicOp::IOP_TraceRay: {
  6617. processTraceRay(callExpr);
  6618. break;
  6619. }
  6620. case hlsl::IntrinsicOp::IOP_CallShader: {
  6621. processCallShader(callExpr);
  6622. break;
  6623. }
  6624. case hlsl::IntrinsicOp::IOP_DispatchMesh: {
  6625. processDispatchMesh(callExpr);
  6626. break;
  6627. }
  6628. case hlsl::IntrinsicOp::IOP_SetMeshOutputCounts: {
  6629. processMeshOutputCounts(callExpr);
  6630. break;
  6631. }
  6632. INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
  6633. INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
  6634. INTRINSIC_SPIRV_OP_CASE(ddx_fine, DPdxFine, false);
  6635. INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
  6636. INTRINSIC_SPIRV_OP_CASE(ddy_coarse, DPdyCoarse, false);
  6637. INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
  6638. INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
  6639. INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
  6640. INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
  6641. INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
  6642. INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
  6643. INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
  6644. INTRINSIC_OP_CASE(round, Round, true);
  6645. INTRINSIC_OP_CASE(uabs, SAbs, true);
  6646. INTRINSIC_OP_CASE_INT_FLOAT(abs, SAbs, FAbs, true);
  6647. INTRINSIC_OP_CASE(acos, Acos, true);
  6648. INTRINSIC_OP_CASE(asin, Asin, true);
  6649. INTRINSIC_OP_CASE(atan, Atan, true);
  6650. INTRINSIC_OP_CASE(atan2, Atan2, true);
  6651. INTRINSIC_OP_CASE(ceil, Ceil, true);
  6652. INTRINSIC_OP_CASE(cos, Cos, true);
  6653. INTRINSIC_OP_CASE(cosh, Cosh, true);
  6654. INTRINSIC_OP_CASE(cross, Cross, false);
  6655. INTRINSIC_OP_CASE(degrees, Degrees, true);
  6656. INTRINSIC_OP_CASE(distance, Distance, false);
  6657. INTRINSIC_OP_CASE(determinant, Determinant, false);
  6658. INTRINSIC_OP_CASE(exp, Exp, true);
  6659. INTRINSIC_OP_CASE(exp2, Exp2, true);
  6660. INTRINSIC_OP_CASE_SINT_UINT(firstbithigh, FindSMsb, FindUMsb, false);
  6661. INTRINSIC_OP_CASE_SINT_UINT(ufirstbithigh, FindSMsb, FindUMsb, false);
  6662. INTRINSIC_OP_CASE(faceforward, FaceForward, false);
  6663. INTRINSIC_OP_CASE(firstbitlow, FindILsb, false);
  6664. INTRINSIC_OP_CASE(floor, Floor, true);
  6665. INTRINSIC_OP_CASE(fma, Fma, true);
  6666. INTRINSIC_OP_CASE(frac, Fract, true);
  6667. INTRINSIC_OP_CASE(length, Length, false);
  6668. INTRINSIC_OP_CASE(lerp, FMix, true);
  6669. INTRINSIC_OP_CASE(log, Log, true);
  6670. INTRINSIC_OP_CASE(log2, Log2, true);
  6671. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(max, SMax, UMax, FMax, true);
  6672. INTRINSIC_OP_CASE(umax, UMax, true);
  6673. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(min, SMin, UMin, FMin, true);
  6674. INTRINSIC_OP_CASE(umin, UMin, true);
  6675. INTRINSIC_OP_CASE(normalize, Normalize, false);
  6676. INTRINSIC_OP_CASE(pow, Pow, true);
  6677. INTRINSIC_OP_CASE(radians, Radians, true);
  6678. INTRINSIC_OP_CASE(reflect, Reflect, false);
  6679. INTRINSIC_OP_CASE(refract, Refract, false);
  6680. INTRINSIC_OP_CASE(rsqrt, InverseSqrt, true);
  6681. INTRINSIC_OP_CASE(smoothstep, SmoothStep, true);
  6682. INTRINSIC_OP_CASE(step, Step, true);
  6683. INTRINSIC_OP_CASE(sin, Sin, true);
  6684. INTRINSIC_OP_CASE(sinh, Sinh, true);
  6685. INTRINSIC_OP_CASE(tan, Tan, true);
  6686. INTRINSIC_OP_CASE(tanh, Tanh, true);
  6687. INTRINSIC_OP_CASE(sqrt, Sqrt, true);
  6688. INTRINSIC_OP_CASE(trunc, Trunc, true);
  6689. default:
  6690. emitError("%0 intrinsic function unimplemented", srcLoc)
  6691. << callee->getName();
  6692. return 0;
  6693. }
  6694. #undef INTRINSIC_OP_CASE
  6695. #undef INTRINSIC_OP_CASE_INT_FLOAT
  6696. if (retVal)
  6697. retVal->setRValue();
  6698. return retVal;
  6699. }
  6700. SpirvInstruction *
  6701. SpirvEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
  6702. hlsl::IntrinsicOp opcode) {
  6703. // The signature of intrinsic atomic methods are:
  6704. // void Interlocked*(in R dest, in T value);
  6705. // void Interlocked*(in R dest, in T value, out T original_value);
  6706. // Note: ALL Interlocked*() methods are forced to have an unsigned integer
  6707. // 'value'. Meaning, T is forced to be 'unsigned int'. If the provided
  6708. // parameter is not an unsigned integer, the frontend inserts an
  6709. // 'ImplicitCastExpr' to convert it to unsigned integer. OpAtomicIAdd (and
  6710. // other SPIR-V OpAtomic* instructions) require that the pointee in 'dest' to
  6711. // be of the same type as T. This will result in an invalid SPIR-V if 'dest'
  6712. // is a signed integer typed resource such as RWTexture1D<int>. For example,
  6713. // the following OpAtomicIAdd is invalid because the pointee type defined in
  6714. // %1 is a signed integer, while the value passed to atomic add (%3) is an
  6715. // unsigned integer.
  6716. //
  6717. // %_ptr_Image_int = OpTypePointer Image %int
  6718. // %1 = OpImageTexelPointer %_ptr_Image_int %RWTexture1D_int %index %uint_0
  6719. // %2 = OpLoad %int %value
  6720. // %3 = OpBitcast %uint %2 <-------- Inserted by the frontend
  6721. // %4 = OpAtomicIAdd %int %1 %uint_1 %uint_0 %3
  6722. //
  6723. // In such cases, we bypass the forced IntegralCast.
  6724. // Moreover, the frontend does not add a cast AST node to cast uint to int
  6725. // where necessary. To ensure SPIR-V validity, we add that where necessary.
  6726. auto *zero =
  6727. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  6728. const auto *dest = expr->getArg(0);
  6729. const auto srcLoc = expr->getExprLoc();
  6730. const auto baseType = dest->getType()->getCanonicalTypeUnqualified();
  6731. if (!baseType->isIntegerType()) {
  6732. emitError("can only perform atomic operations on scalar integer values",
  6733. dest->getLocStart());
  6734. return nullptr;
  6735. }
  6736. const auto doArg = [baseType, this](const CallExpr *callExpr,
  6737. uint32_t argIndex) {
  6738. const Expr *valueExpr = callExpr->getArg(argIndex);
  6739. if (const auto *castExpr = dyn_cast<ImplicitCastExpr>(valueExpr))
  6740. if (castExpr->getCastKind() == CK_IntegralCast &&
  6741. castExpr->getSubExpr()->getType() == baseType)
  6742. valueExpr = castExpr->getSubExpr();
  6743. auto *argInstr = doExpr(valueExpr);
  6744. if (valueExpr->getType() != baseType)
  6745. argInstr = castToInt(argInstr, valueExpr->getType(), baseType,
  6746. valueExpr->getExprLoc());
  6747. return argInstr;
  6748. };
  6749. const auto writeToOutputArg = [&baseType, dest,
  6750. this](SpirvInstruction *toWrite,
  6751. const CallExpr *callExpr,
  6752. uint32_t outputArgIndex) {
  6753. const auto outputArg = callExpr->getArg(outputArgIndex);
  6754. const auto outputArgType = outputArg->getType();
  6755. if (baseType != outputArgType)
  6756. toWrite =
  6757. castToInt(toWrite, baseType, outputArgType, dest->getLocStart());
  6758. spvBuilder.createStore(doExpr(outputArg), toWrite, callExpr->getExprLoc());
  6759. };
  6760. // If a vector swizzling of a texture is done as an argument of an
  6761. // interlocked method, we need to handle the access to the texture
  6762. // buffer element correctly. For example:
  6763. //
  6764. // InterlockedAdd(myRWTexture[index].r, 1);
  6765. //
  6766. // `-CallExpr
  6767. // |-ImplicitCastExpr
  6768. // | `-DeclRefExpr Function 'InterlockedAdd'
  6769. // | 'void (unsigned int &, unsigned int)'
  6770. // |-HLSLVectorElementExpr 'unsigned int' lvalue vectorcomponent r
  6771. // | `-ImplicitCastExpr 'vector<uint, 1>':'vector<unsigned int, 1>'
  6772. // | <HLSLVectorSplat>
  6773. // | `-CXXOperatorCallExpr 'unsigned int' lvalue
  6774. const auto *cxxOpCall = dyn_cast<CXXOperatorCallExpr>(dest);
  6775. if (const auto *vector = dyn_cast<HLSLVectorElementExpr>(dest)) {
  6776. const Expr *base = vector->getBase();
  6777. cxxOpCall = dyn_cast<CXXOperatorCallExpr>(base);
  6778. if (const auto *cast = dyn_cast<CastExpr>(base)) {
  6779. cxxOpCall = dyn_cast<CXXOperatorCallExpr>(cast->getSubExpr());
  6780. }
  6781. }
  6782. // If the argument is indexing into a texture/buffer, we need to create an
  6783. // OpImageTexelPointer instruction.
  6784. SpirvInstruction *ptr = nullptr;
  6785. if (cxxOpCall) {
  6786. const Expr *base = nullptr;
  6787. const Expr *index = nullptr;
  6788. if (isBufferTextureIndexing(cxxOpCall, &base, &index)) {
  6789. if (hlsl::IsHLSLResourceType(base->getType())) {
  6790. const auto resultTy = hlsl::GetHLSLResourceResultType(base->getType());
  6791. if (!isScalarType(resultTy, nullptr)) {
  6792. emitError("Interlocked operation for texture buffer whose result "
  6793. "type is non-scalar type is not allowed",
  6794. dest->getExprLoc());
  6795. return nullptr;
  6796. }
  6797. }
  6798. auto *baseInstr = doExpr(base);
  6799. if (baseInstr->isRValue()) {
  6800. // OpImageTexelPointer's Image argument must have a type of
  6801. // OpTypePointer with Type OpTypeImage. Need to create a temporary
  6802. // variable if the baseId is an rvalue.
  6803. baseInstr =
  6804. createTemporaryVar(base->getType(), getAstTypeName(base->getType()),
  6805. baseInstr, base->getExprLoc());
  6806. }
  6807. auto *coordInstr = doExpr(index);
  6808. ptr = spvBuilder.createImageTexelPointer(baseType, baseInstr, coordInstr,
  6809. zero, srcLoc);
  6810. }
  6811. }
  6812. if (!ptr) {
  6813. auto *ptrInfo = doExpr(dest);
  6814. const auto sc = ptrInfo->getStorageClass();
  6815. if (sc == spv::StorageClass::Private || sc == spv::StorageClass::Function) {
  6816. emitError("using static variable or function scope variable in "
  6817. "interlocked operation is not allowed",
  6818. dest->getExprLoc());
  6819. return nullptr;
  6820. }
  6821. ptr = ptrInfo;
  6822. }
  6823. const bool isCompareExchange =
  6824. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareExchange;
  6825. const bool isCompareStore =
  6826. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareStore;
  6827. if (isCompareExchange || isCompareStore) {
  6828. auto *comparator = doArg(expr, 1);
  6829. auto *valueInstr = doArg(expr, 2);
  6830. auto *originalVal = spvBuilder.createAtomicCompareExchange(
  6831. baseType, ptr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone,
  6832. spv::MemorySemanticsMask::MaskNone, valueInstr, comparator, srcLoc);
  6833. if (isCompareExchange)
  6834. writeToOutputArg(originalVal, expr, 3);
  6835. } else {
  6836. auto *value = doArg(expr, 1);
  6837. // Since these atomic operations write through the provided pointer, the
  6838. // signed vs. unsigned opcode must be decided based on the pointee type
  6839. // of the first argument. However, the frontend decides the opcode based on
  6840. // the second argument (value). Therefore, the HLSL opcode provided by the
  6841. // frontend may be wrong. Therefore we need the following code to make sure
  6842. // we are using the correct SPIR-V opcode.
  6843. spv::Op atomicOp = translateAtomicHlslOpcodeToSpirvOpcode(opcode);
  6844. if (atomicOp == spv::Op::OpAtomicUMax && baseType->isSignedIntegerType())
  6845. atomicOp = spv::Op::OpAtomicSMax;
  6846. if (atomicOp == spv::Op::OpAtomicSMax && baseType->isUnsignedIntegerType())
  6847. atomicOp = spv::Op::OpAtomicUMax;
  6848. if (atomicOp == spv::Op::OpAtomicUMin && baseType->isSignedIntegerType())
  6849. atomicOp = spv::Op::OpAtomicSMin;
  6850. if (atomicOp == spv::Op::OpAtomicSMin && baseType->isUnsignedIntegerType())
  6851. atomicOp = spv::Op::OpAtomicUMin;
  6852. auto *originalVal = spvBuilder.createAtomicOp(
  6853. atomicOp, baseType, ptr, spv::Scope::Device,
  6854. spv::MemorySemanticsMask::MaskNone, value, srcLoc);
  6855. if (expr->getNumArgs() > 2)
  6856. writeToOutputArg(originalVal, expr, 2);
  6857. }
  6858. return nullptr;
  6859. }
  6860. SpirvInstruction *
  6861. SpirvEmitter::processIntrinsicNonUniformResourceIndex(const CallExpr *expr) {
  6862. auto *index = doExpr(expr->getArg(0));
  6863. // Decorate the expression in NonUniformResourceIndex() with NonUniformEXT.
  6864. // Aside from this, we also need to eventually populate the NonUniformEXT
  6865. // status to the usages of this expression. This is done by the
  6866. // NonUniformVisitor class.
  6867. //
  6868. // The decoration shouldn't be applied to the operand, rather to a copy of the
  6869. // result. Even though applying the decoration to the operand may not be
  6870. // functionally incorrect (since adding NonUniform is more conservative), it
  6871. // could affect performance and isn't the intent of the shader.
  6872. auto *copyInstr =
  6873. spvBuilder.createCopyObject(expr->getType(), index, expr->getExprLoc());
  6874. copyInstr->setNonUniform();
  6875. return copyInstr;
  6876. }
  6877. SpirvInstruction *
  6878. SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
  6879. const auto loc = callExpr->getExprLoc();
  6880. if (!spirvOptions.noWarnEmulatedFeatures)
  6881. emitWarning("msad4 intrinsic function is emulated using many SPIR-V "
  6882. "instructions due to lack of direct SPIR-V equivalent",
  6883. loc);
  6884. // Compares a 4-byte reference value and an 8-byte source value and
  6885. // accumulates a vector of 4 sums. Each sum corresponds to the masked sum
  6886. // of absolute differences of a different byte alignment between the
  6887. // reference value and the source value.
  6888. // If we have:
  6889. // uint v0; // reference
  6890. // uint2 v1; // source
  6891. // uint4 v2; // accum
  6892. // uint4 o0; // result of msad4
  6893. // uint4 r0, t0; // temporary values
  6894. //
  6895. // Then msad4(v0, v1, v2) translates to the following SM5 assembly according
  6896. // to fxc:
  6897. // Step 1:
  6898. // ushr r0.xyz, v1.xxxx, l(8, 16, 24, 0)
  6899. // Step 2:
  6900. // [result], [ width ], [ offset ], [ insert ], [ base ]
  6901. // bfi t0.yzw, l(0, 8, 16, 24), l(0, 24, 16, 8), v1.yyyy , r0.xxyz
  6902. // mov t0.x, v1.x
  6903. // Step 3:
  6904. // msad o0.xyzw, v0.xxxx, t0.xyzw, v2.xyzw
  6905. const auto boolType = astContext.BoolTy;
  6906. const auto intType = astContext.IntTy;
  6907. const auto uintType = astContext.UnsignedIntTy;
  6908. const auto uint4Type = astContext.getExtVectorType(uintType, 4);
  6909. auto *reference = doExpr(callExpr->getArg(0));
  6910. auto *source = doExpr(callExpr->getArg(1));
  6911. auto *accum = doExpr(callExpr->getArg(2));
  6912. const auto uint0 =
  6913. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  6914. const auto uint8 =
  6915. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 8));
  6916. const auto uint16 =
  6917. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16));
  6918. const auto uint24 =
  6919. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 24));
  6920. // Step 1.
  6921. auto *v1x = spvBuilder.createCompositeExtract(uintType, source, {0}, loc);
  6922. // r0.x = v1xS8 = v1.x shifted by 8 bits
  6923. auto *v1xS8 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical, uintType,
  6924. v1x, uint8, loc);
  6925. // r0.y = v1xS16 = v1.x shifted by 16 bits
  6926. auto *v1xS16 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  6927. uintType, v1x, uint16, loc);
  6928. // r0.z = v1xS24 = v1.x shifted by 24 bits
  6929. auto *v1xS24 = spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  6930. uintType, v1x, uint24, loc);
  6931. // Step 2.
  6932. // Do bfi 3 times. DXIL bfi is equivalent to SPIR-V OpBitFieldInsert.
  6933. auto *v1y = spvBuilder.createCompositeExtract(uintType, source, {1}, loc);
  6934. // Note that t0.x = v1.x, nothing we need to do for that.
  6935. auto *t0y =
  6936. spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS8, /*insert*/ v1y,
  6937. /*offset*/ uint24,
  6938. /*width*/ uint8, loc);
  6939. auto *t0z =
  6940. spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS16, /*insert*/ v1y,
  6941. /*offset*/ uint16,
  6942. /*width*/ uint16, loc);
  6943. auto *t0w =
  6944. spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS24, /*insert*/ v1y,
  6945. /*offset*/ uint8,
  6946. /*width*/ uint24, loc);
  6947. // Step 3. MSAD (Masked Sum of Absolute Differences)
  6948. // Now perform MSAD four times.
  6949. // Need to mimic this algorithm in SPIR-V!
  6950. //
  6951. // UINT msad( UINT ref, UINT src, UINT accum )
  6952. // {
  6953. // for (UINT i = 0; i < 4; i++)
  6954. // {
  6955. // BYTE refByte, srcByte, absDiff;
  6956. //
  6957. // refByte = (BYTE)(ref >> (i * 8));
  6958. // if (!refByte)
  6959. // {
  6960. // continue;
  6961. // }
  6962. //
  6963. // srcByte = (BYTE)(src >> (i * 8));
  6964. // if (refByte >= srcByte)
  6965. // {
  6966. // absDiff = refByte - srcByte;
  6967. // }
  6968. // else
  6969. // {
  6970. // absDiff = srcByte - refByte;
  6971. // }
  6972. //
  6973. // // The recommended overflow behavior for MSAD is
  6974. // // to do a 32-bit saturate. This is not
  6975. // // required, however, and wrapping is allowed.
  6976. // // So from an application point of view,
  6977. // // overflow behavior is undefined.
  6978. // if (UINT_MAX - accum < absDiff)
  6979. // {
  6980. // accum = UINT_MAX;
  6981. // break;
  6982. // }
  6983. // accum += absDiff;
  6984. // }
  6985. //
  6986. // return accum;
  6987. // }
  6988. auto *accum0 = spvBuilder.createCompositeExtract(uintType, accum, {0}, loc);
  6989. auto *accum1 = spvBuilder.createCompositeExtract(uintType, accum, {1}, loc);
  6990. auto *accum2 = spvBuilder.createCompositeExtract(uintType, accum, {2}, loc);
  6991. auto *accum3 = spvBuilder.createCompositeExtract(uintType, accum, {3}, loc);
  6992. const llvm::SmallVector<SpirvInstruction *, 4> sources = {v1x, t0y, t0z, t0w};
  6993. llvm::SmallVector<SpirvInstruction *, 4> accums = {accum0, accum1, accum2,
  6994. accum3};
  6995. llvm::SmallVector<SpirvInstruction *, 4> refBytes;
  6996. llvm::SmallVector<SpirvInstruction *, 4> signedRefBytes;
  6997. llvm::SmallVector<SpirvInstruction *, 4> isRefByteZero;
  6998. for (uint32_t i = 0; i < 4; ++i) {
  6999. refBytes.push_back(spvBuilder.createBitFieldExtract(
  7000. uintType, reference, /*offset*/
  7001. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  7002. llvm::APInt(32, i * 8)),
  7003. /*count*/ uint8, /*isSigned*/ false, loc));
  7004. signedRefBytes.push_back(spvBuilder.createUnaryOp(
  7005. spv::Op::OpBitcast, intType, refBytes.back(), loc));
  7006. isRefByteZero.push_back(spvBuilder.createBinaryOp(
  7007. spv::Op::OpIEqual, boolType, refBytes.back(), uint0, loc));
  7008. }
  7009. for (uint32_t msadNum = 0; msadNum < 4; ++msadNum) {
  7010. for (uint32_t byteCount = 0; byteCount < 4; ++byteCount) {
  7011. // 'count' is always 8 because we are extracting 8 bits out of 32.
  7012. auto *srcByte = spvBuilder.createBitFieldExtract(
  7013. uintType, sources[msadNum],
  7014. /*offset*/
  7015. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  7016. llvm::APInt(32, 8 * byteCount)),
  7017. /*count*/ uint8, /*isSigned*/ false, loc);
  7018. auto *signedSrcByte =
  7019. spvBuilder.createUnaryOp(spv::Op::OpBitcast, intType, srcByte, loc);
  7020. auto *sub = spvBuilder.createBinaryOp(spv::Op::OpISub, intType,
  7021. signedRefBytes[byteCount],
  7022. signedSrcByte, loc);
  7023. auto *absSub = spvBuilder.createGLSLExtInst(
  7024. intType, GLSLstd450::GLSLstd450SAbs, {sub}, loc);
  7025. auto *diff = spvBuilder.createSelect(
  7026. uintType, isRefByteZero[byteCount], uint0,
  7027. spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintType, absSub, loc),
  7028. loc);
  7029. // As pointed out by the DXIL reference above, it is *not* required to
  7030. // saturate the output to UINT_MAX in case of overflow. Wrapping around is
  7031. // also allowed. For simplicity, we will wrap around at this point.
  7032. accums[msadNum] = spvBuilder.createBinaryOp(spv::Op::OpIAdd, uintType,
  7033. accums[msadNum], diff, loc);
  7034. }
  7035. }
  7036. return spvBuilder.createCompositeConstruct(uint4Type, accums, loc);
  7037. }
  7038. SpirvInstruction *SpirvEmitter::processWaveQuery(const CallExpr *callExpr,
  7039. spv::Op opcode) {
  7040. // Signatures:
  7041. // bool WaveIsFirstLane()
  7042. // uint WaveGetLaneCount()
  7043. // uint WaveGetLaneIndex()
  7044. assert(callExpr->getNumArgs() == 0);
  7045. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7046. callExpr->getExprLoc());
  7047. const QualType retType = callExpr->getCallReturnType(astContext);
  7048. return spvBuilder.createGroupNonUniformElect(
  7049. opcode, retType, spv::Scope::Subgroup, callExpr->getExprLoc());
  7050. }
  7051. SpirvInstruction *SpirvEmitter::processWaveVote(const CallExpr *callExpr,
  7052. spv::Op opcode) {
  7053. // Signatures:
  7054. // bool WaveActiveAnyTrue( bool expr )
  7055. // bool WaveActiveAllTrue( bool expr )
  7056. // bool uint4 WaveActiveBallot( bool expr )
  7057. assert(callExpr->getNumArgs() == 1);
  7058. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7059. callExpr->getExprLoc());
  7060. auto *predicate = doExpr(callExpr->getArg(0));
  7061. const QualType retType = callExpr->getCallReturnType(astContext);
  7062. return spvBuilder.createGroupNonUniformUnaryOp(
  7063. callExpr->getExprLoc(), opcode, retType, spv::Scope::Subgroup, predicate);
  7064. }
  7065. spv::Op SpirvEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
  7066. SourceLocation srcLoc) {
  7067. const bool isSintType = isSintOrVecMatOfSintType(type);
  7068. const bool isUintType = isUintOrVecMatOfUintType(type);
  7069. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  7070. #define WAVE_OP_CASE_INT(kind, intWaveOp) \
  7071. \
  7072. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  7073. if (isSintType || isUintType) { \
  7074. return spv::Op::OpGroupNonUniform##intWaveOp; \
  7075. } \
  7076. } break
  7077. #define WAVE_OP_CASE_INT_FLOAT(kind, intWaveOp, floatWaveOp) \
  7078. \
  7079. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  7080. if (isSintType || isUintType) { \
  7081. return spv::Op::OpGroupNonUniform##intWaveOp; \
  7082. } \
  7083. if (isFloatType) { \
  7084. return spv::Op::OpGroupNonUniform##floatWaveOp; \
  7085. } \
  7086. } break
  7087. #define WAVE_OP_CASE_SINT_UINT_FLOAT(kind, sintWaveOp, uintWaveOp, \
  7088. floatWaveOp) \
  7089. \
  7090. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  7091. if (isSintType) { \
  7092. return spv::Op::OpGroupNonUniform##sintWaveOp; \
  7093. } \
  7094. if (isUintType) { \
  7095. return spv::Op::OpGroupNonUniform##uintWaveOp; \
  7096. } \
  7097. if (isFloatType) { \
  7098. return spv::Op::OpGroupNonUniform##floatWaveOp; \
  7099. } \
  7100. } break
  7101. switch (op) {
  7102. WAVE_OP_CASE_INT_FLOAT(ActiveUSum, IAdd, FAdd);
  7103. WAVE_OP_CASE_INT_FLOAT(ActiveSum, IAdd, FAdd);
  7104. WAVE_OP_CASE_INT_FLOAT(ActiveUProduct, IMul, FMul);
  7105. WAVE_OP_CASE_INT_FLOAT(ActiveProduct, IMul, FMul);
  7106. WAVE_OP_CASE_INT_FLOAT(PrefixUSum, IAdd, FAdd);
  7107. WAVE_OP_CASE_INT_FLOAT(PrefixSum, IAdd, FAdd);
  7108. WAVE_OP_CASE_INT_FLOAT(PrefixUProduct, IMul, FMul);
  7109. WAVE_OP_CASE_INT_FLOAT(PrefixProduct, IMul, FMul);
  7110. WAVE_OP_CASE_INT(ActiveBitAnd, BitwiseAnd);
  7111. WAVE_OP_CASE_INT(ActiveBitOr, BitwiseOr);
  7112. WAVE_OP_CASE_INT(ActiveBitXor, BitwiseXor);
  7113. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMax, SMax, UMax, FMax);
  7114. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax);
  7115. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin);
  7116. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin);
  7117. default:
  7118. // Only Simple Wave Ops are handled here.
  7119. break;
  7120. }
  7121. #undef WAVE_OP_CASE_INT_FLOAT
  7122. #undef WAVE_OP_CASE_INT
  7123. #undef WAVE_OP_CASE_SINT_UINT_FLOAT
  7124. emitError("translating wave operator '%0' unimplemented", srcLoc)
  7125. << static_cast<uint32_t>(op);
  7126. return spv::Op::OpNop;
  7127. }
  7128. SpirvInstruction *
  7129. SpirvEmitter::processWaveCountBits(const CallExpr *callExpr,
  7130. spv::GroupOperation groupOp) {
  7131. // Signatures:
  7132. // uint WaveActiveCountBits(bool bBit)
  7133. // uint WavePrefixCountBits(Bool bBit)
  7134. assert(callExpr->getNumArgs() == 1);
  7135. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7136. callExpr->getExprLoc());
  7137. auto *predicate = doExpr(callExpr->getArg(0));
  7138. const auto srcLoc = callExpr->getExprLoc();
  7139. const QualType u32Type = astContext.UnsignedIntTy;
  7140. const QualType v4u32Type = astContext.getExtVectorType(u32Type, 4);
  7141. const QualType retType = callExpr->getCallReturnType(astContext);
  7142. auto *ballot = spvBuilder.createGroupNonUniformUnaryOp(
  7143. srcLoc, spv::Op::OpGroupNonUniformBallot, v4u32Type, spv::Scope::Subgroup,
  7144. predicate);
  7145. return spvBuilder.createGroupNonUniformUnaryOp(
  7146. srcLoc, spv::Op::OpGroupNonUniformBallotBitCount, retType,
  7147. spv::Scope::Subgroup, ballot,
  7148. llvm::Optional<spv::GroupOperation>(groupOp));
  7149. }
  7150. SpirvInstruction *SpirvEmitter::processWaveReductionOrPrefix(
  7151. const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
  7152. // Signatures:
  7153. // bool WaveActiveAllEqual( <type> expr )
  7154. // <type> WaveActiveSum( <type> expr )
  7155. // <type> WaveActiveProduct( <type> expr )
  7156. // <int_type> WaveActiveBitAnd( <int_type> expr )
  7157. // <int_type> WaveActiveBitOr( <int_type> expr )
  7158. // <int_type> WaveActiveBitXor( <int_type> expr )
  7159. // <type> WaveActiveMin( <type> expr)
  7160. // <type> WaveActiveMax( <type> expr)
  7161. //
  7162. // <type> WavePrefixProduct(<type> value)
  7163. // <type> WavePrefixSum(<type> value)
  7164. assert(callExpr->getNumArgs() == 1);
  7165. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7166. callExpr->getExprLoc());
  7167. auto *predicate = doExpr(callExpr->getArg(0));
  7168. const QualType retType = callExpr->getCallReturnType(astContext);
  7169. return spvBuilder.createGroupNonUniformUnaryOp(
  7170. callExpr->getExprLoc(), opcode, retType, spv::Scope::Subgroup, predicate,
  7171. llvm::Optional<spv::GroupOperation>(groupOp));
  7172. }
  7173. SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) {
  7174. // Signatures:
  7175. // <type> WaveReadLaneFirst(<type> expr)
  7176. // <type> WaveReadLaneAt(<type> expr, uint laneIndex)
  7177. const auto numArgs = callExpr->getNumArgs();
  7178. const auto srcLoc = callExpr->getExprLoc();
  7179. assert(numArgs == 1 || numArgs == 2);
  7180. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7181. callExpr->getExprLoc());
  7182. auto *value = doExpr(callExpr->getArg(0));
  7183. const QualType retType = callExpr->getCallReturnType(astContext);
  7184. if (numArgs == 2)
  7185. // WaveReadLaneAt is in fact not a broadcast operation (even though its name
  7186. // might incorrectly suggest so). The proper mapping to SPIR-V for
  7187. // it is OpGroupNonUniformShuffle, *not* OpGroupNonUniformBroadcast.
  7188. return spvBuilder.createGroupNonUniformBinaryOp(
  7189. spv::Op::OpGroupNonUniformShuffle, retType, spv::Scope::Subgroup,
  7190. value, doExpr(callExpr->getArg(1)), srcLoc);
  7191. else
  7192. return spvBuilder.createGroupNonUniformUnaryOp(
  7193. srcLoc, spv::Op::OpGroupNonUniformBroadcastFirst, retType,
  7194. spv::Scope::Subgroup, value);
  7195. }
  7196. SpirvInstruction *
  7197. SpirvEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
  7198. hlsl::IntrinsicOp op) {
  7199. // Signatures:
  7200. // <type> QuadReadAcrossX(<type> localValue)
  7201. // <type> QuadReadAcrossY(<type> localValue)
  7202. // <type> QuadReadAcrossDiagonal(<type> localValue)
  7203. // <type> QuadReadLaneAt(<type> sourceValue, uint quadLaneID)
  7204. assert(callExpr->getNumArgs() == 1 || callExpr->getNumArgs() == 2);
  7205. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  7206. callExpr->getExprLoc());
  7207. auto *value = doExpr(callExpr->getArg(0));
  7208. const auto srcLoc = callExpr->getExprLoc();
  7209. const QualType retType = callExpr->getCallReturnType(astContext);
  7210. SpirvInstruction *target = nullptr;
  7211. spv::Op opcode = spv::Op::OpGroupNonUniformQuadSwap;
  7212. switch (op) {
  7213. case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
  7214. target =
  7215. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  7216. break;
  7217. case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
  7218. target =
  7219. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  7220. break;
  7221. case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
  7222. target =
  7223. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2));
  7224. break;
  7225. case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
  7226. target = doExpr(callExpr->getArg(1));
  7227. opcode = spv::Op::OpGroupNonUniformQuadBroadcast;
  7228. break;
  7229. default:
  7230. llvm_unreachable("case should not appear here");
  7231. }
  7232. return spvBuilder.createGroupNonUniformBinaryOp(
  7233. opcode, retType, spv::Scope::Subgroup, value, target, srcLoc);
  7234. }
  7235. SpirvInstruction *SpirvEmitter::processIntrinsicModf(const CallExpr *callExpr) {
  7236. // Signature is: ret modf(x, ip)
  7237. // [in] x: the input floating-point value.
  7238. // [out] ip: the integer portion of x.
  7239. // [out] ret: the fractional portion of x.
  7240. // All of the above must be a scalar, vector, or matrix with the same
  7241. // component types. Component types can be float or int.
  7242. // The ModfStruct SPIR-V instruction returns a struct. The first member is the
  7243. // fractional part and the second member is the integer portion.
  7244. // ModfStruct {
  7245. // <scalar or vector of float> frac;
  7246. // <scalar or vector of float> ip;
  7247. // }
  7248. // Note if the input number (x) is not a float (i.e. 'x' is an int), it is
  7249. // automatically converted to float before modf is invoked. Sadly, the 'ip'
  7250. // argument is not treated the same way. Therefore, in such cases we'll have
  7251. // to manually convert the float result into int.
  7252. const Expr *arg = callExpr->getArg(0);
  7253. const Expr *ipArg = callExpr->getArg(1);
  7254. const auto loc = callExpr->getLocStart();
  7255. const auto argType = arg->getType();
  7256. const auto ipType = ipArg->getType();
  7257. const auto returnType = callExpr->getType();
  7258. auto *argInstr = doExpr(arg);
  7259. // For scalar and vector argument types.
  7260. {
  7261. if (isScalarType(argType) || isVectorType(argType)) {
  7262. // The struct members *must* have the same type.
  7263. const auto modfStructType = spvContext.getHybridStructType(
  7264. {HybridStructType::FieldInfo(argType, "frac"),
  7265. HybridStructType::FieldInfo(argType, "ip")},
  7266. "ModfStructType");
  7267. auto *modf = spvBuilder.createGLSLExtInst(
  7268. modfStructType, GLSLstd450::GLSLstd450ModfStruct, {argInstr}, loc);
  7269. SpirvInstruction *ip =
  7270. spvBuilder.createCompositeExtract(argType, modf, {1}, loc);
  7271. // This will do nothing if the input number (x) and the ip are both of the
  7272. // same type. Otherwise, it will convert the ip into int as necessary.
  7273. ip = castToInt(ip, argType, ipType, ipArg->getLocStart());
  7274. processAssignment(ipArg, ip, false, nullptr);
  7275. return spvBuilder.createCompositeExtract(argType, modf, {0}, loc);
  7276. }
  7277. }
  7278. // For matrix argument types.
  7279. {
  7280. uint32_t rowCount = 0, colCount = 0;
  7281. QualType elemType = {};
  7282. if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  7283. const auto colType = astContext.getExtVectorType(elemType, colCount);
  7284. const auto modfStructType = spvContext.getHybridStructType(
  7285. {HybridStructType::FieldInfo(colType, "frac"),
  7286. HybridStructType::FieldInfo(colType, "ip")},
  7287. "ModfStructType");
  7288. llvm::SmallVector<SpirvInstruction *, 4> fracs;
  7289. llvm::SmallVector<SpirvInstruction *, 4> ips;
  7290. for (uint32_t i = 0; i < rowCount; ++i) {
  7291. auto *curRow =
  7292. spvBuilder.createCompositeExtract(colType, argInstr, {i}, loc);
  7293. auto *modf = spvBuilder.createGLSLExtInst(
  7294. modfStructType, GLSLstd450::GLSLstd450ModfStruct, {curRow}, loc);
  7295. ips.push_back(
  7296. spvBuilder.createCompositeExtract(colType, modf, {1}, loc));
  7297. fracs.push_back(
  7298. spvBuilder.createCompositeExtract(colType, modf, {0}, loc));
  7299. }
  7300. SpirvInstruction *ip =
  7301. spvBuilder.createCompositeConstruct(argType, ips, loc);
  7302. // If the 'ip' is not a float type, the AST will not contain a CastExpr
  7303. // because this is internal to the intrinsic function. So, in such a
  7304. // case we need to cast manually.
  7305. if (!hlsl::GetHLSLMatElementType(ipType)->isFloatingType())
  7306. ip = castToInt(ip, argType, ipType, ipArg->getLocStart());
  7307. processAssignment(ipArg, ip, false, nullptr);
  7308. return spvBuilder.createCompositeConstruct(returnType, fracs, loc);
  7309. }
  7310. }
  7311. emitError("invalid argument type passed to Modf intrinsic function",
  7312. callExpr->getExprLoc());
  7313. return nullptr;
  7314. }
  7315. SpirvInstruction *SpirvEmitter::processIntrinsicMad(const CallExpr *callExpr) {
  7316. // Signature is: ret mad(a,b,c)
  7317. // All of the above must be a scalar, vector, or matrix with the same
  7318. // component types. Component types can be float or int.
  7319. // The return value is equal to "a * b + c"
  7320. // In the case of float arguments, we can use the GLSL extended instruction
  7321. // set's Fma instruction with NoContraction decoration. In the case of integer
  7322. // arguments, we'll have to manually perform an OpIMul followed by an OpIAdd
  7323. // (We should also apply NoContraction decoration to these two instructions to
  7324. // get precise arithmetic).
  7325. // TODO: We currently don't propagate the NoContraction decoration.
  7326. const auto loc = callExpr->getLocStart();
  7327. const Expr *arg0 = callExpr->getArg(0);
  7328. const Expr *arg1 = callExpr->getArg(1);
  7329. const Expr *arg2 = callExpr->getArg(2);
  7330. // All arguments and the return type are the same.
  7331. const auto argType = arg0->getType();
  7332. auto *arg0Instr = doExpr(arg0);
  7333. auto *arg1Instr = doExpr(arg1);
  7334. auto *arg2Instr = doExpr(arg2);
  7335. auto arg0Loc = arg0->getLocStart();
  7336. auto arg1Loc = arg1->getLocStart();
  7337. auto arg2Loc = arg2->getLocStart();
  7338. // For floating point arguments, we can use the extended instruction set's Fma
  7339. // instruction. Sadly we can't simply call processIntrinsicUsingGLSLInst
  7340. // because we need to specifically decorate the Fma instruction with
  7341. // NoContraction decoration.
  7342. if (isFloatOrVecMatOfFloatType(argType)) {
  7343. // For matrix cases, operate on each row of the matrix.
  7344. if (isMxNMatrix(arg0->getType())) {
  7345. const auto actOnEachVec = [this, loc, arg1Instr, arg2Instr, arg1Loc,
  7346. arg2Loc](uint32_t index, QualType vecType,
  7347. SpirvInstruction *arg0Row) {
  7348. auto *arg1Row = spvBuilder.createCompositeExtract(vecType, arg1Instr,
  7349. {index}, arg1Loc);
  7350. auto *arg2Row = spvBuilder.createCompositeExtract(vecType, arg2Instr,
  7351. {index}, arg2Loc);
  7352. auto *fma = spvBuilder.createGLSLExtInst(
  7353. vecType, GLSLstd450Fma, {arg0Row, arg1Row, arg2Row}, loc);
  7354. spvBuilder.decorateNoContraction(fma, loc);
  7355. return fma;
  7356. };
  7357. return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc);
  7358. }
  7359. // Non-matrix cases
  7360. auto *fma = spvBuilder.createGLSLExtInst(
  7361. argType, GLSLstd450Fma, {arg0Instr, arg1Instr, arg2Instr}, loc);
  7362. spvBuilder.decorateNoContraction(fma, loc);
  7363. return fma;
  7364. }
  7365. // For scalar and vector argument types.
  7366. {
  7367. if (isScalarType(argType) || isVectorType(argType)) {
  7368. auto *mul = spvBuilder.createBinaryOp(spv::Op::OpIMul, argType, arg0Instr,
  7369. arg1Instr, loc);
  7370. auto *add = spvBuilder.createBinaryOp(spv::Op::OpIAdd, argType, mul,
  7371. arg2Instr, loc);
  7372. spvBuilder.decorateNoContraction(mul, loc);
  7373. spvBuilder.decorateNoContraction(add, loc);
  7374. return add;
  7375. }
  7376. }
  7377. // For matrix argument types.
  7378. {
  7379. uint32_t rowCount = 0, colCount = 0;
  7380. QualType elemType = {};
  7381. if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  7382. const auto colType = astContext.getExtVectorType(elemType, colCount);
  7383. llvm::SmallVector<SpirvInstruction *, 4> resultRows;
  7384. for (uint32_t i = 0; i < rowCount; ++i) {
  7385. auto *rowArg0 =
  7386. spvBuilder.createCompositeExtract(colType, arg0Instr, {i}, arg0Loc);
  7387. auto *rowArg1 =
  7388. spvBuilder.createCompositeExtract(colType, arg1Instr, {i}, arg1Loc);
  7389. auto *rowArg2 =
  7390. spvBuilder.createCompositeExtract(colType, arg2Instr, {i}, arg2Loc);
  7391. auto *mul = spvBuilder.createBinaryOp(spv::Op::OpIMul, colType, rowArg0,
  7392. rowArg1, loc);
  7393. auto *add = spvBuilder.createBinaryOp(spv::Op::OpIAdd, colType, mul,
  7394. rowArg2, loc);
  7395. spvBuilder.decorateNoContraction(mul, loc);
  7396. spvBuilder.decorateNoContraction(add, loc);
  7397. resultRows.push_back(add);
  7398. }
  7399. return spvBuilder.createCompositeConstruct(argType, resultRows, loc);
  7400. }
  7401. }
  7402. emitError("invalid argument type passed to mad intrinsic function",
  7403. callExpr->getExprLoc());
  7404. return 0;
  7405. }
  7406. SpirvInstruction *SpirvEmitter::processIntrinsicLit(const CallExpr *callExpr) {
  7407. // Signature is: float4 lit(float n_dot_l, float n_dot_h, float m)
  7408. //
  7409. // This function returns a lighting coefficient vector
  7410. // (ambient, diffuse, specular, 1) where:
  7411. // ambient = 1.
  7412. // diffuse = (n_dot_l < 0) ? 0 : n_dot_l
  7413. // specular = (n_dot_l < 0 || n_dot_h < 0) ? 0 : ((n_dot_h) * m)
  7414. auto *nDotL = doExpr(callExpr->getArg(0));
  7415. auto *nDotH = doExpr(callExpr->getArg(1));
  7416. auto *m = doExpr(callExpr->getArg(2));
  7417. const auto loc = callExpr->getExprLoc();
  7418. const QualType floatType = astContext.FloatTy;
  7419. const QualType boolType = astContext.BoolTy;
  7420. SpirvInstruction *floatZero =
  7421. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(0.0f));
  7422. SpirvInstruction *floatOne =
  7423. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f));
  7424. const QualType retType = callExpr->getType();
  7425. auto *diffuse = spvBuilder.createGLSLExtInst(
  7426. floatType, GLSLstd450::GLSLstd450FMax, {floatZero, nDotL}, loc);
  7427. auto *min = spvBuilder.createGLSLExtInst(
  7428. floatType, GLSLstd450::GLSLstd450FMin, {nDotL, nDotH}, loc);
  7429. auto *isNeg = spvBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolType,
  7430. min, floatZero, loc);
  7431. auto *mul =
  7432. spvBuilder.createBinaryOp(spv::Op::OpFMul, floatType, nDotH, m, loc);
  7433. auto *specular =
  7434. spvBuilder.createSelect(floatType, isNeg, floatZero, mul, loc);
  7435. return spvBuilder.createCompositeConstruct(
  7436. retType, {floatOne, diffuse, specular, floatOne}, callExpr->getLocEnd());
  7437. }
  7438. SpirvInstruction *
  7439. SpirvEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
  7440. // Signature is: ret frexp(x, exp)
  7441. // [in] x: the input floating-point value.
  7442. // [out] exp: the calculated exponent.
  7443. // [out] ret: the calculated mantissa.
  7444. // All of the above must be a scalar, vector, or matrix of *float* type.
  7445. // The FrexpStruct SPIR-V instruction returns a struct. The first
  7446. // member is the significand (mantissa) and must be of the same type as the
  7447. // input parameter, and the second member is the exponent and must always be a
  7448. // scalar or vector of 32-bit *integer* type.
  7449. // FrexpStruct {
  7450. // <scalar or vector of int/float> mantissa;
  7451. // <scalar or vector of integers> exponent;
  7452. // }
  7453. const Expr *arg = callExpr->getArg(0);
  7454. const auto argType = arg->getType();
  7455. const auto returnType = callExpr->getType();
  7456. const auto loc = callExpr->getExprLoc();
  7457. auto *argInstr = doExpr(arg);
  7458. auto *expInstr = doExpr(callExpr->getArg(1));
  7459. // For scalar and vector argument types.
  7460. {
  7461. uint32_t elemCount = 1;
  7462. if (isScalarType(argType) || isVectorType(argType, nullptr, &elemCount)) {
  7463. const QualType expType =
  7464. elemCount == 1
  7465. ? astContext.IntTy
  7466. : astContext.getExtVectorType(astContext.IntTy, elemCount);
  7467. const auto *frexpStructType = spvContext.getHybridStructType(
  7468. {HybridStructType::FieldInfo(argType, "mantissa"),
  7469. HybridStructType::FieldInfo(expType, "exponent")},
  7470. "FrexpStructType");
  7471. auto *frexp = spvBuilder.createGLSLExtInst(
  7472. frexpStructType, GLSLstd450::GLSLstd450FrexpStruct, {argInstr}, loc);
  7473. auto *exponentInt =
  7474. spvBuilder.createCompositeExtract(expType, frexp, {1}, loc);
  7475. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  7476. // expects a float, an conversion must take place before writing the
  7477. // results.
  7478. auto *exponentFloat = spvBuilder.createUnaryOp(
  7479. spv::Op::OpConvertSToF, returnType, exponentInt, loc);
  7480. spvBuilder.createStore(expInstr, exponentFloat, loc);
  7481. return spvBuilder.createCompositeExtract(argType, frexp, {0}, loc);
  7482. }
  7483. }
  7484. // For matrix argument types.
  7485. {
  7486. uint32_t rowCount = 0, colCount = 0;
  7487. if (isMxNMatrix(argType, nullptr, &rowCount, &colCount)) {
  7488. const auto expType =
  7489. astContext.getExtVectorType(astContext.IntTy, colCount);
  7490. const auto colType =
  7491. astContext.getExtVectorType(astContext.FloatTy, colCount);
  7492. const auto *frexpStructType = spvContext.getHybridStructType(
  7493. {HybridStructType::FieldInfo(colType, "mantissa"),
  7494. HybridStructType::FieldInfo(expType, "exponent")},
  7495. "FrexpStructType");
  7496. llvm::SmallVector<SpirvInstruction *, 4> exponents;
  7497. llvm::SmallVector<SpirvInstruction *, 4> mantissas;
  7498. for (uint32_t i = 0; i < rowCount; ++i) {
  7499. auto *curRow = spvBuilder.createCompositeExtract(colType, argInstr, {i},
  7500. arg->getLocStart());
  7501. auto *frexp = spvBuilder.createGLSLExtInst(
  7502. frexpStructType, GLSLstd450::GLSLstd450FrexpStruct, {curRow}, loc);
  7503. auto *exponentInt =
  7504. spvBuilder.createCompositeExtract(expType, frexp, {1}, loc);
  7505. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  7506. // expects a float, an conversion must take place before writing the
  7507. // results.
  7508. auto *exponentFloat = spvBuilder.createUnaryOp(
  7509. spv::Op::OpConvertSToF, colType, exponentInt, loc);
  7510. exponents.push_back(exponentFloat);
  7511. mantissas.push_back(
  7512. spvBuilder.createCompositeExtract(colType, frexp, {0}, loc));
  7513. }
  7514. auto *exponentsResult =
  7515. spvBuilder.createCompositeConstruct(returnType, exponents, loc);
  7516. spvBuilder.createStore(expInstr, exponentsResult, loc);
  7517. return spvBuilder.createCompositeConstruct(returnType, mantissas,
  7518. callExpr->getLocEnd());
  7519. }
  7520. }
  7521. emitError("invalid argument type passed to Frexp intrinsic function",
  7522. callExpr->getExprLoc());
  7523. return nullptr;
  7524. }
  7525. SpirvInstruction *
  7526. SpirvEmitter::processIntrinsicLdexp(const CallExpr *callExpr) {
  7527. // Signature: ret ldexp(x, exp)
  7528. // This function uses the following formula: x * 2^exp.
  7529. // Note that we cannot use GLSL extended instruction Ldexp since it requires
  7530. // the exponent to be an integer (vector) but HLSL takes an float (vector)
  7531. // exponent. So we must calculate the result manually.
  7532. const Expr *x = callExpr->getArg(0);
  7533. const auto paramType = x->getType();
  7534. auto *xInstr = doExpr(x);
  7535. auto *expInstr = doExpr(callExpr->getArg(1));
  7536. const auto loc = callExpr->getLocStart();
  7537. const auto arg1Loc = callExpr->getArg(1)->getLocStart();
  7538. // For scalar and vector argument types.
  7539. if (isScalarType(paramType) || isVectorType(paramType)) {
  7540. const auto twoExp = spvBuilder.createGLSLExtInst(
  7541. paramType, GLSLstd450::GLSLstd450Exp2, {expInstr}, loc);
  7542. return spvBuilder.createBinaryOp(spv::Op::OpFMul, paramType, xInstr, twoExp,
  7543. loc);
  7544. }
  7545. // For matrix argument types.
  7546. {
  7547. uint32_t rowCount = 0, colCount = 0;
  7548. if (isMxNMatrix(paramType, nullptr, &rowCount, &colCount)) {
  7549. const auto actOnEachVec = [this, loc, expInstr,
  7550. arg1Loc](uint32_t index, QualType vecType,
  7551. SpirvInstruction *xRowInstr) {
  7552. auto *expRowInstr = spvBuilder.createCompositeExtract(vecType, expInstr,
  7553. {index}, arg1Loc);
  7554. auto *twoExp = spvBuilder.createGLSLExtInst(
  7555. vecType, GLSLstd450::GLSLstd450Exp2, {expRowInstr}, loc);
  7556. return spvBuilder.createBinaryOp(spv::Op::OpFMul, vecType, xRowInstr,
  7557. twoExp, loc);
  7558. };
  7559. return processEachVectorInMatrix(x, xInstr, actOnEachVec, loc);
  7560. }
  7561. }
  7562. emitError("invalid argument type passed to ldexp intrinsic function",
  7563. callExpr->getExprLoc());
  7564. return nullptr;
  7565. }
  7566. SpirvInstruction *SpirvEmitter::processIntrinsicDst(const CallExpr *callExpr) {
  7567. // Signature is float4 dst(float4 src0, float4 src1)
  7568. // result.x = 1;
  7569. // result.y = src0.y * src1.y;
  7570. // result.z = src0.z;
  7571. // result.w = src1.w;
  7572. const QualType f32 = astContext.FloatTy;
  7573. auto *arg0Id = doExpr(callExpr->getArg(0));
  7574. auto *arg1Id = doExpr(callExpr->getArg(1));
  7575. auto arg0Loc = callExpr->getArg(0)->getLocStart();
  7576. auto arg1Loc = callExpr->getArg(1)->getLocStart();
  7577. auto *arg0y = spvBuilder.createCompositeExtract(f32, arg0Id, {1}, arg0Loc);
  7578. auto *arg1y = spvBuilder.createCompositeExtract(f32, arg1Id, {1}, arg1Loc);
  7579. auto *arg0z = spvBuilder.createCompositeExtract(f32, arg0Id, {2}, arg0Loc);
  7580. auto *arg1w = spvBuilder.createCompositeExtract(f32, arg1Id, {3}, arg1Loc);
  7581. auto loc = callExpr->getLocEnd();
  7582. auto *arg0yMularg1y =
  7583. spvBuilder.createBinaryOp(spv::Op::OpFMul, f32, arg0y, arg1y, loc);
  7584. return spvBuilder.createCompositeConstruct(
  7585. callExpr->getType(),
  7586. {spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)),
  7587. arg0yMularg1y, arg0z, arg1w},
  7588. loc);
  7589. }
  7590. SpirvInstruction *SpirvEmitter::processIntrinsicClip(const CallExpr *callExpr) {
  7591. // Discards the current pixel if the specified value is less than zero.
  7592. // TODO: If the argument can be const folded and evaluated, we could
  7593. // potentially avoid creating a branch. This would be a bit challenging for
  7594. // matrix/vector arguments.
  7595. assert(callExpr->getNumArgs() == 1u);
  7596. const Expr *arg = callExpr->getArg(0);
  7597. const auto loc = callExpr->getExprLoc();
  7598. const auto argType = arg->getType();
  7599. const auto boolType = astContext.BoolTy;
  7600. SpirvInstruction *condition = nullptr;
  7601. // Could not determine the argument as a constant. We need to branch based on
  7602. // the argument. If the argument is a vector/matrix, clipping is done if *any*
  7603. // element of the vector/matrix is less than zero.
  7604. auto *argInstr = doExpr(arg);
  7605. QualType elemType = {};
  7606. uint32_t elemCount = 0, rowCount = 0, colCount = 0;
  7607. if (isScalarType(argType)) {
  7608. auto *zero = getValueZero(argType);
  7609. condition = spvBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolType,
  7610. argInstr, zero, loc);
  7611. } else if (isVectorType(argType, nullptr, &elemCount)) {
  7612. auto *zero = getValueZero(argType);
  7613. const QualType boolVecType =
  7614. astContext.getExtVectorType(boolType, elemCount);
  7615. auto *cmp = spvBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolVecType,
  7616. argInstr, zero, loc);
  7617. condition = spvBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp, loc);
  7618. } else if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  7619. const auto floatVecType = astContext.getExtVectorType(elemType, colCount);
  7620. auto *elemZero = getValueZero(elemType);
  7621. llvm::SmallVector<SpirvConstant *, 4> elements(size_t(colCount), elemZero);
  7622. auto *zero = spvBuilder.getConstantComposite(floatVecType, elements);
  7623. llvm::SmallVector<SpirvInstruction *, 4> cmpResults;
  7624. for (uint32_t i = 0; i < rowCount; ++i) {
  7625. auto *lhsVec =
  7626. spvBuilder.createCompositeExtract(floatVecType, argInstr, {i}, loc);
  7627. const auto boolColType = astContext.getExtVectorType(boolType, colCount);
  7628. auto *cmp = spvBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  7629. boolColType, lhsVec, zero, loc);
  7630. auto *any = spvBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp, loc);
  7631. cmpResults.push_back(any);
  7632. }
  7633. const auto boolRowType = astContext.getExtVectorType(boolType, rowCount);
  7634. auto *results =
  7635. spvBuilder.createCompositeConstruct(boolRowType, cmpResults, loc);
  7636. condition =
  7637. spvBuilder.createUnaryOp(spv::Op::OpAny, boolType, results, loc);
  7638. } else {
  7639. emitError("invalid argument type passed to clip intrinsic function", loc);
  7640. return nullptr;
  7641. }
  7642. // Then we need to emit the instruction for the conditional branch.
  7643. auto *thenBB = spvBuilder.createBasicBlock("if.true");
  7644. auto *mergeBB = spvBuilder.createBasicBlock("if.merge");
  7645. // Create the branch instruction. This will end the current basic block.
  7646. spvBuilder.createConditionalBranch(condition, thenBB, mergeBB, loc, mergeBB);
  7647. spvBuilder.addSuccessor(thenBB);
  7648. spvBuilder.addSuccessor(mergeBB);
  7649. spvBuilder.setMergeTarget(mergeBB);
  7650. // Handle the then branch
  7651. spvBuilder.setInsertPoint(thenBB);
  7652. spvBuilder.createKill(loc);
  7653. spvBuilder.addSuccessor(mergeBB);
  7654. // From now on, we'll emit instructions into the merge block.
  7655. spvBuilder.setInsertPoint(mergeBB);
  7656. return nullptr;
  7657. }
  7658. SpirvInstruction *
  7659. SpirvEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
  7660. // According the HLSL reference: clamp(X, Min, Max) takes 3 arguments. Each
  7661. // one may be int, uint, or float.
  7662. const QualType returnType = callExpr->getType();
  7663. GLSLstd450 glslOpcode = GLSLstd450::GLSLstd450UClamp;
  7664. if (isFloatOrVecMatOfFloatType(returnType))
  7665. glslOpcode = GLSLstd450::GLSLstd450FClamp;
  7666. else if (isSintOrVecMatOfSintType(returnType))
  7667. glslOpcode = GLSLstd450::GLSLstd450SClamp;
  7668. // Get the function parameters. Expect 3 parameters.
  7669. assert(callExpr->getNumArgs() == 3u);
  7670. const Expr *argX = callExpr->getArg(0);
  7671. const Expr *argMin = callExpr->getArg(1);
  7672. const Expr *argMax = callExpr->getArg(2);
  7673. const auto loc = callExpr->getExprLoc();
  7674. auto *argXInstr = doExpr(argX);
  7675. auto *argMinInstr = doExpr(argMin);
  7676. auto *argMaxInstr = doExpr(argMax);
  7677. const auto argMinLoc = argMin->getLocStart();
  7678. const auto argMaxLoc = argMax->getLocStart();
  7679. // FClamp, UClamp, and SClamp do not operate on matrices, so we should perform
  7680. // the operation on each vector of the matrix.
  7681. if (isMxNMatrix(argX->getType())) {
  7682. const auto actOnEachVec =
  7683. [this, loc, glslOpcode, argMinInstr, argMaxInstr, argMinLoc, argMaxLoc](
  7684. uint32_t index, QualType vecType, SpirvInstruction *curRow) {
  7685. auto *minRowInstr = spvBuilder.createCompositeExtract(
  7686. vecType, argMinInstr, {index}, argMinLoc);
  7687. auto *maxRowInstr = spvBuilder.createCompositeExtract(
  7688. vecType, argMaxInstr, {index}, argMaxLoc);
  7689. return spvBuilder.createGLSLExtInst(
  7690. vecType, glslOpcode, {curRow, minRowInstr, maxRowInstr}, loc);
  7691. };
  7692. return processEachVectorInMatrix(argX, argXInstr, actOnEachVec, loc);
  7693. }
  7694. return spvBuilder.createGLSLExtInst(
  7695. returnType, glslOpcode, {argXInstr, argMinInstr, argMaxInstr}, loc);
  7696. }
  7697. SpirvInstruction *
  7698. SpirvEmitter::processIntrinsicMemoryBarrier(const CallExpr *callExpr,
  7699. bool isDevice, bool groupSync,
  7700. bool isAllBarrier) {
  7701. // * DeviceMemoryBarrier =
  7702. // OpMemoryBarrier (memScope=Device,
  7703. // sem=Image|Uniform|AcquireRelease)
  7704. //
  7705. // * DeviceMemoryBarrierWithGroupSync =
  7706. // OpControlBarrier(execScope = Workgroup,
  7707. // memScope=Device,
  7708. // sem=Image|Uniform|AcquireRelease)
  7709. const spv::MemorySemanticsMask deviceMemoryBarrierSema =
  7710. spv::MemorySemanticsMask::ImageMemory |
  7711. spv::MemorySemanticsMask::UniformMemory |
  7712. spv::MemorySemanticsMask::AcquireRelease;
  7713. // * GroupMemoryBarrier =
  7714. // OpMemoryBarrier (memScope=Workgroup,
  7715. // sem = Workgroup|AcquireRelease)
  7716. //
  7717. // * GroupMemoryBarrierWithGroupSync =
  7718. // OpControlBarrier (execScope = Workgroup,
  7719. // memScope = Workgroup,
  7720. // sem = Workgroup|AcquireRelease)
  7721. const spv::MemorySemanticsMask groupMemoryBarrierSema =
  7722. spv::MemorySemanticsMask::WorkgroupMemory |
  7723. spv::MemorySemanticsMask::AcquireRelease;
  7724. // * AllMemoryBarrier =
  7725. // OpMemoryBarrier(memScope = Device,
  7726. // sem = Image|Uniform|Workgroup|AcquireRelease)
  7727. //
  7728. // * AllMemoryBarrierWithGroupSync =
  7729. // OpControlBarrier(execScope = Workgroup,
  7730. // memScope = Device,
  7731. // sem = Image|Uniform|Workgroup|AcquireRelease)
  7732. const spv::MemorySemanticsMask allMemoryBarrierSema =
  7733. spv::MemorySemanticsMask::ImageMemory |
  7734. spv::MemorySemanticsMask::UniformMemory |
  7735. spv::MemorySemanticsMask::WorkgroupMemory |
  7736. spv::MemorySemanticsMask::AcquireRelease;
  7737. // Get <result-id> for execution scope.
  7738. // If present, execution scope is always Workgroup!
  7739. llvm::Optional<spv::Scope> execScope = llvm::None;
  7740. if (groupSync) {
  7741. execScope = spv::Scope::Workgroup;
  7742. }
  7743. // Get <result-id> for memory scope
  7744. const spv::Scope memScope =
  7745. (isDevice || isAllBarrier) ? spv::Scope::Device : spv::Scope::Workgroup;
  7746. // Get <result-id> for memory semantics
  7747. const auto memSemaMask = isAllBarrier ? allMemoryBarrierSema
  7748. : isDevice ? deviceMemoryBarrierSema
  7749. : groupMemoryBarrierSema;
  7750. spvBuilder.createBarrier(memScope, memSemaMask, execScope,
  7751. callExpr->getExprLoc());
  7752. return nullptr;
  7753. }
  7754. SpirvInstruction *SpirvEmitter::processNonFpMatrixTranspose(
  7755. QualType matType, SpirvInstruction *matrix, SourceLocation loc) {
  7756. // Simplest way is to flatten the matrix construct a new matrix from the
  7757. // flattened elements. (for a mat4x4).
  7758. QualType elemType = {};
  7759. uint32_t numRows = 0, numCols = 0;
  7760. const bool isMat = isMxNMatrix(matType, &elemType, &numRows, &numCols);
  7761. assert(isMat && !elemType->isFloatingType());
  7762. (void)isMat;
  7763. const auto colQualType = astContext.getExtVectorType(elemType, numRows);
  7764. // You cannot perform a composite construct of an array using a few vectors.
  7765. // The number of constutients passed to OpCompositeConstruct must be equal to
  7766. // the number of array elements.
  7767. llvm::SmallVector<SpirvInstruction *, 4> elems;
  7768. for (uint32_t i = 0; i < numRows; ++i)
  7769. for (uint32_t j = 0; j < numCols; ++j)
  7770. elems.push_back(
  7771. spvBuilder.createCompositeExtract(elemType, matrix, {i, j}, loc));
  7772. llvm::SmallVector<SpirvInstruction *, 4> cols;
  7773. for (uint32_t i = 0; i < numCols; ++i) {
  7774. // The elements in the ith vector of the "transposed" array are at offset i,
  7775. // i + <original-vector-size>, ...
  7776. llvm::SmallVector<SpirvInstruction *, 4> indexes;
  7777. for (uint32_t j = 0; j < numRows; ++j)
  7778. indexes.push_back(elems[i + (j * numCols)]);
  7779. cols.push_back(
  7780. spvBuilder.createCompositeConstruct(colQualType, indexes, loc));
  7781. }
  7782. auto transposeType = astContext.getConstantArrayType(
  7783. colQualType, llvm::APInt(32, numCols), clang::ArrayType::Normal, 0);
  7784. return spvBuilder.createCompositeConstruct(transposeType, cols, loc);
  7785. }
  7786. SpirvInstruction *SpirvEmitter::processNonFpDot(SpirvInstruction *vec1Id,
  7787. SpirvInstruction *vec2Id,
  7788. uint32_t vecSize,
  7789. QualType elemType,
  7790. SourceLocation loc) {
  7791. llvm::SmallVector<SpirvInstruction *, 4> muls;
  7792. for (uint32_t i = 0; i < vecSize; ++i) {
  7793. auto *elem1 = spvBuilder.createCompositeExtract(elemType, vec1Id, {i}, loc);
  7794. auto *elem2 = spvBuilder.createCompositeExtract(elemType, vec2Id, {i}, loc);
  7795. muls.push_back(spvBuilder.createBinaryOp(translateOp(BO_Mul, elemType),
  7796. elemType, elem1, elem2, loc));
  7797. }
  7798. SpirvInstruction *sum = muls[0];
  7799. for (uint32_t i = 1; i < vecSize; ++i) {
  7800. sum = spvBuilder.createBinaryOp(translateOp(BO_Add, elemType), elemType,
  7801. sum, muls[i], loc);
  7802. }
  7803. return sum;
  7804. }
  7805. SpirvInstruction *SpirvEmitter::processNonFpScalarTimesMatrix(
  7806. QualType scalarType, SpirvInstruction *scalar, QualType matrixType,
  7807. SpirvInstruction *matrix, SourceLocation loc) {
  7808. assert(isScalarType(scalarType));
  7809. QualType elemType = {};
  7810. uint32_t numRows = 0, numCols = 0;
  7811. const bool isMat = isMxNMatrix(matrixType, &elemType, &numRows, &numCols);
  7812. assert(isMat);
  7813. assert(isSameType(astContext, scalarType, elemType));
  7814. (void)isMat;
  7815. // We need to multiply the scalar by each vector of the matrix.
  7816. // The front-end guarantees that the scalar and matrix element type are
  7817. // the same. For example, if the scalar is a float, the matrix is casted
  7818. // to a float matrix before being passed to mul(). It is also guaranteed
  7819. // that types such as bool are casted to float or int before being
  7820. // passed to mul().
  7821. const auto rowType = astContext.getExtVectorType(elemType, numCols);
  7822. llvm::SmallVector<SpirvInstruction *, 4> splat(size_t(numCols), scalar);
  7823. auto *scalarSplat = spvBuilder.createCompositeConstruct(rowType, splat, loc);
  7824. llvm::SmallVector<SpirvInstruction *, 4> mulRows;
  7825. for (uint32_t row = 0; row < numRows; ++row) {
  7826. auto *rowInstr =
  7827. spvBuilder.createCompositeExtract(rowType, matrix, {row}, loc);
  7828. mulRows.push_back(spvBuilder.createBinaryOp(
  7829. translateOp(BO_Mul, scalarType), rowType, rowInstr, scalarSplat, loc));
  7830. }
  7831. return spvBuilder.createCompositeConstruct(matrixType, mulRows, loc);
  7832. }
  7833. SpirvInstruction *SpirvEmitter::processNonFpVectorTimesMatrix(
  7834. QualType vecType, SpirvInstruction *vector, QualType matType,
  7835. SpirvInstruction *matrix, SourceLocation loc,
  7836. SpirvInstruction *matrixTranspose) {
  7837. // This function assumes that the vector element type and matrix elemet type
  7838. // are the same.
  7839. QualType vecElemType = {}, matElemType = {};
  7840. uint32_t vecSize = 0, numRows = 0, numCols = 0;
  7841. const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
  7842. const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
  7843. assert(isSameType(astContext, vecElemType, matElemType));
  7844. assert(isVec);
  7845. assert(isMat);
  7846. assert(vecSize == numRows);
  7847. (void)isVec;
  7848. (void)isMat;
  7849. // When processing vector times matrix, the vector is a row vector, and it
  7850. // should be multiplied by the matrix *columns*. The most efficient way to
  7851. // handle this in SPIR-V would be to first transpose the matrix, and then use
  7852. // OpAccessChain.
  7853. if (!matrixTranspose)
  7854. matrixTranspose = processNonFpMatrixTranspose(matType, matrix, loc);
  7855. llvm::SmallVector<SpirvInstruction *, 4> resultElems;
  7856. for (uint32_t col = 0; col < numCols; ++col) {
  7857. auto *colInstr =
  7858. spvBuilder.createCompositeExtract(vecType, matrixTranspose, {col}, loc);
  7859. resultElems.push_back(
  7860. processNonFpDot(vector, colInstr, vecSize, vecElemType, loc));
  7861. }
  7862. return spvBuilder.createCompositeConstruct(
  7863. astContext.getExtVectorType(vecElemType, numCols), resultElems, loc);
  7864. }
  7865. SpirvInstruction *SpirvEmitter::processNonFpMatrixTimesVector(
  7866. QualType matType, SpirvInstruction *matrix, QualType vecType,
  7867. SpirvInstruction *vector, SourceLocation loc) {
  7868. // This function assumes that the vector element type and matrix elemet type
  7869. // are the same.
  7870. QualType vecElemType = {}, matElemType = {};
  7871. uint32_t vecSize = 0, numRows = 0, numCols = 0;
  7872. const bool isVec = isVectorType(vecType, &vecElemType, &vecSize);
  7873. const bool isMat = isMxNMatrix(matType, &matElemType, &numRows, &numCols);
  7874. assert(isSameType(astContext, vecElemType, matElemType));
  7875. assert(isVec);
  7876. assert(isMat);
  7877. assert(vecSize == numCols);
  7878. (void)isVec;
  7879. (void)isMat;
  7880. // When processing matrix times vector, the vector is a column vector. So we
  7881. // simply get each row of the matrix and perform a dot product with the
  7882. // vector.
  7883. llvm::SmallVector<SpirvInstruction *, 4> resultElems;
  7884. for (uint32_t row = 0; row < numRows; ++row) {
  7885. auto *rowInstr =
  7886. spvBuilder.createCompositeExtract(vecType, matrix, {row}, loc);
  7887. resultElems.push_back(
  7888. processNonFpDot(rowInstr, vector, vecSize, vecElemType, loc));
  7889. }
  7890. return spvBuilder.createCompositeConstruct(
  7891. astContext.getExtVectorType(vecElemType, numRows), resultElems, loc);
  7892. }
  7893. SpirvInstruction *SpirvEmitter::processNonFpMatrixTimesMatrix(
  7894. QualType lhsType, SpirvInstruction *lhs, QualType rhsType,
  7895. SpirvInstruction *rhs, SourceLocation loc) {
  7896. // This function assumes that the vector element type and matrix elemet type
  7897. // are the same.
  7898. QualType lhsElemType = {}, rhsElemType = {};
  7899. uint32_t lhsNumRows = 0, lhsNumCols = 0;
  7900. uint32_t rhsNumRows = 0, rhsNumCols = 0;
  7901. const bool lhsIsMat =
  7902. isMxNMatrix(lhsType, &lhsElemType, &lhsNumRows, &lhsNumCols);
  7903. const bool rhsIsMat =
  7904. isMxNMatrix(rhsType, &rhsElemType, &rhsNumRows, &rhsNumCols);
  7905. assert(isSameType(astContext, lhsElemType, rhsElemType));
  7906. assert(lhsIsMat && rhsIsMat);
  7907. assert(lhsNumCols == rhsNumRows);
  7908. (void)rhsIsMat;
  7909. (void)lhsIsMat;
  7910. auto *rhsTranspose = processNonFpMatrixTranspose(rhsType, rhs, loc);
  7911. const auto vecType = astContext.getExtVectorType(lhsElemType, lhsNumCols);
  7912. llvm::SmallVector<SpirvInstruction *, 4> resultRows;
  7913. for (uint32_t row = 0; row < lhsNumRows; ++row) {
  7914. auto *rowInstr =
  7915. spvBuilder.createCompositeExtract(vecType, lhs, {row}, loc);
  7916. resultRows.push_back(processNonFpVectorTimesMatrix(
  7917. vecType, rowInstr, rhsType, rhs, loc, rhsTranspose));
  7918. }
  7919. // The resulting matrix will have 'lhsNumRows' rows and 'rhsNumCols' columns.
  7920. const auto resultColType =
  7921. astContext.getExtVectorType(lhsElemType, rhsNumCols);
  7922. const auto resultType = astContext.getConstantArrayType(
  7923. resultColType, llvm::APInt(32, lhsNumRows), clang::ArrayType::Normal, 0);
  7924. return spvBuilder.createCompositeConstruct(resultType, resultRows, loc);
  7925. }
  7926. SpirvInstruction *SpirvEmitter::processIntrinsicMul(const CallExpr *callExpr) {
  7927. const QualType returnType = callExpr->getType();
  7928. // Get the function parameters. Expect 2 parameters.
  7929. assert(callExpr->getNumArgs() == 2u);
  7930. const Expr *arg0 = callExpr->getArg(0);
  7931. const Expr *arg1 = callExpr->getArg(1);
  7932. const QualType arg0Type = arg0->getType();
  7933. const QualType arg1Type = arg1->getType();
  7934. auto loc = callExpr->getExprLoc();
  7935. // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
  7936. // vector, or matrix. The frontend ensures that the two arguments have the
  7937. // same component type. The only allowed component types are int and float.
  7938. // mul(scalar, vector)
  7939. {
  7940. uint32_t elemCount = 0;
  7941. if (isScalarType(arg0Type) && isVectorType(arg1Type, nullptr, &elemCount)) {
  7942. auto *arg1Id = doExpr(arg1);
  7943. // We can use OpVectorTimesScalar if arguments are floats.
  7944. if (arg0Type->isFloatingType())
  7945. return spvBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  7946. returnType, arg1Id, doExpr(arg0), loc);
  7947. // Use OpIMul for integers
  7948. return spvBuilder.createBinaryOp(spv::Op::OpIMul, returnType,
  7949. createVectorSplat(arg0, elemCount),
  7950. arg1Id, loc);
  7951. }
  7952. }
  7953. // mul(vector, scalar)
  7954. {
  7955. uint32_t elemCount = 0;
  7956. if (isVectorType(arg0Type, nullptr, &elemCount) && isScalarType(arg1Type)) {
  7957. auto *arg0Id = doExpr(arg0);
  7958. // We can use OpVectorTimesScalar if arguments are floats.
  7959. if (arg1Type->isFloatingType())
  7960. return spvBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  7961. returnType, arg0Id, doExpr(arg1), loc);
  7962. // Use OpIMul for integers
  7963. return spvBuilder.createBinaryOp(spv::Op::OpIMul, returnType, arg0Id,
  7964. createVectorSplat(arg1, elemCount), loc);
  7965. }
  7966. }
  7967. // mul(vector, vector)
  7968. if (isVectorType(arg0Type) && isVectorType(arg1Type)) {
  7969. // mul( Mat(1xM), Mat(Mx1) ) results in a scalar (same as dot product)
  7970. if (isScalarType(returnType)) {
  7971. return processIntrinsicDot(callExpr);
  7972. }
  7973. // mul( Mat(Mx1), Mat(1xN) ) results in a MxN matrix.
  7974. QualType elemType = {};
  7975. uint32_t numRows = 0;
  7976. if (isMxNMatrix(returnType, &elemType, &numRows)) {
  7977. llvm::SmallVector<SpirvInstruction *, 4> rows;
  7978. auto *arg0Id = doExpr(arg0);
  7979. auto *arg1Id = doExpr(arg1);
  7980. for (uint32_t i = 0; i < numRows; ++i) {
  7981. auto *scalar =
  7982. spvBuilder.createCompositeExtract(elemType, arg0Id, {i}, loc);
  7983. rows.push_back(spvBuilder.createBinaryOp(
  7984. spv::Op::OpVectorTimesScalar, arg1Type, arg1Id, scalar, loc));
  7985. }
  7986. return spvBuilder.createCompositeConstruct(returnType, rows, loc);
  7987. }
  7988. llvm_unreachable("bad arguments passed to mul");
  7989. }
  7990. // All the following cases require handling arg0 and arg1 expressions first.
  7991. auto *arg0Id = doExpr(arg0);
  7992. auto *arg1Id = doExpr(arg1);
  7993. // mul(scalar, scalar)
  7994. if (isScalarType(arg0Type) && isScalarType(arg1Type))
  7995. return spvBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type), returnType,
  7996. arg0Id, arg1Id, loc);
  7997. // mul(scalar, matrix)
  7998. {
  7999. QualType elemType = {};
  8000. if (isScalarType(arg0Type) && isMxNMatrix(arg1Type, &elemType)) {
  8001. // OpMatrixTimesScalar can only be used if *both* the matrix element type
  8002. // and the scalar type are float.
  8003. if (arg0Type->isFloatingType() && elemType->isFloatingType())
  8004. return spvBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  8005. returnType, arg1Id, arg0Id, loc);
  8006. else
  8007. return processNonFpScalarTimesMatrix(arg0Type, arg0Id, arg1Type, arg1Id,
  8008. callExpr->getExprLoc());
  8009. }
  8010. }
  8011. // mul(matrix, scalar)
  8012. {
  8013. QualType elemType = {};
  8014. if (isScalarType(arg1Type) && isMxNMatrix(arg0Type, &elemType)) {
  8015. // OpMatrixTimesScalar can only be used if *both* the matrix element type
  8016. // and the scalar type are float.
  8017. if (arg1Type->isFloatingType() && elemType->isFloatingType())
  8018. return spvBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  8019. returnType, arg0Id, arg1Id, loc);
  8020. else
  8021. return processNonFpScalarTimesMatrix(arg1Type, arg1Id, arg0Type, arg0Id,
  8022. callExpr->getExprLoc());
  8023. }
  8024. }
  8025. // mul(vector, matrix)
  8026. {
  8027. QualType vecElemType = {}, matElemType = {};
  8028. uint32_t elemCount = 0, numRows = 0;
  8029. if (isVectorType(arg0Type, &vecElemType, &elemCount) &&
  8030. isMxNMatrix(arg1Type, &matElemType, &numRows)) {
  8031. assert(elemCount == numRows);
  8032. if (vecElemType->isFloatingType() && matElemType->isFloatingType())
  8033. return spvBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
  8034. returnType, arg1Id, arg0Id, loc);
  8035. else
  8036. return processNonFpVectorTimesMatrix(arg0Type, arg0Id, arg1Type, arg1Id,
  8037. callExpr->getExprLoc());
  8038. }
  8039. }
  8040. // mul(matrix, vector)
  8041. {
  8042. QualType vecElemType = {}, matElemType = {};
  8043. uint32_t elemCount = 0, numCols = 0;
  8044. if (isMxNMatrix(arg0Type, &matElemType, nullptr, &numCols) &&
  8045. isVectorType(arg1Type, &vecElemType, &elemCount)) {
  8046. assert(elemCount == numCols);
  8047. if (vecElemType->isFloatingType() && matElemType->isFloatingType())
  8048. return spvBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
  8049. returnType, arg1Id, arg0Id, loc);
  8050. else
  8051. return processNonFpMatrixTimesVector(arg0Type, arg0Id, arg1Type, arg1Id,
  8052. callExpr->getExprLoc());
  8053. }
  8054. }
  8055. // mul(matrix, matrix)
  8056. {
  8057. // The front-end ensures that the two matrix element types match.
  8058. QualType elemType = {};
  8059. uint32_t lhsCols = 0, rhsRows = 0;
  8060. if (isMxNMatrix(arg0Type, &elemType, nullptr, &lhsCols) &&
  8061. isMxNMatrix(arg1Type, nullptr, &rhsRows, nullptr)) {
  8062. assert(lhsCols == rhsRows);
  8063. if (elemType->isFloatingType())
  8064. return spvBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
  8065. returnType, arg1Id, arg0Id, loc);
  8066. else
  8067. return processNonFpMatrixTimesMatrix(arg0Type, arg0Id, arg1Type, arg1Id,
  8068. callExpr->getExprLoc());
  8069. }
  8070. }
  8071. emitError("invalid argument type passed to mul intrinsic function",
  8072. callExpr->getExprLoc());
  8073. return nullptr;
  8074. }
  8075. SpirvInstruction *
  8076. SpirvEmitter::processIntrinsicPrintf(const CallExpr *callExpr) {
  8077. // C99, s6.5.2.2/6: "If the expression that denotes the called function has a
  8078. // type that does not include a prototype, the integer promotions are
  8079. // performed on each argument, and arguments that have type float are promoted
  8080. // to double. These are called the default argument promotions."
  8081. // C++: All the variadic parameters undergo default promotions before they're
  8082. // received by the function.
  8083. //
  8084. // Therefore by default floating point arguments will be evaluated as double
  8085. // by this function.
  8086. //
  8087. // TODO: We may want to change this behavior for SPIR-V.
  8088. const auto returnType = callExpr->getType();
  8089. const auto numArgs = callExpr->getNumArgs();
  8090. const auto loc = callExpr->getExprLoc();
  8091. assert(numArgs >= 1u);
  8092. llvm::SmallVector<SpirvInstruction *, 4> args;
  8093. for (uint32_t argIndex = 0; argIndex < numArgs; ++argIndex)
  8094. args.push_back(doExpr(callExpr->getArg(argIndex)));
  8095. return spvBuilder.createNonSemanticDebugPrintfExtInst(
  8096. returnType, NonSemanticDebugPrintfDebugPrintf, args, loc);
  8097. }
  8098. SpirvInstruction *SpirvEmitter::processIntrinsicDot(const CallExpr *callExpr) {
  8099. // Get the function parameters. Expect 2 vectors as parameters.
  8100. assert(callExpr->getNumArgs() == 2u);
  8101. const Expr *arg0 = callExpr->getArg(0);
  8102. const Expr *arg1 = callExpr->getArg(1);
  8103. auto *arg0Id = doExpr(arg0);
  8104. auto *arg1Id = doExpr(arg1);
  8105. QualType arg0Type = arg0->getType();
  8106. QualType arg1Type = arg1->getType();
  8107. uint32_t vec0Size = 0, vec1Size = 0;
  8108. QualType vec0ComponentType = {}, vec1ComponentType = {};
  8109. QualType returnType = {};
  8110. const bool arg0isScalarOrVec =
  8111. isScalarOrVectorType(arg0Type, &vec0ComponentType, &vec0Size);
  8112. const bool arg1isScalarOrVec =
  8113. isScalarOrVectorType(arg1Type, &vec1ComponentType, &vec1Size);
  8114. const bool returnIsScalar = isScalarType(callExpr->getType(), &returnType);
  8115. // Each argument should either be a vector or a scalar
  8116. assert(arg0isScalarOrVec && arg1isScalarOrVec);
  8117. // The result type must be a scalar.
  8118. assert(returnIsScalar);
  8119. // The element type of each argument and the return type must be the same.
  8120. assert(returnType == vec1ComponentType);
  8121. assert(vec0ComponentType == vec1ComponentType);
  8122. // The size of the two arguments must be equal.
  8123. assert(vec0Size == vec1Size);
  8124. // Acceptable vector sizes are 1,2,3,4.
  8125. assert(vec0Size >= 1 && vec0Size <= 4);
  8126. (void)arg0isScalarOrVec;
  8127. (void)arg1isScalarOrVec;
  8128. (void)returnIsScalar;
  8129. (void)vec0ComponentType;
  8130. (void)vec1ComponentType;
  8131. (void)vec1Size;
  8132. auto loc = callExpr->getLocStart();
  8133. // According to HLSL reference, the dot function only works on integers
  8134. // and floats.
  8135. assert(returnType->isFloatingType() || returnType->isIntegerType());
  8136. // Special case: dot product of two vectors, each of size 1. That is
  8137. // basically the same as regular multiplication of 2 scalars.
  8138. if (vec0Size == 1) {
  8139. const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
  8140. return spvBuilder.createBinaryOp(spvOp, returnType, arg0Id, arg1Id, loc);
  8141. }
  8142. // If the vectors are of type Float, we can use OpDot.
  8143. if (returnType->isFloatingType()) {
  8144. return spvBuilder.createBinaryOp(spv::Op::OpDot, returnType, arg0Id, arg1Id,
  8145. loc);
  8146. }
  8147. // Vector component type is Integer (signed or unsigned).
  8148. // Create all instructions necessary to perform a dot product on
  8149. // two integer vectors. SPIR-V OpDot does not support integer vectors.
  8150. // Therefore, we use other SPIR-V instructions (addition and
  8151. // multiplication).
  8152. else {
  8153. SpirvInstruction *result = nullptr;
  8154. llvm::SmallVector<SpirvInstruction *, 4> multIds;
  8155. const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
  8156. const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
  8157. // Extract members from the two vectors and multiply them.
  8158. for (unsigned int i = 0; i < vec0Size; ++i) {
  8159. auto *vec0member = spvBuilder.createCompositeExtract(
  8160. returnType, arg0Id, {i}, arg0->getLocStart());
  8161. auto *vec1member = spvBuilder.createCompositeExtract(
  8162. returnType, arg1Id, {i}, arg1->getLocStart());
  8163. auto *multId = spvBuilder.createBinaryOp(multSpvOp, returnType,
  8164. vec0member, vec1member, loc);
  8165. multIds.push_back(multId);
  8166. }
  8167. // Add all the multiplications.
  8168. result = multIds[0];
  8169. for (unsigned int i = 1; i < vec0Size; ++i) {
  8170. auto *additionId = spvBuilder.createBinaryOp(addSpvOp, returnType, result,
  8171. multIds[i], loc);
  8172. result = additionId;
  8173. }
  8174. return result;
  8175. }
  8176. }
  8177. SpirvInstruction *SpirvEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
  8178. // 'rcp' takes only 1 argument that is a scalar, vector, or matrix of type
  8179. // float or double.
  8180. assert(callExpr->getNumArgs() == 1u);
  8181. const QualType returnType = callExpr->getType();
  8182. const Expr *arg = callExpr->getArg(0);
  8183. auto *argId = doExpr(arg);
  8184. const QualType argType = arg->getType();
  8185. auto loc = callExpr->getLocStart();
  8186. // For cases with matrix argument.
  8187. QualType elemType = {};
  8188. uint32_t numRows = 0, numCols = 0;
  8189. if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  8190. auto *vecOne = getVecValueOne(elemType, numCols);
  8191. const auto actOnEachVec = [this, vecOne, loc](uint32_t /*index*/,
  8192. QualType vecType,
  8193. SpirvInstruction *curRow) {
  8194. return spvBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne, curRow,
  8195. loc);
  8196. };
  8197. return processEachVectorInMatrix(arg, argId, actOnEachVec, loc);
  8198. }
  8199. // For cases with scalar or vector arguments.
  8200. return spvBuilder.createBinaryOp(spv::Op::OpFDiv, returnType,
  8201. getValueOne(argType), argId, loc);
  8202. }
  8203. SpirvInstruction *
  8204. SpirvEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
  8205. spv::Op spvOp) {
  8206. // 'all' and 'any' take only 1 parameter.
  8207. assert(callExpr->getNumArgs() == 1u);
  8208. const QualType returnType = callExpr->getType();
  8209. const Expr *arg = callExpr->getArg(0);
  8210. const QualType argType = arg->getType();
  8211. const auto loc = callExpr->getExprLoc();
  8212. // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
  8213. // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
  8214. {
  8215. QualType scalarType = {};
  8216. if (isScalarType(argType, &scalarType) &&
  8217. (scalarType->isBooleanType() || scalarType->isFloatingType() ||
  8218. scalarType->isIntegerType()))
  8219. return castToBool(doExpr(arg), argType, returnType, loc);
  8220. }
  8221. // Handle vectors larger than 1, Mx1 matrices, and 1xN matrices as arguments.
  8222. // Cast the vector to a boolean vector, then run OpAny/OpAll on it.
  8223. {
  8224. QualType elemType = {};
  8225. uint32_t size = 0;
  8226. if (isVectorType(argType, &elemType, &size)) {
  8227. const QualType castToBoolType =
  8228. astContext.getExtVectorType(returnType, size);
  8229. auto *castedToBool =
  8230. castToBool(doExpr(arg), argType, castToBoolType, loc);
  8231. return spvBuilder.createUnaryOp(spvOp, returnType, castedToBool, loc);
  8232. }
  8233. }
  8234. // Handle MxN matrices as arguments.
  8235. {
  8236. QualType elemType = {};
  8237. uint32_t matRowCount = 0, matColCount = 0;
  8238. if (isMxNMatrix(argType, &elemType, &matRowCount, &matColCount)) {
  8239. auto *matrix = doExpr(arg);
  8240. const QualType vecType = getComponentVectorType(astContext, argType);
  8241. llvm::SmallVector<SpirvInstruction *, 4> rowResults;
  8242. for (uint32_t i = 0; i < matRowCount; ++i) {
  8243. // Extract the row which is a float vector of size matColCount.
  8244. auto *rowFloatVec = spvBuilder.createCompositeExtract(
  8245. vecType, matrix, {i}, arg->getLocStart());
  8246. // Cast the float vector to boolean vector.
  8247. const auto rowFloatQualType =
  8248. astContext.getExtVectorType(elemType, matColCount);
  8249. const auto rowBoolQualType =
  8250. astContext.getExtVectorType(returnType, matColCount);
  8251. auto *rowBoolVec = castToBool(rowFloatVec, rowFloatQualType,
  8252. rowBoolQualType, arg->getLocStart());
  8253. // Perform OpAny/OpAll on the boolean vector.
  8254. rowResults.push_back(
  8255. spvBuilder.createUnaryOp(spvOp, returnType, rowBoolVec, loc));
  8256. }
  8257. // Create a new vector that is the concatenation of results of all rows.
  8258. const QualType vecOfBools =
  8259. astContext.getExtVectorType(astContext.BoolTy, matRowCount);
  8260. auto *row =
  8261. spvBuilder.createCompositeConstruct(vecOfBools, rowResults, loc);
  8262. // Run OpAny/OpAll on the newly-created vector.
  8263. return spvBuilder.createUnaryOp(spvOp, returnType, row, loc);
  8264. }
  8265. }
  8266. // All types should be handled already.
  8267. llvm_unreachable("Unknown argument type passed to all()/any().");
  8268. return nullptr;
  8269. }
  8270. SpirvInstruction *
  8271. SpirvEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
  8272. // This function handles 'asint', 'asuint', 'asfloat', and 'asdouble'.
  8273. // Method 1: ret asint(arg)
  8274. // arg component type = {float, uint}
  8275. // arg template type = {scalar, vector, matrix}
  8276. // ret template type = same as arg template type.
  8277. // ret component type = int
  8278. // Method 2: ret asuint(arg)
  8279. // arg component type = {float, int}
  8280. // arg template type = {scalar, vector, matrix}
  8281. // ret template type = same as arg template type.
  8282. // ret component type = uint
  8283. // Method 3: ret asfloat(arg)
  8284. // arg component type = {float, uint, int}
  8285. // arg template type = {scalar, vector, matrix}
  8286. // ret template type = same as arg template type.
  8287. // ret component type = float
  8288. // Method 4: double asdouble(uint lowbits, uint highbits)
  8289. // Method 5: double2 asdouble(uint2 lowbits, uint2 highbits)
  8290. // Method 6:
  8291. // void asuint(
  8292. // in double value,
  8293. // out uint lowbits,
  8294. // out uint highbits
  8295. // );
  8296. const QualType returnType = callExpr->getType();
  8297. const uint32_t numArgs = callExpr->getNumArgs();
  8298. const Expr *arg0 = callExpr->getArg(0);
  8299. const QualType argType = arg0->getType();
  8300. const auto loc = callExpr->getExprLoc();
  8301. // Method 3 return type may be the same as arg type, so it would be a no-op.
  8302. if (isSameType(astContext, returnType, argType))
  8303. return doExpr(arg0);
  8304. switch (numArgs) {
  8305. case 1: {
  8306. // Handling Method 1, 2, and 3.
  8307. auto *argInstr = doExpr(arg0);
  8308. QualType fromElemType = {};
  8309. uint32_t numRows = 0, numCols = 0;
  8310. // For non-matrix arguments (scalar or vector), just do an OpBitCast.
  8311. if (!isMxNMatrix(argType, &fromElemType, &numRows, &numCols)) {
  8312. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, returnType, argInstr,
  8313. loc);
  8314. }
  8315. // Input or output type is a matrix.
  8316. const QualType toElemType = hlsl::GetHLSLMatElementType(returnType);
  8317. llvm::SmallVector<SpirvInstruction *, 4> castedRows;
  8318. const auto fromVecType = astContext.getExtVectorType(fromElemType, numCols);
  8319. const auto toVecType = astContext.getExtVectorType(toElemType, numCols);
  8320. for (uint32_t row = 0; row < numRows; ++row) {
  8321. auto *rowInstr = spvBuilder.createCompositeExtract(
  8322. fromVecType, argInstr, {row}, arg0->getLocStart());
  8323. castedRows.push_back(spvBuilder.createUnaryOp(spv::Op::OpBitcast,
  8324. toVecType, rowInstr, loc));
  8325. }
  8326. return spvBuilder.createCompositeConstruct(returnType, castedRows, loc);
  8327. }
  8328. case 2: {
  8329. auto *lowbits = doExpr(arg0);
  8330. auto *highbits = doExpr(callExpr->getArg(1));
  8331. const auto uintType = astContext.UnsignedIntTy;
  8332. const auto doubleType = astContext.DoubleTy;
  8333. // Handling Method 4
  8334. if (argType->isUnsignedIntegerType()) {
  8335. const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
  8336. auto *operand = spvBuilder.createCompositeConstruct(
  8337. uintVec2Type, {lowbits, highbits}, loc);
  8338. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, doubleType, operand,
  8339. loc);
  8340. }
  8341. // Handling Method 5
  8342. else {
  8343. const auto uintVec4Type = astContext.getExtVectorType(uintType, 4);
  8344. const auto doubleVec2Type = astContext.getExtVectorType(doubleType, 2);
  8345. auto *operand = spvBuilder.createVectorShuffle(
  8346. uintVec4Type, lowbits, highbits, {0, 2, 1, 3}, loc);
  8347. return spvBuilder.createUnaryOp(spv::Op::OpBitcast, doubleVec2Type,
  8348. operand, loc);
  8349. }
  8350. }
  8351. case 3: {
  8352. // Handling Method 6.
  8353. auto *value = doExpr(arg0);
  8354. auto *lowbits = doExpr(callExpr->getArg(1));
  8355. auto *highbits = doExpr(callExpr->getArg(2));
  8356. const auto uintType = astContext.UnsignedIntTy;
  8357. const auto uintVec2Type = astContext.getExtVectorType(uintType, 2);
  8358. auto *vecResult =
  8359. spvBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value, loc);
  8360. spvBuilder.createStore(lowbits,
  8361. spvBuilder.createCompositeExtract(
  8362. uintType, vecResult, {0}, arg0->getLocStart()),
  8363. loc);
  8364. spvBuilder.createStore(highbits,
  8365. spvBuilder.createCompositeExtract(
  8366. uintType, vecResult, {1}, arg0->getLocStart()),
  8367. loc);
  8368. return nullptr;
  8369. }
  8370. default:
  8371. emitError("unrecognized signature for %0 intrinsic function", loc)
  8372. << callExpr->getDirectCallee()->getName();
  8373. return nullptr;
  8374. }
  8375. }
  8376. SpirvInstruction *
  8377. SpirvEmitter::processD3DCOLORtoUBYTE4(const CallExpr *callExpr) {
  8378. // Should take a float4 and return an int4 by doing:
  8379. // int4 result = input.zyxw * 255.001953;
  8380. // Maximum float precision makes the scaling factor 255.002.
  8381. const auto arg = callExpr->getArg(0);
  8382. auto *argId = doExpr(arg);
  8383. const auto argType = arg->getType();
  8384. auto loc = callExpr->getLocStart();
  8385. auto *swizzle =
  8386. spvBuilder.createVectorShuffle(argType, argId, argId, {2, 1, 0, 3}, loc);
  8387. auto *scaled = spvBuilder.createBinaryOp(
  8388. spv::Op::OpVectorTimesScalar, argType, swizzle,
  8389. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(255.002f)),
  8390. loc);
  8391. return castToInt(scaled, arg->getType(), callExpr->getType(), loc);
  8392. }
  8393. SpirvInstruction *
  8394. SpirvEmitter::processIntrinsicIsFinite(const CallExpr *callExpr) {
  8395. // Since OpIsFinite needs the Kernel capability, translation is instead done
  8396. // using OpIsNan and OpIsInf:
  8397. // isFinite = !(isNan || isInf)
  8398. const auto arg = doExpr(callExpr->getArg(0));
  8399. const auto returnType = callExpr->getType();
  8400. const auto loc = callExpr->getExprLoc();
  8401. const auto isNan =
  8402. spvBuilder.createUnaryOp(spv::Op::OpIsNan, returnType, arg, loc);
  8403. const auto isInf =
  8404. spvBuilder.createUnaryOp(spv::Op::OpIsInf, returnType, arg, loc);
  8405. const auto isNanOrInf = spvBuilder.createBinaryOp(
  8406. spv::Op::OpLogicalOr, returnType, isNan, isInf, loc);
  8407. return spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, returnType, isNanOrInf,
  8408. loc);
  8409. }
  8410. SpirvInstruction *
  8411. SpirvEmitter::processIntrinsicSinCos(const CallExpr *callExpr) {
  8412. // Since there is no sincos equivalent in SPIR-V, we need to perform Sin
  8413. // once and Cos once. We can reuse existing Sine/Cosine handling functions.
  8414. CallExpr *sincosExpr =
  8415. new (astContext) CallExpr(astContext, Stmt::StmtClass::NoStmtClass, {});
  8416. sincosExpr->setType(callExpr->getArg(0)->getType());
  8417. sincosExpr->setNumArgs(astContext, 1);
  8418. sincosExpr->setArg(0, const_cast<Expr *>(callExpr->getArg(0)));
  8419. const auto srcLoc = callExpr->getExprLoc();
  8420. // Perform Sin and store results in argument 1.
  8421. auto *sin =
  8422. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Sin,
  8423. /*actPerRowForMatrices*/ true, srcLoc);
  8424. spvBuilder.createStore(doExpr(callExpr->getArg(1)), sin, srcLoc);
  8425. // Perform Cos and store results in argument 2.
  8426. auto *cos =
  8427. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Cos,
  8428. /*actPerRowForMatrices*/ true, srcLoc);
  8429. spvBuilder.createStore(doExpr(callExpr->getArg(2)), cos, srcLoc);
  8430. return nullptr;
  8431. }
  8432. SpirvInstruction *
  8433. SpirvEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
  8434. const auto *arg = callExpr->getArg(0);
  8435. const auto loc = callExpr->getExprLoc();
  8436. auto *argId = doExpr(arg);
  8437. const auto argType = arg->getType();
  8438. const QualType returnType = callExpr->getType();
  8439. QualType elemType = {};
  8440. uint32_t vecSize = 0;
  8441. if (isScalarType(argType, &elemType)) {
  8442. auto *floatZero = getValueZero(elemType);
  8443. auto *floatOne = getValueOne(elemType);
  8444. return spvBuilder.createGLSLExtInst(returnType,
  8445. GLSLstd450::GLSLstd450FClamp,
  8446. {argId, floatZero, floatOne}, loc);
  8447. }
  8448. if (isVectorType(argType, &elemType, &vecSize)) {
  8449. auto *vecZero = getVecValueZero(elemType, vecSize);
  8450. auto *vecOne = getVecValueOne(elemType, vecSize);
  8451. return spvBuilder.createGLSLExtInst(returnType,
  8452. GLSLstd450::GLSLstd450FClamp,
  8453. {argId, vecZero, vecOne}, loc);
  8454. }
  8455. uint32_t numRows = 0, numCols = 0;
  8456. if (isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  8457. auto *vecZero = getVecValueZero(elemType, numCols);
  8458. auto *vecOne = getVecValueOne(elemType, numCols);
  8459. const auto actOnEachVec = [this, loc, vecZero,
  8460. vecOne](uint32_t /*index*/, QualType vecType,
  8461. SpirvInstruction *curRow) {
  8462. return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FClamp,
  8463. {curRow, vecZero, vecOne}, loc);
  8464. };
  8465. return processEachVectorInMatrix(arg, argId, actOnEachVec, loc);
  8466. }
  8467. emitError("invalid argument type passed to saturate intrinsic function",
  8468. callExpr->getExprLoc());
  8469. return nullptr;
  8470. }
  8471. SpirvInstruction *
  8472. SpirvEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
  8473. // Import the GLSL.std.450 extended instruction set.
  8474. const Expr *arg = callExpr->getArg(0);
  8475. const auto loc = callExpr->getExprLoc();
  8476. const QualType returnType = callExpr->getType();
  8477. const QualType argType = arg->getType();
  8478. assert(isFloatOrVecMatOfFloatType(argType));
  8479. auto *argId = doExpr(arg);
  8480. SpirvInstruction *floatSign = nullptr;
  8481. // For matrices, we can perform the instruction on each vector of the matrix.
  8482. if (isMxNMatrix(argType)) {
  8483. const auto actOnEachVec = [this, loc](uint32_t /*index*/, QualType vecType,
  8484. SpirvInstruction *curRow) {
  8485. return spvBuilder.createGLSLExtInst(vecType, GLSLstd450::GLSLstd450FSign,
  8486. {curRow}, loc);
  8487. };
  8488. floatSign = processEachVectorInMatrix(arg, argId, actOnEachVec, loc);
  8489. } else {
  8490. floatSign = spvBuilder.createGLSLExtInst(
  8491. argType, GLSLstd450::GLSLstd450FSign, {argId}, loc);
  8492. }
  8493. return castToInt(floatSign, arg->getType(), returnType, arg->getLocStart());
  8494. }
  8495. SpirvInstruction *
  8496. SpirvEmitter::processIntrinsicF16ToF32(const CallExpr *callExpr) {
  8497. // f16tof32() takes in (vector of) uint and returns (vector of) float.
  8498. // The frontend should guarantee that by inserting implicit casts.
  8499. const QualType f32Type = astContext.FloatTy;
  8500. const QualType u32Type = astContext.UnsignedIntTy;
  8501. const QualType v2f32Type = astContext.getExtVectorType(f32Type, 2);
  8502. const auto loc = callExpr->getExprLoc();
  8503. const auto *arg = callExpr->getArg(0);
  8504. auto *argId = doExpr(arg);
  8505. uint32_t elemCount = {};
  8506. if (isVectorType(arg->getType(), nullptr, &elemCount)) {
  8507. // The input is a vector. We need to handle each element separately.
  8508. llvm::SmallVector<SpirvInstruction *, 4> elements;
  8509. for (uint32_t i = 0; i < elemCount; ++i) {
  8510. auto *srcElem = spvBuilder.createCompositeExtract(u32Type, argId, {i},
  8511. arg->getLocStart());
  8512. auto *convert = spvBuilder.createGLSLExtInst(
  8513. v2f32Type, GLSLstd450::GLSLstd450UnpackHalf2x16, srcElem, loc);
  8514. elements.push_back(
  8515. spvBuilder.createCompositeExtract(f32Type, convert, {0}, loc));
  8516. }
  8517. return spvBuilder.createCompositeConstruct(
  8518. astContext.getExtVectorType(f32Type, elemCount), elements, loc);
  8519. }
  8520. auto *convert = spvBuilder.createGLSLExtInst(
  8521. v2f32Type, GLSLstd450::GLSLstd450UnpackHalf2x16, argId, loc);
  8522. // f16tof32() converts the float16 stored in the low-half of the uint to
  8523. // a float. So just need to return the first component.
  8524. return spvBuilder.createCompositeExtract(f32Type, convert, {0}, loc);
  8525. }
  8526. SpirvInstruction *
  8527. SpirvEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
  8528. // f32tof16() takes in (vector of) float and returns (vector of) uint.
  8529. // The frontend should guarantee that by inserting implicit casts.
  8530. const QualType f32Type = astContext.FloatTy;
  8531. const QualType u32Type = astContext.UnsignedIntTy;
  8532. const QualType v2f32Type = astContext.getExtVectorType(f32Type, 2);
  8533. auto *zero = spvBuilder.getConstantFloat(f32Type, llvm::APFloat(0.0f));
  8534. const auto loc = callExpr->getExprLoc();
  8535. const auto *arg = callExpr->getArg(0);
  8536. auto *argId = doExpr(arg);
  8537. uint32_t elemCount = {};
  8538. if (isVectorType(arg->getType(), nullptr, &elemCount)) {
  8539. // The input is a vector. We need to handle each element separately.
  8540. llvm::SmallVector<SpirvInstruction *, 4> elements;
  8541. for (uint32_t i = 0; i < elemCount; ++i) {
  8542. auto *srcElem = spvBuilder.createCompositeExtract(f32Type, argId, {i},
  8543. arg->getLocStart());
  8544. auto *srcVec =
  8545. spvBuilder.createCompositeConstruct(v2f32Type, {srcElem, zero}, loc);
  8546. elements.push_back(spvBuilder.createGLSLExtInst(
  8547. u32Type, GLSLstd450::GLSLstd450PackHalf2x16, srcVec, loc));
  8548. }
  8549. return spvBuilder.createCompositeConstruct(
  8550. astContext.getExtVectorType(u32Type, elemCount), elements, loc);
  8551. }
  8552. // f16tof32() stores the float into the low-half of the uint. So we need
  8553. // to supply another zero to take the other half.
  8554. auto *srcVec =
  8555. spvBuilder.createCompositeConstruct(v2f32Type, {argId, zero}, loc);
  8556. return spvBuilder.createGLSLExtInst(
  8557. u32Type, GLSLstd450::GLSLstd450PackHalf2x16, srcVec, loc);
  8558. }
  8559. SpirvInstruction *SpirvEmitter::processIntrinsicUsingSpirvInst(
  8560. const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
  8561. // Certain opcodes are only allowed in pixel shader
  8562. if (!spvContext.isPS())
  8563. switch (opcode) {
  8564. case spv::Op::OpDPdx:
  8565. case spv::Op::OpDPdy:
  8566. case spv::Op::OpDPdxFine:
  8567. case spv::Op::OpDPdyFine:
  8568. case spv::Op::OpDPdxCoarse:
  8569. case spv::Op::OpDPdyCoarse:
  8570. case spv::Op::OpFwidth:
  8571. case spv::Op::OpFwidthFine:
  8572. case spv::Op::OpFwidthCoarse:
  8573. needsLegalization = true;
  8574. break;
  8575. default:
  8576. // Only the given opcodes need legalization. Anything else should preserve
  8577. // previous.
  8578. break;
  8579. }
  8580. const auto loc = callExpr->getExprLoc();
  8581. const QualType returnType = callExpr->getType();
  8582. if (callExpr->getNumArgs() == 1u) {
  8583. const Expr *arg = callExpr->getArg(0);
  8584. auto *argId = doExpr(arg);
  8585. // If the instruction does not operate on matrices, we can perform the
  8586. // instruction on each vector of the matrix.
  8587. if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
  8588. const auto actOnEachVec = [this, opcode, loc](uint32_t /*index*/,
  8589. QualType vecType,
  8590. SpirvInstruction *curRow) {
  8591. return spvBuilder.createUnaryOp(opcode, vecType, curRow, loc);
  8592. };
  8593. return processEachVectorInMatrix(arg, argId, actOnEachVec, loc);
  8594. }
  8595. return spvBuilder.createUnaryOp(opcode, returnType, argId, loc);
  8596. } else if (callExpr->getNumArgs() == 2u) {
  8597. const Expr *arg0 = callExpr->getArg(0);
  8598. auto *arg0Id = doExpr(arg0);
  8599. auto *arg1Id = doExpr(callExpr->getArg(1));
  8600. const auto arg1Loc = callExpr->getArg(1)->getLocStart();
  8601. // If the instruction does not operate on matrices, we can perform the
  8602. // instruction on each vector of the matrix.
  8603. if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
  8604. const auto actOnEachVec = [this, opcode, arg1Id, loc,
  8605. arg1Loc](uint32_t index, QualType vecType,
  8606. SpirvInstruction *arg0Row) {
  8607. auto *arg1Row = spvBuilder.createCompositeExtract(vecType, arg1Id,
  8608. {index}, arg1Loc);
  8609. return spvBuilder.createBinaryOp(opcode, vecType, arg0Row, arg1Row,
  8610. loc);
  8611. };
  8612. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec, loc);
  8613. }
  8614. return spvBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id, loc);
  8615. }
  8616. emitError("unsupported %0 intrinsic function", loc)
  8617. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  8618. return nullptr;
  8619. }
  8620. SpirvInstruction *SpirvEmitter::processIntrinsicUsingGLSLInst(
  8621. const CallExpr *callExpr, GLSLstd450 opcode, bool actPerRowForMatrices,
  8622. SourceLocation loc) {
  8623. // Import the GLSL.std.450 extended instruction set.
  8624. const QualType returnType = callExpr->getType();
  8625. if (callExpr->getNumArgs() == 1u) {
  8626. const Expr *arg = callExpr->getArg(0);
  8627. auto *argInstr = doExpr(arg);
  8628. // If the instruction does not operate on matrices, we can perform the
  8629. // instruction on each vector of the matrix.
  8630. if (actPerRowForMatrices && isMxNMatrix(arg->getType())) {
  8631. const auto actOnEachVec = [this, loc,
  8632. opcode](uint32_t /*index*/, QualType vecType,
  8633. SpirvInstruction *curRowInstr) {
  8634. return spvBuilder.createGLSLExtInst(vecType, opcode, {curRowInstr},
  8635. loc);
  8636. };
  8637. return processEachVectorInMatrix(arg, argInstr, actOnEachVec, loc);
  8638. }
  8639. return spvBuilder.createGLSLExtInst(returnType, opcode, {argInstr}, loc);
  8640. } else if (callExpr->getNumArgs() == 2u) {
  8641. const Expr *arg0 = callExpr->getArg(0);
  8642. auto *arg0Instr = doExpr(arg0);
  8643. auto *arg1Instr = doExpr(callExpr->getArg(1));
  8644. const auto arg1Loc = callExpr->getArg(1)->getLocStart();
  8645. // If the instruction does not operate on matrices, we can perform the
  8646. // instruction on each vector of the matrix.
  8647. if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
  8648. const auto actOnEachVec = [this, loc, opcode, arg1Instr,
  8649. arg1Loc](uint32_t index, QualType vecType,
  8650. SpirvInstruction *arg0RowInstr) {
  8651. auto *arg1RowInstr = spvBuilder.createCompositeExtract(
  8652. vecType, arg1Instr, {index}, arg1Loc);
  8653. return spvBuilder.createGLSLExtInst(vecType, opcode,
  8654. {arg0RowInstr, arg1RowInstr}, loc);
  8655. };
  8656. return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc);
  8657. }
  8658. return spvBuilder.createGLSLExtInst(returnType, opcode,
  8659. {arg0Instr, arg1Instr}, loc);
  8660. } else if (callExpr->getNumArgs() == 3u) {
  8661. const Expr *arg0 = callExpr->getArg(0);
  8662. auto *arg0Instr = doExpr(arg0);
  8663. auto *arg1Instr = doExpr(callExpr->getArg(1));
  8664. auto *arg2Instr = doExpr(callExpr->getArg(2));
  8665. auto arg1Loc = callExpr->getArg(1)->getLocStart();
  8666. auto arg2Loc = callExpr->getArg(2)->getLocStart();
  8667. // If the instruction does not operate on matrices, we can perform the
  8668. // instruction on each vector of the matrix.
  8669. if (actPerRowForMatrices && isMxNMatrix(arg0->getType())) {
  8670. const auto actOnEachVec = [this, loc, opcode, arg1Instr, arg2Instr,
  8671. arg1Loc,
  8672. arg2Loc](uint32_t index, QualType vecType,
  8673. SpirvInstruction *arg0RowInstr) {
  8674. auto *arg1RowInstr = spvBuilder.createCompositeExtract(
  8675. vecType, arg1Instr, {index}, arg1Loc);
  8676. auto *arg2RowInstr = spvBuilder.createCompositeExtract(
  8677. vecType, arg2Instr, {index}, arg2Loc);
  8678. return spvBuilder.createGLSLExtInst(
  8679. vecType, opcode, {arg0RowInstr, arg1RowInstr, arg2RowInstr}, loc);
  8680. };
  8681. return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec, loc);
  8682. }
  8683. return spvBuilder.createGLSLExtInst(returnType, opcode,
  8684. {arg0Instr, arg1Instr, arg2Instr}, loc);
  8685. }
  8686. emitError("unsupported %0 intrinsic function", callExpr->getExprLoc())
  8687. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  8688. return nullptr;
  8689. }
  8690. SpirvInstruction *
  8691. SpirvEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
  8692. // Since there is no log10 instruction in SPIR-V, we can use:
  8693. // log10(x) = log2(x) * ( 1 / log2(10) )
  8694. // 1 / log2(10) = 0.30103
  8695. auto loc = callExpr->getExprLoc();
  8696. auto *scale =
  8697. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(0.30103f));
  8698. auto *log2 = processIntrinsicUsingGLSLInst(
  8699. callExpr, GLSLstd450::GLSLstd450Log2, true, loc);
  8700. const auto returnType = callExpr->getType();
  8701. spv::Op scaleOp = isScalarType(returnType)
  8702. ? spv::Op::OpFMul
  8703. : isVectorType(returnType)
  8704. ? spv::Op::OpVectorTimesScalar
  8705. : spv::Op::OpMatrixTimesScalar;
  8706. return spvBuilder.createBinaryOp(scaleOp, returnType, log2, scale, loc);
  8707. }
  8708. SpirvInstruction *SpirvEmitter::processRayBuiltins(const CallExpr *callExpr,
  8709. hlsl::IntrinsicOp op) {
  8710. spv::BuiltIn builtin = spv::BuiltIn::Max;
  8711. bool transposeMatrix = false;
  8712. const auto loc = callExpr->getExprLoc();
  8713. switch (op) {
  8714. case hlsl::IntrinsicOp::IOP_DispatchRaysDimensions:
  8715. builtin = spv::BuiltIn::LaunchSizeNV;
  8716. break;
  8717. case hlsl::IntrinsicOp::IOP_DispatchRaysIndex:
  8718. builtin = spv::BuiltIn::LaunchIdNV;
  8719. break;
  8720. case hlsl::IntrinsicOp::IOP_RayTCurrent:
  8721. builtin = spv::BuiltIn::HitTNV;
  8722. break;
  8723. case hlsl::IntrinsicOp::IOP_RayTMin:
  8724. builtin = spv::BuiltIn::RayTminNV;
  8725. break;
  8726. case hlsl::IntrinsicOp::IOP_HitKind:
  8727. builtin = spv::BuiltIn::HitKindNV;
  8728. break;
  8729. case hlsl::IntrinsicOp::IOP_WorldRayDirection:
  8730. builtin = spv::BuiltIn::WorldRayDirectionNV;
  8731. break;
  8732. case hlsl::IntrinsicOp::IOP_WorldRayOrigin:
  8733. builtin = spv::BuiltIn::WorldRayOriginNV;
  8734. break;
  8735. case hlsl::IntrinsicOp::IOP_ObjectRayDirection:
  8736. builtin = spv::BuiltIn::ObjectRayDirectionNV;
  8737. break;
  8738. case hlsl::IntrinsicOp::IOP_ObjectRayOrigin:
  8739. builtin = spv::BuiltIn::ObjectRayOriginNV;
  8740. break;
  8741. case hlsl::IntrinsicOp::IOP_GeometryIndex:
  8742. featureManager.requestExtension(Extension::KHR_ray_tracing,
  8743. "GeometryIndex()", loc);
  8744. builtin = spv::BuiltIn::RayGeometryIndexKHR;
  8745. break;
  8746. case hlsl::IntrinsicOp::IOP_InstanceIndex:
  8747. builtin = spv::BuiltIn::InstanceId;
  8748. break;
  8749. case hlsl::IntrinsicOp::IOP_PrimitiveIndex:
  8750. builtin = spv::BuiltIn::PrimitiveId;
  8751. break;
  8752. case hlsl::IntrinsicOp::IOP_InstanceID:
  8753. builtin = spv::BuiltIn::InstanceCustomIndexNV;
  8754. break;
  8755. case hlsl::IntrinsicOp::IOP_RayFlags:
  8756. builtin = spv::BuiltIn::IncomingRayFlagsNV;
  8757. break;
  8758. case hlsl::IntrinsicOp::IOP_ObjectToWorld3x4:
  8759. transposeMatrix = true;
  8760. case hlsl::IntrinsicOp::IOP_ObjectToWorld4x3:
  8761. builtin = spv::BuiltIn::ObjectToWorldNV;
  8762. break;
  8763. case hlsl::IntrinsicOp::IOP_WorldToObject3x4:
  8764. transposeMatrix = true;
  8765. case hlsl::IntrinsicOp::IOP_WorldToObject4x3:
  8766. builtin = spv::BuiltIn::WorldToObjectNV;
  8767. break;
  8768. default:
  8769. emitError("ray intrinsic function unimplemented", loc);
  8770. return nullptr;
  8771. }
  8772. QualType builtinType = callExpr->getType();
  8773. if (transposeMatrix) {
  8774. // DXR defines ObjectToWorld3x4, WorldToObject3x4 as transposed matrices.
  8775. // SPIR-V has only non tranposed variant defined as a builtin
  8776. // So perform read of original non transposed builtin and perform transpose.
  8777. assert(hlsl::IsHLSLMatType(builtinType) && "Builtin should be matrix");
  8778. const clang::Type *type = builtinType.getCanonicalType().getTypePtr();
  8779. const RecordType *RT = cast<RecordType>(type);
  8780. const ClassTemplateSpecializationDecl *templateSpecDecl =
  8781. cast<ClassTemplateSpecializationDecl>(RT->getDecl());
  8782. ClassTemplateDecl *templateDecl =
  8783. templateSpecDecl->getSpecializedTemplate();
  8784. builtinType = getHLSLMatrixType(astContext, theCompilerInstance.getSema(),
  8785. templateDecl, astContext.FloatTy, 4, 3);
  8786. }
  8787. SpirvInstruction *retVal =
  8788. declIdMapper.getBuiltinVar(builtin, builtinType, loc);
  8789. retVal = spvBuilder.createLoad(builtinType, retVal, loc);
  8790. if (transposeMatrix)
  8791. retVal = spvBuilder.createUnaryOp(spv::Op::OpTranspose, callExpr->getType(),
  8792. retVal, loc);
  8793. return retVal;
  8794. }
  8795. SpirvInstruction *SpirvEmitter::processReportHit(const CallExpr *callExpr) {
  8796. SpirvInstruction *hitAttributeStageVar = nullptr;
  8797. const VarDecl *hitAttributeArg = nullptr;
  8798. QualType hitAttributeType;
  8799. const auto args = callExpr->getArgs();
  8800. if (callExpr->getNumArgs() != 3) {
  8801. emitError("invalid number of arguments to ReportHit",
  8802. callExpr->getExprLoc());
  8803. }
  8804. // HLSL Function :
  8805. // template<typename hitAttr>
  8806. // ReportHit(in float, in uint, in hitAttr)
  8807. if (const auto *implCastExpr = dyn_cast<CastExpr>(callExpr->getArg(2))) {
  8808. if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
  8809. if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
  8810. hitAttributeType = varDecl->getType();
  8811. hitAttributeArg = varDecl;
  8812. // Check if same type of hit attribute stage variable was already
  8813. // created, if so re-use
  8814. const auto iter = hitAttributeMap.find(hitAttributeType);
  8815. if (iter == hitAttributeMap.end()) {
  8816. hitAttributeStageVar = declIdMapper.createRayTracingNVStageVar(
  8817. spv::StorageClass::HitAttributeNV, varDecl);
  8818. hitAttributeMap[hitAttributeType] = hitAttributeStageVar;
  8819. } else {
  8820. hitAttributeStageVar = iter->second;
  8821. }
  8822. }
  8823. }
  8824. }
  8825. assert(hitAttributeStageVar && hitAttributeArg);
  8826. // Copy argument to stage variable
  8827. const auto hitAttributeArgInst =
  8828. declIdMapper.getDeclEvalInfo(hitAttributeArg, callExpr->getExprLoc());
  8829. auto tempLoad =
  8830. spvBuilder.createLoad(hitAttributeArg->getType(), hitAttributeArgInst,
  8831. hitAttributeArg->getLocStart());
  8832. spvBuilder.createStore(hitAttributeStageVar, tempLoad,
  8833. callExpr->getExprLoc());
  8834. // SPIR-V Instruction :
  8835. // bool OpReportIntersection(<id> float Hit, <id> uint HitKind)
  8836. llvm::SmallVector<SpirvInstruction *, 4> reportHitArgs;
  8837. reportHitArgs.push_back(doExpr(args[0])); // Hit
  8838. reportHitArgs.push_back(doExpr(args[1])); // HitKind
  8839. return spvBuilder.createRayTracingOpsNV(spv::Op::OpReportIntersectionNV,
  8840. astContext.BoolTy, reportHitArgs,
  8841. callExpr->getExprLoc());
  8842. }
  8843. void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
  8844. SpirvInstruction *callDataLocInst = nullptr;
  8845. SpirvInstruction *callDataStageVar = nullptr;
  8846. const VarDecl *callDataArg = nullptr;
  8847. QualType callDataType;
  8848. const auto args = callExpr->getArgs();
  8849. if (callExpr->getNumArgs() != 2) {
  8850. emitError("invalid number of arguments to CallShader",
  8851. callExpr->getExprLoc());
  8852. }
  8853. // HLSL Func :
  8854. // template<typename CallData>
  8855. // void CallShader(in int sbtIndex, inout CallData arg)
  8856. if (const auto *implCastExpr = dyn_cast<CastExpr>(args[1])) {
  8857. if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
  8858. if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
  8859. callDataType = varDecl->getType();
  8860. callDataArg = varDecl;
  8861. // Check if same type of callable data stage variable was already
  8862. // created, if so re-use
  8863. const auto callDataPair = callDataMap.find(callDataType);
  8864. if (callDataPair == callDataMap.end()) {
  8865. int numCallDataVars = callDataMap.size();
  8866. callDataStageVar = declIdMapper.createRayTracingNVStageVar(
  8867. spv::StorageClass::CallableDataNV, varDecl);
  8868. // Decorate unique location id for each created stage var
  8869. spvBuilder.decorateLocation(callDataStageVar, numCallDataVars);
  8870. callDataLocInst = spvBuilder.getConstantInt(
  8871. astContext.UnsignedIntTy, llvm::APInt(32, numCallDataVars));
  8872. callDataMap[callDataType] =
  8873. std::make_pair(callDataStageVar, callDataLocInst);
  8874. } else {
  8875. callDataStageVar = callDataPair->second.first;
  8876. callDataLocInst = callDataPair->second.second;
  8877. }
  8878. }
  8879. }
  8880. }
  8881. assert(callDataStageVar && callDataArg);
  8882. // Copy argument to stage variable
  8883. const auto callDataArgInst =
  8884. declIdMapper.getDeclEvalInfo(callDataArg, callExpr->getExprLoc());
  8885. auto tempLoad = spvBuilder.createLoad(callDataArg->getType(), callDataArgInst,
  8886. callDataArg->getLocStart());
  8887. spvBuilder.createStore(callDataStageVar, tempLoad, callExpr->getExprLoc());
  8888. // SPIR-V Instruction
  8889. // void OpExecuteCallable(<id> int SBT Index, <id> uint Callable Data Location
  8890. // Id)
  8891. llvm::SmallVector<SpirvInstruction *, 2> callShaderArgs;
  8892. callShaderArgs.push_back(doExpr(args[0]));
  8893. callShaderArgs.push_back(callDataLocInst);
  8894. spvBuilder.createRayTracingOpsNV(spv::Op::OpExecuteCallableNV, QualType(),
  8895. callShaderArgs, callExpr->getExprLoc());
  8896. // Copy data back to argument
  8897. tempLoad = spvBuilder.createLoad(callDataArg->getType(), callDataStageVar,
  8898. callDataArg->getLocStart());
  8899. spvBuilder.createStore(callDataArgInst, tempLoad, callExpr->getExprLoc());
  8900. return;
  8901. }
  8902. void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
  8903. SpirvInstruction *rayPayloadLocInst = nullptr;
  8904. SpirvInstruction *rayPayloadStageVar = nullptr;
  8905. const VarDecl *rayPayloadArg = nullptr;
  8906. QualType rayPayloadType;
  8907. const auto args = callExpr->getArgs();
  8908. if (callExpr->getNumArgs() != 8) {
  8909. emitError("invalid number of arguments to TraceRay",
  8910. callExpr->getExprLoc());
  8911. }
  8912. // HLSL Func
  8913. // template<typename RayPayload>
  8914. // void TraceRay(RaytracingAccelerationStructure rs,
  8915. // uint rayflags,
  8916. // uint InstanceInclusionMask
  8917. // uint RayContributionToHitGroupIndex,
  8918. // uint MultiplierForGeometryContributionToHitGroupIndex,
  8919. // uint MissShaderIndex,
  8920. // RayDesc ray,
  8921. // inout RayPayload p)
  8922. // where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
  8923. if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
  8924. if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
  8925. if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
  8926. rayPayloadType = varDecl->getType();
  8927. rayPayloadArg = varDecl;
  8928. const auto rayPayloadPair = rayPayloadMap.find(rayPayloadType);
  8929. // Check if same type of rayPayload stage variable was already
  8930. // created, if so re-use
  8931. if (rayPayloadPair == rayPayloadMap.end()) {
  8932. int numPayloadVars = rayPayloadMap.size();
  8933. rayPayloadStageVar = declIdMapper.createRayTracingNVStageVar(
  8934. spv::StorageClass::RayPayloadNV, varDecl);
  8935. // Decorate unique location id for each created stage var
  8936. spvBuilder.decorateLocation(rayPayloadStageVar, numPayloadVars);
  8937. rayPayloadLocInst = spvBuilder.getConstantInt(
  8938. astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
  8939. rayPayloadMap[rayPayloadType] =
  8940. std::make_pair(rayPayloadStageVar, rayPayloadLocInst);
  8941. } else {
  8942. rayPayloadStageVar = rayPayloadPair->second.first;
  8943. rayPayloadLocInst = rayPayloadPair->second.second;
  8944. }
  8945. }
  8946. }
  8947. }
  8948. assert(rayPayloadStageVar && rayPayloadArg);
  8949. const auto floatType = astContext.FloatTy;
  8950. const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
  8951. // Extract the ray description to match SPIR-V
  8952. SpirvInstruction *rayDescArg = doExpr(args[6]);
  8953. const auto loc = args[6]->getLocStart();
  8954. const auto origin =
  8955. spvBuilder.createCompositeExtract(vecType, rayDescArg, {0}, loc);
  8956. const auto tMin =
  8957. spvBuilder.createCompositeExtract(floatType, rayDescArg, {1}, loc);
  8958. const auto direction =
  8959. spvBuilder.createCompositeExtract(vecType, rayDescArg, {2}, loc);
  8960. const auto tMax =
  8961. spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
  8962. // Copy argument to stage variable
  8963. const auto rayPayloadArgInst =
  8964. declIdMapper.getDeclEvalInfo(rayPayloadArg, rayPayloadArg->getLocStart());
  8965. auto tempLoad =
  8966. spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadArgInst,
  8967. rayPayloadArg->getLocStart());
  8968. spvBuilder.createStore(rayPayloadStageVar, tempLoad, callExpr->getExprLoc());
  8969. // SPIR-V Instruction
  8970. // void OpTraceNV ( <id> AccelerationStructureNV acStruct,
  8971. // <id> uint Ray Flags,
  8972. // <id> uint Cull Mask,
  8973. // <id> uint SBT Offset,
  8974. // <id> uint SBT Stride,
  8975. // <id> uint Miss Index,
  8976. // <id> vec4 Ray Origin,
  8977. // <id> float Ray Tmin,
  8978. // <id> vec3 Ray Direction,
  8979. // <id> float Ray Tmax,
  8980. // <id> uint RayPayload number)
  8981. llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
  8982. for (int ii = 0; ii < 6; ii++) {
  8983. traceArgs.push_back(doExpr(args[ii]));
  8984. }
  8985. traceArgs.push_back(origin);
  8986. traceArgs.push_back(tMin);
  8987. traceArgs.push_back(direction);
  8988. traceArgs.push_back(tMax);
  8989. traceArgs.push_back(rayPayloadLocInst);
  8990. spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
  8991. callExpr->getExprLoc());
  8992. // Copy arguments back to stage variable
  8993. tempLoad = spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadStageVar,
  8994. rayPayloadArg->getLocStart());
  8995. spvBuilder.createStore(rayPayloadArgInst, tempLoad, callExpr->getExprLoc());
  8996. return;
  8997. }
  8998. void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
  8999. // HLSL Func - void DispatchMesh(uint ThreadGroupCountX,
  9000. // uint ThreadGroupCountY,
  9001. // uint ThreadGroupCountZ,
  9002. // groupshared <structType> MeshPayload);
  9003. assert(callExpr->getNumArgs() == 4);
  9004. const auto args = callExpr->getArgs();
  9005. const auto loc = callExpr->getExprLoc();
  9006. // 1) create a barrier GroupMemoryBarrierWithGroupSync().
  9007. processIntrinsicMemoryBarrier(callExpr,
  9008. /*isDevice*/ false,
  9009. /*groupSync*/ true,
  9010. /*isAllBarrier*/ false);
  9011. // 2) set TaskCountNV = threadX * threadY * threadZ.
  9012. auto *threadX = doExpr(args[0]);
  9013. auto *threadY = doExpr(args[1]);
  9014. auto *threadZ = doExpr(args[2]);
  9015. auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
  9016. astContext.UnsignedIntTy, loc);
  9017. auto *taskCount = spvBuilder.createBinaryOp(
  9018. spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
  9019. spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
  9020. threadY, threadZ, loc),
  9021. loc);
  9022. spvBuilder.createStore(var, taskCount, loc);
  9023. // 3) create PerTaskNV out attribute block and store MeshPayload info.
  9024. const auto *sigPoint =
  9025. hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
  9026. spv::StorageClass sc = spv::StorageClass::Output;
  9027. auto *payloadArg = doExpr(args[3]);
  9028. bool isValid = false;
  9029. if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
  9030. if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
  9031. if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
  9032. if (paramDecl->hasAttr<HLSLGroupSharedAttr>()) {
  9033. isValid = declIdMapper.createPayloadStageVars(
  9034. sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
  9035. "out.var", &payloadArg);
  9036. }
  9037. }
  9038. }
  9039. }
  9040. if (!isValid) {
  9041. emitError("expected groupshared object as argument to DispatchMesh()",
  9042. args[3]->getExprLoc());
  9043. }
  9044. }
  9045. void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
  9046. // HLSL Func - void SetMeshOutputCounts(uint numVertices, uint numPrimitives);
  9047. assert(callExpr->getNumArgs() == 2);
  9048. const auto args = callExpr->getArgs();
  9049. const auto loc = callExpr->getExprLoc();
  9050. auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
  9051. astContext.UnsignedIntTy, loc);
  9052. spvBuilder.createStore(var, doExpr(args[1]), loc);
  9053. }
  9054. SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
  9055. {
  9056. QualType scalarType = {};
  9057. if (isScalarType(type, &scalarType)) {
  9058. if (scalarType->isBooleanType()) {
  9059. return spvBuilder.getConstantBool(false);
  9060. }
  9061. if (scalarType->isIntegerType()) {
  9062. return spvBuilder.getConstantInt(scalarType, llvm::APInt(32, 0));
  9063. }
  9064. if (scalarType->isFloatingType()) {
  9065. return spvBuilder.getConstantFloat(scalarType, llvm::APFloat(0.0f));
  9066. }
  9067. }
  9068. }
  9069. {
  9070. QualType elemType = {};
  9071. uint32_t size = {};
  9072. if (isVectorType(type, &elemType, &size)) {
  9073. return getVecValueZero(elemType, size);
  9074. }
  9075. }
  9076. {
  9077. QualType elemType = {};
  9078. uint32_t rowCount = 0, colCount = 0;
  9079. if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
  9080. auto *row = getVecValueZero(elemType, colCount);
  9081. llvm::SmallVector<SpirvConstant *, 4> rows((size_t)rowCount, row);
  9082. return spvBuilder.getConstantComposite(type, rows);
  9083. }
  9084. }
  9085. emitError("getting value 0 for type %0 unimplemented", {})
  9086. << type.getAsString();
  9087. return nullptr;
  9088. }
  9089. SpirvConstant *SpirvEmitter::getVecValueZero(QualType elemType, uint32_t size) {
  9090. auto *elemZeroId = getValueZero(elemType);
  9091. if (size == 1)
  9092. return elemZeroId;
  9093. llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), elemZeroId);
  9094. const QualType vecType = astContext.getExtVectorType(elemType, size);
  9095. return spvBuilder.getConstantComposite(vecType, elements);
  9096. }
  9097. SpirvConstant *SpirvEmitter::getValueOne(QualType type) {
  9098. {
  9099. QualType scalarType = {};
  9100. if (isScalarType(type, &scalarType)) {
  9101. if (scalarType->isBooleanType()) {
  9102. return spvBuilder.getConstantBool(true);
  9103. }
  9104. if (scalarType->isIntegerType()) {
  9105. return spvBuilder.getConstantInt(scalarType, llvm::APInt(32, 1));
  9106. }
  9107. if (scalarType->isFloatingType()) {
  9108. return spvBuilder.getConstantFloat(scalarType, llvm::APFloat(1.0f));
  9109. }
  9110. }
  9111. }
  9112. {
  9113. QualType elemType = {};
  9114. uint32_t size = {};
  9115. if (isVectorType(type, &elemType, &size)) {
  9116. return getVecValueOne(elemType, size);
  9117. }
  9118. }
  9119. emitError("getting value 1 for type %0 unimplemented", {}) << type;
  9120. return 0;
  9121. }
  9122. SpirvConstant *SpirvEmitter::getVecValueOne(QualType elemType, uint32_t size) {
  9123. auto *elemOne = getValueOne(elemType);
  9124. if (size == 1)
  9125. return elemOne;
  9126. llvm::SmallVector<SpirvConstant *, 4> elements(size_t(size), elemOne);
  9127. const QualType vecType = astContext.getExtVectorType(elemType, size);
  9128. return spvBuilder.getConstantComposite(vecType, elements);
  9129. }
  9130. SpirvConstant *SpirvEmitter::getMatElemValueOne(QualType type) {
  9131. assert(hlsl::IsHLSLMatType(type));
  9132. const auto elemType = hlsl::GetHLSLMatElementType(type);
  9133. uint32_t rowCount = 0, colCount = 0;
  9134. hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
  9135. if (rowCount == 1 && colCount == 1)
  9136. return getValueOne(elemType);
  9137. if (colCount == 1)
  9138. return getVecValueOne(elemType, rowCount);
  9139. return getVecValueOne(elemType, colCount);
  9140. }
  9141. SpirvConstant *SpirvEmitter::getMaskForBitwidthValue(QualType type) {
  9142. QualType elemType = {};
  9143. uint32_t count = 1;
  9144. if (isScalarType(type, &elemType) || isVectorType(type, &elemType, &count)) {
  9145. const auto bitwidth = getElementSpirvBitwidth(
  9146. astContext, elemType, spirvOptions.enable16BitTypes);
  9147. SpirvConstant *mask = spvBuilder.getConstantInt(
  9148. elemType,
  9149. llvm::APInt(bitwidth, bitwidth - 1, elemType->isSignedIntegerType()));
  9150. if (count == 1)
  9151. return mask;
  9152. const QualType resultType = astContext.getExtVectorType(elemType, count);
  9153. llvm::SmallVector<SpirvConstant *, 4> elements(size_t(count), mask);
  9154. return spvBuilder.getConstantComposite(resultType, elements);
  9155. }
  9156. assert(false && "this method only supports scalars and vectors");
  9157. return nullptr;
  9158. }
  9159. SpirvConstant *SpirvEmitter::translateAPValue(const APValue &value,
  9160. const QualType targetType) {
  9161. SpirvConstant *result = nullptr;
  9162. if (targetType->isBooleanType()) {
  9163. result = spvBuilder.getConstantBool(value.getInt().getBoolValue(),
  9164. isSpecConstantMode);
  9165. } else if (targetType->isIntegerType()) {
  9166. result = translateAPInt(value.getInt(), targetType);
  9167. } else if (targetType->isFloatingType()) {
  9168. result = translateAPFloat(value.getFloat(), targetType);
  9169. } else if (hlsl::IsHLSLVecType(targetType)) {
  9170. const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
  9171. const auto numElements = value.getVectorLength();
  9172. // Special case for vectors of size 1. SPIR-V doesn't support this vector
  9173. // size so we need to translate it to scalar values.
  9174. if (numElements == 1) {
  9175. result = translateAPValue(value.getVectorElt(0), elemType);
  9176. } else {
  9177. llvm::SmallVector<SpirvConstant *, 4> elements;
  9178. for (uint32_t i = 0; i < numElements; ++i) {
  9179. elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
  9180. }
  9181. result = spvBuilder.getConstantComposite(targetType, elements);
  9182. }
  9183. }
  9184. if (result)
  9185. return result;
  9186. emitError("APValue of type %0 unimplemented", {}) << value.getKind();
  9187. value.dump();
  9188. return 0;
  9189. }
  9190. SpirvConstant *SpirvEmitter::translateAPInt(const llvm::APInt &intValue,
  9191. QualType targetType) {
  9192. return spvBuilder.getConstantInt(targetType, intValue, isSpecConstantMode);
  9193. }
  9194. bool SpirvEmitter::isLiteralLargerThan32Bits(const Expr *expr) {
  9195. if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  9196. const bool isSigned = expr->getType()->isSignedIntegerType();
  9197. const llvm::APInt &value = intLiteral->getValue();
  9198. return (isSigned && !value.isSignedIntN(32)) ||
  9199. (!isSigned && !value.isIntN(32));
  9200. }
  9201. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  9202. llvm::APFloat value = floatLiteral->getValue();
  9203. const auto &semantics = value.getSemantics();
  9204. // regular 'half' and 'float' can be represented in 32 bits.
  9205. if (&semantics == &llvm::APFloat::IEEEsingle ||
  9206. &semantics == &llvm::APFloat::IEEEhalf)
  9207. return true;
  9208. // See if 'double' value can be represented in 32 bits without losing info.
  9209. bool losesInfo = false;
  9210. const auto convertStatus =
  9211. value.convert(llvm::APFloat::IEEEsingle,
  9212. llvm::APFloat::rmNearestTiesToEven, &losesInfo);
  9213. if (convertStatus != llvm::APFloat::opOK &&
  9214. convertStatus != llvm::APFloat::opInexact)
  9215. return true;
  9216. }
  9217. return false;
  9218. }
  9219. SpirvConstant *SpirvEmitter::tryToEvaluateAsInt32(const llvm::APInt &intValue,
  9220. bool isSigned) {
  9221. if (isSigned && intValue.isSignedIntN(32)) {
  9222. return spvBuilder.getConstantInt(astContext.IntTy, intValue);
  9223. }
  9224. if (!isSigned && intValue.isIntN(32)) {
  9225. return spvBuilder.getConstantInt(astContext.UnsignedIntTy, intValue);
  9226. }
  9227. // Couldn't evaluate as a 32-bit int without losing information.
  9228. return nullptr;
  9229. }
  9230. SpirvConstant *
  9231. SpirvEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
  9232. const auto &semantics = floatValue.getSemantics();
  9233. // If the given value is already a 32-bit float, there is no need to convert.
  9234. if (&semantics == &llvm::APFloat::IEEEsingle) {
  9235. return spvBuilder.getConstantFloat(astContext.FloatTy, floatValue,
  9236. isSpecConstantMode);
  9237. }
  9238. // Try to see if this literal float can be represented in 32-bit.
  9239. // Since the convert function below may modify the fp value, we call it on a
  9240. // temporary copy.
  9241. llvm::APFloat eval = floatValue;
  9242. bool losesInfo = false;
  9243. const auto convertStatus =
  9244. eval.convert(llvm::APFloat::IEEEsingle,
  9245. llvm::APFloat::rmNearestTiesToEven, &losesInfo);
  9246. if (convertStatus == llvm::APFloat::opOK && !losesInfo)
  9247. return spvBuilder.getConstantFloat(astContext.FloatTy,
  9248. llvm::APFloat(eval.convertToFloat()));
  9249. // Couldn't evaluate as a 32-bit float without losing information.
  9250. return nullptr;
  9251. }
  9252. SpirvConstant *SpirvEmitter::translateAPFloat(llvm::APFloat floatValue,
  9253. QualType targetType) {
  9254. return spvBuilder.getConstantFloat(targetType, floatValue,
  9255. isSpecConstantMode);
  9256. }
  9257. SpirvConstant *SpirvEmitter::tryToEvaluateAsConst(const Expr *expr) {
  9258. Expr::EvalResult evalResult;
  9259. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  9260. !evalResult.HasSideEffects) {
  9261. return translateAPValue(evalResult.Val, expr->getType());
  9262. }
  9263. return nullptr;
  9264. }
  9265. hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
  9266. hlsl::ShaderModel::Kind smk;
  9267. switch (stageName[0]) {
  9268. case 'c':
  9269. switch (stageName[1]) {
  9270. case 'o':
  9271. smk = hlsl::ShaderModel::Kind::Compute;
  9272. break;
  9273. case 'l':
  9274. smk = hlsl::ShaderModel::Kind::ClosestHit;
  9275. break;
  9276. case 'a':
  9277. smk = hlsl::ShaderModel::Kind::Callable;
  9278. break;
  9279. default:
  9280. smk = hlsl::ShaderModel::Kind::Invalid;
  9281. break;
  9282. }
  9283. break;
  9284. case 'v':
  9285. smk = hlsl::ShaderModel::Kind::Vertex;
  9286. break;
  9287. case 'h':
  9288. smk = hlsl::ShaderModel::Kind::Hull;
  9289. break;
  9290. case 'd':
  9291. smk = hlsl::ShaderModel::Kind::Domain;
  9292. break;
  9293. case 'g':
  9294. smk = hlsl::ShaderModel::Kind::Geometry;
  9295. break;
  9296. case 'p':
  9297. smk = hlsl::ShaderModel::Kind::Pixel;
  9298. break;
  9299. case 'r':
  9300. smk = hlsl::ShaderModel::Kind::RayGeneration;
  9301. break;
  9302. case 'i':
  9303. smk = hlsl::ShaderModel::Kind::Intersection;
  9304. break;
  9305. case 'a':
  9306. switch (stageName[1]) {
  9307. case 'm':
  9308. smk = hlsl::ShaderModel::Kind::Amplification;
  9309. break;
  9310. case 'n':
  9311. smk = hlsl::ShaderModel::Kind::AnyHit;
  9312. break;
  9313. }
  9314. break;
  9315. case 'm':
  9316. switch (stageName[1]) {
  9317. case 'e':
  9318. smk = hlsl::ShaderModel::Kind::Mesh;
  9319. break;
  9320. case 'i':
  9321. smk = hlsl::ShaderModel::Kind::Miss;
  9322. break;
  9323. }
  9324. break;
  9325. default:
  9326. smk = hlsl::ShaderModel::Kind::Invalid;
  9327. break;
  9328. }
  9329. if (smk == hlsl::ShaderModel::Kind::Invalid) {
  9330. llvm_unreachable("unknown stage name");
  9331. }
  9332. return smk;
  9333. }
  9334. spv::ExecutionModel
  9335. SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
  9336. switch (smk) {
  9337. case hlsl::ShaderModel::Kind::Vertex:
  9338. return spv::ExecutionModel::Vertex;
  9339. case hlsl::ShaderModel::Kind::Hull:
  9340. return spv::ExecutionModel::TessellationControl;
  9341. case hlsl::ShaderModel::Kind::Domain:
  9342. return spv::ExecutionModel::TessellationEvaluation;
  9343. case hlsl::ShaderModel::Kind::Geometry:
  9344. return spv::ExecutionModel::Geometry;
  9345. case hlsl::ShaderModel::Kind::Pixel:
  9346. return spv::ExecutionModel::Fragment;
  9347. case hlsl::ShaderModel::Kind::Compute:
  9348. return spv::ExecutionModel::GLCompute;
  9349. case hlsl::ShaderModel::Kind::RayGeneration:
  9350. return spv::ExecutionModel::RayGenerationNV;
  9351. case hlsl::ShaderModel::Kind::Intersection:
  9352. return spv::ExecutionModel::IntersectionNV;
  9353. case hlsl::ShaderModel::Kind::AnyHit:
  9354. return spv::ExecutionModel::AnyHitNV;
  9355. case hlsl::ShaderModel::Kind::ClosestHit:
  9356. return spv::ExecutionModel::ClosestHitNV;
  9357. case hlsl::ShaderModel::Kind::Miss:
  9358. return spv::ExecutionModel::MissNV;
  9359. case hlsl::ShaderModel::Kind::Callable:
  9360. return spv::ExecutionModel::CallableNV;
  9361. case hlsl::ShaderModel::Kind::Mesh:
  9362. return spv::ExecutionModel::MeshNV;
  9363. case hlsl::ShaderModel::Kind::Amplification:
  9364. return spv::ExecutionModel::TaskNV;
  9365. default:
  9366. llvm_unreachable("invalid shader model kind");
  9367. break;
  9368. }
  9369. }
  9370. bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
  9371. uint32_t *arraySize) {
  9372. bool success = true;
  9373. assert(spvContext.isGS());
  9374. if (auto *vcAttr = decl->getAttr<HLSLMaxVertexCountAttr>()) {
  9375. spvBuilder.addExecutionMode(
  9376. entryFunction, spv::ExecutionMode::OutputVertices,
  9377. {static_cast<uint32_t>(vcAttr->getCount())}, decl->getLocation());
  9378. }
  9379. uint32_t invocations = 1;
  9380. if (auto *instanceAttr = decl->getAttr<HLSLInstanceAttr>()) {
  9381. invocations = static_cast<uint32_t>(instanceAttr->getCount());
  9382. }
  9383. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::Invocations,
  9384. {invocations}, decl->getLocation());
  9385. // Only one primitive type is permitted for the geometry shader.
  9386. bool outPoint = false, outLine = false, outTriangle = false, inPoint = false,
  9387. inLine = false, inTriangle = false, inLineAdj = false,
  9388. inTriangleAdj = false;
  9389. for (const auto *param : decl->params()) {
  9390. // Add an execution mode based on the output stream type. Do not an
  9391. // execution mode more than once.
  9392. if (param->hasAttr<HLSLInOutAttr>()) {
  9393. const auto paramType = param->getType();
  9394. if (hlsl::IsHLSLTriangleStreamType(paramType) && !outTriangle) {
  9395. spvBuilder.addExecutionMode(entryFunction,
  9396. spv::ExecutionMode::OutputTriangleStrip, {},
  9397. param->getLocation());
  9398. outTriangle = true;
  9399. } else if (hlsl::IsHLSLLineStreamType(paramType) && !outLine) {
  9400. spvBuilder.addExecutionMode(entryFunction,
  9401. spv::ExecutionMode::OutputLineStrip, {},
  9402. param->getLocation());
  9403. outLine = true;
  9404. } else if (hlsl::IsHLSLPointStreamType(paramType) && !outPoint) {
  9405. spvBuilder.addExecutionMode(entryFunction,
  9406. spv::ExecutionMode::OutputPoints, {},
  9407. param->getLocation());
  9408. outPoint = true;
  9409. }
  9410. // An output stream parameter will not have the input primitive type
  9411. // attributes, so we can continue to the next parameter.
  9412. continue;
  9413. }
  9414. // Add an execution mode based on the input primitive type. Do not add an
  9415. // execution mode more than once.
  9416. if (param->hasAttr<HLSLPointAttr>() && !inPoint) {
  9417. spvBuilder.addExecutionMode(entryFunction,
  9418. spv::ExecutionMode::InputPoints, {},
  9419. param->getLocation());
  9420. *arraySize = 1;
  9421. inPoint = true;
  9422. } else if (param->hasAttr<HLSLLineAttr>() && !inLine) {
  9423. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::InputLines,
  9424. {}, param->getLocation());
  9425. *arraySize = 2;
  9426. inLine = true;
  9427. } else if (param->hasAttr<HLSLTriangleAttr>() && !inTriangle) {
  9428. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::Triangles,
  9429. {}, param->getLocation());
  9430. *arraySize = 3;
  9431. inTriangle = true;
  9432. } else if (param->hasAttr<HLSLLineAdjAttr>() && !inLineAdj) {
  9433. spvBuilder.addExecutionMode(entryFunction,
  9434. spv::ExecutionMode::InputLinesAdjacency, {},
  9435. param->getLocation());
  9436. *arraySize = 4;
  9437. inLineAdj = true;
  9438. } else if (param->hasAttr<HLSLTriangleAdjAttr>() && !inTriangleAdj) {
  9439. spvBuilder.addExecutionMode(entryFunction,
  9440. spv::ExecutionMode::InputTrianglesAdjacency,
  9441. {}, param->getLocation());
  9442. *arraySize = 6;
  9443. inTriangleAdj = true;
  9444. }
  9445. }
  9446. if (inPoint + inLine + inLineAdj + inTriangle + inTriangleAdj > 1) {
  9447. emitError("only one input primitive type can be specified in the geometry "
  9448. "shader",
  9449. {});
  9450. success = false;
  9451. }
  9452. if (outPoint + outTriangle + outLine > 1) {
  9453. emitError("only one output primitive type can be specified in the geometry "
  9454. "shader",
  9455. {});
  9456. success = false;
  9457. }
  9458. return success;
  9459. }
  9460. void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
  9461. spvBuilder.addExecutionMode(entryFunction,
  9462. spv::ExecutionMode::OriginUpperLeft, {},
  9463. decl->getLocation());
  9464. if (decl->getAttr<HLSLEarlyDepthStencilAttr>()) {
  9465. spvBuilder.addExecutionMode(entryFunction,
  9466. spv::ExecutionMode::EarlyFragmentTests, {},
  9467. decl->getLocation());
  9468. }
  9469. if (decl->getAttr<VKPostDepthCoverageAttr>()) {
  9470. spvBuilder.addExecutionMode(entryFunction,
  9471. spv::ExecutionMode::PostDepthCoverage, {},
  9472. decl->getLocation());
  9473. }
  9474. }
  9475. void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
  9476. // If not explicitly specified, x, y, and z should be defaulted to 1.
  9477. uint32_t x = 1, y = 1, z = 1;
  9478. if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
  9479. x = static_cast<uint32_t>(numThreadsAttr->getX());
  9480. y = static_cast<uint32_t>(numThreadsAttr->getY());
  9481. z = static_cast<uint32_t>(numThreadsAttr->getZ());
  9482. } else {
  9483. emitError("thread group size [numthreads(x,y,z)] is missing from the "
  9484. "entry-point function",
  9485. decl->getLocation());
  9486. return;
  9487. }
  9488. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
  9489. {x, y, z}, decl->getLocation());
  9490. }
  9491. bool SpirvEmitter::processTessellationShaderAttributes(
  9492. const FunctionDecl *decl, uint32_t *numOutputControlPoints) {
  9493. assert(spvContext.isHS() || spvContext.isDS());
  9494. using namespace spv;
  9495. if (auto *domain = decl->getAttr<HLSLDomainAttr>()) {
  9496. const auto domainType = domain->getDomainType().lower();
  9497. const ExecutionMode hsExecMode =
  9498. llvm::StringSwitch<ExecutionMode>(domainType)
  9499. .Case("tri", ExecutionMode::Triangles)
  9500. .Case("quad", ExecutionMode::Quads)
  9501. .Case("isoline", ExecutionMode::Isolines)
  9502. .Default(ExecutionMode::Max);
  9503. if (hsExecMode == ExecutionMode::Max) {
  9504. emitError("unknown domain type specified for entry function",
  9505. domain->getLocation());
  9506. return false;
  9507. }
  9508. spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
  9509. decl->getLocation());
  9510. }
  9511. // Early return for domain shaders as domain shaders only takes the 'domain'
  9512. // attribute.
  9513. if (spvContext.isDS())
  9514. return true;
  9515. if (auto *partitioning = decl->getAttr<HLSLPartitioningAttr>()) {
  9516. const auto scheme = partitioning->getScheme().lower();
  9517. if (scheme == "pow2") {
  9518. emitError("pow2 partitioning scheme is not supported since there is no "
  9519. "equivalent in Vulkan",
  9520. partitioning->getLocation());
  9521. return false;
  9522. }
  9523. const ExecutionMode hsExecMode =
  9524. llvm::StringSwitch<ExecutionMode>(scheme)
  9525. .Case("fractional_even", ExecutionMode::SpacingFractionalEven)
  9526. .Case("fractional_odd", ExecutionMode::SpacingFractionalOdd)
  9527. .Case("integer", ExecutionMode::SpacingEqual)
  9528. .Default(ExecutionMode::Max);
  9529. if (hsExecMode == ExecutionMode::Max) {
  9530. emitError("unknown partitioning scheme in hull shader",
  9531. partitioning->getLocation());
  9532. return false;
  9533. }
  9534. spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
  9535. decl->getLocation());
  9536. }
  9537. if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
  9538. const auto topology = outputTopology->getTopology().lower();
  9539. const ExecutionMode hsExecMode =
  9540. llvm::StringSwitch<ExecutionMode>(topology)
  9541. .Case("point", ExecutionMode::PointMode)
  9542. .Case("triangle_cw", ExecutionMode::VertexOrderCw)
  9543. .Case("triangle_ccw", ExecutionMode::VertexOrderCcw)
  9544. .Default(ExecutionMode::Max);
  9545. // TODO: There is no SPIR-V equivalent for "line" topology. Is it the
  9546. // default?
  9547. if (topology != "line") {
  9548. if (hsExecMode != spv::ExecutionMode::Max) {
  9549. spvBuilder.addExecutionMode(entryFunction, hsExecMode, {},
  9550. decl->getLocation());
  9551. } else {
  9552. emitError("unknown output topology in hull shader",
  9553. outputTopology->getLocation());
  9554. return false;
  9555. }
  9556. }
  9557. }
  9558. if (auto *controlPoints = decl->getAttr<HLSLOutputControlPointsAttr>()) {
  9559. *numOutputControlPoints = controlPoints->getCount();
  9560. spvBuilder.addExecutionMode(entryFunction,
  9561. spv::ExecutionMode::OutputVertices,
  9562. {*numOutputControlPoints}, decl->getLocation());
  9563. }
  9564. if (auto *pcf = decl->getAttr<HLSLPatchConstantFuncAttr>()) {
  9565. llvm::StringRef pcf_name = pcf->getFunctionName();
  9566. for (auto *decl : astContext.getTranslationUnitDecl()->decls())
  9567. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl))
  9568. if (astContext.IsPatchConstantFunctionDecl(funcDecl) &&
  9569. funcDecl->getName() == pcf_name)
  9570. patchConstFunc = funcDecl;
  9571. }
  9572. return true;
  9573. }
  9574. bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
  9575. const FunctionDecl *decl, SpirvFunction *entryFuncInstr) {
  9576. // The entry basic block.
  9577. auto *entryLabel = spvBuilder.createBasicBlock();
  9578. spvBuilder.setInsertPoint(entryLabel);
  9579. // Initialize all global variables at the beginning of the wrapper
  9580. for (const VarDecl *varDecl : toInitGloalVars) {
  9581. const auto varInfo =
  9582. declIdMapper.getDeclEvalInfo(varDecl, varDecl->getLocation());
  9583. if (const auto *init = varDecl->getInit()) {
  9584. storeValue(varInfo, loadIfGLValue(init), varDecl->getType(),
  9585. init->getLocStart());
  9586. // Update counter variable associated with global variables
  9587. tryToAssignCounterVar(varDecl, init);
  9588. }
  9589. // If not explicitly initialized, initialize with their zero values if not
  9590. // resource objects
  9591. else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
  9592. auto *nullValue = spvBuilder.getConstantNull(varDecl->getType());
  9593. spvBuilder.createStore(varInfo, nullValue, varDecl->getLocation());
  9594. }
  9595. }
  9596. // Create temporary variables for holding function call arguments
  9597. llvm::SmallVector<SpirvInstruction *, 4> params;
  9598. llvm::SmallVector<QualType, 4> paramTypes;
  9599. llvm::SmallVector<SpirvInstruction *, 4> stageVars;
  9600. hlsl::ShaderModel::Kind sKind = spvContext.getCurrentShaderModelKind();
  9601. for (uint32_t i = 0; i < decl->getNumParams(); i++) {
  9602. const auto param = decl->getParamDecl(i);
  9603. const auto paramType = param->getType();
  9604. std::string tempVarName = "param.var." + param->getNameAsString();
  9605. auto *tempVar =
  9606. spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName,
  9607. param->hasAttr<HLSLPreciseAttr>());
  9608. SpirvVariable *curStageVar = nullptr;
  9609. params.push_back(tempVar);
  9610. paramTypes.push_back(paramType);
  9611. // Order of arguments is fixed
  9612. // Any-Hit/Closest-Hit : Arg 0 = rayPayload(inout), Arg1 = attribute(in)
  9613. // Miss : Arg 0 = rayPayload(inout)
  9614. // Callable : Arg 0 = callable data(inout)
  9615. // Raygeneration/Intersection : No Args allowed
  9616. if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
  9617. assert("Raygeneration shaders have no arguments of entry function");
  9618. } else if (sKind == hlsl::ShaderModel::Kind::Intersection) {
  9619. assert("Intersection shaders have no arguments of entry function");
  9620. } else if (sKind == hlsl::ShaderModel::Kind::ClosestHit ||
  9621. sKind == hlsl::ShaderModel::Kind::AnyHit) {
  9622. // Generate rayPayloadInNV and hitAttributeNV stage variables
  9623. if (i == 0) {
  9624. // First argument is always rayPayload
  9625. curStageVar = declIdMapper.createRayTracingNVStageVar(
  9626. spv::StorageClass::IncomingRayPayloadNV, param);
  9627. currentRayPayload = curStageVar;
  9628. } else {
  9629. // Second argument is always attribute
  9630. curStageVar = declIdMapper.createRayTracingNVStageVar(
  9631. spv::StorageClass::HitAttributeNV, param);
  9632. }
  9633. } else if (sKind == hlsl::ShaderModel::Kind::Miss) {
  9634. // Generate rayPayloadInNV stage variable
  9635. // First and only argument is rayPayload
  9636. curStageVar = declIdMapper.createRayTracingNVStageVar(
  9637. spv::StorageClass::IncomingRayPayloadNV, param);
  9638. } else if (sKind == hlsl::ShaderModel::Kind::Callable) {
  9639. curStageVar = declIdMapper.createRayTracingNVStageVar(
  9640. spv::StorageClass::IncomingCallableDataNV, param);
  9641. }
  9642. if (curStageVar != nullptr) {
  9643. stageVars.push_back(curStageVar);
  9644. // Copy data to temporary
  9645. auto *tempLoadInst =
  9646. spvBuilder.createLoad(paramType, curStageVar, param->getLocation());
  9647. spvBuilder.createStore(tempVar, tempLoadInst, param->getLocation());
  9648. }
  9649. }
  9650. // Call the original entry function
  9651. const QualType retType = decl->getReturnType();
  9652. spvBuilder.createFunctionCall(retType, entryFuncInstr, params,
  9653. decl->getLocStart());
  9654. // Write certain output variables back
  9655. if (sKind == hlsl::ShaderModel::Kind::ClosestHit ||
  9656. sKind == hlsl::ShaderModel::Kind::AnyHit ||
  9657. sKind == hlsl::ShaderModel::Kind::Miss ||
  9658. sKind == hlsl::ShaderModel::Kind::Callable) {
  9659. // Write back results to IncomingRayPayloadNV/IncomingCallableDataNV
  9660. auto *tempLoad = spvBuilder.createLoad(paramTypes[0], params[0],
  9661. decl->getBody()->getLocEnd());
  9662. spvBuilder.createStore(stageVars[0], tempLoad,
  9663. decl->getBody()->getLocEnd());
  9664. }
  9665. spvBuilder.createReturn(decl->getBody()->getLocEnd());
  9666. spvBuilder.endFunction();
  9667. return true;
  9668. }
  9669. bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
  9670. const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
  9671. if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
  9672. uint32_t x, y, z;
  9673. x = static_cast<uint32_t>(numThreadsAttr->getX());
  9674. y = static_cast<uint32_t>(numThreadsAttr->getY());
  9675. z = static_cast<uint32_t>(numThreadsAttr->getZ());
  9676. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
  9677. {x, y, z}, decl->getLocation());
  9678. }
  9679. // Early return for amplification shaders as they only take the 'numthreads'
  9680. // attribute.
  9681. if (spvContext.isAS())
  9682. return true;
  9683. spv::ExecutionMode outputPrimitive = spv::ExecutionMode::Max;
  9684. if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
  9685. const auto topology = outputTopology->getTopology().lower();
  9686. outputPrimitive =
  9687. llvm::StringSwitch<spv::ExecutionMode>(topology)
  9688. .Case("point", spv::ExecutionMode::OutputPoints)
  9689. .Case("line", spv::ExecutionMode::OutputLinesNV)
  9690. .Case("triangle", spv::ExecutionMode::OutputTrianglesNV);
  9691. if (outputPrimitive != spv::ExecutionMode::Max) {
  9692. spvBuilder.addExecutionMode(entryFunction, outputPrimitive, {},
  9693. decl->getLocation());
  9694. } else {
  9695. emitError("unknown output topology in mesh shader",
  9696. outputTopology->getLocation());
  9697. return false;
  9698. }
  9699. }
  9700. uint32_t numVertices = 0;
  9701. uint32_t numIndices = 0;
  9702. uint32_t numPrimitives = 0;
  9703. bool payloadDeclSeen = false;
  9704. for (uint32_t i = 0; i < decl->getNumParams(); i++) {
  9705. const auto param = decl->getParamDecl(i);
  9706. const auto paramType = param->getType();
  9707. const auto paramLoc = param->getLocation();
  9708. if (param->hasAttr<HLSLVerticesAttr>() ||
  9709. param->hasAttr<HLSLIndicesAttr>() ||
  9710. param->hasAttr<HLSLPrimitivesAttr>()) {
  9711. uint32_t arraySize = 0;
  9712. if (const auto *arrayType =
  9713. astContext.getAsConstantArrayType(paramType)) {
  9714. const auto eleType =
  9715. arrayType->getElementType()->getCanonicalTypeUnqualified();
  9716. if (param->hasAttr<HLSLIndicesAttr>()) {
  9717. switch (outputPrimitive) {
  9718. case spv::ExecutionMode::OutputPoints:
  9719. if (eleType != astContext.UnsignedIntTy) {
  9720. emitError("expected 1D array of uint type", paramLoc);
  9721. return false;
  9722. }
  9723. break;
  9724. case spv::ExecutionMode::OutputLinesNV: {
  9725. QualType baseType;
  9726. uint32_t length;
  9727. if (!isVectorType(eleType, &baseType, &length) ||
  9728. baseType != astContext.UnsignedIntTy || length != 2) {
  9729. emitError("expected 1D array of uint2 type", paramLoc);
  9730. return false;
  9731. }
  9732. break;
  9733. }
  9734. case spv::ExecutionMode::OutputTrianglesNV: {
  9735. QualType baseType;
  9736. uint32_t length;
  9737. if (!isVectorType(eleType, &baseType, &length) ||
  9738. baseType != astContext.UnsignedIntTy || length != 3) {
  9739. emitError("expected 1D array of uint3 type", paramLoc);
  9740. return false;
  9741. }
  9742. break;
  9743. }
  9744. default:
  9745. assert(false && "unexpected spirv execution mode");
  9746. }
  9747. } else if (!eleType->isStructureType()) {
  9748. // vertices/primitives objects
  9749. emitError("expected 1D array of struct type", paramLoc);
  9750. return false;
  9751. }
  9752. arraySize = static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  9753. } else {
  9754. emitError("expected 1D array of indices/vertices/primitives object",
  9755. paramLoc);
  9756. return false;
  9757. }
  9758. if (param->hasAttr<HLSLVerticesAttr>()) {
  9759. if (numVertices != 0) {
  9760. emitError("only one object with 'vertices' modifier is allowed",
  9761. paramLoc);
  9762. return false;
  9763. }
  9764. numVertices = arraySize;
  9765. } else if (param->hasAttr<HLSLIndicesAttr>()) {
  9766. if (numIndices != 0) {
  9767. emitError("only one object with 'indices' modifier is allowed",
  9768. paramLoc);
  9769. return false;
  9770. }
  9771. numIndices = arraySize;
  9772. } else if (param->hasAttr<HLSLPrimitivesAttr>()) {
  9773. if (numPrimitives != 0) {
  9774. emitError("only one object with 'primitives' modifier is allowed",
  9775. paramLoc);
  9776. return false;
  9777. }
  9778. numPrimitives = arraySize;
  9779. }
  9780. } else if (param->hasAttr<HLSLPayloadAttr>()) {
  9781. if (payloadDeclSeen) {
  9782. emitError("only one object with 'payload' modifier is allowed",
  9783. paramLoc);
  9784. return false;
  9785. }
  9786. payloadDeclSeen = true;
  9787. if (!paramType->isStructureType()) {
  9788. emitError("expected payload of struct type", paramLoc);
  9789. return false;
  9790. }
  9791. }
  9792. }
  9793. // Vertex attribute array is a mandatory param to mesh entry function.
  9794. if (numVertices != 0) {
  9795. *outVerticesArraySize = numVertices;
  9796. spvBuilder.addExecutionMode(
  9797. entryFunction, spv::ExecutionMode::OutputVertices,
  9798. {static_cast<uint32_t>(numVertices)}, decl->getLocation());
  9799. } else {
  9800. emitError("expected vertices object declaration", decl->getLocation());
  9801. return false;
  9802. }
  9803. // Vertex indices array is a mandatory param to mesh entry function.
  9804. if (numIndices != 0) {
  9805. spvBuilder.addExecutionMode(
  9806. entryFunction, spv::ExecutionMode::OutputPrimitivesNV,
  9807. {static_cast<uint32_t>(numIndices)}, decl->getLocation());
  9808. // Primitive attribute array is an optional param to mesh entry function,
  9809. // but the array size should match the indices array.
  9810. if (numPrimitives != 0 && numPrimitives != numIndices) {
  9811. emitError("array size of primitives object should match 'indices' object",
  9812. decl->getLocation());
  9813. return false;
  9814. }
  9815. } else {
  9816. emitError("expected indices object declaration", decl->getLocation());
  9817. return false;
  9818. }
  9819. return true;
  9820. }
  9821. bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
  9822. SpirvFunction *entryFuncInstr) {
  9823. // HS specific attributes
  9824. uint32_t numOutputControlPoints = 0;
  9825. SpirvInstruction *outputControlPointIdVal =
  9826. nullptr; // SV_OutputControlPointID value
  9827. SpirvInstruction *primitiveIdVar = nullptr; // SV_PrimitiveID variable
  9828. SpirvInstruction *viewIdVar = nullptr; // SV_ViewID variable
  9829. SpirvInstruction *hullMainInputPatchParam =
  9830. nullptr; // Temporary parameter for InputPatch<>
  9831. // The array size of per-vertex input/output variables
  9832. // Used by HS/DS/GS for the additional arrayness, zero means not an array.
  9833. uint32_t inputArraySize = 0;
  9834. uint32_t outputArraySize = 0;
  9835. // The wrapper entry function surely does not have pre-assigned <result-id>
  9836. // for it like other functions that got added to the work queue following
  9837. // function calls. And the wrapper is the entry function.
  9838. entryFunction = spvBuilder.beginFunction(
  9839. astContext.VoidTy, decl->getLocStart(), decl->getName());
  9840. // Specify that entryFunction is an entry function wrapper.
  9841. entryFunction->setEntryFunctionWrapper();
  9842. // Note this should happen before using declIdMapper for other tasks.
  9843. declIdMapper.setEntryFunction(entryFunction);
  9844. // Set entryFunction for current entry point.
  9845. auto iter = functionInfoMap.find(decl);
  9846. assert(iter != functionInfoMap.end());
  9847. auto &entryInfo = iter->second;
  9848. assert(entryInfo->isEntryFunction);
  9849. entryInfo->entryFunction = entryFunction;
  9850. if (spvContext.isRay()) {
  9851. return emitEntryFunctionWrapperForRayTracing(decl, entryFuncInstr);
  9852. }
  9853. // Handle attributes specific to each shader stage
  9854. if (spvContext.isPS()) {
  9855. processPixelShaderAttributes(decl);
  9856. } else if (spvContext.isCS()) {
  9857. processComputeShaderAttributes(decl);
  9858. } else if (spvContext.isHS()) {
  9859. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  9860. return false;
  9861. // The input array size for HS is specified in the InputPatch parameter.
  9862. for (const auto *param : decl->params())
  9863. if (hlsl::IsHLSLInputPatchType(param->getType())) {
  9864. inputArraySize = hlsl::GetHLSLInputPatchCount(param->getType());
  9865. break;
  9866. }
  9867. outputArraySize = numOutputControlPoints;
  9868. } else if (spvContext.isDS()) {
  9869. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  9870. return false;
  9871. // The input array size for HS is specified in the OutputPatch parameter.
  9872. for (const auto *param : decl->params())
  9873. if (hlsl::IsHLSLOutputPatchType(param->getType())) {
  9874. inputArraySize = hlsl::GetHLSLOutputPatchCount(param->getType());
  9875. break;
  9876. }
  9877. // The per-vertex output of DS is not an array.
  9878. } else if (spvContext.isGS()) {
  9879. if (!processGeometryShaderAttributes(decl, &inputArraySize))
  9880. return false;
  9881. // The per-vertex output of GS is not an array.
  9882. } else if (spvContext.isMS() || spvContext.isAS()) {
  9883. if (!processMeshOrAmplificationShaderAttributes(decl, &outputArraySize))
  9884. return false;
  9885. }
  9886. // Go through all parameters and record the declaration of SV_ClipDistance
  9887. // and SV_CullDistance. We need to do this extra step because in HLSL we
  9888. // can declare multiple SV_ClipDistance/SV_CullDistance variables of float
  9889. // or vector of float types, but we can only have one single float array
  9890. // for the ClipDistance/CullDistance builtin. So we need to group all
  9891. // SV_ClipDistance/SV_CullDistance variables into one float array, thus we
  9892. // need to calculate the total size of the array and the offset of each
  9893. // variable within that array.
  9894. // Also go through all parameters to record the semantic strings provided for
  9895. // the builtins in gl_PerVertex.
  9896. for (const auto *param : decl->params()) {
  9897. if (canActAsInParmVar(param))
  9898. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(param, true))
  9899. return false;
  9900. if (canActAsOutParmVar(param))
  9901. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(param, false))
  9902. return false;
  9903. }
  9904. // Also consider the SV_ClipDistance/SV_CullDistance in the return type
  9905. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(decl, false))
  9906. return false;
  9907. // Calculate the total size of the ClipDistance/CullDistance array and the
  9908. // offset of SV_ClipDistance/SV_CullDistance variables within the array.
  9909. declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
  9910. if (!spvContext.isCS() && !spvContext.isAS()) {
  9911. // Generate stand-alone builtins of Position, ClipDistance, and
  9912. // CullDistance, which belongs to gl_PerVertex.
  9913. declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
  9914. }
  9915. // The entry basic block.
  9916. auto *entryLabel = spvBuilder.createBasicBlock();
  9917. spvBuilder.setInsertPoint(entryLabel);
  9918. // Initialize all global variables at the beginning of the wrapper
  9919. for (const VarDecl *varDecl : toInitGloalVars) {
  9920. // SPIR-V does not have string variables
  9921. if (isStringType(varDecl->getType()))
  9922. continue;
  9923. const auto varInfo =
  9924. declIdMapper.getDeclEvalInfo(varDecl, varDecl->getLocation());
  9925. if (const auto *init = varDecl->getInit()) {
  9926. storeValue(varInfo, loadIfGLValue(init), varDecl->getType(),
  9927. init->getLocStart());
  9928. // Update counter variable associated with global variables
  9929. tryToAssignCounterVar(varDecl, init);
  9930. }
  9931. // If not explicitly initialized, initialize with their zero values if not
  9932. // resource objects
  9933. else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
  9934. auto *nullValue = spvBuilder.getConstantNull(varDecl->getType());
  9935. spvBuilder.createStore(varInfo, nullValue, varDecl->getLocation());
  9936. }
  9937. }
  9938. // Create temporary variables for holding function call arguments
  9939. llvm::SmallVector<SpirvInstruction *, 4> params;
  9940. for (const auto *param : decl->params()) {
  9941. const auto paramType = param->getType();
  9942. std::string tempVarName = "param.var." + param->getNameAsString();
  9943. auto *tempVar =
  9944. spvBuilder.addFnVar(paramType, param->getLocation(), tempVarName,
  9945. param->hasAttr<HLSLPreciseAttr>());
  9946. params.push_back(tempVar);
  9947. // Create the stage input variable for parameter not marked as pure out and
  9948. // initialize the corresponding temporary variable
  9949. // Also do not create input variables for output stream objects of geometry
  9950. // shaders (e.g. TriangleStream) which are required to be marked as 'inout'.
  9951. if (canActAsInParmVar(param)) {
  9952. if (spvContext.isHS() && hlsl::IsHLSLInputPatchType(paramType)) {
  9953. // Record the temporary variable holding InputPatch. It may be used
  9954. // later in the patch constant function.
  9955. hullMainInputPatchParam = tempVar;
  9956. }
  9957. SpirvInstruction *loadedValue = nullptr;
  9958. if (!declIdMapper.createStageInputVar(param, &loadedValue, false))
  9959. return false;
  9960. // Only initialize the temporary variable if the parameter is indeed used.
  9961. if (param->isUsed()) {
  9962. spvBuilder.createStore(tempVar, loadedValue, param->getLocation());
  9963. }
  9964. // Record the temporary variable holding SV_OutputControlPointID,
  9965. // SV_PrimitiveID, and SV_ViewID. It may be used later in the patch
  9966. // constant function.
  9967. if (hasSemantic(param, hlsl::DXIL::SemanticKind::OutputControlPointID))
  9968. outputControlPointIdVal = loadedValue;
  9969. else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID))
  9970. primitiveIdVar = tempVar;
  9971. else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID))
  9972. viewIdVar = tempVar;
  9973. }
  9974. }
  9975. // Call the original entry function
  9976. const QualType retType = decl->getReturnType();
  9977. auto *retVal = spvBuilder.createFunctionCall(retType, entryFuncInstr, params,
  9978. decl->getLocStart());
  9979. // Create and write stage output variables for return value. Special case for
  9980. // Hull shaders since they operate differently in 2 ways:
  9981. // 1- Their return value is in fact an array and each invocation should write
  9982. // to the proper offset in the array.
  9983. // 2- The patch constant function must be called *once* after all invocations
  9984. // of the main entry point function is done.
  9985. if (spvContext.isHS()) {
  9986. // Create stage output variables out of the return type.
  9987. if (!declIdMapper.createStageOutputVar(decl, numOutputControlPoints,
  9988. outputControlPointIdVal, retVal))
  9989. return false;
  9990. if (!processHSEntryPointOutputAndPCF(
  9991. decl, retType, retVal, numOutputControlPoints,
  9992. outputControlPointIdVal, primitiveIdVar, viewIdVar,
  9993. hullMainInputPatchParam))
  9994. return false;
  9995. } else {
  9996. if (!declIdMapper.createStageOutputVar(decl, retVal, /*forPCF*/ false))
  9997. return false;
  9998. }
  9999. // Create and write stage output variables for parameters marked as
  10000. // out/inout
  10001. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  10002. const auto *param = decl->getParamDecl(i);
  10003. if (canActAsOutParmVar(param)) {
  10004. // Load the value from the parameter after function call
  10005. SpirvInstruction *loadedParam = nullptr;
  10006. // No need to write back the value if the parameter is not used at all in
  10007. // the original entry function.
  10008. //
  10009. // Write back of stage output variables in GS is manually controlled by
  10010. // .Append() intrinsic method. No need to load the parameter since we
  10011. // won't need to write back here.
  10012. if (param->isUsed() && !spvContext.isGS())
  10013. loadedParam = spvBuilder.createLoad(param->getType(), params[i],
  10014. param->getLocStart());
  10015. if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
  10016. return false;
  10017. }
  10018. }
  10019. // For wrapper of entry point, it is better not to specify SourceLocation
  10020. // for return statement, because it is not the location of the actual
  10021. // return and emitting the location of the end of entry function makes
  10022. // us confused. It is better to emit debug line just before OpFunctionEnd.
  10023. spvBuilder.createReturn(/* SourceLocation */ {});
  10024. spvBuilder.endFunction();
  10025. // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
  10026. // We should invoke a translation of the PCF manually.
  10027. if (spvContext.isHS())
  10028. doDecl(patchConstFunc);
  10029. return true;
  10030. }
  10031. bool SpirvEmitter::processHSEntryPointOutputAndPCF(
  10032. const FunctionDecl *hullMainFuncDecl, QualType retType,
  10033. SpirvInstruction *retVal, uint32_t numOutputControlPoints,
  10034. SpirvInstruction *outputControlPointId, SpirvInstruction *primitiveId,
  10035. SpirvInstruction *viewId, SpirvInstruction *hullMainInputPatch) {
  10036. // This method may only be called for Hull shaders.
  10037. assert(spvContext.isHS());
  10038. auto loc = hullMainFuncDecl->getLocation();
  10039. auto locEnd = hullMainFuncDecl->getLocEnd();
  10040. // For Hull shaders, the real output is an array of size
  10041. // numOutputControlPoints. The results of the main should be written to the
  10042. // correct offset in the array (based on InvocationID).
  10043. if (!numOutputControlPoints) {
  10044. emitError("number of output control points cannot be zero", loc);
  10045. return false;
  10046. }
  10047. // TODO: We should be able to handle cases where the SV_OutputControlPointID
  10048. // is not provided.
  10049. if (!outputControlPointId) {
  10050. emitError(
  10051. "SV_OutputControlPointID semantic must be provided in hull shader",
  10052. loc);
  10053. return false;
  10054. }
  10055. if (!patchConstFunc) {
  10056. emitError("patch constant function not defined in hull shader", loc);
  10057. return false;
  10058. }
  10059. SpirvInstruction *hullMainOutputPatch = nullptr;
  10060. // If the patch constant function (PCF) takes the result of the Hull main
  10061. // entry point, create a temporary function-scope variable and write the
  10062. // results to it, so it can be passed to the PCF.
  10063. if (const auto *param = patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
  10064. hullMainOutputPatch = declIdMapper.createHullMainOutputPatch(
  10065. param, retType, numOutputControlPoints, locEnd);
  10066. auto *tempLocation = spvBuilder.createAccessChain(
  10067. retType, hullMainOutputPatch, {outputControlPointId}, locEnd);
  10068. spvBuilder.createStore(tempLocation, retVal, locEnd);
  10069. }
  10070. // Now create a barrier before calling the Patch Constant Function (PCF).
  10071. // Flags are:
  10072. // Execution Barrier scope = Workgroup (2)
  10073. // Memory Barrier scope = Invocation (4)
  10074. // Memory Semantics Barrier scope = None (0)
  10075. spvBuilder.createBarrier(spv::Scope::Invocation,
  10076. spv::MemorySemanticsMask::MaskNone,
  10077. spv::Scope::Workgroup, {});
  10078. // The PCF should be called only once. Therefore, we check the invocationID,
  10079. // and we only allow ID 0 to call the PCF.
  10080. auto *condition = spvBuilder.createBinaryOp(
  10081. spv::Op::OpIEqual, astContext.BoolTy, outputControlPointId,
  10082. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)),
  10083. loc);
  10084. auto *thenBB = spvBuilder.createBasicBlock("if.true");
  10085. auto *mergeBB = spvBuilder.createBasicBlock("if.merge");
  10086. spvBuilder.createConditionalBranch(condition, thenBB, mergeBB, loc, mergeBB);
  10087. spvBuilder.addSuccessor(thenBB);
  10088. spvBuilder.addSuccessor(mergeBB);
  10089. spvBuilder.setMergeTarget(mergeBB);
  10090. spvBuilder.setInsertPoint(thenBB);
  10091. // Call the PCF. Since the function is not explicitly called, we must first
  10092. // register an ID for it.
  10093. SpirvFunction *pcfId = declIdMapper.getOrRegisterFn(patchConstFunc);
  10094. const QualType pcfRetType = patchConstFunc->getReturnType();
  10095. std::vector<SpirvInstruction *> pcfParams;
  10096. // A lambda for creating a stage input variable and its associated temporary
  10097. // variable for function call. Also initializes the temporary variable using
  10098. // the contents loaded from the stage input variable. Returns the <result-id>
  10099. // of the temporary variable.
  10100. const auto createParmVarAndInitFromStageInputVar =
  10101. [this](const ParmVarDecl *param) {
  10102. const QualType type = param->getType();
  10103. std::string tempVarName = "param.var." + param->getNameAsString();
  10104. auto paramLoc = param->getLocation();
  10105. auto *tempVar = spvBuilder.addFnVar(type, paramLoc, tempVarName,
  10106. param->hasAttr<HLSLPreciseAttr>());
  10107. SpirvInstruction *loadedValue = nullptr;
  10108. declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
  10109. spvBuilder.createStore(tempVar, loadedValue, paramLoc);
  10110. return tempVar;
  10111. };
  10112. for (const auto *param : patchConstFunc->parameters()) {
  10113. // Note: According to the HLSL reference, the PCF takes an InputPatch of
  10114. // ControlPoints as well as the PatchID (PrimitiveID). This does not
  10115. // necessarily mean that they are present. There is also no requirement
  10116. // for the order of parameters passed to PCF.
  10117. if (hlsl::IsHLSLInputPatchType(param->getType())) {
  10118. pcfParams.push_back(hullMainInputPatch);
  10119. } else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
  10120. // Since the output patch used in hull shaders is translated to
  10121. // a variable with Workgroup storage class, there is no need
  10122. // to pass the variable as function parameter in SPIR-V.
  10123. continue;
  10124. } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
  10125. if (!primitiveId) {
  10126. primitiveId = createParmVarAndInitFromStageInputVar(param);
  10127. }
  10128. pcfParams.push_back(primitiveId);
  10129. } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID)) {
  10130. if (!viewId) {
  10131. viewId = createParmVarAndInitFromStageInputVar(param);
  10132. }
  10133. pcfParams.push_back(viewId);
  10134. } else {
  10135. emitError("patch constant function parameter '%0' unknown",
  10136. param->getLocation())
  10137. << param->getName();
  10138. }
  10139. }
  10140. auto *pcfResultId = spvBuilder.createFunctionCall(
  10141. pcfRetType, pcfId, {pcfParams}, hullMainFuncDecl->getLocStart());
  10142. if (!declIdMapper.createStageOutputVar(patchConstFunc, pcfResultId,
  10143. /*forPCF*/ true))
  10144. return false;
  10145. spvBuilder.createBranch(mergeBB, locEnd);
  10146. spvBuilder.addSuccessor(mergeBB);
  10147. spvBuilder.setInsertPoint(mergeBB);
  10148. return true;
  10149. }
  10150. bool SpirvEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
  10151. if (!root)
  10152. return false;
  10153. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  10154. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  10155. if (!caseStmt && !compoundStmt)
  10156. return true;
  10157. if (caseStmt) {
  10158. const Expr *caseExpr = caseStmt->getLHS();
  10159. return caseExpr && caseExpr->isEvaluatable(astContext);
  10160. }
  10161. // Recurse down if facing a compound statement.
  10162. for (auto *st : compoundStmt->body())
  10163. if (!allSwitchCasesAreIntegerLiterals(st))
  10164. return false;
  10165. return true;
  10166. }
  10167. void SpirvEmitter::discoverAllCaseStmtInSwitchStmt(
  10168. const Stmt *root, SpirvBasicBlock **defaultBB,
  10169. std::vector<std::pair<uint32_t, SpirvBasicBlock *>> *targets) {
  10170. if (!root)
  10171. return;
  10172. // A switch case can only appear in DefaultStmt, CaseStmt, or
  10173. // CompoundStmt. For the rest, we can just return.
  10174. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  10175. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  10176. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  10177. if (!defaultStmt && !caseStmt && !compoundStmt)
  10178. return;
  10179. // Recurse down if facing a compound statement.
  10180. if (compoundStmt) {
  10181. for (auto *st : compoundStmt->body())
  10182. discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
  10183. return;
  10184. }
  10185. std::string caseLabel;
  10186. uint32_t caseValue = 0;
  10187. if (defaultStmt) {
  10188. // This is the default branch.
  10189. caseLabel = "switch.default";
  10190. } else if (caseStmt) {
  10191. // This is a non-default case.
  10192. // When using OpSwitch, we only allow integer literal cases. e.g:
  10193. // case <literal_integer>: {...; break;}
  10194. const Expr *caseExpr = caseStmt->getLHS();
  10195. assert(caseExpr && caseExpr->isEvaluatable(astContext));
  10196. auto bitWidth = astContext.getIntWidth(caseExpr->getType());
  10197. if (bitWidth != 32)
  10198. emitError(
  10199. "non-32bit integer case value in switch statement unimplemented",
  10200. caseExpr->getExprLoc());
  10201. Expr::EvalResult evalResult;
  10202. caseExpr->EvaluateAsRValue(evalResult, astContext);
  10203. const int64_t value = evalResult.Val.getInt().getSExtValue();
  10204. caseValue = static_cast<uint32_t>(value);
  10205. caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
  10206. llvm::itostr(std::abs(value));
  10207. }
  10208. auto *caseBB = spvBuilder.createBasicBlock(caseLabel);
  10209. spvBuilder.addSuccessor(caseBB);
  10210. stmtBasicBlock[root] = caseBB;
  10211. // Add all cases to the 'targets' vector.
  10212. if (caseStmt)
  10213. targets->emplace_back(caseValue, caseBB);
  10214. // The default label is not part of the 'targets' vector that is passed
  10215. // to the OpSwitch instruction.
  10216. // If default statement was discovered, return its label via defaultBB.
  10217. if (defaultStmt)
  10218. *defaultBB = caseBB;
  10219. // Process cases nested in other cases. It happens when we have fall through
  10220. // cases. For example:
  10221. // case 1: case 2: ...; break;
  10222. // will result in the CaseSmt for case 2 nested in the one for case 1.
  10223. discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
  10224. : defaultStmt->getSubStmt(),
  10225. defaultBB, targets);
  10226. }
  10227. void SpirvEmitter::flattenSwitchStmtAST(const Stmt *root,
  10228. std::vector<const Stmt *> *flatSwitch) {
  10229. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  10230. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  10231. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  10232. if (!compoundStmt) {
  10233. flatSwitch->push_back(root);
  10234. }
  10235. if (compoundStmt) {
  10236. for (const auto *st : compoundStmt->body())
  10237. flattenSwitchStmtAST(st, flatSwitch);
  10238. } else if (caseStmt) {
  10239. flattenSwitchStmtAST(caseStmt->getSubStmt(), flatSwitch);
  10240. } else if (defaultStmt) {
  10241. flattenSwitchStmtAST(defaultStmt->getSubStmt(), flatSwitch);
  10242. }
  10243. }
  10244. void SpirvEmitter::processCaseStmtOrDefaultStmt(const Stmt *stmt) {
  10245. auto *caseStmt = dyn_cast<CaseStmt>(stmt);
  10246. auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
  10247. assert(caseStmt || defaultStmt);
  10248. auto *caseBB = stmtBasicBlock[stmt];
  10249. if (!spvBuilder.isCurrentBasicBlockTerminated()) {
  10250. // We are about to handle the case passed in as parameter. If the current
  10251. // basic block is not terminated, it means the previous case is a fall
  10252. // through case. We need to link it to the case to be processed.
  10253. spvBuilder.createBranch(caseBB, stmt->getLocStart());
  10254. spvBuilder.addSuccessor(caseBB);
  10255. }
  10256. spvBuilder.setInsertPoint(caseBB);
  10257. doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
  10258. }
  10259. void SpirvEmitter::processSwitchStmtUsingSpirvOpSwitch(
  10260. const SwitchStmt *switchStmt) {
  10261. const SourceLocation srcLoc = switchStmt->getSwitchLoc();
  10262. // First handle the condition variable DeclStmt if one exists.
  10263. // For example: handle 'int a = b' in the following:
  10264. // switch (int a = b) {...}
  10265. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  10266. doDeclStmt(condVarDeclStmt);
  10267. auto *selector = doExpr(switchStmt->getCond());
  10268. // We need a merge block regardless of the number of switch cases.
  10269. // Since OpSwitch always requires a default label, if the switch statement
  10270. // does not have a default branch, we use the merge block as the default
  10271. // target.
  10272. auto *mergeBB = spvBuilder.createBasicBlock("switch.merge");
  10273. spvBuilder.setMergeTarget(mergeBB);
  10274. breakStack.push(mergeBB);
  10275. auto *defaultBB = mergeBB;
  10276. // (literal, labelId) pairs to pass to the OpSwitch instruction.
  10277. std::vector<std::pair<uint32_t, SpirvBasicBlock *>> targets;
  10278. discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
  10279. // Create the OpSelectionMerge and OpSwitch.
  10280. spvBuilder.createSwitch(mergeBB, selector, defaultBB, targets, srcLoc);
  10281. // Handle the switch body.
  10282. doStmt(switchStmt->getBody());
  10283. if (!spvBuilder.isCurrentBasicBlockTerminated())
  10284. spvBuilder.createBranch(mergeBB, switchStmt->getLocEnd());
  10285. spvBuilder.setInsertPoint(mergeBB);
  10286. breakStack.pop();
  10287. }
  10288. void SpirvEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
  10289. std::vector<const Stmt *> flatSwitch;
  10290. flattenSwitchStmtAST(switchStmt->getBody(), &flatSwitch);
  10291. // First handle the condition variable DeclStmt if one exists.
  10292. // For example: handle 'int a = b' in the following:
  10293. // switch (int a = b) {...}
  10294. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  10295. doDeclStmt(condVarDeclStmt);
  10296. // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
  10297. // the flattened switch AST.
  10298. // For instance, for the following flat vector:
  10299. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  10300. // |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
  10301. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  10302. // The indexes are: {0, 2, 5, 6, 9}
  10303. std::vector<uint32_t> caseStmtLocs;
  10304. for (uint32_t i = 0; i < flatSwitch.size(); ++i)
  10305. if (isa<CaseStmt>(flatSwitch[i]) || isa<DefaultStmt>(flatSwitch[i]))
  10306. caseStmtLocs.push_back(i);
  10307. IfStmt *prevIfStmt = nullptr;
  10308. IfStmt *rootIfStmt = nullptr;
  10309. CompoundStmt *defaultBody = nullptr;
  10310. // For each case, start at its index in the vector, and go forward
  10311. // accumulating statements until BreakStmt or end of vector is reached.
  10312. for (auto curCaseIndex : caseStmtLocs) {
  10313. const Stmt *curCase = flatSwitch[curCaseIndex];
  10314. // CompoundStmt to hold all statements for this case.
  10315. CompoundStmt *cs = new (astContext) CompoundStmt(Stmt::EmptyShell());
  10316. // Accumulate all non-case/default/break statements as the body for the
  10317. // current case.
  10318. std::vector<Stmt *> statements;
  10319. for (unsigned i = curCaseIndex + 1;
  10320. i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
  10321. if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
  10322. statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
  10323. }
  10324. if (!statements.empty())
  10325. cs->setStmts(astContext, statements.data(), statements.size());
  10326. // For non-default cases, generate the IfStmt that compares the switch
  10327. // value to the case value.
  10328. if (auto *caseStmt = dyn_cast<CaseStmt>(curCase)) {
  10329. IfStmt *curIf = new (astContext) IfStmt(Stmt::EmptyShell());
  10330. BinaryOperator *bo = new (astContext) BinaryOperator(Stmt::EmptyShell());
  10331. bo->setLHS(const_cast<Expr *>(switchStmt->getCond()));
  10332. bo->setRHS(const_cast<Expr *>(caseStmt->getLHS()));
  10333. bo->setOpcode(BO_EQ);
  10334. bo->setType(astContext.getLogicalOperationType());
  10335. curIf->setCond(bo);
  10336. curIf->setThen(cs);
  10337. // No conditional variable associated with this faux if statement.
  10338. curIf->setConditionVariable(astContext, nullptr);
  10339. // Each If statement is the "else" of the previous if statement.
  10340. if (prevIfStmt)
  10341. prevIfStmt->setElse(curIf);
  10342. else
  10343. rootIfStmt = curIf;
  10344. prevIfStmt = curIf;
  10345. } else {
  10346. // Record the DefaultStmt body as it will be used as the body of the
  10347. // "else" block in the if-elseif-...-else pattern.
  10348. defaultBody = cs;
  10349. }
  10350. }
  10351. // If a default case exists, it is the "else" of the last if statement.
  10352. if (prevIfStmt)
  10353. prevIfStmt->setElse(defaultBody);
  10354. // Since all else-if and else statements are the child nodes of the first
  10355. // IfStmt, we only need to call doStmt for the first IfStmt.
  10356. if (rootIfStmt)
  10357. doStmt(rootIfStmt);
  10358. // If there are no CaseStmt and there is only 1 DefaultStmt, there will be
  10359. // no if statements. The switch in that case only executes the body of the
  10360. // default case.
  10361. else if (defaultBody)
  10362. doStmt(defaultBody);
  10363. }
  10364. SpirvInstruction *SpirvEmitter::extractVecFromVec4(SpirvInstruction *from,
  10365. uint32_t targetVecSize,
  10366. QualType targetElemType,
  10367. SourceLocation loc) {
  10368. assert(targetVecSize > 0 && targetVecSize < 5);
  10369. const QualType retType =
  10370. targetVecSize == 1
  10371. ? targetElemType
  10372. : astContext.getExtVectorType(targetElemType, targetVecSize);
  10373. switch (targetVecSize) {
  10374. case 1:
  10375. return spvBuilder.createCompositeExtract(retType, from, {0}, loc);
  10376. break;
  10377. case 2:
  10378. return spvBuilder.createVectorShuffle(retType, from, from, {0, 1}, loc);
  10379. break;
  10380. case 3:
  10381. return spvBuilder.createVectorShuffle(retType, from, from, {0, 1, 2}, loc);
  10382. break;
  10383. case 4:
  10384. return from;
  10385. default:
  10386. llvm_unreachable("vector element count must be 1, 2, 3, or 4");
  10387. }
  10388. }
  10389. void SpirvEmitter::addFunctionToWorkQueue(hlsl::DXIL::ShaderKind shaderKind,
  10390. const clang::FunctionDecl *fnDecl,
  10391. bool isEntryFunction) {
  10392. // Only update the workQueue and the function info map if the given
  10393. // FunctionDecl hasn't been added already.
  10394. if (functionInfoMap.find(fnDecl) == functionInfoMap.end()) {
  10395. // Note: The function is just discovered and is being added to the
  10396. // workQueue, therefore it does not have the entryFunction SPIR-V
  10397. // instruction yet (use nullptr).
  10398. auto *fnInfo = new (spvContext) FunctionInfo(
  10399. shaderKind, fnDecl, /*entryFunction*/ nullptr, isEntryFunction);
  10400. functionInfoMap[fnDecl] = fnInfo;
  10401. workQueue.push_back(fnInfo);
  10402. }
  10403. }
  10404. SpirvInstruction *
  10405. SpirvEmitter::processTraceRayInline(const CXXMemberCallExpr *expr) {
  10406. emitWarning("SPV_KHR_ray_query is currently a provisional extension and "
  10407. "might change in ways that are not backwards compatible",
  10408. expr->getExprLoc());
  10409. const auto object = expr->getImplicitObjectArgument();
  10410. uint32_t templateFlags = hlsl::GetHLSLResourceTemplateUInt(object->getType());
  10411. const auto constFlags = spvBuilder.getConstantInt(
  10412. astContext.UnsignedIntTy, llvm::APInt(32, templateFlags));
  10413. SpirvInstruction *rayqueryObj = loadIfAliasVarRef(object);
  10414. const auto args = expr->getArgs();
  10415. if (expr->getNumArgs() != 4) {
  10416. emitError("invalid number of arguments to RayQueryInitialize",
  10417. expr->getExprLoc());
  10418. }
  10419. // HLSL Func
  10420. // void RayQuery::TraceRayInline(
  10421. // RaytracingAccelerationStructure AccelerationStructure,
  10422. // uint RayFlags,
  10423. // uint InstanceInclusionMask,
  10424. // RayDesc Ray);
  10425. // void OpRayQueryInitializeKHR ( <id> RayQuery,
  10426. // <id> Acceleration Structure
  10427. // <id> RayFlags
  10428. // <id> CullMask
  10429. // <id> RayOrigin
  10430. // <id> RayTmin
  10431. // <id> RayDirection
  10432. // <id> Ray Tmax)
  10433. const auto accelStructure = doExpr(args[0]);
  10434. SpirvInstruction *rayFlags = nullptr;
  10435. if ((rayFlags = tryToEvaluateAsConst(args[1]))) {
  10436. rayFlags->setRValue();
  10437. } else {
  10438. rayFlags = doExpr(args[1]);
  10439. }
  10440. if (auto constFlags = dyn_cast<SpirvConstantInteger>(rayFlags)) {
  10441. auto interRayFlags = constFlags->getValue().getZExtValue();
  10442. templateFlags |= interRayFlags;
  10443. }
  10444. bool hasCullFlags =
  10445. templateFlags & (uint32_t(hlsl::DXIL::RayFlag::SkipTriangles) |
  10446. uint32_t(hlsl::DXIL::RayFlag::SkipProceduralPrimitives));
  10447. auto loc = args[1]->getLocStart();
  10448. rayFlags =
  10449. spvBuilder.createBinaryOp(spv::Op::OpBitwiseOr, astContext.UnsignedIntTy,
  10450. constFlags, rayFlags, loc);
  10451. const auto cullMask = doExpr(args[2]);
  10452. // Extract the ray description to match SPIR-V
  10453. const auto floatType = astContext.FloatTy;
  10454. const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
  10455. SpirvInstruction *rayDescArg = doExpr(args[3]);
  10456. loc = args[3]->getLocStart();
  10457. const auto origin =
  10458. spvBuilder.createCompositeExtract(vecType, rayDescArg, {0}, loc);
  10459. const auto tMin =
  10460. spvBuilder.createCompositeExtract(floatType, rayDescArg, {1}, loc);
  10461. const auto direction =
  10462. spvBuilder.createCompositeExtract(vecType, rayDescArg, {2}, loc);
  10463. const auto tMax =
  10464. spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
  10465. llvm::SmallVector<SpirvInstruction *, 8> traceArgs = {
  10466. rayqueryObj, accelStructure, rayFlags, cullMask,
  10467. origin, tMin, direction, tMax};
  10468. return spvBuilder.createRayQueryOpsKHR(spv::Op::OpRayQueryInitializeKHR,
  10469. QualType(), traceArgs, hasCullFlags,
  10470. expr->getExprLoc());
  10471. }
  10472. SpirvInstruction *
  10473. SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
  10474. hlsl::IntrinsicOp opcode) {
  10475. emitWarning("SPV_KHR_ray_query is currently a provisional extension and "
  10476. "might change in ways that are not backwards compatible",
  10477. expr->getExprLoc());
  10478. const auto object = expr->getImplicitObjectArgument();
  10479. SpirvInstruction *rayqueryObj = loadIfAliasVarRef(object);
  10480. const auto args = expr->getArgs();
  10481. llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
  10482. traceArgs.push_back(rayqueryObj);
  10483. for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
  10484. traceArgs.push_back(doExpr(args[i]));
  10485. }
  10486. spv::Op spvCode = spv::Op::Max;
  10487. QualType exprType = expr->getType();
  10488. exprType = exprType->isVoidType() ? QualType() : exprType;
  10489. const auto candidateIntersection =
  10490. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
  10491. const auto committedIntersection =
  10492. spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 1));
  10493. bool transposeMatrix = false;
  10494. bool logicalNot = false;
  10495. using namespace hlsl;
  10496. switch (opcode) {
  10497. case IntrinsicOp::MOP_Proceed:
  10498. spvCode = spv::Op::OpRayQueryProceedKHR;
  10499. break;
  10500. case IntrinsicOp::MOP_Abort:
  10501. spvCode = spv::Op::OpRayQueryTerminateKHR;
  10502. exprType = QualType();
  10503. break;
  10504. case IntrinsicOp::MOP_CandidateGeometryIndex:
  10505. traceArgs.push_back(candidateIntersection);
  10506. spvCode = spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR;
  10507. break;
  10508. case IntrinsicOp::MOP_CandidateInstanceContributionToHitGroupIndex:
  10509. traceArgs.push_back(candidateIntersection);
  10510. spvCode = spv::Op::
  10511. OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR;
  10512. break;
  10513. case IntrinsicOp::MOP_CandidateInstanceID:
  10514. traceArgs.push_back(candidateIntersection);
  10515. spvCode = spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR;
  10516. break;
  10517. case IntrinsicOp::MOP_CandidateInstanceIndex:
  10518. traceArgs.push_back(candidateIntersection);
  10519. spvCode = spv::Op::OpRayQueryGetIntersectionInstanceIdKHR;
  10520. break;
  10521. case IntrinsicOp::MOP_CandidateObjectRayDirection:
  10522. traceArgs.push_back(candidateIntersection);
  10523. spvCode = spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR;
  10524. break;
  10525. case IntrinsicOp::MOP_CandidateObjectRayOrigin:
  10526. traceArgs.push_back(candidateIntersection);
  10527. spvCode = spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR;
  10528. break;
  10529. case IntrinsicOp::MOP_CandidateObjectToWorld3x4:
  10530. spvCode = spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR;
  10531. traceArgs.push_back(candidateIntersection);
  10532. transposeMatrix = true;
  10533. break;
  10534. case IntrinsicOp::MOP_CandidateObjectToWorld4x3:
  10535. spvCode = spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR;
  10536. traceArgs.push_back(candidateIntersection);
  10537. break;
  10538. case IntrinsicOp::MOP_CandidatePrimitiveIndex:
  10539. traceArgs.push_back(candidateIntersection);
  10540. spvCode = spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR;
  10541. break;
  10542. case IntrinsicOp::MOP_CandidateProceduralPrimitiveNonOpaque:
  10543. spvCode = spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR;
  10544. logicalNot = true;
  10545. break;
  10546. case IntrinsicOp::MOP_CandidateTriangleBarycentrics:
  10547. traceArgs.push_back(candidateIntersection);
  10548. spvCode = spv::Op::OpRayQueryGetIntersectionBarycentricsKHR;
  10549. break;
  10550. case IntrinsicOp::MOP_CandidateTriangleFrontFace:
  10551. traceArgs.push_back(candidateIntersection);
  10552. spvCode = spv::Op::OpRayQueryGetIntersectionFrontFaceKHR;
  10553. break;
  10554. case IntrinsicOp::MOP_CandidateTriangleRayT:
  10555. traceArgs.push_back(candidateIntersection);
  10556. spvCode = spv::Op::OpRayQueryGetIntersectionTKHR;
  10557. break;
  10558. case IntrinsicOp::MOP_CandidateType:
  10559. spvCode = spv::Op::OpRayQueryGetIntersectionTypeKHR;
  10560. traceArgs.push_back(candidateIntersection);
  10561. break;
  10562. case IntrinsicOp::MOP_CandidateWorldToObject4x3:
  10563. spvCode = spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR;
  10564. traceArgs.push_back(candidateIntersection);
  10565. break;
  10566. case IntrinsicOp::MOP_CandidateWorldToObject3x4:
  10567. spvCode = spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR;
  10568. traceArgs.push_back(candidateIntersection);
  10569. transposeMatrix = true;
  10570. break;
  10571. case IntrinsicOp::MOP_CommitNonOpaqueTriangleHit:
  10572. spvCode = spv::Op::OpRayQueryConfirmIntersectionKHR;
  10573. exprType = QualType();
  10574. break;
  10575. case IntrinsicOp::MOP_CommitProceduralPrimitiveHit:
  10576. spvCode = spv::Op::OpRayQueryGenerateIntersectionKHR;
  10577. exprType = QualType();
  10578. break;
  10579. case IntrinsicOp::MOP_CommittedGeometryIndex:
  10580. spvCode = spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR;
  10581. traceArgs.push_back(committedIntersection);
  10582. break;
  10583. case IntrinsicOp::MOP_CommittedInstanceContributionToHitGroupIndex:
  10584. spvCode = spv::Op::
  10585. OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR;
  10586. traceArgs.push_back(committedIntersection);
  10587. break;
  10588. case IntrinsicOp::MOP_CommittedInstanceID:
  10589. spvCode = spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR;
  10590. traceArgs.push_back(committedIntersection);
  10591. break;
  10592. case IntrinsicOp::MOP_CommittedInstanceIndex:
  10593. spvCode = spv::Op::OpRayQueryGetIntersectionInstanceIdKHR;
  10594. traceArgs.push_back(committedIntersection);
  10595. break;
  10596. case IntrinsicOp::MOP_CommittedObjectRayDirection:
  10597. spvCode = spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR;
  10598. traceArgs.push_back(committedIntersection);
  10599. break;
  10600. case IntrinsicOp::MOP_CommittedObjectRayOrigin:
  10601. spvCode = spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR;
  10602. traceArgs.push_back(committedIntersection);
  10603. break;
  10604. case IntrinsicOp::MOP_CommittedObjectToWorld3x4:
  10605. spvCode = spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR;
  10606. traceArgs.push_back(committedIntersection);
  10607. transposeMatrix = true;
  10608. break;
  10609. case IntrinsicOp::MOP_CommittedObjectToWorld4x3:
  10610. spvCode = spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR;
  10611. traceArgs.push_back(committedIntersection);
  10612. break;
  10613. case IntrinsicOp::MOP_CommittedPrimitiveIndex:
  10614. spvCode = spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR;
  10615. traceArgs.push_back(committedIntersection);
  10616. break;
  10617. case IntrinsicOp::MOP_CommittedRayT:
  10618. spvCode = spv::Op::OpRayQueryGetIntersectionTKHR;
  10619. traceArgs.push_back(committedIntersection);
  10620. break;
  10621. case IntrinsicOp::MOP_CommittedStatus:
  10622. spvCode = spv::Op::OpRayQueryGetIntersectionTypeKHR;
  10623. traceArgs.push_back(committedIntersection);
  10624. break;
  10625. case IntrinsicOp::MOP_CommittedTriangleBarycentrics:
  10626. spvCode = spv::Op::OpRayQueryGetIntersectionBarycentricsKHR;
  10627. traceArgs.push_back(committedIntersection);
  10628. break;
  10629. case IntrinsicOp::MOP_CommittedTriangleFrontFace:
  10630. spvCode = spv::Op::OpRayQueryGetIntersectionFrontFaceKHR;
  10631. traceArgs.push_back(committedIntersection);
  10632. break;
  10633. case IntrinsicOp::MOP_CommittedWorldToObject3x4:
  10634. spvCode = spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR;
  10635. traceArgs.push_back(committedIntersection);
  10636. transposeMatrix = true;
  10637. break;
  10638. case IntrinsicOp::MOP_CommittedWorldToObject4x3:
  10639. spvCode = spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR;
  10640. traceArgs.push_back(committedIntersection);
  10641. break;
  10642. case IntrinsicOp::MOP_RayFlags:
  10643. spvCode = spv::Op::OpRayQueryGetRayFlagsKHR;
  10644. break;
  10645. case IntrinsicOp::MOP_RayTMin:
  10646. spvCode = spv::Op::OpRayQueryGetRayTMinKHR;
  10647. break;
  10648. case IntrinsicOp::MOP_WorldRayDirection:
  10649. spvCode = spv::Op::OpRayQueryGetWorldRayDirectionKHR;
  10650. break;
  10651. case IntrinsicOp::MOP_WorldRayOrigin:
  10652. spvCode = spv::Op::OpRayQueryGetWorldRayOriginKHR;
  10653. break;
  10654. default:
  10655. emitError("intrinsic '%0' method unimplemented",
  10656. expr->getCallee()->getExprLoc())
  10657. << expr->getDirectCallee()->getName();
  10658. return nullptr;
  10659. }
  10660. if (transposeMatrix) {
  10661. assert(hlsl::IsHLSLMatType(exprType) && "intrinsic should be matrix");
  10662. const clang::Type *type = exprType.getCanonicalType().getTypePtr();
  10663. const RecordType *RT = cast<RecordType>(type);
  10664. const ClassTemplateSpecializationDecl *templateSpecDecl =
  10665. cast<ClassTemplateSpecializationDecl>(RT->getDecl());
  10666. ClassTemplateDecl *templateDecl =
  10667. templateSpecDecl->getSpecializedTemplate();
  10668. exprType = getHLSLMatrixType(astContext, theCompilerInstance.getSema(),
  10669. templateDecl, astContext.FloatTy, 4, 3);
  10670. }
  10671. const auto loc = expr->getExprLoc();
  10672. SpirvInstruction *retVal =
  10673. spvBuilder.createRayQueryOpsKHR(spvCode, exprType, traceArgs, false, loc);
  10674. if (transposeMatrix) {
  10675. retVal = spvBuilder.createUnaryOp(spv::Op::OpTranspose, expr->getType(),
  10676. retVal, loc);
  10677. }
  10678. if (logicalNot) {
  10679. retVal = spvBuilder.createUnaryOp(spv::Op::OpLogicalNot, expr->getType(),
  10680. retVal, loc);
  10681. }
  10682. retVal->setRValue();
  10683. return retVal;
  10684. }
  10685. bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
  10686. std::string *messages) {
  10687. spvtools::SpirvTools tools(featureManager.getTargetEnv());
  10688. tools.SetMessageConsumer(
  10689. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  10690. const spv_position_t & /*position*/,
  10691. const char *message) { *messages += message; });
  10692. spvtools::ValidatorOptions options;
  10693. options.SetBeforeHlslLegalization(needsLegalization ||
  10694. declIdMapper.requiresLegalization());
  10695. // GL: strict block layout rules
  10696. // VK: relaxed block layout rules
  10697. // DX: Skip block layout rules
  10698. if (spirvOptions.useScalarLayout || spirvOptions.useDxLayout) {
  10699. options.SetScalarBlockLayout(true);
  10700. } else if (spirvOptions.useGlLayout) {
  10701. // spirv-val by default checks this.
  10702. } else {
  10703. options.SetRelaxBlockLayout(true);
  10704. }
  10705. return tools.Validate(mod->data(), mod->size(), options);
  10706. }
  10707. bool SpirvEmitter::spirvToolsOptimize(std::vector<uint32_t> *mod,
  10708. std::string *messages) {
  10709. spvtools::Optimizer optimizer(featureManager.getTargetEnv());
  10710. optimizer.SetMessageConsumer(
  10711. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  10712. const spv_position_t & /*position*/,
  10713. const char *message) { *messages += message; });
  10714. spvtools::OptimizerOptions options;
  10715. options.set_run_validator(false);
  10716. if (spirvOptions.optConfig.empty()) {
  10717. // Add performance passes.
  10718. optimizer.RegisterPerformancePasses();
  10719. // Add flattening of resources if needed.
  10720. if (spirvOptions.flattenResourceArrays ||
  10721. declIdMapper.requiresFlatteningCompositeResources()) {
  10722. optimizer.RegisterPass(spvtools::CreateDescriptorScalarReplacementPass());
  10723. // ADCE should be run after desc_sroa in order to remove potentially
  10724. // illegal types such as structures containing opaque types.
  10725. optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
  10726. }
  10727. // Add compact ID pass.
  10728. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  10729. } else {
  10730. // Command line options use llvm::SmallVector and llvm::StringRef, whereas
  10731. // SPIR-V optimizer uses std::vector and std::string.
  10732. std::vector<std::string> stdFlags;
  10733. for (const auto &f : spirvOptions.optConfig)
  10734. stdFlags.push_back(f.str());
  10735. if (!optimizer.RegisterPassesFromFlags(stdFlags))
  10736. return false;
  10737. }
  10738. return optimizer.Run(mod->data(), mod->size(), mod, options);
  10739. }
  10740. bool SpirvEmitter::spirvToolsLegalize(std::vector<uint32_t> *mod,
  10741. std::string *messages) {
  10742. spvtools::Optimizer optimizer(featureManager.getTargetEnv());
  10743. optimizer.SetMessageConsumer(
  10744. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  10745. const spv_position_t & /*position*/,
  10746. const char *message) { *messages += message; });
  10747. spvtools::OptimizerOptions options;
  10748. options.set_run_validator(false);
  10749. optimizer.RegisterLegalizationPasses();
  10750. optimizer.RegisterPass(spvtools::CreateReplaceInvalidOpcodePass());
  10751. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  10752. return optimizer.Run(mod->data(), mod->size(), mod, options);
  10753. }
  10754. SpirvInstruction *
  10755. SpirvEmitter::doUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *expr) {
  10756. // TODO: We support only `sizeof()`. Support other kinds.
  10757. if (expr->getKind() != clang::UnaryExprOrTypeTrait::UETT_SizeOf) {
  10758. emitError("expression class '%0' unimplemented", expr->getExprLoc())
  10759. << expr->getStmtClassName();
  10760. return nullptr;
  10761. }
  10762. AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
  10763. uint32_t size = 0, stride = 0;
  10764. std::tie(std::ignore, size) = alignmentCalc.getAlignmentAndSize(
  10765. expr->getArgumentType(), SpirvLayoutRule::Void,
  10766. /*isRowMajor*/ llvm::None, &stride);
  10767. auto *sizeConst = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  10768. llvm::APInt(32, size));
  10769. sizeConst->setRValue();
  10770. return sizeConst;
  10771. }
  10772. } // end namespace spirv
  10773. } // end namespace clang