SPIRVEmitter.cpp 250 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330
  1. //===------- SPIRVEmitter.h - SPIR-V Binary Code Emitter --------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a SPIR-V emitter class that takes in HLSL AST and emits
  10. // SPIR-V binary words.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "SPIRVEmitter.h"
  14. #include "dxc/HlslIntrinsicOp.h"
  15. #include "spirv-tools/optimizer.hpp"
  16. #include "llvm/ADT/StringExtras.h"
  17. #include "InitListHandler.h"
  18. namespace clang {
  19. namespace spirv {
  20. namespace {
  21. // Returns true if the given decl has the given semantic.
  22. bool hasSemantic(const DeclaratorDecl *decl,
  23. hlsl::DXIL::SemanticKind semanticKind) {
  24. using namespace hlsl;
  25. for (auto *annotation : decl->getUnusualAnnotations()) {
  26. if (auto *semanticDecl = dyn_cast<SemanticDecl>(annotation)) {
  27. llvm::StringRef semanticName;
  28. uint32_t semanticIndex = 0;
  29. Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName,
  30. &semanticIndex);
  31. const auto *semantic = Semantic::GetByName(semanticName);
  32. if (semantic->GetKind() == semanticKind)
  33. return true;
  34. }
  35. }
  36. return false;
  37. }
  38. bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
  39. for (const auto *param : pcf->parameters())
  40. if (hlsl::IsHLSLOutputPatchType(param->getType()))
  41. return true;
  42. return false;
  43. }
  44. // TODO: Maybe we should move these type probing functions to TypeTranslator.
  45. /// Returns true if the two types are the same scalar or vector type.
  46. bool isSameScalarOrVecType(QualType type1, QualType type2) {
  47. {
  48. QualType scalarType1 = {}, scalarType2 = {};
  49. if (TypeTranslator::isScalarType(type1, &scalarType1) &&
  50. TypeTranslator::isScalarType(type2, &scalarType2))
  51. return scalarType1.getCanonicalType() == scalarType2.getCanonicalType();
  52. }
  53. {
  54. QualType elemType1 = {}, elemType2 = {};
  55. uint32_t count1 = {}, count2 = {};
  56. if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
  57. TypeTranslator::isVectorType(type2, &elemType2, &count2))
  58. return count1 == count2 &&
  59. elemType1.getCanonicalType() == elemType2.getCanonicalType();
  60. }
  61. return false;
  62. }
  63. /// Returns true if the given type is a bool or vector of bool type.
  64. bool isBoolOrVecOfBoolType(QualType type) {
  65. QualType elemType = {};
  66. return (TypeTranslator::isScalarType(type, &elemType) ||
  67. TypeTranslator::isVectorType(type, &elemType)) &&
  68. elemType->isBooleanType();
  69. }
  70. /// Returns true if the given type is a signed integer or vector of signed
  71. /// integer type.
  72. bool isSintOrVecOfSintType(QualType type) {
  73. QualType elemType = {};
  74. return (TypeTranslator::isScalarType(type, &elemType) ||
  75. TypeTranslator::isVectorType(type, &elemType)) &&
  76. elemType->isSignedIntegerType();
  77. }
  78. /// Returns true if the given type is an unsigned integer or vector of unsigned
  79. /// integer type.
  80. bool isUintOrVecOfUintType(QualType type) {
  81. QualType elemType = {};
  82. return (TypeTranslator::isScalarType(type, &elemType) ||
  83. TypeTranslator::isVectorType(type, &elemType)) &&
  84. elemType->isUnsignedIntegerType();
  85. }
  86. /// Returns true if the given type is a float or vector of float type.
  87. bool isFloatOrVecOfFloatType(QualType type) {
  88. QualType elemType = {};
  89. return (TypeTranslator::isScalarType(type, &elemType) ||
  90. TypeTranslator::isVectorType(type, &elemType)) &&
  91. elemType->isFloatingType();
  92. }
  93. /// Returns true if the given type is a bool or vector/matrix of bool type.
  94. bool isBoolOrVecMatOfBoolType(QualType type) {
  95. return isBoolOrVecOfBoolType(type) ||
  96. (hlsl::IsHLSLMatType(type) &&
  97. hlsl::GetHLSLMatElementType(type)->isBooleanType());
  98. }
  99. /// Returns true if the given type is a signed integer or vector/matrix of
  100. /// signed integer type.
  101. bool isSintOrVecMatOfSintType(QualType type) {
  102. return isSintOrVecOfSintType(type) ||
  103. (hlsl::IsHLSLMatType(type) &&
  104. hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
  105. }
  106. /// Returns true if the given type is an unsigned integer or vector/matrix of
  107. /// unsigned integer type.
  108. bool isUintOrVecMatOfUintType(QualType type) {
  109. return isUintOrVecOfUintType(type) ||
  110. (hlsl::IsHLSLMatType(type) &&
  111. hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
  112. }
  113. /// Returns true if the given type is a float or vector/matrix of float type.
  114. bool isFloatOrVecMatOfFloatType(QualType type) {
  115. return isFloatOrVecOfFloatType(type) ||
  116. (hlsl::IsHLSLMatType(type) &&
  117. hlsl::GetHLSLMatElementType(type)->isFloatingType());
  118. }
  119. bool isSpirvMatrixOp(spv::Op opcode) {
  120. switch (opcode) {
  121. case spv::Op::OpMatrixTimesMatrix:
  122. case spv::Op::OpMatrixTimesVector:
  123. case spv::Op::OpMatrixTimesScalar:
  124. return true;
  125. default:
  126. break;
  127. }
  128. return false;
  129. }
  130. /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes
  131. /// index. Otherwiser, returns false.
  132. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a
  133. // duplicate of doCXXMemberCallExpr.
  134. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
  135. using namespace hlsl;
  136. if (const auto *indexing = dyn_cast<CXXMemberCallExpr>(expr)) {
  137. const auto *callee = indexing->getDirectCallee();
  138. uint32_t opcode = static_cast<uint32_t>(IntrinsicOp::Num_Intrinsics);
  139. llvm::StringRef group;
  140. if (GetIntrinsicOp(callee, opcode, group)) {
  141. if (static_cast<IntrinsicOp>(opcode) == IntrinsicOp::MOP_Load) {
  142. const auto *object = indexing->getImplicitObjectArgument();
  143. if (TypeTranslator::isStructuredBuffer(object->getType())) {
  144. *index = indexing->getArg(0);
  145. return indexing->getImplicitObjectArgument();
  146. }
  147. }
  148. }
  149. }
  150. return nullptr;
  151. }
  152. bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
  153. spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
  154. optimizer.SetMessageConsumer(
  155. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  156. const spv_position_t & /*position*/,
  157. const char *message) { *messages += message; });
  158. optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
  159. optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass());
  160. optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass());
  161. optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass());
  162. optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
  163. optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
  164. optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass());
  165. optimizer.RegisterPass(spvtools::CreateBlockMergePass());
  166. optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass());
  167. optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
  168. optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
  169. optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass());
  170. optimizer.RegisterPass(spvtools::CreateCFGCleanupPass());
  171. optimizer.RegisterPass(spvtools::CreateDeadVariableEliminationPass());
  172. optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass());
  173. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  174. return optimizer.Run(module->data(), module->size(), module);
  175. }
  176. /// Translates atomic HLSL opcodes into the equivalent SPIR-V opcode.
  177. spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) {
  178. using namespace hlsl;
  179. using namespace spv;
  180. switch (opcode) {
  181. case IntrinsicOp::IOP_InterlockedAdd:
  182. case IntrinsicOp::MOP_InterlockedAdd:
  183. return Op::OpAtomicIAdd;
  184. case IntrinsicOp::IOP_InterlockedAnd:
  185. case IntrinsicOp::MOP_InterlockedAnd:
  186. return Op::OpAtomicAnd;
  187. case IntrinsicOp::IOP_InterlockedOr:
  188. case IntrinsicOp::MOP_InterlockedOr:
  189. return Op::OpAtomicOr;
  190. case IntrinsicOp::IOP_InterlockedXor:
  191. case IntrinsicOp::MOP_InterlockedXor:
  192. return Op::OpAtomicXor;
  193. case IntrinsicOp::IOP_InterlockedUMax:
  194. case IntrinsicOp::MOP_InterlockedUMax:
  195. return Op::OpAtomicUMax;
  196. case IntrinsicOp::IOP_InterlockedUMin:
  197. case IntrinsicOp::MOP_InterlockedUMin:
  198. return Op::OpAtomicUMin;
  199. case IntrinsicOp::IOP_InterlockedMax:
  200. case IntrinsicOp::MOP_InterlockedMax:
  201. return Op::OpAtomicSMax;
  202. case IntrinsicOp::IOP_InterlockedMin:
  203. case IntrinsicOp::MOP_InterlockedMin:
  204. return Op::OpAtomicSMin;
  205. case IntrinsicOp::IOP_InterlockedExchange:
  206. case IntrinsicOp::MOP_InterlockedExchange:
  207. return Op::OpAtomicExchange;
  208. }
  209. assert(false && "unimplemented hlsl intrinsic opcode");
  210. return Op::Max;
  211. }
  212. /// Returns true if the given function parameter can act as shader stage
  213. /// input parameter.
  214. inline bool canActAsInParmVar(const ParmVarDecl *param) {
  215. // If the parameter has no in/out/inout attribute, it is defaulted to
  216. // an in parameter.
  217. return !param->hasAttr<HLSLOutAttr>() &&
  218. // GS output streams are marked as inout, but it should not be
  219. // used as in parameter.
  220. !hlsl::IsHLSLStreamOutputType(param->getType());
  221. }
  222. /// Returns true if the given function parameter can act as shader stage
  223. /// output parameter.
  224. inline bool canActAsOutParmVar(const ParmVarDecl *param) {
  225. return param->hasAttr<HLSLOutAttr>() || param->hasAttr<HLSLInOutAttr>();
  226. }
  227. } // namespace
  228. SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
  229. const EmitSPIRVOptions &options)
  230. : theCompilerInstance(ci), astContext(ci.getASTContext()),
  231. diags(ci.getDiagnostics()), spirvOptions(options),
  232. entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
  233. shaderModel(*hlsl::ShaderModel::GetByName(
  234. ci.getCodeGenOpts().HLSLProfile.c_str())),
  235. theContext(), theBuilder(&theContext),
  236. declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
  237. typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
  238. curFunction(nullptr), curThis(0), needsLegalization(false) {
  239. if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
  240. emitError("unknown shader module: %0", {}) << shaderModel.GetName();
  241. }
  242. void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
  243. // Stop translating if there are errors in previous compilation stages.
  244. if (context.getDiagnostics().hasErrorOccurred())
  245. return;
  246. TranslationUnitDecl *tu = context.getTranslationUnitDecl();
  247. // The entry function is the seed of the queue.
  248. for (auto *decl : tu->decls()) {
  249. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  250. if (funcDecl->getName() == entryFunctionName) {
  251. workQueue.insert(funcDecl);
  252. }
  253. if (context.IsPatchConstantFunctionDecl(funcDecl)) {
  254. patchConstFunc = funcDecl;
  255. }
  256. } else if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
  257. if (isa<HLSLBufferDecl>(varDecl->getDeclContext())) {
  258. // This is a VarDecl of a ConstantBuffer/TextureBuffer type.
  259. (void)declIdMapper.createCTBuffer(varDecl);
  260. } else {
  261. doVarDecl(varDecl);
  262. }
  263. } else if (auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
  264. // This is a cbuffer/tbuffer decl.
  265. (void)declIdMapper.createCTBuffer(bufferDecl);
  266. }
  267. }
  268. // Translate all functions reachable from the entry function.
  269. // The queue can grow in the meanwhile; so need to keep evaluating
  270. // workQueue.size().
  271. for (uint32_t i = 0; i < workQueue.size(); ++i) {
  272. doDecl(workQueue[i]);
  273. }
  274. if (context.getDiagnostics().hasErrorOccurred())
  275. return;
  276. AddRequiredCapabilitiesForShaderModel();
  277. // Addressing and memory model are required in a valid SPIR-V module.
  278. theBuilder.setAddressingModel(spv::AddressingModel::Logical);
  279. theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
  280. theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId,
  281. entryFunctionName, declIdMapper.collectStageVars());
  282. AddExecutionModeForEntryPoint(entryFunctionId);
  283. // Add Location decorations to stage input/output variables.
  284. if (!declIdMapper.decorateStageIOLocations())
  285. return;
  286. // Add descriptor set and binding decorations to resource variables.
  287. if (!declIdMapper.decorateResourceBindings())
  288. return;
  289. // Output the constructed module.
  290. std::vector<uint32_t> m = theBuilder.takeModule();
  291. const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel;
  292. if (needsLegalization || optLevel > 0) {
  293. if (needsLegalization && optLevel == 0)
  294. emitWarning("-O0 ignored since SPIR-V legalization required", {});
  295. std::string messages;
  296. if (!spirvToolsOptimize(&m, &messages)) {
  297. emitFatalError("failed to legalize/optimize SPIR-V: %0", {}) << messages;
  298. return;
  299. }
  300. }
  301. theCompilerInstance.getOutStream()->write(
  302. reinterpret_cast<const char *>(m.data()), m.size() * 4);
  303. }
  304. void SPIRVEmitter::doDecl(const Decl *decl) {
  305. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  306. doVarDecl(varDecl);
  307. } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  308. doFunctionDecl(funcDecl);
  309. } else if (dyn_cast<HLSLBufferDecl>(decl)) {
  310. llvm_unreachable("HLSLBufferDecl should not be handled here");
  311. } else {
  312. emitError("decl type %0 unimplemented", decl->getLocation())
  313. << decl->getDeclKindName();
  314. }
  315. }
  316. void SPIRVEmitter::doStmt(const Stmt *stmt,
  317. llvm::ArrayRef<const Attr *> attrs) {
  318. if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
  319. for (auto *st : compoundStmt->body())
  320. doStmt(st);
  321. } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
  322. doReturnStmt(retStmt);
  323. } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
  324. doDeclStmt(declStmt);
  325. } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
  326. doIfStmt(ifStmt);
  327. } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
  328. doSwitchStmt(switchStmt, attrs);
  329. } else if (const auto *caseStmt = dyn_cast<CaseStmt>(stmt)) {
  330. processCaseStmtOrDefaultStmt(stmt);
  331. } else if (const auto *defaultStmt = dyn_cast<DefaultStmt>(stmt)) {
  332. processCaseStmtOrDefaultStmt(stmt);
  333. } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
  334. doBreakStmt(breakStmt);
  335. } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
  336. doDoStmt(theDoStmt, attrs);
  337. } else if (const auto *discardStmt = dyn_cast<DiscardStmt>(stmt)) {
  338. doDiscardStmt(discardStmt);
  339. } else if (const auto *continueStmt = dyn_cast<ContinueStmt>(stmt)) {
  340. doContinueStmt(continueStmt);
  341. } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
  342. doWhileStmt(whileStmt, attrs);
  343. } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
  344. doForStmt(forStmt, attrs);
  345. } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
  346. // For the null statement ";". We don't need to do anything.
  347. } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
  348. // All cases for expressions used as statements
  349. doExpr(expr);
  350. } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
  351. doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
  352. } else {
  353. emitError("statement class '%0' unimplemented", stmt->getLocStart())
  354. << stmt->getStmtClassName() << stmt->getSourceRange();
  355. }
  356. }
  357. SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
  358. if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  359. return declIdMapper.getDeclResultId(delRefExpr->getFoundDecl());
  360. }
  361. if (const auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
  362. // Just need to return what's inside the parentheses.
  363. return doExpr(parenExpr->getSubExpr());
  364. }
  365. if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
  366. return doMemberExpr(memberExpr);
  367. }
  368. if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
  369. return doCastExpr(castExpr);
  370. }
  371. if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
  372. return doInitListExpr(initListExpr);
  373. }
  374. if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
  375. const bool value = boolLiteral->getValue();
  376. return SpirvEvalInfo::withConst(theBuilder.getConstantBool(value));
  377. }
  378. if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  379. return SpirvEvalInfo::withConst(
  380. translateAPInt(intLiteral->getValue(), expr->getType()));
  381. }
  382. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  383. return SpirvEvalInfo::withConst(
  384. translateAPFloat(floatLiteral->getValue(), expr->getType()));
  385. }
  386. // CompoundAssignOperator is a subclass of BinaryOperator. It should be
  387. // checked before BinaryOperator.
  388. if (const auto *compoundAssignOp = dyn_cast<CompoundAssignOperator>(expr)) {
  389. return doCompoundAssignOperator(compoundAssignOp);
  390. }
  391. if (const auto *binOp = dyn_cast<BinaryOperator>(expr)) {
  392. return doBinaryOperator(binOp);
  393. }
  394. if (const auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
  395. return doUnaryOperator(unaryOp);
  396. }
  397. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  398. return doHLSLVectorElementExpr(vecElemExpr);
  399. }
  400. if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
  401. return doExtMatrixElementExpr(matElemExpr);
  402. }
  403. if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
  404. return doCallExpr(funcCall);
  405. }
  406. if (const auto *subscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  407. return doArraySubscriptExpr(subscriptExpr);
  408. }
  409. if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
  410. return doConditionalOperator(condExpr);
  411. }
  412. if (isa<CXXThisExpr>(expr)) {
  413. assert(curThis);
  414. return curThis;
  415. }
  416. emitError("expression class '%0' unimplemented", expr->getExprLoc())
  417. << expr->getStmtClassName() << expr->getSourceRange();
  418. return 0;
  419. }
  420. SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
  421. auto info = doExpr(expr);
  422. if (expr->isGLValue())
  423. info.resultId = theBuilder.createLoad(
  424. typeTranslator.translateType(expr->getType(), info.layoutRule),
  425. info.resultId);
  426. return info;
  427. }
  428. uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
  429. QualType toType, SourceLocation srcLoc) {
  430. if (isFloatOrVecOfFloatType(toType))
  431. return castToFloat(value, fromType, toType, srcLoc);
  432. // Order matters here. Bool (vector) values will also be considered as uint
  433. // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
  434. // will also return true. We need to check bool before uint. The opposite is
  435. // not true.
  436. if (isBoolOrVecOfBoolType(toType))
  437. return castToBool(value, fromType, toType);
  438. if (isSintOrVecOfSintType(toType) || isUintOrVecOfUintType(toType))
  439. return castToInt(value, fromType, toType, srcLoc);
  440. emitError("casting to type %0 unimplemented", {}) << toType;
  441. return 0;
  442. }
  443. void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
  444. // We are about to start translation for a new function. Clear the break stack
  445. // and the continue stack.
  446. breakStack = std::stack<uint32_t>();
  447. continueStack = std::stack<uint32_t>();
  448. curFunction = decl;
  449. std::string funcName = decl->getName();
  450. uint32_t funcId = 0;
  451. if (funcName == entryFunctionName) {
  452. // The entry function surely does not have pre-assigned <result-id> for
  453. // it like other functions that got added to the work queue following
  454. // function calls.
  455. funcId = theContext.takeNextId();
  456. funcName = "src." + funcName;
  457. // Create wrapper for the entry function
  458. if (!emitEntryFunctionWrapper(decl, funcId))
  459. return;
  460. } else {
  461. // Non-entry functions are added to the work queue following function
  462. // calls. We have already assigned <result-id>s for it when translating
  463. // its call site. Query it here.
  464. funcId = declIdMapper.getDeclResultId(decl).resultId;
  465. }
  466. if (!needsLegalization &&
  467. TypeTranslator::isOpaqueStructType(decl->getReturnType()))
  468. needsLegalization = true;
  469. const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
  470. // Construct the function signature.
  471. llvm::SmallVector<uint32_t, 4> paramTypes;
  472. bool isNonStaticMemberFn = false;
  473. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
  474. isNonStaticMemberFn = !memberFn->isStatic();
  475. if (isNonStaticMemberFn) {
  476. // For non-static member function, the first parameter should be the
  477. // object on which we are invoking this method.
  478. const uint32_t valueType = typeTranslator.translateType(
  479. memberFn->getThisType(astContext)->getPointeeType());
  480. const uint32_t ptrType =
  481. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  482. paramTypes.push_back(ptrType);
  483. }
  484. // Prefix the function name with the struct name
  485. if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
  486. funcName = st->getName().str() + "." + funcName;
  487. }
  488. for (const auto *param : decl->params()) {
  489. const uint32_t valueType = typeTranslator.translateType(param->getType());
  490. const uint32_t ptrType =
  491. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  492. paramTypes.push_back(ptrType);
  493. if (!needsLegalization &&
  494. TypeTranslator::isOpaqueStructType(param->getType()))
  495. needsLegalization = true;
  496. }
  497. const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
  498. theBuilder.beginFunction(funcType, retType, funcName, funcId);
  499. if (isNonStaticMemberFn) {
  500. // Remember the parameter for the this object so later we can handle
  501. // CXXThisExpr correctly.
  502. curThis = theBuilder.addFnParam(paramTypes[0], "param.this");
  503. }
  504. // Create all parameters.
  505. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  506. const ParmVarDecl *paramDecl = decl->getParamDecl(i);
  507. (void)declIdMapper.createFnParam(paramTypes[i + isNonStaticMemberFn],
  508. paramDecl);
  509. }
  510. if (decl->hasBody()) {
  511. // The entry basic block.
  512. const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
  513. theBuilder.setInsertPoint(entryLabel);
  514. // Process all statments in the body.
  515. doStmt(decl->getBody());
  516. // We have processed all Stmts in this function and now in the last
  517. // basic block. Make sure we have OpReturn if missing.
  518. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  519. theBuilder.createReturn();
  520. }
  521. }
  522. theBuilder.endFunction();
  523. curFunction = nullptr;
  524. }
  525. void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
  526. uint32_t varId = 0;
  527. // The contents in externally visible variables can be updated via the
  528. // pipeline. They should be handled differently from file and function scope
  529. // variables.
  530. // File scope variables (static "global" and "local" variables) belongs to
  531. // the Private storage class, while function scope variables (normal "local"
  532. // variables) belongs to the Function storage class.
  533. if (!decl->isExternallyVisible() || decl->isStaticDataMember()) {
  534. // Note: cannot move varType outside of this scope because it generates
  535. // SPIR-V types without decorations, while external visible variable should
  536. // have SPIR-V type with decorations.
  537. const uint32_t varType = typeTranslator.translateType(decl->getType());
  538. // We already know the variable is not externally visible here. If it does
  539. // not have local storage, it should be file scope variable.
  540. const bool isFileScopeVar = !decl->hasLocalStorage();
  541. // Handle initializer. SPIR-V requires that "initializer must be an <id>
  542. // from a constant instruction or a global (module scope) OpVariable
  543. // instruction."
  544. llvm::Optional<uint32_t> constInit;
  545. if (decl->hasInit()) {
  546. if (const uint32_t id = tryToEvaluateAsConst(decl->getInit()))
  547. constInit = llvm::Optional<uint32_t>(id);
  548. } else if (isFileScopeVar) {
  549. // For static variables, if no initializers are provided, we should
  550. // initialize them to zero values.
  551. constInit = llvm::Optional<uint32_t>(theBuilder.getConstantNull(varType));
  552. }
  553. if (isFileScopeVar)
  554. varId = declIdMapper.createFileVar(varType, decl, constInit);
  555. else
  556. varId = declIdMapper.createFnVar(varType, decl, constInit);
  557. // If we cannot evaluate the initializer as a constant expression, we'll
  558. // need to use OpStore to write the initializer to the variable.
  559. // Also we should only evaluate the initializer once for a static variable.
  560. if (decl->hasInit() && !constInit.hasValue()) {
  561. if (isFileScopeVar) {
  562. if (decl->isStaticLocal()) {
  563. initOnce(decl->getName(), varId, decl->getInit());
  564. } else {
  565. // Defer to initialize these global variables at the beginning of the
  566. // entry function.
  567. toInitGloalVars.push_back(decl);
  568. }
  569. } else {
  570. storeValue(varId, loadIfGLValue(decl->getInit()), decl->getType());
  571. }
  572. }
  573. } else {
  574. varId = declIdMapper.createExternVar(decl);
  575. }
  576. if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
  577. theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
  578. }
  579. if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
  580. needsLegalization = true;
  581. }
  582. spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Stmt *stmt,
  583. const Attr &attr) {
  584. switch (attr.getKind()) {
  585. case attr::HLSLLoop:
  586. case attr::HLSLFastOpt:
  587. return spv::LoopControlMask::DontUnroll;
  588. case attr::HLSLUnroll:
  589. return spv::LoopControlMask::Unroll;
  590. case attr::HLSLAllowUAVCondition:
  591. emitWarning("unsupported allow_uav_condition attribute ignored",
  592. stmt->getLocStart());
  593. break;
  594. default:
  595. llvm_unreachable("found unknown loop attribute");
  596. }
  597. return spv::LoopControlMask::MaskNone;
  598. }
  599. void SPIRVEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
  600. assert(!theBuilder.isCurrentBasicBlockTerminated());
  601. theBuilder.createKill();
  602. // Some statements that alter the control flow (break, continue, return, and
  603. // discard), require creation of a new basic block to hold any statement that
  604. // may follow them.
  605. const uint32_t newBB = theBuilder.createBasicBlock();
  606. theBuilder.setInsertPoint(newBB);
  607. }
  608. void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
  609. llvm::ArrayRef<const Attr *> attrs) {
  610. // do-while loops are composed of:
  611. //
  612. // do {
  613. // <body>
  614. // } while(<check>);
  615. //
  616. // SPIR-V requires loops to have a merge basic block as well as a continue
  617. // basic block. Even though do-while loops do not have an explicit continue
  618. // block as in for-loops, we still do need to create a continue block.
  619. //
  620. // Since SPIR-V requires structured control flow, we need two more basic
  621. // blocks, <header> and <merge>. <header> is the block before control flow
  622. // diverges, and <merge> is the block where control flow subsequently
  623. // converges. The <check> can be performed in the <continue> basic block.
  624. // The final CFG should normally be like the following. Exceptions
  625. // will occur with non-local exits like loop breaks or early returns.
  626. //
  627. // +----------+
  628. // | header | <-----------------------------------+
  629. // +----------+ |
  630. // | | (true)
  631. // v |
  632. // +------+ +--------------------+ |
  633. // | body | ----> | continue (<check>) |-----------+
  634. // +------+ +--------------------+
  635. // |
  636. // | (false)
  637. // +-------+ |
  638. // | merge | <-------------+
  639. // +-------+
  640. //
  641. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  642. const spv::LoopControlMask loopControl =
  643. attrs.empty() ? spv::LoopControlMask::MaskNone
  644. : translateLoopAttribute(theDoStmt, *attrs.front());
  645. // Create basic blocks
  646. const uint32_t headerBB = theBuilder.createBasicBlock("do_while.header");
  647. const uint32_t bodyBB = theBuilder.createBasicBlock("do_while.body");
  648. const uint32_t continueBB = theBuilder.createBasicBlock("do_while.continue");
  649. const uint32_t mergeBB = theBuilder.createBasicBlock("do_while.merge");
  650. // Make sure any continue statements branch to the continue block, and any
  651. // break statements branch to the merge block.
  652. continueStack.push(continueBB);
  653. breakStack.push(mergeBB);
  654. // Branch from the current insert point to the header block.
  655. theBuilder.createBranch(headerBB);
  656. theBuilder.addSuccessor(headerBB);
  657. // Process the <header> block
  658. // The header block must always branch to the body.
  659. theBuilder.setInsertPoint(headerBB);
  660. theBuilder.createBranch(bodyBB, mergeBB, continueBB, loopControl);
  661. theBuilder.addSuccessor(bodyBB);
  662. // The current basic block has OpLoopMerge instruction. We need to set its
  663. // continue and merge target.
  664. theBuilder.setContinueTarget(continueBB);
  665. theBuilder.setMergeTarget(mergeBB);
  666. // Process the <body> block
  667. theBuilder.setInsertPoint(bodyBB);
  668. if (const Stmt *body = theDoStmt->getBody()) {
  669. doStmt(body);
  670. }
  671. if (!theBuilder.isCurrentBasicBlockTerminated())
  672. theBuilder.createBranch(continueBB);
  673. theBuilder.addSuccessor(continueBB);
  674. // Process the <continue> block. The check for whether the loop should
  675. // continue lies in the continue block.
  676. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in
  677. // a continue block of a loop, there should be no OpSelectionMerge. Only an
  678. // OpBranchConditional must be specified.
  679. theBuilder.setInsertPoint(continueBB);
  680. uint32_t condition = 0;
  681. if (const Expr *check = theDoStmt->getCond()) {
  682. condition = doExpr(check);
  683. } else {
  684. condition = theBuilder.getConstantBool(true);
  685. }
  686. theBuilder.createConditionalBranch(condition, headerBB, mergeBB);
  687. theBuilder.addSuccessor(headerBB);
  688. theBuilder.addSuccessor(mergeBB);
  689. // Set insertion point to the <merge> block for subsequent statements
  690. theBuilder.setInsertPoint(mergeBB);
  691. // Done with the current scope's continue block and merge block.
  692. continueStack.pop();
  693. breakStack.pop();
  694. }
  695. void SPIRVEmitter::doContinueStmt(const ContinueStmt *continueStmt) {
  696. assert(!theBuilder.isCurrentBasicBlockTerminated());
  697. const uint32_t continueTargetBB = continueStack.top();
  698. theBuilder.createBranch(continueTargetBB);
  699. theBuilder.addSuccessor(continueTargetBB);
  700. // Some statements that alter the control flow (break, continue, return, and
  701. // discard), require creation of a new basic block to hold any statement that
  702. // may follow them. For example: StmtB and StmtC below are put inside a new
  703. // basic block which is unreachable.
  704. //
  705. // while (true) {
  706. // StmtA;
  707. // continue;
  708. // StmtB;
  709. // StmtC;
  710. // }
  711. const uint32_t newBB = theBuilder.createBasicBlock();
  712. theBuilder.setInsertPoint(newBB);
  713. }
  714. void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt,
  715. llvm::ArrayRef<const Attr *> attrs) {
  716. // While loops are composed of:
  717. // while (<check>) { <body> }
  718. //
  719. // SPIR-V requires loops to have a merge basic block as well as a continue
  720. // basic block. Even though while loops do not have an explicit continue
  721. // block as in for-loops, we still do need to create a continue block.
  722. //
  723. // Since SPIR-V requires structured control flow, we need two more basic
  724. // blocks, <header> and <merge>. <header> is the block before control flow
  725. // diverges, and <merge> is the block where control flow subsequently
  726. // converges. The <check> block can take the responsibility of the <header>
  727. // block. The final CFG should normally be like the following. Exceptions
  728. // will occur with non-local exits like loop breaks or early returns.
  729. //
  730. // +----------+
  731. // | header | <------------------+
  732. // | (check) | |
  733. // +----------+ |
  734. // | |
  735. // +-------+-------+ |
  736. // | false | true |
  737. // | v |
  738. // | +------+ +------------------+
  739. // | | body | --> | continue (no-op) |
  740. // v +------+ +------------------+
  741. // +-------+
  742. // | merge |
  743. // +-------+
  744. //
  745. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  746. const spv::LoopControlMask loopControl =
  747. attrs.empty() ? spv::LoopControlMask::MaskNone
  748. : translateLoopAttribute(whileStmt, *attrs.front());
  749. // Create basic blocks
  750. const uint32_t checkBB = theBuilder.createBasicBlock("while.check");
  751. const uint32_t bodyBB = theBuilder.createBasicBlock("while.body");
  752. const uint32_t continueBB = theBuilder.createBasicBlock("while.continue");
  753. const uint32_t mergeBB = theBuilder.createBasicBlock("while.merge");
  754. // Make sure any continue statements branch to the continue block, and any
  755. // break statements branch to the merge block.
  756. continueStack.push(continueBB);
  757. breakStack.push(mergeBB);
  758. // Process the <check> block
  759. theBuilder.createBranch(checkBB);
  760. theBuilder.addSuccessor(checkBB);
  761. theBuilder.setInsertPoint(checkBB);
  762. // If we have:
  763. // while (int a = foo()) {...}
  764. // we should evaluate 'a' by calling 'foo()' every single time the check has
  765. // to occur.
  766. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
  767. doStmt(condVarDecl);
  768. uint32_t condition = 0;
  769. if (const Expr *check = whileStmt->getCond()) {
  770. condition = doExpr(check);
  771. } else {
  772. condition = theBuilder.getConstantBool(true);
  773. }
  774. theBuilder.createConditionalBranch(condition, bodyBB,
  775. /*false branch*/ mergeBB,
  776. /*merge*/ mergeBB, continueBB,
  777. spv::SelectionControlMask::MaskNone,
  778. loopControl);
  779. theBuilder.addSuccessor(bodyBB);
  780. theBuilder.addSuccessor(mergeBB);
  781. // The current basic block has OpLoopMerge instruction. We need to set its
  782. // continue and merge target.
  783. theBuilder.setContinueTarget(continueBB);
  784. theBuilder.setMergeTarget(mergeBB);
  785. // Process the <body> block
  786. theBuilder.setInsertPoint(bodyBB);
  787. if (const Stmt *body = whileStmt->getBody()) {
  788. doStmt(body);
  789. }
  790. if (!theBuilder.isCurrentBasicBlockTerminated())
  791. theBuilder.createBranch(continueBB);
  792. theBuilder.addSuccessor(continueBB);
  793. // Process the <continue> block. While loops do not have an explicit
  794. // continue block. The continue block just branches to the <check> block.
  795. theBuilder.setInsertPoint(continueBB);
  796. theBuilder.createBranch(checkBB);
  797. theBuilder.addSuccessor(checkBB);
  798. // Set insertion point to the <merge> block for subsequent statements
  799. theBuilder.setInsertPoint(mergeBB);
  800. // Done with the current scope's continue and merge blocks.
  801. continueStack.pop();
  802. breakStack.pop();
  803. }
  804. void SPIRVEmitter::doForStmt(const ForStmt *forStmt,
  805. llvm::ArrayRef<const Attr *> attrs) {
  806. // for loops are composed of:
  807. // for (<init>; <check>; <continue>) <body>
  808. //
  809. // To translate a for loop, we'll need to emit all <init> statements
  810. // in the current basic block, and then have separate basic blocks for
  811. // <check>, <continue>, and <body>. Besides, since SPIR-V requires
  812. // structured control flow, we need two more basic blocks, <header>
  813. // and <merge>. <header> is the block before control flow diverges,
  814. // while <merge> is the block where control flow subsequently converges.
  815. // The <check> block can take the responsibility of the <header> block.
  816. // The final CFG should normally be like the following. Exceptions will
  817. // occur with non-local exits like loop breaks or early returns.
  818. // +--------+
  819. // | init |
  820. // +--------+
  821. // |
  822. // v
  823. // +----------+
  824. // | header | <---------------+
  825. // | (check) | |
  826. // +----------+ |
  827. // | |
  828. // +-------+-------+ |
  829. // | false | true |
  830. // | v |
  831. // | +------+ +----------+
  832. // | | body | --> | continue |
  833. // v +------+ +----------+
  834. // +-------+
  835. // | merge |
  836. // +-------+
  837. //
  838. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  839. const spv::LoopControlMask loopControl =
  840. attrs.empty() ? spv::LoopControlMask::MaskNone
  841. : translateLoopAttribute(forStmt, *attrs.front());
  842. // Create basic blocks
  843. const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
  844. const uint32_t bodyBB = theBuilder.createBasicBlock("for.body");
  845. const uint32_t continueBB = theBuilder.createBasicBlock("for.continue");
  846. const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge");
  847. // Make sure any continue statements branch to the continue block, and any
  848. // break statements branch to the merge block.
  849. continueStack.push(continueBB);
  850. breakStack.push(mergeBB);
  851. // Process the <init> block
  852. if (const Stmt *initStmt = forStmt->getInit()) {
  853. doStmt(initStmt);
  854. }
  855. theBuilder.createBranch(checkBB);
  856. theBuilder.addSuccessor(checkBB);
  857. // Process the <check> block
  858. theBuilder.setInsertPoint(checkBB);
  859. uint32_t condition;
  860. if (const Expr *check = forStmt->getCond()) {
  861. condition = doExpr(check);
  862. } else {
  863. condition = theBuilder.getConstantBool(true);
  864. }
  865. theBuilder.createConditionalBranch(condition, bodyBB,
  866. /*false branch*/ mergeBB,
  867. /*merge*/ mergeBB, continueBB,
  868. spv::SelectionControlMask::MaskNone,
  869. loopControl);
  870. theBuilder.addSuccessor(bodyBB);
  871. theBuilder.addSuccessor(mergeBB);
  872. // The current basic block has OpLoopMerge instruction. We need to set its
  873. // continue and merge target.
  874. theBuilder.setContinueTarget(continueBB);
  875. theBuilder.setMergeTarget(mergeBB);
  876. // Process the <body> block
  877. theBuilder.setInsertPoint(bodyBB);
  878. if (const Stmt *body = forStmt->getBody()) {
  879. doStmt(body);
  880. }
  881. if (!theBuilder.isCurrentBasicBlockTerminated())
  882. theBuilder.createBranch(continueBB);
  883. theBuilder.addSuccessor(continueBB);
  884. // Process the <continue> block
  885. theBuilder.setInsertPoint(continueBB);
  886. if (const Expr *cont = forStmt->getInc()) {
  887. doExpr(cont);
  888. }
  889. theBuilder.createBranch(checkBB); // <continue> should jump back to header
  890. theBuilder.addSuccessor(checkBB);
  891. // Set insertion point to the <merge> block for subsequent statements
  892. theBuilder.setInsertPoint(mergeBB);
  893. // Done with the current scope's continue block and merge block.
  894. continueStack.pop();
  895. breakStack.pop();
  896. }
  897. void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
  898. // if statements are composed of:
  899. // if (<check>) { <then> } else { <else> }
  900. //
  901. // To translate if statements, we'll need to emit the <check> expressions
  902. // in the current basic block, and then create separate basic blocks for
  903. // <then> and <else>. Additionally, we'll need a <merge> block as per
  904. // SPIR-V's structured control flow requirements. Depending whether there
  905. // exists the else branch, the final CFG should normally be like the
  906. // following. Exceptions will occur with non-local exits like loop breaks
  907. // or early returns.
  908. // +-------+ +-------+
  909. // | check | | check |
  910. // +-------+ +-------+
  911. // | |
  912. // +-------+-------+ +-----+-----+
  913. // | true | false | true | false
  914. // v v or v |
  915. // +------+ +------+ +------+ |
  916. // | then | | else | | then | |
  917. // +------+ +------+ +------+ |
  918. // | | | v
  919. // | +-------+ | | +-------+
  920. // +-> | merge | <-+ +---> | merge |
  921. // +-------+ +-------+
  922. { // Try to see if we can const-eval the condition
  923. bool condition = false;
  924. if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) {
  925. if (condition) {
  926. doStmt(ifStmt->getThen());
  927. } else if (ifStmt->getElse()) {
  928. doStmt(ifStmt->getElse());
  929. }
  930. return;
  931. }
  932. }
  933. if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt())
  934. doDeclStmt(declStmt);
  935. // First emit the instruction for evaluating the condition.
  936. const uint32_t condition = doExpr(ifStmt->getCond());
  937. // Then we need to emit the instruction for the conditional branch.
  938. // We'll need the <label-id> for the then/else/merge block to do so.
  939. const bool hasElse = ifStmt->getElse() != nullptr;
  940. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  941. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  942. const uint32_t elseBB =
  943. hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB;
  944. // Create the branch instruction. This will end the current basic block.
  945. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
  946. theBuilder.addSuccessor(thenBB);
  947. theBuilder.addSuccessor(elseBB);
  948. // The current basic block has the OpSelectionMerge instruction. We need
  949. // to record its merge target.
  950. theBuilder.setMergeTarget(mergeBB);
  951. // Handle the then branch
  952. theBuilder.setInsertPoint(thenBB);
  953. doStmt(ifStmt->getThen());
  954. if (!theBuilder.isCurrentBasicBlockTerminated())
  955. theBuilder.createBranch(mergeBB);
  956. theBuilder.addSuccessor(mergeBB);
  957. // Handle the else branch (if exists)
  958. if (hasElse) {
  959. theBuilder.setInsertPoint(elseBB);
  960. doStmt(ifStmt->getElse());
  961. if (!theBuilder.isCurrentBasicBlockTerminated())
  962. theBuilder.createBranch(mergeBB);
  963. theBuilder.addSuccessor(mergeBB);
  964. }
  965. // From now on, we'll emit instructions into the merge block.
  966. theBuilder.setInsertPoint(mergeBB);
  967. }
  968. void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
  969. if (const auto *retVal = stmt->getRetValue()) {
  970. const auto retInfo = doExpr(retVal);
  971. const auto retType = retVal->getType();
  972. if (retInfo.storageClass != spv::StorageClass::Function &&
  973. retType->isStructureType()) {
  974. // We are returning some value from a non-Function storage class. Need to
  975. // create a temporary variable to "convert" the value to Function storage
  976. // class and then return.
  977. const uint32_t valType = typeTranslator.translateType(retType);
  978. const uint32_t tempVar = theBuilder.addFnVar(valType, "temp.var.ret");
  979. storeValue(tempVar, retInfo, retType);
  980. theBuilder.createReturnValue(theBuilder.createLoad(valType, tempVar));
  981. } else {
  982. theBuilder.createReturnValue(retInfo);
  983. }
  984. } else {
  985. theBuilder.createReturn();
  986. }
  987. // Some statements that alter the control flow (break, continue, return, and
  988. // discard), require creation of a new basic block to hold any statement that
  989. // may follow them. In this case, the newly created basic block will contain
  990. // any statement that may come after an early return.
  991. const uint32_t newBB = theBuilder.createBasicBlock();
  992. theBuilder.setInsertPoint(newBB);
  993. }
  994. void SPIRVEmitter::doBreakStmt(const BreakStmt *breakStmt) {
  995. assert(!theBuilder.isCurrentBasicBlockTerminated());
  996. uint32_t breakTargetBB = breakStack.top();
  997. theBuilder.addSuccessor(breakTargetBB);
  998. theBuilder.createBranch(breakTargetBB);
  999. // Some statements that alter the control flow (break, continue, return, and
  1000. // discard), require creation of a new basic block to hold any statement that
  1001. // may follow them. For example: StmtB and StmtC below are put inside a new
  1002. // basic block which is unreachable.
  1003. //
  1004. // while (true) {
  1005. // StmtA;
  1006. // break;
  1007. // StmtB;
  1008. // StmtC;
  1009. // }
  1010. const uint32_t newBB = theBuilder.createBasicBlock();
  1011. theBuilder.setInsertPoint(newBB);
  1012. }
  1013. void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
  1014. llvm::ArrayRef<const Attr *> attrs) {
  1015. // Switch statements are composed of:
  1016. // switch (<condition variable>) {
  1017. // <CaseStmt>
  1018. // <CaseStmt>
  1019. // <CaseStmt>
  1020. // <DefaultStmt> (optional)
  1021. // }
  1022. //
  1023. // +-------+
  1024. // | check |
  1025. // +-------+
  1026. // |
  1027. // +-------+-------+----------------+---------------+
  1028. // | 1 | 2 | 3 | (others)
  1029. // v v v v
  1030. // +-------+ +-------------+ +-------+ +------------+
  1031. // | case1 | | case2 | | case3 | ... | default |
  1032. // | | |(fallthrough)|---->| | | (optional) |
  1033. // +-------+ |+------------+ +-------+ +------------+
  1034. // | | |
  1035. // | | |
  1036. // | +-------+ | |
  1037. // | | | <--------------------+ |
  1038. // +-> | merge | |
  1039. // | | <-------------------------------------+
  1040. // +-------+
  1041. // If no attributes are given, or if "forcecase" attribute was provided,
  1042. // we'll do our best to use OpSwitch if possible.
  1043. // If any of the cases compares to a variable (rather than an integer
  1044. // literal), we cannot use OpSwitch because OpSwitch expects literal
  1045. // numbers as parameters.
  1046. const bool isAttrForceCase =
  1047. !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
  1048. const bool canUseSpirvOpSwitch =
  1049. (attrs.empty() || isAttrForceCase) &&
  1050. allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
  1051. if (isAttrForceCase && !canUseSpirvOpSwitch)
  1052. emitWarning("ignored 'forcecase' attribute for the switch statement "
  1053. "since one or more case values are not integer literals",
  1054. switchStmt->getLocStart());
  1055. if (canUseSpirvOpSwitch)
  1056. processSwitchStmtUsingSpirvOpSwitch(switchStmt);
  1057. else
  1058. processSwitchStmtUsingIfStmts(switchStmt);
  1059. }
  1060. SpirvEvalInfo
  1061. SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
  1062. llvm::SmallVector<uint32_t, 4> indices;
  1063. const auto *base = collectArrayStructIndices(expr, &indices);
  1064. auto info = doExpr(base);
  1065. if (!indices.empty()) {
  1066. const uint32_t ptrType = theBuilder.getPointerType(
  1067. typeTranslator.translateType(expr->getType(), info.layoutRule),
  1068. info.storageClass);
  1069. info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
  1070. }
  1071. return info;
  1072. }
  1073. SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
  1074. const auto opcode = expr->getOpcode();
  1075. // Handle assignment first since we need to evaluate rhs before lhs.
  1076. // For other binary operations, we need to evaluate lhs before rhs.
  1077. if (opcode == BO_Assign) {
  1078. return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
  1079. false);
  1080. }
  1081. // Try to optimize floatMxN * float and floatN * float case
  1082. if (opcode == BO_Mul) {
  1083. if (const SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  1084. return result;
  1085. if (const SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  1086. return result;
  1087. }
  1088. const uint32_t resultType = typeTranslator.translateType(expr->getType());
  1089. return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode, resultType,
  1090. expr->getSourceRange());
  1091. }
  1092. SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
  1093. if (const auto *operatorCall = dyn_cast<CXXOperatorCallExpr>(callExpr))
  1094. return doCXXOperatorCallExpr(operatorCall);
  1095. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr))
  1096. return doCXXMemberCallExpr(memberCall);
  1097. // Intrinsic functions such as 'dot' or 'mul'
  1098. if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) {
  1099. return processIntrinsicCallExpr(callExpr);
  1100. }
  1101. // Normal standalone functions
  1102. return processCall(callExpr);
  1103. }
  1104. uint32_t SPIRVEmitter::processCall(const CallExpr *callExpr) {
  1105. const FunctionDecl *callee = callExpr->getDirectCallee();
  1106. if (callee) {
  1107. const auto numParams = callee->getNumParams();
  1108. bool isNonStaticMemberCall = false;
  1109. llvm::SmallVector<uint32_t, 4> params; // Temporary variables
  1110. llvm::SmallVector<uint32_t, 4> args; // Evaluated arguments
  1111. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
  1112. isNonStaticMemberCall =
  1113. !cast<CXXMethodDecl>(memberCall->getCalleeDecl())->isStatic();
  1114. if (isNonStaticMemberCall) {
  1115. // For non-static member calls, evaluate the object and pass it as the
  1116. // first argument.
  1117. const auto *object = memberCall->getImplicitObjectArgument();
  1118. args.push_back(doExpr(object));
  1119. // We do not need to create a new temporary variable for the this
  1120. // object. Use the evaluated argument.
  1121. params.push_back(args.back());
  1122. }
  1123. }
  1124. // Evaluate parameters
  1125. for (uint32_t i = 0; i < numParams; ++i) {
  1126. const auto *arg = callExpr->getArg(i);
  1127. const auto *param = callee->getParamDecl(i);
  1128. // We need to create variables for holding the values to be used as
  1129. // arguments. The variables themselves are of pointer types.
  1130. const uint32_t varType = typeTranslator.translateType(arg->getType());
  1131. const std::string varName = "param.var." + param->getNameAsString();
  1132. const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
  1133. params.push_back(tempVarId);
  1134. args.push_back(doExpr(arg));
  1135. if (canActAsOutParmVar(param)) {
  1136. // The current parameter is marked as out/inout. The argument then is
  1137. // essentially passed in by reference. We need to load the value
  1138. // explicitly here since the AST won't inject LValueToRValue implicit
  1139. // cast for this case.
  1140. const uint32_t value = theBuilder.createLoad(varType, args.back());
  1141. theBuilder.createStore(tempVarId, value);
  1142. } else {
  1143. theBuilder.createStore(tempVarId, args.back());
  1144. }
  1145. }
  1146. // Push the callee into the work queue if it is not there.
  1147. if (!workQueue.count(callee)) {
  1148. workQueue.insert(callee);
  1149. }
  1150. const uint32_t retType = typeTranslator.translateType(callExpr->getType());
  1151. // Get or forward declare the function <result-id>
  1152. const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
  1153. const uint32_t retVal =
  1154. theBuilder.createFunctionCall(retType, funcId, params);
  1155. // Go through all parameters and write those marked as out/inout
  1156. for (uint32_t i = 0; i < numParams; ++i) {
  1157. const auto *param = callee->getParamDecl(i);
  1158. if (canActAsOutParmVar(param)) {
  1159. const uint32_t index = i + isNonStaticMemberCall;
  1160. const uint32_t typeId = typeTranslator.translateType(param->getType());
  1161. const uint32_t value = theBuilder.createLoad(typeId, params[index]);
  1162. theBuilder.createStore(args[index], value);
  1163. }
  1164. }
  1165. return retVal;
  1166. }
  1167. emitError("calling non-function unimplemented", callExpr->getExprLoc());
  1168. return 0;
  1169. }
  1170. SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
  1171. const Expr *subExpr = expr->getSubExpr();
  1172. const QualType toType = expr->getType();
  1173. switch (expr->getCastKind()) {
  1174. case CastKind::CK_LValueToRValue: {
  1175. auto info = doExpr(subExpr);
  1176. // There are cases where the AST includes incorrect LValueToRValue nodes in
  1177. // the tree where not necessary. To make sure we emit the correct SPIR-V, we
  1178. // should bypass such casts.
  1179. if (subExpr->IgnoreParenNoopCasts(astContext)->isRValue())
  1180. return info;
  1181. if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr) ||
  1182. isBufferTextureIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr)) ||
  1183. isTextureMipsSampleIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr))) {
  1184. // By reaching here, it means the vector/matrix/Buffer/RWBuffer/RWTexture
  1185. // element accessing operation is an lvalue. For vector element accessing,
  1186. // if we generated a vector shuffle for it and trying to use it as a
  1187. // rvalue, we cannot do the load here as normal. Need the upper nodes in
  1188. // the AST tree to handle it properly. For matrix element accessing, load
  1189. // should have already happened after creating access chain for each
  1190. // element. For (RW)Buffer/RWTexture element accessing, load should have
  1191. // already happened using OpImageFetch.
  1192. return info;
  1193. }
  1194. // Using lvalue as rvalue means we need to OpLoad the contents from
  1195. // the parameter/variable first.
  1196. info.resultId = theBuilder.createLoad(
  1197. typeTranslator.translateType(expr->getType(), info.layoutRule), info);
  1198. return info;
  1199. }
  1200. case CastKind::CK_NoOp:
  1201. return doExpr(subExpr);
  1202. case CastKind::CK_IntegralCast:
  1203. case CastKind::CK_FloatingToIntegral:
  1204. case CastKind::CK_HLSLCC_IntegralCast:
  1205. case CastKind::CK_HLSLCC_FloatingToIntegral: {
  1206. // Integer literals in the AST are represented using 64bit APInt
  1207. // themselves and then implicitly casted into the expected bitwidth.
  1208. // We need special treatment of integer literals here because generating
  1209. // a 64bit constant and then explicit casting in SPIR-V requires Int64
  1210. // capability. We should avoid introducing unnecessary capabilities to
  1211. // our best.
  1212. llvm::APSInt intValue;
  1213. if (expr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects)) {
  1214. return translateAPInt(intValue, toType);
  1215. }
  1216. return castToInt(doExpr(subExpr), subExpr->getType(), toType,
  1217. subExpr->getExprLoc());
  1218. }
  1219. case CastKind::CK_FloatingCast:
  1220. case CastKind::CK_IntegralToFloating:
  1221. case CastKind::CK_HLSLCC_FloatingCast:
  1222. case CastKind::CK_HLSLCC_IntegralToFloating: {
  1223. // First try to see if we can do constant folding for floating point
  1224. // numbers like what we are doing for integers in the above.
  1225. Expr::EvalResult evalResult;
  1226. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  1227. !evalResult.HasSideEffects) {
  1228. return translateAPFloat(evalResult.Val.getFloat(), toType);
  1229. }
  1230. return castToFloat(doExpr(subExpr), subExpr->getType(), toType,
  1231. subExpr->getExprLoc());
  1232. }
  1233. case CastKind::CK_IntegralToBoolean:
  1234. case CastKind::CK_FloatingToBoolean:
  1235. case CastKind::CK_HLSLCC_IntegralToBoolean:
  1236. case CastKind::CK_HLSLCC_FloatingToBoolean: {
  1237. // First try to see if we can do constant folding.
  1238. bool boolVal;
  1239. if (!expr->HasSideEffects(astContext) &&
  1240. expr->EvaluateAsBooleanCondition(boolVal, astContext)) {
  1241. return theBuilder.getConstantBool(boolVal);
  1242. }
  1243. return castToBool(doExpr(subExpr), subExpr->getType(), toType);
  1244. }
  1245. case CastKind::CK_HLSLVectorSplat: {
  1246. const size_t size = hlsl::GetHLSLVecSize(expr->getType());
  1247. return createVectorSplat(subExpr, size);
  1248. }
  1249. case CastKind::CK_HLSLVectorTruncationCast: {
  1250. const uint32_t toVecTypeId = typeTranslator.translateType(toType);
  1251. const uint32_t elemTypeId =
  1252. typeTranslator.translateType(hlsl::GetHLSLVecElementType(toType));
  1253. const auto toSize = hlsl::GetHLSLVecSize(toType);
  1254. const uint32_t composite = doExpr(subExpr);
  1255. llvm::SmallVector<uint32_t, 4> elements;
  1256. for (uint32_t i = 0; i < toSize; ++i) {
  1257. elements.push_back(
  1258. theBuilder.createCompositeExtract(elemTypeId, composite, {i}));
  1259. }
  1260. if (toSize == 1) {
  1261. return elements.front();
  1262. }
  1263. return theBuilder.createCompositeConstruct(toVecTypeId, elements);
  1264. }
  1265. case CastKind::CK_HLSLVectorToScalarCast: {
  1266. // The underlying should already be a vector of size 1.
  1267. assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
  1268. return doExpr(subExpr);
  1269. }
  1270. case CastKind::CK_HLSLVectorToMatrixCast: {
  1271. // The target type should already be a 1xN matrix type.
  1272. assert(TypeTranslator::is1xNMatrix(toType));
  1273. return doExpr(subExpr);
  1274. }
  1275. case CastKind::CK_HLSLMatrixSplat: {
  1276. // From scalar to matrix
  1277. uint32_t rowCount = 0, colCount = 0;
  1278. hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount);
  1279. // Handle degenerated cases first
  1280. if (rowCount == 1 && colCount == 1)
  1281. return doExpr(subExpr);
  1282. if (colCount == 1)
  1283. return createVectorSplat(subExpr, rowCount);
  1284. const auto vecSplat = createVectorSplat(subExpr, colCount);
  1285. if (rowCount == 1)
  1286. return vecSplat;
  1287. const uint32_t matType = typeTranslator.translateType(toType);
  1288. llvm::SmallVector<uint32_t, 4> vectors(size_t(rowCount), vecSplat);
  1289. if (vecSplat.isConst) {
  1290. return SpirvEvalInfo::withConst(
  1291. theBuilder.getConstantComposite(matType, vectors));
  1292. } else {
  1293. return theBuilder.createCompositeConstruct(matType, vectors);
  1294. }
  1295. }
  1296. case CastKind::CK_HLSLMatrixTruncationCast: {
  1297. const QualType srcType = subExpr->getType();
  1298. const uint32_t srcId = doExpr(subExpr);
  1299. const QualType elemType = hlsl::GetHLSLMatElementType(srcType);
  1300. const uint32_t dstTypeId = typeTranslator.translateType(toType);
  1301. llvm::SmallVector<uint32_t, 4> indexes;
  1302. // It is possible that the source matrix is in fact a vector.
  1303. // For example: Truncate float1x3 --> float1x2.
  1304. // The front-end disallows float1x3 --> float2x1.
  1305. {
  1306. uint32_t srcVecSize = 0, dstVecSize = 0;
  1307. if (TypeTranslator::isVectorType(srcType, nullptr, &srcVecSize) &&
  1308. TypeTranslator::isVectorType(toType, nullptr, &dstVecSize)) {
  1309. for (uint32_t i = 0; i < dstVecSize; ++i)
  1310. indexes.push_back(i);
  1311. return theBuilder.createVectorShuffle(dstTypeId, srcId, srcId, indexes);
  1312. }
  1313. }
  1314. uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0;
  1315. hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols);
  1316. hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols);
  1317. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  1318. const uint32_t srcRowType = theBuilder.getVecType(elemTypeId, srcCols);
  1319. // Indexes to pass to OpVectorShuffle
  1320. for (uint32_t i = 0; i < dstCols; ++i)
  1321. indexes.push_back(i);
  1322. llvm::SmallVector<uint32_t, 4> extractedVecs;
  1323. for (uint32_t row = 0; row < dstRows; ++row) {
  1324. // Extract a row
  1325. uint32_t rowId =
  1326. theBuilder.createCompositeExtract(srcRowType, srcId, {row});
  1327. // Extract the necessary columns from that row.
  1328. // The front-end ensures dstCols <= srcCols.
  1329. // If dstCols equals srcCols, we can use the whole row directly.
  1330. if (dstCols == 1) {
  1331. rowId = theBuilder.createCompositeExtract(elemTypeId, rowId, {0});
  1332. } else if (dstCols < srcCols) {
  1333. rowId = theBuilder.createVectorShuffle(
  1334. theBuilder.getVecType(elemTypeId, dstCols), rowId, rowId, indexes);
  1335. }
  1336. extractedVecs.push_back(rowId);
  1337. }
  1338. if (extractedVecs.size() == 1)
  1339. return extractedVecs.front();
  1340. return theBuilder.createCompositeConstruct(
  1341. typeTranslator.translateType(toType), extractedVecs);
  1342. }
  1343. case CastKind::CK_HLSLMatrixToScalarCast: {
  1344. // The underlying should already be a matrix of 1x1.
  1345. assert(TypeTranslator::is1x1Matrix(subExpr->getType()));
  1346. return doExpr(subExpr);
  1347. }
  1348. case CastKind::CK_HLSLMatrixToVectorCast: {
  1349. // The underlying should already be a matrix of 1xN.
  1350. assert(TypeTranslator::is1xNMatrix(subExpr->getType()) ||
  1351. TypeTranslator::isMx1Matrix(subExpr->getType()));
  1352. return doExpr(subExpr);
  1353. }
  1354. case CastKind::CK_FunctionToPointerDecay:
  1355. // Just need to return the function id
  1356. return doExpr(subExpr);
  1357. case CastKind::CK_FlatConversion: {
  1358. // Optimization: we can use OpConstantNull for cases where we want to
  1359. // initialize an entire data structure to zeros.
  1360. llvm::APSInt intValue;
  1361. if (subExpr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects) &&
  1362. intValue.getExtValue() == 0) {
  1363. return theBuilder.getConstantNull(typeTranslator.translateType(toType));
  1364. } else {
  1365. return processFlatConversion(toType, subExpr->getType(), doExpr(subExpr),
  1366. expr->getExprLoc());
  1367. }
  1368. }
  1369. default:
  1370. emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
  1371. << expr->getCastKindName() << expr->getSourceRange();
  1372. expr->dump();
  1373. return 0;
  1374. }
  1375. }
  1376. uint32_t SPIRVEmitter::processFlatConversion(const QualType type,
  1377. const QualType initType,
  1378. const uint32_t initId,
  1379. SourceLocation srcLoc) {
  1380. // Try to translate the canonical type first
  1381. const auto canonicalType = type.getCanonicalType();
  1382. if (canonicalType != type)
  1383. return processFlatConversion(canonicalType, initType, initId, srcLoc);
  1384. // Primitive types
  1385. {
  1386. QualType ty = {};
  1387. if (TypeTranslator::isScalarType(type, &ty)) {
  1388. if (const auto *builtinType = ty->getAs<BuiltinType>()) {
  1389. switch (builtinType->getKind()) {
  1390. case BuiltinType::Void: {
  1391. emitError("cannot create a constant of void type", srcLoc);
  1392. return 0;
  1393. }
  1394. case BuiltinType::Bool:
  1395. return castToBool(initId, initType, ty);
  1396. // int, min16int (short), and min12int are all translated to 32-bit
  1397. // signed integers in SPIR-V.
  1398. case BuiltinType::Int:
  1399. case BuiltinType::Short:
  1400. case BuiltinType::Min12Int:
  1401. case BuiltinType::UShort:
  1402. case BuiltinType::UInt:
  1403. return castToInt(initId, initType, ty, srcLoc);
  1404. // float, min16float (half), and min10float are all translated to
  1405. // 32-bit float in SPIR-V.
  1406. case BuiltinType::Float:
  1407. case BuiltinType::Half:
  1408. case BuiltinType::Min10Float:
  1409. return castToFloat(initId, initType, ty, srcLoc);
  1410. default:
  1411. emitError("flat conversion of type %0 unimplemented", srcLoc)
  1412. << builtinType->getTypeClassName();
  1413. return 0;
  1414. }
  1415. }
  1416. }
  1417. }
  1418. // Vector types
  1419. {
  1420. QualType elemType = {};
  1421. uint32_t elemCount = {};
  1422. if (TypeTranslator::isVectorType(type, &elemType, &elemCount)) {
  1423. const uint32_t elemId =
  1424. processFlatConversion(elemType, initType, initId, srcLoc);
  1425. llvm::SmallVector<uint32_t, 4> constituents(size_t(elemCount), elemId);
  1426. return theBuilder.createCompositeConstruct(
  1427. typeTranslator.translateType(type), constituents);
  1428. }
  1429. }
  1430. // Matrix types
  1431. {
  1432. QualType elemType = {};
  1433. uint32_t rowCount = 0, colCount = 0;
  1434. if (TypeTranslator::isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
  1435. if (!elemType->isFloatingType()) {
  1436. emitError("non-floating-point matrix type unimplemented", {});
  1437. return 0;
  1438. }
  1439. // By default HLSL matrices are row major, while SPIR-V matrices are
  1440. // column major. We are mapping what HLSL semantically mean a row into a
  1441. // column here.
  1442. const uint32_t vecType = theBuilder.getVecType(
  1443. typeTranslator.translateType(elemType), colCount);
  1444. const uint32_t elemId =
  1445. processFlatConversion(elemType, initType, initId, srcLoc);
  1446. const llvm::SmallVector<uint32_t, 4> constituents(size_t(colCount),
  1447. elemId);
  1448. const uint32_t colId =
  1449. theBuilder.createCompositeConstruct(vecType, constituents);
  1450. const llvm::SmallVector<uint32_t, 4> rows(size_t(rowCount), colId);
  1451. return theBuilder.createCompositeConstruct(
  1452. typeTranslator.translateType(type), rows);
  1453. }
  1454. }
  1455. // Struct type
  1456. if (const auto *structType = type->getAs<RecordType>()) {
  1457. const auto *decl = structType->getDecl();
  1458. llvm::SmallVector<uint32_t, 4> fields;
  1459. for (const auto *field : decl->fields())
  1460. fields.push_back(
  1461. processFlatConversion(field->getType(), initType, initId, srcLoc));
  1462. return theBuilder.createCompositeConstruct(
  1463. typeTranslator.translateType(type), fields);
  1464. }
  1465. // Array type
  1466. if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
  1467. const auto size =
  1468. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  1469. const uint32_t elemId = processFlatConversion(arrayType->getElementType(),
  1470. initType, initId, srcLoc);
  1471. llvm::SmallVector<uint32_t, 4> constituents(size_t(size), elemId);
  1472. return theBuilder.createCompositeConstruct(
  1473. typeTranslator.translateType(type), constituents);
  1474. }
  1475. emitError("flat conversion of type %0 unimplemented", {})
  1476. << type->getTypeClassName();
  1477. type->dump();
  1478. return 0;
  1479. }
  1480. SpirvEvalInfo
  1481. SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
  1482. const auto opcode = expr->getOpcode();
  1483. // Try to optimize floatMxN *= float and floatN *= float case
  1484. if (opcode == BO_MulAssign) {
  1485. if (const SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  1486. return result;
  1487. if (const SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  1488. return result;
  1489. }
  1490. const auto *rhs = expr->getRHS();
  1491. const auto *lhs = expr->getLHS();
  1492. SpirvEvalInfo lhsPtr = 0;
  1493. const uint32_t resultType = typeTranslator.translateType(expr->getType());
  1494. const auto result = processBinaryOp(lhs, rhs, opcode, resultType,
  1495. expr->getSourceRange(), &lhsPtr);
  1496. return processAssignment(lhs, result, true, lhsPtr);
  1497. }
  1498. uint32_t SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
  1499. // According to HLSL doc, all sides of the ?: expression are always
  1500. // evaluated.
  1501. const uint32_t type = typeTranslator.translateType(expr->getType());
  1502. const uint32_t condition = doExpr(expr->getCond());
  1503. const uint32_t trueBranch = doExpr(expr->getTrueExpr());
  1504. const uint32_t falseBranch = doExpr(expr->getFalseExpr());
  1505. return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
  1506. }
  1507. uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
  1508. const CXXMemberCallExpr *expr) {
  1509. const auto *object = expr->getImplicitObjectArgument();
  1510. const auto objectId = loadIfGLValue(object);
  1511. const auto type = object->getType();
  1512. const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) ||
  1513. TypeTranslator::isRWByteAddressBuffer(type);
  1514. const bool isStructuredBuffer =
  1515. TypeTranslator::isStructuredBuffer(type) ||
  1516. TypeTranslator::isAppendStructuredBuffer(type) ||
  1517. TypeTranslator::isConsumeStructuredBuffer(type);
  1518. assert(isByteAddressBuffer || isStructuredBuffer);
  1519. // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
  1520. // with only one member that is a runtime array. We need to perform
  1521. // OpArrayLength on member 0.
  1522. const auto uintType = theBuilder.getUint32Type();
  1523. uint32_t length =
  1524. theBuilder.createBinaryOp(spv::Op::OpArrayLength, uintType, objectId, 0);
  1525. // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
  1526. // in bytes, but OpArrayLength returns the number of uints in the runtime
  1527. // array. Therefore we must multiply the results by 4.
  1528. if (isByteAddressBuffer) {
  1529. length = theBuilder.createBinaryOp(spv::Op::OpIMul, uintType, length,
  1530. theBuilder.getConstantUint32(4u));
  1531. }
  1532. theBuilder.createStore(doExpr(expr->getArg(0)), length);
  1533. if (isStructuredBuffer) {
  1534. // For (RW)StructuredBuffer, the stride of the runtime array (which is the
  1535. // size of the struct) must also be written to the second argument.
  1536. uint32_t size = 0, stride = 0;
  1537. std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
  1538. type, LayoutRule::GLSLStd430, /*isRowMajor*/ false, &stride);
  1539. const auto sizeId = theBuilder.getConstantUint32(size);
  1540. theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
  1541. }
  1542. return 0;
  1543. }
  1544. uint32_t SPIRVEmitter::processRWByteAddressBufferAtomicMethods(
  1545. hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) {
  1546. // The signature of RWByteAddressBuffer atomic methods are largely:
  1547. // void Interlocked*(in UINT dest, in UINT value);
  1548. // void Interlocked*(in UINT dest, in UINT value, out UINT original_value);
  1549. const auto *object = expr->getImplicitObjectArgument();
  1550. // We do not need to load the object since we are using its pointers.
  1551. const auto objectInfo = doExpr(object);
  1552. const auto uintType = theBuilder.getUint32Type();
  1553. const uint32_t zero = theBuilder.getConstantUint32(0);
  1554. const uint32_t offset = doExpr(expr->getArg(0));
  1555. // Right shift by 2 to convert the byte offset to uint32_t offset
  1556. const uint32_t address =
  1557. theBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, uintType, offset,
  1558. theBuilder.getConstantUint32(2));
  1559. const auto ptrType =
  1560. theBuilder.getPointerType(uintType, objectInfo.storageClass);
  1561. const uint32_t ptr =
  1562. theBuilder.createAccessChain(ptrType, objectInfo, {zero, address});
  1563. const uint32_t scope = theBuilder.getConstantUint32(1); // Device
  1564. const bool isCompareExchange =
  1565. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange;
  1566. const bool isCompareStore =
  1567. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore;
  1568. if (isCompareExchange || isCompareStore) {
  1569. const uint32_t comparator = doExpr(expr->getArg(1));
  1570. const uint32_t originalVal = theBuilder.createAtomicCompareExchange(
  1571. uintType, ptr, scope, zero, zero, doExpr(expr->getArg(2)), comparator);
  1572. if (isCompareExchange)
  1573. theBuilder.createStore(doExpr(expr->getArg(3)), originalVal);
  1574. } else {
  1575. const uint32_t value = doExpr(expr->getArg(1));
  1576. const uint32_t originalVal = theBuilder.createAtomicOp(
  1577. translateAtomicHlslOpcodeToSpirvOpcode(opcode), uintType, ptr, scope,
  1578. zero, value);
  1579. if (expr->getNumArgs() > 2)
  1580. theBuilder.createStore(doExpr(expr->getArg(2)), originalVal);
  1581. }
  1582. return 0;
  1583. }
  1584. uint32_t
  1585. SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
  1586. theBuilder.requireCapability(spv::Capability::ImageQuery);
  1587. const auto *object = expr->getImplicitObjectArgument();
  1588. const auto objectId = loadIfGLValue(object);
  1589. const auto type = object->getType();
  1590. const auto *recType = type->getAs<RecordType>();
  1591. assert(recType);
  1592. const auto typeName = recType->getDecl()->getName();
  1593. const auto numArgs = expr->getNumArgs();
  1594. const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr;
  1595. assert(TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
  1596. TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type));
  1597. // For Texture1D, arguments are either:
  1598. // a) width
  1599. // b) MipLevel, width, NumLevels
  1600. // For Texture1DArray, arguments are either:
  1601. // a) width, elements
  1602. // b) MipLevel, width, elements, NumLevels
  1603. // For Texture2D, arguments are either:
  1604. // a) width, height
  1605. // b) MipLevel, width, height, NumLevels
  1606. // For Texture2DArray, arguments are either:
  1607. // a) width, height, elements
  1608. // b) MipLevel, width, height, elements, NumLevels
  1609. // For Texture3D, arguments are either:
  1610. // a) width, height, depth
  1611. // b) MipLevel, width, height, depth, NumLevels
  1612. // For Texture2DMS, arguments are: width, height, NumSamples
  1613. // For Texture2DMSArray, arguments are: width, height, elements, NumSamples
  1614. if ((typeName == "Texture1D" && numArgs > 1) ||
  1615. (typeName == "Texture2D" && numArgs > 2) ||
  1616. (typeName == "Texture3D" && numArgs > 3) ||
  1617. (typeName == "Texture1DArray" && numArgs > 2) ||
  1618. (typeName == "Texture2DArray" && numArgs > 3)) {
  1619. mipLevel = expr->getArg(0);
  1620. numLevels = expr->getArg(numArgs - 1);
  1621. }
  1622. if (TypeTranslator::isTextureMS(type)) {
  1623. numSamples = expr->getArg(numArgs - 1);
  1624. }
  1625. uint32_t querySize = numArgs;
  1626. // If numLevels arg is present, mipLevel must also be present. These are not
  1627. // queried via ImageQuerySizeLod.
  1628. if (numLevels)
  1629. querySize -= 2;
  1630. // If numLevels arg is present, mipLevel must also be present.
  1631. else if (numSamples)
  1632. querySize -= 1;
  1633. const uint32_t uintId = theBuilder.getUint32Type();
  1634. const uint32_t resultTypeId =
  1635. querySize == 1 ? uintId : theBuilder.getVecType(uintId, querySize);
  1636. // Only Texture types use ImageQuerySizeLod.
  1637. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize.
  1638. uint32_t lod = 0;
  1639. if (TypeTranslator::isTexture(type) && !numSamples) {
  1640. if (mipLevel) {
  1641. // For Texture types when mipLevel argument is present.
  1642. lod = doExpr(mipLevel);
  1643. } else {
  1644. // For Texture types when mipLevel argument is omitted.
  1645. lod = theBuilder.getConstantInt32(0);
  1646. }
  1647. }
  1648. const uint32_t query =
  1649. lod
  1650. ? theBuilder.createBinaryOp(spv::Op::OpImageQuerySizeLod,
  1651. resultTypeId, objectId, lod)
  1652. : theBuilder.createUnaryOp(spv::Op::OpImageQuerySize, resultTypeId,
  1653. objectId);
  1654. if (querySize == 1) {
  1655. const uint32_t argIndex = mipLevel ? 1 : 0;
  1656. theBuilder.createStore(doExpr(expr->getArg(argIndex)), query);
  1657. } else {
  1658. for (uint32_t i = 0; i < querySize; ++i) {
  1659. const uint32_t component =
  1660. theBuilder.createCompositeExtract(uintId, query, {i});
  1661. // If the first arg is the mipmap level, we must write the results
  1662. // starting from Arg(i+1), not Arg(i).
  1663. const uint32_t argIndex = mipLevel ? i + 1 : i;
  1664. theBuilder.createStore(doExpr(expr->getArg(argIndex)), component);
  1665. }
  1666. }
  1667. if (numLevels || numSamples) {
  1668. const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
  1669. const spv::Op opcode =
  1670. numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
  1671. const uint32_t resultType =
  1672. typeTranslator.translateType(numLevelsSamplesArg->getType());
  1673. const uint32_t numLevelsSamplesQuery =
  1674. theBuilder.createUnaryOp(opcode, resultType, objectId);
  1675. theBuilder.createStore(doExpr(numLevelsSamplesArg), numLevelsSamplesQuery);
  1676. }
  1677. return 0;
  1678. }
  1679. uint32_t
  1680. SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr) {
  1681. // Possible signatures are as follows:
  1682. // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x);
  1683. // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy);
  1684. // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz);
  1685. // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz);
  1686. // Return type is always a single float (LOD).
  1687. assert(expr->getNumArgs() == 2u);
  1688. theBuilder.requireCapability(spv::Capability::ImageQuery);
  1689. const auto *object = expr->getImplicitObjectArgument();
  1690. const uint32_t objectId = loadIfGLValue(object);
  1691. const uint32_t samplerState = doExpr(expr->getArg(0));
  1692. const uint32_t coordinate = doExpr(expr->getArg(1));
  1693. const uint32_t sampledImageType = theBuilder.getSampledImageType(
  1694. typeTranslator.translateType(object->getType()));
  1695. const uint32_t sampledImage = theBuilder.createBinaryOp(
  1696. spv::Op::OpSampledImage, sampledImageType, objectId, samplerState);
  1697. // The result type of OpImageQueryLod must be a float2.
  1698. const uint32_t queryResultType =
  1699. theBuilder.getVecType(theBuilder.getFloat32Type(), 2u);
  1700. const uint32_t query = theBuilder.createBinaryOp(
  1701. spv::Op::OpImageQueryLod, queryResultType, sampledImage, coordinate);
  1702. // The first component of the float2 contains the mipmap array layer.
  1703. return theBuilder.createCompositeExtract(theBuilder.getFloat32Type(), query,
  1704. {0});
  1705. }
  1706. uint32_t SPIRVEmitter::processTextureGatherRGBACmpRGBA(
  1707. const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) {
  1708. // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following
  1709. // two sets:
  1710. // * SamplerState s, float2 location, int2 offset
  1711. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  1712. // int offset2, int2 offset3
  1713. //
  1714. // An additional out uint status parameter can appear in both of the above,
  1715. // which we does not support yet.
  1716. //
  1717. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following
  1718. // two sets:
  1719. // * SamplerState s, float2 location, int2 offset
  1720. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  1721. // int offset2, int2 offset3
  1722. //
  1723. // An additional out uint status parameter can appear in both of the above,
  1724. // which we does not support yet.
  1725. //
  1726. // Return type is always a 4-component vector.
  1727. const FunctionDecl *callee = expr->getDirectCallee();
  1728. const auto numArgs = expr->getNumArgs();
  1729. if (numArgs != 3 + isCmp && numArgs != 6 + isCmp) {
  1730. emitError("unsupported '%0' method call with status parameter",
  1731. expr->getExprLoc())
  1732. << callee->getName() << expr->getSourceRange();
  1733. return 0;
  1734. }
  1735. const auto *imageExpr = expr->getImplicitObjectArgument();
  1736. const uint32_t image = loadIfGLValue(imageExpr);
  1737. const uint32_t sampler = doExpr(expr->getArg(0));
  1738. const uint32_t coordinate = doExpr(expr->getArg(1));
  1739. const uint32_t compareVal = isCmp ? doExpr(expr->getArg(2)) : 0;
  1740. uint32_t constOffset = 0, varOffset = 0, constOffsets = 0;
  1741. if (numArgs == 3 + isCmp) {
  1742. // One offset parameter
  1743. handleOptionalOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset);
  1744. } else {
  1745. // Four offset parameters
  1746. const auto offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp));
  1747. const auto offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp));
  1748. const auto offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp));
  1749. const auto offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp));
  1750. // Make sure we can generate the ConstOffsets image operands in SPIR-V.
  1751. if (!offset0 || !offset1 || !offset2 || !offset3) {
  1752. emitError("all offset parameters to '%0' method call must be constants",
  1753. expr->getExprLoc())
  1754. << callee->getName() << expr->getSourceRange();
  1755. return 0;
  1756. }
  1757. const uint32_t v2i32 = theBuilder.getVecType(theBuilder.getInt32Type(), 2);
  1758. const uint32_t offsetType =
  1759. theBuilder.getArrayType(v2i32, theBuilder.getConstantUint32(4));
  1760. constOffsets = theBuilder.getConstantComposite(
  1761. offsetType, {offset0, offset1, offset2, offset3});
  1762. }
  1763. const auto retType = typeTranslator.translateType(callee->getReturnType());
  1764. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  1765. return theBuilder.createImageGather(
  1766. retType, imageType, image, sampler, coordinate,
  1767. theBuilder.getConstantInt32(component), compareVal, constOffset,
  1768. varOffset, constOffsets, /*sampleNumber*/ 0);
  1769. }
  1770. uint32_t SPIRVEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) {
  1771. // Signature:
  1772. //
  1773. // float4 GatherCmp(
  1774. // in SamplerComparisonState s,
  1775. // in float2 location,
  1776. // in float compare_value
  1777. // [,in int2 offset]
  1778. // );
  1779. const FunctionDecl *callee = expr->getDirectCallee();
  1780. const auto numArgs = expr->getNumArgs();
  1781. if (expr->getNumArgs() > 4) {
  1782. emitError("unsupported '%0' method call with status parameter",
  1783. expr->getExprLoc())
  1784. << callee->getName() << expr->getSourceRange();
  1785. return 0;
  1786. }
  1787. const auto *imageExpr = expr->getImplicitObjectArgument();
  1788. const uint32_t image = loadIfGLValue(imageExpr);
  1789. const uint32_t sampler = doExpr(expr->getArg(0));
  1790. const uint32_t coordinate = doExpr(expr->getArg(1));
  1791. const uint32_t comparator = doExpr(expr->getArg(2));
  1792. uint32_t constOffset = 0, varOffset = 0;
  1793. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  1794. const auto retType = typeTranslator.translateType(callee->getReturnType());
  1795. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  1796. return theBuilder.createImageGather(
  1797. retType, imageType, image, sampler, coordinate,
  1798. /*component*/ 0, comparator, constOffset, varOffset, /*constOffsets*/ 0,
  1799. /*sampleNumber*/ 0);
  1800. }
  1801. uint32_t SPIRVEmitter::processBufferTextureLoad(const Expr *object,
  1802. const uint32_t locationId,
  1803. uint32_t constOffset,
  1804. uint32_t varOffset,
  1805. uint32_t lod) {
  1806. // Loading for Buffer and RWBuffer translates to an OpImageFetch.
  1807. // The result type of an OpImageFetch must be a vec4 of float or int.
  1808. const auto type = object->getType();
  1809. assert(TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type) ||
  1810. TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type));
  1811. const bool doFetch =
  1812. TypeTranslator::isBuffer(type) || TypeTranslator::isTexture(type);
  1813. const uint32_t objectId = loadIfGLValue(object);
  1814. // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod.
  1815. uint32_t sampleNumber = 0;
  1816. if (TypeTranslator::isTextureMS(type)) {
  1817. sampleNumber = lod;
  1818. lod = 0;
  1819. }
  1820. const auto sampledType = hlsl::GetHLSLResourceResultType(type);
  1821. QualType elemType = sampledType;
  1822. uint32_t elemCount = 1;
  1823. uint32_t elemTypeId = 0;
  1824. (void)TypeTranslator::isVectorType(sampledType, &elemType, &elemCount);
  1825. if (elemType->isFloatingType()) {
  1826. elemTypeId = theBuilder.getFloat32Type();
  1827. } else if (elemType->isSignedIntegerType()) {
  1828. elemTypeId = theBuilder.getInt32Type();
  1829. } else if (elemType->isUnsignedIntegerType()) {
  1830. elemTypeId = theBuilder.getUint32Type();
  1831. } else {
  1832. emitError("buffer/texture type unimplemented", object->getExprLoc());
  1833. return 0;
  1834. }
  1835. const uint32_t resultTypeId =
  1836. elemCount == 1 ? elemTypeId
  1837. : theBuilder.getVecType(elemTypeId, elemCount);
  1838. // OpImageFetch can only fetch a vector of 4 elements. OpImageRead can load a
  1839. // vector of any size.
  1840. const uint32_t fetchTypeId = theBuilder.getVecType(elemTypeId, 4u);
  1841. const uint32_t texel = theBuilder.createImageFetchOrRead(
  1842. doFetch, doFetch ? fetchTypeId : resultTypeId, objectId, locationId, lod,
  1843. constOffset, varOffset, /*constOffsets*/ 0, sampleNumber);
  1844. // OpImageRead can load a vector of any size. So we can return the result of
  1845. // the instruction directly.
  1846. if (!doFetch) {
  1847. return texel;
  1848. }
  1849. // OpImageFetch can only fetch vec4. If the result type is a vec1, vec2, or
  1850. // vec3, some extra processing (extraction) is required.
  1851. switch (elemCount) {
  1852. case 1:
  1853. return theBuilder.createCompositeExtract(elemTypeId, texel, {0});
  1854. case 2:
  1855. return theBuilder.createVectorShuffle(resultTypeId, texel, texel, {0, 1});
  1856. case 3:
  1857. return theBuilder.createVectorShuffle(resultTypeId, texel, texel,
  1858. {0, 1, 2});
  1859. case 4:
  1860. return texel;
  1861. }
  1862. llvm_unreachable("Element count of a vector must be 1, 2, 3, or 4.");
  1863. }
  1864. uint32_t SPIRVEmitter::processByteAddressBufferLoadStore(
  1865. const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
  1866. uint32_t resultId = 0;
  1867. const auto object = expr->getImplicitObjectArgument();
  1868. const auto type = object->getType();
  1869. const auto objectInfo = doExpr(object);
  1870. assert(numWords >= 1 && numWords <= 4);
  1871. if (doStore) {
  1872. assert(typeTranslator.isRWByteAddressBuffer(type));
  1873. assert(expr->getNumArgs() == 2);
  1874. } else {
  1875. assert(typeTranslator.isRWByteAddressBuffer(type) ||
  1876. typeTranslator.isByteAddressBuffer(type));
  1877. if (expr->getNumArgs() == 2) {
  1878. emitError(
  1879. "(RW)ByteAddressBuffer::Load(in address, out status) unimplemented",
  1880. expr->getExprLoc());
  1881. return 0;
  1882. }
  1883. }
  1884. const Expr *addressExpr = expr->getArg(0);
  1885. const uint32_t byteAddress = doExpr(addressExpr);
  1886. const uint32_t addressTypeId =
  1887. typeTranslator.translateType(addressExpr->getType());
  1888. // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory
  1889. // access). The AST always casts the address to unsinged integer, so shift
  1890. // by unsinged integer 2.
  1891. const uint32_t constUint2 = theBuilder.getConstantUint32(2);
  1892. const uint32_t address = theBuilder.createBinaryOp(
  1893. spv::Op::OpShiftRightLogical, addressTypeId, byteAddress, constUint2);
  1894. // Perform access chain into the RWByteAddressBuffer.
  1895. // First index must be zero (member 0 of the struct is a
  1896. // runtimeArray). The second index passed to OpAccessChain should be
  1897. // the address.
  1898. const uint32_t uintTypeId = theBuilder.getUint32Type();
  1899. const uint32_t ptrType =
  1900. theBuilder.getPointerType(uintTypeId, objectInfo.storageClass);
  1901. const uint32_t constUint0 = theBuilder.getConstantUint32(0);
  1902. if (doStore) {
  1903. const uint32_t valuesId = doExpr(expr->getArg(1));
  1904. uint32_t curStoreAddress = address;
  1905. for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) {
  1906. // Extract a 32-bit word from the input.
  1907. const uint32_t curValue = numWords == 1
  1908. ? valuesId
  1909. : theBuilder.createCompositeExtract(
  1910. uintTypeId, valuesId, {wordCounter});
  1911. // Update the output address if necessary.
  1912. if (wordCounter > 0) {
  1913. const uint32_t offset = theBuilder.getConstantUint32(wordCounter);
  1914. curStoreAddress = theBuilder.createBinaryOp(
  1915. spv::Op::OpIAdd, addressTypeId, address, offset);
  1916. }
  1917. // Store the word to the right address at the output.
  1918. const uint32_t storePtr = theBuilder.createAccessChain(
  1919. ptrType, objectInfo, {constUint0, curStoreAddress});
  1920. theBuilder.createStore(storePtr, curValue);
  1921. }
  1922. } else {
  1923. uint32_t loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  1924. {constUint0, address});
  1925. resultId = theBuilder.createLoad(uintTypeId, loadPtr);
  1926. if (numWords > 1) {
  1927. // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to
  1928. // return a vector result.
  1929. llvm::SmallVector<uint32_t, 4> values;
  1930. values.push_back(resultId);
  1931. for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) {
  1932. const uint32_t offset = theBuilder.getConstantUint32(wordCounter - 1);
  1933. const uint32_t newAddress = theBuilder.createBinaryOp(
  1934. spv::Op::OpIAdd, addressTypeId, address, offset);
  1935. loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  1936. {constUint0, newAddress});
  1937. values.push_back(theBuilder.createLoad(uintTypeId, loadPtr));
  1938. }
  1939. const uint32_t resultType =
  1940. theBuilder.getVecType(addressTypeId, numWords);
  1941. resultId = theBuilder.createCompositeConstruct(resultType, values);
  1942. }
  1943. }
  1944. return resultId;
  1945. }
  1946. SpirvEvalInfo
  1947. SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
  1948. if (expr->getNumArgs() == 2) {
  1949. emitError("(RW)StructuredBuffer::Load(int, int) unimplemented",
  1950. expr->getExprLoc());
  1951. return 0;
  1952. }
  1953. const auto *buffer = expr->getImplicitObjectArgument();
  1954. auto info = doExpr(buffer);
  1955. const QualType structType =
  1956. hlsl::GetHLSLResourceResultType(buffer->getType());
  1957. const uint32_t ptrType = theBuilder.getPointerType(
  1958. typeTranslator.translateType(structType, info.layoutRule),
  1959. info.storageClass);
  1960. const uint32_t zero = theBuilder.getConstantInt32(0);
  1961. const uint32_t index = doExpr(expr->getArg(0));
  1962. info.resultId = theBuilder.createAccessChain(ptrType, info, {zero, index});
  1963. return info;
  1964. }
  1965. uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
  1966. bool isInc) {
  1967. const uint32_t i32Type = theBuilder.getInt32Type();
  1968. const uint32_t one = theBuilder.getConstantUint32(1); // As scope: Device
  1969. const uint32_t zero = theBuilder.getConstantUint32(0); // As memory sema: None
  1970. const uint32_t sOne = theBuilder.getConstantInt32(1);
  1971. const auto *object =
  1972. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  1973. const auto *buffer = cast<DeclRefExpr>(object)->getDecl();
  1974. const uint32_t counterVar = declIdMapper.getOrCreateCounterId(buffer);
  1975. const uint32_t counterPtrType = theBuilder.getPointerType(
  1976. theBuilder.getInt32Type(), spv::StorageClass::Uniform);
  1977. const uint32_t counterPtr =
  1978. theBuilder.createAccessChain(counterPtrType, counterVar, {zero});
  1979. uint32_t index = 0;
  1980. if (isInc) {
  1981. index = theBuilder.createAtomicOp(spv::Op::OpAtomicIAdd, i32Type,
  1982. counterPtr, one, zero, sOne);
  1983. } else {
  1984. // Note that OpAtomicISub returns the value before the subtraction;
  1985. // so we need to do substraction again with OpAtomicISub's return value.
  1986. const auto prev = theBuilder.createAtomicOp(spv::Op::OpAtomicISub, i32Type,
  1987. counterPtr, one, zero, sOne);
  1988. index = theBuilder.createBinaryOp(spv::Op::OpISub, i32Type, prev, sOne);
  1989. }
  1990. return index;
  1991. }
  1992. SpirvEvalInfo
  1993. SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
  1994. const bool isAppend = expr->getNumArgs() == 1;
  1995. const uint32_t zero = theBuilder.getConstantUint32(0);
  1996. const auto *object =
  1997. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  1998. const auto *buffer = cast<DeclRefExpr>(object)->getDecl();
  1999. uint32_t index = incDecRWACSBufferCounter(expr, isAppend);
  2000. auto bufferInfo = declIdMapper.getDeclResultId(buffer);
  2001. const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
  2002. const uint32_t bufferElemType =
  2003. typeTranslator.translateType(bufferElemTy, bufferInfo.layoutRule);
  2004. // Get the pointer inside the {Append|Consume}StructuredBuffer
  2005. const uint32_t bufferElemPtrType =
  2006. theBuilder.getPointerType(bufferElemType, bufferInfo.storageClass);
  2007. const uint32_t bufferElemPtr = theBuilder.createAccessChain(
  2008. bufferElemPtrType, bufferInfo.resultId, {zero, index});
  2009. if (isAppend) {
  2010. // Write out the value
  2011. bufferInfo.resultId = bufferElemPtr;
  2012. storeValue(bufferInfo, doExpr(expr->getArg(0)), bufferElemTy);
  2013. return 0;
  2014. } else {
  2015. // Somehow if the element type is not a structure type, the return value
  2016. // of .Consume() is not labelled as xvalue. That will cause OpLoad
  2017. // instruction missing. Load directly here.
  2018. if (bufferElemTy->isStructureType())
  2019. bufferInfo.resultId = bufferElemPtr;
  2020. else
  2021. bufferInfo.resultId =
  2022. theBuilder.createLoad(bufferElemType, bufferElemPtr);
  2023. return bufferInfo;
  2024. }
  2025. }
  2026. uint32_t
  2027. SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
  2028. // TODO: handle multiple stream-output objects
  2029. const auto *object =
  2030. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  2031. const auto *stream = cast<DeclRefExpr>(object)->getDecl();
  2032. const uint32_t value = doExpr(expr->getArg(0));
  2033. declIdMapper.writeBackOutputStream(stream, value);
  2034. theBuilder.createEmitVertex();
  2035. return 0;
  2036. }
  2037. uint32_t
  2038. SPIRVEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) {
  2039. // TODO: handle multiple stream-output objects
  2040. theBuilder.createEndPrimitive();
  2041. return 0;
  2042. }
  2043. SpirvEvalInfo SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
  2044. const FunctionDecl *callee = expr->getDirectCallee();
  2045. llvm::StringRef group;
  2046. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  2047. if (hlsl::GetIntrinsicOp(callee, opcode, group)) {
  2048. return processIntrinsicMemberCall(expr,
  2049. static_cast<hlsl::IntrinsicOp>(opcode));
  2050. }
  2051. return processCall(expr);
  2052. }
  2053. void SPIRVEmitter::handleOptionalOffsetInMethodCall(
  2054. const CXXMemberCallExpr *expr, uint32_t index, uint32_t *constOffset,
  2055. uint32_t *varOffset) {
  2056. *constOffset = *varOffset = 0; // Initialize both first
  2057. if (expr->getNumArgs() == index + 1) { // Has offset argument
  2058. if (*constOffset = tryToEvaluateAsConst(expr->getArg(index)))
  2059. return; // Constant offset
  2060. else
  2061. *varOffset = doExpr(expr->getArg(index));
  2062. }
  2063. };
  2064. SpirvEvalInfo
  2065. SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
  2066. hlsl::IntrinsicOp opcode) {
  2067. using namespace hlsl;
  2068. switch (opcode) {
  2069. case IntrinsicOp::MOP_Sample:
  2070. return processTextureSampleGather(expr, /*isSample=*/true);
  2071. case IntrinsicOp::MOP_Gather:
  2072. return processTextureSampleGather(expr, /*isSample=*/false);
  2073. case IntrinsicOp::MOP_SampleBias:
  2074. return processTextureSampleBiasLevel(expr, /*isBias=*/true);
  2075. case IntrinsicOp::MOP_SampleLevel:
  2076. return processTextureSampleBiasLevel(expr, /*isBias=*/false);
  2077. case IntrinsicOp::MOP_SampleGrad:
  2078. return processTextureSampleGrad(expr);
  2079. case IntrinsicOp::MOP_SampleCmp:
  2080. return processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true);
  2081. case IntrinsicOp::MOP_SampleCmpLevelZero:
  2082. return processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false);
  2083. case IntrinsicOp::MOP_GatherRed:
  2084. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0);
  2085. case IntrinsicOp::MOP_GatherGreen:
  2086. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1);
  2087. case IntrinsicOp::MOP_GatherBlue:
  2088. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2);
  2089. case IntrinsicOp::MOP_GatherAlpha:
  2090. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3);
  2091. case IntrinsicOp::MOP_GatherCmp:
  2092. return processTextureGatherCmp(expr);
  2093. case IntrinsicOp::MOP_GatherCmpRed:
  2094. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0);
  2095. case IntrinsicOp::MOP_Load:
  2096. return processBufferTextureLoad(expr);
  2097. case IntrinsicOp::MOP_Load2:
  2098. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false);
  2099. case IntrinsicOp::MOP_Load3:
  2100. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false);
  2101. case IntrinsicOp::MOP_Load4:
  2102. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false);
  2103. case IntrinsicOp::MOP_Store:
  2104. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true);
  2105. case IntrinsicOp::MOP_Store2:
  2106. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true);
  2107. case IntrinsicOp::MOP_Store3:
  2108. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true);
  2109. case IntrinsicOp::MOP_Store4:
  2110. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true);
  2111. case IntrinsicOp::MOP_GetDimensions:
  2112. return processGetDimensions(expr);
  2113. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  2114. return processTextureLevelOfDetail(expr);
  2115. case IntrinsicOp::MOP_IncrementCounter:
  2116. return theBuilder.createUnaryOp(
  2117. spv::Op::OpBitcast, theBuilder.getUint32Type(),
  2118. incDecRWACSBufferCounter(expr, /*isInc*/ true));
  2119. case IntrinsicOp::MOP_DecrementCounter:
  2120. return theBuilder.createUnaryOp(
  2121. spv::Op::OpBitcast, theBuilder.getUint32Type(),
  2122. incDecRWACSBufferCounter(expr, /*isInc*/ false));
  2123. case IntrinsicOp::MOP_Append:
  2124. if (hlsl::IsHLSLStreamOutputType(
  2125. expr->getImplicitObjectArgument()->getType()))
  2126. return processStreamOutputAppend(expr);
  2127. else
  2128. return processACSBufferAppendConsume(expr);
  2129. case IntrinsicOp::MOP_Consume:
  2130. return processACSBufferAppendConsume(expr);
  2131. case IntrinsicOp::MOP_RestartStrip:
  2132. return processStreamOutputRestart(expr);
  2133. case IntrinsicOp::MOP_InterlockedAdd:
  2134. case IntrinsicOp::MOP_InterlockedAnd:
  2135. case IntrinsicOp::MOP_InterlockedOr:
  2136. case IntrinsicOp::MOP_InterlockedXor:
  2137. case IntrinsicOp::MOP_InterlockedUMax:
  2138. case IntrinsicOp::MOP_InterlockedUMin:
  2139. case IntrinsicOp::MOP_InterlockedMax:
  2140. case IntrinsicOp::MOP_InterlockedMin:
  2141. case IntrinsicOp::MOP_InterlockedExchange:
  2142. case IntrinsicOp::MOP_InterlockedCompareExchange:
  2143. case IntrinsicOp::MOP_InterlockedCompareStore:
  2144. return processRWByteAddressBufferAtomicMethods(opcode, expr);
  2145. }
  2146. emitError("intrinsic '%0' method unimplemented", expr->getExprLoc())
  2147. << expr->getDirectCallee()->getName();
  2148. return 0;
  2149. }
  2150. uint32_t SPIRVEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
  2151. const bool isSample) {
  2152. // Signatures:
  2153. // DXGI_FORMAT Object.Sample(sampler_state S,
  2154. // float Location
  2155. // [, int Offset]);
  2156. //
  2157. // <Template Type>4 Object.Gather(sampler_state S,
  2158. // float2|3|4 Location
  2159. // [, int2 Offset]);
  2160. const auto *imageExpr = expr->getImplicitObjectArgument();
  2161. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  2162. const uint32_t image = loadIfGLValue(imageExpr);
  2163. const uint32_t sampler = doExpr(expr->getArg(0));
  2164. const uint32_t coordinate = doExpr(expr->getArg(1));
  2165. // .Sample()/.Gather() has a third optional paramter for offset.
  2166. uint32_t constOffset = 0, varOffset = 0;
  2167. handleOptionalOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  2168. const auto retType =
  2169. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  2170. if (isSample) {
  2171. return theBuilder.createImageSample(
  2172. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0,
  2173. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  2174. /*constOffsets*/ 0, /*sampleNumber*/ 0);
  2175. } else {
  2176. return theBuilder.createImageGather(
  2177. retType, imageType, image, sampler, coordinate,
  2178. // .Gather() doc says we return four components of red data.
  2179. theBuilder.getConstantInt32(0), /*compareVal*/ 0, constOffset,
  2180. varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0);
  2181. }
  2182. }
  2183. uint32_t
  2184. SPIRVEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
  2185. const bool isBias) {
  2186. // Signatures:
  2187. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  2188. // float Location,
  2189. // float Bias
  2190. // [, int Offset]);
  2191. //
  2192. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  2193. // float Location,
  2194. // float LOD
  2195. // [, int Offset]);
  2196. const auto *imageExpr = expr->getImplicitObjectArgument();
  2197. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  2198. const uint32_t image = loadIfGLValue(imageExpr);
  2199. const uint32_t sampler = doExpr(expr->getArg(0));
  2200. const uint32_t coordinate = doExpr(expr->getArg(1));
  2201. uint32_t lod = 0;
  2202. uint32_t bias = 0;
  2203. if (isBias) {
  2204. bias = doExpr(expr->getArg(2));
  2205. } else {
  2206. lod = doExpr(expr->getArg(2));
  2207. }
  2208. // .Bias()/.SampleLevel() has a fourth optional paramter for offset.
  2209. uint32_t constOffset = 0, varOffset = 0;
  2210. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  2211. const auto retType =
  2212. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  2213. return theBuilder.createImageSample(
  2214. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0, bias,
  2215. lod, std::make_pair(0, 0), constOffset, varOffset, /*constOffsets*/ 0,
  2216. /*sampleNumber*/ 0);
  2217. }
  2218. uint32_t SPIRVEmitter::processTextureSampleGrad(const CXXMemberCallExpr *expr) {
  2219. // Signature:
  2220. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  2221. // float Location,
  2222. // float DDX,
  2223. // float DDY
  2224. // [, int Offset]);
  2225. const auto *imageExpr = expr->getImplicitObjectArgument();
  2226. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  2227. const uint32_t image = loadIfGLValue(imageExpr);
  2228. const uint32_t sampler = doExpr(expr->getArg(0));
  2229. const uint32_t coordinate = doExpr(expr->getArg(1));
  2230. const uint32_t ddx = doExpr(expr->getArg(2));
  2231. const uint32_t ddy = doExpr(expr->getArg(3));
  2232. // .SampleGrad() has a fifth optional paramter for offset.
  2233. uint32_t constOffset = 0, varOffset = 0;
  2234. handleOptionalOffsetInMethodCall(expr, 4, &constOffset, &varOffset);
  2235. const auto retType =
  2236. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  2237. return theBuilder.createImageSample(
  2238. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0,
  2239. /*bias*/ 0, /*lod*/ 0, std::make_pair(ddx, ddy), constOffset, varOffset,
  2240. /*constOffsets*/ 0,
  2241. /*sampleNumber*/ 0);
  2242. }
  2243. uint32_t
  2244. SPIRVEmitter::processTextureSampleCmpCmpLevelZero(const CXXMemberCallExpr *expr,
  2245. const bool isCmp) {
  2246. // .SampleCmp() Signature:
  2247. //
  2248. // float Object.SampleCmp(
  2249. // SamplerComparisonState S,
  2250. // float Location,
  2251. // float CompareValue,
  2252. // [int Offset]
  2253. // );
  2254. //
  2255. // .SampleCmpLevelZero() is identical to .SampleCmp() on mipmap level 0 only.
  2256. const auto *imageExpr = expr->getImplicitObjectArgument();
  2257. const uint32_t image = loadIfGLValue(imageExpr);
  2258. const uint32_t sampler = doExpr(expr->getArg(0));
  2259. const uint32_t coordinate = doExpr(expr->getArg(1));
  2260. const uint32_t compareVal = doExpr(expr->getArg(2));
  2261. // .SampleCmp() has a fourth optional paramter for offset.
  2262. uint32_t constOffset = 0, varOffset = 0;
  2263. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  2264. const uint32_t lod = isCmp ? 0 : theBuilder.getConstantFloat32(0);
  2265. const auto retType =
  2266. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  2267. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  2268. return theBuilder.createImageSample(
  2269. retType, imageType, image, sampler, coordinate, compareVal, /*bias*/ 0,
  2270. lod, std::make_pair(0, 0), constOffset, varOffset,
  2271. /*constOffsets*/ 0, /*sampleNumber*/ 0);
  2272. }
  2273. SpirvEvalInfo
  2274. SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
  2275. // Signature:
  2276. // ret Object.Load(int Location
  2277. // [, int SampleIndex,]
  2278. // [, int Offset]);
  2279. const auto *object = expr->getImplicitObjectArgument();
  2280. const auto *location = expr->getArg(0);
  2281. const auto objectType = object->getType();
  2282. if (typeTranslator.isRWByteAddressBuffer(objectType) ||
  2283. typeTranslator.isByteAddressBuffer(objectType))
  2284. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
  2285. if (TypeTranslator::isStructuredBuffer(objectType))
  2286. return processStructuredBufferLoad(expr);
  2287. if (TypeTranslator::isBuffer(objectType) ||
  2288. TypeTranslator::isRWBuffer(objectType) ||
  2289. TypeTranslator::isRWTexture(objectType))
  2290. return processBufferTextureLoad(object, doExpr(location));
  2291. if (TypeTranslator::isTexture(objectType)) {
  2292. // .Load() has a second optional paramter for offset.
  2293. const auto locationId = doExpr(location);
  2294. uint32_t constOffset = 0, varOffset = 0;
  2295. uint32_t coordinate = locationId, lod = 0;
  2296. if (TypeTranslator::isTextureMS(objectType)) {
  2297. // SampleIndex is only available when the Object is of Texture2DMS or
  2298. // Texture2DMSArray types. Under those cases, Offset will be the third
  2299. // parameter (index 2).
  2300. lod = doExpr(expr->getArg(1));
  2301. handleOptionalOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  2302. } else {
  2303. // For Texture Load() functions, the location parameter is a vector
  2304. // that consists of both the coordinate and the mipmap level (via the
  2305. // last vector element). We need to split it here since the
  2306. // OpImageFetch SPIR-V instruction encodes them as separate arguments.
  2307. splitVecLastElement(location->getType(), locationId, &coordinate, &lod);
  2308. // For textures other than Texture2DMS(Array), offset should be the
  2309. // second parameter (index 1).
  2310. handleOptionalOffsetInMethodCall(expr, 1, &constOffset, &varOffset);
  2311. }
  2312. return processBufferTextureLoad(object, coordinate, constOffset, varOffset,
  2313. lod);
  2314. }
  2315. emitError("Load() of the given object type unimplemented",
  2316. object->getExprLoc());
  2317. return 0;
  2318. }
  2319. uint32_t SPIRVEmitter::processGetDimensions(const CXXMemberCallExpr *expr) {
  2320. const auto objectType = expr->getImplicitObjectArgument()->getType();
  2321. if (TypeTranslator::isTexture(objectType) ||
  2322. TypeTranslator::isRWTexture(objectType) ||
  2323. TypeTranslator::isBuffer(objectType) ||
  2324. TypeTranslator::isRWBuffer(objectType)) {
  2325. return processBufferTextureGetDimensions(expr);
  2326. } else if (TypeTranslator::isByteAddressBuffer(objectType) ||
  2327. TypeTranslator::isRWByteAddressBuffer(objectType) ||
  2328. TypeTranslator::isStructuredBuffer(objectType) ||
  2329. TypeTranslator::isAppendStructuredBuffer(objectType) ||
  2330. TypeTranslator::isConsumeStructuredBuffer(objectType)) {
  2331. return processByteAddressBufferStructuredBufferGetDimensions(expr);
  2332. } else {
  2333. emitError("GetDimensions() of the given object type unimplemented",
  2334. expr->getExprLoc());
  2335. return 0;
  2336. }
  2337. }
  2338. SpirvEvalInfo
  2339. SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
  2340. { // Handle Buffer/RWBuffer/Texture/RWTexture indexing
  2341. const Expr *baseExpr = nullptr;
  2342. const Expr *indexExpr = nullptr;
  2343. const Expr *lodExpr = nullptr;
  2344. // For Textures, regular indexing (operator[]) uses slice 0.
  2345. if (isBufferTextureIndexing(expr, &baseExpr, &indexExpr)) {
  2346. const uint32_t lod = TypeTranslator::isTexture(baseExpr->getType())
  2347. ? theBuilder.getConstantUint32(0)
  2348. : 0;
  2349. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  2350. /*constOffset*/ 0, /*varOffset*/ 0, lod);
  2351. }
  2352. // .mips[][] or .sample[][] must use the correct slice.
  2353. if (isTextureMipsSampleIndexing(expr, &baseExpr, &indexExpr, &lodExpr)) {
  2354. const uint32_t lod = doExpr(lodExpr);
  2355. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  2356. /*constOffset*/ 0, /*varOffset*/ 0, lod);
  2357. }
  2358. }
  2359. llvm::SmallVector<uint32_t, 4> indices;
  2360. const Expr *baseExpr = collectArrayStructIndices(expr, &indices);
  2361. auto base = doExpr(baseExpr);
  2362. if (indices.empty())
  2363. return base; // For indexing into size-1 vectors and 1xN matrices
  2364. // If we are indexing into a rvalue, to use OpAccessChain, we first need
  2365. // to create a local variable to hold the rvalue.
  2366. //
  2367. // TODO: We can optimize the codegen by emitting OpCompositeExtract if
  2368. // all indices are contant integers.
  2369. if (!baseExpr->isGLValue()) {
  2370. const uint32_t baseType = typeTranslator.translateType(baseExpr->getType());
  2371. const uint32_t tempVar = theBuilder.addFnVar(baseType, "temp.var");
  2372. theBuilder.createStore(tempVar, base);
  2373. base = tempVar;
  2374. }
  2375. const uint32_t ptrType = theBuilder.getPointerType(
  2376. typeTranslator.translateType(expr->getType(), base.layoutRule),
  2377. base.storageClass);
  2378. base.resultId = theBuilder.createAccessChain(ptrType, base, indices);
  2379. return base;
  2380. }
  2381. SpirvEvalInfo
  2382. SPIRVEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
  2383. const Expr *baseExpr = expr->getBase();
  2384. const auto baseInfo = doExpr(baseExpr);
  2385. const auto accessor = expr->getEncodedElementAccess();
  2386. const uint32_t elemType = typeTranslator.translateType(
  2387. hlsl::GetHLSLMatElementType(baseExpr->getType()));
  2388. uint32_t rowCount = 0, colCount = 0;
  2389. hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
  2390. // Construct a temporary vector out of all elements accessed:
  2391. // 1. Create access chain for each element using OpAccessChain
  2392. // 2. Load each element using OpLoad
  2393. // 3. Create the vector using OpCompositeConstruct
  2394. llvm::SmallVector<uint32_t, 4> elements;
  2395. for (uint32_t i = 0; i < accessor.Count; ++i) {
  2396. uint32_t row = 0, col = 0, elem = 0;
  2397. accessor.GetPosition(i, &row, &col);
  2398. llvm::SmallVector<uint32_t, 2> indices;
  2399. // If the matrix only has one row/column, we are indexing into a vector
  2400. // then. Only one index is needed for such cases.
  2401. if (rowCount > 1)
  2402. indices.push_back(row);
  2403. if (colCount > 1)
  2404. indices.push_back(col);
  2405. if (baseExpr->isGLValue()) {
  2406. for (uint32_t i = 0; i < indices.size(); ++i)
  2407. indices[i] = theBuilder.getConstantInt32(indices[i]);
  2408. const uint32_t ptrType =
  2409. theBuilder.getPointerType(elemType, baseInfo.storageClass);
  2410. if (!indices.empty()) {
  2411. // Load the element via access chain
  2412. elem = theBuilder.createAccessChain(ptrType, baseInfo, indices);
  2413. } else {
  2414. // The matrix is of size 1x1. No need to use access chain, base should
  2415. // be the source pointer.
  2416. elem = baseInfo;
  2417. }
  2418. elem = theBuilder.createLoad(elemType, elem);
  2419. } else { // e.g., (mat1 + mat2)._m11
  2420. elem = theBuilder.createCompositeExtract(elemType, baseInfo, indices);
  2421. }
  2422. elements.push_back(elem);
  2423. }
  2424. if (elements.size() == 1)
  2425. return elements.front();
  2426. const uint32_t vecType = theBuilder.getVecType(elemType, elements.size());
  2427. return theBuilder.createCompositeConstruct(vecType, elements);
  2428. }
  2429. SpirvEvalInfo
  2430. SPIRVEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
  2431. const Expr *baseExpr = nullptr;
  2432. hlsl::VectorMemberAccessPositions accessor;
  2433. condenseVectorElementExpr(expr, &baseExpr, &accessor);
  2434. const QualType baseType = baseExpr->getType();
  2435. assert(hlsl::IsHLSLVecType(baseType));
  2436. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  2437. const uint32_t type = typeTranslator.translateType(expr->getType());
  2438. const auto accessorSize = accessor.Count;
  2439. // Depending on the number of elements selected, we emit different
  2440. // instructions.
  2441. // For vectors of size greater than 1, if we are only selecting one element,
  2442. // typical access chain or composite extraction should be fine. But if we
  2443. // are selecting more than one elements, we must resolve to vector specific
  2444. // operations.
  2445. // For size-1 vectors, if we are selecting their single elements multiple
  2446. // times, we need composite construct instructions.
  2447. if (accessorSize == 1) {
  2448. if (baseSize == 1) {
  2449. // Selecting one element from a size-1 vector. The underlying vector is
  2450. // already treated as a scalar.
  2451. return doExpr(baseExpr);
  2452. }
  2453. // If the base is an lvalue, we should emit an access chain instruction
  2454. // so that we can load/store the specified element. For rvalue base,
  2455. // we should use composite extraction. We should check the immediate base
  2456. // instead of the original base here since we can have something like
  2457. // v.xyyz to turn a lvalue v into rvalue.
  2458. if (expr->getBase()->isGLValue()) { // E.g., v.x;
  2459. const auto baseInfo = doExpr(baseExpr);
  2460. const uint32_t ptrType =
  2461. theBuilder.getPointerType(type, baseInfo.storageClass);
  2462. const uint32_t index = theBuilder.getConstantInt32(accessor.Swz0);
  2463. // We need a lvalue here. Do not try to load.
  2464. return theBuilder.createAccessChain(ptrType, baseInfo, {index});
  2465. } else { // E.g., (v + w).x;
  2466. // The original base vector may not be a rvalue. Need to load it if
  2467. // it is lvalue since ImplicitCastExpr (LValueToRValue) will be missing
  2468. // for that case.
  2469. return theBuilder.createCompositeExtract(type, loadIfGLValue(baseExpr),
  2470. {accessor.Swz0});
  2471. }
  2472. }
  2473. if (baseSize == 1) {
  2474. // Selecting one element from a size-1 vector. Construct the vector.
  2475. llvm::SmallVector<uint32_t, 4> components(static_cast<size_t>(accessorSize),
  2476. loadIfGLValue(baseExpr));
  2477. return theBuilder.createCompositeConstruct(type, components);
  2478. }
  2479. llvm::SmallVector<uint32_t, 4> selectors;
  2480. selectors.resize(accessorSize);
  2481. // Whether we are selecting elements in the original order
  2482. bool originalOrder = baseSize == accessorSize;
  2483. for (uint32_t i = 0; i < accessorSize; ++i) {
  2484. accessor.GetPosition(i, &selectors[i]);
  2485. // We can select more elements than the vector provides. This handles
  2486. // that case too.
  2487. originalOrder &= selectors[i] == i;
  2488. }
  2489. if (originalOrder)
  2490. return doExpr(baseExpr);
  2491. const uint32_t baseVal = loadIfGLValue(baseExpr);
  2492. // Use base for both vectors. But we are only selecting values from the
  2493. // first one.
  2494. return theBuilder.createVectorShuffle(type, baseVal, baseVal, selectors);
  2495. }
  2496. SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
  2497. if (const uint32_t id = tryToEvaluateAsConst(expr))
  2498. return id;
  2499. return InitListHandler(*this).process(expr);
  2500. }
  2501. SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
  2502. llvm::SmallVector<uint32_t, 4> indices;
  2503. const Expr *base = collectArrayStructIndices(expr, &indices);
  2504. auto info = doExpr(base);
  2505. if (!indices.empty()) {
  2506. const uint32_t ptrType = theBuilder.getPointerType(
  2507. typeTranslator.translateType(expr->getType(), info.layoutRule),
  2508. info.storageClass);
  2509. info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
  2510. }
  2511. return info;
  2512. }
  2513. SpirvEvalInfo SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
  2514. const auto opcode = expr->getOpcode();
  2515. const auto *subExpr = expr->getSubExpr();
  2516. const auto subType = subExpr->getType();
  2517. auto subValue = doExpr(subExpr);
  2518. const auto subTypeId = typeTranslator.translateType(subType);
  2519. switch (opcode) {
  2520. case UO_PreInc:
  2521. case UO_PreDec:
  2522. case UO_PostInc:
  2523. case UO_PostDec: {
  2524. const bool isPre = opcode == UO_PreInc || opcode == UO_PreDec;
  2525. const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
  2526. const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
  2527. const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
  2528. const uint32_t one = hlsl::IsHLSLMatType(subType)
  2529. ? getMatElemValueOne(subType)
  2530. : getValueOne(subType);
  2531. uint32_t incValue = 0;
  2532. if (TypeTranslator::isSpirvAcceptableMatrixType(subType)) {
  2533. // For matrices, we can only increment/decrement each vector of it.
  2534. const auto actOnEachVec = [this, spvOp, one](
  2535. uint32_t /*index*/, uint32_t vecType, uint32_t lhsVec) {
  2536. return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, one);
  2537. };
  2538. incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec);
  2539. } else {
  2540. incValue = theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
  2541. }
  2542. theBuilder.createStore(subValue, incValue);
  2543. // Prefix increment/decrement operator returns a lvalue, while postfix
  2544. // increment/decrement returns a rvalue.
  2545. return isPre ? subValue : originValue;
  2546. }
  2547. case UO_Not:
  2548. return theBuilder.createUnaryOp(spv::Op::OpNot, subTypeId, subValue);
  2549. case UO_LNot:
  2550. // Parsing will do the necessary casting to make sure we are applying the
  2551. // ! operator on boolean values.
  2552. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, subTypeId, subValue);
  2553. case UO_Plus:
  2554. // No need to do anything for the prefix + operator.
  2555. return subValue;
  2556. case UO_Minus: {
  2557. // SPIR-V have two opcodes for negating values: OpSNegate and OpFNegate.
  2558. const spv::Op spvOp = isFloatOrVecOfFloatType(subType) ? spv::Op::OpFNegate
  2559. : spv::Op::OpSNegate;
  2560. return theBuilder.createUnaryOp(spvOp, subTypeId, subValue);
  2561. }
  2562. default:
  2563. break;
  2564. }
  2565. emitError("unary operator '%0' unimplemented", expr->getExprLoc())
  2566. << expr->getOpcodeStr(opcode);
  2567. expr->dump();
  2568. return 0;
  2569. }
  2570. spv::Op SPIRVEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
  2571. const bool isSintType = isSintOrVecMatOfSintType(type);
  2572. const bool isUintType = isUintOrVecMatOfUintType(type);
  2573. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  2574. #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
  2575. \
  2576. case BO_##kind : { \
  2577. if (isSintType || isUintType) { \
  2578. return spv::Op::Op##intBinOp; \
  2579. } \
  2580. if (isFloatType) { \
  2581. return spv::Op::Op##floatBinOp; \
  2582. } \
  2583. } \
  2584. break
  2585. #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp) \
  2586. \
  2587. case BO_##kind : { \
  2588. if (isSintType) { \
  2589. return spv::Op::Op##sintBinOp; \
  2590. } \
  2591. if (isUintType) { \
  2592. return spv::Op::Op##uintBinOp; \
  2593. } \
  2594. if (isFloatType) { \
  2595. return spv::Op::Op##floatBinOp; \
  2596. } \
  2597. } \
  2598. break
  2599. #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp) \
  2600. \
  2601. case BO_##kind : { \
  2602. if (isSintType) { \
  2603. return spv::Op::Op##sintBinOp; \
  2604. } \
  2605. if (isUintType) { \
  2606. return spv::Op::Op##uintBinOp; \
  2607. } \
  2608. } \
  2609. break
  2610. switch (op) {
  2611. case BO_EQ: {
  2612. if (isBoolOrVecMatOfBoolType(type))
  2613. return spv::Op::OpLogicalEqual;
  2614. if (isSintType || isUintType)
  2615. return spv::Op::OpIEqual;
  2616. if (isFloatType)
  2617. return spv::Op::OpFOrdEqual;
  2618. } break;
  2619. case BO_NE: {
  2620. if (isBoolOrVecMatOfBoolType(type))
  2621. return spv::Op::OpLogicalNotEqual;
  2622. if (isSintType || isUintType)
  2623. return spv::Op::OpINotEqual;
  2624. if (isFloatType)
  2625. return spv::Op::OpFOrdNotEqual;
  2626. } break;
  2627. // According to HLSL doc, all sides of the && and || expression are always
  2628. // evaluated.
  2629. case BO_LAnd:
  2630. return spv::Op::OpLogicalAnd;
  2631. case BO_LOr:
  2632. return spv::Op::OpLogicalOr;
  2633. BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
  2634. BIN_OP_CASE_INT_FLOAT(AddAssign, IAdd, FAdd);
  2635. BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
  2636. BIN_OP_CASE_INT_FLOAT(SubAssign, ISub, FSub);
  2637. BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
  2638. BIN_OP_CASE_INT_FLOAT(MulAssign, IMul, FMul);
  2639. BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
  2640. BIN_OP_CASE_SINT_UINT_FLOAT(DivAssign, SDiv, UDiv, FDiv);
  2641. // According to HLSL spec, "the modulus operator returns the remainder of
  2642. // a division." "The % operator is defined only in cases where either both
  2643. // sides are positive or both sides are negative."
  2644. //
  2645. // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
  2646. // the former, the sign of a non-0 result comes from Operand 1, while
  2647. // with the latter, from Operand 2.
  2648. //
  2649. // For operands with different signs, technically we can map % to either
  2650. // Op*Rem or Op*Mod since it's undefined behavior. But it is more
  2651. // consistent with C (HLSL starts as a C derivative) and Clang frontend
  2652. // const expression evaluation if we map % to Op*Rem.
  2653. //
  2654. // Note there is no OpURem in SPIR-V.
  2655. BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
  2656. BIN_OP_CASE_SINT_UINT_FLOAT(RemAssign, SRem, UMod, FRem);
  2657. BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
  2658. BIN_OP_CASE_SINT_UINT_FLOAT(LE, SLessThanEqual, ULessThanEqual,
  2659. FOrdLessThanEqual);
  2660. BIN_OP_CASE_SINT_UINT_FLOAT(GT, SGreaterThan, UGreaterThan,
  2661. FOrdGreaterThan);
  2662. BIN_OP_CASE_SINT_UINT_FLOAT(GE, SGreaterThanEqual, UGreaterThanEqual,
  2663. FOrdGreaterThanEqual);
  2664. BIN_OP_CASE_SINT_UINT(And, BitwiseAnd, BitwiseAnd);
  2665. BIN_OP_CASE_SINT_UINT(AndAssign, BitwiseAnd, BitwiseAnd);
  2666. BIN_OP_CASE_SINT_UINT(Or, BitwiseOr, BitwiseOr);
  2667. BIN_OP_CASE_SINT_UINT(OrAssign, BitwiseOr, BitwiseOr);
  2668. BIN_OP_CASE_SINT_UINT(Xor, BitwiseXor, BitwiseXor);
  2669. BIN_OP_CASE_SINT_UINT(XorAssign, BitwiseXor, BitwiseXor);
  2670. BIN_OP_CASE_SINT_UINT(Shl, ShiftLeftLogical, ShiftLeftLogical);
  2671. BIN_OP_CASE_SINT_UINT(ShlAssign, ShiftLeftLogical, ShiftLeftLogical);
  2672. BIN_OP_CASE_SINT_UINT(Shr, ShiftRightArithmetic, ShiftRightLogical);
  2673. BIN_OP_CASE_SINT_UINT(ShrAssign, ShiftRightArithmetic, ShiftRightLogical);
  2674. default:
  2675. break;
  2676. }
  2677. #undef BIN_OP_CASE_INT_FLOAT
  2678. #undef BIN_OP_CASE_SINT_UINT_FLOAT
  2679. #undef BIN_OP_CASE_SINT_UINT
  2680. emitError("translating binary operator '%0' unimplemented", {})
  2681. << BinaryOperator::getOpcodeStr(op);
  2682. return spv::Op::OpNop;
  2683. }
  2684. SpirvEvalInfo SPIRVEmitter::processAssignment(const Expr *lhs,
  2685. const SpirvEvalInfo &rhs,
  2686. const bool isCompoundAssignment,
  2687. SpirvEvalInfo lhsPtr) {
  2688. // Assigning to vector swizzling should be handled differently.
  2689. if (const SpirvEvalInfo result = tryToAssignToVectorElements(lhs, rhs))
  2690. return result;
  2691. // Assigning to matrix swizzling should be handled differently.
  2692. if (const SpirvEvalInfo result = tryToAssignToMatrixElements(lhs, rhs))
  2693. return result;
  2694. // Assigning to a RWBuffer/RWTexture should be handled differently.
  2695. if (const SpirvEvalInfo result = tryToAssignToRWBufferRWTexture(lhs, rhs))
  2696. return result;
  2697. // Normal assignment procedure
  2698. if (!lhsPtr.resultId)
  2699. lhsPtr = doExpr(lhs);
  2700. storeValue(lhsPtr, rhs, lhs->getType());
  2701. // Plain assignment returns a rvalue, while compound assignment returns
  2702. // lvalue.
  2703. return isCompoundAssignment ? lhsPtr : rhs;
  2704. }
  2705. void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
  2706. const SpirvEvalInfo &rhsVal,
  2707. const QualType valType) {
  2708. // If lhs and rhs has the same memory layout, we should be safe to load
  2709. // from rhs and directly store into lhs and avoid decomposing rhs.
  2710. // TODO: is this optimization always correct?
  2711. if (lhsPtr.layoutRule == rhsVal.layoutRule ||
  2712. typeTranslator.isScalarType(valType) ||
  2713. typeTranslator.isVectorType(valType) ||
  2714. typeTranslator.isMxNMatrix(valType)) {
  2715. theBuilder.createStore(lhsPtr, rhsVal);
  2716. } else if (const auto *recordType = valType->getAs<RecordType>()) {
  2717. uint32_t index = 0;
  2718. for (const auto *decl : recordType->getDecl()->decls()) {
  2719. // Ignore implicit generated struct declarations/constructors/destructors.
  2720. if (decl->isImplicit())
  2721. continue;
  2722. const auto *field = cast<FieldDecl>(decl);
  2723. assert(field);
  2724. const auto subRhsValType =
  2725. typeTranslator.translateType(field->getType(), rhsVal.layoutRule);
  2726. const auto subRhsVal =
  2727. theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
  2728. const auto subLhsPtrType = theBuilder.getPointerType(
  2729. typeTranslator.translateType(field->getType(), lhsPtr.layoutRule),
  2730. lhsPtr.storageClass);
  2731. const auto subLhsPtr = theBuilder.createAccessChain(
  2732. subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});
  2733. storeValue(lhsPtr.substResultId(subLhsPtr),
  2734. rhsVal.substResultId(subRhsVal), field->getType());
  2735. ++index;
  2736. }
  2737. } else if (const auto *arrayType =
  2738. astContext.getAsConstantArrayType(valType)) {
  2739. const auto elemType = arrayType->getElementType();
  2740. // TODO: handle extra large array size?
  2741. const auto size =
  2742. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  2743. for (uint32_t i = 0; i < size; ++i) {
  2744. const auto subRhsValType =
  2745. typeTranslator.translateType(elemType, rhsVal.layoutRule);
  2746. const auto subRhsVal =
  2747. theBuilder.createCompositeExtract(subRhsValType, rhsVal, {i});
  2748. const auto subLhsPtrType = theBuilder.getPointerType(
  2749. typeTranslator.translateType(elemType, lhsPtr.layoutRule),
  2750. lhsPtr.storageClass);
  2751. const auto subLhsPtr = theBuilder.createAccessChain(
  2752. subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(i)});
  2753. storeValue(lhsPtr.substResultId(subLhsPtr),
  2754. rhsVal.substResultId(subRhsVal), elemType);
  2755. }
  2756. } else {
  2757. emitError("storing value of type %0 unimplemented", {}) << valType;
  2758. }
  2759. }
  2760. SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
  2761. const BinaryOperatorKind opcode,
  2762. const uint32_t resultType,
  2763. SourceRange sourceRange,
  2764. SpirvEvalInfo *lhsInfo,
  2765. const spv::Op mandateGenOpcode) {
  2766. // If the operands are of matrix type, we need to dispatch the operation
  2767. // onto each element vector iff the operands are not degenerated matrices
  2768. // and we don't have a matrix specific SPIR-V instruction for the operation.
  2769. if (!isSpirvMatrixOp(mandateGenOpcode) &&
  2770. TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
  2771. return processMatrixBinaryOp(lhs, rhs, opcode, sourceRange);
  2772. }
  2773. // Comma operator works differently from other binary operations as there is
  2774. // no SPIR-V instruction for it. For each comma, we must evaluate lhs and rhs
  2775. // respectively, and return the results of rhs.
  2776. if (opcode == BO_Comma) {
  2777. (void)doExpr(lhs);
  2778. return doExpr(rhs);
  2779. }
  2780. const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
  2781. ? translateOp(opcode, lhs->getType())
  2782. : mandateGenOpcode;
  2783. SpirvEvalInfo rhsVal = 0, lhsPtr = 0, lhsVal = 0;
  2784. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  2785. // Evalute rhs before lhs
  2786. rhsVal = doExpr(rhs);
  2787. lhsVal = lhsPtr = doExpr(lhs);
  2788. // This is a compound assignment. We need to load the lhs value if lhs
  2789. // does not generate a vector shuffle.
  2790. if (!isVectorShuffle(lhs)) {
  2791. const uint32_t lhsTy = typeTranslator.translateType(lhs->getType());
  2792. lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
  2793. }
  2794. } else {
  2795. // Evalute lhs before rhs
  2796. lhsVal = lhsPtr = doExpr(lhs);
  2797. rhsVal = doExpr(rhs);
  2798. }
  2799. if (lhsInfo)
  2800. *lhsInfo = lhsPtr;
  2801. switch (opcode) {
  2802. case BO_Add:
  2803. case BO_Sub:
  2804. case BO_Mul:
  2805. case BO_Div:
  2806. case BO_Rem:
  2807. case BO_LT:
  2808. case BO_LE:
  2809. case BO_GT:
  2810. case BO_GE:
  2811. case BO_EQ:
  2812. case BO_NE:
  2813. case BO_And:
  2814. case BO_Or:
  2815. case BO_Xor:
  2816. case BO_Shl:
  2817. case BO_Shr:
  2818. case BO_LAnd:
  2819. case BO_LOr:
  2820. case BO_AddAssign:
  2821. case BO_SubAssign:
  2822. case BO_MulAssign:
  2823. case BO_DivAssign:
  2824. case BO_RemAssign:
  2825. case BO_AndAssign:
  2826. case BO_OrAssign:
  2827. case BO_XorAssign:
  2828. case BO_ShlAssign:
  2829. case BO_ShrAssign: {
  2830. const auto result =
  2831. theBuilder.createBinaryOp(spvOp, resultType, lhsVal, rhsVal);
  2832. return lhsVal.isRelaxedPrecision || rhsVal.isRelaxedPrecision
  2833. ? SpirvEvalInfo::withRelaxedPrecision(result)
  2834. : result;
  2835. }
  2836. case BO_Assign:
  2837. llvm_unreachable("assignment should not be handled here");
  2838. default:
  2839. break;
  2840. }
  2841. emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
  2842. << BinaryOperator::getOpcodeStr(opcode) << sourceRange;
  2843. return 0;
  2844. }
  2845. void SPIRVEmitter::initOnce(std::string varName, uint32_t varPtr,
  2846. const Expr *varInit) {
  2847. const uint32_t boolType = theBuilder.getBoolType();
  2848. varName = "init.done." + varName;
  2849. // Create a file/module visible variable to hold the initialization state.
  2850. const uint32_t initDoneVar =
  2851. theBuilder.addModuleVar(boolType, spv::StorageClass::Private, varName,
  2852. theBuilder.getConstantBool(false));
  2853. const uint32_t condition = theBuilder.createLoad(boolType, initDoneVar);
  2854. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  2855. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  2856. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  2857. theBuilder.addSuccessor(thenBB);
  2858. theBuilder.addSuccessor(mergeBB);
  2859. theBuilder.setMergeTarget(mergeBB);
  2860. theBuilder.setInsertPoint(thenBB);
  2861. // Do initialization and mark done
  2862. theBuilder.createStore(varPtr, doExpr(varInit));
  2863. theBuilder.createStore(initDoneVar, theBuilder.getConstantBool(true));
  2864. theBuilder.createBranch(mergeBB);
  2865. theBuilder.addSuccessor(mergeBB);
  2866. theBuilder.setInsertPoint(mergeBB);
  2867. }
  2868. bool SPIRVEmitter::isVectorShuffle(const Expr *expr) {
  2869. // TODO: the following check is essentially duplicated from
  2870. // doHLSLVectorElementExpr. Should unify them.
  2871. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  2872. const Expr *base = nullptr;
  2873. hlsl::VectorMemberAccessPositions accessor;
  2874. condenseVectorElementExpr(vecElemExpr, &base, &accessor);
  2875. const auto accessorSize = accessor.Count;
  2876. if (accessorSize == 1) {
  2877. // Selecting only one element. OpAccessChain or OpCompositeExtract for
  2878. // such cases.
  2879. return false;
  2880. }
  2881. const auto baseSize = hlsl::GetHLSLVecSize(base->getType());
  2882. if (accessorSize != baseSize)
  2883. return true;
  2884. for (uint32_t i = 0; i < accessorSize; ++i) {
  2885. uint32_t position;
  2886. accessor.GetPosition(i, &position);
  2887. if (position != i)
  2888. return true;
  2889. }
  2890. // Selecting exactly the original vector. No vector shuffle generated.
  2891. return false;
  2892. }
  2893. return false;
  2894. }
  2895. bool SPIRVEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
  2896. const Expr **base,
  2897. const Expr **location,
  2898. const Expr **lod) {
  2899. if (!expr)
  2900. return false;
  2901. // <object>.mips[][] consists of an outer operator[] and an inner operator[]
  2902. const CXXOperatorCallExpr *outerExpr = expr;
  2903. if (outerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  2904. return false;
  2905. const Expr *arg0 = outerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  2906. const CXXOperatorCallExpr *innerExpr = dyn_cast<CXXOperatorCallExpr>(arg0);
  2907. // Must have an inner operator[]
  2908. if (!innerExpr ||
  2909. innerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript) {
  2910. return false;
  2911. }
  2912. const Expr *innerArg0 =
  2913. innerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  2914. const MemberExpr *memberExpr = dyn_cast<MemberExpr>(innerArg0);
  2915. if (!memberExpr)
  2916. return false;
  2917. // Must be accessing the member named "mips" or "sample"
  2918. const auto &memberName =
  2919. memberExpr->getMemberNameInfo().getName().getAsString();
  2920. if (memberName != "mips" && memberName != "sample")
  2921. return false;
  2922. const Expr *object = memberExpr->getBase();
  2923. const auto objectType = object->getType();
  2924. if (!TypeTranslator::isTexture(objectType))
  2925. return false;
  2926. if (base)
  2927. *base = object;
  2928. if (lod)
  2929. *lod = innerExpr->getArg(1);
  2930. if (location)
  2931. *location = outerExpr->getArg(1);
  2932. return true;
  2933. }
  2934. bool SPIRVEmitter::isBufferTextureIndexing(const CXXOperatorCallExpr *indexExpr,
  2935. const Expr **base,
  2936. const Expr **index) {
  2937. if (!indexExpr)
  2938. return false;
  2939. // Must be operator[]
  2940. if (indexExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  2941. return false;
  2942. const Expr *object = indexExpr->getArg(0);
  2943. const auto objectType = object->getType();
  2944. if (TypeTranslator::isBuffer(objectType) ||
  2945. TypeTranslator::isRWBuffer(objectType) ||
  2946. TypeTranslator::isTexture(objectType) ||
  2947. TypeTranslator::isRWTexture(objectType)) {
  2948. if (base)
  2949. *base = object;
  2950. if (index)
  2951. *index = indexExpr->getArg(1);
  2952. return true;
  2953. }
  2954. return false;
  2955. }
  2956. void SPIRVEmitter::condenseVectorElementExpr(
  2957. const HLSLVectorElementExpr *expr, const Expr **basePtr,
  2958. hlsl::VectorMemberAccessPositions *flattenedAccessor) {
  2959. llvm::SmallVector<hlsl::VectorMemberAccessPositions, 2> accessors;
  2960. accessors.push_back(expr->getEncodedElementAccess());
  2961. // Recursively descending until we find the true base vector. In the
  2962. // meanwhile, collecting accessors in the reverse order.
  2963. *basePtr = expr->getBase();
  2964. while (const auto *vecElemBase = dyn_cast<HLSLVectorElementExpr>(*basePtr)) {
  2965. accessors.push_back(vecElemBase->getEncodedElementAccess());
  2966. *basePtr = vecElemBase->getBase();
  2967. }
  2968. *flattenedAccessor = accessors.back();
  2969. for (int32_t i = accessors.size() - 2; i >= 0; --i) {
  2970. const auto &currentAccessor = accessors[i];
  2971. // Apply the current level of accessor to the flattened accessor of all
  2972. // previous levels of ones.
  2973. hlsl::VectorMemberAccessPositions combinedAccessor;
  2974. for (uint32_t j = 0; j < currentAccessor.Count; ++j) {
  2975. uint32_t currentPosition = 0;
  2976. currentAccessor.GetPosition(j, &currentPosition);
  2977. uint32_t previousPosition = 0;
  2978. flattenedAccessor->GetPosition(currentPosition, &previousPosition);
  2979. combinedAccessor.SetPosition(j, previousPosition);
  2980. }
  2981. combinedAccessor.Count = currentAccessor.Count;
  2982. combinedAccessor.IsValid =
  2983. flattenedAccessor->IsValid && currentAccessor.IsValid;
  2984. *flattenedAccessor = combinedAccessor;
  2985. }
  2986. }
  2987. SpirvEvalInfo SPIRVEmitter::createVectorSplat(const Expr *scalarExpr,
  2988. uint32_t size) {
  2989. bool isConstVal = false;
  2990. uint32_t scalarVal = 0;
  2991. // Try to evaluate the element as constant first. If successful, then we
  2992. // can generate constant instructions for this vector splat.
  2993. if (scalarVal = tryToEvaluateAsConst(scalarExpr)) {
  2994. isConstVal = true;
  2995. } else {
  2996. scalarVal = doExpr(scalarExpr);
  2997. }
  2998. // Just return the scalar value for vector splat with size 1
  2999. if (size == 1)
  3000. return isConstVal ? SpirvEvalInfo::withConst(scalarVal) : scalarVal;
  3001. const uint32_t vecType = theBuilder.getVecType(
  3002. typeTranslator.translateType(scalarExpr->getType()), size);
  3003. llvm::SmallVector<uint32_t, 4> elements(size_t(size), scalarVal);
  3004. if (isConstVal) {
  3005. // TODO: we are saying the constant has Function storage class here.
  3006. // Should find a more meaningful one.
  3007. return SpirvEvalInfo::withConst(
  3008. theBuilder.getConstantComposite(vecType, elements));
  3009. } else {
  3010. return theBuilder.createCompositeConstruct(vecType, elements);
  3011. }
  3012. }
  3013. void SPIRVEmitter::splitVecLastElement(QualType vecType, uint32_t vec,
  3014. uint32_t *residual,
  3015. uint32_t *lastElement) {
  3016. assert(hlsl::IsHLSLVecType(vecType));
  3017. const uint32_t count = hlsl::GetHLSLVecSize(vecType);
  3018. assert(count > 1);
  3019. const uint32_t elemTypeId =
  3020. typeTranslator.translateType(hlsl::GetHLSLVecElementType(vecType));
  3021. if (count == 2) {
  3022. *residual = theBuilder.createCompositeExtract(elemTypeId, vec, 0);
  3023. } else {
  3024. llvm::SmallVector<uint32_t, 4> indices;
  3025. for (uint32_t i = 0; i < count - 1; ++i)
  3026. indices.push_back(i);
  3027. const uint32_t typeId = theBuilder.getVecType(elemTypeId, count - 1);
  3028. *residual = theBuilder.createVectorShuffle(typeId, vec, vec, indices);
  3029. }
  3030. *lastElement =
  3031. theBuilder.createCompositeExtract(elemTypeId, vec, {count - 1});
  3032. }
  3033. SpirvEvalInfo
  3034. SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
  3035. const QualType type = expr->getType();
  3036. const SourceRange range = expr->getSourceRange();
  3037. // We can only translate floatN * float into OpVectorTimesScalar.
  3038. // So the result type must be floatN.
  3039. if (!hlsl::IsHLSLVecType(type) ||
  3040. !hlsl::GetHLSLVecElementType(type)->isFloatingType())
  3041. return 0;
  3042. const Expr *lhs = expr->getLHS();
  3043. const Expr *rhs = expr->getRHS();
  3044. // Multiplying a float vector with a float scalar will be represented in
  3045. // AST via a binary operation with two float vectors as operands; one of
  3046. // the operand is from an implicit cast with kind CK_HLSLVectorSplat.
  3047. // vector * scalar
  3048. if (hlsl::IsHLSLVecType(lhs->getType())) {
  3049. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  3050. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  3051. const uint32_t vecType = typeTranslator.translateType(expr->getType());
  3052. if (isa<CompoundAssignOperator>(expr)) {
  3053. SpirvEvalInfo lhsPtr = 0;
  3054. const auto result = processBinaryOp(
  3055. lhs, cast->getSubExpr(), expr->getOpcode(), vecType, range,
  3056. &lhsPtr, spv::Op::OpVectorTimesScalar);
  3057. return processAssignment(lhs, result, true, lhsPtr);
  3058. } else {
  3059. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  3060. vecType, range, nullptr,
  3061. spv::Op::OpVectorTimesScalar);
  3062. }
  3063. }
  3064. }
  3065. }
  3066. // scalar * vector
  3067. if (hlsl::IsHLSLVecType(rhs->getType())) {
  3068. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  3069. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  3070. const uint32_t vecType = typeTranslator.translateType(expr->getType());
  3071. // We need to switch the positions of lhs and rhs here because
  3072. // OpVectorTimesScalar requires the first operand to be a vector and
  3073. // the second to be a scalar.
  3074. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  3075. vecType, range, nullptr,
  3076. spv::Op::OpVectorTimesScalar);
  3077. }
  3078. }
  3079. }
  3080. return 0;
  3081. }
  3082. SpirvEvalInfo
  3083. SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
  3084. const QualType type = expr->getType();
  3085. const SourceRange range = expr->getSourceRange();
  3086. // We can only translate floatMxN * float into OpMatrixTimesScalar.
  3087. // So the result type must be floatMxN.
  3088. if (!hlsl::IsHLSLMatType(type) ||
  3089. !hlsl::GetHLSLMatElementType(type)->isFloatingType())
  3090. return 0;
  3091. const Expr *lhs = expr->getLHS();
  3092. const Expr *rhs = expr->getRHS();
  3093. const QualType lhsType = lhs->getType();
  3094. const QualType rhsType = rhs->getType();
  3095. const auto selectOpcode = [](const QualType ty) {
  3096. return TypeTranslator::isMx1Matrix(ty) || TypeTranslator::is1xNMatrix(ty)
  3097. ? spv::Op::OpVectorTimesScalar
  3098. : spv::Op::OpMatrixTimesScalar;
  3099. };
  3100. // Multiplying a float matrix with a float scalar will be represented in
  3101. // AST via a binary operation with two float matrices as operands; one of
  3102. // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
  3103. // matrix * scalar
  3104. if (hlsl::IsHLSLMatType(lhsType)) {
  3105. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  3106. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  3107. const uint32_t matType = typeTranslator.translateType(expr->getType());
  3108. const spv::Op opcode = selectOpcode(lhsType);
  3109. if (isa<CompoundAssignOperator>(expr)) {
  3110. SpirvEvalInfo lhsPtr = 0;
  3111. const auto result =
  3112. processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  3113. matType, range, &lhsPtr, opcode);
  3114. return processAssignment(lhs, result, true, lhsPtr);
  3115. } else {
  3116. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  3117. matType, range, nullptr, opcode);
  3118. }
  3119. }
  3120. }
  3121. }
  3122. // scalar * matrix
  3123. if (hlsl::IsHLSLMatType(rhsType)) {
  3124. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  3125. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  3126. const uint32_t matType = typeTranslator.translateType(expr->getType());
  3127. const spv::Op opcode = selectOpcode(rhsType);
  3128. // We need to switch the positions of lhs and rhs here because
  3129. // OpMatrixTimesScalar requires the first operand to be a matrix and
  3130. // the second to be a scalar.
  3131. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  3132. matType, range, nullptr, opcode);
  3133. }
  3134. }
  3135. }
  3136. return 0;
  3137. }
  3138. SpirvEvalInfo
  3139. SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
  3140. const SpirvEvalInfo &rhs) {
  3141. // Assigning to a vector swizzling lhs is tricky if we are neither
  3142. // writing to one element nor all elements in their original order.
  3143. // Under such cases, we need to create a new vector swizzling involving
  3144. // both the lhs and rhs vectors and then write the result of this swizzling
  3145. // into the base vector of lhs.
  3146. // For example, for vec4.yz = vec2, we nee to do the following:
  3147. //
  3148. // %vec4Val = OpLoad %v4float %vec4
  3149. // %vec2Val = OpLoad %v2float %vec2
  3150. // %shuffle = OpVectorShuffle %v4float %vec4Val %vec2Val 0 4 5 3
  3151. // OpStore %vec4 %shuffle
  3152. //
  3153. // When doing the vector shuffle, we use the lhs base vector as the first
  3154. // vector and the rhs vector as the second vector. Therefore, all elements
  3155. // in the second vector will be selected into the shuffle result.
  3156. const auto *lhsExpr = dyn_cast<HLSLVectorElementExpr>(lhs);
  3157. if (!lhsExpr)
  3158. return 0;
  3159. if (!isVectorShuffle(lhs)) {
  3160. // No vector shuffle needed to be generated for this assignment.
  3161. // Should fall back to the normal handling of assignment.
  3162. return 0;
  3163. }
  3164. const Expr *base = nullptr;
  3165. hlsl::VectorMemberAccessPositions accessor;
  3166. condenseVectorElementExpr(lhsExpr, &base, &accessor);
  3167. const QualType baseType = base->getType();
  3168. assert(hlsl::IsHLSLVecType(baseType));
  3169. const auto baseSizse = hlsl::GetHLSLVecSize(baseType);
  3170. llvm::SmallVector<uint32_t, 4> selectors;
  3171. selectors.resize(baseSizse);
  3172. // Assume we are selecting all original elements first.
  3173. for (uint32_t i = 0; i < baseSizse; ++i) {
  3174. selectors[i] = i;
  3175. }
  3176. // Now fix up the elements that actually got overwritten by the rhs vector.
  3177. // Since we are using the rhs vector as the second vector, their index
  3178. // should be offset'ed by the size of the lhs base vector.
  3179. for (uint32_t i = 0; i < accessor.Count; ++i) {
  3180. uint32_t position;
  3181. accessor.GetPosition(i, &position);
  3182. selectors[position] = baseSizse + i;
  3183. }
  3184. const uint32_t baseTypeId = typeTranslator.translateType(baseType);
  3185. const uint32_t vec1 = doExpr(base);
  3186. const uint32_t vec1Val = theBuilder.createLoad(baseTypeId, vec1);
  3187. const uint32_t shuffle =
  3188. theBuilder.createVectorShuffle(baseTypeId, vec1Val, rhs, selectors);
  3189. theBuilder.createStore(vec1, shuffle);
  3190. // TODO: OK, this return value is incorrect for compound assignments, for
  3191. // which cases we should return lvalues. Should at least emit errors if
  3192. // this return value is used (can be checked via ASTContext.getParents).
  3193. return rhs;
  3194. }
  3195. SpirvEvalInfo
  3196. SPIRVEmitter::tryToAssignToRWBufferRWTexture(const Expr *lhs,
  3197. const SpirvEvalInfo &rhs) {
  3198. const Expr *baseExpr = nullptr;
  3199. const Expr *indexExpr = nullptr;
  3200. const auto lhsExpr = dyn_cast<CXXOperatorCallExpr>(lhs);
  3201. if (isBufferTextureIndexing(lhsExpr, &baseExpr, &indexExpr)) {
  3202. const uint32_t locId = doExpr(indexExpr);
  3203. const uint32_t imageId = theBuilder.createLoad(
  3204. typeTranslator.translateType(baseExpr->getType()), doExpr(baseExpr));
  3205. theBuilder.createImageWrite(imageId, locId, rhs);
  3206. return rhs;
  3207. }
  3208. return 0;
  3209. }
  3210. SpirvEvalInfo
  3211. SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
  3212. const SpirvEvalInfo &rhs) {
  3213. const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
  3214. if (!lhsExpr)
  3215. return 0;
  3216. const Expr *baseMat = lhsExpr->getBase();
  3217. const auto &base = doExpr(baseMat);
  3218. const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
  3219. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  3220. uint32_t rowCount = 0, colCount = 0;
  3221. hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
  3222. // For each lhs element written to:
  3223. // 1. Extract the corresponding rhs element using OpCompositeExtract
  3224. // 2. Create access chain for the lhs element using OpAccessChain
  3225. // 3. Write using OpStore
  3226. const auto accessor = lhsExpr->getEncodedElementAccess();
  3227. for (uint32_t i = 0; i < accessor.Count; ++i) {
  3228. uint32_t row = 0, col = 0;
  3229. accessor.GetPosition(i, &row, &col);
  3230. llvm::SmallVector<uint32_t, 2> indices;
  3231. // If the matrix only have one row/column, we are indexing into a vector
  3232. // then. Only one index is needed for such cases.
  3233. if (rowCount > 1)
  3234. indices.push_back(row);
  3235. if (colCount > 1)
  3236. indices.push_back(col);
  3237. for (uint32_t i = 0; i < indices.size(); ++i)
  3238. indices[i] = theBuilder.getConstantInt32(indices[i]);
  3239. // If we are writing to only one element, the rhs should already be a
  3240. // scalar value.
  3241. uint32_t rhsElem = rhs;
  3242. if (accessor.Count > 1)
  3243. rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
  3244. const uint32_t ptrType =
  3245. theBuilder.getPointerType(elemTypeId, base.storageClass);
  3246. // If the lhs is actually a matrix of size 1x1, we don't need the access
  3247. // chain. base is already the dest pointer.
  3248. uint32_t lhsElemPtr = base;
  3249. if (!indices.empty()) {
  3250. // Load the element via access chain
  3251. lhsElemPtr = theBuilder.createAccessChain(ptrType, lhsElemPtr, indices);
  3252. }
  3253. theBuilder.createStore(lhsElemPtr, rhsElem);
  3254. }
  3255. // TODO: OK, this return value is incorrect for compound assignments, for
  3256. // which cases we should return lvalues. Should at least emit errors if
  3257. // this return value is used (can be checked via ASTContext.getParents).
  3258. return rhs;
  3259. }
  3260. uint32_t SPIRVEmitter::processEachVectorInMatrix(
  3261. const Expr *matrix, const uint32_t matrixVal,
  3262. llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
  3263. actOnEachVector) {
  3264. const auto matType = matrix->getType();
  3265. assert(TypeTranslator::isSpirvAcceptableMatrixType(matType));
  3266. const uint32_t vecType = typeTranslator.getComponentVectorType(matType);
  3267. uint32_t rowCount = 0, colCount = 0;
  3268. hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
  3269. llvm::SmallVector<uint32_t, 4> vectors;
  3270. // Extract each component vector and do operation on it
  3271. for (uint32_t i = 0; i < rowCount; ++i) {
  3272. const uint32_t lhsVec =
  3273. theBuilder.createCompositeExtract(vecType, matrixVal, {i});
  3274. vectors.push_back(actOnEachVector(i, vecType, lhsVec));
  3275. }
  3276. // Construct the result matrix
  3277. return theBuilder.createCompositeConstruct(
  3278. typeTranslator.translateType(matType), vectors);
  3279. }
  3280. SpirvEvalInfo
  3281. SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
  3282. const BinaryOperatorKind opcode,
  3283. SourceRange range) {
  3284. // TODO: some code are duplicated from processBinaryOp. Try to unify them.
  3285. const auto lhsType = lhs->getType();
  3286. assert(TypeTranslator::isSpirvAcceptableMatrixType(lhsType));
  3287. const spv::Op spvOp = translateOp(opcode, lhsType);
  3288. uint32_t rhsVal, lhsPtr, lhsVal;
  3289. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  3290. // Evalute rhs before lhs
  3291. rhsVal = doExpr(rhs);
  3292. lhsPtr = doExpr(lhs);
  3293. const uint32_t lhsTy = typeTranslator.translateType(lhsType);
  3294. lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
  3295. } else {
  3296. // Evalute lhs before rhs
  3297. lhsVal = lhsPtr = doExpr(lhs);
  3298. rhsVal = doExpr(rhs);
  3299. }
  3300. switch (opcode) {
  3301. case BO_Add:
  3302. case BO_Sub:
  3303. case BO_Mul:
  3304. case BO_Div:
  3305. case BO_Rem:
  3306. case BO_AddAssign:
  3307. case BO_SubAssign:
  3308. case BO_MulAssign:
  3309. case BO_DivAssign:
  3310. case BO_RemAssign: {
  3311. const uint32_t vecType = typeTranslator.getComponentVectorType(lhsType);
  3312. const auto actOnEachVec = [this, spvOp, rhsVal](
  3313. uint32_t index, uint32_t vecType, uint32_t lhsVec) {
  3314. // For each vector of lhs, we need to load the corresponding vector of
  3315. // rhs and do the operation on them.
  3316. const uint32_t rhsVec =
  3317. theBuilder.createCompositeExtract(vecType, rhsVal, {index});
  3318. return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec);
  3319. };
  3320. return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec);
  3321. }
  3322. case BO_Assign:
  3323. llvm_unreachable("assignment should not be handled here");
  3324. default:
  3325. break;
  3326. }
  3327. emitError("binary operator '%0' over matrix type unimplemented",
  3328. lhs->getExprLoc())
  3329. << BinaryOperator::getOpcodeStr(opcode) << range;
  3330. return 0;
  3331. }
  3332. const Expr *SPIRVEmitter::collectArrayStructIndices(
  3333. const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices) {
  3334. if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
  3335. // First check whether this is referring to a static member. If it is, we
  3336. // create a DeclRefExpr for it.
  3337. if (auto *varDecl = dyn_cast<VarDecl>(indexing->getMemberDecl()))
  3338. if (varDecl->isStaticDataMember())
  3339. return DeclRefExpr::Create(
  3340. astContext, NestedNameSpecifierLoc(), SourceLocation(), varDecl,
  3341. /*RefersToEnclosingVariableOrCapture=*/false, SourceLocation(),
  3342. varDecl->getType(), VK_LValue);
  3343. const Expr *base = collectArrayStructIndices(
  3344. indexing->getBase()->IgnoreParenNoopCasts(astContext), indices);
  3345. // Append the index of the current level
  3346. const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
  3347. assert(fieldDecl);
  3348. indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex()));
  3349. return base;
  3350. }
  3351. if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
  3352. // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
  3353. // cast. We need to ingore it to avoid creating OpLoad.
  3354. const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
  3355. const Expr *base = collectArrayStructIndices(thisBase, indices);
  3356. indices->push_back(doExpr(indexing->getIdx()));
  3357. return base;
  3358. }
  3359. if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
  3360. if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
  3361. const Expr *thisBase =
  3362. indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
  3363. const auto thisBaseType = thisBase->getType();
  3364. const Expr *base = collectArrayStructIndices(thisBase, indices);
  3365. // If the base is a StructureType, we need to push an addtional index 0
  3366. // here. This is because we created an additional OpTypeRuntimeArray
  3367. // in the structure.
  3368. if (TypeTranslator::isStructuredBuffer(thisBaseType))
  3369. indices->push_back(theBuilder.getConstantInt32(0));
  3370. if ((hlsl::IsHLSLVecType(thisBaseType) &&
  3371. (hlsl::GetHLSLVecSize(thisBaseType) == 1)) ||
  3372. typeTranslator.is1x1Matrix(thisBaseType) ||
  3373. typeTranslator.is1xNMatrix(thisBaseType)) {
  3374. // If this is a size-1 vector or 1xN matrix, ignore the index.
  3375. } else {
  3376. indices->push_back(doExpr(indexing->getArg(1)));
  3377. }
  3378. return base;
  3379. }
  3380. {
  3381. const Expr *index = nullptr;
  3382. // TODO: the following is duplicating the logic in doCXXMemberCallExpr.
  3383. if (const auto *object = isStructuredBufferLoad(expr, &index)) {
  3384. // For object.Load(index), there should be no more indexing into the
  3385. // object.
  3386. indices->push_back(theBuilder.getConstantInt32(0));
  3387. indices->push_back(doExpr(index));
  3388. return object;
  3389. }
  3390. }
  3391. // This the deepest we can go. No more array or struct indexing.
  3392. return expr;
  3393. }
  3394. uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
  3395. QualType toBoolType) {
  3396. if (isSameScalarOrVecType(fromType, toBoolType))
  3397. return fromVal;
  3398. // Converting to bool means comparing with value zero.
  3399. const spv::Op spvOp = translateOp(BO_NE, fromType);
  3400. const uint32_t boolType = typeTranslator.translateType(toBoolType);
  3401. const uint32_t zeroVal = getValueZero(fromType);
  3402. return theBuilder.createBinaryOp(spvOp, boolType, fromVal, zeroVal);
  3403. }
  3404. uint32_t SPIRVEmitter::castToInt(const uint32_t fromVal, QualType fromType,
  3405. QualType toIntType, SourceLocation srcLoc) {
  3406. if (isSameScalarOrVecType(fromType, toIntType))
  3407. return fromVal;
  3408. const uint32_t intType = typeTranslator.translateType(toIntType);
  3409. if (isBoolOrVecOfBoolType(fromType)) {
  3410. const uint32_t one = getValueOne(toIntType);
  3411. const uint32_t zero = getValueZero(toIntType);
  3412. return theBuilder.createSelect(intType, fromVal, one, zero);
  3413. }
  3414. if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
  3415. // TODO: handle different bitwidths
  3416. return theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, fromVal);
  3417. }
  3418. if (isFloatOrVecOfFloatType(fromType)) {
  3419. if (isSintOrVecOfSintType(toIntType)) {
  3420. return theBuilder.createUnaryOp(spv::Op::OpConvertFToS, intType, fromVal);
  3421. } else if (isUintOrVecOfUintType(toIntType)) {
  3422. return theBuilder.createUnaryOp(spv::Op::OpConvertFToU, intType, fromVal);
  3423. } else {
  3424. emitError("casting from floating point to integer unimplemented", srcLoc);
  3425. }
  3426. } else {
  3427. emitError("casting to integer unimplemented", srcLoc);
  3428. }
  3429. return 0;
  3430. }
  3431. uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
  3432. QualType toFloatType,
  3433. SourceLocation srcLoc) {
  3434. if (isSameScalarOrVecType(fromType, toFloatType))
  3435. return fromVal;
  3436. const uint32_t floatType = typeTranslator.translateType(toFloatType);
  3437. if (isBoolOrVecOfBoolType(fromType)) {
  3438. const uint32_t one = getValueOne(toFloatType);
  3439. const uint32_t zero = getValueZero(toFloatType);
  3440. return theBuilder.createSelect(floatType, fromVal, one, zero);
  3441. }
  3442. if (isSintOrVecOfSintType(fromType)) {
  3443. return theBuilder.createUnaryOp(spv::Op::OpConvertSToF, floatType, fromVal);
  3444. }
  3445. if (isUintOrVecOfUintType(fromType)) {
  3446. return theBuilder.createUnaryOp(spv::Op::OpConvertUToF, floatType, fromVal);
  3447. }
  3448. if (isFloatOrVecOfFloatType(fromType)) {
  3449. emitError("casting between different floating point bitwidth unimplemented",
  3450. srcLoc);
  3451. return 0;
  3452. }
  3453. emitError("casting to floating point unimplemented", srcLoc);
  3454. return 0;
  3455. }
  3456. uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
  3457. const FunctionDecl *callee = callExpr->getDirectCallee();
  3458. assert(hlsl::IsIntrinsicOp(callee) &&
  3459. "doIntrinsicCallExpr was called for a non-intrinsic function.");
  3460. const bool isFloatType = isFloatOrVecMatOfFloatType(callExpr->getType());
  3461. const bool isSintType = isSintOrVecMatOfSintType(callExpr->getType());
  3462. // Figure out which intrinsic function to translate.
  3463. llvm::StringRef group;
  3464. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  3465. hlsl::GetIntrinsicOp(callee, opcode, group);
  3466. GLSLstd450 glslOpcode = GLSLstd450Bad;
  3467. #define INTRINSIC_SPIRV_OP_WITH_CAP_CASE(intrinsicOp, spirvOp, doEachVec, cap) \
  3468. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3469. theBuilder.requireCapability(cap); \
  3470. return processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  3471. doEachVec); \
  3472. } break
  3473. #define INTRINSIC_SPIRV_OP_CASE(intrinsicOp, spirvOp, doEachVec) \
  3474. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3475. return processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  3476. doEachVec); \
  3477. } break
  3478. #define INTRINSIC_OP_CASE(intrinsicOp, glslOp, doEachVec) \
  3479. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3480. glslOpcode = GLSLstd450::GLSLstd450##glslOp; \
  3481. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3482. } break
  3483. #define INTRINSIC_OP_CASE_INT_FLOAT(intrinsicOp, glslIntOp, glslFloatOp, \
  3484. doEachVec) \
  3485. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3486. glslOpcode = isFloatType ? GLSLstd450::GLSLstd450##glslFloatOp \
  3487. : GLSLstd450::GLSLstd450##glslIntOp; \
  3488. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3489. } break
  3490. #define INTRINSIC_OP_CASE_SINT_UINT(intrinsicOp, glslSintOp, glslUintOp, \
  3491. doEachVec) \
  3492. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3493. glslOpcode = isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  3494. : GLSLstd450::GLSLstd450##glslUintOp; \
  3495. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3496. } break
  3497. #define INTRINSIC_OP_CASE_SINT_UINT_FLOAT(intrinsicOp, glslSintOp, glslUintOp, \
  3498. glslFloatOp, doEachVec) \
  3499. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3500. glslOpcode = isFloatType \
  3501. ? GLSLstd450::GLSLstd450##glslFloatOp \
  3502. : isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  3503. : GLSLstd450::GLSLstd450##glslUintOp; \
  3504. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3505. } break
  3506. switch (const auto hlslOpcode = static_cast<hlsl::IntrinsicOp>(opcode)) {
  3507. case hlsl::IntrinsicOp::IOP_InterlockedAdd:
  3508. case hlsl::IntrinsicOp::IOP_InterlockedAnd:
  3509. case hlsl::IntrinsicOp::IOP_InterlockedMax:
  3510. case hlsl::IntrinsicOp::IOP_InterlockedUMax:
  3511. case hlsl::IntrinsicOp::IOP_InterlockedMin:
  3512. case hlsl::IntrinsicOp::IOP_InterlockedUMin:
  3513. case hlsl::IntrinsicOp::IOP_InterlockedOr:
  3514. case hlsl::IntrinsicOp::IOP_InterlockedXor:
  3515. case hlsl::IntrinsicOp::IOP_InterlockedExchange:
  3516. case hlsl::IntrinsicOp::IOP_InterlockedCompareStore:
  3517. case hlsl::IntrinsicOp::IOP_InterlockedCompareExchange:
  3518. return processIntrinsicInterlockedMethod(callExpr, hlslOpcode);
  3519. case hlsl::IntrinsicOp::IOP_tex1D:
  3520. case hlsl::IntrinsicOp::IOP_tex1Dbias:
  3521. case hlsl::IntrinsicOp::IOP_tex1Dgrad:
  3522. case hlsl::IntrinsicOp::IOP_tex1Dlod:
  3523. case hlsl::IntrinsicOp::IOP_tex1Dproj:
  3524. case hlsl::IntrinsicOp::IOP_tex2D:
  3525. case hlsl::IntrinsicOp::IOP_tex2Dbias:
  3526. case hlsl::IntrinsicOp::IOP_tex2Dgrad:
  3527. case hlsl::IntrinsicOp::IOP_tex2Dlod:
  3528. case hlsl::IntrinsicOp::IOP_tex2Dproj:
  3529. case hlsl::IntrinsicOp::IOP_tex3D:
  3530. case hlsl::IntrinsicOp::IOP_tex3Dbias:
  3531. case hlsl::IntrinsicOp::IOP_tex3Dgrad:
  3532. case hlsl::IntrinsicOp::IOP_tex3Dlod:
  3533. case hlsl::IntrinsicOp::IOP_tex3Dproj:
  3534. case hlsl::IntrinsicOp::IOP_texCUBE:
  3535. case hlsl::IntrinsicOp::IOP_texCUBEbias:
  3536. case hlsl::IntrinsicOp::IOP_texCUBEgrad:
  3537. case hlsl::IntrinsicOp::IOP_texCUBElod:
  3538. case hlsl::IntrinsicOp::IOP_texCUBEproj: {
  3539. emitError("deprecated %0 intrinsic function will not be supported",
  3540. callExpr->getExprLoc())
  3541. << callee->getName();
  3542. return 0;
  3543. }
  3544. case hlsl::IntrinsicOp::IOP_dot:
  3545. return processIntrinsicDot(callExpr);
  3546. case hlsl::IntrinsicOp::IOP_mul:
  3547. return processIntrinsicMul(callExpr);
  3548. case hlsl::IntrinsicOp::IOP_all:
  3549. return processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
  3550. case hlsl::IntrinsicOp::IOP_any:
  3551. return processIntrinsicAllOrAny(callExpr, spv::Op::OpAny);
  3552. case hlsl::IntrinsicOp::IOP_asdouble:
  3553. case hlsl::IntrinsicOp::IOP_asfloat:
  3554. case hlsl::IntrinsicOp::IOP_asint:
  3555. case hlsl::IntrinsicOp::IOP_asuint:
  3556. return processIntrinsicAsType(callExpr);
  3557. case hlsl::IntrinsicOp::IOP_clip: {
  3558. return processIntrinsicClip(callExpr);
  3559. }
  3560. case hlsl::IntrinsicOp::IOP_clamp:
  3561. case hlsl::IntrinsicOp::IOP_uclamp:
  3562. return processIntrinsicClamp(callExpr);
  3563. case hlsl::IntrinsicOp::IOP_frexp:
  3564. return processIntrinsicFrexp(callExpr);
  3565. case hlsl::IntrinsicOp::IOP_lit:
  3566. return processIntrinsicLit(callExpr);
  3567. case hlsl::IntrinsicOp::IOP_modf:
  3568. return processIntrinsicModf(callExpr);
  3569. case hlsl::IntrinsicOp::IOP_sign: {
  3570. if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
  3571. return processIntrinsicFloatSign(callExpr);
  3572. else
  3573. return processIntrinsicUsingGLSLInst(callExpr,
  3574. GLSLstd450::GLSLstd450SSign,
  3575. /*actPerRowForMatrices*/ true);
  3576. }
  3577. case hlsl::IntrinsicOp::IOP_isfinite: {
  3578. return processIntrinsicIsFinite(callExpr);
  3579. }
  3580. case hlsl::IntrinsicOp::IOP_sincos: {
  3581. return processIntrinsicSinCos(callExpr);
  3582. }
  3583. case hlsl::IntrinsicOp::IOP_rcp: {
  3584. return processIntrinsicRcp(callExpr);
  3585. }
  3586. case hlsl::IntrinsicOp::IOP_saturate: {
  3587. return processIntrinsicSaturate(callExpr);
  3588. }
  3589. case hlsl::IntrinsicOp::IOP_log10: {
  3590. return processIntrinsicLog10(callExpr);
  3591. }
  3592. case hlsl::IntrinsicOp::IOP_f16tof32:
  3593. return processIntrinsicF16ToF32(callExpr);
  3594. case hlsl::IntrinsicOp::IOP_f32tof16:
  3595. return processIntrinsicF32ToF16(callExpr);
  3596. INTRINSIC_SPIRV_OP_CASE(transpose, Transpose, false);
  3597. INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
  3598. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_coarse, DPdxCoarse, false,
  3599. spv::Capability::DerivativeControl);
  3600. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_fine, DPdxFine, false,
  3601. spv::Capability::DerivativeControl);
  3602. INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
  3603. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_coarse, DPdyCoarse, false,
  3604. spv::Capability::DerivativeControl);
  3605. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_fine, DPdyFine, false,
  3606. spv::Capability::DerivativeControl);
  3607. INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
  3608. INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
  3609. INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
  3610. INTRINSIC_SPIRV_OP_CASE(fmod, FMod, true);
  3611. INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
  3612. INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
  3613. INTRINSIC_OP_CASE(round, Round, true);
  3614. INTRINSIC_OP_CASE_INT_FLOAT(abs, SAbs, FAbs, true);
  3615. INTRINSIC_OP_CASE(acos, Acos, true);
  3616. INTRINSIC_OP_CASE(asin, Asin, true);
  3617. INTRINSIC_OP_CASE(atan, Atan, true);
  3618. INTRINSIC_OP_CASE(atan2, Atan2, true);
  3619. INTRINSIC_OP_CASE(ceil, Ceil, true);
  3620. INTRINSIC_OP_CASE(cos, Cos, true);
  3621. INTRINSIC_OP_CASE(cosh, Cosh, true);
  3622. INTRINSIC_OP_CASE(cross, Cross, false);
  3623. INTRINSIC_OP_CASE(degrees, Degrees, true);
  3624. INTRINSIC_OP_CASE(distance, Distance, false);
  3625. INTRINSIC_OP_CASE(determinant, Determinant, false);
  3626. INTRINSIC_OP_CASE(exp, Exp, true);
  3627. INTRINSIC_OP_CASE(exp2, Exp2, true);
  3628. INTRINSIC_OP_CASE_SINT_UINT(firstbithigh, FindSMsb, FindUMsb, false);
  3629. INTRINSIC_OP_CASE_SINT_UINT(ufirstbithigh, FindSMsb, FindUMsb, false);
  3630. INTRINSIC_OP_CASE(faceforward, FaceForward, false);
  3631. INTRINSIC_OP_CASE(firstbitlow, FindILsb, false);
  3632. INTRINSIC_OP_CASE(floor, Floor, true);
  3633. INTRINSIC_OP_CASE(fma, Fma, true);
  3634. INTRINSIC_OP_CASE(frac, Fract, true);
  3635. INTRINSIC_OP_CASE(length, Length, false);
  3636. INTRINSIC_OP_CASE(ldexp, Ldexp, true);
  3637. INTRINSIC_OP_CASE(lerp, FMix, true);
  3638. INTRINSIC_OP_CASE(log, Log, true);
  3639. INTRINSIC_OP_CASE(log2, Log2, true);
  3640. INTRINSIC_OP_CASE(mad, Fma, true);
  3641. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(max, SMax, UMax, FMax, true);
  3642. INTRINSIC_OP_CASE(umax, UMax, true);
  3643. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(min, SMin, UMin, FMin, true);
  3644. INTRINSIC_OP_CASE(umin, UMin, true);
  3645. INTRINSIC_OP_CASE(normalize, Normalize, false);
  3646. INTRINSIC_OP_CASE(pow, Pow, true);
  3647. INTRINSIC_OP_CASE(radians, Radians, true);
  3648. INTRINSIC_OP_CASE(reflect, Reflect, false);
  3649. INTRINSIC_OP_CASE(refract, Refract, false);
  3650. INTRINSIC_OP_CASE(rsqrt, InverseSqrt, true);
  3651. INTRINSIC_OP_CASE(smoothstep, SmoothStep, true);
  3652. INTRINSIC_OP_CASE(step, Step, true);
  3653. INTRINSIC_OP_CASE(sin, Sin, true);
  3654. INTRINSIC_OP_CASE(sinh, Sinh, true);
  3655. INTRINSIC_OP_CASE(tan, Tan, true);
  3656. INTRINSIC_OP_CASE(tanh, Tanh, true);
  3657. INTRINSIC_OP_CASE(sqrt, Sqrt, true);
  3658. INTRINSIC_OP_CASE(trunc, Trunc, true);
  3659. default:
  3660. emitError("intrinsic '%0' function unimplemented", callExpr->getExprLoc())
  3661. << callee->getName();
  3662. return 0;
  3663. }
  3664. #undef INTRINSIC_OP_CASE
  3665. #undef INTRINSIC_OP_CASE_INT_FLOAT
  3666. return 0;
  3667. }
  3668. uint32_t
  3669. SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
  3670. hlsl::IntrinsicOp opcode) {
  3671. // The signature of intrinsic atomic methods are:
  3672. // void Interlocked*(in R dest, in T value);
  3673. // void Interlocked*(in R dest, in T value, out T original_value);
  3674. // Note: ALL Interlocked*() methods are forced to have an unsigned integer
  3675. // 'value'. Meaning, T is forced to be 'unsigned int'. If the provided
  3676. // parameter is not an unsigned integer, the frontend inserts an
  3677. // 'ImplicitCastExpr' to convert it to unsigned integer. OpAtomicIAdd (and
  3678. // other SPIR-V OpAtomic* instructions) require that the pointee in 'dest' to
  3679. // be of the same type as T. This will result in an invalid SPIR-V if 'dest'
  3680. // is a signed integer typed resource such as RWTexture1D<int>. For example,
  3681. // the following OpAtomicIAdd is invalid because the pointee type defined in
  3682. // %1 is a signed integer, while the value passed to atomic add (%3) is an
  3683. // unsigned integer.
  3684. //
  3685. // %_ptr_Image_int = OpTypePointer Image %int
  3686. // %1 = OpImageTexelPointer %_ptr_Image_int %RWTexture1D_int %index %uint_0
  3687. // %2 = OpLoad %int %value
  3688. // %3 = OpBitcast %uint %2 <-------- Inserted by the frontend
  3689. // %4 = OpAtomicIAdd %int %1 %uint_1 %uint_0 %3
  3690. //
  3691. // In such cases, we bypass the forced IntegralCast.
  3692. // Moreover, the frontend does not add a cast AST node to cast uint to int
  3693. // where necessary. To ensure SPIR-V validity, we add that where necessary.
  3694. const uint32_t zero = theBuilder.getConstantUint32(0);
  3695. const uint32_t scope = theBuilder.getConstantUint32(1); // Device
  3696. const auto *dest = expr->getArg(0);
  3697. const auto baseType = dest->getType();
  3698. const uint32_t baseTypeId = typeTranslator.translateType(baseType);
  3699. const auto doArg = [baseType, this](const CallExpr *callExpr,
  3700. uint32_t argIndex) {
  3701. const Expr *valueExpr = callExpr->getArg(argIndex);
  3702. if (const auto *castExpr = dyn_cast<ImplicitCastExpr>(valueExpr))
  3703. if (castExpr->getCastKind() == CK_IntegralCast &&
  3704. castExpr->getSubExpr()->getType() == baseType)
  3705. valueExpr = castExpr->getSubExpr();
  3706. uint32_t argId = doExpr(valueExpr);
  3707. if (valueExpr->getType() != baseType)
  3708. argId = castToInt(argId, valueExpr->getType(), baseType,
  3709. valueExpr->getExprLoc());
  3710. return argId;
  3711. };
  3712. const auto writeToOutputArg = [&baseType, dest, this](
  3713. uint32_t toWrite, const CallExpr *callExpr, uint32_t outputArgIndex) {
  3714. const auto outputArg = callExpr->getArg(outputArgIndex);
  3715. const auto outputArgType = outputArg->getType();
  3716. if (baseType != outputArgType)
  3717. toWrite = castToInt(toWrite, baseType, outputArgType, dest->getExprLoc());
  3718. theBuilder.createStore(doExpr(outputArg), toWrite);
  3719. };
  3720. // If the argument is indexing into a texture/buffer, we need to create an
  3721. // OpImageTexelPointer instruction.
  3722. uint32_t ptr = 0;
  3723. if (const auto *callExpr = dyn_cast<CXXOperatorCallExpr>(dest)) {
  3724. const Expr *base = nullptr;
  3725. const Expr *index = nullptr;
  3726. if (isBufferTextureIndexing(callExpr, &base, &index)) {
  3727. const auto ptrType =
  3728. theBuilder.getPointerType(baseTypeId, spv::StorageClass::Image);
  3729. const auto baseId = doExpr(base);
  3730. const auto coordId = doExpr(index);
  3731. ptr = theBuilder.createImageTexelPointer(ptrType, baseId, coordId, zero);
  3732. }
  3733. } else {
  3734. ptr = doExpr(dest);
  3735. }
  3736. const bool isCompareExchange =
  3737. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareExchange;
  3738. const bool isCompareStore =
  3739. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareStore;
  3740. if (isCompareExchange || isCompareStore) {
  3741. const uint32_t comparator = doArg(expr, 1);
  3742. const uint32_t valueId = doArg(expr, 2);
  3743. const uint32_t originalVal = theBuilder.createAtomicCompareExchange(
  3744. baseTypeId, ptr, scope, zero, zero, valueId, comparator);
  3745. if (isCompareExchange)
  3746. writeToOutputArg(originalVal, expr, 3);
  3747. } else {
  3748. const uint32_t valueId = doArg(expr, 1);
  3749. // Since these atomic operations write through the provided pointer, the
  3750. // signed vs. unsigned opcode must be decided based on the pointee type
  3751. // of the first argument. However, the frontend decides the opcode based on
  3752. // the second argument (value). Therefore, the HLSL opcode provided by the
  3753. // frontend may be wrong. Therefore we need the following code to make sure
  3754. // we are using the correct SPIR-V opcode.
  3755. spv::Op atomicOp = translateAtomicHlslOpcodeToSpirvOpcode(opcode);
  3756. if (atomicOp == spv::Op::OpAtomicUMax && baseType->isSignedIntegerType())
  3757. atomicOp = spv::Op::OpAtomicSMax;
  3758. if (atomicOp == spv::Op::OpAtomicSMax && baseType->isUnsignedIntegerType())
  3759. atomicOp = spv::Op::OpAtomicUMax;
  3760. if (atomicOp == spv::Op::OpAtomicUMin && baseType->isSignedIntegerType())
  3761. atomicOp = spv::Op::OpAtomicSMin;
  3762. if (atomicOp == spv::Op::OpAtomicSMin && baseType->isUnsignedIntegerType())
  3763. atomicOp = spv::Op::OpAtomicUMin;
  3764. const uint32_t originalVal = theBuilder.createAtomicOp(
  3765. atomicOp, baseTypeId, ptr, scope, zero, valueId);
  3766. if (expr->getNumArgs() > 2)
  3767. writeToOutputArg(originalVal, expr, 2);
  3768. }
  3769. return 0;
  3770. }
  3771. uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
  3772. // Signature is: ret modf(x, ip)
  3773. // [in] x: the input floating-point value.
  3774. // [out] ip: the integer portion of x.
  3775. // [out] ret: the fractional portion of x.
  3776. // All of the above must be a scalar, vector, or matrix with the same
  3777. // component types. Component types can be float or int.
  3778. // The ModfStruct SPIR-V instruction returns a struct. The first member is the
  3779. // fractional part and the second member is the integer portion.
  3780. // ModfStruct {
  3781. // <scalar or vector of float> frac;
  3782. // <scalar or vector of float> ip;
  3783. // }
  3784. // Note if the input number (x) is not a float (i.e. 'x' is an int), it is
  3785. // automatically converted to float before modf is invoked. Sadly, the 'ip'
  3786. // argument is not treated the same way. Therefore, in such cases we'll have
  3787. // to manually convert the float result into int.
  3788. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3789. const Expr *arg = callExpr->getArg(0);
  3790. const Expr *ipArg = callExpr->getArg(1);
  3791. const auto argType = arg->getType();
  3792. const auto ipType = ipArg->getType();
  3793. const auto returnType = callExpr->getType();
  3794. const auto returnTypeId = typeTranslator.translateType(returnType);
  3795. const auto ipTypeId = typeTranslator.translateType(ipType);
  3796. const uint32_t argId = doExpr(arg);
  3797. const uint32_t ipId = doExpr(ipArg);
  3798. // TODO: We currently do not support non-float matrices.
  3799. QualType ipElemType = {};
  3800. if (TypeTranslator::isMxNMatrix(ipType, &ipElemType) &&
  3801. !ipElemType->isFloatingType()) {
  3802. emitError("non-floating-point matrix type unimplemented", {});
  3803. return 0;
  3804. }
  3805. // For scalar and vector argument types.
  3806. {
  3807. if (TypeTranslator::isScalarType(argType) ||
  3808. TypeTranslator::isVectorType(argType)) {
  3809. const auto argTypeId = typeTranslator.translateType(argType);
  3810. // The struct members *must* have the same type.
  3811. const auto modfStructTypeId = theBuilder.getStructType(
  3812. {argTypeId, argTypeId}, "ModfStructType", {"frac", "ip"});
  3813. const auto modf =
  3814. theBuilder.createExtInst(modfStructTypeId, glslInstSetId,
  3815. GLSLstd450::GLSLstd450ModfStruct, {argId});
  3816. auto ip = theBuilder.createCompositeExtract(argTypeId, modf, {1});
  3817. // This will do nothing if the input number (x) and the ip are both of the
  3818. // same type. Otherwise, it will convert the ip into int as necessary.
  3819. ip = castToInt(ip, argType, ipType, arg->getExprLoc());
  3820. theBuilder.createStore(ipId, ip);
  3821. return theBuilder.createCompositeExtract(argTypeId, modf, {0});
  3822. }
  3823. }
  3824. // For matrix argument types.
  3825. {
  3826. uint32_t rowCount = 0, colCount = 0;
  3827. QualType elemType = {};
  3828. if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  3829. const auto elemTypeId = typeTranslator.translateType(elemType);
  3830. const auto colTypeId = theBuilder.getVecType(elemTypeId, colCount);
  3831. const auto modfStructTypeId = theBuilder.getStructType(
  3832. {colTypeId, colTypeId}, "ModfStructType", {"frac", "ip"});
  3833. llvm::SmallVector<uint32_t, 4> fracs;
  3834. llvm::SmallVector<uint32_t, 4> ips;
  3835. for (uint32_t i = 0; i < rowCount; ++i) {
  3836. const auto curRow =
  3837. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  3838. const auto modf = theBuilder.createExtInst(
  3839. modfStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450ModfStruct,
  3840. {curRow});
  3841. auto ip = theBuilder.createCompositeExtract(colTypeId, modf, {1});
  3842. ips.push_back(ip);
  3843. fracs.push_back(
  3844. theBuilder.createCompositeExtract(colTypeId, modf, {0}));
  3845. }
  3846. theBuilder.createStore(
  3847. ipId, theBuilder.createCompositeConstruct(returnTypeId, ips));
  3848. return theBuilder.createCompositeConstruct(returnTypeId, fracs);
  3849. }
  3850. }
  3851. emitError("invalid argument type passed to Modf intrinsic function",
  3852. callExpr->getExprLoc());
  3853. return 0;
  3854. }
  3855. uint32_t SPIRVEmitter::processIntrinsicLit(const CallExpr *callExpr) {
  3856. // Signature is: float4 lit(float n_dot_l, float n_dot_h, float m)
  3857. //
  3858. // This function returns a lighting coefficient vector
  3859. // (ambient, diffuse, specular, 1) where:
  3860. // ambient = 1.
  3861. // diffuse = (n_dot_l < 0) ? 0 : n_dot_l
  3862. // specular = (n_dot_l < 0 || n_dot_h < 0) ? 0 : ((n_dot_h) * m)
  3863. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3864. const uint32_t nDotL = doExpr(callExpr->getArg(0));
  3865. const uint32_t nDotH = doExpr(callExpr->getArg(1));
  3866. const uint32_t m = doExpr(callExpr->getArg(2));
  3867. const uint32_t floatType = theBuilder.getFloat32Type();
  3868. const uint32_t boolType = theBuilder.getBoolType();
  3869. const uint32_t floatZero = theBuilder.getConstantFloat32(0);
  3870. const uint32_t floatOne = theBuilder.getConstantFloat32(1);
  3871. const uint32_t retType = typeTranslator.translateType(callExpr->getType());
  3872. const uint32_t diffuse = theBuilder.createExtInst(
  3873. floatType, glslInstSetId, GLSLstd450::GLSLstd450FMax, {floatZero, nDotL});
  3874. const uint32_t min = theBuilder.createExtInst(
  3875. floatType, glslInstSetId, GLSLstd450::GLSLstd450FMin, {nDotL, nDotH});
  3876. const uint32_t isNeg = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  3877. boolType, min, floatZero);
  3878. const uint32_t mul =
  3879. theBuilder.createBinaryOp(spv::Op::OpFMul, floatType, nDotH, m);
  3880. const uint32_t specular =
  3881. theBuilder.createSelect(floatType, isNeg, floatZero, mul);
  3882. return theBuilder.createCompositeConstruct(
  3883. retType, {floatOne, diffuse, specular, floatOne});
  3884. }
  3885. uint32_t SPIRVEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
  3886. // Signature is: ret frexp(x, exp)
  3887. // [in] x: the input floating-point value.
  3888. // [out] exp: the calculated exponent.
  3889. // [out] ret: the calculated mantissa.
  3890. // All of the above must be a scalar, vector, or matrix of *float* type.
  3891. // The FrexpStruct SPIR-V instruction returns a struct. The first
  3892. // member is the significand (mantissa) and must be of the same type as the
  3893. // input parameter, and the second member is the exponent and must always be a
  3894. // scalar or vector of 32-bit *integer* type.
  3895. // FrexpStruct {
  3896. // <scalar or vector of int/float> mantissa;
  3897. // <scalar or vector of integers> exponent;
  3898. // }
  3899. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3900. const Expr *arg = callExpr->getArg(0);
  3901. const auto argType = arg->getType();
  3902. const auto intId = theBuilder.getInt32Type();
  3903. const auto returnTypeId = typeTranslator.translateType(callExpr->getType());
  3904. const uint32_t argId = doExpr(arg);
  3905. const uint32_t expId = doExpr(callExpr->getArg(1));
  3906. // For scalar and vector argument types.
  3907. {
  3908. uint32_t elemCount = 1;
  3909. if (TypeTranslator::isScalarType(argType) ||
  3910. TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  3911. const auto argTypeId = typeTranslator.translateType(argType);
  3912. const auto expTypeId =
  3913. elemCount == 1 ? intId : theBuilder.getVecType(intId, elemCount);
  3914. const auto frexpStructTypeId = theBuilder.getStructType(
  3915. {argTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  3916. const auto frexp =
  3917. theBuilder.createExtInst(frexpStructTypeId, glslInstSetId,
  3918. GLSLstd450::GLSLstd450FrexpStruct, {argId});
  3919. const auto exponentInt =
  3920. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  3921. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  3922. // expects a float, an conversion must take place before writing the
  3923. // results.
  3924. const auto exponentFloat = theBuilder.createUnaryOp(
  3925. spv::Op::OpConvertSToF, returnTypeId, exponentInt);
  3926. theBuilder.createStore(expId, exponentFloat);
  3927. return theBuilder.createCompositeExtract(argTypeId, frexp, {0});
  3928. }
  3929. }
  3930. // For matrix argument types.
  3931. {
  3932. uint32_t rowCount = 0, colCount = 0;
  3933. if (TypeTranslator::isMxNMatrix(argType, nullptr, &rowCount, &colCount)) {
  3934. const auto floatId = theBuilder.getFloat32Type();
  3935. const auto expTypeId = theBuilder.getVecType(intId, colCount);
  3936. const auto colTypeId = theBuilder.getVecType(floatId, colCount);
  3937. const auto frexpStructTypeId = theBuilder.getStructType(
  3938. {colTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  3939. llvm::SmallVector<uint32_t, 4> exponents;
  3940. llvm::SmallVector<uint32_t, 4> mantissas;
  3941. for (uint32_t i = 0; i < rowCount; ++i) {
  3942. const auto curRow =
  3943. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  3944. const auto frexp = theBuilder.createExtInst(
  3945. frexpStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450FrexpStruct,
  3946. {curRow});
  3947. const auto exponentInt =
  3948. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  3949. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  3950. // expects a float, an conversion must take place before writing the
  3951. // results.
  3952. const auto exponentFloat = theBuilder.createUnaryOp(
  3953. spv::Op::OpConvertSToF, colTypeId, exponentInt);
  3954. exponents.push_back(exponentFloat);
  3955. mantissas.push_back(
  3956. theBuilder.createCompositeExtract(colTypeId, frexp, {0}));
  3957. }
  3958. const auto exponentsResultId =
  3959. theBuilder.createCompositeConstruct(returnTypeId, exponents);
  3960. theBuilder.createStore(expId, exponentsResultId);
  3961. return theBuilder.createCompositeConstruct(returnTypeId, mantissas);
  3962. }
  3963. }
  3964. emitError("invalid argument type passed to Frexp intrinsic function",
  3965. callExpr->getExprLoc());
  3966. return 0;
  3967. }
  3968. uint32_t SPIRVEmitter::processIntrinsicClip(const CallExpr *callExpr) {
  3969. // Discards the current pixel if the specified value is less than zero.
  3970. // TODO: If the argument can be const folded and evaluated, we could
  3971. // potentially avoid creating a branch. This would be a bit challenging for
  3972. // matrix/vector arguments.
  3973. assert(callExpr->getNumArgs() == 1u);
  3974. const Expr *arg = callExpr->getArg(0);
  3975. const auto argType = arg->getType();
  3976. const auto boolType = theBuilder.getBoolType();
  3977. uint32_t condition = 0;
  3978. // Could not determine the argument as a constant. We need to branch based on
  3979. // the argument. If the argument is a vector/matrix, clipping is done if *any*
  3980. // element of the vector/matrix is less than zero.
  3981. const uint32_t argId = doExpr(arg);
  3982. QualType elemType = {};
  3983. uint32_t elemCount = 0, rowCount = 0, colCount = 0;
  3984. if (TypeTranslator::isScalarType(argType)) {
  3985. const auto zero = getValueZero(argType);
  3986. condition = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolType,
  3987. argId, zero);
  3988. } else if (TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  3989. const auto zero = getValueZero(argType);
  3990. const auto boolVecType = theBuilder.getVecType(boolType, elemCount);
  3991. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  3992. boolVecType, argId, zero);
  3993. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  3994. } else if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount,
  3995. &colCount)) {
  3996. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  3997. const uint32_t floatVecType = theBuilder.getVecType(elemTypeId, colCount);
  3998. const uint32_t elemZeroId = getValueZero(elemType);
  3999. llvm::SmallVector<uint32_t, 4> elements(size_t(colCount), elemZeroId);
  4000. const auto zero = theBuilder.getConstantComposite(floatVecType, elements);
  4001. llvm::SmallVector<uint32_t, 4> cmpResults;
  4002. for (uint32_t i = 0; i < rowCount; ++i) {
  4003. const uint32_t lhsVec =
  4004. theBuilder.createCompositeExtract(floatVecType, argId, {i});
  4005. const auto boolColType = theBuilder.getVecType(boolType, colCount);
  4006. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  4007. boolColType, lhsVec, zero);
  4008. const auto any = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  4009. cmpResults.push_back(any);
  4010. }
  4011. const auto boolRowType = theBuilder.getVecType(boolType, rowCount);
  4012. const auto results =
  4013. theBuilder.createCompositeConstruct(boolRowType, cmpResults);
  4014. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, results);
  4015. } else {
  4016. emitError("invalid argument type passed to clip intrinsic function",
  4017. callExpr->getExprLoc());
  4018. return 0;
  4019. }
  4020. // Then we need to emit the instruction for the conditional branch.
  4021. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  4022. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  4023. // Create the branch instruction. This will end the current basic block.
  4024. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  4025. theBuilder.addSuccessor(thenBB);
  4026. theBuilder.addSuccessor(mergeBB);
  4027. theBuilder.setMergeTarget(mergeBB);
  4028. // Handle the then branch
  4029. theBuilder.setInsertPoint(thenBB);
  4030. theBuilder.createKill();
  4031. theBuilder.addSuccessor(mergeBB);
  4032. // From now on, we'll emit instructions into the merge block.
  4033. theBuilder.setInsertPoint(mergeBB);
  4034. return 0;
  4035. }
  4036. uint32_t SPIRVEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
  4037. // According the HLSL reference: clamp(X, Min, Max) takes 3 arguments. Each
  4038. // one may be int, uint, or float.
  4039. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  4040. const QualType returnType = callExpr->getType();
  4041. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  4042. GLSLstd450 glslOpcode = GLSLstd450::GLSLstd450UClamp;
  4043. if (isFloatOrVecMatOfFloatType(returnType))
  4044. glslOpcode = GLSLstd450::GLSLstd450FClamp;
  4045. else if (isSintOrVecMatOfSintType(returnType))
  4046. glslOpcode = GLSLstd450::GLSLstd450SClamp;
  4047. // Get the function parameters. Expect 3 parameters.
  4048. assert(callExpr->getNumArgs() == 3u);
  4049. const Expr *argX = callExpr->getArg(0);
  4050. const Expr *argMin = callExpr->getArg(1);
  4051. const Expr *argMax = callExpr->getArg(2);
  4052. const uint32_t argXId = doExpr(argX);
  4053. const uint32_t argMinId = doExpr(argMin);
  4054. const uint32_t argMaxId = doExpr(argMax);
  4055. // FClamp, UClamp, and SClamp do not operate on matrices, so we should perform
  4056. // the operation on each vector of the matrix.
  4057. if (TypeTranslator::isSpirvAcceptableMatrixType(argX->getType())) {
  4058. const auto actOnEachVec = [this, glslInstSetId, glslOpcode, argMinId,
  4059. argMaxId](uint32_t index, uint32_t vecType,
  4060. uint32_t curRowId) {
  4061. const auto minRowId =
  4062. theBuilder.createCompositeExtract(vecType, argMinId, {index});
  4063. const auto maxRowId =
  4064. theBuilder.createCompositeExtract(vecType, argMaxId, {index});
  4065. return theBuilder.createExtInst(vecType, glslInstSetId, glslOpcode,
  4066. {curRowId, minRowId, maxRowId});
  4067. };
  4068. return processEachVectorInMatrix(argX, argXId, actOnEachVec);
  4069. }
  4070. return theBuilder.createExtInst(returnTypeId, glslInstSetId, glslOpcode,
  4071. {argXId, argMinId, argMaxId});
  4072. }
  4073. uint32_t SPIRVEmitter::processIntrinsicMul(const CallExpr *callExpr) {
  4074. const QualType returnType = callExpr->getType();
  4075. const uint32_t returnTypeId =
  4076. typeTranslator.translateType(callExpr->getType());
  4077. // Get the function parameters. Expect 2 parameters.
  4078. assert(callExpr->getNumArgs() == 2u);
  4079. const Expr *arg0 = callExpr->getArg(0);
  4080. const Expr *arg1 = callExpr->getArg(1);
  4081. const QualType arg0Type = arg0->getType();
  4082. const QualType arg1Type = arg1->getType();
  4083. // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
  4084. // vector, or matrix. The frontend ensures that the two arguments have the
  4085. // same component type. The only allowed component types are int and float.
  4086. // mul(scalar, vector)
  4087. {
  4088. uint32_t elemCount = 0;
  4089. if (TypeTranslator::isScalarType(arg0Type) &&
  4090. TypeTranslator::isVectorType(arg1Type, nullptr, &elemCount)) {
  4091. const uint32_t arg1Id = doExpr(arg1);
  4092. // We can use OpVectorTimesScalar if arguments are floats.
  4093. if (arg0Type->isFloatingType())
  4094. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  4095. returnTypeId, arg1Id, doExpr(arg0));
  4096. // Use OpIMul for integers
  4097. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId,
  4098. createVectorSplat(arg0, elemCount),
  4099. arg1Id);
  4100. }
  4101. }
  4102. // mul(vector, scalar)
  4103. {
  4104. uint32_t elemCount = 0;
  4105. if (TypeTranslator::isVectorType(arg0Type, nullptr, &elemCount) &&
  4106. TypeTranslator::isScalarType(arg1Type)) {
  4107. const uint32_t arg0Id = doExpr(arg0);
  4108. // We can use OpVectorTimesScalar if arguments are floats.
  4109. if (arg1Type->isFloatingType())
  4110. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  4111. returnTypeId, arg0Id, doExpr(arg1));
  4112. // Use OpIMul for integers
  4113. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId, arg0Id,
  4114. createVectorSplat(arg1, elemCount));
  4115. }
  4116. }
  4117. // mul(vector, vector)
  4118. if (TypeTranslator::isVectorType(arg0Type) &&
  4119. TypeTranslator::isVectorType(arg1Type))
  4120. return processIntrinsicDot(callExpr);
  4121. // All the following cases require handling arg0 and arg1 expressions first.
  4122. const uint32_t arg0Id = doExpr(arg0);
  4123. const uint32_t arg1Id = doExpr(arg1);
  4124. // mul(scalar, scalar)
  4125. if (TypeTranslator::isScalarType(arg0Type) &&
  4126. TypeTranslator::isScalarType(arg1Type))
  4127. return theBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type),
  4128. returnTypeId, arg0Id, arg1Id);
  4129. // mul(scalar, matrix)
  4130. if (TypeTranslator::isScalarType(arg0Type) &&
  4131. TypeTranslator::isMxNMatrix(arg1Type)) {
  4132. // We currently only support float matrices. So we can use
  4133. // OpMatrixTimesScalar
  4134. if (arg0Type->isFloatingType())
  4135. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  4136. returnTypeId, arg1Id, arg0Id);
  4137. }
  4138. // mul(matrix, scalar)
  4139. if (TypeTranslator::isScalarType(arg1Type) &&
  4140. TypeTranslator::isMxNMatrix(arg0Type)) {
  4141. // We currently only support float matrices. So we can use
  4142. // OpMatrixTimesScalar
  4143. if (arg1Type->isFloatingType())
  4144. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  4145. returnTypeId, arg0Id, arg1Id);
  4146. }
  4147. // mul(vector, matrix)
  4148. {
  4149. QualType elemType = {};
  4150. uint32_t elemCount = 0, numRows = 0;
  4151. if (TypeTranslator::isVectorType(arg0Type, &elemType, &elemCount) &&
  4152. TypeTranslator::isMxNMatrix(arg1Type, nullptr, &numRows, nullptr) &&
  4153. elemType->isFloatingType()) {
  4154. assert(elemCount == numRows);
  4155. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
  4156. returnTypeId, arg1Id, arg0Id);
  4157. }
  4158. }
  4159. // mul(matrix, vector)
  4160. {
  4161. QualType elemType = {};
  4162. uint32_t elemCount = 0, numCols = 0;
  4163. if (TypeTranslator::isMxNMatrix(arg0Type, nullptr, nullptr, &numCols) &&
  4164. TypeTranslator::isVectorType(arg1Type, &elemType, &elemCount) &&
  4165. elemType->isFloatingType()) {
  4166. assert(elemCount == numCols);
  4167. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
  4168. returnTypeId, arg1Id, arg0Id);
  4169. }
  4170. }
  4171. // mul(matrix, matrix)
  4172. {
  4173. QualType elemType = {};
  4174. uint32_t arg0Cols = 0, arg1Rows = 0;
  4175. if (TypeTranslator::isMxNMatrix(arg0Type, &elemType, nullptr, &arg0Cols) &&
  4176. TypeTranslator::isMxNMatrix(arg1Type, nullptr, &arg1Rows, nullptr) &&
  4177. elemType->isFloatingType()) {
  4178. assert(arg0Cols == arg1Rows);
  4179. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
  4180. returnTypeId, arg1Id, arg0Id);
  4181. }
  4182. }
  4183. emitError("invalid argument type passed to mul intrinsic function",
  4184. callExpr->getExprLoc());
  4185. return 0;
  4186. }
  4187. uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
  4188. const QualType returnType = callExpr->getType();
  4189. const uint32_t returnTypeId =
  4190. typeTranslator.translateType(callExpr->getType());
  4191. // Get the function parameters. Expect 2 vectors as parameters.
  4192. assert(callExpr->getNumArgs() == 2u);
  4193. const Expr *arg0 = callExpr->getArg(0);
  4194. const Expr *arg1 = callExpr->getArg(1);
  4195. const uint32_t arg0Id = doExpr(arg0);
  4196. const uint32_t arg1Id = doExpr(arg1);
  4197. QualType arg0Type = arg0->getType();
  4198. QualType arg1Type = arg1->getType();
  4199. const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
  4200. const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
  4201. const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
  4202. const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
  4203. assert(returnType == vec1ComponentType);
  4204. assert(vec0ComponentType == vec1ComponentType);
  4205. assert(vec0Size == vec1Size);
  4206. assert(vec0Size >= 1 && vec0Size <= 4);
  4207. // According to HLSL reference, the dot function only works on integers
  4208. // and floats.
  4209. assert(returnType->isFloatingType() || returnType->isIntegerType());
  4210. // Special case: dot product of two vectors, each of size 1. That is
  4211. // basically the same as regular multiplication of 2 scalars.
  4212. if (vec0Size == 1) {
  4213. const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
  4214. return theBuilder.createBinaryOp(spvOp, returnTypeId, arg0Id, arg1Id);
  4215. }
  4216. // If the vectors are of type Float, we can use OpDot.
  4217. if (returnType->isFloatingType()) {
  4218. return theBuilder.createBinaryOp(spv::Op::OpDot, returnTypeId, arg0Id,
  4219. arg1Id);
  4220. }
  4221. // Vector component type is Integer (signed or unsigned).
  4222. // Create all instructions necessary to perform a dot product on
  4223. // two integer vectors. SPIR-V OpDot does not support integer vectors.
  4224. // Therefore, we use other SPIR-V instructions (addition and
  4225. // multiplication).
  4226. else {
  4227. uint32_t result = 0;
  4228. llvm::SmallVector<uint32_t, 4> multIds;
  4229. const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
  4230. const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
  4231. // Extract members from the two vectors and multiply them.
  4232. for (unsigned int i = 0; i < vec0Size; ++i) {
  4233. const uint32_t vec0member =
  4234. theBuilder.createCompositeExtract(returnTypeId, arg0Id, {i});
  4235. const uint32_t vec1member =
  4236. theBuilder.createCompositeExtract(returnTypeId, arg1Id, {i});
  4237. const uint32_t multId = theBuilder.createBinaryOp(multSpvOp, returnTypeId,
  4238. vec0member, vec1member);
  4239. multIds.push_back(multId);
  4240. }
  4241. // Add all the multiplications.
  4242. result = multIds[0];
  4243. for (unsigned int i = 1; i < vec0Size; ++i) {
  4244. const uint32_t additionId =
  4245. theBuilder.createBinaryOp(addSpvOp, returnTypeId, result, multIds[i]);
  4246. result = additionId;
  4247. }
  4248. return result;
  4249. }
  4250. }
  4251. uint32_t SPIRVEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
  4252. // 'rcp' takes only 1 argument that is a scalar, vector, or matrix of type
  4253. // float or double.
  4254. assert(callExpr->getNumArgs() == 1u);
  4255. const QualType returnType = callExpr->getType();
  4256. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  4257. const Expr *arg = callExpr->getArg(0);
  4258. const uint32_t argId = doExpr(arg);
  4259. const QualType argType = arg->getType();
  4260. // For cases with matrix argument.
  4261. QualType elemType = {};
  4262. uint32_t numRows = 0, numCols = 0;
  4263. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  4264. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  4265. const auto actOnEachVec = [this, vecOne](
  4266. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4267. return theBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne,
  4268. curRowId);
  4269. };
  4270. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4271. }
  4272. // For cases with scalar or vector arguments.
  4273. return theBuilder.createBinaryOp(spv::Op::OpFDiv, returnTypeId,
  4274. getValueOne(argType), argId);
  4275. }
  4276. uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
  4277. spv::Op spvOp) {
  4278. // 'all' and 'any' take only 1 parameter.
  4279. assert(callExpr->getNumArgs() == 1u);
  4280. const QualType returnType = callExpr->getType();
  4281. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  4282. const Expr *arg = callExpr->getArg(0);
  4283. const QualType argType = arg->getType();
  4284. // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
  4285. // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
  4286. {
  4287. QualType scalarType = {};
  4288. if (TypeTranslator::isScalarType(argType, &scalarType) &&
  4289. (scalarType->isBooleanType() || scalarType->isFloatingType() ||
  4290. scalarType->isIntegerType()))
  4291. return castToBool(doExpr(arg), argType, returnType);
  4292. }
  4293. // Handle vectors larger than 1, Mx1 matrices, and 1xN matrices as arguments.
  4294. // Cast the vector to a boolean vector, then run OpAny/OpAll on it.
  4295. {
  4296. QualType elemType = {};
  4297. uint32_t size = 0;
  4298. if (TypeTranslator::isVectorType(argType, &elemType, &size)) {
  4299. const QualType castToBoolType =
  4300. astContext.getExtVectorType(returnType, size);
  4301. uint32_t castedToBoolId =
  4302. castToBool(doExpr(arg), argType, castToBoolType);
  4303. return theBuilder.createUnaryOp(spvOp, returnTypeId, castedToBoolId);
  4304. }
  4305. }
  4306. // Handle MxN matrices as arguments.
  4307. {
  4308. QualType elemType = {};
  4309. uint32_t matRowCount = 0, matColCount = 0;
  4310. if (TypeTranslator::isMxNMatrix(argType, &elemType, &matRowCount,
  4311. &matColCount)) {
  4312. if (!elemType->isFloatingType()) {
  4313. emitError("non-floating-point matrix arguments in all/any intrinsic "
  4314. "function unimplemented",
  4315. callExpr->getExprLoc());
  4316. return 0;
  4317. }
  4318. uint32_t matrixId = doExpr(arg);
  4319. const uint32_t vecType = typeTranslator.getComponentVectorType(argType);
  4320. llvm::SmallVector<uint32_t, 4> rowResults;
  4321. for (uint32_t i = 0; i < matRowCount; ++i) {
  4322. // Extract the row which is a float vector of size matColCount.
  4323. const uint32_t rowFloatVec =
  4324. theBuilder.createCompositeExtract(vecType, matrixId, {i});
  4325. // Cast the float vector to boolean vector.
  4326. const auto rowFloatQualType =
  4327. astContext.getExtVectorType(elemType, matColCount);
  4328. const auto rowBoolQualType =
  4329. astContext.getExtVectorType(returnType, matColCount);
  4330. const uint32_t rowBoolVec =
  4331. castToBool(rowFloatVec, rowFloatQualType, rowBoolQualType);
  4332. // Perform OpAny/OpAll on the boolean vector.
  4333. rowResults.push_back(
  4334. theBuilder.createUnaryOp(spvOp, returnTypeId, rowBoolVec));
  4335. }
  4336. // Create a new vector that is the concatenation of results of all rows.
  4337. uint32_t boolId = theBuilder.getBoolType();
  4338. uint32_t vecOfBoolsId = theBuilder.getVecType(boolId, matRowCount);
  4339. const uint32_t rowResultsId =
  4340. theBuilder.createCompositeConstruct(vecOfBoolsId, rowResults);
  4341. // Run OpAny/OpAll on the newly-created vector.
  4342. return theBuilder.createUnaryOp(spvOp, returnTypeId, rowResultsId);
  4343. }
  4344. }
  4345. // All types should be handled already.
  4346. llvm_unreachable("Unknown argument type passed to all()/any().");
  4347. return 0;
  4348. }
  4349. uint32_t SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
  4350. // This function handles 'asint', 'asuint', 'asfloat', and 'asdouble'.
  4351. // Method 1: ret asint(arg)
  4352. // arg component type = {float, uint}
  4353. // arg template type = {scalar, vector, matrix}
  4354. // ret template type = same as arg template type.
  4355. // ret component type = int
  4356. // Method 2: ret asuint(arg)
  4357. // arg component type = {float, int}
  4358. // arg template type = {scalar, vector, matrix}
  4359. // ret template type = same as arg template type.
  4360. // ret component type = uint
  4361. // Method 3: ret asfloat(arg)
  4362. // arg component type = {float, uint, int}
  4363. // arg template type = {scalar, vector, matrix}
  4364. // ret template type = same as arg template type.
  4365. // ret component type = float
  4366. // Method 4: double asdouble(uint lowbits, uint highbits)
  4367. // Method 5: double2 asdouble(uint2 lowbits, uint2 highbits)
  4368. // Method 6:
  4369. // void asuint(
  4370. // in double value,
  4371. // out uint lowbits,
  4372. // out uint highbits
  4373. // );
  4374. const QualType returnType = callExpr->getType();
  4375. const uint32_t numArgs = callExpr->getNumArgs();
  4376. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  4377. const Expr *arg0 = callExpr->getArg(0);
  4378. const QualType argType = arg0->getType();
  4379. // Method 3 return type may be the same as arg type, so it would be a no-op.
  4380. if (returnType.getCanonicalType() == argType.getCanonicalType())
  4381. return doExpr(arg0);
  4382. // SPIR-V does not support non-floating point matrices. For the above methods
  4383. // that involve matrices, either the input or the output is a non-float
  4384. // matrix. (except for 'asfloat' taking a float matrix and returning a float
  4385. // matrix, which is a no-op and is handled by the condition above).
  4386. if (TypeTranslator::isMxNMatrix(argType)) {
  4387. emitError("non-floating-point matrix type unimplemented",
  4388. callExpr->getExprLoc());
  4389. return 0;
  4390. }
  4391. switch (numArgs) {
  4392. case 1: {
  4393. // Handling Method 1, 2, and 3.
  4394. return theBuilder.createUnaryOp(spv::Op::OpBitcast, returnTypeId,
  4395. doExpr(arg0));
  4396. }
  4397. case 2: {
  4398. const uint32_t lowbits = doExpr(arg0);
  4399. const uint32_t highbits = doExpr(callExpr->getArg(1));
  4400. const uint32_t uintType = theBuilder.getUint32Type();
  4401. const uint32_t doubleType = theBuilder.getFloat64Type();
  4402. // Handling Method 4
  4403. if (argType->isUnsignedIntegerType()) {
  4404. const uint32_t uintVec2Type = theBuilder.getVecType(uintType, 2);
  4405. const uint32_t operand = theBuilder.createCompositeConstruct(
  4406. uintVec2Type, {lowbits, highbits});
  4407. return theBuilder.createUnaryOp(spv::Op::OpBitcast, doubleType, operand);
  4408. }
  4409. // Handling Method 5
  4410. else {
  4411. const uint32_t uintVec4Type = theBuilder.getVecType(uintType, 4);
  4412. const uint32_t doubleVec2Type = theBuilder.getVecType(doubleType, 2);
  4413. const uint32_t operand = theBuilder.createVectorShuffle(
  4414. uintVec4Type, lowbits, highbits, {0, 2, 1, 3});
  4415. return theBuilder.createUnaryOp(spv::Op::OpBitcast, doubleVec2Type,
  4416. operand);
  4417. }
  4418. }
  4419. case 3: {
  4420. // Handling Method 6.
  4421. const uint32_t value = doExpr(arg0);
  4422. const uint32_t lowbits = doExpr(callExpr->getArg(1));
  4423. const uint32_t highbits = doExpr(callExpr->getArg(2));
  4424. const uint32_t uintType = theBuilder.getUint32Type();
  4425. const uint32_t uintVec2Type = theBuilder.getVecType(uintType, 2);
  4426. const uint32_t vecResult =
  4427. theBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value);
  4428. theBuilder.createStore(
  4429. lowbits, theBuilder.createCompositeExtract(uintType, vecResult, {0}));
  4430. theBuilder.createStore(
  4431. highbits, theBuilder.createCompositeExtract(uintType, vecResult, {1}));
  4432. return 0;
  4433. }
  4434. default:
  4435. emitError("unrecognized signature for intrinsic function %0",
  4436. callExpr->getExprLoc())
  4437. << callExpr->getDirectCallee()->getName();
  4438. return 0;
  4439. }
  4440. }
  4441. uint32_t SPIRVEmitter::processIntrinsicIsFinite(const CallExpr *callExpr) {
  4442. // Since OpIsFinite needs the Kernel capability, translation is instead done
  4443. // using OpIsNan and OpIsInf:
  4444. // isFinite = !(isNan || isInf)
  4445. const auto arg = doExpr(callExpr->getArg(0));
  4446. const auto returnType = typeTranslator.translateType(callExpr->getType());
  4447. const auto isNan =
  4448. theBuilder.createUnaryOp(spv::Op::OpIsNan, returnType, arg);
  4449. const auto isInf =
  4450. theBuilder.createUnaryOp(spv::Op::OpIsInf, returnType, arg);
  4451. const auto isNanOrInf =
  4452. theBuilder.createBinaryOp(spv::Op::OpLogicalOr, returnType, isNan, isInf);
  4453. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, returnType,
  4454. isNanOrInf);
  4455. }
  4456. uint32_t SPIRVEmitter::processIntrinsicSinCos(const CallExpr *callExpr) {
  4457. // Since there is no sincos equivalent in SPIR-V, we need to perform Sin
  4458. // once and Cos once. We can reuse existing Sine/Cosine handling functions.
  4459. CallExpr *sincosExpr =
  4460. new (astContext) CallExpr(astContext, Stmt::StmtClass::NoStmtClass, {});
  4461. sincosExpr->setType(callExpr->getArg(0)->getType());
  4462. sincosExpr->setNumArgs(astContext, 1);
  4463. sincosExpr->setArg(0, const_cast<Expr *>(callExpr->getArg(0)));
  4464. // Perform Sin and store results in argument 1.
  4465. const uint32_t sin =
  4466. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Sin,
  4467. /*actPerRowForMatrices*/ true);
  4468. theBuilder.createStore(doExpr(callExpr->getArg(1)), sin);
  4469. // Perform Cos and store results in argument 2.
  4470. const uint32_t cos =
  4471. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Cos,
  4472. /*actPerRowForMatrices*/ true);
  4473. theBuilder.createStore(doExpr(callExpr->getArg(2)), cos);
  4474. return 0;
  4475. }
  4476. uint32_t SPIRVEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
  4477. const auto *arg = callExpr->getArg(0);
  4478. const auto argId = doExpr(arg);
  4479. const auto argType = arg->getType();
  4480. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  4481. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  4482. if (argType->isFloatingType()) {
  4483. const uint32_t floatZero = getValueZero(argType);
  4484. const uint32_t floatOne = getValueOne(argType);
  4485. return theBuilder.createExtInst(returnType, glslInstSetId,
  4486. GLSLstd450::GLSLstd450FClamp,
  4487. {argId, floatZero, floatOne});
  4488. }
  4489. QualType elemType = {};
  4490. uint32_t vecSize = 0;
  4491. if (TypeTranslator::isVectorType(argType, &elemType, &vecSize)) {
  4492. const uint32_t vecZero = getVecValueZero(elemType, vecSize);
  4493. const uint32_t vecOne = getVecValueOne(elemType, vecSize);
  4494. return theBuilder.createExtInst(returnType, glslInstSetId,
  4495. GLSLstd450::GLSLstd450FClamp,
  4496. {argId, vecZero, vecOne});
  4497. }
  4498. uint32_t numRows = 0, numCols = 0;
  4499. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  4500. const uint32_t vecZero = getVecValueZero(elemType, numCols);
  4501. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  4502. const auto actOnEachVec = [this, vecZero, vecOne, glslInstSetId](
  4503. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4504. return theBuilder.createExtInst(vecType, glslInstSetId,
  4505. GLSLstd450::GLSLstd450FClamp,
  4506. {curRowId, vecZero, vecOne});
  4507. };
  4508. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4509. }
  4510. emitError("invalid argument type passed to saturate intrinsic function",
  4511. callExpr->getExprLoc());
  4512. return 0;
  4513. }
  4514. uint32_t SPIRVEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
  4515. // Import the GLSL.std.450 extended instruction set.
  4516. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  4517. const Expr *arg = callExpr->getArg(0);
  4518. const QualType returnType = callExpr->getType();
  4519. const QualType argType = arg->getType();
  4520. assert(isFloatOrVecMatOfFloatType(argType));
  4521. const uint32_t argTypeId = typeTranslator.translateType(argType);
  4522. const uint32_t argId = doExpr(arg);
  4523. uint32_t floatSignResultId = 0;
  4524. // For matrices, we can perform the instruction on each vector of the matrix.
  4525. if (TypeTranslator::isSpirvAcceptableMatrixType(argType)) {
  4526. const auto actOnEachVec = [this, glslInstSetId](
  4527. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4528. return theBuilder.createExtInst(vecType, glslInstSetId,
  4529. GLSLstd450::GLSLstd450FSign, {curRowId});
  4530. };
  4531. floatSignResultId = processEachVectorInMatrix(arg, argId, actOnEachVec);
  4532. } else {
  4533. floatSignResultId = theBuilder.createExtInst(
  4534. argTypeId, glslInstSetId, GLSLstd450::GLSLstd450FSign, {argId});
  4535. }
  4536. return castToInt(floatSignResultId, arg->getType(), returnType,
  4537. arg->getExprLoc());
  4538. }
  4539. uint32_t SPIRVEmitter::processIntrinsicF16ToF32(const CallExpr *callExpr) {
  4540. // f16tof32() takes in (vector of) uint and returns (vector of) float.
  4541. // The frontend should guarantee that by inserting implicit casts.
  4542. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  4543. const uint32_t f32TypeId = theBuilder.getFloat32Type();
  4544. const uint32_t u32TypeId = theBuilder.getUint32Type();
  4545. const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
  4546. const auto *arg = callExpr->getArg(0);
  4547. const uint32_t argId = doExpr(arg);
  4548. uint32_t elemCount = {};
  4549. if (TypeTranslator::isVectorType(arg->getType(), nullptr, &elemCount)) {
  4550. // The input is a vector. We need to handle each element separately.
  4551. llvm::SmallVector<uint32_t, 4> elements;
  4552. for (uint32_t i = 0; i < elemCount; ++i) {
  4553. const uint32_t srcElem =
  4554. theBuilder.createCompositeExtract(u32TypeId, argId, {i});
  4555. const uint32_t convert = theBuilder.createExtInst(
  4556. v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, srcElem);
  4557. elements.push_back(
  4558. theBuilder.createCompositeExtract(f32TypeId, convert, {0}));
  4559. }
  4560. return theBuilder.createCompositeConstruct(
  4561. theBuilder.getVecType(f32TypeId, elemCount), elements);
  4562. }
  4563. const uint32_t convert = theBuilder.createExtInst(
  4564. v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, argId);
  4565. // f16tof32() converts the float16 stored in the low-half of the uint to
  4566. // a float. So just need to return the first component.
  4567. return theBuilder.createCompositeExtract(f32TypeId, convert, {0});
  4568. }
  4569. uint32_t SPIRVEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
  4570. // f32tof16() takes in (vector of) float and returns (vector of) uint.
  4571. // The frontend should guarantee that by inserting implicit casts.
  4572. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  4573. const uint32_t f32TypeId = theBuilder.getFloat32Type();
  4574. const uint32_t u32TypeId = theBuilder.getUint32Type();
  4575. const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
  4576. const uint32_t zero = theBuilder.getConstantFloat32(0);
  4577. const auto *arg = callExpr->getArg(0);
  4578. const uint32_t argId = doExpr(arg);
  4579. uint32_t elemCount = {};
  4580. if (TypeTranslator::isVectorType(arg->getType(), nullptr, &elemCount)) {
  4581. // The input is a vector. We need to handle each element separately.
  4582. llvm::SmallVector<uint32_t, 4> elements;
  4583. for (uint32_t i = 0; i < elemCount; ++i) {
  4584. const uint32_t srcElem =
  4585. theBuilder.createCompositeExtract(f32TypeId, argId, {i});
  4586. const uint32_t srcVec =
  4587. theBuilder.createCompositeConstruct(v2f32TypeId, {srcElem, zero});
  4588. elements.push_back(theBuilder.createExtInst(
  4589. u32TypeId, glsl, GLSLstd450::GLSLstd450PackHalf2x16, srcVec));
  4590. }
  4591. return theBuilder.createCompositeConstruct(
  4592. theBuilder.getVecType(u32TypeId, elemCount), elements);
  4593. }
  4594. // f16tof32() stores the float into the low-half of the uint. So we need
  4595. // to supply another zero to take the other half.
  4596. const uint32_t srcVec =
  4597. theBuilder.createCompositeConstruct(v2f32TypeId, {argId, zero});
  4598. return theBuilder.createExtInst(u32TypeId, glsl,
  4599. GLSLstd450::GLSLstd450PackHalf2x16, srcVec);
  4600. }
  4601. uint32_t SPIRVEmitter::processIntrinsicUsingSpirvInst(
  4602. const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
  4603. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  4604. if (callExpr->getNumArgs() == 1u) {
  4605. const Expr *arg = callExpr->getArg(0);
  4606. const uint32_t argId = doExpr(arg);
  4607. // If the instruction does not operate on matrices, we can perform the
  4608. // instruction on each vector of the matrix.
  4609. if (actPerRowForMatrices &&
  4610. TypeTranslator::isSpirvAcceptableMatrixType(arg->getType())) {
  4611. const auto actOnEachVec = [this, opcode](
  4612. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4613. return theBuilder.createUnaryOp(opcode, vecType, {curRowId});
  4614. };
  4615. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4616. }
  4617. return theBuilder.createUnaryOp(opcode, returnType, {argId});
  4618. } else if (callExpr->getNumArgs() == 2u) {
  4619. const Expr *arg0 = callExpr->getArg(0);
  4620. const uint32_t arg0Id = doExpr(arg0);
  4621. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4622. // If the instruction does not operate on matrices, we can perform the
  4623. // instruction on each vector of the matrix.
  4624. if (actPerRowForMatrices &&
  4625. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4626. const auto actOnEachVec = [this, opcode, arg1Id](
  4627. uint32_t index, uint32_t vecType, uint32_t arg0RowId) {
  4628. const uint32_t arg1RowId =
  4629. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4630. return theBuilder.createBinaryOp(opcode, vecType, arg0RowId, arg1RowId);
  4631. };
  4632. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4633. }
  4634. return theBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id);
  4635. }
  4636. emitError("unsupported intrinsic function %0", callExpr->getExprLoc())
  4637. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  4638. return 0;
  4639. }
  4640. uint32_t SPIRVEmitter::processIntrinsicUsingGLSLInst(
  4641. const CallExpr *callExpr, GLSLstd450 opcode, bool actPerRowForMatrices) {
  4642. // Import the GLSL.std.450 extended instruction set.
  4643. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  4644. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  4645. if (callExpr->getNumArgs() == 1u) {
  4646. const Expr *arg = callExpr->getArg(0);
  4647. const uint32_t argId = doExpr(arg);
  4648. // If the instruction does not operate on matrices, we can perform the
  4649. // instruction on each vector of the matrix.
  4650. if (actPerRowForMatrices &&
  4651. TypeTranslator::isSpirvAcceptableMatrixType(arg->getType())) {
  4652. const auto actOnEachVec = [this, glslInstSetId, opcode](
  4653. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4654. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4655. {curRowId});
  4656. };
  4657. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4658. }
  4659. return theBuilder.createExtInst(returnType, glslInstSetId, opcode, {argId});
  4660. } else if (callExpr->getNumArgs() == 2u) {
  4661. const Expr *arg0 = callExpr->getArg(0);
  4662. const uint32_t arg0Id = doExpr(arg0);
  4663. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4664. // If the instruction does not operate on matrices, we can perform the
  4665. // instruction on each vector of the matrix.
  4666. if (actPerRowForMatrices &&
  4667. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4668. const auto actOnEachVec = [this, glslInstSetId, opcode, arg1Id](
  4669. uint32_t index, uint32_t vecType, uint32_t arg0RowId) {
  4670. const uint32_t arg1RowId =
  4671. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4672. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4673. {arg0RowId, arg1RowId});
  4674. };
  4675. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4676. }
  4677. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  4678. {arg0Id, arg1Id});
  4679. } else if (callExpr->getNumArgs() == 3u) {
  4680. const Expr *arg0 = callExpr->getArg(0);
  4681. const uint32_t arg0Id = doExpr(arg0);
  4682. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4683. const uint32_t arg2Id = doExpr(callExpr->getArg(2));
  4684. // If the instruction does not operate on matrices, we can perform the
  4685. // instruction on each vector of the matrix.
  4686. if (actPerRowForMatrices &&
  4687. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4688. const auto actOnEachVec = [this, glslInstSetId, opcode, arg0Id, arg1Id,
  4689. arg2Id](uint32_t index, uint32_t vecType,
  4690. uint32_t arg0RowId) {
  4691. const uint32_t arg1RowId =
  4692. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4693. const uint32_t arg2RowId =
  4694. theBuilder.createCompositeExtract(vecType, arg2Id, {index});
  4695. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4696. {arg0RowId, arg1RowId, arg2RowId});
  4697. };
  4698. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4699. }
  4700. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  4701. {arg0Id, arg1Id, arg2Id});
  4702. }
  4703. emitError("unsupported intrinsic function %0", callExpr->getExprLoc())
  4704. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  4705. return 0;
  4706. }
  4707. uint32_t SPIRVEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
  4708. // Since there is no log10 instruction in SPIR-V, we can use:
  4709. // log10(x) = log2(x) * ( 1 / log2(10) )
  4710. // 1 / log2(10) = 0.30103
  4711. const auto scale = theBuilder.getConstantFloat32(0.30103f);
  4712. const auto log2 =
  4713. processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450Log2, true);
  4714. const auto returnType = callExpr->getType();
  4715. const auto returnTypeId = typeTranslator.translateType(returnType);
  4716. spv::Op scaleOp = TypeTranslator::isScalarType(returnType)
  4717. ? spv::Op::OpFMul
  4718. : TypeTranslator::isVectorType(returnType)
  4719. ? spv::Op::OpVectorTimesScalar
  4720. : spv::Op::OpMatrixTimesScalar;
  4721. return theBuilder.createBinaryOp(scaleOp, returnTypeId, log2, scale);
  4722. }
  4723. uint32_t SPIRVEmitter::getValueZero(QualType type) {
  4724. {
  4725. QualType scalarType = {};
  4726. if (TypeTranslator::isScalarType(type, &scalarType)) {
  4727. if (scalarType->isSignedIntegerType()) {
  4728. return theBuilder.getConstantInt32(0);
  4729. }
  4730. if (scalarType->isUnsignedIntegerType()) {
  4731. return theBuilder.getConstantUint32(0);
  4732. }
  4733. if (scalarType->isFloatingType()) {
  4734. return theBuilder.getConstantFloat32(0.0);
  4735. }
  4736. }
  4737. }
  4738. {
  4739. QualType elemType = {};
  4740. uint32_t size = {};
  4741. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  4742. return getVecValueZero(elemType, size);
  4743. }
  4744. }
  4745. // TODO: Handle getValueZero for MxN matrices.
  4746. emitError("getting value 0 for type %0 unimplemented", {})
  4747. << type.getAsString();
  4748. return 0;
  4749. }
  4750. uint32_t SPIRVEmitter::getVecValueZero(QualType elemType, uint32_t size) {
  4751. const uint32_t elemZeroId = getValueZero(elemType);
  4752. if (size == 1)
  4753. return elemZeroId;
  4754. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemZeroId);
  4755. const uint32_t vecType =
  4756. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  4757. return theBuilder.getConstantComposite(vecType, elements);
  4758. }
  4759. uint32_t SPIRVEmitter::getValueOne(QualType type) {
  4760. {
  4761. QualType scalarType = {};
  4762. if (TypeTranslator::isScalarType(type, &scalarType)) {
  4763. // TODO: Support other types such as short, half, etc.
  4764. if (scalarType->isSignedIntegerType()) {
  4765. return theBuilder.getConstantInt32(1);
  4766. }
  4767. if (scalarType->isUnsignedIntegerType()) {
  4768. return theBuilder.getConstantUint32(1);
  4769. }
  4770. if (const auto *builtinType = scalarType->getAs<BuiltinType>()) {
  4771. // TODO: Add support for other types that are not covered yet.
  4772. switch (builtinType->getKind()) {
  4773. case BuiltinType::Double:
  4774. return theBuilder.getConstantFloat64(1.0);
  4775. case BuiltinType::Float:
  4776. return theBuilder.getConstantFloat32(1.0);
  4777. }
  4778. }
  4779. }
  4780. }
  4781. {
  4782. QualType elemType = {};
  4783. uint32_t size = {};
  4784. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  4785. return getVecValueOne(elemType, size);
  4786. }
  4787. }
  4788. emitError("getting value 1 for type %0 unimplemented", {}) << type;
  4789. return 0;
  4790. }
  4791. uint32_t SPIRVEmitter::getVecValueOne(QualType elemType, uint32_t size) {
  4792. const uint32_t elemOneId = getValueOne(elemType);
  4793. if (size == 1)
  4794. return elemOneId;
  4795. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemOneId);
  4796. const uint32_t vecType =
  4797. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  4798. return theBuilder.getConstantComposite(vecType, elements);
  4799. }
  4800. uint32_t SPIRVEmitter::getMatElemValueOne(QualType type) {
  4801. assert(hlsl::IsHLSLMatType(type));
  4802. const auto elemType = hlsl::GetHLSLMatElementType(type);
  4803. uint32_t rowCount = 0, colCount = 0;
  4804. hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
  4805. if (rowCount == 1 && colCount == 1)
  4806. return getValueOne(elemType);
  4807. if (colCount == 1)
  4808. return getVecValueOne(elemType, rowCount);
  4809. return getVecValueOne(elemType, colCount);
  4810. }
  4811. uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
  4812. const QualType targetType) {
  4813. if (targetType->isBooleanType()) {
  4814. const bool boolValue = value.getInt().getBoolValue();
  4815. return theBuilder.getConstantBool(boolValue);
  4816. }
  4817. if (targetType->isIntegerType()) {
  4818. const llvm::APInt &intValue = value.getInt();
  4819. return translateAPInt(intValue, targetType);
  4820. }
  4821. if (targetType->isFloatingType()) {
  4822. const llvm::APFloat &floatValue = value.getFloat();
  4823. return translateAPFloat(floatValue, targetType);
  4824. }
  4825. if (hlsl::IsHLSLVecType(targetType)) {
  4826. const uint32_t vecType = typeTranslator.translateType(targetType);
  4827. const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
  4828. const auto numElements = value.getVectorLength();
  4829. // Special case for vectors of size 1. SPIR-V doesn't support this vector
  4830. // size so we need to translate it to scalar values.
  4831. if (numElements == 1) {
  4832. return translateAPValue(value.getVectorElt(0), elemType);
  4833. }
  4834. llvm::SmallVector<uint32_t, 4> elements;
  4835. for (uint32_t i = 0; i < numElements; ++i) {
  4836. elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
  4837. }
  4838. return theBuilder.getConstantComposite(vecType, elements);
  4839. }
  4840. emitError("APValue of type %0 unimplemented", {}) << value.getKind();
  4841. value.dump();
  4842. return 0;
  4843. }
  4844. uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
  4845. QualType targetType) {
  4846. if (targetType->isSignedIntegerType()) {
  4847. // Try to see if this integer can be represented in 32-bit
  4848. if (intValue.isSignedIntN(32))
  4849. return theBuilder.getConstantInt32(
  4850. static_cast<int32_t>(intValue.getSExtValue()));
  4851. } else {
  4852. // Try to see if this integer can be represented in 32-bit
  4853. if (intValue.isIntN(32))
  4854. return theBuilder.getConstantUint32(
  4855. static_cast<uint32_t>(intValue.getZExtValue()));
  4856. }
  4857. emitError("APInt for target bitwidth %0 unimplemented", {})
  4858. << astContext.getIntWidth(targetType);
  4859. return 0;
  4860. }
  4861. uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
  4862. QualType targetType) {
  4863. const auto &semantics = astContext.getFloatTypeSemantics(targetType);
  4864. const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
  4865. switch (bitwidth) {
  4866. case 32:
  4867. return theBuilder.getConstantFloat32(floatValue.convertToFloat());
  4868. case 64:
  4869. return theBuilder.getConstantFloat64(floatValue.convertToDouble());
  4870. default:
  4871. break;
  4872. }
  4873. emitError("APFloat for target bitwidth %0 unimplemented", {}) << bitwidth;
  4874. return 0;
  4875. }
  4876. uint32_t SPIRVEmitter::tryToEvaluateAsConst(const Expr *expr) {
  4877. Expr::EvalResult evalResult;
  4878. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  4879. !evalResult.HasSideEffects) {
  4880. return translateAPValue(evalResult.Val, expr->getType());
  4881. }
  4882. return 0;
  4883. }
  4884. spv::ExecutionModel
  4885. SPIRVEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
  4886. // DXIL Models are:
  4887. // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
  4888. // vs_<version> : Vertex Shader : Vertex Shader
  4889. // hs_<version> : Hull Shader : Tassellation Control Shader
  4890. // ds_<version> : Domain Shader : Tessellation Evaluation Shader
  4891. // gs_<version> : Geometry Shader : Geometry Shader
  4892. // ps_<version> : Pixel Shader : Fragment Shader
  4893. // cs_<version> : Compute Shader : Compute Shader
  4894. switch (model.GetKind()) {
  4895. case hlsl::ShaderModel::Kind::Vertex:
  4896. return spv::ExecutionModel::Vertex;
  4897. case hlsl::ShaderModel::Kind::Hull:
  4898. return spv::ExecutionModel::TessellationControl;
  4899. case hlsl::ShaderModel::Kind::Domain:
  4900. return spv::ExecutionModel::TessellationEvaluation;
  4901. case hlsl::ShaderModel::Kind::Geometry:
  4902. return spv::ExecutionModel::Geometry;
  4903. case hlsl::ShaderModel::Kind::Pixel:
  4904. return spv::ExecutionModel::Fragment;
  4905. case hlsl::ShaderModel::Kind::Compute:
  4906. return spv::ExecutionModel::GLCompute;
  4907. default:
  4908. break;
  4909. }
  4910. llvm_unreachable("unknown shader model");
  4911. }
  4912. void SPIRVEmitter::AddRequiredCapabilitiesForShaderModel() {
  4913. if (shaderModel.IsHS() || shaderModel.IsDS()) {
  4914. theBuilder.requireCapability(spv::Capability::Tessellation);
  4915. } else if (shaderModel.IsGS()) {
  4916. theBuilder.requireCapability(spv::Capability::Geometry);
  4917. } else {
  4918. theBuilder.requireCapability(spv::Capability::Shader);
  4919. }
  4920. }
  4921. void SPIRVEmitter::AddExecutionModeForEntryPoint(uint32_t entryPointId) {
  4922. if (shaderModel.IsPS()) {
  4923. theBuilder.addExecutionMode(entryPointId,
  4924. spv::ExecutionMode::OriginUpperLeft, {});
  4925. }
  4926. }
  4927. bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
  4928. uint32_t *arraySize) {
  4929. bool success = true;
  4930. assert(shaderModel.IsGS());
  4931. if (auto *vcAttr = decl->getAttr<HLSLMaxVertexCountAttr>()) {
  4932. theBuilder.addExecutionMode(entryFunctionId,
  4933. spv::ExecutionMode::OutputVertices,
  4934. {static_cast<uint32_t>(vcAttr->getCount())});
  4935. }
  4936. // Only one primitive type is permitted for the geometry shader.
  4937. bool outPoint = false, outLine = false, outTriangle = false, inPoint = false,
  4938. inLine = false, inTriangle = false, inLineAdj = false,
  4939. inTriangleAdj = false;
  4940. for (const auto *param : decl->params()) {
  4941. // Add an execution mode based on the output stream type. Do not an
  4942. // execution mode more than once.
  4943. if (param->hasAttr<HLSLInOutAttr>()) {
  4944. const auto paramType = param->getType();
  4945. if (hlsl::IsHLSLTriangleStreamType(paramType) && !outTriangle) {
  4946. theBuilder.addExecutionMode(
  4947. entryFunctionId, spv::ExecutionMode::OutputTriangleStrip, {});
  4948. outTriangle = true;
  4949. } else if (hlsl::IsHLSLLineStreamType(paramType) && !outLine) {
  4950. theBuilder.addExecutionMode(entryFunctionId,
  4951. spv::ExecutionMode::OutputLineStrip, {});
  4952. outLine = true;
  4953. } else if (hlsl::IsHLSLPointStreamType(paramType) && !outPoint) {
  4954. theBuilder.addExecutionMode(entryFunctionId,
  4955. spv::ExecutionMode::OutputPoints, {});
  4956. outPoint = true;
  4957. }
  4958. // An output stream parameter will not have the input primitive type
  4959. // attributes, so we can continue to the next parameter.
  4960. continue;
  4961. }
  4962. // Add an execution mode based on the input primitive type. Do not add an
  4963. // execution mode more than once.
  4964. if (param->hasAttr<HLSLPointAttr>() && !inPoint) {
  4965. theBuilder.addExecutionMode(entryFunctionId,
  4966. spv::ExecutionMode::InputPoints, {});
  4967. *arraySize = 1;
  4968. inPoint = true;
  4969. } else if (param->hasAttr<HLSLLineAttr>() && !inLine) {
  4970. theBuilder.addExecutionMode(entryFunctionId,
  4971. spv::ExecutionMode::InputLines, {});
  4972. *arraySize = 2;
  4973. inLine = true;
  4974. } else if (param->hasAttr<HLSLTriangleAttr>() && !inTriangle) {
  4975. theBuilder.addExecutionMode(entryFunctionId,
  4976. spv::ExecutionMode::Triangles, {});
  4977. *arraySize = 3;
  4978. inTriangle = true;
  4979. } else if (param->hasAttr<HLSLLineAdjAttr>() && !inLineAdj) {
  4980. theBuilder.addExecutionMode(entryFunctionId,
  4981. spv::ExecutionMode::InputLinesAdjacency, {});
  4982. *arraySize = 4;
  4983. inLineAdj = true;
  4984. } else if (param->hasAttr<HLSLTriangleAdjAttr>() && !inTriangleAdj) {
  4985. theBuilder.addExecutionMode(
  4986. entryFunctionId, spv::ExecutionMode::InputTrianglesAdjacency, {});
  4987. *arraySize = 6;
  4988. inTriangleAdj = true;
  4989. }
  4990. }
  4991. if (inPoint + inLine + inLineAdj + inTriangle + inTriangleAdj > 1) {
  4992. emitError("only one input primitive type can be specified in the geometry "
  4993. "shader",
  4994. {});
  4995. success = false;
  4996. }
  4997. if (outPoint + outTriangle + outLine > 1) {
  4998. emitError("only one output primitive type can be specified in the geometry "
  4999. "shader",
  5000. {});
  5001. success = false;
  5002. }
  5003. return success;
  5004. }
  5005. void SPIRVEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
  5006. // If not explicitly specified, x, y, and z should be defaulted to 1.
  5007. uint32_t x = 1, y = 1, z = 1;
  5008. if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
  5009. x = static_cast<uint32_t>(numThreadsAttr->getX());
  5010. y = static_cast<uint32_t>(numThreadsAttr->getY());
  5011. z = static_cast<uint32_t>(numThreadsAttr->getZ());
  5012. }
  5013. theBuilder.addExecutionMode(entryFunctionId, spv::ExecutionMode::LocalSize,
  5014. {x, y, z});
  5015. }
  5016. bool SPIRVEmitter::processTessellationShaderAttributes(
  5017. const FunctionDecl *decl, uint32_t *numOutputControlPoints) {
  5018. assert(shaderModel.IsHS() || shaderModel.IsDS());
  5019. using namespace spv;
  5020. if (auto *domain = decl->getAttr<HLSLDomainAttr>()) {
  5021. const auto domainType = domain->getDomainType().lower();
  5022. const ExecutionMode hsExecMode =
  5023. llvm::StringSwitch<ExecutionMode>(domainType)
  5024. .Case("tri", ExecutionMode::Triangles)
  5025. .Case("quad", ExecutionMode::Quads)
  5026. .Case("isoline", ExecutionMode::Isolines)
  5027. .Default(ExecutionMode::Max);
  5028. if (hsExecMode == ExecutionMode::Max) {
  5029. emitError("unknown domain type specified for entry function",
  5030. decl->getLocation());
  5031. return false;
  5032. }
  5033. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  5034. }
  5035. // Early return for domain shaders as domain shaders only takes the 'domain'
  5036. // attribute.
  5037. if (shaderModel.IsDS())
  5038. return true;
  5039. if (auto *partitioning = decl->getAttr<HLSLPartitioningAttr>()) {
  5040. // TODO: Could not find an equivalent of "pow2" partitioning scheme in
  5041. // SPIR-V.
  5042. const auto scheme = partitioning->getScheme().lower();
  5043. const ExecutionMode hsExecMode =
  5044. llvm::StringSwitch<ExecutionMode>(scheme)
  5045. .Case("fractional_even", ExecutionMode::SpacingFractionalEven)
  5046. .Case("fractional_odd", ExecutionMode::SpacingFractionalOdd)
  5047. .Case("integer", ExecutionMode::SpacingEqual)
  5048. .Default(ExecutionMode::Max);
  5049. if (hsExecMode == ExecutionMode::Max) {
  5050. emitError("unknown partitioning scheme in hull shader",
  5051. decl->getLocation());
  5052. return false;
  5053. }
  5054. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  5055. }
  5056. if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
  5057. const auto topology = outputTopology->getTopology().lower();
  5058. const ExecutionMode hsExecMode =
  5059. llvm::StringSwitch<ExecutionMode>(topology)
  5060. .Case("point", ExecutionMode::PointMode)
  5061. .Case("triangle_cw", ExecutionMode::VertexOrderCw)
  5062. .Case("triangle_ccw", ExecutionMode::VertexOrderCcw)
  5063. .Default(ExecutionMode::Max);
  5064. // TODO: There is no SPIR-V equivalent for "line" topology. Is it the
  5065. // default?
  5066. if (topology != "line") {
  5067. if (hsExecMode != spv::ExecutionMode::Max) {
  5068. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  5069. } else {
  5070. emitError("unknown output topology in hull shader",
  5071. decl->getLocation());
  5072. return false;
  5073. }
  5074. }
  5075. }
  5076. if (auto *controlPoints = decl->getAttr<HLSLOutputControlPointsAttr>()) {
  5077. *numOutputControlPoints = controlPoints->getCount();
  5078. theBuilder.addExecutionMode(entryFunctionId,
  5079. spv::ExecutionMode::OutputVertices,
  5080. {*numOutputControlPoints});
  5081. }
  5082. return true;
  5083. }
  5084. bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
  5085. const uint32_t entryFuncId) {
  5086. // HS specific attributes
  5087. uint32_t numOutputControlPoints = 0;
  5088. uint32_t outputControlPointIdVal = 0; // SV_OutputControlPointID value
  5089. uint32_t primitiveIdVar = 0; // SV_PrimitiveID variable
  5090. uint32_t hullMainInputPatchParam = 0; // Temporary parameter for InputPatch<>
  5091. // The array size of per-vertex input/output variables
  5092. // Used by HS/DS/GS for the additional arrayness, zero means not an array.
  5093. uint32_t inputArraySize = 0;
  5094. uint32_t outputArraySize = 0;
  5095. // Construct the wrapper function signature.
  5096. const uint32_t voidType = theBuilder.getVoidType();
  5097. const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
  5098. // The wrapper entry function surely does not have pre-assigned <result-id>
  5099. // for it like other functions that got added to the work queue following
  5100. // function calls. And the wrapper is the entry function.
  5101. entryFunctionId =
  5102. theBuilder.beginFunction(funcType, voidType, decl->getName());
  5103. // Note this should happen before using declIdMapper for other tasks.
  5104. declIdMapper.setEntryFunctionId(entryFunctionId);
  5105. // Handle attributes specific to each shader stage
  5106. if (shaderModel.IsCS()) {
  5107. processComputeShaderAttributes(decl);
  5108. } else if (shaderModel.IsHS()) {
  5109. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  5110. return false;
  5111. // The input array size for HS is specified in the InputPatch parameter.
  5112. for (const auto *param : decl->params())
  5113. if (hlsl::IsHLSLInputPatchType(param->getType())) {
  5114. inputArraySize = hlsl::GetHLSLInputPatchCount(param->getType());
  5115. break;
  5116. }
  5117. outputArraySize = numOutputControlPoints;
  5118. } else if (shaderModel.IsDS()) {
  5119. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  5120. return false;
  5121. // The input array size for HS is specified in the OutputPatch parameter.
  5122. for (const auto *param : decl->params())
  5123. if (hlsl::IsHLSLOutputPatchType(param->getType())) {
  5124. inputArraySize = hlsl::GetHLSLOutputPatchCount(param->getType());
  5125. break;
  5126. }
  5127. // The per-vertex output of DS is not an array.
  5128. } else if (shaderModel.IsGS()) {
  5129. if (!processGeometryShaderAttributes(decl, &inputArraySize))
  5130. return false;
  5131. // The per-vertex output of GS is not an array.
  5132. }
  5133. // Go through all parameters and record the declaration of SV_ClipDistance
  5134. // and SV_CullDistance. We need to do this extra step because in HLSL we
  5135. // can declare multiple SV_ClipDistance/SV_CullDistance variables of float
  5136. // or vector of float types, but we can only have one single float array
  5137. // for the ClipDistance/CullDistance builtin. So we need to group all
  5138. // SV_ClipDistance/SV_CullDistance variables into one float array, thus we
  5139. // need to calculate the total size of the array and the offset of each
  5140. // variable within that array.
  5141. for (const auto *param : decl->params()) {
  5142. if (canActAsInParmVar(param))
  5143. if (!declIdMapper.glPerVertex.recordClipCullDistanceDecl(param, true))
  5144. return false;
  5145. if (canActAsOutParmVar(param))
  5146. if (!declIdMapper.glPerVertex.recordClipCullDistanceDecl(param, false))
  5147. return false;
  5148. }
  5149. // Also consider the SV_ClipDistance/SV_CullDistance in the return type
  5150. if (!declIdMapper.glPerVertex.recordClipCullDistanceDecl(decl, false))
  5151. return false;
  5152. // Calculate the total size of the ClipDistance/CullDistance array and the
  5153. // offset of SV_ClipDistance/SV_CullDistance variables within the array.
  5154. declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
  5155. if (!shaderModel.IsCS()) {
  5156. // Generate the gl_PerVertex structs or stand-alone builtins of
  5157. // Position, ClipDistance, and CullDistance.
  5158. declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
  5159. }
  5160. // Require the ClipDistance/CullDistance capability if necessary.
  5161. // It is legal to just use the ClipDistance/CullDistance builtin without
  5162. // requiring the ClipDistance/CullDistance capability, as long as we don't
  5163. // read or write the builtin variable.
  5164. // For our CodeGen, that corresponds to not seeing SV_ClipDistance or
  5165. // SV_CullDistance at all. If we see them, we will generate code to read
  5166. // them to initialize temporary variable for calling the source code entry
  5167. // function or write to them after calling the source code entry function.
  5168. declIdMapper.glPerVertex.requireCapabilityIfNecessary();
  5169. // The entry basic block.
  5170. const uint32_t entryLabel = theBuilder.createBasicBlock();
  5171. theBuilder.setInsertPoint(entryLabel);
  5172. // Initialize all global variables at the beginning of the wrapper
  5173. for (const VarDecl *varDecl : toInitGloalVars)
  5174. theBuilder.createStore(declIdMapper.getDeclResultId(varDecl),
  5175. doExpr(varDecl->getInit()));
  5176. // Create temporary variables for holding function call arguments
  5177. llvm::SmallVector<uint32_t, 4> params;
  5178. for (const auto *param : decl->params()) {
  5179. const auto paramType = param->getType();
  5180. const uint32_t typeId = typeTranslator.translateType(paramType);
  5181. std::string tempVarName = "param.var." + param->getNameAsString();
  5182. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  5183. params.push_back(tempVar);
  5184. // Create the stage input variable for parameter not marked as pure out and
  5185. // initialize the corresponding temporary variable
  5186. // Also do not create input variables for output stream objects of geometry
  5187. // shaders (e.g. TriangleStream) which are required to be marked as 'inout'.
  5188. if (canActAsInParmVar(param)) {
  5189. if (shaderModel.IsHS() && hlsl::IsHLSLInputPatchType(paramType)) {
  5190. // Record the temporary variable holding InputPatch. It may be used
  5191. // later in the patch constant function.
  5192. hullMainInputPatchParam = tempVar;
  5193. }
  5194. uint32_t loadedValue = 0;
  5195. if (!declIdMapper.createStageInputVar(param, &loadedValue, false))
  5196. return false;
  5197. theBuilder.createStore(tempVar, loadedValue);
  5198. // Record the temporary variable holding SV_OutputControlPointID and
  5199. // SV_PrimitiveID. It may be used later in the patch constant function.
  5200. if (hasSemantic(param, hlsl::DXIL::SemanticKind::OutputControlPointID))
  5201. outputControlPointIdVal = loadedValue;
  5202. if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID))
  5203. primitiveIdVar = tempVar;
  5204. }
  5205. }
  5206. // Call the original entry function
  5207. const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
  5208. const uint32_t retVal =
  5209. theBuilder.createFunctionCall(retType, entryFuncId, params);
  5210. // Create and write stage output variables for return value. Special case for
  5211. // Hull shaders since they operate differently in 2 ways:
  5212. // 1- Their return value is in fact an array and each invocation should write
  5213. // to the proper offset in the array.
  5214. // 2- The patch constant function must be called *once* after all invocations
  5215. // of the main entry point function is done.
  5216. if (shaderModel.IsHS()) {
  5217. // Create stage output variables out of the return type.
  5218. if (!declIdMapper.createStageOutputVar(decl, numOutputControlPoints,
  5219. outputControlPointIdVal, retVal))
  5220. return false;
  5221. if (!processHullEntryPointOutputAndPatchConstFunc(
  5222. decl, retType, retVal, numOutputControlPoints,
  5223. outputControlPointIdVal, primitiveIdVar, hullMainInputPatchParam))
  5224. return false;
  5225. } else {
  5226. if (!declIdMapper.createStageOutputVar(decl, retVal, /*forPCF*/ false))
  5227. return false;
  5228. }
  5229. // Create and write stage output variables for parameters marked as
  5230. // out/inout
  5231. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  5232. const auto *param = decl->getParamDecl(i);
  5233. if (canActAsOutParmVar(param)) {
  5234. // Load the value from the parameter after function call
  5235. const uint32_t typeId = typeTranslator.translateType(param->getType());
  5236. uint32_t loadedParam = 0;
  5237. // Write back of stage output variables in GS is manually controlled by
  5238. // .Append() intrinsic method. No need to load the parameter since we
  5239. // won't need to write back here.
  5240. if (!shaderModel.IsGS())
  5241. loadedParam = theBuilder.createLoad(typeId, params[i]);
  5242. if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
  5243. return false;
  5244. }
  5245. }
  5246. theBuilder.createReturn();
  5247. theBuilder.endFunction();
  5248. // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
  5249. // We should invoke a translation of the PCF manually.
  5250. if (shaderModel.IsHS())
  5251. doDecl(patchConstFunc);
  5252. return true;
  5253. }
  5254. bool SPIRVEmitter::processHullEntryPointOutputAndPatchConstFunc(
  5255. const FunctionDecl *hullMainFuncDecl, uint32_t retType, uint32_t retVal,
  5256. uint32_t numOutputControlPoints, uint32_t outputControlPointId,
  5257. uint32_t primitiveId, uint32_t hullMainInputPatch) {
  5258. // This method may only be called for Hull shaders.
  5259. assert(shaderModel.IsHS());
  5260. // For Hull shaders, the real output is an array of size
  5261. // numOutputControlPoints. The results of the main should be written to the
  5262. // correct offset in the array (based on InvocationID).
  5263. if (!numOutputControlPoints) {
  5264. emitError("number of output control points cannot be zero",
  5265. hullMainFuncDecl->getLocation());
  5266. return false;
  5267. }
  5268. // TODO: We should be able to handle cases where the SV_OutputControlPointID
  5269. // is not provided.
  5270. if (!outputControlPointId) {
  5271. emitError(
  5272. "SV_OutputControlPointID semantic must be provided in hull shader",
  5273. hullMainFuncDecl->getLocation());
  5274. return false;
  5275. }
  5276. if (!patchConstFunc) {
  5277. emitError("patch constant function not defined in hull shader",
  5278. hullMainFuncDecl->getLocation());
  5279. return false;
  5280. }
  5281. uint32_t hullMainOutputPatch = 0;
  5282. // If the patch constant function (PCF) takes the result of the Hull main
  5283. // entry point, create a temporary function-scope variable and write the
  5284. // results to it, so it can be passed to the PCF.
  5285. if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
  5286. const uint32_t hullMainRetType = theBuilder.getArrayType(
  5287. retType, theBuilder.getConstantUint32(numOutputControlPoints));
  5288. hullMainOutputPatch =
  5289. theBuilder.addFnVar(hullMainRetType, "temp.var.hullMainRetVal");
  5290. const auto tempLocation = theBuilder.createAccessChain(
  5291. theBuilder.getPointerType(retType, spv::StorageClass::Function),
  5292. hullMainOutputPatch, {outputControlPointId});
  5293. theBuilder.createStore(tempLocation, retVal);
  5294. }
  5295. // Now create a barrier before calling the Patch Constant Function (PCF).
  5296. // Flags are:
  5297. // Execution Barrier scope = Workgroup (2)
  5298. // Memory Barrier scope = Device (1)
  5299. // Memory Semantics Barrier scope = None (0)
  5300. theBuilder.createControlBarrier(theBuilder.getConstantUint32(2),
  5301. theBuilder.getConstantUint32(1),
  5302. theBuilder.getConstantUint32(0));
  5303. // The PCF should be called only once. Therefore, we check the invocationID,
  5304. // and we only allow ID 0 to call the PCF.
  5305. const uint32_t condition = theBuilder.createBinaryOp(
  5306. spv::Op::OpIEqual, theBuilder.getBoolType(), outputControlPointId,
  5307. theBuilder.getConstantUint32(0));
  5308. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  5309. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  5310. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  5311. theBuilder.addSuccessor(thenBB);
  5312. theBuilder.addSuccessor(mergeBB);
  5313. theBuilder.setMergeTarget(mergeBB);
  5314. theBuilder.setInsertPoint(thenBB);
  5315. // Call the PCF. Since the function is not explicitly called, we must first
  5316. // register an ID for it.
  5317. const uint32_t pcfId = declIdMapper.getOrRegisterFnResultId(patchConstFunc);
  5318. const uint32_t pcfRetType =
  5319. typeTranslator.translateType(patchConstFunc->getReturnType());
  5320. std::vector<uint32_t> pcfParams;
  5321. for (const auto *param : patchConstFunc->parameters()) {
  5322. // Note: According to the HLSL reference, the PCF takes an InputPatch of
  5323. // ControlPoints as well as the PatchID (PrimitiveID). This does not
  5324. // necessarily mean that they are present. There is also no requirement
  5325. // for the order of parameters passed to PCF.
  5326. if (hlsl::IsHLSLInputPatchType(param->getType()))
  5327. pcfParams.push_back(hullMainInputPatch);
  5328. if (hlsl::IsHLSLOutputPatchType(param->getType()))
  5329. pcfParams.push_back(hullMainOutputPatch);
  5330. if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
  5331. if (!primitiveId) {
  5332. const uint32_t typeId = typeTranslator.translateType(param->getType());
  5333. std::string tempVarName = "param.var." + param->getNameAsString();
  5334. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  5335. uint32_t loadedValue = 0;
  5336. declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
  5337. theBuilder.createStore(tempVar, loadedValue);
  5338. primitiveId = tempVar;
  5339. }
  5340. pcfParams.push_back(primitiveId);
  5341. }
  5342. }
  5343. const uint32_t pcfResultId =
  5344. theBuilder.createFunctionCall(pcfRetType, pcfId, {pcfParams});
  5345. if (!declIdMapper.createStageOutputVar(patchConstFunc, pcfResultId,
  5346. /*forPCF*/ true))
  5347. return false;
  5348. theBuilder.createBranch(mergeBB);
  5349. theBuilder.addSuccessor(mergeBB);
  5350. theBuilder.setInsertPoint(mergeBB);
  5351. return true;
  5352. }
  5353. bool SPIRVEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
  5354. if (!root)
  5355. return false;
  5356. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  5357. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  5358. if (!caseStmt && !compoundStmt)
  5359. return true;
  5360. if (caseStmt) {
  5361. const Expr *caseExpr = caseStmt->getLHS();
  5362. return caseExpr && caseExpr->isEvaluatable(astContext);
  5363. }
  5364. // Recurse down if facing a compound statement.
  5365. for (auto *st : compoundStmt->body())
  5366. if (!allSwitchCasesAreIntegerLiterals(st))
  5367. return false;
  5368. return true;
  5369. }
  5370. void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
  5371. const Stmt *root, uint32_t *defaultBB,
  5372. std::vector<std::pair<uint32_t, uint32_t>> *targets) {
  5373. if (!root)
  5374. return;
  5375. // A switch case can only appear in DefaultStmt, CaseStmt, or
  5376. // CompoundStmt. For the rest, we can just return.
  5377. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  5378. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  5379. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  5380. if (!defaultStmt && !caseStmt && !compoundStmt)
  5381. return;
  5382. // Recurse down if facing a compound statement.
  5383. if (compoundStmt) {
  5384. for (auto *st : compoundStmt->body())
  5385. discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
  5386. return;
  5387. }
  5388. std::string caseLabel;
  5389. uint32_t caseValue = 0;
  5390. if (defaultStmt) {
  5391. // This is the default branch.
  5392. caseLabel = "switch.default";
  5393. } else if (caseStmt) {
  5394. // This is a non-default case.
  5395. // When using OpSwitch, we only allow integer literal cases. e.g:
  5396. // case <literal_integer>: {...; break;}
  5397. const Expr *caseExpr = caseStmt->getLHS();
  5398. assert(caseExpr && caseExpr->isEvaluatable(astContext));
  5399. auto bitWidth = astContext.getIntWidth(caseExpr->getType());
  5400. if (bitWidth != 32)
  5401. emitError(
  5402. "non-32bit integer case value in switch statement unimplemented",
  5403. caseExpr->getExprLoc());
  5404. Expr::EvalResult evalResult;
  5405. caseExpr->EvaluateAsRValue(evalResult, astContext);
  5406. const int64_t value = evalResult.Val.getInt().getSExtValue();
  5407. caseValue = static_cast<uint32_t>(value);
  5408. caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
  5409. llvm::itostr(std::abs(value));
  5410. }
  5411. const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
  5412. theBuilder.addSuccessor(caseBB);
  5413. stmtBasicBlock[root] = caseBB;
  5414. // Add all cases to the 'targets' vector.
  5415. if (caseStmt)
  5416. targets->emplace_back(caseValue, caseBB);
  5417. // The default label is not part of the 'targets' vector that is passed
  5418. // to the OpSwitch instruction.
  5419. // If default statement was discovered, return its label via defaultBB.
  5420. if (defaultStmt)
  5421. *defaultBB = caseBB;
  5422. // Process cases nested in other cases. It happens when we have fall through
  5423. // cases. For example:
  5424. // case 1: case 2: ...; break;
  5425. // will result in the CaseSmt for case 2 nested in the one for case 1.
  5426. discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
  5427. : defaultStmt->getSubStmt(),
  5428. defaultBB, targets);
  5429. }
  5430. void SPIRVEmitter::flattenSwitchStmtAST(const Stmt *root,
  5431. std::vector<const Stmt *> *flatSwitch) {
  5432. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  5433. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  5434. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  5435. if (!compoundStmt) {
  5436. flatSwitch->push_back(root);
  5437. }
  5438. if (compoundStmt) {
  5439. for (const auto *st : compoundStmt->body())
  5440. flattenSwitchStmtAST(st, flatSwitch);
  5441. } else if (caseStmt) {
  5442. flattenSwitchStmtAST(caseStmt->getSubStmt(), flatSwitch);
  5443. } else if (defaultStmt) {
  5444. flattenSwitchStmtAST(defaultStmt->getSubStmt(), flatSwitch);
  5445. }
  5446. }
  5447. void SPIRVEmitter::processCaseStmtOrDefaultStmt(const Stmt *stmt) {
  5448. auto *caseStmt = dyn_cast<CaseStmt>(stmt);
  5449. auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
  5450. assert(caseStmt || defaultStmt);
  5451. uint32_t caseBB = stmtBasicBlock[stmt];
  5452. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  5453. // We are about to handle the case passed in as parameter. If the current
  5454. // basic block is not terminated, it means the previous case is a fall
  5455. // through case. We need to link it to the case to be processed.
  5456. theBuilder.createBranch(caseBB);
  5457. theBuilder.addSuccessor(caseBB);
  5458. }
  5459. theBuilder.setInsertPoint(caseBB);
  5460. doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
  5461. }
  5462. void SPIRVEmitter::processSwitchStmtUsingSpirvOpSwitch(
  5463. const SwitchStmt *switchStmt) {
  5464. // First handle the condition variable DeclStmt if one exists.
  5465. // For example: handle 'int a = b' in the following:
  5466. // switch (int a = b) {...}
  5467. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  5468. doDeclStmt(condVarDeclStmt);
  5469. const uint32_t selector = doExpr(switchStmt->getCond());
  5470. // We need a merge block regardless of the number of switch cases.
  5471. // Since OpSwitch always requires a default label, if the switch statement
  5472. // does not have a default branch, we use the merge block as the default
  5473. // target.
  5474. const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
  5475. theBuilder.setMergeTarget(mergeBB);
  5476. breakStack.push(mergeBB);
  5477. uint32_t defaultBB = mergeBB;
  5478. // (literal, labelId) pairs to pass to the OpSwitch instruction.
  5479. std::vector<std::pair<uint32_t, uint32_t>> targets;
  5480. discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
  5481. // Create the OpSelectionMerge and OpSwitch.
  5482. theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
  5483. // Handle the switch body.
  5484. doStmt(switchStmt->getBody());
  5485. if (!theBuilder.isCurrentBasicBlockTerminated())
  5486. theBuilder.createBranch(mergeBB);
  5487. theBuilder.setInsertPoint(mergeBB);
  5488. breakStack.pop();
  5489. }
  5490. void SPIRVEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
  5491. std::vector<const Stmt *> flatSwitch;
  5492. flattenSwitchStmtAST(switchStmt->getBody(), &flatSwitch);
  5493. // First handle the condition variable DeclStmt if one exists.
  5494. // For example: handle 'int a = b' in the following:
  5495. // switch (int a = b) {...}
  5496. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  5497. doDeclStmt(condVarDeclStmt);
  5498. // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
  5499. // the flattened switch AST.
  5500. // For instance, for the following flat vector:
  5501. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  5502. // |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
  5503. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  5504. // The indexes are: {0, 2, 5, 6, 9}
  5505. std::vector<uint32_t> caseStmtLocs;
  5506. for (uint32_t i = 0; i < flatSwitch.size(); ++i)
  5507. if (isa<CaseStmt>(flatSwitch[i]) || isa<DefaultStmt>(flatSwitch[i]))
  5508. caseStmtLocs.push_back(i);
  5509. IfStmt *prevIfStmt = nullptr;
  5510. IfStmt *rootIfStmt = nullptr;
  5511. CompoundStmt *defaultBody = nullptr;
  5512. // For each case, start at its index in the vector, and go forward
  5513. // accumulating statements until BreakStmt or end of vector is reached.
  5514. for (auto curCaseIndex : caseStmtLocs) {
  5515. const Stmt *curCase = flatSwitch[curCaseIndex];
  5516. // CompoundStmt to hold all statements for this case.
  5517. CompoundStmt *cs = new (astContext) CompoundStmt(Stmt::EmptyShell());
  5518. // Accumulate all non-case/default/break statements as the body for the
  5519. // current case.
  5520. std::vector<Stmt *> statements;
  5521. for (int i = curCaseIndex + 1;
  5522. i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
  5523. if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
  5524. statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
  5525. }
  5526. if (!statements.empty())
  5527. cs->setStmts(astContext, statements.data(), statements.size());
  5528. // For non-default cases, generate the IfStmt that compares the switch
  5529. // value to the case value.
  5530. if (auto *caseStmt = dyn_cast<CaseStmt>(curCase)) {
  5531. IfStmt *curIf = new (astContext) IfStmt(Stmt::EmptyShell());
  5532. BinaryOperator *bo = new (astContext) BinaryOperator(Stmt::EmptyShell());
  5533. bo->setLHS(const_cast<Expr *>(switchStmt->getCond()));
  5534. bo->setRHS(const_cast<Expr *>(caseStmt->getLHS()));
  5535. bo->setOpcode(BO_EQ);
  5536. bo->setType(astContext.getLogicalOperationType());
  5537. curIf->setCond(bo);
  5538. curIf->setThen(cs);
  5539. // No conditional variable associated with this faux if statement.
  5540. curIf->setConditionVariable(astContext, nullptr);
  5541. // Each If statement is the "else" of the previous if statement.
  5542. if (prevIfStmt)
  5543. prevIfStmt->setElse(curIf);
  5544. else
  5545. rootIfStmt = curIf;
  5546. prevIfStmt = curIf;
  5547. } else {
  5548. // Record the DefaultStmt body as it will be used as the body of the
  5549. // "else" block in the if-elseif-...-else pattern.
  5550. defaultBody = cs;
  5551. }
  5552. }
  5553. // If a default case exists, it is the "else" of the last if statement.
  5554. if (prevIfStmt)
  5555. prevIfStmt->setElse(defaultBody);
  5556. // Since all else-if and else statements are the child nodes of the first
  5557. // IfStmt, we only need to call doStmt for the first IfStmt.
  5558. if (rootIfStmt)
  5559. doStmt(rootIfStmt);
  5560. // If there are no CaseStmt and there is only 1 DefaultStmt, there will be
  5561. // no if statements. The switch in that case only executes the body of the
  5562. // default case.
  5563. else if (defaultBody)
  5564. doStmt(defaultBody);
  5565. }
  5566. } // end namespace spirv
  5567. } // end namespace clang