SPIRVEmitter.cpp 400 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751775277537754775577567757775877597760776177627763776477657766776777687769777077717772777377747775777677777778777977807781778277837784778577867787778877897790779177927793779477957796779777987799780078017802780378047805780678077808780978107811781278137814781578167817781878197820782178227823782478257826782778287829783078317832783378347835783678377838783978407841784278437844784578467847784878497850785178527853785478557856785778587859786078617862786378647865786678677868786978707871787278737874787578767877787878797880788178827883788478857886788778887889789078917892789378947895789678977898789979007901790279037904790579067907790879097910791179127913791479157916791779187919792079217922792379247925792679277928792979307931793279337934793579367937793879397940794179427943794479457946794779487949795079517952795379547955795679577958795979607961796279637964796579667967796879697970797179727973797479757976797779787979798079817982798379847985798679877988798979907991799279937994799579967997799879998000800180028003800480058006800780088009801080118012801380148015801680178018801980208021802280238024802580268027802880298030803180328033803480358036803780388039804080418042804380448045804680478048804980508051805280538054805580568057805880598060806180628063806480658066806780688069807080718072807380748075807680778078807980808081808280838084808580868087808880898090809180928093809480958096809780988099810081018102810381048105810681078108810981108111811281138114811581168117811881198120812181228123812481258126812781288129813081318132813381348135813681378138813981408141814281438144814581468147814881498150815181528153815481558156815781588159816081618162816381648165816681678168816981708171817281738174817581768177817881798180818181828183818481858186818781888189819081918192819381948195819681978198819982008201820282038204820582068207820882098210821182128213821482158216821782188219822082218222822382248225822682278228822982308231823282338234823582368237823882398240824182428243824482458246824782488249825082518252825382548255825682578258825982608261826282638264826582668267826882698270827182728273827482758276827782788279828082818282828382848285828682878288828982908291829282938294829582968297829882998300830183028303830483058306830783088309831083118312831383148315831683178318831983208321832283238324832583268327832883298330833183328333833483358336833783388339834083418342834383448345834683478348834983508351835283538354835583568357835883598360836183628363836483658366836783688369837083718372837383748375837683778378837983808381838283838384838583868387838883898390839183928393839483958396839783988399840084018402840384048405840684078408840984108411841284138414841584168417841884198420842184228423842484258426842784288429843084318432843384348435843684378438843984408441844284438444844584468447844884498450845184528453845484558456845784588459846084618462846384648465846684678468846984708471847284738474847584768477847884798480848184828483848484858486848784888489849084918492849384948495849684978498849985008501850285038504850585068507850885098510851185128513851485158516851785188519852085218522852385248525852685278528852985308531853285338534853585368537853885398540854185428543854485458546854785488549855085518552855385548555855685578558855985608561856285638564856585668567856885698570857185728573857485758576857785788579858085818582858385848585858685878588858985908591859285938594859585968597859885998600860186028603860486058606860786088609861086118612861386148615861686178618861986208621862286238624862586268627862886298630863186328633863486358636863786388639864086418642864386448645864686478648864986508651865286538654865586568657865886598660866186628663866486658666866786688669867086718672867386748675867686778678867986808681868286838684868586868687868886898690869186928693869486958696869786988699870087018702870387048705870687078708870987108711871287138714871587168717871887198720872187228723872487258726872787288729873087318732873387348735873687378738873987408741874287438744874587468747874887498750875187528753875487558756875787588759876087618762876387648765876687678768876987708771877287738774877587768777877887798780878187828783878487858786878787888789879087918792879387948795879687978798879988008801880288038804880588068807880888098810881188128813881488158816881788188819882088218822882388248825882688278828882988308831883288338834883588368837883888398840884188428843884488458846884788488849885088518852885388548855885688578858885988608861886288638864886588668867886888698870887188728873887488758876887788788879888088818882888388848885888688878888888988908891889288938894889588968897889888998900890189028903890489058906890789088909891089118912891389148915891689178918891989208921892289238924892589268927892889298930893189328933893489358936893789388939894089418942894389448945894689478948894989508951895289538954895589568957895889598960896189628963896489658966896789688969897089718972897389748975897689778978897989808981898289838984898589868987898889898990899189928993899489958996899789988999900090019002900390049005900690079008900990109011901290139014901590169017901890199020902190229023902490259026902790289029903090319032903390349035903690379038903990409041904290439044904590469047904890499050905190529053905490559056905790589059906090619062906390649065906690679068906990709071907290739074907590769077907890799080908190829083908490859086908790889089909090919092909390949095909690979098909991009101910291039104910591069107910891099110911191129113911491159116911791189119912091219122912391249125912691279128912991309131913291339134913591369137913891399140914191429143914491459146914791489149915091519152915391549155915691579158915991609161916291639164916591669167916891699170917191729173917491759176917791789179918091819182918391849185918691879188918991909191919291939194919591969197919891999200920192029203920492059206920792089209921092119212921392149215921692179218921992209221922292239224922592269227922892299230923192329233923492359236923792389239924092419242924392449245924692479248924992509251925292539254925592569257925892599260926192629263926492659266926792689269927092719272927392749275927692779278927992809281928292839284928592869287928892899290929192929293929492959296929792989299930093019302930393049305930693079308930993109311931293139314931593169317931893199320932193229323932493259326932793289329933093319332933393349335933693379338933993409341934293439344934593469347934893499350935193529353935493559356935793589359936093619362936393649365936693679368936993709371937293739374937593769377937893799380938193829383938493859386938793889389939093919392939393949395939693979398939994009401940294039404940594069407940894099410941194129413941494159416941794189419942094219422942394249425942694279428942994309431943294339434943594369437943894399440944194429443944494459446944794489449945094519452945394549455945694579458945994609461946294639464946594669467946894699470947194729473947494759476947794789479948094819482948394849485948694879488948994909491949294939494949594969497949894999500950195029503950495059506950795089509951095119512951395149515951695179518951995209521952295239524952595269527952895299530953195329533953495359536953795389539954095419542954395449545954695479548954995509551955295539554955595569557955895599560956195629563956495659566956795689569957095719572957395749575957695779578957995809581958295839584958595869587958895899590959195929593959495959596959795989599960096019602960396049605960696079608960996109611961296139614961596169617961896199620962196229623962496259626962796289629963096319632963396349635963696379638963996409641964296439644964596469647964896499650965196529653965496559656965796589659966096619662966396649665966696679668966996709671967296739674967596769677967896799680968196829683968496859686968796889689969096919692969396949695969696979698969997009701970297039704970597069707970897099710971197129713971497159716971797189719972097219722972397249725972697279728972997309731973297339734973597369737973897399740974197429743974497459746974797489749975097519752975397549755975697579758975997609761976297639764976597669767976897699770977197729773977497759776977797789779978097819782978397849785978697879788978997909791979297939794979597969797979897999800980198029803980498059806980798089809981098119812981398149815981698179818981998209821982298239824982598269827982898299830983198329833983498359836983798389839984098419842984398449845984698479848984998509851985298539854985598569857985898599860986198629863986498659866986798689869987098719872987398749875987698779878987998809881988298839884988598869887988898899890989198929893989498959896989798989899990099019902990399049905990699079908990999109911991299139914991599169917991899199920992199229923992499259926992799289929993099319932993399349935993699379938993999409941994299439944994599469947994899499950995199529953995499559956995799589959996099619962996399649965996699679968996999709971997299739974997599769977997899799980998199829983998499859986998799889989999099919992999399949995999699979998999910000100011000210003100041000510006100071000810009100101001110012100131001410015100161001710018100191002010021100221002310024100251002610027100281002910030100311003210033100341003510036
  1. //===------- SPIRVEmitter.h - SPIR-V Binary Code Emitter --------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a SPIR-V emitter class that takes in HLSL AST and emits
  10. // SPIR-V binary words.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "SPIRVEmitter.h"
  14. #include "dxc/HlslIntrinsicOp.h"
  15. #include "spirv-tools/optimizer.hpp"
  16. #include "llvm/ADT/StringExtras.h"
  17. #include "InitListHandler.h"
  18. namespace clang {
  19. namespace spirv {
  20. namespace {
  21. // Returns true if the given decl has the given semantic.
  22. bool hasSemantic(const DeclaratorDecl *decl,
  23. hlsl::DXIL::SemanticKind semanticKind) {
  24. using namespace hlsl;
  25. for (auto *annotation : decl->getUnusualAnnotations()) {
  26. if (auto *semanticDecl = dyn_cast<SemanticDecl>(annotation)) {
  27. llvm::StringRef semanticName;
  28. uint32_t semanticIndex = 0;
  29. Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName,
  30. &semanticIndex);
  31. const auto *semantic = Semantic::GetByName(semanticName);
  32. if (semantic->GetKind() == semanticKind)
  33. return true;
  34. }
  35. }
  36. return false;
  37. }
  38. bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
  39. for (const auto *param : pcf->parameters())
  40. if (hlsl::IsHLSLOutputPatchType(param->getType()))
  41. return true;
  42. return false;
  43. }
  44. // TODO: Maybe we should move these type probing functions to TypeTranslator.
  45. /// Returns true if the given type is a bool or vector of bool type.
  46. bool isBoolOrVecOfBoolType(QualType type) {
  47. QualType elemType = {};
  48. return (TypeTranslator::isScalarType(type, &elemType) ||
  49. TypeTranslator::isVectorType(type, &elemType)) &&
  50. elemType->isBooleanType();
  51. }
  52. /// Returns true if the given type is a signed integer or vector of signed
  53. /// integer type.
  54. bool isSintOrVecOfSintType(QualType type) {
  55. QualType elemType = {};
  56. return (TypeTranslator::isScalarType(type, &elemType) ||
  57. TypeTranslator::isVectorType(type, &elemType)) &&
  58. elemType->isSignedIntegerType();
  59. }
  60. /// Returns true if the given type is an unsigned integer or vector of unsigned
  61. /// integer type.
  62. bool isUintOrVecOfUintType(QualType type) {
  63. QualType elemType = {};
  64. return (TypeTranslator::isScalarType(type, &elemType) ||
  65. TypeTranslator::isVectorType(type, &elemType)) &&
  66. elemType->isUnsignedIntegerType();
  67. }
  68. /// Returns true if the given type is a float or vector of float type.
  69. bool isFloatOrVecOfFloatType(QualType type) {
  70. QualType elemType = {};
  71. return (TypeTranslator::isScalarType(type, &elemType) ||
  72. TypeTranslator::isVectorType(type, &elemType)) &&
  73. elemType->isFloatingType();
  74. }
  75. /// Returns true if the given type is a bool or vector/matrix of bool type.
  76. bool isBoolOrVecMatOfBoolType(QualType type) {
  77. return isBoolOrVecOfBoolType(type) ||
  78. (hlsl::IsHLSLMatType(type) &&
  79. hlsl::GetHLSLMatElementType(type)->isBooleanType());
  80. }
  81. /// Returns true if the given type is a signed integer or vector/matrix of
  82. /// signed integer type.
  83. bool isSintOrVecMatOfSintType(QualType type) {
  84. return isSintOrVecOfSintType(type) ||
  85. (hlsl::IsHLSLMatType(type) &&
  86. hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
  87. }
  88. /// Returns true if the given type is an unsigned integer or vector/matrix of
  89. /// unsigned integer type.
  90. bool isUintOrVecMatOfUintType(QualType type) {
  91. return isUintOrVecOfUintType(type) ||
  92. (hlsl::IsHLSLMatType(type) &&
  93. hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
  94. }
  95. /// Returns true if the given type is a float or vector/matrix of float type.
  96. bool isFloatOrVecMatOfFloatType(QualType type) {
  97. return isFloatOrVecOfFloatType(type) ||
  98. (hlsl::IsHLSLMatType(type) &&
  99. hlsl::GetHLSLMatElementType(type)->isFloatingType());
  100. }
  101. inline bool isSpirvMatrixOp(spv::Op opcode) {
  102. return opcode == spv::Op::OpMatrixTimesMatrix ||
  103. opcode == spv::Op::OpMatrixTimesVector ||
  104. opcode == spv::Op::OpMatrixTimesScalar;
  105. }
  106. /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes
  107. /// index. Otherwiser, returns false.
  108. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a
  109. // duplicate of doCXXMemberCallExpr.
  110. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
  111. using namespace hlsl;
  112. if (const auto *indexing = dyn_cast<CXXMemberCallExpr>(expr)) {
  113. const auto *callee = indexing->getDirectCallee();
  114. uint32_t opcode = static_cast<uint32_t>(IntrinsicOp::Num_Intrinsics);
  115. llvm::StringRef group;
  116. if (GetIntrinsicOp(callee, opcode, group)) {
  117. if (static_cast<IntrinsicOp>(opcode) == IntrinsicOp::MOP_Load) {
  118. const auto *object = indexing->getImplicitObjectArgument();
  119. if (TypeTranslator::isStructuredBuffer(object->getType())) {
  120. *index = indexing->getArg(0);
  121. return indexing->getImplicitObjectArgument();
  122. }
  123. }
  124. }
  125. }
  126. return nullptr;
  127. }
  128. /// Returns true if the given VarDecl will be translated into a SPIR-V variable
  129. /// not in the Private or Function storage class.
  130. inline bool isExternalVar(const VarDecl *var) {
  131. // Class static variables should be put in the Private storage class.
  132. // groupshared variables are allowed to be declared as "static". But we still
  133. // need to put them in the Workgroup storage class. That is, when seeing
  134. // "static groupshared", ignore "static".
  135. return var->hasExternalFormalLinkage()
  136. ? !var->isStaticDataMember()
  137. : (var->getAttr<HLSLGroupSharedAttr>() != nullptr);
  138. }
  139. /// Returns the referenced variable's DeclContext if the given expr is
  140. /// a DeclRefExpr referencing a ConstantBuffer/TextureBuffer. Otherwise,
  141. /// returns nullptr.
  142. const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) {
  143. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr->IgnoreParenCasts()))
  144. if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
  145. if (TypeTranslator::isConstantTextureBuffer(varDecl))
  146. return varDecl->getType()->getAs<RecordType>()->getDecl();
  147. return nullptr;
  148. }
  149. /// Returns true if
  150. /// * the given expr is an DeclRefExpr referencing a kind of structured or byte
  151. /// buffer and it is non-alias one, or
  152. /// * the given expr is an CallExpr returning a kind of structured or byte
  153. /// buffer.
  154. /// * the given expr is an ArraySubscriptExpr referencing a kind of structured
  155. /// or byte buffer.
  156. ///
  157. /// Note: legalization specific code
  158. bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
  159. expr = expr->IgnoreParenCasts();
  160. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  161. if (const auto *varDecl = dyn_cast<VarDecl>(declRefExpr->getFoundDecl()))
  162. if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType()))
  163. return isExternalVar(varDecl);
  164. } else if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
  165. if (TypeTranslator::isAKindOfStructuredOrByteBuffer(callExpr->getType()))
  166. return true;
  167. } else if (const auto *arrSubExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  168. return isReferencingNonAliasStructuredOrByteBuffer(arrSubExpr->getBase());
  169. }
  170. return false;
  171. }
  172. bool spirvToolsLegalize(spv_target_env env, std::vector<uint32_t> *module,
  173. std::string *messages) {
  174. spvtools::Optimizer optimizer(env);
  175. optimizer.SetMessageConsumer(
  176. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  177. const spv_position_t & /*position*/,
  178. const char *message) { *messages += message; });
  179. spvtools::OptimizerOptions options;
  180. options.set_run_validator(false);
  181. optimizer.RegisterLegalizationPasses();
  182. optimizer.RegisterPass(spvtools::CreateReplaceInvalidOpcodePass());
  183. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  184. return optimizer.Run(module->data(), module->size(), module, options);
  185. }
  186. bool spirvToolsOptimize(spv_target_env env, std::vector<uint32_t> *module,
  187. const llvm::SmallVector<llvm::StringRef, 4> &flags,
  188. std::string *messages) {
  189. spvtools::Optimizer optimizer(env);
  190. optimizer.SetMessageConsumer(
  191. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  192. const spv_position_t & /*position*/,
  193. const char *message) { *messages += message; });
  194. spvtools::OptimizerOptions options;
  195. options.set_run_validator(false);
  196. if (flags.empty()) {
  197. optimizer.RegisterPerformancePasses();
  198. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  199. } else {
  200. // Command line options use llvm::SmallVector and llvm::StringRef, whereas
  201. // SPIR-V optimizer uses std::vector and std::string.
  202. std::vector<std::string> stdFlags;
  203. for (const auto &f : flags)
  204. stdFlags.push_back(f.str());
  205. if (!optimizer.RegisterPassesFromFlags(stdFlags))
  206. return false;
  207. }
  208. return optimizer.Run(module->data(), module->size(), module, options);
  209. }
  210. bool spirvToolsValidate(spv_target_env env, std::vector<uint32_t> *module,
  211. std::string *messages, bool relaxLogicalPointer,
  212. bool glLayout, bool dxLayout) {
  213. spvtools::SpirvTools tools(env);
  214. tools.SetMessageConsumer(
  215. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  216. const spv_position_t & /*position*/,
  217. const char *message) { *messages += message; });
  218. spvtools::ValidatorOptions options;
  219. options.SetRelaxLogicalPointer(relaxLogicalPointer);
  220. // GL: strict block layout rules
  221. // VK: relaxed block layout rules
  222. // DX: Skip block layout rules
  223. options.SetRelaxBlockLayout(!glLayout && !dxLayout);
  224. options.SetSkipBlockLayout(dxLayout);
  225. return tools.Validate(module->data(), module->size(), options);
  226. }
  227. /// Translates atomic HLSL opcodes into the equivalent SPIR-V opcode.
  228. spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) {
  229. using namespace hlsl;
  230. using namespace spv;
  231. switch (opcode) {
  232. case IntrinsicOp::IOP_InterlockedAdd:
  233. case IntrinsicOp::MOP_InterlockedAdd:
  234. return Op::OpAtomicIAdd;
  235. case IntrinsicOp::IOP_InterlockedAnd:
  236. case IntrinsicOp::MOP_InterlockedAnd:
  237. return Op::OpAtomicAnd;
  238. case IntrinsicOp::IOP_InterlockedOr:
  239. case IntrinsicOp::MOP_InterlockedOr:
  240. return Op::OpAtomicOr;
  241. case IntrinsicOp::IOP_InterlockedXor:
  242. case IntrinsicOp::MOP_InterlockedXor:
  243. return Op::OpAtomicXor;
  244. case IntrinsicOp::IOP_InterlockedUMax:
  245. case IntrinsicOp::MOP_InterlockedUMax:
  246. return Op::OpAtomicUMax;
  247. case IntrinsicOp::IOP_InterlockedUMin:
  248. case IntrinsicOp::MOP_InterlockedUMin:
  249. return Op::OpAtomicUMin;
  250. case IntrinsicOp::IOP_InterlockedMax:
  251. case IntrinsicOp::MOP_InterlockedMax:
  252. return Op::OpAtomicSMax;
  253. case IntrinsicOp::IOP_InterlockedMin:
  254. case IntrinsicOp::MOP_InterlockedMin:
  255. return Op::OpAtomicSMin;
  256. case IntrinsicOp::IOP_InterlockedExchange:
  257. case IntrinsicOp::MOP_InterlockedExchange:
  258. return Op::OpAtomicExchange;
  259. default:
  260. // Only atomic opcodes are relevant.
  261. break;
  262. }
  263. assert(false && "unimplemented hlsl intrinsic opcode");
  264. return Op::Max;
  265. }
  266. // Returns true if the given opcode is an accepted binary opcode in
  267. // OpSpecConstantOp.
  268. bool isAcceptedSpecConstantBinaryOp(spv::Op op) {
  269. switch (op) {
  270. case spv::Op::OpIAdd:
  271. case spv::Op::OpISub:
  272. case spv::Op::OpIMul:
  273. case spv::Op::OpUDiv:
  274. case spv::Op::OpSDiv:
  275. case spv::Op::OpUMod:
  276. case spv::Op::OpSRem:
  277. case spv::Op::OpSMod:
  278. case spv::Op::OpShiftRightLogical:
  279. case spv::Op::OpShiftRightArithmetic:
  280. case spv::Op::OpShiftLeftLogical:
  281. case spv::Op::OpBitwiseOr:
  282. case spv::Op::OpBitwiseXor:
  283. case spv::Op::OpBitwiseAnd:
  284. case spv::Op::OpVectorShuffle:
  285. case spv::Op::OpCompositeExtract:
  286. case spv::Op::OpCompositeInsert:
  287. case spv::Op::OpLogicalOr:
  288. case spv::Op::OpLogicalAnd:
  289. case spv::Op::OpLogicalNot:
  290. case spv::Op::OpLogicalEqual:
  291. case spv::Op::OpLogicalNotEqual:
  292. case spv::Op::OpIEqual:
  293. case spv::Op::OpINotEqual:
  294. case spv::Op::OpULessThan:
  295. case spv::Op::OpSLessThan:
  296. case spv::Op::OpUGreaterThan:
  297. case spv::Op::OpSGreaterThan:
  298. case spv::Op::OpULessThanEqual:
  299. case spv::Op::OpSLessThanEqual:
  300. case spv::Op::OpUGreaterThanEqual:
  301. case spv::Op::OpSGreaterThanEqual:
  302. return true;
  303. default:
  304. // Accepted binary opcodes return true. Anything else is false.
  305. return false;
  306. }
  307. return false;
  308. }
  309. /// Returns true if the given expression is an accepted initializer for a spec
  310. /// constant.
  311. bool isAcceptedSpecConstantInit(const Expr *init) {
  312. // Allow numeric casts
  313. init = init->IgnoreParenCasts();
  314. if (isa<CXXBoolLiteralExpr>(init) || isa<IntegerLiteral>(init) ||
  315. isa<FloatingLiteral>(init))
  316. return true;
  317. // Allow the minus operator which is used to specify negative values
  318. if (const auto *unaryOp = dyn_cast<UnaryOperator>(init))
  319. return unaryOp->getOpcode() == UO_Minus &&
  320. isAcceptedSpecConstantInit(unaryOp->getSubExpr());
  321. return false;
  322. }
  323. /// Returns true if the given function parameter can act as shader stage
  324. /// input parameter.
  325. inline bool canActAsInParmVar(const ParmVarDecl *param) {
  326. // If the parameter has no in/out/inout attribute, it is defaulted to
  327. // an in parameter.
  328. return !param->hasAttr<HLSLOutAttr>() &&
  329. // GS output streams are marked as inout, but it should not be
  330. // used as in parameter.
  331. !hlsl::IsHLSLStreamOutputType(param->getType());
  332. }
  333. /// Returns true if the given function parameter can act as shader stage
  334. /// output parameter.
  335. inline bool canActAsOutParmVar(const ParmVarDecl *param) {
  336. return param->hasAttr<HLSLOutAttr>() || param->hasAttr<HLSLInOutAttr>();
  337. }
  338. /// Returns true if the given expression is of builtin type and can be evaluated
  339. /// to a constant zero. Returns false otherwise.
  340. inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) {
  341. const auto type = expr->getType();
  342. if (!type->isBuiltinType())
  343. return false;
  344. Expr::EvalResult evalResult;
  345. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  346. !evalResult.HasSideEffects) {
  347. const auto &val = evalResult.Val;
  348. return ((type->isBooleanType() && !val.getInt().getBoolValue()) ||
  349. (type->isIntegerType() && !val.getInt().getBoolValue()) ||
  350. (type->isFloatingType() && val.getFloat().isZero()));
  351. }
  352. return false;
  353. }
  354. /// Returns the HLSLBufferDecl if the given VarDecl is inside a cbuffer/tbuffer.
  355. /// Returns nullptr otherwise, including varDecl is a ConstantBuffer or
  356. /// TextureBuffer itself.
  357. inline const HLSLBufferDecl *getCTBufferContext(const VarDecl *varDecl) {
  358. if (const auto *bufferDecl =
  359. dyn_cast<HLSLBufferDecl>(varDecl->getDeclContext()))
  360. // Filter ConstantBuffer/TextureBuffer
  361. if (!bufferDecl->isConstantBufferView())
  362. return bufferDecl;
  363. return nullptr;
  364. }
  365. /// Returns the real definition of the callee of the given CallExpr.
  366. ///
  367. /// If we are calling a forward-declared function, callee will be the
  368. /// FunctionDecl for the foward-declared function, not the actual
  369. /// definition. The foward-delcaration and defintion are two completely
  370. /// different AST nodes.
  371. inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) {
  372. const auto *callee = expr->getDirectCallee();
  373. if (callee->isThisDeclarationADefinition())
  374. return callee;
  375. // We need to update callee to the actual definition here
  376. if (!callee->isDefined(callee))
  377. return nullptr;
  378. return callee;
  379. }
  380. /// Returns the referenced definition. The given expr is expected to be a
  381. /// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise.
  382. const DeclaratorDecl *getReferencedDef(const Expr *expr) {
  383. if (!expr)
  384. return nullptr;
  385. expr = expr->IgnoreParenCasts();
  386. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  387. return dyn_cast_or_null<DeclaratorDecl>(declRefExpr->getDecl());
  388. }
  389. if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
  390. return getCalleeDefinition(callExpr);
  391. }
  392. return nullptr;
  393. }
  394. /// Returns the number of base classes if this type is a derived class/struct.
  395. /// Returns zero otherwise.
  396. inline uint32_t getNumBaseClasses(QualType type) {
  397. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  398. return cxxDecl->getNumBases();
  399. return 0;
  400. }
  401. /// Gets the index sequence of casting a derived object to a base object by
  402. /// following the cast chain.
  403. void getBaseClassIndices(const CastExpr *expr,
  404. llvm::SmallVectorImpl<uint32_t> *indices) {
  405. assert(expr->getCastKind() == CK_UncheckedDerivedToBase ||
  406. expr->getCastKind() == CK_HLSLDerivedToBase);
  407. indices->clear();
  408. QualType derivedType = expr->getSubExpr()->getType();
  409. const auto *derivedDecl = derivedType->getAsCXXRecordDecl();
  410. // Go through the base cast chain: for each of the derived to base cast, find
  411. // the index of the base in question in the derived's bases.
  412. for (auto pathIt = expr->path_begin(), pathIe = expr->path_end();
  413. pathIt != pathIe; ++pathIt) {
  414. // The type of the base in question
  415. const auto baseType = (*pathIt)->getType();
  416. uint32_t index = 0;
  417. for (auto baseIt = derivedDecl->bases_begin(),
  418. baseIe = derivedDecl->bases_end();
  419. baseIt != baseIe; ++baseIt, ++index)
  420. if (baseIt->getType() == baseType) {
  421. indices->push_back(index);
  422. break;
  423. }
  424. assert(index < derivedDecl->getNumBases());
  425. // Continue to proceed the next base in the chain
  426. derivedType = baseType;
  427. derivedDecl = derivedType->getAsCXXRecordDecl();
  428. }
  429. }
  430. spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) {
  431. switch (opcode) {
  432. case spv::Op::OpGroupNonUniformElect:
  433. return spv::Capability::GroupNonUniform;
  434. case spv::Op::OpGroupNonUniformAny:
  435. case spv::Op::OpGroupNonUniformAll:
  436. case spv::Op::OpGroupNonUniformAllEqual:
  437. return spv::Capability::GroupNonUniformVote;
  438. case spv::Op::OpGroupNonUniformBallot:
  439. case spv::Op::OpGroupNonUniformBallotBitCount:
  440. case spv::Op::OpGroupNonUniformBroadcast:
  441. case spv::Op::OpGroupNonUniformBroadcastFirst:
  442. return spv::Capability::GroupNonUniformBallot;
  443. case spv::Op::OpGroupNonUniformIAdd:
  444. case spv::Op::OpGroupNonUniformFAdd:
  445. case spv::Op::OpGroupNonUniformIMul:
  446. case spv::Op::OpGroupNonUniformFMul:
  447. case spv::Op::OpGroupNonUniformSMax:
  448. case spv::Op::OpGroupNonUniformUMax:
  449. case spv::Op::OpGroupNonUniformFMax:
  450. case spv::Op::OpGroupNonUniformSMin:
  451. case spv::Op::OpGroupNonUniformUMin:
  452. case spv::Op::OpGroupNonUniformFMin:
  453. case spv::Op::OpGroupNonUniformBitwiseAnd:
  454. case spv::Op::OpGroupNonUniformBitwiseOr:
  455. case spv::Op::OpGroupNonUniformBitwiseXor:
  456. return spv::Capability::GroupNonUniformArithmetic;
  457. case spv::Op::OpGroupNonUniformQuadBroadcast:
  458. case spv::Op::OpGroupNonUniformQuadSwap:
  459. return spv::Capability::GroupNonUniformQuad;
  460. default:
  461. assert(false && "unhandled opcode");
  462. break;
  463. }
  464. assert(false && "unhandled opcode");
  465. return spv::Capability::Max;
  466. }
  467. std::string getNamespacePrefix(const Decl *decl) {
  468. std::string nsPrefix = "";
  469. const DeclContext *dc = decl->getDeclContext();
  470. while (dc && !dc->isTranslationUnit()) {
  471. if (const NamespaceDecl *ns = dyn_cast<NamespaceDecl>(dc)) {
  472. if (!ns->isAnonymousNamespace()) {
  473. nsPrefix = ns->getName().str() + "::" + nsPrefix;
  474. }
  475. }
  476. dc = dc->getParent();
  477. }
  478. return nsPrefix;
  479. }
  480. std::string getFnName(const FunctionDecl *fn) {
  481. // Prefix the function name with the struct name if necessary
  482. std::string classOrStructName = "";
  483. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(fn))
  484. if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
  485. classOrStructName = st->getName().str() + ".";
  486. return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
  487. }
  488. /// Returns the capability required to non-uniformly index into the given type.
  489. spv::Capability getNonUniformCapability(QualType type) {
  490. using spv::Capability;
  491. if (type->isArrayType()) {
  492. return getNonUniformCapability(
  493. type->getAsArrayTypeUnsafe()->getElementType());
  494. }
  495. if (TypeTranslator::isTexture(type) || TypeTranslator::isSampler(type)) {
  496. return Capability::SampledImageArrayNonUniformIndexingEXT;
  497. }
  498. if (TypeTranslator::isRWTexture(type)) {
  499. return Capability::StorageImageArrayNonUniformIndexingEXT;
  500. }
  501. if (TypeTranslator::isBuffer(type)) {
  502. return Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
  503. }
  504. if (TypeTranslator::isRWBuffer(type)) {
  505. return Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
  506. }
  507. if (const auto *recordType = type->getAs<RecordType>()) {
  508. const auto name = recordType->getDecl()->getName();
  509. if (name == "SubpassInput" || name == "SubpassInputMS") {
  510. return Capability::InputAttachmentArrayNonUniformIndexingEXT;
  511. }
  512. }
  513. return Capability::Max;
  514. }
  515. } // namespace
  516. SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
  517. : theCompilerInstance(ci), astContext(ci.getASTContext()),
  518. diags(ci.getDiagnostics()),
  519. spirvOptions(ci.getCodeGenOpts().SpirvOptions),
  520. entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
  521. shaderModel(*hlsl::ShaderModel::GetByName(
  522. ci.getCodeGenOpts().HLSLProfile.c_str())),
  523. theContext(), featureManager(diags, spirvOptions),
  524. theBuilder(&theContext, &featureManager, spirvOptions),
  525. typeTranslator(astContext, theBuilder, diags, spirvOptions),
  526. declIdMapper(shaderModel, astContext, theBuilder, *this, typeTranslator,
  527. featureManager, spirvOptions),
  528. entryFunctionId(0), curFunction(nullptr), curThis(0),
  529. seenPushConstantAt(), isSpecConstantMode(false),
  530. foundNonUniformResourceIndex(false), needsLegalization(false),
  531. mainSourceFileId(0) {
  532. if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
  533. emitError("unknown shader module: %0", {}) << shaderModel.GetName();
  534. if (spirvOptions.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
  535. !shaderModel.IsGS())
  536. emitError("-fvk-invert-y can only be used in VS/DS/GS", {});
  537. if (spirvOptions.useGlLayout && spirvOptions.useDxLayout)
  538. emitError("cannot specify both -fvk-use-dx-layout and -fvk-use-gl-layout",
  539. {});
  540. if (spirvOptions.useDxLayout) {
  541. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
  542. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
  543. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
  544. } else if (spirvOptions.useGlLayout) {
  545. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
  546. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
  547. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
  548. } else {
  549. spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
  550. spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
  551. spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
  552. }
  553. // Set shader module version
  554. theBuilder.setShaderModelVersion(shaderModel.GetMajor(),
  555. shaderModel.GetMinor());
  556. // Set debug info
  557. const auto &inputFiles = ci.getFrontendOpts().Inputs;
  558. if (spirvOptions.debugInfoFile && !inputFiles.empty()) {
  559. // File name
  560. mainSourceFileId = theContext.takeNextId();
  561. theBuilder.setSourceFileName(mainSourceFileId,
  562. inputFiles.front().getFile().str());
  563. // Source code
  564. const auto &sm = ci.getSourceManager();
  565. const llvm::MemoryBuffer *mainFile =
  566. sm.getBuffer(sm.getMainFileID(), SourceLocation());
  567. theBuilder.setSourceFileContent(
  568. StringRef(mainFile->getBufferStart(), mainFile->getBufferSize()));
  569. }
  570. }
  571. void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
  572. // Stop translating if there are errors in previous compilation stages.
  573. if (context.getDiagnostics().hasErrorOccurred())
  574. return;
  575. TranslationUnitDecl *tu = context.getTranslationUnitDecl();
  576. // The entry function is the seed of the queue.
  577. for (auto *decl : tu->decls()) {
  578. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  579. if (funcDecl->getName() == entryFunctionName) {
  580. workQueue.insert(funcDecl);
  581. }
  582. } else {
  583. doDecl(decl);
  584. }
  585. }
  586. // Translate all functions reachable from the entry function.
  587. // The queue can grow in the meanwhile; so need to keep evaluating
  588. // workQueue.size().
  589. for (uint32_t i = 0; i < workQueue.size(); ++i) {
  590. doDecl(workQueue[i]);
  591. }
  592. if (context.getDiagnostics().hasErrorOccurred())
  593. return;
  594. const spv_target_env targetEnv = featureManager.getTargetEnv();
  595. AddRequiredCapabilitiesForShaderModel();
  596. // Addressing and memory model are required in a valid SPIR-V module.
  597. theBuilder.setAddressingModel(spv::AddressingModel::Logical);
  598. theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
  599. theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId,
  600. entryFunctionName, declIdMapper.collectStageVars());
  601. // Add Location decorations to stage input/output variables.
  602. if (!declIdMapper.decorateStageIOLocations())
  603. return;
  604. // Add descriptor set and binding decorations to resource variables.
  605. if (!declIdMapper.decorateResourceBindings())
  606. return;
  607. // Output the constructed module.
  608. std::vector<uint32_t> m = theBuilder.takeModule();
  609. if (!spirvOptions.codeGenHighLevel) {
  610. // Run legalization passes
  611. if (needsLegalization || declIdMapper.requiresLegalization()) {
  612. std::string messages;
  613. if (!spirvToolsLegalize(targetEnv, &m, &messages)) {
  614. emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
  615. emitNote("please file a bug report on "
  616. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  617. "with source code if possible",
  618. {});
  619. return;
  620. } else if (!messages.empty()) {
  621. emitWarning("SPIR-V legalization: %0", {}) << messages;
  622. }
  623. }
  624. // Run optimization passes
  625. if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) {
  626. std::string messages;
  627. if (!spirvToolsOptimize(targetEnv, &m, spirvOptions.optConfig,
  628. &messages)) {
  629. emitFatalError("failed to optimize SPIR-V: %0", {}) << messages;
  630. emitNote("please file a bug report on "
  631. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  632. "with source code if possible",
  633. {});
  634. return;
  635. }
  636. }
  637. }
  638. // Validate the generated SPIR-V code
  639. if (!spirvOptions.disableValidation) {
  640. std::string messages;
  641. if (!spirvToolsValidate(
  642. targetEnv, &m, &messages, declIdMapper.requiresLegalization(),
  643. spirvOptions.useGlLayout, spirvOptions.useDxLayout)) {
  644. emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
  645. emitNote("please file a bug report on "
  646. "https://github.com/Microsoft/DirectXShaderCompiler/issues "
  647. "with source code if possible",
  648. {});
  649. return;
  650. }
  651. }
  652. theCompilerInstance.getOutStream()->write(
  653. reinterpret_cast<const char *>(m.data()), m.size() * 4);
  654. }
  655. void SPIRVEmitter::doDecl(const Decl *decl) {
  656. if (decl->isImplicit() || isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl))
  657. return;
  658. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  659. // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need
  660. // to emit their cbuffer/tbuffer as a whole and access each individual one
  661. // using access chains.
  662. if (const auto *bufferDecl = getCTBufferContext(varDecl)) {
  663. doHLSLBufferDecl(bufferDecl);
  664. } else {
  665. doVarDecl(varDecl);
  666. }
  667. } else if (const auto *namespaceDecl = dyn_cast<NamespaceDecl>(decl)) {
  668. for (auto *subDecl : namespaceDecl->decls())
  669. // Note: We only emit functions as they are discovered through the call
  670. // graph starting from the entry-point. We should not emit unused
  671. // functions inside namespaces.
  672. if (!isa<FunctionDecl>(subDecl))
  673. doDecl(subDecl);
  674. } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  675. doFunctionDecl(funcDecl);
  676. } else if (const auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
  677. doHLSLBufferDecl(bufferDecl);
  678. } else if (const auto *recordDecl = dyn_cast<RecordDecl>(decl)) {
  679. doRecordDecl(recordDecl);
  680. } else {
  681. emitError("decl type %0 unimplemented", decl->getLocation())
  682. << decl->getDeclKindName();
  683. }
  684. }
  685. void SPIRVEmitter::doStmt(const Stmt *stmt,
  686. llvm::ArrayRef<const Attr *> attrs) {
  687. if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
  688. for (auto *st : compoundStmt->body())
  689. doStmt(st);
  690. } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
  691. doReturnStmt(retStmt);
  692. } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
  693. doDeclStmt(declStmt);
  694. } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
  695. doIfStmt(ifStmt, attrs);
  696. } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
  697. doSwitchStmt(switchStmt, attrs);
  698. } else if (dyn_cast<CaseStmt>(stmt)) {
  699. processCaseStmtOrDefaultStmt(stmt);
  700. } else if (dyn_cast<DefaultStmt>(stmt)) {
  701. processCaseStmtOrDefaultStmt(stmt);
  702. } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
  703. doBreakStmt(breakStmt);
  704. } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
  705. doDoStmt(theDoStmt, attrs);
  706. } else if (const auto *discardStmt = dyn_cast<DiscardStmt>(stmt)) {
  707. doDiscardStmt(discardStmt);
  708. } else if (const auto *continueStmt = dyn_cast<ContinueStmt>(stmt)) {
  709. doContinueStmt(continueStmt);
  710. } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
  711. doWhileStmt(whileStmt, attrs);
  712. } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
  713. doForStmt(forStmt, attrs);
  714. } else if (dyn_cast<NullStmt>(stmt)) {
  715. // For the null statement ";". We don't need to do anything.
  716. } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
  717. // All cases for expressions used as statements
  718. doExpr(expr);
  719. } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
  720. doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
  721. } else {
  722. emitError("statement class '%0' unimplemented", stmt->getLocStart())
  723. << stmt->getStmtClassName() << stmt->getSourceRange();
  724. }
  725. }
  726. SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
  727. SpirvEvalInfo result(/*id*/ 0);
  728. // Provide a hint to the typeTranslator that if a literal is discovered, its
  729. // intended usage is as this expression type.
  730. TypeTranslator::LiteralTypeHint hint(typeTranslator, expr->getType());
  731. expr = expr->IgnoreParens();
  732. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  733. result = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl());
  734. } else if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
  735. result = doMemberExpr(memberExpr);
  736. } else if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
  737. result = doCastExpr(castExpr);
  738. } else if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
  739. result = doInitListExpr(initListExpr);
  740. } else if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
  741. const auto value =
  742. theBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode);
  743. result = SpirvEvalInfo(value).setConstant().setRValue();
  744. } else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  745. const auto value = translateAPInt(intLiteral->getValue(), expr->getType());
  746. result = SpirvEvalInfo(value).setConstant().setRValue();
  747. } else if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  748. const auto value =
  749. translateAPFloat(floatLiteral->getValue(), expr->getType());
  750. result = SpirvEvalInfo(value).setConstant().setRValue();
  751. } else if (const auto *compoundAssignOp =
  752. dyn_cast<CompoundAssignOperator>(expr)) {
  753. // CompoundAssignOperator is a subclass of BinaryOperator. It should be
  754. // checked before BinaryOperator.
  755. result = doCompoundAssignOperator(compoundAssignOp);
  756. } else if (const auto *binOp = dyn_cast<BinaryOperator>(expr)) {
  757. result = doBinaryOperator(binOp);
  758. } else if (const auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
  759. result = doUnaryOperator(unaryOp);
  760. } else if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  761. result = doHLSLVectorElementExpr(vecElemExpr);
  762. } else if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
  763. result = doExtMatrixElementExpr(matElemExpr);
  764. } else if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
  765. result = doCallExpr(funcCall);
  766. } else if (const auto *subscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  767. result = doArraySubscriptExpr(subscriptExpr);
  768. } else if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
  769. result = doConditionalOperator(condExpr);
  770. } else if (const auto *defaultArgExpr = dyn_cast<CXXDefaultArgExpr>(expr)) {
  771. result = doExpr(defaultArgExpr->getParam()->getDefaultArg());
  772. } else if (isa<CXXThisExpr>(expr)) {
  773. assert(curThis);
  774. result = curThis;
  775. } else {
  776. emitError("expression class '%0' unimplemented", expr->getExprLoc())
  777. << expr->getStmtClassName() << expr->getSourceRange();
  778. }
  779. return result;
  780. }
  781. SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
  782. // We are trying to load the value here, which is what an LValueToRValue
  783. // implicit cast is intended to do. We can ignore the cast if exists.
  784. expr = expr->IgnoreParenLValueCasts();
  785. return loadIfGLValue(expr, doExpr(expr));
  786. }
  787. SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
  788. SpirvEvalInfo info) {
  789. // Do nothing if this is already rvalue
  790. if (info.isRValue())
  791. return info;
  792. // Check whether we are trying to load an array of opaque objects as a whole.
  793. // If true, we are likely to copy it as a whole. To assist per-element
  794. // copying, avoid the load here and return the pointer directly.
  795. // TODO: consider moving this hack into SPIRV-Tools as a transformation.
  796. if (TypeTranslator::isOpaqueArrayType(expr->getType()))
  797. return info;
  798. // Check whether we are trying to load an externally visible structured/byte
  799. // buffer as a whole. If true, it means we are creating alias for it. Avoid
  800. // the load and write the pointer directly to the alias variable then.
  801. //
  802. // Also for the case of alias function returns. If we are trying to load an
  803. // alias function return as a whole, it means we are assigning it to another
  804. // alias variable. Avoid the load and write the pointer directly.
  805. //
  806. // Note: legalization specific code
  807. if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
  808. return info.setRValue();
  809. }
  810. if (loadIfAliasVarRef(expr, info)) {
  811. // We are loading an alias variable as a whole here. This is likely for
  812. // wholesale assignments or function returns. Need to load the pointer.
  813. //
  814. // Note: legalization specific code
  815. return info;
  816. }
  817. uint32_t valType = 0;
  818. // TODO: Ouch. Very hacky. We need special path to get the value type if
  819. // we are loading a whole ConstantBuffer/TextureBuffer since the normal
  820. // type translation path won't work.
  821. if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
  822. valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
  823. } else {
  824. valType =
  825. typeTranslator.translateType(expr->getType(), info.getLayoutRule());
  826. }
  827. uint32_t loadedId = theBuilder.createLoad(valType, info);
  828. // Decorate with NonUniformEXT if loading from a pointer with that property.
  829. // We are likely loading an element from the resource array here.
  830. if (info.isNonUniform()) {
  831. theBuilder.decorateNonUniformEXT(loadedId);
  832. }
  833. // Special-case: According to the SPIR-V Spec: There is no physical size or
  834. // bit pattern defined for boolean type. Therefore an unsigned integer is used
  835. // to represent booleans when layout is required. In such cases, after loading
  836. // the uint, we should perform a comparison.
  837. {
  838. uint32_t vecSize = 1, numRows = 0, numCols = 0;
  839. if (info.getLayoutRule() != SpirvLayoutRule::Void &&
  840. isBoolOrVecMatOfBoolType(expr->getType())) {
  841. const auto exprType = expr->getType();
  842. QualType uintType = astContext.UnsignedIntTy;
  843. QualType boolType = astContext.BoolTy;
  844. if (TypeTranslator::isScalarType(exprType) ||
  845. TypeTranslator::isVectorType(exprType, nullptr, &vecSize)) {
  846. const auto fromType =
  847. vecSize == 1 ? uintType
  848. : astContext.getExtVectorType(uintType, vecSize);
  849. const auto toType =
  850. vecSize == 1 ? boolType
  851. : astContext.getExtVectorType(boolType, vecSize);
  852. loadedId = castToBool(loadedId, fromType, toType);
  853. } else {
  854. const bool isMat =
  855. TypeTranslator::isMxNMatrix(exprType, nullptr, &numRows, &numCols);
  856. assert(isMat);
  857. (void)isMat;
  858. const auto uintRowQualType =
  859. astContext.getExtVectorType(uintType, numCols);
  860. const auto uintRowQualTypeId =
  861. typeTranslator.translateType(uintRowQualType);
  862. const auto boolRowQualType =
  863. astContext.getExtVectorType(boolType, numCols);
  864. const auto boolRowQualTypeId =
  865. typeTranslator.translateType(boolRowQualType);
  866. const uint32_t resultTypeId =
  867. theBuilder.getMatType(boolType, boolRowQualTypeId, numRows);
  868. llvm::SmallVector<uint32_t, 4> rows;
  869. for (uint32_t i = 0; i < numRows; ++i) {
  870. const auto row = theBuilder.createCompositeExtract(uintRowQualTypeId,
  871. loadedId, {i});
  872. rows.push_back(castToBool(row, uintRowQualType, boolRowQualType));
  873. }
  874. loadedId = theBuilder.createCompositeConstruct(resultTypeId, rows);
  875. }
  876. // Now that it is converted to Bool, it has no layout rule.
  877. // This result-id should be evaluated as bool from here on out.
  878. info.setLayoutRule(SpirvLayoutRule::Void);
  879. }
  880. }
  881. return info.setResultId(loadedId).setRValue();
  882. }
  883. SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
  884. auto info = doExpr(expr);
  885. loadIfAliasVarRef(expr, info);
  886. return info;
  887. }
  888. bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) {
  889. if (info.containsAliasComponent() &&
  890. TypeTranslator::isAKindOfStructuredOrByteBuffer(varExpr->getType())) {
  891. // Aliased-to variables are all in the Uniform storage class with GLSL
  892. // std430 layout rules.
  893. const auto ptrType = typeTranslator.translateType(varExpr->getType());
  894. // Load the pointer of the aliased-to-variable if the expression has a
  895. // pointer to pointer type. That is, the expression itself is a lvalue.
  896. // (Note that we translate alias function return values as pointer types,
  897. // not pointer to pointer types.)
  898. if (varExpr->isGLValue())
  899. info.setResultId(theBuilder.createLoad(ptrType, info));
  900. info.setStorageClass(spv::StorageClass::Uniform)
  901. .setLayoutRule(spirvOptions.sBufferLayoutRule)
  902. // Now it is a pointer to the global resource, which is lvalue.
  903. .setRValue(false)
  904. // Set to false to indicate that we've performed dereference over the
  905. // pointer-to-pointer and now should fallback to the normal path
  906. .setContainsAliasComponent(false);
  907. return true;
  908. }
  909. return false;
  910. }
  911. uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
  912. QualType toType, SourceLocation srcLoc) {
  913. if (isFloatOrVecOfFloatType(toType))
  914. return castToFloat(value, fromType, toType, srcLoc);
  915. // Order matters here. Bool (vector) values will also be considered as uint
  916. // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
  917. // will also return true. We need to check bool before uint. The opposite is
  918. // not true.
  919. if (isBoolOrVecOfBoolType(toType))
  920. return castToBool(value, fromType, toType);
  921. if (isSintOrVecOfSintType(toType) || isUintOrVecOfUintType(toType))
  922. return castToInt(value, fromType, toType, srcLoc);
  923. emitError("casting to type %0 unimplemented", {}) << toType;
  924. return 0;
  925. }
  926. void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
  927. assert(decl->isThisDeclarationADefinition());
  928. // A RAII class for maintaining the current function under traversal.
  929. class FnEnvRAII {
  930. public:
  931. // Creates a new instance which sets fnEnv to the newFn on creation,
  932. // and resets fnEnv to its original value on destruction.
  933. FnEnvRAII(const FunctionDecl **fnEnv, const FunctionDecl *newFn)
  934. : oldFn(*fnEnv), fnSlot(fnEnv) {
  935. *fnEnv = newFn;
  936. }
  937. ~FnEnvRAII() { *fnSlot = oldFn; }
  938. private:
  939. const FunctionDecl *oldFn;
  940. const FunctionDecl **fnSlot;
  941. };
  942. FnEnvRAII fnEnvRAII(&curFunction, decl);
  943. // We are about to start translation for a new function. Clear the break stack
  944. // and the continue stack.
  945. breakStack = std::stack<uint32_t>();
  946. continueStack = std::stack<uint32_t>();
  947. // This will allow the entry-point name to be something like
  948. // myNamespace::myEntrypointFunc.
  949. std::string funcName = getFnName(decl);
  950. uint32_t funcId = 0;
  951. if (funcName == entryFunctionName) {
  952. // The entry function surely does not have pre-assigned <result-id> for
  953. // it like other functions that got added to the work queue following
  954. // function calls.
  955. funcId = theContext.takeNextId();
  956. funcName = "src." + funcName;
  957. // Create wrapper for the entry function
  958. if (!emitEntryFunctionWrapper(decl, funcId))
  959. return;
  960. } else {
  961. // Non-entry functions are added to the work queue following function
  962. // calls. We have already assigned <result-id>s for it when translating
  963. // its call site. Query it here.
  964. funcId = declIdMapper.getDeclEvalInfo(decl);
  965. }
  966. const uint32_t retType =
  967. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
  968. // Construct the function signature.
  969. llvm::SmallVector<uint32_t, 4> paramTypes;
  970. bool isNonStaticMemberFn = false;
  971. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
  972. isNonStaticMemberFn = !memberFn->isStatic();
  973. if (isNonStaticMemberFn) {
  974. // For non-static member function, the first parameter should be the
  975. // object on which we are invoking this method.
  976. const uint32_t valueType = typeTranslator.translateType(
  977. memberFn->getThisType(astContext)->getPointeeType());
  978. const uint32_t ptrType =
  979. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  980. paramTypes.push_back(ptrType);
  981. }
  982. }
  983. for (const auto *param : decl->params()) {
  984. const uint32_t valueType =
  985. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
  986. const uint32_t ptrType =
  987. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  988. paramTypes.push_back(ptrType);
  989. }
  990. const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
  991. theBuilder.beginFunction(funcType, retType, funcName, funcId);
  992. if (isNonStaticMemberFn) {
  993. // Remember the parameter for the this object so later we can handle
  994. // CXXThisExpr correctly.
  995. curThis = theBuilder.addFnParam(paramTypes[0], "param.this");
  996. }
  997. // Create all parameters.
  998. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  999. const ParmVarDecl *paramDecl = decl->getParamDecl(i);
  1000. (void)declIdMapper.createFnParam(paramDecl);
  1001. }
  1002. if (decl->hasBody()) {
  1003. // The entry basic block.
  1004. const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
  1005. theBuilder.setInsertPoint(entryLabel);
  1006. // Process all statments in the body.
  1007. doStmt(decl->getBody());
  1008. // We have processed all Stmts in this function and now in the last
  1009. // basic block. Make sure we have a termination instruction.
  1010. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  1011. const auto retType = decl->getReturnType();
  1012. if (retType->isVoidType()) {
  1013. theBuilder.createReturn();
  1014. } else {
  1015. // If the source code does not provide a proper return value for some
  1016. // control flow path, it's undefined behavior. We just return null
  1017. // value here.
  1018. theBuilder.createReturnValue(
  1019. theBuilder.getConstantNull(typeTranslator.translateType(retType)));
  1020. }
  1021. }
  1022. }
  1023. theBuilder.endFunction();
  1024. }
  1025. bool SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) {
  1026. bool success = true;
  1027. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  1028. const auto varType = varDecl->getType();
  1029. if ((TypeTranslator::isSubpassInput(varType) ||
  1030. TypeTranslator::isSubpassInputMS(varType)) &&
  1031. !varDecl->hasAttr<VKInputAttachmentIndexAttr>()) {
  1032. emitError("missing vk::input_attachment_index attribute",
  1033. varDecl->getLocation());
  1034. success = false;
  1035. }
  1036. }
  1037. if (decl->getAttr<VKInputAttachmentIndexAttr>()) {
  1038. if (!shaderModel.IsPS()) {
  1039. emitError("SubpassInput(MS) only allowed in pixel shader",
  1040. decl->getLocation());
  1041. success = false;
  1042. }
  1043. if (!decl->isExternallyVisible()) {
  1044. emitError("SubpassInput(MS) must be externally visible",
  1045. decl->getLocation());
  1046. success = false;
  1047. }
  1048. // We only allow VKInputAttachmentIndexAttr to be attached to global
  1049. // variables. So it should be fine to cast here.
  1050. const auto elementType =
  1051. hlsl::GetHLSLResourceResultType(cast<VarDecl>(decl)->getType());
  1052. if (!TypeTranslator::isScalarType(elementType) &&
  1053. !TypeTranslator::isVectorType(elementType)) {
  1054. emitError(
  1055. "only scalar/vector types allowed as SubpassInput(MS) parameter type",
  1056. decl->getLocation());
  1057. // Return directly to avoid further type processing, which will hit
  1058. // asserts in TypeTranslator.
  1059. return false;
  1060. }
  1061. }
  1062. // The frontend will make sure that
  1063. // * vk::push_constant applies to global variables of struct type
  1064. // * vk::binding applies to global variables or cbuffers/tbuffers
  1065. // * vk::counter_binding applies to global variables of RW/Append/Consume
  1066. // StructuredBuffer
  1067. // * vk::location applies to function parameters/returns and struct fields
  1068. // So the only case we need to check co-existence is vk::push_constant and
  1069. // vk::binding.
  1070. if (const auto *pcAttr = decl->getAttr<VKPushConstantAttr>()) {
  1071. const auto loc = pcAttr->getLocation();
  1072. if (seenPushConstantAt.isInvalid()) {
  1073. seenPushConstantAt = loc;
  1074. } else {
  1075. // TODO: Actually this is slightly incorrect. The Vulkan spec says:
  1076. // There must be no more than one push constant block statically used
  1077. // per shader entry point.
  1078. // But we are checking whether there are more than one push constant
  1079. // blocks defined. Tracking usage requires more work.
  1080. emitError("cannot have more than one push constant block", loc);
  1081. emitNote("push constant block previously defined here",
  1082. seenPushConstantAt);
  1083. success = false;
  1084. }
  1085. if (decl->hasAttr<VKBindingAttr>()) {
  1086. emitError("vk::push_constant attribute cannot be used together with "
  1087. "vk::binding attribute",
  1088. loc);
  1089. success = false;
  1090. }
  1091. }
  1092. return success;
  1093. }
  1094. void SPIRVEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
  1095. // This is a cbuffer/tbuffer decl.
  1096. // Check and emit warnings for member intializers which are not
  1097. // supported in Vulkan
  1098. for (const auto *member : bufferDecl->decls()) {
  1099. if (const auto *varMember = dyn_cast<VarDecl>(member)) {
  1100. if (!spirvOptions.noWarnIgnoredFeatures) {
  1101. if (const auto *init = varMember->getInit())
  1102. emitWarning("%select{tbuffer|cbuffer}0 member initializer "
  1103. "ignored since no Vulkan equivalent",
  1104. init->getExprLoc())
  1105. << bufferDecl->isCBuffer() << init->getSourceRange();
  1106. }
  1107. // We cannot handle external initialization of column-major matrices now.
  1108. if (typeTranslator.isOrContainsNonFpColMajorMatrix(varMember->getType(),
  1109. varMember)) {
  1110. emitError("externally initialized non-floating-point column-major "
  1111. "matrices not supported yet",
  1112. varMember->getLocation());
  1113. }
  1114. }
  1115. }
  1116. if (!validateVKAttributes(bufferDecl))
  1117. return;
  1118. (void)declIdMapper.createCTBuffer(bufferDecl);
  1119. }
  1120. void SPIRVEmitter::doRecordDecl(const RecordDecl *recordDecl) {
  1121. // Ignore implict records
  1122. // Somehow we'll have implicit records with:
  1123. // static const int Length = count;
  1124. // that can mess up with the normal CodeGen.
  1125. if (recordDecl->isImplicit())
  1126. return;
  1127. // Handle each static member with inline initializer.
  1128. // Each static member has a corresponding VarDecl inside the
  1129. // RecordDecl. For those defined in the translation unit,
  1130. // their VarDecls do not have initializer.
  1131. for (auto *subDecl : recordDecl->decls())
  1132. if (auto *varDecl = dyn_cast<VarDecl>(subDecl))
  1133. if (varDecl->isStaticDataMember() && varDecl->hasInit())
  1134. doVarDecl(varDecl);
  1135. }
  1136. void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
  1137. if (!validateVKAttributes(decl))
  1138. return;
  1139. // We cannot handle external initialization of column-major matrices now.
  1140. if (isExternalVar(decl) &&
  1141. typeTranslator.isOrContainsNonFpColMajorMatrix(decl->getType(), decl)) {
  1142. emitError("externally initialized non-floating-point column-major "
  1143. "matrices not supported yet",
  1144. decl->getLocation());
  1145. }
  1146. // Reject arrays of RW/append/consume structured buffers. They have assoicated
  1147. // counters, which are quite nasty to handle.
  1148. if (decl->getType()->isArrayType()) {
  1149. auto type = decl->getType();
  1150. do {
  1151. type = type->getAsArrayTypeUnsafe()->getElementType();
  1152. } while (type->isArrayType());
  1153. if (TypeTranslator::isRWAppendConsumeSBuffer(type)) {
  1154. emitError("arrays of RW/append/consume structured buffers unsupported",
  1155. decl->getLocation());
  1156. return;
  1157. }
  1158. }
  1159. if (decl->hasAttr<VKConstantIdAttr>()) {
  1160. // This is a VarDecl for specialization constant.
  1161. createSpecConstant(decl);
  1162. return;
  1163. }
  1164. if (decl->hasAttr<VKPushConstantAttr>()) {
  1165. // This is a VarDecl for PushConstant block.
  1166. (void)declIdMapper.createPushConstant(decl);
  1167. return;
  1168. }
  1169. if (isa<HLSLBufferDecl>(decl->getDeclContext())) {
  1170. // This is a VarDecl of a ConstantBuffer/TextureBuffer type.
  1171. (void)declIdMapper.createCTBuffer(decl);
  1172. return;
  1173. }
  1174. SpirvEvalInfo varId(0);
  1175. // The contents in externally visible variables can be updated via the
  1176. // pipeline. They should be handled differently from file and function scope
  1177. // variables.
  1178. // File scope variables (static "global" and "local" variables) belongs to
  1179. // the Private storage class, while function scope variables (normal "local"
  1180. // variables) belongs to the Function storage class.
  1181. if (isExternalVar(decl)) {
  1182. varId = declIdMapper.createExternVar(decl);
  1183. } else {
  1184. // We already know the variable is not externally visible here. If it does
  1185. // not have local storage, it should be file scope variable.
  1186. const bool isFileScopeVar = !decl->hasLocalStorage();
  1187. if (isFileScopeVar)
  1188. varId = declIdMapper.createFileVar(decl, llvm::None);
  1189. else
  1190. varId = declIdMapper.createFnVar(decl, llvm::None);
  1191. // Emit OpStore to initialize the variable
  1192. // TODO: revert back to use OpVariable initializer
  1193. // We should only evaluate the initializer once for a static variable.
  1194. if (isFileScopeVar) {
  1195. if (decl->isStaticLocal()) {
  1196. initOnce(decl->getType(), decl->getName(), varId, decl->getInit());
  1197. } else {
  1198. // Defer to initialize these global variables at the beginning of the
  1199. // entry function.
  1200. toInitGloalVars.push_back(decl);
  1201. }
  1202. }
  1203. // Function local variables. Just emit OpStore at the current insert point.
  1204. else if (const Expr *init = decl->getInit()) {
  1205. if (const auto constId = tryToEvaluateAsConst(init))
  1206. theBuilder.createStore(varId, constId);
  1207. else
  1208. storeValue(varId, loadIfGLValue(init), decl->getType());
  1209. // Update counter variable associated with local variables
  1210. tryToAssignCounterVar(decl, init);
  1211. }
  1212. // Variables that are not externally visible and of opaque types should
  1213. // request legalization.
  1214. if (!needsLegalization && TypeTranslator::isOpaqueType(decl->getType()))
  1215. needsLegalization = true;
  1216. }
  1217. if (TypeTranslator::isRelaxedPrecisionType(decl->getType(), spirvOptions)) {
  1218. theBuilder.decorateRelaxedPrecision(varId);
  1219. }
  1220. // All variables that are of opaque struct types should request legalization.
  1221. if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
  1222. needsLegalization = true;
  1223. }
  1224. spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Stmt *stmt,
  1225. const Attr &attr) {
  1226. switch (attr.getKind()) {
  1227. case attr::HLSLLoop:
  1228. case attr::HLSLFastOpt:
  1229. return spv::LoopControlMask::DontUnroll;
  1230. case attr::HLSLUnroll:
  1231. return spv::LoopControlMask::Unroll;
  1232. case attr::HLSLAllowUAVCondition:
  1233. if (!spirvOptions.noWarnIgnoredFeatures) {
  1234. emitWarning("unsupported allow_uav_condition attribute ignored",
  1235. stmt->getLocStart());
  1236. }
  1237. break;
  1238. default:
  1239. llvm_unreachable("found unknown loop attribute");
  1240. }
  1241. return spv::LoopControlMask::MaskNone;
  1242. }
  1243. void SPIRVEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
  1244. assert(!theBuilder.isCurrentBasicBlockTerminated());
  1245. theBuilder.createKill();
  1246. // Some statements that alter the control flow (break, continue, return, and
  1247. // discard), require creation of a new basic block to hold any statement that
  1248. // may follow them.
  1249. const uint32_t newBB = theBuilder.createBasicBlock();
  1250. theBuilder.setInsertPoint(newBB);
  1251. }
  1252. void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
  1253. llvm::ArrayRef<const Attr *> attrs) {
  1254. // do-while loops are composed of:
  1255. //
  1256. // do {
  1257. // <body>
  1258. // } while(<check>);
  1259. //
  1260. // SPIR-V requires loops to have a merge basic block as well as a continue
  1261. // basic block. Even though do-while loops do not have an explicit continue
  1262. // block as in for-loops, we still do need to create a continue block.
  1263. //
  1264. // Since SPIR-V requires structured control flow, we need two more basic
  1265. // blocks, <header> and <merge>. <header> is the block before control flow
  1266. // diverges, and <merge> is the block where control flow subsequently
  1267. // converges. The <check> can be performed in the <continue> basic block.
  1268. // The final CFG should normally be like the following. Exceptions
  1269. // will occur with non-local exits like loop breaks or early returns.
  1270. //
  1271. // +----------+
  1272. // | header | <-----------------------------------+
  1273. // +----------+ |
  1274. // | | (true)
  1275. // v |
  1276. // +------+ +--------------------+ |
  1277. // | body | ----> | continue (<check>) |-----------+
  1278. // +------+ +--------------------+
  1279. // |
  1280. // | (false)
  1281. // +-------+ |
  1282. // | merge | <-------------+
  1283. // +-------+
  1284. //
  1285. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1286. const spv::LoopControlMask loopControl =
  1287. attrs.empty() ? spv::LoopControlMask::MaskNone
  1288. : translateLoopAttribute(theDoStmt, *attrs.front());
  1289. // Create basic blocks
  1290. const uint32_t headerBB = theBuilder.createBasicBlock("do_while.header");
  1291. const uint32_t bodyBB = theBuilder.createBasicBlock("do_while.body");
  1292. const uint32_t continueBB = theBuilder.createBasicBlock("do_while.continue");
  1293. const uint32_t mergeBB = theBuilder.createBasicBlock("do_while.merge");
  1294. // Make sure any continue statements branch to the continue block, and any
  1295. // break statements branch to the merge block.
  1296. continueStack.push(continueBB);
  1297. breakStack.push(mergeBB);
  1298. // Branch from the current insert point to the header block.
  1299. theBuilder.createBranch(headerBB);
  1300. theBuilder.addSuccessor(headerBB);
  1301. // Process the <header> block
  1302. // The header block must always branch to the body.
  1303. theBuilder.setInsertPoint(headerBB);
  1304. theBuilder.createBranch(bodyBB, mergeBB, continueBB, loopControl);
  1305. theBuilder.addSuccessor(bodyBB);
  1306. // The current basic block has OpLoopMerge instruction. We need to set its
  1307. // continue and merge target.
  1308. theBuilder.setContinueTarget(continueBB);
  1309. theBuilder.setMergeTarget(mergeBB);
  1310. // Process the <body> block
  1311. theBuilder.setInsertPoint(bodyBB);
  1312. if (const Stmt *body = theDoStmt->getBody()) {
  1313. doStmt(body);
  1314. }
  1315. if (!theBuilder.isCurrentBasicBlockTerminated())
  1316. theBuilder.createBranch(continueBB);
  1317. theBuilder.addSuccessor(continueBB);
  1318. // Process the <continue> block. The check for whether the loop should
  1319. // continue lies in the continue block.
  1320. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in
  1321. // a continue block of a loop, there should be no OpSelectionMerge. Only an
  1322. // OpBranchConditional must be specified.
  1323. theBuilder.setInsertPoint(continueBB);
  1324. uint32_t condition = 0;
  1325. if (const Expr *check = theDoStmt->getCond()) {
  1326. emitDebugLine(check->getLocStart());
  1327. condition = doExpr(check);
  1328. } else {
  1329. condition = theBuilder.getConstantBool(true);
  1330. }
  1331. theBuilder.createConditionalBranch(condition, headerBB, mergeBB);
  1332. theBuilder.addSuccessor(headerBB);
  1333. theBuilder.addSuccessor(mergeBB);
  1334. // Set insertion point to the <merge> block for subsequent statements
  1335. theBuilder.setInsertPoint(mergeBB);
  1336. // Done with the current scope's continue block and merge block.
  1337. continueStack.pop();
  1338. breakStack.pop();
  1339. }
  1340. void SPIRVEmitter::doContinueStmt(const ContinueStmt *continueStmt) {
  1341. assert(!theBuilder.isCurrentBasicBlockTerminated());
  1342. const uint32_t continueTargetBB = continueStack.top();
  1343. theBuilder.createBranch(continueTargetBB);
  1344. theBuilder.addSuccessor(continueTargetBB);
  1345. // Some statements that alter the control flow (break, continue, return, and
  1346. // discard), require creation of a new basic block to hold any statement that
  1347. // may follow them. For example: StmtB and StmtC below are put inside a new
  1348. // basic block which is unreachable.
  1349. //
  1350. // while (true) {
  1351. // StmtA;
  1352. // continue;
  1353. // StmtB;
  1354. // StmtC;
  1355. // }
  1356. const uint32_t newBB = theBuilder.createBasicBlock();
  1357. theBuilder.setInsertPoint(newBB);
  1358. }
  1359. void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt,
  1360. llvm::ArrayRef<const Attr *> attrs) {
  1361. // While loops are composed of:
  1362. // while (<check>) { <body> }
  1363. //
  1364. // SPIR-V requires loops to have a merge basic block as well as a continue
  1365. // basic block. Even though while loops do not have an explicit continue
  1366. // block as in for-loops, we still do need to create a continue block.
  1367. //
  1368. // Since SPIR-V requires structured control flow, we need two more basic
  1369. // blocks, <header> and <merge>. <header> is the block before control flow
  1370. // diverges, and <merge> is the block where control flow subsequently
  1371. // converges. The <check> block can take the responsibility of the <header>
  1372. // block. The final CFG should normally be like the following. Exceptions
  1373. // will occur with non-local exits like loop breaks or early returns.
  1374. //
  1375. // +----------+
  1376. // | header | <------------------+
  1377. // | (check) | |
  1378. // +----------+ |
  1379. // | |
  1380. // +-------+-------+ |
  1381. // | false | true |
  1382. // | v |
  1383. // | +------+ +------------------+
  1384. // | | body | --> | continue (no-op) |
  1385. // v +------+ +------------------+
  1386. // +-------+
  1387. // | merge |
  1388. // +-------+
  1389. //
  1390. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1391. const spv::LoopControlMask loopControl =
  1392. attrs.empty() ? spv::LoopControlMask::MaskNone
  1393. : translateLoopAttribute(whileStmt, *attrs.front());
  1394. // Create basic blocks
  1395. const uint32_t checkBB = theBuilder.createBasicBlock("while.check");
  1396. const uint32_t bodyBB = theBuilder.createBasicBlock("while.body");
  1397. const uint32_t continueBB = theBuilder.createBasicBlock("while.continue");
  1398. const uint32_t mergeBB = theBuilder.createBasicBlock("while.merge");
  1399. // Make sure any continue statements branch to the continue block, and any
  1400. // break statements branch to the merge block.
  1401. continueStack.push(continueBB);
  1402. breakStack.push(mergeBB);
  1403. // Process the <check> block
  1404. theBuilder.createBranch(checkBB);
  1405. theBuilder.addSuccessor(checkBB);
  1406. theBuilder.setInsertPoint(checkBB);
  1407. // If we have:
  1408. // while (int a = foo()) {...}
  1409. // we should evaluate 'a' by calling 'foo()' every single time the check has
  1410. // to occur.
  1411. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
  1412. doStmt(condVarDecl);
  1413. uint32_t condition = 0;
  1414. if (const Expr *check = whileStmt->getCond()) {
  1415. emitDebugLine(check->getLocStart());
  1416. condition = doExpr(check);
  1417. } else {
  1418. condition = theBuilder.getConstantBool(true);
  1419. }
  1420. theBuilder.createConditionalBranch(condition, bodyBB,
  1421. /*false branch*/ mergeBB,
  1422. /*merge*/ mergeBB, continueBB,
  1423. spv::SelectionControlMask::MaskNone,
  1424. loopControl);
  1425. theBuilder.addSuccessor(bodyBB);
  1426. theBuilder.addSuccessor(mergeBB);
  1427. // The current basic block has OpLoopMerge instruction. We need to set its
  1428. // continue and merge target.
  1429. theBuilder.setContinueTarget(continueBB);
  1430. theBuilder.setMergeTarget(mergeBB);
  1431. // Process the <body> block
  1432. theBuilder.setInsertPoint(bodyBB);
  1433. if (const Stmt *body = whileStmt->getBody()) {
  1434. doStmt(body);
  1435. }
  1436. if (!theBuilder.isCurrentBasicBlockTerminated())
  1437. theBuilder.createBranch(continueBB);
  1438. theBuilder.addSuccessor(continueBB);
  1439. // Process the <continue> block. While loops do not have an explicit
  1440. // continue block. The continue block just branches to the <check> block.
  1441. theBuilder.setInsertPoint(continueBB);
  1442. theBuilder.createBranch(checkBB);
  1443. theBuilder.addSuccessor(checkBB);
  1444. // Set insertion point to the <merge> block for subsequent statements
  1445. theBuilder.setInsertPoint(mergeBB);
  1446. // Done with the current scope's continue and merge blocks.
  1447. continueStack.pop();
  1448. breakStack.pop();
  1449. }
  1450. void SPIRVEmitter::doForStmt(const ForStmt *forStmt,
  1451. llvm::ArrayRef<const Attr *> attrs) {
  1452. // for loops are composed of:
  1453. // for (<init>; <check>; <continue>) <body>
  1454. //
  1455. // To translate a for loop, we'll need to emit all <init> statements
  1456. // in the current basic block, and then have separate basic blocks for
  1457. // <check>, <continue>, and <body>. Besides, since SPIR-V requires
  1458. // structured control flow, we need two more basic blocks, <header>
  1459. // and <merge>. <header> is the block before control flow diverges,
  1460. // while <merge> is the block where control flow subsequently converges.
  1461. // The <check> block can take the responsibility of the <header> block.
  1462. // The final CFG should normally be like the following. Exceptions will
  1463. // occur with non-local exits like loop breaks or early returns.
  1464. // +--------+
  1465. // | init |
  1466. // +--------+
  1467. // |
  1468. // v
  1469. // +----------+
  1470. // | header | <---------------+
  1471. // | (check) | |
  1472. // +----------+ |
  1473. // | |
  1474. // +-------+-------+ |
  1475. // | false | true |
  1476. // | v |
  1477. // | +------+ +----------+
  1478. // | | body | --> | continue |
  1479. // v +------+ +----------+
  1480. // +-------+
  1481. // | merge |
  1482. // +-------+
  1483. //
  1484. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  1485. const spv::LoopControlMask loopControl =
  1486. attrs.empty() ? spv::LoopControlMask::MaskNone
  1487. : translateLoopAttribute(forStmt, *attrs.front());
  1488. // Create basic blocks
  1489. const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
  1490. const uint32_t bodyBB = theBuilder.createBasicBlock("for.body");
  1491. const uint32_t continueBB = theBuilder.createBasicBlock("for.continue");
  1492. const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge");
  1493. // Make sure any continue statements branch to the continue block, and any
  1494. // break statements branch to the merge block.
  1495. continueStack.push(continueBB);
  1496. breakStack.push(mergeBB);
  1497. // Process the <init> block
  1498. if (const Stmt *initStmt = forStmt->getInit()) {
  1499. emitDebugLine(initStmt->getLocStart());
  1500. doStmt(initStmt);
  1501. }
  1502. theBuilder.createBranch(checkBB);
  1503. theBuilder.addSuccessor(checkBB);
  1504. // Process the <check> block
  1505. theBuilder.setInsertPoint(checkBB);
  1506. uint32_t condition;
  1507. if (const Expr *check = forStmt->getCond()) {
  1508. emitDebugLine(check->getLocStart());
  1509. condition = doExpr(check);
  1510. } else {
  1511. condition = theBuilder.getConstantBool(true);
  1512. }
  1513. theBuilder.createConditionalBranch(condition, bodyBB,
  1514. /*false branch*/ mergeBB,
  1515. /*merge*/ mergeBB, continueBB,
  1516. spv::SelectionControlMask::MaskNone,
  1517. loopControl);
  1518. theBuilder.addSuccessor(bodyBB);
  1519. theBuilder.addSuccessor(mergeBB);
  1520. // The current basic block has OpLoopMerge instruction. We need to set its
  1521. // continue and merge target.
  1522. theBuilder.setContinueTarget(continueBB);
  1523. theBuilder.setMergeTarget(mergeBB);
  1524. // Process the <body> block
  1525. theBuilder.setInsertPoint(bodyBB);
  1526. if (const Stmt *body = forStmt->getBody()) {
  1527. doStmt(body);
  1528. }
  1529. if (!theBuilder.isCurrentBasicBlockTerminated())
  1530. theBuilder.createBranch(continueBB);
  1531. theBuilder.addSuccessor(continueBB);
  1532. // Process the <continue> block
  1533. theBuilder.setInsertPoint(continueBB);
  1534. if (const Expr *cont = forStmt->getInc()) {
  1535. emitDebugLine(cont->getLocStart());
  1536. doExpr(cont);
  1537. }
  1538. theBuilder.createBranch(checkBB); // <continue> should jump back to header
  1539. theBuilder.addSuccessor(checkBB);
  1540. // Set insertion point to the <merge> block for subsequent statements
  1541. theBuilder.setInsertPoint(mergeBB);
  1542. // Done with the current scope's continue block and merge block.
  1543. continueStack.pop();
  1544. breakStack.pop();
  1545. }
  1546. void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt,
  1547. llvm::ArrayRef<const Attr *> attrs) {
  1548. // if statements are composed of:
  1549. // if (<check>) { <then> } else { <else> }
  1550. //
  1551. // To translate if statements, we'll need to emit the <check> expressions
  1552. // in the current basic block, and then create separate basic blocks for
  1553. // <then> and <else>. Additionally, we'll need a <merge> block as per
  1554. // SPIR-V's structured control flow requirements. Depending whether there
  1555. // exists the else branch, the final CFG should normally be like the
  1556. // following. Exceptions will occur with non-local exits like loop breaks
  1557. // or early returns.
  1558. // +-------+ +-------+
  1559. // | check | | check |
  1560. // +-------+ +-------+
  1561. // | |
  1562. // +-------+-------+ +-----+-----+
  1563. // | true | false | true | false
  1564. // v v or v |
  1565. // +------+ +------+ +------+ |
  1566. // | then | | else | | then | |
  1567. // +------+ +------+ +------+ |
  1568. // | | | v
  1569. // | +-------+ | | +-------+
  1570. // +-> | merge | <-+ +---> | merge |
  1571. // +-------+ +-------+
  1572. { // Try to see if we can const-eval the condition
  1573. bool condition = false;
  1574. if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) {
  1575. if (condition) {
  1576. doStmt(ifStmt->getThen());
  1577. } else if (ifStmt->getElse()) {
  1578. doStmt(ifStmt->getElse());
  1579. }
  1580. return;
  1581. }
  1582. }
  1583. auto selectionControl = spv::SelectionControlMask::MaskNone;
  1584. if (!attrs.empty()) {
  1585. const Attr *attribute = attrs.front();
  1586. switch (attribute->getKind()) {
  1587. case attr::HLSLBranch:
  1588. selectionControl = spv::SelectionControlMask::DontFlatten;
  1589. break;
  1590. case attr::HLSLFlatten:
  1591. selectionControl = spv::SelectionControlMask::Flatten;
  1592. break;
  1593. default:
  1594. if (!spirvOptions.noWarnIgnoredFeatures) {
  1595. emitWarning("unknown if statement attribute '%0' ignored",
  1596. attribute->getLocation())
  1597. << attribute->getSpelling();
  1598. }
  1599. break;
  1600. }
  1601. }
  1602. if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt())
  1603. doDeclStmt(declStmt);
  1604. emitDebugLine(ifStmt->getCond()->getLocStart());
  1605. // First emit the instruction for evaluating the condition.
  1606. const uint32_t condition = doExpr(ifStmt->getCond());
  1607. // Then we need to emit the instruction for the conditional branch.
  1608. // We'll need the <label-id> for the then/else/merge block to do so.
  1609. const bool hasElse = ifStmt->getElse() != nullptr;
  1610. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  1611. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  1612. const uint32_t elseBB =
  1613. hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB;
  1614. // Create the branch instruction. This will end the current basic block.
  1615. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB,
  1616. /*continue*/ 0, selectionControl);
  1617. theBuilder.addSuccessor(thenBB);
  1618. theBuilder.addSuccessor(elseBB);
  1619. // The current basic block has the OpSelectionMerge instruction. We need
  1620. // to record its merge target.
  1621. theBuilder.setMergeTarget(mergeBB);
  1622. // Handle the then branch
  1623. theBuilder.setInsertPoint(thenBB);
  1624. doStmt(ifStmt->getThen());
  1625. if (!theBuilder.isCurrentBasicBlockTerminated())
  1626. theBuilder.createBranch(mergeBB);
  1627. theBuilder.addSuccessor(mergeBB);
  1628. // Handle the else branch (if exists)
  1629. if (hasElse) {
  1630. theBuilder.setInsertPoint(elseBB);
  1631. doStmt(ifStmt->getElse());
  1632. if (!theBuilder.isCurrentBasicBlockTerminated())
  1633. theBuilder.createBranch(mergeBB);
  1634. theBuilder.addSuccessor(mergeBB);
  1635. }
  1636. // From now on, we'll emit instructions into the merge block.
  1637. theBuilder.setInsertPoint(mergeBB);
  1638. }
  1639. void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
  1640. if (const auto *retVal = stmt->getRetValue()) {
  1641. // Update counter variable associated with function returns
  1642. tryToAssignCounterVar(curFunction, retVal);
  1643. const auto retInfo = loadIfGLValue(retVal);
  1644. const auto retType = retVal->getType();
  1645. if (retInfo.getStorageClass() != spv::StorageClass::Function &&
  1646. retType->isStructureType()) {
  1647. // We are returning some value from a non-Function storage class. Need to
  1648. // create a temporary variable to "convert" the value to Function storage
  1649. // class and then return.
  1650. const uint32_t valType = typeTranslator.translateType(retType);
  1651. const uint32_t tempVar = theBuilder.addFnVar(valType, "temp.var.ret");
  1652. storeValue(tempVar, retInfo, retType);
  1653. theBuilder.createReturnValue(theBuilder.createLoad(valType, tempVar));
  1654. } else {
  1655. theBuilder.createReturnValue(retInfo);
  1656. }
  1657. } else {
  1658. theBuilder.createReturn();
  1659. }
  1660. // We are translating a ReturnStmt, we should be in some function's body.
  1661. assert(curFunction->hasBody());
  1662. // If this return statement is the last statement in the function, then
  1663. // whe have no more work to do.
  1664. if (cast<CompoundStmt>(curFunction->getBody())->body_back() == stmt)
  1665. return;
  1666. // Some statements that alter the control flow (break, continue, return, and
  1667. // discard), require creation of a new basic block to hold any statement that
  1668. // may follow them. In this case, the newly created basic block will contain
  1669. // any statement that may come after an early return.
  1670. const uint32_t newBB = theBuilder.createBasicBlock();
  1671. theBuilder.setInsertPoint(newBB);
  1672. }
  1673. void SPIRVEmitter::doBreakStmt(const BreakStmt *breakStmt) {
  1674. assert(!theBuilder.isCurrentBasicBlockTerminated());
  1675. uint32_t breakTargetBB = breakStack.top();
  1676. theBuilder.addSuccessor(breakTargetBB);
  1677. theBuilder.createBranch(breakTargetBB);
  1678. // Some statements that alter the control flow (break, continue, return, and
  1679. // discard), require creation of a new basic block to hold any statement that
  1680. // may follow them. For example: StmtB and StmtC below are put inside a new
  1681. // basic block which is unreachable.
  1682. //
  1683. // while (true) {
  1684. // StmtA;
  1685. // break;
  1686. // StmtB;
  1687. // StmtC;
  1688. // }
  1689. const uint32_t newBB = theBuilder.createBasicBlock();
  1690. theBuilder.setInsertPoint(newBB);
  1691. }
  1692. void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
  1693. llvm::ArrayRef<const Attr *> attrs) {
  1694. // Switch statements are composed of:
  1695. // switch (<condition variable>) {
  1696. // <CaseStmt>
  1697. // <CaseStmt>
  1698. // <CaseStmt>
  1699. // <DefaultStmt> (optional)
  1700. // }
  1701. //
  1702. // +-------+
  1703. // | check |
  1704. // +-------+
  1705. // |
  1706. // +-------+-------+----------------+---------------+
  1707. // | 1 | 2 | 3 | (others)
  1708. // v v v v
  1709. // +-------+ +-------------+ +-------+ +------------+
  1710. // | case1 | | case2 | | case3 | ... | default |
  1711. // | | |(fallthrough)|---->| | | (optional) |
  1712. // +-------+ |+------------+ +-------+ +------------+
  1713. // | | |
  1714. // | | |
  1715. // | +-------+ | |
  1716. // | | | <--------------------+ |
  1717. // +-> | merge | |
  1718. // | | <-------------------------------------+
  1719. // +-------+
  1720. // If no attributes are given, or if "forcecase" attribute was provided,
  1721. // we'll do our best to use OpSwitch if possible.
  1722. // If any of the cases compares to a variable (rather than an integer
  1723. // literal), we cannot use OpSwitch because OpSwitch expects literal
  1724. // numbers as parameters.
  1725. const bool isAttrForceCase =
  1726. !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
  1727. const bool canUseSpirvOpSwitch =
  1728. (attrs.empty() || isAttrForceCase) &&
  1729. allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
  1730. if (isAttrForceCase && !canUseSpirvOpSwitch &&
  1731. !spirvOptions.noWarnIgnoredFeatures) {
  1732. emitWarning("ignored 'forcecase' attribute for the switch statement "
  1733. "since one or more case values are not integer literals",
  1734. switchStmt->getLocStart());
  1735. }
  1736. if (canUseSpirvOpSwitch)
  1737. processSwitchStmtUsingSpirvOpSwitch(switchStmt);
  1738. else
  1739. processSwitchStmtUsingIfStmts(switchStmt);
  1740. }
  1741. SpirvEvalInfo
  1742. SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
  1743. // Make sure we don't have previously unhandled NonUniformResourceIndex()
  1744. assert(!foundNonUniformResourceIndex);
  1745. llvm::SmallVector<uint32_t, 4> indices;
  1746. const auto *base = collectArrayStructIndices(expr, &indices);
  1747. auto info = loadIfAliasVarRef(base);
  1748. if (foundNonUniformResourceIndex) {
  1749. // Add the necessary capability required for indexing into this kind
  1750. // of resource
  1751. theBuilder.requireCapability(getNonUniformCapability(base->getType()));
  1752. info.setNonUniform(); // Carry forward the NonUniformEXT decoration
  1753. foundNonUniformResourceIndex = false;
  1754. }
  1755. if (!indices.empty()) {
  1756. (void)turnIntoElementPtr(base->getType(), info, expr->getType(), indices);
  1757. }
  1758. return info;
  1759. }
  1760. SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
  1761. const auto opcode = expr->getOpcode();
  1762. // Handle assignment first since we need to evaluate rhs before lhs.
  1763. // For other binary operations, we need to evaluate lhs before rhs.
  1764. if (opcode == BO_Assign) {
  1765. // Update counter variable associated with lhs of assignments
  1766. tryToAssignCounterVar(expr->getLHS(), expr->getRHS());
  1767. return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
  1768. /*isCompoundAssignment=*/false);
  1769. }
  1770. // Try to optimize floatMxN * float and floatN * float case
  1771. if (opcode == BO_Mul) {
  1772. if (SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  1773. return result;
  1774. if (SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  1775. return result;
  1776. }
  1777. return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode,
  1778. expr->getLHS()->getType(), expr->getType(),
  1779. expr->getSourceRange());
  1780. }
  1781. SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
  1782. emitDebugLine(callExpr->getLocStart());
  1783. if (const auto *operatorCall = dyn_cast<CXXOperatorCallExpr>(callExpr))
  1784. return doCXXOperatorCallExpr(operatorCall);
  1785. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr))
  1786. return doCXXMemberCallExpr(memberCall);
  1787. // Intrinsic functions such as 'dot' or 'mul'
  1788. if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) {
  1789. return processIntrinsicCallExpr(callExpr);
  1790. }
  1791. // Normal standalone functions
  1792. return processCall(callExpr);
  1793. }
  1794. SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
  1795. const FunctionDecl *callee = getCalleeDefinition(callExpr);
  1796. // Note that we always want the defintion because Stmts/Exprs in the
  1797. // function body references the parameters in the definition.
  1798. if (!callee) {
  1799. emitError("found undefined function", callExpr->getExprLoc());
  1800. return 0;
  1801. }
  1802. const auto numParams = callee->getNumParams();
  1803. bool isNonStaticMemberCall = false;
  1804. QualType objectType = {}; // Type of the object (if exists)
  1805. SpirvEvalInfo objectEvalInfo = 0; // EvalInfo for the object (if exists)
  1806. bool needsTempVar = false; // Whether we need temporary variable.
  1807. llvm::SmallVector<uint32_t, 4> vars; // Variables for function call
  1808. llvm::SmallVector<bool, 4> isTempVar; // Temporary variable or not
  1809. llvm::SmallVector<SpirvEvalInfo, 4> args; // Evaluated arguments
  1810. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
  1811. const auto *memberFn = cast<CXXMethodDecl>(memberCall->getCalleeDecl());
  1812. isNonStaticMemberCall = !memberFn->isStatic();
  1813. if (isNonStaticMemberCall) {
  1814. // For non-static member calls, evaluate the object and pass it as the
  1815. // first argument.
  1816. const auto *object = memberCall->getImplicitObjectArgument();
  1817. object = object->IgnoreParenNoopCasts(astContext);
  1818. // Update counter variable associated with the implicit object
  1819. tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object);
  1820. objectType = object->getType();
  1821. objectEvalInfo = doExpr(object);
  1822. uint32_t objectId = objectEvalInfo;
  1823. // If not already a variable, we need to create a temporary variable and
  1824. // pass the object pointer to the function. Example:
  1825. // getObject().objectMethod();
  1826. // Also, any parameter passed to the member function must be of Function
  1827. // storage class.
  1828. needsTempVar =
  1829. objectEvalInfo.isRValue() ||
  1830. objectEvalInfo.getStorageClass() != spv::StorageClass::Function;
  1831. if (needsTempVar) {
  1832. objectId =
  1833. createTemporaryVar(objectType, TypeTranslator::getName(objectType),
  1834. // May need to load to use as initializer
  1835. loadIfGLValue(object, objectEvalInfo));
  1836. }
  1837. args.push_back(objectId);
  1838. // We do not need to create a new temporary variable for the this
  1839. // object. Use the evaluated argument.
  1840. vars.push_back(args.back());
  1841. isTempVar.push_back(false);
  1842. }
  1843. }
  1844. // Evaluate parameters
  1845. for (uint32_t i = 0; i < numParams; ++i) {
  1846. // We want the argument variable here so that we can write back to it
  1847. // later. We will do the OpLoad of this argument manually. So ingore
  1848. // the LValueToRValue implicit cast here.
  1849. auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts();
  1850. const auto *param = callee->getParamDecl(i);
  1851. // Get the evaluation info if this argument is referencing some variable
  1852. // *as a whole*, in which case we can avoid creating the temporary variable
  1853. // for it if it is Function scope and can act as out parameter.
  1854. SpirvEvalInfo argInfo = 0;
  1855. if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
  1856. argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl());
  1857. }
  1858. if (argInfo && argInfo.getStorageClass() == spv::StorageClass::Function &&
  1859. canActAsOutParmVar(param)) {
  1860. vars.push_back(argInfo);
  1861. isTempVar.push_back(false);
  1862. args.push_back(doExpr(arg));
  1863. } else {
  1864. // We need to create variables for holding the values to be used as
  1865. // arguments. The variables themselves are of pointer types.
  1866. const uint32_t varType =
  1867. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
  1868. const std::string varName = "param.var." + param->getNameAsString();
  1869. const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
  1870. vars.push_back(tempVarId);
  1871. isTempVar.push_back(true);
  1872. args.push_back(doExpr(arg));
  1873. // Update counter variable associated with function parameters
  1874. tryToAssignCounterVar(param, arg);
  1875. // Manually load the argument here
  1876. const auto rhsVal = loadIfGLValue(arg, args.back());
  1877. // Initialize the temporary variables using the contents of the arguments
  1878. storeValue(tempVarId, rhsVal, param->getType());
  1879. }
  1880. }
  1881. assert(vars.size() == isTempVar.size());
  1882. assert(vars.size() == args.size());
  1883. // Push the callee into the work queue if it is not there.
  1884. if (!workQueue.count(callee)) {
  1885. workQueue.insert(callee);
  1886. }
  1887. const uint32_t retType =
  1888. declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee);
  1889. // Get or forward declare the function <result-id>
  1890. const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
  1891. const uint32_t retVal = theBuilder.createFunctionCall(retType, funcId, vars);
  1892. // If we created a temporary variable for the lvalue object this method is
  1893. // invoked upon, we need to copy the contents in the temporary variable back
  1894. // to the original object's variable in case there are side effects.
  1895. if (needsTempVar && !objectEvalInfo.isRValue()) {
  1896. const uint32_t typeId = typeTranslator.translateType(objectType);
  1897. const uint32_t value = theBuilder.createLoad(typeId, vars.front());
  1898. storeValue(objectEvalInfo, value, objectType);
  1899. }
  1900. // Go through all parameters and write those marked as out/inout
  1901. for (uint32_t i = 0; i < numParams; ++i) {
  1902. const auto *param = callee->getParamDecl(i);
  1903. if (isTempVar[i] && canActAsOutParmVar(param)) {
  1904. const auto *arg = callExpr->getArg(i);
  1905. const uint32_t index = i + isNonStaticMemberCall;
  1906. const uint32_t typeId = typeTranslator.translateType(param->getType());
  1907. const uint32_t value = theBuilder.createLoad(typeId, vars[index]);
  1908. processAssignment(arg, value, false, args[index]);
  1909. }
  1910. }
  1911. // Inherit the SpirvEvalInfo from the function definition
  1912. return declIdMapper.getDeclEvalInfo(callee).setResultId(retVal);
  1913. }
  1914. SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
  1915. const Expr *subExpr = expr->getSubExpr();
  1916. const QualType subExprType = subExpr->getType();
  1917. const QualType toType = expr->getType();
  1918. // Unfortunately the front-end fails to deduce some types in certain cases.
  1919. // Provide a hint about literal type usage if possible.
  1920. TypeTranslator::LiteralTypeHint hint(typeTranslator);
  1921. // 'literal int' to 'float' conversion. If a literal integer is to be used as
  1922. // a 32-bit float, the hint is a 32-bit integer.
  1923. if (toType->isFloatingType() &&
  1924. subExprType->isSpecificBuiltinType(BuiltinType::LitInt) &&
  1925. llvm::APFloat::getSizeInBits(astContext.getFloatTypeSemantics(toType)) ==
  1926. 32)
  1927. hint.setHint(astContext.IntTy);
  1928. // 'literal float' to 'float' conversion where intended type is float32.
  1929. if (toType->isFloatingType() &&
  1930. subExprType->isSpecificBuiltinType(BuiltinType::LitFloat) &&
  1931. llvm::APFloat::getSizeInBits(astContext.getFloatTypeSemantics(toType)) ==
  1932. 32)
  1933. hint.setHint(astContext.FloatTy);
  1934. // TODO: We could provide other useful hints. For instance:
  1935. // For the case of toType being a boolean, if the fromType is a literal float,
  1936. // we could provide a FloatTy hint and if the fromType is a literal integer,
  1937. // we could provide an IntTy hint. The front-end, however, seems to deduce the
  1938. // correct type in these cases; therefore we currently don't provide any
  1939. // additional hints.
  1940. switch (expr->getCastKind()) {
  1941. case CastKind::CK_LValueToRValue:
  1942. return loadIfGLValue(subExpr);
  1943. case CastKind::CK_NoOp:
  1944. return doExpr(subExpr);
  1945. case CastKind::CK_IntegralCast:
  1946. case CastKind::CK_FloatingToIntegral:
  1947. case CastKind::CK_HLSLCC_IntegralCast:
  1948. case CastKind::CK_HLSLCC_FloatingToIntegral: {
  1949. // Integer literals in the AST are represented using 64bit APInt
  1950. // themselves and then implicitly casted into the expected bitwidth.
  1951. // We need special treatment of integer literals here because generating
  1952. // a 64bit constant and then explicit casting in SPIR-V requires Int64
  1953. // capability. We should avoid introducing unnecessary capabilities to
  1954. // our best.
  1955. if (const uint32_t valueId = tryToEvaluateAsConst(expr))
  1956. return SpirvEvalInfo(valueId).setConstant().setRValue();
  1957. const auto valueId =
  1958. castToInt(doExpr(subExpr), subExprType, toType, subExpr->getExprLoc());
  1959. return SpirvEvalInfo(valueId).setRValue();
  1960. }
  1961. case CastKind::CK_FloatingCast:
  1962. case CastKind::CK_IntegralToFloating:
  1963. case CastKind::CK_HLSLCC_FloatingCast:
  1964. case CastKind::CK_HLSLCC_IntegralToFloating: {
  1965. // First try to see if we can do constant folding for floating point
  1966. // numbers like what we are doing for integers in the above.
  1967. if (const uint32_t valueId = tryToEvaluateAsConst(expr))
  1968. return SpirvEvalInfo(valueId).setConstant().setRValue();
  1969. const auto valueId = castToFloat(doExpr(subExpr), subExprType, toType,
  1970. subExpr->getExprLoc());
  1971. return SpirvEvalInfo(valueId).setRValue();
  1972. }
  1973. case CastKind::CK_IntegralToBoolean:
  1974. case CastKind::CK_FloatingToBoolean:
  1975. case CastKind::CK_HLSLCC_IntegralToBoolean:
  1976. case CastKind::CK_HLSLCC_FloatingToBoolean: {
  1977. // First try to see if we can do constant folding.
  1978. if (const uint32_t valueId = tryToEvaluateAsConst(expr))
  1979. return SpirvEvalInfo(valueId).setConstant().setRValue();
  1980. const auto valueId = castToBool(doExpr(subExpr), subExprType, toType);
  1981. return SpirvEvalInfo(valueId).setRValue();
  1982. }
  1983. case CastKind::CK_HLSLVectorSplat: {
  1984. const size_t size = hlsl::GetHLSLVecSize(expr->getType());
  1985. return createVectorSplat(subExpr, size);
  1986. }
  1987. case CastKind::CK_HLSLVectorTruncationCast: {
  1988. const uint32_t toVecTypeId = typeTranslator.translateType(toType);
  1989. const uint32_t elemTypeId =
  1990. typeTranslator.translateType(hlsl::GetHLSLVecElementType(toType));
  1991. const auto toSize = hlsl::GetHLSLVecSize(toType);
  1992. const uint32_t composite = doExpr(subExpr);
  1993. llvm::SmallVector<uint32_t, 4> elements;
  1994. for (uint32_t i = 0; i < toSize; ++i) {
  1995. elements.push_back(
  1996. theBuilder.createCompositeExtract(elemTypeId, composite, {i}));
  1997. }
  1998. auto valueId = elements.front();
  1999. if (toSize > 1)
  2000. valueId = theBuilder.createCompositeConstruct(toVecTypeId, elements);
  2001. return SpirvEvalInfo(valueId).setRValue();
  2002. }
  2003. case CastKind::CK_HLSLVectorToScalarCast: {
  2004. // The underlying should already be a vector of size 1.
  2005. assert(hlsl::GetHLSLVecSize(subExprType) == 1);
  2006. return doExpr(subExpr);
  2007. }
  2008. case CastKind::CK_HLSLVectorToMatrixCast: {
  2009. // If target type is already an 1xN matrix type, we just return the
  2010. // underlying vector.
  2011. if (TypeTranslator::is1xNMatrix(toType))
  2012. return doExpr(subExpr);
  2013. // A vector can have no more than 4 elements. The only remaining case
  2014. // is casting from size-4 vector to size-2-by-2 matrix.
  2015. const auto vec = loadIfGLValue(subExpr);
  2016. QualType elemType = {};
  2017. uint32_t rowCount = 0, colCount = 0;
  2018. const bool isMat =
  2019. TypeTranslator::isMxNMatrix(toType, &elemType, &rowCount, &colCount);
  2020. assert(isMat && rowCount == 2 && colCount == 2);
  2021. (void)isMat;
  2022. uint32_t vec2Type =
  2023. theBuilder.getVecType(typeTranslator.translateType(elemType), 2);
  2024. const auto subVec1 =
  2025. theBuilder.createVectorShuffle(vec2Type, vec, vec, {0, 1});
  2026. const auto subVec2 =
  2027. theBuilder.createVectorShuffle(vec2Type, vec, vec, {2, 3});
  2028. const auto mat = theBuilder.createCompositeConstruct(
  2029. theBuilder.getMatType(elemType, vec2Type, 2), {subVec1, subVec2});
  2030. return SpirvEvalInfo(mat).setRValue();
  2031. }
  2032. case CastKind::CK_HLSLMatrixSplat: {
  2033. // From scalar to matrix
  2034. uint32_t rowCount = 0, colCount = 0;
  2035. hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount);
  2036. // Handle degenerated cases first
  2037. if (rowCount == 1 && colCount == 1)
  2038. return doExpr(subExpr);
  2039. if (colCount == 1)
  2040. return createVectorSplat(subExpr, rowCount);
  2041. const auto vecSplat = createVectorSplat(subExpr, colCount);
  2042. if (rowCount == 1)
  2043. return vecSplat;
  2044. const uint32_t matType = typeTranslator.translateType(toType);
  2045. llvm::SmallVector<uint32_t, 4> vectors(size_t(rowCount), vecSplat);
  2046. if (vecSplat.isConstant()) {
  2047. const auto valueId = theBuilder.getConstantComposite(matType, vectors);
  2048. return SpirvEvalInfo(valueId).setConstant().setRValue();
  2049. } else {
  2050. const auto valueId =
  2051. theBuilder.createCompositeConstruct(matType, vectors);
  2052. return SpirvEvalInfo(valueId).setRValue();
  2053. }
  2054. }
  2055. case CastKind::CK_HLSLMatrixTruncationCast: {
  2056. const QualType srcType = subExprType;
  2057. const uint32_t srcId = doExpr(subExpr);
  2058. const QualType elemType = hlsl::GetHLSLMatElementType(srcType);
  2059. const uint32_t dstTypeId = typeTranslator.translateType(toType);
  2060. llvm::SmallVector<uint32_t, 4> indexes;
  2061. // It is possible that the source matrix is in fact a vector.
  2062. // For example: Truncate float1x3 --> float1x2.
  2063. // The front-end disallows float1x3 --> float2x1.
  2064. {
  2065. uint32_t srcVecSize = 0, dstVecSize = 0;
  2066. if (TypeTranslator::isVectorType(srcType, nullptr, &srcVecSize) &&
  2067. TypeTranslator::isVectorType(toType, nullptr, &dstVecSize)) {
  2068. for (uint32_t i = 0; i < dstVecSize; ++i)
  2069. indexes.push_back(i);
  2070. const auto valId =
  2071. theBuilder.createVectorShuffle(dstTypeId, srcId, srcId, indexes);
  2072. return SpirvEvalInfo(valId).setRValue();
  2073. }
  2074. }
  2075. uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0;
  2076. hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols);
  2077. hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols);
  2078. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  2079. const uint32_t srcRowType = theBuilder.getVecType(elemTypeId, srcCols);
  2080. // Indexes to pass to OpVectorShuffle
  2081. for (uint32_t i = 0; i < dstCols; ++i)
  2082. indexes.push_back(i);
  2083. llvm::SmallVector<uint32_t, 4> extractedVecs;
  2084. for (uint32_t row = 0; row < dstRows; ++row) {
  2085. // Extract a row
  2086. uint32_t rowId =
  2087. theBuilder.createCompositeExtract(srcRowType, srcId, {row});
  2088. // Extract the necessary columns from that row.
  2089. // The front-end ensures dstCols <= srcCols.
  2090. // If dstCols equals srcCols, we can use the whole row directly.
  2091. if (dstCols == 1) {
  2092. rowId = theBuilder.createCompositeExtract(elemTypeId, rowId, {0});
  2093. } else if (dstCols < srcCols) {
  2094. rowId = theBuilder.createVectorShuffle(
  2095. theBuilder.getVecType(elemTypeId, dstCols), rowId, rowId, indexes);
  2096. }
  2097. extractedVecs.push_back(rowId);
  2098. }
  2099. uint32_t valId = extractedVecs.front();
  2100. if (extractedVecs.size() > 1) {
  2101. valId = theBuilder.createCompositeConstruct(
  2102. typeTranslator.translateType(toType), extractedVecs);
  2103. }
  2104. return SpirvEvalInfo(valId).setRValue();
  2105. }
  2106. case CastKind::CK_HLSLMatrixToScalarCast: {
  2107. // The underlying should already be a matrix of 1x1.
  2108. assert(TypeTranslator::is1x1Matrix(subExprType));
  2109. return doExpr(subExpr);
  2110. }
  2111. case CastKind::CK_HLSLMatrixToVectorCast: {
  2112. // The underlying should already be a matrix of 1xN.
  2113. assert(TypeTranslator::is1xNMatrix(subExprType) ||
  2114. TypeTranslator::isMx1Matrix(subExprType));
  2115. return doExpr(subExpr);
  2116. }
  2117. case CastKind::CK_FunctionToPointerDecay:
  2118. // Just need to return the function id
  2119. return doExpr(subExpr);
  2120. case CastKind::CK_FlatConversion: {
  2121. uint32_t subExprId = 0;
  2122. QualType evalType = subExprType;
  2123. // Optimization: we can use OpConstantNull for cases where we want to
  2124. // initialize an entire data structure to zeros.
  2125. if (evaluatesToConstZero(subExpr, astContext)) {
  2126. subExprId =
  2127. theBuilder.getConstantNull(typeTranslator.translateType(toType));
  2128. return SpirvEvalInfo(subExprId).setRValue().setConstant();
  2129. }
  2130. TypeTranslator::LiteralTypeHint hint(typeTranslator);
  2131. // Try to evaluate float literals as float rather than double.
  2132. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(subExpr)) {
  2133. subExprId = tryToEvaluateAsFloat32(floatLiteral->getValue());
  2134. if (subExprId)
  2135. evalType = astContext.FloatTy;
  2136. }
  2137. // Evaluate 'literal float' initializer type as float rather than double.
  2138. // TODO: This could result in rounding error if the initializer is a
  2139. // non-literal expression that requires larger than 32 bits and has the
  2140. // 'literal float' type.
  2141. else if (subExprType->isSpecificBuiltinType(BuiltinType::LitFloat)) {
  2142. evalType = astContext.FloatTy;
  2143. hint.setHint(astContext.FloatTy);
  2144. }
  2145. // Try to evaluate integer literals as 32-bit int rather than 64-bit int.
  2146. else if (const auto *intLiteral = dyn_cast<IntegerLiteral>(subExpr)) {
  2147. const bool isSigned = subExprType->isSignedIntegerType();
  2148. subExprId = tryToEvaluateAsInt32(intLiteral->getValue(), isSigned);
  2149. if (subExprId)
  2150. evalType = isSigned ? astContext.IntTy : astContext.UnsignedIntTy;
  2151. }
  2152. // For assigning one array instance to another one with the same array type
  2153. // (regardless of constness and literalness), the rhs will be wrapped in a
  2154. // FlatConversion:
  2155. // |- <lhs>
  2156. // `- ImplicitCastExpr <FlatConversion>
  2157. // `- ImplicitCastExpr <LValueToRValue>
  2158. // `- <rhs>
  2159. // This FlatConversion does not affect CodeGen, so that we can ignore it.
  2160. else if (subExprType->isArrayType() &&
  2161. typeTranslator.isSameType(expr->getType(), subExprType)) {
  2162. return doExpr(subExpr);
  2163. }
  2164. if (!subExprId)
  2165. subExprId = doExpr(subExpr);
  2166. const auto valId =
  2167. processFlatConversion(toType, evalType, subExprId, expr->getExprLoc());
  2168. return SpirvEvalInfo(valId).setRValue();
  2169. }
  2170. case CastKind::CK_UncheckedDerivedToBase:
  2171. case CastKind::CK_HLSLDerivedToBase: {
  2172. // Find the index sequence of the base to which we are casting
  2173. llvm::SmallVector<uint32_t, 4> baseIndices;
  2174. getBaseClassIndices(expr, &baseIndices);
  2175. // Turn them in to SPIR-V constants
  2176. for (uint32_t i = 0; i < baseIndices.size(); ++i)
  2177. baseIndices[i] = theBuilder.getConstantUint32(baseIndices[i]);
  2178. auto derivedInfo = doExpr(subExpr);
  2179. return turnIntoElementPtr(subExpr->getType(), derivedInfo, expr->getType(),
  2180. baseIndices);
  2181. }
  2182. default:
  2183. emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
  2184. << expr->getCastKindName() << expr->getSourceRange();
  2185. expr->dump();
  2186. return 0;
  2187. }
  2188. }
  2189. uint32_t SPIRVEmitter::processFlatConversion(const QualType type,
  2190. const QualType initType,
  2191. const uint32_t initId,
  2192. SourceLocation srcLoc) {
  2193. // Try to translate the canonical type first
  2194. const auto canonicalType = type.getCanonicalType();
  2195. if (canonicalType != type)
  2196. return processFlatConversion(canonicalType, initType, initId, srcLoc);
  2197. // Primitive types
  2198. {
  2199. QualType ty = {};
  2200. if (TypeTranslator::isScalarType(type, &ty)) {
  2201. if (const auto *builtinType = ty->getAs<BuiltinType>()) {
  2202. switch (builtinType->getKind()) {
  2203. case BuiltinType::Void: {
  2204. emitError("cannot create a constant of void type", srcLoc);
  2205. return 0;
  2206. }
  2207. case BuiltinType::Bool:
  2208. return castToBool(initId, initType, ty);
  2209. // Target type is an integer variant.
  2210. case BuiltinType::Int:
  2211. case BuiltinType::Short:
  2212. case BuiltinType::Min12Int:
  2213. case BuiltinType::Min16Int:
  2214. case BuiltinType::Min16UInt:
  2215. case BuiltinType::UShort:
  2216. case BuiltinType::UInt:
  2217. case BuiltinType::Long:
  2218. case BuiltinType::LongLong:
  2219. case BuiltinType::ULong:
  2220. case BuiltinType::ULongLong:
  2221. return castToInt(initId, initType, ty, srcLoc);
  2222. // Target type is a float variant.
  2223. case BuiltinType::Double:
  2224. case BuiltinType::Float:
  2225. case BuiltinType::Half:
  2226. case BuiltinType::HalfFloat:
  2227. case BuiltinType::Min10Float:
  2228. case BuiltinType::Min16Float:
  2229. return castToFloat(initId, initType, ty, srcLoc);
  2230. default:
  2231. emitError("flat conversion of type %0 unimplemented", srcLoc)
  2232. << builtinType->getTypeClassName();
  2233. return 0;
  2234. }
  2235. }
  2236. }
  2237. }
  2238. // Vector types
  2239. {
  2240. QualType elemType = {};
  2241. uint32_t elemCount = {};
  2242. if (TypeTranslator::isVectorType(type, &elemType, &elemCount)) {
  2243. const uint32_t elemId =
  2244. processFlatConversion(elemType, initType, initId, srcLoc);
  2245. llvm::SmallVector<uint32_t, 4> constituents(size_t(elemCount), elemId);
  2246. return theBuilder.createCompositeConstruct(
  2247. typeTranslator.translateType(type), constituents);
  2248. }
  2249. }
  2250. // Matrix types
  2251. {
  2252. QualType elemType = {};
  2253. uint32_t rowCount = 0, colCount = 0;
  2254. if (TypeTranslator::isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
  2255. // By default HLSL matrices are row major, while SPIR-V matrices are
  2256. // column major. We are mapping what HLSL semantically mean a row into a
  2257. // column here.
  2258. const uint32_t vecType = theBuilder.getVecType(
  2259. typeTranslator.translateType(elemType), colCount);
  2260. const uint32_t elemId =
  2261. processFlatConversion(elemType, initType, initId, srcLoc);
  2262. const llvm::SmallVector<uint32_t, 4> constituents(size_t(colCount),
  2263. elemId);
  2264. const uint32_t colId =
  2265. theBuilder.createCompositeConstruct(vecType, constituents);
  2266. const llvm::SmallVector<uint32_t, 4> rows(size_t(rowCount), colId);
  2267. return theBuilder.createCompositeConstruct(
  2268. typeTranslator.translateType(type), rows);
  2269. }
  2270. }
  2271. // Struct type
  2272. if (const auto *structType = type->getAs<RecordType>()) {
  2273. const auto *decl = structType->getDecl();
  2274. llvm::SmallVector<uint32_t, 4> fields;
  2275. for (const auto *field : decl->fields()) {
  2276. // There is a special case for FlatConversion. If T is a struct with only
  2277. // one member, S, then (T)<an-instance-of-S> is allowed, which essentially
  2278. // constructs a new T instance using the instance of S as its only member.
  2279. // Check whether we are handling that case here first.
  2280. if (field->getType().getCanonicalType() == initType.getCanonicalType()) {
  2281. fields.push_back(initId);
  2282. } else {
  2283. fields.push_back(
  2284. processFlatConversion(field->getType(), initType, initId, srcLoc));
  2285. }
  2286. }
  2287. return theBuilder.createCompositeConstruct(
  2288. typeTranslator.translateType(type), fields);
  2289. }
  2290. // Array type
  2291. if (const auto *arrayType = astContext.getAsConstantArrayType(type)) {
  2292. const auto size =
  2293. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  2294. const uint32_t elemId = processFlatConversion(arrayType->getElementType(),
  2295. initType, initId, srcLoc);
  2296. llvm::SmallVector<uint32_t, 4> constituents(size_t(size), elemId);
  2297. return theBuilder.createCompositeConstruct(
  2298. typeTranslator.translateType(type), constituents);
  2299. }
  2300. emitError("flat conversion of type %0 unimplemented", {})
  2301. << type->getTypeClassName();
  2302. type->dump();
  2303. return 0;
  2304. }
  2305. SpirvEvalInfo
  2306. SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
  2307. const auto opcode = expr->getOpcode();
  2308. // Try to optimize floatMxN *= float and floatN *= float case
  2309. if (opcode == BO_MulAssign) {
  2310. if (SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  2311. return result;
  2312. if (SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  2313. return result;
  2314. }
  2315. const auto *rhs = expr->getRHS();
  2316. const auto *lhs = expr->getLHS();
  2317. SpirvEvalInfo lhsPtr = 0;
  2318. const auto result =
  2319. processBinaryOp(lhs, rhs, opcode, expr->getComputationLHSType(),
  2320. expr->getType(), expr->getSourceRange(), &lhsPtr);
  2321. return processAssignment(lhs, result, true, lhsPtr);
  2322. }
  2323. SpirvEvalInfo
  2324. SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
  2325. const auto type = expr->getType();
  2326. // Enhancement for special case when the ConditionalOperator return type is a
  2327. // literal type. For example:
  2328. //
  2329. // float a = cond ? 1 : 2;
  2330. // int b = cond ? 1.5 : 2.5;
  2331. //
  2332. // There will be no indications about whether '1' and '2' should be used as
  2333. // 32-bit or 64-bit integers. Similarly, there will be no indication about
  2334. // whether '1.5' and '2.5' should be used as 32-bit or 64-bit floats.
  2335. //
  2336. // We want to avoid using 64-bit int and 64-bit float as much as possible.
  2337. //
  2338. // Note that if the literal is in fact large enough that it can't be
  2339. // represented in 32 bits (e.g. integer larger than 3e+9), we should *not*
  2340. // provide a hint.
  2341. TypeTranslator::LiteralTypeHint hint(typeTranslator);
  2342. const bool isLitInt = type->isSpecificBuiltinType(BuiltinType::LitInt);
  2343. const bool isLitFloat = type->isSpecificBuiltinType(BuiltinType::LitFloat);
  2344. // Return type of ConditionalOperator is a 'literal int' or 'literal float'
  2345. if (isLitInt || isLitFloat) {
  2346. // There is no hint about the intended usage of the literal type.
  2347. if (typeTranslator.getIntendedLiteralType(type) == type) {
  2348. // If either branch is a literal that is larger than 32-bits, do not
  2349. // provide a hint.
  2350. if (!isLiteralLargerThan32Bits(expr->getTrueExpr()) &&
  2351. !isLiteralLargerThan32Bits(expr->getFalseExpr())) {
  2352. if (isLitInt)
  2353. hint.setHint(astContext.IntTy);
  2354. else if (isLitFloat)
  2355. hint.setHint(astContext.FloatTy);
  2356. }
  2357. }
  2358. }
  2359. // According to HLSL doc, all sides of the ?: expression are always
  2360. // evaluated.
  2361. const uint32_t typeId = typeTranslator.translateType(type);
  2362. // If we are selecting between two SampleState objects, none of the three
  2363. // operands has a LValueToRValue implicit cast.
  2364. uint32_t condition = loadIfGLValue(expr->getCond());
  2365. const auto trueBranch = loadIfGLValue(expr->getTrueExpr());
  2366. const auto falseBranch = loadIfGLValue(expr->getFalseExpr());
  2367. // For cases where the return type is a scalar or a vector, we can use
  2368. // OpSelect to choose between the two. OpSelect's return type must be either
  2369. // scalar or vector.
  2370. if (TypeTranslator::isScalarType(type) ||
  2371. TypeTranslator::isVectorType(type)) {
  2372. // The SPIR-V OpSelect instruction must have a selection argument that is
  2373. // the same size as the return type. If the return type is a vector, the
  2374. // selection must be a vector of booleans (one per output component).
  2375. uint32_t count = 0;
  2376. if (TypeTranslator::isVectorType(expr->getType(), nullptr, &count) &&
  2377. !TypeTranslator::isVectorType(expr->getCond()->getType())) {
  2378. const uint32_t condVecType =
  2379. theBuilder.getVecType(theBuilder.getBoolType(), count);
  2380. const llvm::SmallVector<uint32_t, 4> components(size_t(count), condition);
  2381. condition = theBuilder.createCompositeConstruct(condVecType, components);
  2382. }
  2383. auto valueId =
  2384. theBuilder.createSelect(typeId, condition, trueBranch, falseBranch);
  2385. return SpirvEvalInfo(valueId).setRValue();
  2386. }
  2387. // If we can't use OpSelect, we need to create if-else control flow.
  2388. const uint32_t tempVar = theBuilder.addFnVar(typeId, "temp.var.ternary");
  2389. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  2390. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  2391. const uint32_t elseBB = theBuilder.createBasicBlock("if.false");
  2392. // Create the branch instruction. This will end the current basic block.
  2393. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
  2394. theBuilder.addSuccessor(thenBB);
  2395. theBuilder.addSuccessor(elseBB);
  2396. theBuilder.setMergeTarget(mergeBB);
  2397. // Handle the then branch
  2398. theBuilder.setInsertPoint(thenBB);
  2399. theBuilder.createStore(tempVar, trueBranch);
  2400. theBuilder.createBranch(mergeBB);
  2401. theBuilder.addSuccessor(mergeBB);
  2402. // Handle the else branch
  2403. theBuilder.setInsertPoint(elseBB);
  2404. theBuilder.createStore(tempVar, falseBranch);
  2405. theBuilder.createBranch(mergeBB);
  2406. theBuilder.addSuccessor(mergeBB);
  2407. // From now on, emit instructions into the merge block.
  2408. theBuilder.setInsertPoint(mergeBB);
  2409. return SpirvEvalInfo(theBuilder.createLoad(typeId, tempVar)).setRValue();
  2410. }
  2411. uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
  2412. const CXXMemberCallExpr *expr) {
  2413. const auto *object = expr->getImplicitObjectArgument();
  2414. const auto objectId = loadIfAliasVarRef(object);
  2415. const auto type = object->getType();
  2416. const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) ||
  2417. TypeTranslator::isRWByteAddressBuffer(type);
  2418. const bool isStructuredBuffer =
  2419. TypeTranslator::isStructuredBuffer(type) ||
  2420. TypeTranslator::isAppendStructuredBuffer(type) ||
  2421. TypeTranslator::isConsumeStructuredBuffer(type);
  2422. assert(isByteAddressBuffer || isStructuredBuffer);
  2423. // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
  2424. // with only one member that is a runtime array. We need to perform
  2425. // OpArrayLength on member 0.
  2426. const auto uintType = theBuilder.getUint32Type();
  2427. uint32_t length =
  2428. theBuilder.createBinaryOp(spv::Op::OpArrayLength, uintType, objectId, 0);
  2429. // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
  2430. // in bytes, but OpArrayLength returns the number of uints in the runtime
  2431. // array. Therefore we must multiply the results by 4.
  2432. if (isByteAddressBuffer) {
  2433. length = theBuilder.createBinaryOp(spv::Op::OpIMul, uintType, length,
  2434. theBuilder.getConstantUint32(4u));
  2435. }
  2436. theBuilder.createStore(doExpr(expr->getArg(0)), length);
  2437. if (isStructuredBuffer) {
  2438. // For (RW)StructuredBuffer, the stride of the runtime array (which is the
  2439. // size of the struct) must also be written to the second argument.
  2440. uint32_t size = 0, stride = 0;
  2441. std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
  2442. type, spirvOptions.sBufferLayoutRule, &stride);
  2443. const auto sizeId = theBuilder.getConstantUint32(size);
  2444. theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
  2445. }
  2446. return 0;
  2447. }
  2448. uint32_t SPIRVEmitter::processRWByteAddressBufferAtomicMethods(
  2449. hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) {
  2450. // The signature of RWByteAddressBuffer atomic methods are largely:
  2451. // void Interlocked*(in UINT dest, in UINT value);
  2452. // void Interlocked*(in UINT dest, in UINT value, out UINT original_value);
  2453. const auto *object = expr->getImplicitObjectArgument();
  2454. const auto objectInfo = loadIfAliasVarRef(object);
  2455. const auto uintType = theBuilder.getUint32Type();
  2456. const uint32_t zero = theBuilder.getConstantUint32(0);
  2457. const uint32_t offset = doExpr(expr->getArg(0));
  2458. // Right shift by 2 to convert the byte offset to uint32_t offset
  2459. const uint32_t address =
  2460. theBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, uintType, offset,
  2461. theBuilder.getConstantUint32(2));
  2462. const auto ptrType =
  2463. theBuilder.getPointerType(uintType, objectInfo.getStorageClass());
  2464. const uint32_t ptr =
  2465. theBuilder.createAccessChain(ptrType, objectInfo, {zero, address});
  2466. const uint32_t scope = theBuilder.getConstantUint32(1); // Device
  2467. const bool isCompareExchange =
  2468. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange;
  2469. const bool isCompareStore =
  2470. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore;
  2471. if (isCompareExchange || isCompareStore) {
  2472. const uint32_t comparator = doExpr(expr->getArg(1));
  2473. const uint32_t originalVal = theBuilder.createAtomicCompareExchange(
  2474. uintType, ptr, scope, zero, zero, doExpr(expr->getArg(2)), comparator);
  2475. if (isCompareExchange)
  2476. theBuilder.createStore(doExpr(expr->getArg(3)), originalVal);
  2477. } else {
  2478. const uint32_t value = doExpr(expr->getArg(1));
  2479. const uint32_t originalVal = theBuilder.createAtomicOp(
  2480. translateAtomicHlslOpcodeToSpirvOpcode(opcode), uintType, ptr, scope,
  2481. zero, value);
  2482. if (expr->getNumArgs() > 2)
  2483. theBuilder.createStore(doExpr(expr->getArg(2)), originalVal);
  2484. }
  2485. return 0;
  2486. }
  2487. uint32_t SPIRVEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) {
  2488. const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
  2489. const auto sampleCount = theBuilder.createUnaryOp(
  2490. spv::Op::OpImageQuerySamples, theBuilder.getUint32Type(),
  2491. loadIfGLValue(object));
  2492. if (!spirvOptions.noWarnEmulatedFeatures)
  2493. emitWarning("GetSamplePosition is emulated using many SPIR-V instructions "
  2494. "due to lack of direct SPIR-V equivalent, so it only supports "
  2495. "standard sample settings with 1, 2, 4, 8, or 16 samples and "
  2496. "will return float2(0, 0) for other cases",
  2497. expr->getCallee()->getExprLoc());
  2498. return emitGetSamplePosition(sampleCount, doExpr(expr->getArg(0)));
  2499. }
  2500. SpirvEvalInfo SPIRVEmitter::processSubpassLoad(const CXXMemberCallExpr *expr) {
  2501. const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
  2502. const uint32_t sample = expr->getNumArgs() == 1
  2503. ? static_cast<uint32_t>(doExpr(expr->getArg(0)))
  2504. : 0;
  2505. const uint32_t zero = theBuilder.getConstantInt32(0);
  2506. const uint32_t location = theBuilder.getConstantComposite(
  2507. theBuilder.getVecType(theBuilder.getInt32Type(), 2), {zero, zero});
  2508. return processBufferTextureLoad(object, location, /*constOffset*/ 0,
  2509. /*varOffset*/ 0, /*lod*/ sample,
  2510. /*residencyCode*/ 0);
  2511. }
  2512. uint32_t
  2513. SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
  2514. const auto *object = expr->getImplicitObjectArgument();
  2515. const auto objectId = loadIfGLValue(object);
  2516. const auto type = object->getType();
  2517. const auto *recType = type->getAs<RecordType>();
  2518. assert(recType);
  2519. const auto typeName = recType->getDecl()->getName();
  2520. const auto numArgs = expr->getNumArgs();
  2521. const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr;
  2522. assert(TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
  2523. TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type));
  2524. // For Texture1D, arguments are either:
  2525. // a) width
  2526. // b) MipLevel, width, NumLevels
  2527. // For Texture1DArray, arguments are either:
  2528. // a) width, elements
  2529. // b) MipLevel, width, elements, NumLevels
  2530. // For Texture2D, arguments are either:
  2531. // a) width, height
  2532. // b) MipLevel, width, height, NumLevels
  2533. // For Texture2DArray, arguments are either:
  2534. // a) width, height, elements
  2535. // b) MipLevel, width, height, elements, NumLevels
  2536. // For Texture3D, arguments are either:
  2537. // a) width, height, depth
  2538. // b) MipLevel, width, height, depth, NumLevels
  2539. // For Texture2DMS, arguments are: width, height, NumSamples
  2540. // For Texture2DMSArray, arguments are: width, height, elements, NumSamples
  2541. // For TextureCube, arguments are either:
  2542. // a) width, height
  2543. // b) MipLevel, width, height, NumLevels
  2544. // For TextureCubeArray, arguments are either:
  2545. // a) width, height, elements
  2546. // b) MipLevel, width, height, elements, NumLevels
  2547. // Note: SPIR-V Spec requires return type of OpImageQuerySize(Lod) to be a
  2548. // scalar/vector of integers. SPIR-V Spec also requires return type of
  2549. // OpImageQueryLevels and OpImageQuerySamples to be scalar integers.
  2550. // The HLSL methods, however, have overloaded functions which have float
  2551. // output arguments. Since the AST naturally won't have casting AST nodes for
  2552. // such cases, we'll have to perform the cast ourselves.
  2553. const auto storeToOutputArg = [this](const Expr *outputArg, uint32_t id,
  2554. QualType type) {
  2555. id = castToType(id, type, outputArg->getType(), outputArg->getExprLoc());
  2556. theBuilder.createStore(doExpr(outputArg), id);
  2557. };
  2558. if ((typeName == "Texture1D" && numArgs > 1) ||
  2559. (typeName == "Texture2D" && numArgs > 2) ||
  2560. (typeName == "TextureCube" && numArgs > 2) ||
  2561. (typeName == "Texture3D" && numArgs > 3) ||
  2562. (typeName == "Texture1DArray" && numArgs > 2) ||
  2563. (typeName == "TextureCubeArray" && numArgs > 3) ||
  2564. (typeName == "Texture2DArray" && numArgs > 3)) {
  2565. mipLevel = expr->getArg(0);
  2566. numLevels = expr->getArg(numArgs - 1);
  2567. }
  2568. if (TypeTranslator::isTextureMS(type)) {
  2569. numSamples = expr->getArg(numArgs - 1);
  2570. }
  2571. uint32_t querySize = numArgs;
  2572. // If numLevels arg is present, mipLevel must also be present. These are not
  2573. // queried via ImageQuerySizeLod.
  2574. if (numLevels)
  2575. querySize -= 2;
  2576. // If numLevels arg is present, mipLevel must also be present.
  2577. else if (numSamples)
  2578. querySize -= 1;
  2579. const uint32_t uintId = theBuilder.getUint32Type();
  2580. const QualType resultQualType =
  2581. querySize == 1
  2582. ? astContext.UnsignedIntTy
  2583. : astContext.getExtVectorType(astContext.UnsignedIntTy, querySize);
  2584. const uint32_t resultTypeId = typeTranslator.translateType(resultQualType);
  2585. // Only Texture types use ImageQuerySizeLod.
  2586. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize.
  2587. uint32_t lod = 0;
  2588. if (TypeTranslator::isTexture(type) && !numSamples) {
  2589. if (mipLevel) {
  2590. // For Texture types when mipLevel argument is present.
  2591. lod = doExpr(mipLevel);
  2592. } else {
  2593. // For Texture types when mipLevel argument is omitted.
  2594. lod = theBuilder.getConstantInt32(0);
  2595. }
  2596. }
  2597. const uint32_t query =
  2598. lod ? theBuilder.createBinaryOp(spv::Op::OpImageQuerySizeLod,
  2599. resultTypeId, objectId, lod)
  2600. : theBuilder.createUnaryOp(spv::Op::OpImageQuerySize, resultTypeId,
  2601. objectId);
  2602. if (querySize == 1) {
  2603. const uint32_t argIndex = mipLevel ? 1 : 0;
  2604. storeToOutputArg(expr->getArg(argIndex), query, resultQualType);
  2605. } else {
  2606. for (uint32_t i = 0; i < querySize; ++i) {
  2607. const uint32_t component =
  2608. theBuilder.createCompositeExtract(uintId, query, {i});
  2609. // If the first arg is the mipmap level, we must write the results
  2610. // starting from Arg(i+1), not Arg(i).
  2611. const uint32_t argIndex = mipLevel ? i + 1 : i;
  2612. storeToOutputArg(expr->getArg(argIndex), component,
  2613. astContext.UnsignedIntTy);
  2614. }
  2615. }
  2616. if (numLevels || numSamples) {
  2617. const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
  2618. const spv::Op opcode =
  2619. numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
  2620. const uint32_t numLevelsSamplesQuery =
  2621. theBuilder.createUnaryOp(opcode, uintId, objectId);
  2622. storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery,
  2623. astContext.UnsignedIntTy);
  2624. }
  2625. return 0;
  2626. }
  2627. uint32_t
  2628. SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
  2629. bool unclamped) {
  2630. // Possible signatures are as follows:
  2631. // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x);
  2632. // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy);
  2633. // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz);
  2634. // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz);
  2635. // Return type is always a single float (LOD).
  2636. assert(expr->getNumArgs() == 2u);
  2637. const auto *object = expr->getImplicitObjectArgument();
  2638. const auto objectInfo = loadIfGLValue(object);
  2639. const auto samplerState = doExpr(expr->getArg(0));
  2640. const uint32_t coordinate = doExpr(expr->getArg(1));
  2641. const uint32_t sampledImageType = theBuilder.getSampledImageType(
  2642. typeTranslator.translateType(object->getType()));
  2643. const uint32_t sampledImage = theBuilder.createBinaryOp(
  2644. spv::Op::OpSampledImage, sampledImageType, objectInfo, samplerState);
  2645. if (objectInfo.isNonUniform() || samplerState.isNonUniform()) {
  2646. // The sampled image will be used to access resource's memory, so we need
  2647. // to decorate it with NonUniformEXT.
  2648. theBuilder.decorateNonUniformEXT(sampledImage);
  2649. }
  2650. // The result type of OpImageQueryLod must be a float2.
  2651. const uint32_t queryResultType =
  2652. theBuilder.getVecType(theBuilder.getFloat32Type(), 2u);
  2653. const uint32_t query = theBuilder.createBinaryOp(
  2654. spv::Op::OpImageQueryLod, queryResultType, sampledImage, coordinate);
  2655. // The first component of the float2 contains the mipmap array layer.
  2656. // The second component of the float2 represents the unclamped lod.
  2657. return theBuilder.createCompositeExtract(theBuilder.getFloat32Type(), query,
  2658. unclamped ? 1 : 0);
  2659. }
  2660. uint32_t SPIRVEmitter::processTextureGatherRGBACmpRGBA(
  2661. const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) {
  2662. // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following
  2663. // two sets:
  2664. // * SamplerState s, float2 location, int2 offset
  2665. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  2666. // int offset2, int2 offset3
  2667. //
  2668. // An additional 'out uint status' parameter can appear in both of the above.
  2669. //
  2670. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following
  2671. // two sets:
  2672. // * SamplerState s, float2 location, float compare_value, int2 offset
  2673. // * SamplerState s, float2 location, float compare_value, int2 offset1,
  2674. // int2 offset2, int2 offset3, int2 offset4
  2675. //
  2676. // An additional 'out uint status' parameter can appear in both of the above.
  2677. //
  2678. // TextureCube's signature is somewhat different from the rest.
  2679. // Parameters for .Gather{Red|Green|Blue|Alpha}() for TextureCube are:
  2680. // * SamplerState s, float2 location, out uint status
  2681. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() for TextureCube are:
  2682. // * SamplerState s, float2 location, float compare_value, out uint status
  2683. //
  2684. // Return type is always a 4-component vector.
  2685. const FunctionDecl *callee = expr->getDirectCallee();
  2686. const auto numArgs = expr->getNumArgs();
  2687. const auto *imageExpr = expr->getImplicitObjectArgument();
  2688. const QualType imageType = imageExpr->getType();
  2689. const auto imageTypeId = typeTranslator.translateType(imageType);
  2690. const auto retTypeId = typeTranslator.translateType(callee->getReturnType());
  2691. // If the last arg is an unsigned integer, it must be the status.
  2692. const bool hasStatusArg =
  2693. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  2694. // Subtract 1 for status arg (if it exists), subtract 1 for compare_value (if
  2695. // it exists), and subtract 2 for SamplerState and location.
  2696. const auto numOffsetArgs = numArgs - hasStatusArg - isCmp - 2;
  2697. // No offset args for TextureCube, 1 or 4 offset args for the rest.
  2698. assert(numOffsetArgs == 0 || numOffsetArgs == 1 || numOffsetArgs == 4);
  2699. const auto image = loadIfGLValue(imageExpr);
  2700. const auto sampler = doExpr(expr->getArg(0));
  2701. const uint32_t coordinate = doExpr(expr->getArg(1));
  2702. const uint32_t compareVal =
  2703. isCmp ? static_cast<uint32_t>(doExpr(expr->getArg(2))) : 0;
  2704. // Handle offsets (if any).
  2705. bool needsEmulation = false;
  2706. uint32_t constOffset = 0, varOffset = 0, constOffsets = 0;
  2707. if (numOffsetArgs == 1) {
  2708. // The offset arg is not optional.
  2709. handleOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset);
  2710. } else if (numOffsetArgs == 4) {
  2711. const auto offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp));
  2712. const auto offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp));
  2713. const auto offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp));
  2714. const auto offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp));
  2715. // If any of the offsets is not constant, we then need to emulate the call
  2716. // using 4 OpImageGather instructions. Otherwise, we can leverage the
  2717. // ConstOffsets image operand.
  2718. if (offset0 && offset1 && offset2 && offset3) {
  2719. const uint32_t v2i32 =
  2720. theBuilder.getVecType(theBuilder.getInt32Type(), 2);
  2721. const uint32_t offsetType =
  2722. theBuilder.getArrayType(v2i32, theBuilder.getConstantUint32(4));
  2723. constOffsets = theBuilder.getConstantComposite(
  2724. offsetType, {offset0, offset1, offset2, offset3});
  2725. } else {
  2726. needsEmulation = true;
  2727. }
  2728. }
  2729. const auto status =
  2730. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  2731. : 0;
  2732. const bool isNonUniform = image.isNonUniform() || sampler.isNonUniform();
  2733. if (needsEmulation) {
  2734. const auto elemType = typeTranslator.translateType(
  2735. hlsl::GetHLSLVecElementType(callee->getReturnType()));
  2736. uint32_t texels[4];
  2737. for (uint32_t i = 0; i < 4; ++i) {
  2738. varOffset = doExpr(expr->getArg(2 + isCmp + i));
  2739. const uint32_t gatherRet = theBuilder.createImageGather(
  2740. retTypeId, imageTypeId, image, sampler, isNonUniform, coordinate,
  2741. theBuilder.getConstantInt32(component), compareVal, /*constOffset*/ 0,
  2742. varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0, status);
  2743. texels[i] = theBuilder.createCompositeExtract(elemType, gatherRet, {i});
  2744. }
  2745. return theBuilder.createCompositeConstruct(
  2746. retTypeId, {texels[0], texels[1], texels[2], texels[3]});
  2747. }
  2748. return theBuilder.createImageGather(
  2749. retTypeId, imageTypeId, image, sampler, isNonUniform, coordinate,
  2750. theBuilder.getConstantInt32(component), compareVal, constOffset,
  2751. varOffset, constOffsets, /*sampleNumber*/ 0, status);
  2752. }
  2753. uint32_t SPIRVEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) {
  2754. // Signature for Texture2D/Texture2DArray:
  2755. //
  2756. // float4 GatherCmp(
  2757. // in SamplerComparisonState s,
  2758. // in float2 location,
  2759. // in float compare_value
  2760. // [,in int2 offset]
  2761. // [,out uint Status]
  2762. // );
  2763. //
  2764. // Signature for TextureCube/TextureCubeArray:
  2765. //
  2766. // float4 GatherCmp(
  2767. // in SamplerComparisonState s,
  2768. // in float2 location,
  2769. // in float compare_value,
  2770. // out uint Status
  2771. // );
  2772. //
  2773. // Other Texture types do not have the GatherCmp method.
  2774. const FunctionDecl *callee = expr->getDirectCallee();
  2775. const auto numArgs = expr->getNumArgs();
  2776. const bool hasStatusArg =
  2777. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  2778. const bool hasOffsetArg = (numArgs == 5) || (numArgs == 4 && !hasStatusArg);
  2779. const auto *imageExpr = expr->getImplicitObjectArgument();
  2780. const auto image = loadIfGLValue(imageExpr);
  2781. const auto sampler = doExpr(expr->getArg(0));
  2782. const uint32_t coordinate = doExpr(expr->getArg(1));
  2783. const uint32_t comparator = doExpr(expr->getArg(2));
  2784. uint32_t constOffset = 0, varOffset = 0;
  2785. if (hasOffsetArg)
  2786. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  2787. const auto retType = typeTranslator.translateType(callee->getReturnType());
  2788. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  2789. const auto status =
  2790. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  2791. : 0;
  2792. return theBuilder.createImageGather(
  2793. retType, imageType, image, sampler,
  2794. image.isNonUniform() || sampler.isNonUniform(), coordinate,
  2795. /*component*/ 0, comparator, constOffset, varOffset, /*constOffsets*/ 0,
  2796. /*sampleNumber*/ 0, status);
  2797. }
  2798. SpirvEvalInfo SPIRVEmitter::processBufferTextureLoad(
  2799. const Expr *object, const uint32_t locationId, uint32_t constOffset,
  2800. uint32_t varOffset, uint32_t lod, uint32_t residencyCode) {
  2801. // Loading for Buffer and RWBuffer translates to an OpImageFetch.
  2802. // The result type of an OpImageFetch must be a vec4 of float or int.
  2803. const auto type = object->getType();
  2804. assert(TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type) ||
  2805. TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
  2806. TypeTranslator::isSubpassInput(type) ||
  2807. TypeTranslator::isSubpassInputMS(type));
  2808. const bool doFetch =
  2809. TypeTranslator::isBuffer(type) || TypeTranslator::isTexture(type);
  2810. const auto objectInfo = loadIfGLValue(object);
  2811. if (objectInfo.isNonUniform()) {
  2812. // Decoreate the image handle for OpImageFetch/OpImageRead
  2813. theBuilder.decorateNonUniformEXT(objectInfo);
  2814. }
  2815. // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod.
  2816. uint32_t sampleNumber = 0;
  2817. if (TypeTranslator::isTextureMS(type) ||
  2818. TypeTranslator::isSubpassInputMS(type)) {
  2819. sampleNumber = lod;
  2820. lod = 0;
  2821. }
  2822. const auto sampledType = hlsl::GetHLSLResourceResultType(type);
  2823. QualType elemType = sampledType;
  2824. uint32_t elemCount = 1;
  2825. uint32_t elemTypeId = 0;
  2826. bool isTemplateOverStruct = false;
  2827. // Check whether the template type is a vector type or struct type.
  2828. if (!TypeTranslator::isVectorType(sampledType, &elemType, &elemCount)) {
  2829. if (sampledType->getAsStructureType()) {
  2830. isTemplateOverStruct = true;
  2831. // For struct type, we need to make sure it can fit into a 4-component
  2832. // vector. Detailed failing reasons will be emitted by the function so
  2833. // we don't need to emit errors here.
  2834. if (!typeTranslator.canFitIntoOneRegister(sampledType, &elemType,
  2835. &elemCount))
  2836. return 0;
  2837. }
  2838. }
  2839. if (elemType->isFloatingType()) {
  2840. elemTypeId = theBuilder.getFloat32Type();
  2841. } else if (elemType->isSignedIntegerType()) {
  2842. elemTypeId = theBuilder.getInt32Type();
  2843. } else if (elemType->isUnsignedIntegerType()) {
  2844. elemTypeId = theBuilder.getUint32Type();
  2845. } else {
  2846. emitError("loading %0 value unsupported", object->getExprLoc()) << type;
  2847. return 0;
  2848. }
  2849. // OpImageFetch and OpImageRead can only fetch a vector of 4 elements.
  2850. const uint32_t texelTypeId = theBuilder.getVecType(elemTypeId, 4u);
  2851. const uint32_t texel = theBuilder.createImageFetchOrRead(
  2852. doFetch, texelTypeId, type, objectInfo, locationId, lod, constOffset,
  2853. varOffset, /*constOffsets*/ 0, sampleNumber, residencyCode);
  2854. // If the result type is a vec1, vec2, or vec3, some extra processing
  2855. // (extraction) is required.
  2856. uint32_t retVal = extractVecFromVec4(texel, elemCount, elemTypeId);
  2857. if (isTemplateOverStruct) {
  2858. // Convert to the struct so that we are consistent with types in the AST.
  2859. retVal = convertVectorToStruct(sampledType, elemTypeId, retVal);
  2860. }
  2861. return SpirvEvalInfo(retVal).setRValue();
  2862. }
  2863. SpirvEvalInfo SPIRVEmitter::processByteAddressBufferLoadStore(
  2864. const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
  2865. uint32_t resultId = 0;
  2866. const auto object = expr->getImplicitObjectArgument();
  2867. const auto objectInfo = loadIfAliasVarRef(object);
  2868. assert(numWords >= 1 && numWords <= 4);
  2869. if (doStore) {
  2870. assert(typeTranslator.isRWByteAddressBuffer(object->getType()));
  2871. assert(expr->getNumArgs() == 2);
  2872. } else {
  2873. assert(typeTranslator.isRWByteAddressBuffer(object->getType()) ||
  2874. typeTranslator.isByteAddressBuffer(object->getType()));
  2875. if (expr->getNumArgs() == 2) {
  2876. emitError(
  2877. "(RW)ByteAddressBuffer::Load(in address, out status) not supported",
  2878. expr->getExprLoc());
  2879. return 0;
  2880. }
  2881. }
  2882. const Expr *addressExpr = expr->getArg(0);
  2883. const uint32_t byteAddress = doExpr(addressExpr);
  2884. const uint32_t addressTypeId =
  2885. typeTranslator.translateType(addressExpr->getType());
  2886. // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory
  2887. // access). The AST always casts the address to unsinged integer, so shift
  2888. // by unsinged integer 2.
  2889. const uint32_t constUint2 = theBuilder.getConstantUint32(2);
  2890. const uint32_t address = theBuilder.createBinaryOp(
  2891. spv::Op::OpShiftRightLogical, addressTypeId, byteAddress, constUint2);
  2892. // Perform access chain into the RWByteAddressBuffer.
  2893. // First index must be zero (member 0 of the struct is a
  2894. // runtimeArray). The second index passed to OpAccessChain should be
  2895. // the address.
  2896. const uint32_t uintTypeId = theBuilder.getUint32Type();
  2897. const uint32_t ptrType =
  2898. theBuilder.getPointerType(uintTypeId, objectInfo.getStorageClass());
  2899. const uint32_t constUint0 = theBuilder.getConstantUint32(0);
  2900. if (doStore) {
  2901. const uint32_t valuesId = doExpr(expr->getArg(1));
  2902. uint32_t curStoreAddress = address;
  2903. for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) {
  2904. // Extract a 32-bit word from the input.
  2905. const uint32_t curValue = numWords == 1
  2906. ? valuesId
  2907. : theBuilder.createCompositeExtract(
  2908. uintTypeId, valuesId, {wordCounter});
  2909. // Update the output address if necessary.
  2910. if (wordCounter > 0) {
  2911. const uint32_t offset = theBuilder.getConstantUint32(wordCounter);
  2912. curStoreAddress = theBuilder.createBinaryOp(
  2913. spv::Op::OpIAdd, addressTypeId, address, offset);
  2914. }
  2915. // Store the word to the right address at the output.
  2916. const uint32_t storePtr = theBuilder.createAccessChain(
  2917. ptrType, objectInfo, {constUint0, curStoreAddress});
  2918. theBuilder.createStore(storePtr, curValue);
  2919. }
  2920. } else {
  2921. uint32_t loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  2922. {constUint0, address});
  2923. resultId = theBuilder.createLoad(uintTypeId, loadPtr);
  2924. if (numWords > 1) {
  2925. // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to
  2926. // return a vector result.
  2927. llvm::SmallVector<uint32_t, 4> values;
  2928. values.push_back(resultId);
  2929. for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) {
  2930. const uint32_t offset = theBuilder.getConstantUint32(wordCounter - 1);
  2931. const uint32_t newAddress = theBuilder.createBinaryOp(
  2932. spv::Op::OpIAdd, addressTypeId, address, offset);
  2933. loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  2934. {constUint0, newAddress});
  2935. values.push_back(theBuilder.createLoad(uintTypeId, loadPtr));
  2936. }
  2937. const uint32_t resultType =
  2938. theBuilder.getVecType(addressTypeId, numWords);
  2939. resultId = theBuilder.createCompositeConstruct(resultType, values);
  2940. }
  2941. }
  2942. return SpirvEvalInfo(resultId).setRValue();
  2943. }
  2944. SpirvEvalInfo
  2945. SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
  2946. if (expr->getNumArgs() == 2) {
  2947. emitError(
  2948. "(RW)StructuredBuffer::Load(in location, out status) not supported",
  2949. expr->getExprLoc());
  2950. return 0;
  2951. }
  2952. const auto *buffer = expr->getImplicitObjectArgument();
  2953. auto info = loadIfAliasVarRef(buffer);
  2954. const QualType structType =
  2955. hlsl::GetHLSLResourceResultType(buffer->getType());
  2956. const uint32_t zero = theBuilder.getConstantInt32(0);
  2957. const uint32_t index = doExpr(expr->getArg(0));
  2958. return turnIntoElementPtr(buffer->getType(), info, structType, {zero, index});
  2959. }
  2960. uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
  2961. bool isInc, bool loadObject) {
  2962. const uint32_t i32Type = theBuilder.getInt32Type();
  2963. const uint32_t one = theBuilder.getConstantUint32(1); // As scope: Device
  2964. const uint32_t zero = theBuilder.getConstantUint32(0); // As memory sema: None
  2965. const uint32_t sOne = theBuilder.getConstantInt32(1);
  2966. const auto *object =
  2967. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  2968. if (loadObject) {
  2969. // We don't need the object's <result-id> here since counter variable is a
  2970. // separate variable. But we still need the side effects of evaluating the
  2971. // object, e.g., if the source code is foo(...).IncrementCounter(), we still
  2972. // want to emit the code for foo(...).
  2973. (void)doExpr(object);
  2974. }
  2975. const auto *counterPair = getFinalACSBufferCounter(object);
  2976. if (!counterPair) {
  2977. emitFatalError("cannot find the associated counter variable",
  2978. object->getExprLoc());
  2979. return 0;
  2980. }
  2981. const uint32_t counterPtrType = theBuilder.getPointerType(
  2982. theBuilder.getInt32Type(), spv::StorageClass::Uniform);
  2983. const uint32_t counterPtr = theBuilder.createAccessChain(
  2984. counterPtrType, counterPair->get(theBuilder, typeTranslator), {zero});
  2985. uint32_t index = 0;
  2986. if (isInc) {
  2987. index = theBuilder.createAtomicOp(spv::Op::OpAtomicIAdd, i32Type,
  2988. counterPtr, one, zero, sOne);
  2989. } else {
  2990. // Note that OpAtomicISub returns the value before the subtraction;
  2991. // so we need to do substraction again with OpAtomicISub's return value.
  2992. const auto prev = theBuilder.createAtomicOp(spv::Op::OpAtomicISub, i32Type,
  2993. counterPtr, one, zero, sOne);
  2994. index = theBuilder.createBinaryOp(spv::Op::OpISub, i32Type, prev, sOne);
  2995. }
  2996. return index;
  2997. }
  2998. bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
  2999. const Expr *srcExpr) {
  3000. // We are handling associated counters here. Casts should not alter which
  3001. // associated counter to manipulate.
  3002. srcExpr = srcExpr->IgnoreParenCasts();
  3003. // For parameters of forward-declared functions. We must make sure the
  3004. // associated counter variable is created. But for forward-declared functions,
  3005. // the translation of the real definition may not be started yet.
  3006. if (const auto *param = dyn_cast<ParmVarDecl>(dstDecl))
  3007. declIdMapper.createFnParamCounterVar(param);
  3008. // For implicit objects of methods. Similar to the above.
  3009. else if (const auto *thisObject = dyn_cast<ImplicitParamDecl>(dstDecl))
  3010. declIdMapper.createFnParamCounterVar(thisObject);
  3011. // Handle AssocCounter#1 (see CounterVarFields comment)
  3012. if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) {
  3013. const auto *srcPair = getFinalACSBufferCounter(srcExpr);
  3014. if (!srcPair) {
  3015. emitFatalError("cannot find the associated counter variable",
  3016. srcExpr->getExprLoc());
  3017. return false;
  3018. }
  3019. dstPair->assign(*srcPair, theBuilder, typeTranslator);
  3020. return true;
  3021. }
  3022. // Handle AssocCounter#3
  3023. llvm::SmallVector<uint32_t, 4> srcIndices;
  3024. const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl);
  3025. const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
  3026. if (dstFields && srcFields) {
  3027. if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) {
  3028. emitFatalError("cannot handle associated counter variable assignment",
  3029. srcExpr->getExprLoc());
  3030. return false;
  3031. }
  3032. return true;
  3033. }
  3034. // AssocCounter#2 and AssocCounter#4 for the lhs cannot happen since the lhs
  3035. // is a stand-alone decl in this method.
  3036. return false;
  3037. }
  3038. bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr,
  3039. const Expr *srcExpr) {
  3040. dstExpr = dstExpr->IgnoreParenCasts();
  3041. srcExpr = srcExpr->IgnoreParenCasts();
  3042. const auto *dstPair = getFinalACSBufferCounter(dstExpr);
  3043. const auto *srcPair = getFinalACSBufferCounter(srcExpr);
  3044. if ((dstPair == nullptr) != (srcPair == nullptr)) {
  3045. emitFatalError("cannot handle associated counter variable assignment",
  3046. srcExpr->getExprLoc());
  3047. return false;
  3048. }
  3049. // Handle AssocCounter#1 & AssocCounter#2
  3050. if (dstPair && srcPair) {
  3051. dstPair->assign(*srcPair, theBuilder, typeTranslator);
  3052. return true;
  3053. }
  3054. // Handle AssocCounter#3 & AssocCounter#4
  3055. llvm::SmallVector<uint32_t, 4> dstIndices;
  3056. llvm::SmallVector<uint32_t, 4> srcIndices;
  3057. const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices);
  3058. const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices);
  3059. if (dstFields && srcFields) {
  3060. return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder,
  3061. typeTranslator);
  3062. }
  3063. return false;
  3064. }
  3065. const CounterIdAliasPair *
  3066. SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) {
  3067. // AssocCounter#1: referencing some stand-alone variable
  3068. if (const auto *decl = getReferencedDef(expr))
  3069. return declIdMapper.getCounterIdAliasPair(decl);
  3070. // AssocCounter#2: referencing some non-struct field
  3071. llvm::SmallVector<uint32_t, 4> indices;
  3072. const auto *base =
  3073. collectArrayStructIndices(expr, &indices, /*rawIndex=*/true);
  3074. const auto *decl =
  3075. (base && isa<CXXThisExpr>(base))
  3076. ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
  3077. : getReferencedDef(base);
  3078. return declIdMapper.getCounterIdAliasPair(decl, &indices);
  3079. }
  3080. const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter(
  3081. const Expr *expr, llvm::SmallVector<uint32_t, 4> *indices) {
  3082. const auto *base =
  3083. collectArrayStructIndices(expr, indices, /*rawIndex=*/true);
  3084. const auto *decl =
  3085. (base && isa<CXXThisExpr>(base))
  3086. // Use the decl we created to represent the implicit object
  3087. ? getOrCreateDeclForMethodObject(cast<CXXMethodDecl>(curFunction))
  3088. // Find the referenced decl from the original source code
  3089. : getReferencedDef(base);
  3090. return declIdMapper.getCounterVarFields(decl);
  3091. }
  3092. const ImplicitParamDecl *
  3093. SPIRVEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) {
  3094. const auto found = thisDecls.find(method);
  3095. if (found != thisDecls.end())
  3096. return found->second;
  3097. const std::string name = method->getName().str() + ".this";
  3098. // Create a new identifier to convey the name
  3099. auto &identifier = astContext.Idents.get(name);
  3100. return thisDecls[method] = ImplicitParamDecl::Create(
  3101. astContext, /*DC=*/nullptr, SourceLocation(), &identifier,
  3102. method->getThisType(astContext)->getPointeeType());
  3103. }
  3104. SpirvEvalInfo
  3105. SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
  3106. const bool isAppend = expr->getNumArgs() == 1;
  3107. const uint32_t zero = theBuilder.getConstantUint32(0);
  3108. const auto *object =
  3109. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  3110. auto bufferInfo = loadIfAliasVarRef(object);
  3111. uint32_t index = incDecRWACSBufferCounter(
  3112. expr, isAppend,
  3113. // We have already translated the object in the above. Avoid duplication.
  3114. /*loadObject=*/false);
  3115. const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
  3116. (void)turnIntoElementPtr(object->getType(), bufferInfo, bufferElemTy,
  3117. {zero, index});
  3118. if (isAppend) {
  3119. // Write out the value
  3120. auto arg0 = doExpr(expr->getArg(0));
  3121. if (!arg0.isRValue()) {
  3122. arg0.setResultId(theBuilder.createLoad(
  3123. typeTranslator.translateType(bufferElemTy), arg0));
  3124. }
  3125. storeValue(bufferInfo, arg0, bufferElemTy);
  3126. return 0;
  3127. } else {
  3128. // Note that we are returning a pointer (lvalue) here inorder to further
  3129. // acess the fields in this element, e.g., buffer.Consume().a.b. So we
  3130. // cannot forcefully set all normal function calls as returning rvalue.
  3131. return bufferInfo;
  3132. }
  3133. }
  3134. uint32_t
  3135. SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
  3136. // TODO: handle multiple stream-output objects
  3137. const auto *object =
  3138. expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
  3139. const auto *stream = cast<DeclRefExpr>(object)->getDecl();
  3140. const uint32_t value = doExpr(expr->getArg(0));
  3141. declIdMapper.writeBackOutputStream(stream, stream->getType(), value);
  3142. theBuilder.createEmitVertex();
  3143. return 0;
  3144. }
  3145. uint32_t
  3146. SPIRVEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) {
  3147. // TODO: handle multiple stream-output objects
  3148. theBuilder.createEndPrimitive();
  3149. return 0;
  3150. }
  3151. uint32_t SPIRVEmitter::emitGetSamplePosition(const uint32_t sampleCount,
  3152. const uint32_t sampleIndex) {
  3153. struct Float2 {
  3154. float x;
  3155. float y;
  3156. };
  3157. static const Float2 pos2[] = {
  3158. {4.0 / 16.0, 4.0 / 16.0},
  3159. {-4.0 / 16.0, -4.0 / 16.0},
  3160. };
  3161. static const Float2 pos4[] = {
  3162. {-2.0 / 16.0, -6.0 / 16.0},
  3163. {6.0 / 16.0, -2.0 / 16.0},
  3164. {-6.0 / 16.0, 2.0 / 16.0},
  3165. {2.0 / 16.0, 6.0 / 16.0},
  3166. };
  3167. static const Float2 pos8[] = {
  3168. {1.0 / 16.0, -3.0 / 16.0}, {-1.0 / 16.0, 3.0 / 16.0},
  3169. {5.0 / 16.0, 1.0 / 16.0}, {-3.0 / 16.0, -5.0 / 16.0},
  3170. {-5.0 / 16.0, 5.0 / 16.0}, {-7.0 / 16.0, -1.0 / 16.0},
  3171. {3.0 / 16.0, 7.0 / 16.0}, {7.0 / 16.0, -7.0 / 16.0},
  3172. };
  3173. static const Float2 pos16[] = {
  3174. {1.0 / 16.0, 1.0 / 16.0}, {-1.0 / 16.0, -3.0 / 16.0},
  3175. {-3.0 / 16.0, 2.0 / 16.0}, {4.0 / 16.0, -1.0 / 16.0},
  3176. {-5.0 / 16.0, -2.0 / 16.0}, {2.0 / 16.0, 5.0 / 16.0},
  3177. {5.0 / 16.0, 3.0 / 16.0}, {3.0 / 16.0, -5.0 / 16.0},
  3178. {-2.0 / 16.0, 6.0 / 16.0}, {0.0 / 16.0, -7.0 / 16.0},
  3179. {-4.0 / 16.0, -6.0 / 16.0}, {-6.0 / 16.0, 4.0 / 16.0},
  3180. {-8.0 / 16.0, 0.0 / 16.0}, {7.0 / 16.0, -4.0 / 16.0},
  3181. {6.0 / 16.0, 7.0 / 16.0}, {-7.0 / 16.0, -8.0 / 16.0},
  3182. };
  3183. // We are emitting the SPIR-V for the following HLSL source code:
  3184. //
  3185. // float2 position;
  3186. //
  3187. // if (count == 2) {
  3188. // position = pos2[index];
  3189. // }
  3190. // else if (count == 4) {
  3191. // position = pos4[index];
  3192. // }
  3193. // else if (count == 8) {
  3194. // position = pos8[index];
  3195. // }
  3196. // else if (count == 16) {
  3197. // position = pos16[index];
  3198. // }
  3199. // else {
  3200. // position = float2(0.0f, 0.0f);
  3201. // }
  3202. const uint32_t boolType = theBuilder.getBoolType();
  3203. const auto v2f32Type = theBuilder.getVecType(theBuilder.getFloat32Type(), 2);
  3204. const uint32_t ptrType =
  3205. theBuilder.getPointerType(v2f32Type, spv::StorageClass::Function);
  3206. // Creates a SPIR-V function scope variable of type float2[len].
  3207. const auto createArray = [this, v2f32Type](const Float2 *ptr, uint32_t len) {
  3208. llvm::SmallVector<uint32_t, 16> components;
  3209. for (uint32_t i = 0; i < len; ++i) {
  3210. const auto x = theBuilder.getConstantFloat32(ptr[i].x);
  3211. const auto y = theBuilder.getConstantFloat32(ptr[i].y);
  3212. components.push_back(theBuilder.getConstantComposite(v2f32Type, {x, y}));
  3213. }
  3214. const auto arrType =
  3215. theBuilder.getArrayType(v2f32Type, theBuilder.getConstantUint32(len));
  3216. const auto val = theBuilder.getConstantComposite(arrType, components);
  3217. const std::string varName =
  3218. "var.GetSamplePosition.data." + std::to_string(len);
  3219. const auto var = theBuilder.addFnVar(arrType, varName);
  3220. theBuilder.createStore(var, val);
  3221. return var;
  3222. };
  3223. const uint32_t pos2Arr = createArray(pos2, 2);
  3224. const uint32_t pos4Arr = createArray(pos4, 4);
  3225. const uint32_t pos8Arr = createArray(pos8, 8);
  3226. const uint32_t pos16Arr = createArray(pos16, 16);
  3227. const uint32_t resultVar =
  3228. theBuilder.addFnVar(v2f32Type, "var.GetSamplePosition.result");
  3229. const uint32_t then2BB =
  3230. theBuilder.createBasicBlock("if.GetSamplePosition.then2");
  3231. const uint32_t then4BB =
  3232. theBuilder.createBasicBlock("if.GetSamplePosition.then4");
  3233. const uint32_t then8BB =
  3234. theBuilder.createBasicBlock("if.GetSamplePosition.then8");
  3235. const uint32_t then16BB =
  3236. theBuilder.createBasicBlock("if.GetSamplePosition.then16");
  3237. const uint32_t else2BB =
  3238. theBuilder.createBasicBlock("if.GetSamplePosition.else2");
  3239. const uint32_t else4BB =
  3240. theBuilder.createBasicBlock("if.GetSamplePosition.else4");
  3241. const uint32_t else8BB =
  3242. theBuilder.createBasicBlock("if.GetSamplePosition.else8");
  3243. const uint32_t else16BB =
  3244. theBuilder.createBasicBlock("if.GetSamplePosition.else16");
  3245. const uint32_t merge2BB =
  3246. theBuilder.createBasicBlock("if.GetSamplePosition.merge2");
  3247. const uint32_t merge4BB =
  3248. theBuilder.createBasicBlock("if.GetSamplePosition.merge4");
  3249. const uint32_t merge8BB =
  3250. theBuilder.createBasicBlock("if.GetSamplePosition.merge8");
  3251. const uint32_t merge16BB =
  3252. theBuilder.createBasicBlock("if.GetSamplePosition.merge16");
  3253. // if (count == 2) {
  3254. const auto check2 =
  3255. theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount,
  3256. theBuilder.getConstantUint32(2));
  3257. theBuilder.createConditionalBranch(check2, then2BB, else2BB, merge2BB);
  3258. theBuilder.addSuccessor(then2BB);
  3259. theBuilder.addSuccessor(else2BB);
  3260. theBuilder.setMergeTarget(merge2BB);
  3261. // position = pos2[index];
  3262. // }
  3263. theBuilder.setInsertPoint(then2BB);
  3264. auto ac = theBuilder.createAccessChain(ptrType, pos2Arr, {sampleIndex});
  3265. theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac));
  3266. theBuilder.createBranch(merge2BB);
  3267. theBuilder.addSuccessor(merge2BB);
  3268. // else if (count == 4) {
  3269. theBuilder.setInsertPoint(else2BB);
  3270. const auto check4 =
  3271. theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount,
  3272. theBuilder.getConstantUint32(4));
  3273. theBuilder.createConditionalBranch(check4, then4BB, else4BB, merge4BB);
  3274. theBuilder.addSuccessor(then4BB);
  3275. theBuilder.addSuccessor(else4BB);
  3276. theBuilder.setMergeTarget(merge4BB);
  3277. // position = pos4[index];
  3278. // }
  3279. theBuilder.setInsertPoint(then4BB);
  3280. ac = theBuilder.createAccessChain(ptrType, pos4Arr, {sampleIndex});
  3281. theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac));
  3282. theBuilder.createBranch(merge4BB);
  3283. theBuilder.addSuccessor(merge4BB);
  3284. // else if (count == 8) {
  3285. theBuilder.setInsertPoint(else4BB);
  3286. const auto check8 =
  3287. theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount,
  3288. theBuilder.getConstantUint32(8));
  3289. theBuilder.createConditionalBranch(check8, then8BB, else8BB, merge8BB);
  3290. theBuilder.addSuccessor(then8BB);
  3291. theBuilder.addSuccessor(else8BB);
  3292. theBuilder.setMergeTarget(merge8BB);
  3293. // position = pos8[index];
  3294. // }
  3295. theBuilder.setInsertPoint(then8BB);
  3296. ac = theBuilder.createAccessChain(ptrType, pos8Arr, {sampleIndex});
  3297. theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac));
  3298. theBuilder.createBranch(merge8BB);
  3299. theBuilder.addSuccessor(merge8BB);
  3300. // else if (count == 16) {
  3301. theBuilder.setInsertPoint(else8BB);
  3302. const auto check16 =
  3303. theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount,
  3304. theBuilder.getConstantUint32(16));
  3305. theBuilder.createConditionalBranch(check16, then16BB, else16BB, merge16BB);
  3306. theBuilder.addSuccessor(then16BB);
  3307. theBuilder.addSuccessor(else16BB);
  3308. theBuilder.setMergeTarget(merge16BB);
  3309. // position = pos16[index];
  3310. // }
  3311. theBuilder.setInsertPoint(then16BB);
  3312. ac = theBuilder.createAccessChain(ptrType, pos16Arr, {sampleIndex});
  3313. theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac));
  3314. theBuilder.createBranch(merge16BB);
  3315. theBuilder.addSuccessor(merge16BB);
  3316. // else {
  3317. // position = float2(0.0f, 0.0f);
  3318. // }
  3319. theBuilder.setInsertPoint(else16BB);
  3320. const auto zero = theBuilder.getConstantFloat32(0);
  3321. const auto v2f32Zero =
  3322. theBuilder.getConstantComposite(v2f32Type, {zero, zero});
  3323. theBuilder.createStore(resultVar, v2f32Zero);
  3324. theBuilder.createBranch(merge16BB);
  3325. theBuilder.addSuccessor(merge16BB);
  3326. theBuilder.setInsertPoint(merge16BB);
  3327. theBuilder.createBranch(merge8BB);
  3328. theBuilder.addSuccessor(merge8BB);
  3329. theBuilder.setInsertPoint(merge8BB);
  3330. theBuilder.createBranch(merge4BB);
  3331. theBuilder.addSuccessor(merge4BB);
  3332. theBuilder.setInsertPoint(merge4BB);
  3333. theBuilder.createBranch(merge2BB);
  3334. theBuilder.addSuccessor(merge2BB);
  3335. theBuilder.setInsertPoint(merge2BB);
  3336. return theBuilder.createLoad(v2f32Type, resultVar);
  3337. }
  3338. SpirvEvalInfo SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
  3339. const FunctionDecl *callee = expr->getDirectCallee();
  3340. llvm::StringRef group;
  3341. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  3342. if (hlsl::GetIntrinsicOp(callee, opcode, group)) {
  3343. return processIntrinsicMemberCall(expr,
  3344. static_cast<hlsl::IntrinsicOp>(opcode));
  3345. }
  3346. return processCall(expr);
  3347. }
  3348. void SPIRVEmitter::handleOffsetInMethodCall(const CXXMemberCallExpr *expr,
  3349. uint32_t index,
  3350. uint32_t *constOffset,
  3351. uint32_t *varOffset) {
  3352. // Ensure the given arg index is not out-of-range.
  3353. assert(index < expr->getNumArgs());
  3354. *constOffset = *varOffset = 0; // Initialize both first
  3355. if ((*constOffset = tryToEvaluateAsConst(expr->getArg(index))))
  3356. return; // Constant offset
  3357. else
  3358. *varOffset = doExpr(expr->getArg(index));
  3359. }
  3360. SpirvEvalInfo
  3361. SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
  3362. hlsl::IntrinsicOp opcode) {
  3363. using namespace hlsl;
  3364. uint32_t retVal = 0;
  3365. switch (opcode) {
  3366. case IntrinsicOp::MOP_Sample:
  3367. retVal = processTextureSampleGather(expr, /*isSample=*/true);
  3368. break;
  3369. case IntrinsicOp::MOP_Gather:
  3370. retVal = processTextureSampleGather(expr, /*isSample=*/false);
  3371. break;
  3372. case IntrinsicOp::MOP_SampleBias:
  3373. retVal = processTextureSampleBiasLevel(expr, /*isBias=*/true);
  3374. break;
  3375. case IntrinsicOp::MOP_SampleLevel:
  3376. retVal = processTextureSampleBiasLevel(expr, /*isBias=*/false);
  3377. break;
  3378. case IntrinsicOp::MOP_SampleGrad:
  3379. retVal = processTextureSampleGrad(expr);
  3380. break;
  3381. case IntrinsicOp::MOP_SampleCmp:
  3382. retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true);
  3383. break;
  3384. case IntrinsicOp::MOP_SampleCmpLevelZero:
  3385. retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false);
  3386. break;
  3387. case IntrinsicOp::MOP_GatherRed:
  3388. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0);
  3389. break;
  3390. case IntrinsicOp::MOP_GatherGreen:
  3391. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1);
  3392. break;
  3393. case IntrinsicOp::MOP_GatherBlue:
  3394. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2);
  3395. break;
  3396. case IntrinsicOp::MOP_GatherAlpha:
  3397. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3);
  3398. break;
  3399. case IntrinsicOp::MOP_GatherCmp:
  3400. retVal = processTextureGatherCmp(expr);
  3401. break;
  3402. case IntrinsicOp::MOP_GatherCmpRed:
  3403. retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0);
  3404. break;
  3405. case IntrinsicOp::MOP_Load:
  3406. return processBufferTextureLoad(expr);
  3407. case IntrinsicOp::MOP_Load2:
  3408. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false);
  3409. case IntrinsicOp::MOP_Load3:
  3410. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false);
  3411. case IntrinsicOp::MOP_Load4:
  3412. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false);
  3413. case IntrinsicOp::MOP_Store:
  3414. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true);
  3415. case IntrinsicOp::MOP_Store2:
  3416. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true);
  3417. case IntrinsicOp::MOP_Store3:
  3418. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true);
  3419. case IntrinsicOp::MOP_Store4:
  3420. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true);
  3421. case IntrinsicOp::MOP_GetDimensions:
  3422. retVal = processGetDimensions(expr);
  3423. break;
  3424. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  3425. retVal = processTextureLevelOfDetail(expr, /* unclamped */ false);
  3426. case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped:
  3427. retVal = processTextureLevelOfDetail(expr, /* unclamped */ true);
  3428. break;
  3429. case IntrinsicOp::MOP_IncrementCounter:
  3430. retVal = theBuilder.createUnaryOp(
  3431. spv::Op::OpBitcast, theBuilder.getUint32Type(),
  3432. incDecRWACSBufferCounter(expr, /*isInc*/ true));
  3433. break;
  3434. case IntrinsicOp::MOP_DecrementCounter:
  3435. retVal = theBuilder.createUnaryOp(
  3436. spv::Op::OpBitcast, theBuilder.getUint32Type(),
  3437. incDecRWACSBufferCounter(expr, /*isInc*/ false));
  3438. break;
  3439. case IntrinsicOp::MOP_Append:
  3440. if (hlsl::IsHLSLStreamOutputType(
  3441. expr->getImplicitObjectArgument()->getType()))
  3442. return processStreamOutputAppend(expr);
  3443. else
  3444. return processACSBufferAppendConsume(expr);
  3445. case IntrinsicOp::MOP_Consume:
  3446. return processACSBufferAppendConsume(expr);
  3447. case IntrinsicOp::MOP_RestartStrip:
  3448. retVal = processStreamOutputRestart(expr);
  3449. break;
  3450. case IntrinsicOp::MOP_InterlockedAdd:
  3451. case IntrinsicOp::MOP_InterlockedAnd:
  3452. case IntrinsicOp::MOP_InterlockedOr:
  3453. case IntrinsicOp::MOP_InterlockedXor:
  3454. case IntrinsicOp::MOP_InterlockedUMax:
  3455. case IntrinsicOp::MOP_InterlockedUMin:
  3456. case IntrinsicOp::MOP_InterlockedMax:
  3457. case IntrinsicOp::MOP_InterlockedMin:
  3458. case IntrinsicOp::MOP_InterlockedExchange:
  3459. case IntrinsicOp::MOP_InterlockedCompareExchange:
  3460. case IntrinsicOp::MOP_InterlockedCompareStore:
  3461. retVal = processRWByteAddressBufferAtomicMethods(opcode, expr);
  3462. break;
  3463. case IntrinsicOp::MOP_GetSamplePosition:
  3464. retVal = processGetSamplePosition(expr);
  3465. break;
  3466. case IntrinsicOp::MOP_SubpassLoad:
  3467. retVal = processSubpassLoad(expr);
  3468. break;
  3469. case IntrinsicOp::MOP_GatherCmpGreen:
  3470. case IntrinsicOp::MOP_GatherCmpBlue:
  3471. case IntrinsicOp::MOP_GatherCmpAlpha:
  3472. emitError("no equivalent for %0 intrinsic method in Vulkan",
  3473. expr->getCallee()->getExprLoc())
  3474. << expr->getMethodDecl()->getName();
  3475. return 0;
  3476. default:
  3477. emitError("intrinsic '%0' method unimplemented",
  3478. expr->getCallee()->getExprLoc())
  3479. << expr->getDirectCallee()->getName();
  3480. return 0;
  3481. }
  3482. return SpirvEvalInfo(retVal).setRValue();
  3483. }
  3484. uint32_t SPIRVEmitter::createImageSample(
  3485. QualType retType, uint32_t imageType, uint32_t image, uint32_t sampler,
  3486. bool isNonUniform, uint32_t coordinate, uint32_t compareVal, uint32_t bias,
  3487. uint32_t lod, std::pair<uint32_t, uint32_t> grad, uint32_t constOffset,
  3488. uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod,
  3489. uint32_t residencyCodeId) {
  3490. const auto retTypeId = typeTranslator.translateType(retType);
  3491. // SampleDref* instructions in SPIR-V always return a scalar.
  3492. // They also have the correct type in HLSL.
  3493. if (compareVal) {
  3494. return theBuilder.createImageSample(
  3495. retTypeId, imageType, image, sampler, isNonUniform, coordinate,
  3496. compareVal, bias, lod, grad, constOffset, varOffset, constOffsets,
  3497. sample, minLod, residencyCodeId);
  3498. }
  3499. // Non-Dref Sample instructions in SPIR-V must always return a vec4.
  3500. auto texelTypeId = retTypeId;
  3501. QualType elemType = {};
  3502. uint32_t elemTypeId = 0;
  3503. uint32_t retVecSize = 0;
  3504. if (TypeTranslator::isVectorType(retType, &elemType, &retVecSize) &&
  3505. retVecSize != 4) {
  3506. elemTypeId = typeTranslator.translateType(elemType);
  3507. texelTypeId = theBuilder.getVecType(elemTypeId, 4);
  3508. } else if (TypeTranslator::isScalarType(retType)) {
  3509. retVecSize = 1;
  3510. elemTypeId = typeTranslator.translateType(retType);
  3511. texelTypeId = theBuilder.getVecType(elemTypeId, 4);
  3512. }
  3513. // The Lod and Grad image operands requires explicit-lod instructions.
  3514. // Otherwise we use implicit-lod instructions.
  3515. const bool isExplicit = lod || (grad.first && grad.second);
  3516. // Implicit-lod instructions are only allowed in pixel shader.
  3517. if (!shaderModel.IsPS() && !isExplicit)
  3518. needsLegalization = true;
  3519. uint32_t retVal = theBuilder.createImageSample(
  3520. texelTypeId, imageType, image, sampler, isNonUniform, coordinate,
  3521. compareVal, bias, lod, grad, constOffset, varOffset, constOffsets, sample,
  3522. minLod, residencyCodeId);
  3523. // Extract smaller vector from the vec4 result if necessary.
  3524. if (texelTypeId != retTypeId) {
  3525. retVal = extractVecFromVec4(retVal, retVecSize, elemTypeId);
  3526. }
  3527. return retVal;
  3528. }
  3529. uint32_t SPIRVEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
  3530. const bool isSample) {
  3531. // Signatures:
  3532. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D:
  3533. // DXGI_FORMAT Object.Sample(sampler_state S,
  3534. // float Location
  3535. // [, int Offset]
  3536. // [, float Clamp]
  3537. // [, out uint Status]);
  3538. //
  3539. // For TextureCube and TextureCubeArray:
  3540. // DXGI_FORMAT Object.Sample(sampler_state S,
  3541. // float Location
  3542. // [, float Clamp]
  3543. // [, out uint Status]);
  3544. //
  3545. // For Texture2D/Texture2DArray:
  3546. // <Template Type>4 Object.Gather(sampler_state S,
  3547. // float2|3|4 Location,
  3548. // int2 Offset
  3549. // [, uint Status]);
  3550. //
  3551. // For TextureCube/TextureCubeArray:
  3552. // <Template Type>4 Object.Gather(sampler_state S,
  3553. // float2|3|4 Location
  3554. // [, uint Status]);
  3555. //
  3556. // Other Texture types do not have a Gather method.
  3557. const auto numArgs = expr->getNumArgs();
  3558. const bool hasStatusArg =
  3559. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3560. uint32_t clamp = 0;
  3561. if (numArgs > 2 && expr->getArg(2)->getType()->isFloatingType())
  3562. clamp = doExpr(expr->getArg(2));
  3563. else if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  3564. clamp = doExpr(expr->getArg(3));
  3565. const bool hasClampArg = (clamp != 0);
  3566. const auto status =
  3567. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  3568. : 0;
  3569. // Subtract 1 for status (if it exists), subtract 1 for clamp (if it exists),
  3570. // and subtract 2 for sampler_state and location.
  3571. const bool hasOffsetArg = numArgs - hasStatusArg - hasClampArg - 2 > 0;
  3572. const auto *imageExpr = expr->getImplicitObjectArgument();
  3573. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  3574. const auto image = loadIfGLValue(imageExpr);
  3575. const auto sampler = doExpr(expr->getArg(0));
  3576. const uint32_t coordinate = doExpr(expr->getArg(1));
  3577. // .Sample()/.Gather() may have a third optional paramter for offset.
  3578. uint32_t constOffset = 0, varOffset = 0;
  3579. if (hasOffsetArg)
  3580. handleOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  3581. const bool isNonUniform = image.isNonUniform() || sampler.isNonUniform();
  3582. const auto retType = expr->getDirectCallee()->getReturnType();
  3583. const auto retTypeId = typeTranslator.translateType(retType);
  3584. if (isSample) {
  3585. return createImageSample(
  3586. retType, imageType, image, sampler, isNonUniform, coordinate,
  3587. /*compareVal*/ 0, /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0),
  3588. constOffset, varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0,
  3589. /*minLod*/ clamp, status);
  3590. } else {
  3591. return theBuilder.createImageGather(
  3592. retTypeId, imageType, image, sampler, isNonUniform, coordinate,
  3593. // .Gather() doc says we return four components of red data.
  3594. theBuilder.getConstantInt32(0), /*compareVal*/ 0, constOffset,
  3595. varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0, status);
  3596. }
  3597. }
  3598. uint32_t
  3599. SPIRVEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
  3600. const bool isBias) {
  3601. // Signatures:
  3602. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  3603. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  3604. // float Location,
  3605. // float Bias
  3606. // [, int Offset]
  3607. // [, float clamp]
  3608. // [, out uint Status]);
  3609. //
  3610. // For TextureCube and TextureCubeArray:
  3611. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  3612. // float Location,
  3613. // float Bias
  3614. // [, float clamp]
  3615. // [, out uint Status]);
  3616. //
  3617. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  3618. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  3619. // float Location,
  3620. // float LOD
  3621. // [, int Offset]
  3622. // [, out uint Status]);
  3623. //
  3624. // For TextureCube and TextureCubeArray:
  3625. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  3626. // float Location,
  3627. // float LOD
  3628. // [, out uint Status]);
  3629. const auto numArgs = expr->getNumArgs();
  3630. const bool hasStatusArg =
  3631. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3632. const auto status =
  3633. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  3634. : 0;
  3635. uint32_t clamp = 0;
  3636. // The .SampleLevel() methods do not take the clamp argument.
  3637. if (isBias) {
  3638. if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  3639. clamp = doExpr(expr->getArg(3));
  3640. else if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  3641. clamp = doExpr(expr->getArg(4));
  3642. }
  3643. const bool hasClampArg = clamp != 0;
  3644. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  3645. // and 3 for sampler_state, location, and Bias/LOD.
  3646. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 3 > 0;
  3647. const auto *imageExpr = expr->getImplicitObjectArgument();
  3648. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  3649. const auto image = loadIfGLValue(imageExpr);
  3650. const auto sampler = doExpr(expr->getArg(0));
  3651. const uint32_t coordinate = doExpr(expr->getArg(1));
  3652. uint32_t lod = 0;
  3653. uint32_t bias = 0;
  3654. if (isBias) {
  3655. bias = doExpr(expr->getArg(2));
  3656. } else {
  3657. lod = doExpr(expr->getArg(2));
  3658. }
  3659. // If offset is present in .Bias()/.SampleLevel(), it is the fourth argument.
  3660. uint32_t constOffset = 0, varOffset = 0;
  3661. if (hasOffsetArg)
  3662. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  3663. const auto retType = expr->getDirectCallee()->getReturnType();
  3664. return createImageSample(
  3665. retType, imageType, image, sampler,
  3666. image.isNonUniform() || sampler.isNonUniform(), coordinate,
  3667. /*compareVal*/ 0, bias, lod, std::make_pair(0, 0), constOffset, varOffset,
  3668. /*constOffsets*/ 0, /*sampleNumber*/ 0, /*minLod*/ clamp, status);
  3669. }
  3670. uint32_t SPIRVEmitter::processTextureSampleGrad(const CXXMemberCallExpr *expr) {
  3671. // Signature:
  3672. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, and Texture3D:
  3673. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  3674. // float Location,
  3675. // float DDX,
  3676. // float DDY
  3677. // [, int Offset]
  3678. // [, float Clamp]
  3679. // [, out uint Status]);
  3680. //
  3681. // For TextureCube and TextureCubeArray:
  3682. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  3683. // float Location,
  3684. // float DDX,
  3685. // float DDY
  3686. // [, float Clamp]
  3687. // [, out uint Status]);
  3688. const auto numArgs = expr->getNumArgs();
  3689. const bool hasStatusArg =
  3690. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3691. const auto status =
  3692. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  3693. : 0;
  3694. uint32_t clamp = 0;
  3695. if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  3696. clamp = doExpr(expr->getArg(4));
  3697. else if (numArgs > 5 && expr->getArg(5)->getType()->isFloatingType())
  3698. clamp = doExpr(expr->getArg(5));
  3699. const bool hasClampArg = clamp != 0;
  3700. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  3701. // and 4 for sampler_state, location, DDX, and DDY;
  3702. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 4 > 0;
  3703. const auto *imageExpr = expr->getImplicitObjectArgument();
  3704. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  3705. const auto image = loadIfGLValue(imageExpr);
  3706. const auto sampler = doExpr(expr->getArg(0));
  3707. const uint32_t coordinate = doExpr(expr->getArg(1));
  3708. const uint32_t ddx = doExpr(expr->getArg(2));
  3709. const uint32_t ddy = doExpr(expr->getArg(3));
  3710. // If offset is present in .SampleGrad(), it is the fifth argument.
  3711. uint32_t constOffset = 0, varOffset = 0;
  3712. if (hasOffsetArg)
  3713. handleOffsetInMethodCall(expr, 4, &constOffset, &varOffset);
  3714. const auto retType = expr->getDirectCallee()->getReturnType();
  3715. return createImageSample(
  3716. retType, imageType, image, sampler,
  3717. image.isNonUniform() || sampler.isNonUniform(), coordinate,
  3718. /*compareVal*/ 0, /*bias*/ 0, /*lod*/ 0, std::make_pair(ddx, ddy),
  3719. constOffset, varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0,
  3720. /*minLod*/ clamp, status);
  3721. }
  3722. uint32_t
  3723. SPIRVEmitter::processTextureSampleCmpCmpLevelZero(const CXXMemberCallExpr *expr,
  3724. const bool isCmp) {
  3725. // .SampleCmp() Signature:
  3726. //
  3727. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray:
  3728. // float Object.SampleCmp(
  3729. // SamplerComparisonState S,
  3730. // float Location,
  3731. // float CompareValue
  3732. // [, int Offset]
  3733. // [, float Clamp]
  3734. // [, out uint Status]
  3735. // );
  3736. //
  3737. // For TextureCube and TextureCubeArray:
  3738. // float Object.SampleCmp(
  3739. // SamplerComparisonState S,
  3740. // float Location,
  3741. // float CompareValue
  3742. // [, float Clamp]
  3743. // [, out uint Status]
  3744. // );
  3745. //
  3746. // .SampleCmpLevelZero() is identical to .SampleCmp() on mipmap level 0 only.
  3747. // It never takes a clamp argument, which is good because lod and clamp may
  3748. // not be used together.
  3749. //
  3750. // .SampleCmpLevelZero() Signature:
  3751. //
  3752. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray:
  3753. // float Object.SampleCmpLevelZero(
  3754. // SamplerComparisonState S,
  3755. // float Location,
  3756. // float CompareValue
  3757. // [, int Offset]
  3758. // [, out uint Status]
  3759. // );
  3760. //
  3761. // For TextureCube and TextureCubeArray:
  3762. // float Object.SampleCmpLevelZero(
  3763. // SamplerComparisonState S,
  3764. // float Location,
  3765. // float CompareValue
  3766. // [, out uint Status]
  3767. // );
  3768. const auto numArgs = expr->getNumArgs();
  3769. const bool hasStatusArg =
  3770. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3771. const auto status =
  3772. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  3773. : 0;
  3774. uint32_t clamp = 0;
  3775. // The .SampleCmpLevelZero() methods do not take the clamp argument.
  3776. if (isCmp) {
  3777. if (numArgs > 3 && expr->getArg(3)->getType()->isFloatingType())
  3778. clamp = doExpr(expr->getArg(3));
  3779. else if (numArgs > 4 && expr->getArg(4)->getType()->isFloatingType())
  3780. clamp = doExpr(expr->getArg(4));
  3781. }
  3782. const bool hasClampArg = clamp != 0;
  3783. // Subtract 1 for clamp (if it exists), 1 for status (if it exists),
  3784. // and 3 for sampler_state, location, and compare_value.
  3785. const bool hasOffsetArg = numArgs - hasClampArg - hasStatusArg - 3 > 0;
  3786. const auto *imageExpr = expr->getImplicitObjectArgument();
  3787. const auto image = loadIfGLValue(imageExpr);
  3788. const auto sampler = doExpr(expr->getArg(0));
  3789. const uint32_t coordinate = doExpr(expr->getArg(1));
  3790. const uint32_t compareVal = doExpr(expr->getArg(2));
  3791. // If offset is present in .SampleCmp(), it will be the fourth argument.
  3792. uint32_t constOffset = 0, varOffset = 0;
  3793. if (hasOffsetArg)
  3794. handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  3795. const uint32_t lod = isCmp ? 0 : theBuilder.getConstantFloat32(0);
  3796. const auto retType = expr->getDirectCallee()->getReturnType();
  3797. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  3798. return createImageSample(
  3799. retType, imageType, image, sampler,
  3800. image.isNonUniform() || sampler.isNonUniform(), coordinate, compareVal,
  3801. /*bias*/ 0, lod, std::make_pair(0, 0), constOffset, varOffset,
  3802. /*constOffsets*/ 0, /*sampleNumber*/ 0, /*minLod*/ clamp, status);
  3803. }
  3804. SpirvEvalInfo
  3805. SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
  3806. // Signature:
  3807. // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D:
  3808. // ret Object.Load(int Location
  3809. // [, int Offset]
  3810. // [, uint status]);
  3811. //
  3812. // For Texture2DMS and Texture2DMSArray, there is one additional argument:
  3813. // ret Object.Load(int Location
  3814. // [, int SampleIndex]
  3815. // [, int Offset]
  3816. // [, uint status]);
  3817. //
  3818. // For (RW)Buffer, RWTexture1D, RWTexture1DArray, RWTexture2D,
  3819. // RWTexture2DArray, RWTexture3D:
  3820. // ret Object.Load (int Location
  3821. // [, uint status]);
  3822. //
  3823. // Note: (RW)ByteAddressBuffer and (RW)StructuredBuffer types also have Load
  3824. // methods that take an additional Status argument. However, since these types
  3825. // are not represented as OpTypeImage in SPIR-V, we don't have a way of
  3826. // figuring out the Residency Code for them. Therefore having the Status
  3827. // argument for these types is not supported.
  3828. //
  3829. // For (RW)ByteAddressBuffer:
  3830. // ret Object.{Load,Load2,Load3,Load4} (int Location
  3831. // [, uint status]);
  3832. //
  3833. // For (RW)StructuredBuffer:
  3834. // ret Object.Load (int Location
  3835. // [, uint status]);
  3836. //
  3837. const auto *object = expr->getImplicitObjectArgument();
  3838. const auto objectType = object->getType();
  3839. if (typeTranslator.isRWByteAddressBuffer(objectType) ||
  3840. typeTranslator.isByteAddressBuffer(objectType))
  3841. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
  3842. if (TypeTranslator::isStructuredBuffer(objectType))
  3843. return processStructuredBufferLoad(expr);
  3844. const auto numArgs = expr->getNumArgs();
  3845. const auto *location = expr->getArg(0);
  3846. const bool isTextureMS = TypeTranslator::isTextureMS(objectType);
  3847. const bool hasStatusArg =
  3848. expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType();
  3849. const auto status =
  3850. hasStatusArg ? static_cast<uint32_t>(doExpr(expr->getArg(numArgs - 1)))
  3851. : 0;
  3852. if (TypeTranslator::isBuffer(objectType) ||
  3853. TypeTranslator::isRWBuffer(objectType) ||
  3854. TypeTranslator::isRWTexture(objectType))
  3855. return processBufferTextureLoad(object, doExpr(location), /*constOffset*/ 0,
  3856. /*varOffset*/ 0, /*lod*/ 0,
  3857. /*residencyCode*/ status);
  3858. // Subtract 1 for status (if it exists), and 1 for sampleIndex (if it exists),
  3859. // and 1 for location.
  3860. const bool hasOffsetArg = numArgs - hasStatusArg - isTextureMS - 1 > 0;
  3861. if (TypeTranslator::isTexture(objectType)) {
  3862. // .Load() has a second optional paramter for offset.
  3863. const auto locationId = doExpr(location);
  3864. uint32_t constOffset = 0, varOffset = 0;
  3865. uint32_t coordinate = locationId, lod = 0;
  3866. if (isTextureMS) {
  3867. // SampleIndex is only available when the Object is of Texture2DMS or
  3868. // Texture2DMSArray types. Under those cases, Offset will be the third
  3869. // parameter (index 2).
  3870. lod = doExpr(expr->getArg(1));
  3871. if (hasOffsetArg)
  3872. handleOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  3873. } else {
  3874. // For Texture Load() functions, the location parameter is a vector
  3875. // that consists of both the coordinate and the mipmap level (via the
  3876. // last vector element). We need to split it here since the
  3877. // OpImageFetch SPIR-V instruction encodes them as separate arguments.
  3878. splitVecLastElement(location->getType(), locationId, &coordinate, &lod);
  3879. // For textures other than Texture2DMS(Array), offset should be the
  3880. // second parameter (index 1).
  3881. if (hasOffsetArg)
  3882. handleOffsetInMethodCall(expr, 1, &constOffset, &varOffset);
  3883. }
  3884. return processBufferTextureLoad(object, coordinate, constOffset, varOffset,
  3885. lod, status);
  3886. }
  3887. emitError("Load() of the given object type unimplemented",
  3888. object->getExprLoc());
  3889. return 0;
  3890. }
  3891. uint32_t SPIRVEmitter::processGetDimensions(const CXXMemberCallExpr *expr) {
  3892. const auto objectType = expr->getImplicitObjectArgument()->getType();
  3893. if (TypeTranslator::isTexture(objectType) ||
  3894. TypeTranslator::isRWTexture(objectType) ||
  3895. TypeTranslator::isBuffer(objectType) ||
  3896. TypeTranslator::isRWBuffer(objectType)) {
  3897. return processBufferTextureGetDimensions(expr);
  3898. } else if (TypeTranslator::isByteAddressBuffer(objectType) ||
  3899. TypeTranslator::isRWByteAddressBuffer(objectType) ||
  3900. TypeTranslator::isStructuredBuffer(objectType) ||
  3901. TypeTranslator::isAppendStructuredBuffer(objectType) ||
  3902. TypeTranslator::isConsumeStructuredBuffer(objectType)) {
  3903. return processByteAddressBufferStructuredBufferGetDimensions(expr);
  3904. } else {
  3905. emitError("GetDimensions() of the given object type unimplemented",
  3906. expr->getExprLoc());
  3907. return 0;
  3908. }
  3909. }
  3910. SpirvEvalInfo
  3911. SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
  3912. { // Handle Buffer/RWBuffer/Texture/RWTexture indexing
  3913. const Expr *baseExpr = nullptr;
  3914. const Expr *indexExpr = nullptr;
  3915. const Expr *lodExpr = nullptr;
  3916. // For Textures, regular indexing (operator[]) uses slice 0.
  3917. if (isBufferTextureIndexing(expr, &baseExpr, &indexExpr)) {
  3918. const uint32_t lod = TypeTranslator::isTexture(baseExpr->getType())
  3919. ? theBuilder.getConstantUint32(0)
  3920. : 0;
  3921. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  3922. /*constOffset*/ 0, /*varOffset*/ 0, lod,
  3923. /*residencyCode*/ 0);
  3924. }
  3925. // .mips[][] or .sample[][] must use the correct slice.
  3926. if (isTextureMipsSampleIndexing(expr, &baseExpr, &indexExpr, &lodExpr)) {
  3927. const uint32_t lod = doExpr(lodExpr);
  3928. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  3929. /*constOffset*/ 0, /*varOffset*/ 0, lod,
  3930. /*residencyCode*/ 0);
  3931. }
  3932. }
  3933. llvm::SmallVector<uint32_t, 4> indices;
  3934. const Expr *baseExpr = collectArrayStructIndices(expr, &indices);
  3935. auto base = loadIfAliasVarRef(baseExpr);
  3936. if (indices.empty())
  3937. return base; // For indexing into size-1 vectors and 1xN matrices
  3938. // If we are indexing into a rvalue, to use OpAccessChain, we first need
  3939. // to create a local variable to hold the rvalue.
  3940. //
  3941. // TODO: We can optimize the codegen by emitting OpCompositeExtract if
  3942. // all indices are contant integers.
  3943. if (base.isRValue()) {
  3944. base = createTemporaryVar(baseExpr->getType(), "vector", base);
  3945. }
  3946. return turnIntoElementPtr(baseExpr->getType(), base, expr->getType(),
  3947. indices);
  3948. }
  3949. SpirvEvalInfo
  3950. SPIRVEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
  3951. const Expr *baseExpr = expr->getBase();
  3952. const auto baseInfo = doExpr(baseExpr);
  3953. const auto layoutRule = baseInfo.getLayoutRule();
  3954. const auto elemType = hlsl::GetHLSLMatElementType(baseExpr->getType());
  3955. const auto accessor = expr->getEncodedElementAccess();
  3956. const uint32_t elemTypeId =
  3957. typeTranslator.translateType(elemType, layoutRule);
  3958. uint32_t rowCount = 0, colCount = 0;
  3959. hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
  3960. // Construct a temporary vector out of all elements accessed:
  3961. // 1. Create access chain for each element using OpAccessChain
  3962. // 2. Load each element using OpLoad
  3963. // 3. Create the vector using OpCompositeConstruct
  3964. llvm::SmallVector<uint32_t, 4> elements;
  3965. for (uint32_t i = 0; i < accessor.Count; ++i) {
  3966. uint32_t row = 0, col = 0, elem = 0;
  3967. accessor.GetPosition(i, &row, &col);
  3968. llvm::SmallVector<uint32_t, 2> indices;
  3969. // If the matrix only has one row/column, we are indexing into a vector
  3970. // then. Only one index is needed for such cases.
  3971. if (rowCount > 1)
  3972. indices.push_back(row);
  3973. if (colCount > 1)
  3974. indices.push_back(col);
  3975. if (baseExpr->isGLValue()) {
  3976. for (uint32_t i = 0; i < indices.size(); ++i)
  3977. indices[i] = theBuilder.getConstantInt32(indices[i]);
  3978. const uint32_t ptrType =
  3979. theBuilder.getPointerType(elemTypeId, baseInfo.getStorageClass());
  3980. if (!indices.empty()) {
  3981. assert(!baseInfo.isRValue());
  3982. // Load the element via access chain
  3983. elem = theBuilder.createAccessChain(ptrType, baseInfo, indices);
  3984. } else {
  3985. // The matrix is of size 1x1. No need to use access chain, base should
  3986. // be the source pointer.
  3987. elem = baseInfo;
  3988. }
  3989. elem = theBuilder.createLoad(elemTypeId, elem);
  3990. } else { // e.g., (mat1 + mat2)._m11
  3991. elem = theBuilder.createCompositeExtract(elemTypeId, baseInfo, indices);
  3992. }
  3993. elements.push_back(elem);
  3994. }
  3995. const auto size = elements.size();
  3996. auto valueId = elements.front();
  3997. if (size > 1) {
  3998. const uint32_t vecType = theBuilder.getVecType(elemTypeId, size);
  3999. valueId = theBuilder.createCompositeConstruct(vecType, elements);
  4000. }
  4001. // Note: Special-case: Booleans have no physical layout, and therefore when
  4002. // layout is required booleans are represented as unsigned integers.
  4003. // Therefore, after loading the uint we should convert it boolean.
  4004. if (elemType->isBooleanType() && layoutRule != SpirvLayoutRule::Void) {
  4005. const auto fromType =
  4006. size == 1 ? astContext.UnsignedIntTy
  4007. : astContext.getExtVectorType(astContext.UnsignedIntTy, size);
  4008. const auto toType =
  4009. size == 1 ? astContext.BoolTy
  4010. : astContext.getExtVectorType(astContext.BoolTy, size);
  4011. valueId = castToBool(valueId, fromType, toType);
  4012. }
  4013. return SpirvEvalInfo(valueId).setRValue();
  4014. }
  4015. SpirvEvalInfo
  4016. SPIRVEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
  4017. const Expr *baseExpr = nullptr;
  4018. hlsl::VectorMemberAccessPositions accessor;
  4019. condenseVectorElementExpr(expr, &baseExpr, &accessor);
  4020. const QualType baseType = baseExpr->getType();
  4021. assert(hlsl::IsHLSLVecType(baseType));
  4022. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  4023. const auto accessorSize = static_cast<size_t>(accessor.Count);
  4024. // Depending on the number of elements selected, we emit different
  4025. // instructions.
  4026. // For vectors of size greater than 1, if we are only selecting one element,
  4027. // typical access chain or composite extraction should be fine. But if we
  4028. // are selecting more than one elements, we must resolve to vector specific
  4029. // operations.
  4030. // For size-1 vectors, if we are selecting their single elements multiple
  4031. // times, we need composite construct instructions.
  4032. if (accessorSize == 1) {
  4033. auto baseInfo = doExpr(baseExpr);
  4034. if (baseSize == 1) {
  4035. // Selecting one element from a size-1 vector. The underlying vector is
  4036. // already treated as a scalar.
  4037. return baseInfo;
  4038. }
  4039. // If the base is an lvalue, we should emit an access chain instruction
  4040. // so that we can load/store the specified element. For rvalue base,
  4041. // we should use composite extraction. We should check the immediate base
  4042. // instead of the original base here since we can have something like
  4043. // v.xyyz to turn a lvalue v into rvalue.
  4044. const auto type =
  4045. typeTranslator.translateType(expr->getType(), baseInfo.getLayoutRule());
  4046. if (!baseInfo.isRValue()) { // E.g., v.x;
  4047. const uint32_t ptrType =
  4048. theBuilder.getPointerType(type, baseInfo.getStorageClass());
  4049. const uint32_t index = theBuilder.getConstantInt32(accessor.Swz0);
  4050. // We need a lvalue here. Do not try to load.
  4051. return baseInfo.setResultId(
  4052. theBuilder.createAccessChain(ptrType, baseInfo, {index}));
  4053. } else { // E.g., (v + w).x;
  4054. // The original base vector may not be a rvalue. Need to load it if
  4055. // it is lvalue since ImplicitCastExpr (LValueToRValue) will be missing
  4056. // for that case.
  4057. auto result =
  4058. theBuilder.createCompositeExtract(type, baseInfo, {accessor.Swz0});
  4059. // Special-case: Booleans in SPIR-V do not have a physical layout. Uint is
  4060. // used to represent them when layout is required.
  4061. if (expr->getType()->isBooleanType() &&
  4062. baseInfo.getLayoutRule() != SpirvLayoutRule::Void)
  4063. result =
  4064. castToBool(result, astContext.UnsignedIntTy, astContext.BoolTy);
  4065. return baseInfo.setResultId(result);
  4066. }
  4067. }
  4068. if (baseSize == 1) {
  4069. // Selecting more than one element from a size-1 vector, for example,
  4070. // <scalar>.xx. Construct the vector.
  4071. auto info = loadIfGLValue(baseExpr);
  4072. const auto type =
  4073. typeTranslator.translateType(expr->getType(), info.getLayoutRule());
  4074. llvm::SmallVector<uint32_t, 4> components(accessorSize, info);
  4075. return info
  4076. .setResultId(theBuilder.createCompositeConstruct(type, components))
  4077. .setRValue();
  4078. }
  4079. llvm::SmallVector<uint32_t, 4> selectors;
  4080. selectors.resize(accessorSize);
  4081. // Whether we are selecting elements in the original order
  4082. bool originalOrder = baseSize == accessorSize;
  4083. for (uint32_t i = 0; i < accessorSize; ++i) {
  4084. accessor.GetPosition(i, &selectors[i]);
  4085. // We can select more elements than the vector provides. This handles
  4086. // that case too.
  4087. originalOrder &= selectors[i] == i;
  4088. }
  4089. if (originalOrder)
  4090. return doExpr(baseExpr);
  4091. auto info = loadIfGLValue(baseExpr);
  4092. const auto type =
  4093. typeTranslator.translateType(expr->getType(), info.getLayoutRule());
  4094. // Use base for both vectors. But we are only selecting values from the
  4095. // first one.
  4096. return info.setResultId(
  4097. theBuilder.createVectorShuffle(type, info, info, selectors));
  4098. }
  4099. SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
  4100. if (const uint32_t id = tryToEvaluateAsConst(expr))
  4101. return SpirvEvalInfo(id).setRValue();
  4102. return SpirvEvalInfo(InitListHandler(*this).process(expr)).setRValue();
  4103. }
  4104. SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
  4105. llvm::SmallVector<uint32_t, 4> indices;
  4106. const Expr *base = collectArrayStructIndices(expr, &indices);
  4107. auto info = loadIfAliasVarRef(base);
  4108. if (!indices.empty()) {
  4109. (void)turnIntoElementPtr(base->getType(), info, expr->getType(), indices);
  4110. }
  4111. return info;
  4112. }
  4113. uint32_t SPIRVEmitter::createTemporaryVar(QualType type, llvm::StringRef name,
  4114. const SpirvEvalInfo &init) {
  4115. // We are creating a temporary variable in the Function storage class here,
  4116. // which means it has void layout rule.
  4117. const uint32_t varType = typeTranslator.translateType(type);
  4118. const std::string varName = "temp.var." + name.str();
  4119. const uint32_t varId = theBuilder.addFnVar(varType, varName);
  4120. storeValue(varId, init, type);
  4121. return varId;
  4122. }
  4123. SpirvEvalInfo SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
  4124. const auto opcode = expr->getOpcode();
  4125. const auto *subExpr = expr->getSubExpr();
  4126. const auto subType = subExpr->getType();
  4127. auto subValue = doExpr(subExpr);
  4128. const auto subTypeId = typeTranslator.translateType(subType);
  4129. switch (opcode) {
  4130. case UO_PreInc:
  4131. case UO_PreDec:
  4132. case UO_PostInc:
  4133. case UO_PostDec: {
  4134. const bool isPre = opcode == UO_PreInc || opcode == UO_PreDec;
  4135. const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
  4136. const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
  4137. const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
  4138. const uint32_t one = hlsl::IsHLSLMatType(subType)
  4139. ? getMatElemValueOne(subType)
  4140. : getValueOne(subType);
  4141. uint32_t incValue = 0;
  4142. if (TypeTranslator::isMxNMatrix(subType)) {
  4143. // For matrices, we can only increment/decrement each vector of it.
  4144. const auto actOnEachVec = [this, spvOp, one](uint32_t /*index*/,
  4145. uint32_t vecType,
  4146. uint32_t lhsVec) {
  4147. const auto valId =
  4148. theBuilder.createBinaryOp(spvOp, vecType, lhsVec, one);
  4149. return SpirvEvalInfo(valId).setRValue();
  4150. };
  4151. incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec);
  4152. } else {
  4153. incValue = theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
  4154. }
  4155. theBuilder.createStore(subValue, incValue);
  4156. // Prefix increment/decrement operator returns a lvalue, while postfix
  4157. // increment/decrement returns a rvalue.
  4158. return isPre ? subValue : subValue.setResultId(originValue).setRValue();
  4159. }
  4160. case UO_Not: {
  4161. return subValue
  4162. .setResultId(
  4163. theBuilder.createUnaryOp(spv::Op::OpNot, subTypeId, subValue))
  4164. .setRValue();
  4165. }
  4166. case UO_LNot: {
  4167. // Parsing will do the necessary casting to make sure we are applying the
  4168. // ! operator on boolean values.
  4169. return subValue
  4170. .setResultId(theBuilder.createUnaryOp(spv::Op::OpLogicalNot, subTypeId,
  4171. subValue))
  4172. .setRValue();
  4173. }
  4174. case UO_Plus:
  4175. // No need to do anything for the prefix + operator.
  4176. return subValue;
  4177. case UO_Minus: {
  4178. // SPIR-V have two opcodes for negating values: OpSNegate and OpFNegate.
  4179. const spv::Op spvOp = isFloatOrVecOfFloatType(subType) ? spv::Op::OpFNegate
  4180. : spv::Op::OpSNegate;
  4181. return subValue
  4182. .setResultId(theBuilder.createUnaryOp(spvOp, subTypeId, subValue))
  4183. .setRValue();
  4184. }
  4185. default:
  4186. break;
  4187. }
  4188. emitError("unary operator '%0' unimplemented", expr->getExprLoc())
  4189. << expr->getOpcodeStr(opcode);
  4190. expr->dump();
  4191. return 0;
  4192. }
  4193. spv::Op SPIRVEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
  4194. const bool isSintType = isSintOrVecMatOfSintType(type);
  4195. const bool isUintType = isUintOrVecMatOfUintType(type);
  4196. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  4197. #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
  4198. \
  4199. case BO_##kind: { \
  4200. if (isSintType || isUintType) { \
  4201. return spv::Op::Op##intBinOp; \
  4202. } \
  4203. if (isFloatType) { \
  4204. return spv::Op::Op##floatBinOp; \
  4205. } \
  4206. } break
  4207. #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp) \
  4208. \
  4209. case BO_##kind: { \
  4210. if (isSintType) { \
  4211. return spv::Op::Op##sintBinOp; \
  4212. } \
  4213. if (isUintType) { \
  4214. return spv::Op::Op##uintBinOp; \
  4215. } \
  4216. if (isFloatType) { \
  4217. return spv::Op::Op##floatBinOp; \
  4218. } \
  4219. } break
  4220. #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp) \
  4221. \
  4222. case BO_##kind: { \
  4223. if (isSintType) { \
  4224. return spv::Op::Op##sintBinOp; \
  4225. } \
  4226. if (isUintType) { \
  4227. return spv::Op::Op##uintBinOp; \
  4228. } \
  4229. } break
  4230. switch (op) {
  4231. case BO_EQ: {
  4232. if (isBoolOrVecMatOfBoolType(type))
  4233. return spv::Op::OpLogicalEqual;
  4234. if (isSintType || isUintType)
  4235. return spv::Op::OpIEqual;
  4236. if (isFloatType)
  4237. return spv::Op::OpFOrdEqual;
  4238. } break;
  4239. case BO_NE: {
  4240. if (isBoolOrVecMatOfBoolType(type))
  4241. return spv::Op::OpLogicalNotEqual;
  4242. if (isSintType || isUintType)
  4243. return spv::Op::OpINotEqual;
  4244. if (isFloatType)
  4245. return spv::Op::OpFOrdNotEqual;
  4246. } break;
  4247. // According to HLSL doc, all sides of the && and || expression are always
  4248. // evaluated.
  4249. case BO_LAnd:
  4250. return spv::Op::OpLogicalAnd;
  4251. case BO_LOr:
  4252. return spv::Op::OpLogicalOr;
  4253. BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
  4254. BIN_OP_CASE_INT_FLOAT(AddAssign, IAdd, FAdd);
  4255. BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
  4256. BIN_OP_CASE_INT_FLOAT(SubAssign, ISub, FSub);
  4257. BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
  4258. BIN_OP_CASE_INT_FLOAT(MulAssign, IMul, FMul);
  4259. BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
  4260. BIN_OP_CASE_SINT_UINT_FLOAT(DivAssign, SDiv, UDiv, FDiv);
  4261. // According to HLSL spec, "the modulus operator returns the remainder of
  4262. // a division." "The % operator is defined only in cases where either both
  4263. // sides are positive or both sides are negative."
  4264. //
  4265. // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
  4266. // the former, the sign of a non-0 result comes from Operand 1, while
  4267. // with the latter, from Operand 2.
  4268. //
  4269. // For operands with different signs, technically we can map % to either
  4270. // Op*Rem or Op*Mod since it's undefined behavior. But it is more
  4271. // consistent with C (HLSL starts as a C derivative) and Clang frontend
  4272. // const expression evaluation if we map % to Op*Rem.
  4273. //
  4274. // Note there is no OpURem in SPIR-V.
  4275. BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
  4276. BIN_OP_CASE_SINT_UINT_FLOAT(RemAssign, SRem, UMod, FRem);
  4277. BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
  4278. BIN_OP_CASE_SINT_UINT_FLOAT(LE, SLessThanEqual, ULessThanEqual,
  4279. FOrdLessThanEqual);
  4280. BIN_OP_CASE_SINT_UINT_FLOAT(GT, SGreaterThan, UGreaterThan,
  4281. FOrdGreaterThan);
  4282. BIN_OP_CASE_SINT_UINT_FLOAT(GE, SGreaterThanEqual, UGreaterThanEqual,
  4283. FOrdGreaterThanEqual);
  4284. BIN_OP_CASE_SINT_UINT(And, BitwiseAnd, BitwiseAnd);
  4285. BIN_OP_CASE_SINT_UINT(AndAssign, BitwiseAnd, BitwiseAnd);
  4286. BIN_OP_CASE_SINT_UINT(Or, BitwiseOr, BitwiseOr);
  4287. BIN_OP_CASE_SINT_UINT(OrAssign, BitwiseOr, BitwiseOr);
  4288. BIN_OP_CASE_SINT_UINT(Xor, BitwiseXor, BitwiseXor);
  4289. BIN_OP_CASE_SINT_UINT(XorAssign, BitwiseXor, BitwiseXor);
  4290. BIN_OP_CASE_SINT_UINT(Shl, ShiftLeftLogical, ShiftLeftLogical);
  4291. BIN_OP_CASE_SINT_UINT(ShlAssign, ShiftLeftLogical, ShiftLeftLogical);
  4292. BIN_OP_CASE_SINT_UINT(Shr, ShiftRightArithmetic, ShiftRightLogical);
  4293. BIN_OP_CASE_SINT_UINT(ShrAssign, ShiftRightArithmetic, ShiftRightLogical);
  4294. default:
  4295. break;
  4296. }
  4297. #undef BIN_OP_CASE_INT_FLOAT
  4298. #undef BIN_OP_CASE_SINT_UINT_FLOAT
  4299. #undef BIN_OP_CASE_SINT_UINT
  4300. emitError("translating binary operator '%0' unimplemented", {})
  4301. << BinaryOperator::getOpcodeStr(op);
  4302. return spv::Op::OpNop;
  4303. }
  4304. SpirvEvalInfo SPIRVEmitter::processAssignment(const Expr *lhs,
  4305. const SpirvEvalInfo &rhs,
  4306. const bool isCompoundAssignment,
  4307. SpirvEvalInfo lhsPtr) {
  4308. lhs = lhs->IgnoreParenNoopCasts(astContext);
  4309. // Assigning to vector swizzling should be handled differently.
  4310. if (SpirvEvalInfo result = tryToAssignToVectorElements(lhs, rhs))
  4311. return result;
  4312. // Assigning to matrix swizzling should be handled differently.
  4313. if (SpirvEvalInfo result = tryToAssignToMatrixElements(lhs, rhs))
  4314. return result;
  4315. // Assigning to a RWBuffer/RWTexture should be handled differently.
  4316. if (SpirvEvalInfo result = tryToAssignToRWBufferRWTexture(lhs, rhs))
  4317. return result;
  4318. // Normal assignment procedure
  4319. if (!lhsPtr)
  4320. lhsPtr = doExpr(lhs);
  4321. storeValue(lhsPtr, rhs, lhs->getType());
  4322. // Plain assignment returns a rvalue, while compound assignment returns
  4323. // lvalue.
  4324. return isCompoundAssignment ? lhsPtr : rhs;
  4325. }
  4326. void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
  4327. const SpirvEvalInfo &rhsVal,
  4328. QualType lhsValType) {
  4329. if (const auto *refType = lhsValType->getAs<ReferenceType>())
  4330. lhsValType = refType->getPointeeType();
  4331. QualType matElemType = {};
  4332. const bool lhsIsMat = typeTranslator.isMxNMatrix(lhsValType, &matElemType);
  4333. const bool lhsIsFloatMat = lhsIsMat && matElemType->isFloatingType();
  4334. const bool lhsIsNonFpMat = lhsIsMat && !matElemType->isFloatingType();
  4335. if (typeTranslator.isScalarType(lhsValType) ||
  4336. typeTranslator.isVectorType(lhsValType) || lhsIsFloatMat) {
  4337. uint32_t rhsValId = rhsVal;
  4338. // Special-case: According to the SPIR-V Spec: There is no physical size
  4339. // or bit pattern defined for boolean type. Therefore an unsigned integer
  4340. // is used to represent booleans when layout is required. In such cases,
  4341. // we should cast the boolean to uint before creating OpStore.
  4342. if (isBoolOrVecOfBoolType(lhsValType) &&
  4343. lhsPtr.getLayoutRule() != SpirvLayoutRule::Void) {
  4344. uint32_t vecSize = 1;
  4345. const bool isVec =
  4346. TypeTranslator::isVectorType(lhsValType, nullptr, &vecSize);
  4347. const auto toType =
  4348. isVec ? astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize)
  4349. : astContext.UnsignedIntTy;
  4350. const auto fromType =
  4351. isVec ? astContext.getExtVectorType(astContext.BoolTy, vecSize)
  4352. : astContext.BoolTy;
  4353. rhsValId = castToInt(rhsValId, fromType, toType, {});
  4354. }
  4355. theBuilder.createStore(lhsPtr, rhsValId);
  4356. } else if (TypeTranslator::isOpaqueType(lhsValType)) {
  4357. // Resource types are represented using RecordType in the AST.
  4358. // Handle them before the general RecordType.
  4359. //
  4360. // HLSL allows to put resource types that translating into SPIR-V opaque
  4361. // types in structs, or assign to variables of resource types. These can all
  4362. // result in illegal SPIR-V for Vulkan. We just translate here literally and
  4363. // let SPIRV-Tools opt to do the legalization work.
  4364. //
  4365. // Note: legalization specific code
  4366. theBuilder.createStore(lhsPtr, rhsVal);
  4367. needsLegalization = true;
  4368. } else if (TypeTranslator::isAKindOfStructuredOrByteBuffer(lhsValType)) {
  4369. // The rhs should be a pointer and the lhs should be a pointer-to-pointer.
  4370. // Directly store the pointer here and let SPIRV-Tools opt to do the clean
  4371. // up.
  4372. //
  4373. // Note: legalization specific code
  4374. theBuilder.createStore(lhsPtr, rhsVal);
  4375. needsLegalization = true;
  4376. // For ConstantBuffers/TextureBuffers, we decompose and assign each field
  4377. // recursively like normal structs using the following logic.
  4378. //
  4379. // The frontend forbids declaring ConstantBuffer<T> or TextureBuffer<T>
  4380. // variables as function parameters/returns/variables, but happily accepts
  4381. // assignments/returns from ConstantBuffer<T>/TextureBuffer<T> to function
  4382. // parameters/returns/variables of type T. And ConstantBuffer<T> is not
  4383. // represented differently as struct T.
  4384. } else if (TypeTranslator::isOpaqueArrayType(lhsValType)) {
  4385. // For opaque array types, we cannot perform OpLoad on the whole array and
  4386. // then write out as a whole; instead, we need to OpLoad each element
  4387. // using access chains. This is to influence later SPIR-V transformations
  4388. // to use access chains to access each opaque object; if we do array
  4389. // wholesale handling here, they will be in the final transformed code.
  4390. // Drivers don't like that.
  4391. // TODO: consider moving this hack into SPIRV-Tools as a transformation.
  4392. assert(lhsValType->isConstantArrayType());
  4393. assert(!rhsVal.isRValue());
  4394. const auto *arrayType = astContext.getAsConstantArrayType(lhsValType);
  4395. const auto elemType = arrayType->getElementType();
  4396. const auto arraySize =
  4397. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  4398. // Do separate load of each element via access chain
  4399. llvm::SmallVector<uint32_t, 8> elements;
  4400. for (uint32_t i = 0; i < arraySize; ++i) {
  4401. const auto subRhsValType =
  4402. typeTranslator.translateType(elemType, rhsVal.getLayoutRule());
  4403. const auto subRhsPtrType =
  4404. theBuilder.getPointerType(subRhsValType, rhsVal.getStorageClass());
  4405. const auto subRhsPtr = theBuilder.createAccessChain(
  4406. subRhsPtrType, rhsVal, {theBuilder.getConstantInt32(i)});
  4407. elements.push_back(theBuilder.createLoad(subRhsValType, subRhsPtr));
  4408. }
  4409. // Create a new composite and write out once
  4410. const auto lhsValTypeId =
  4411. typeTranslator.translateType(lhsValType, lhsPtr.getLayoutRule());
  4412. theBuilder.createStore(
  4413. lhsPtr, theBuilder.createCompositeConstruct(lhsValTypeId, elements));
  4414. } else if (lhsPtr.getLayoutRule() == rhsVal.getLayoutRule()) {
  4415. // If lhs and rhs has the same memory layout, we should be safe to load
  4416. // from rhs and directly store into lhs and avoid decomposing rhs.
  4417. // Note: this check should happen after those setting needsLegalization.
  4418. // TODO: is this optimization always correct?
  4419. theBuilder.createStore(lhsPtr, rhsVal);
  4420. } else if (lhsValType->isRecordType() || lhsValType->isConstantArrayType() ||
  4421. lhsIsNonFpMat) {
  4422. theBuilder.createStore(
  4423. lhsPtr, reconstructValue(rhsVal, lhsValType, lhsPtr.getLayoutRule()));
  4424. } else {
  4425. emitError("storing value of type %0 unimplemented", {}) << lhsValType;
  4426. }
  4427. }
  4428. uint32_t SPIRVEmitter::reconstructValue(const SpirvEvalInfo &srcVal,
  4429. const QualType valType,
  4430. SpirvLayoutRule dstLR) {
  4431. // Lambda for casting scalar or vector of bool<-->uint in cases where one side
  4432. // of the reconstruction (lhs or rhs) has a layout rule.
  4433. const auto handleBooleanLayout = [this, &srcVal, dstLR](uint32_t val,
  4434. QualType valType) {
  4435. // We only need to cast if we have a scalar or vector of booleans.
  4436. if (!isBoolOrVecOfBoolType(valType))
  4437. return val;
  4438. SpirvLayoutRule srcLR = srcVal.getLayoutRule();
  4439. // Source value has a layout rule, and has therefore been represented
  4440. // as a uint. Cast it to boolean before using.
  4441. bool shouldCastToBool =
  4442. srcLR != SpirvLayoutRule::Void && dstLR == SpirvLayoutRule::Void;
  4443. // Destination has a layout rule, and should therefore be represented
  4444. // as a uint. Cast to uint before using.
  4445. bool shouldCastToUint =
  4446. srcLR == SpirvLayoutRule::Void && dstLR != SpirvLayoutRule::Void;
  4447. // No boolean layout issues to take care of.
  4448. if (!shouldCastToBool && !shouldCastToUint)
  4449. return val;
  4450. uint32_t vecSize = 1;
  4451. TypeTranslator::isVectorType(valType, nullptr, &vecSize);
  4452. QualType boolType =
  4453. vecSize == 1 ? astContext.BoolTy
  4454. : astContext.getExtVectorType(astContext.BoolTy, vecSize);
  4455. QualType uintType =
  4456. vecSize == 1
  4457. ? astContext.UnsignedIntTy
  4458. : astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize);
  4459. if (shouldCastToBool)
  4460. return castToBool(val, uintType, boolType);
  4461. if (shouldCastToUint)
  4462. return castToInt(val, boolType, uintType, {});
  4463. return val;
  4464. };
  4465. // Lambda for cases where we want to reconstruct an array
  4466. const auto reconstructArray = [this, &srcVal, valType,
  4467. dstLR](uint32_t arraySize,
  4468. QualType arrayElemType) {
  4469. llvm::SmallVector<uint32_t, 4> elements;
  4470. for (uint32_t i = 0; i < arraySize; ++i) {
  4471. const auto subSrcValType =
  4472. typeTranslator.translateType(arrayElemType, srcVal.getLayoutRule());
  4473. const auto subSrcVal =
  4474. theBuilder.createCompositeExtract(subSrcValType, srcVal, {i});
  4475. elements.push_back(reconstructValue(srcVal.substResultId(subSrcVal),
  4476. arrayElemType, dstLR));
  4477. }
  4478. const auto dstValType = typeTranslator.translateType(valType, dstLR);
  4479. return theBuilder.createCompositeConstruct(dstValType, elements);
  4480. };
  4481. // Constant arrays
  4482. if (const auto *arrayType = astContext.getAsConstantArrayType(valType)) {
  4483. const auto elemType = arrayType->getElementType();
  4484. const auto size =
  4485. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  4486. return reconstructArray(size, elemType);
  4487. }
  4488. // Non-floating-point matrices
  4489. QualType matElemType = {};
  4490. uint32_t numRows = 0, numCols = 0;
  4491. const bool isNonFpMat =
  4492. typeTranslator.isMxNMatrix(valType, &matElemType, &numRows, &numCols) &&
  4493. !matElemType->isFloatingType();
  4494. if (isNonFpMat) {
  4495. // Note: This check should happen before the RecordType check.
  4496. // Non-fp matrices are represented as arrays of vectors in SPIR-V.
  4497. // Each array element is a vector. Get the QualType for the vector.
  4498. const auto elemType = astContext.getExtVectorType(matElemType, numCols);
  4499. return reconstructArray(numRows, elemType);
  4500. }
  4501. // Note: This check should happen before the RecordType check since
  4502. // vector/matrix/resource types are represented as RecordType in the AST.
  4503. if (hlsl::IsHLSLVecMatType(valType) || hlsl::IsHLSLResourceType(valType))
  4504. return handleBooleanLayout(srcVal, valType);
  4505. // Structs
  4506. if (const auto *recordType = valType->getAs<RecordType>()) {
  4507. uint32_t index = 0;
  4508. llvm::SmallVector<uint32_t, 4> elements;
  4509. for (const auto *field : recordType->getDecl()->fields()) {
  4510. const auto subSrcValType = typeTranslator.translateType(
  4511. field->getType(), srcVal.getLayoutRule());
  4512. const auto subSrcVal =
  4513. theBuilder.createCompositeExtract(subSrcValType, srcVal, {index});
  4514. elements.push_back(reconstructValue(srcVal.substResultId(subSrcVal),
  4515. field->getType(), dstLR));
  4516. ++index;
  4517. }
  4518. const auto dstValType = typeTranslator.translateType(valType, dstLR);
  4519. return theBuilder.createCompositeConstruct(dstValType, elements);
  4520. }
  4521. return handleBooleanLayout(srcVal, valType);
  4522. }
  4523. SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
  4524. const BinaryOperatorKind opcode,
  4525. const QualType computationType,
  4526. const QualType resultType,
  4527. SourceRange sourceRange,
  4528. SpirvEvalInfo *lhsInfo,
  4529. const spv::Op mandateGenOpcode) {
  4530. const QualType lhsType = lhs->getType();
  4531. const QualType rhsType = rhs->getType();
  4532. // Binary logical operations (such as ==, !=, etc) that return a boolean type
  4533. // may get a literal (e.g. 0, 1, etc.) as lhs or rhs args. Since only
  4534. // non-zero-ness of these literals matter, they can be translated as 32-bits.
  4535. TypeTranslator::LiteralTypeHint hint(typeTranslator);
  4536. if (resultType->isBooleanType()) {
  4537. if (lhsType->isSpecificBuiltinType(BuiltinType::LitInt) ||
  4538. rhsType->isSpecificBuiltinType(BuiltinType::LitInt))
  4539. hint.setHint(astContext.IntTy);
  4540. if (lhsType->isSpecificBuiltinType(BuiltinType::LitFloat) ||
  4541. rhsType->isSpecificBuiltinType(BuiltinType::LitFloat))
  4542. hint.setHint(astContext.FloatTy);
  4543. }
  4544. // If the operands are of matrix type, we need to dispatch the operation
  4545. // onto each element vector iff the operands are not degenerated matrices
  4546. // and we don't have a matrix specific SPIR-V instruction for the operation.
  4547. if (!isSpirvMatrixOp(mandateGenOpcode) &&
  4548. TypeTranslator::isMxNMatrix(lhsType)) {
  4549. return processMatrixBinaryOp(lhs, rhs, opcode, sourceRange);
  4550. }
  4551. // Comma operator works differently from other binary operations as there is
  4552. // no SPIR-V instruction for it. For each comma, we must evaluate lhs and rhs
  4553. // respectively, and return the results of rhs.
  4554. if (opcode == BO_Comma) {
  4555. (void)doExpr(lhs);
  4556. return doExpr(rhs);
  4557. }
  4558. SpirvEvalInfo rhsVal = 0, lhsPtr = 0, lhsVal = 0;
  4559. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  4560. // Evalute rhs before lhs
  4561. rhsVal = loadIfGLValue(rhs);
  4562. lhsVal = lhsPtr = doExpr(lhs);
  4563. // This is a compound assignment. We need to load the lhs value if lhs
  4564. // is not already rvalue and does not generate a vector shuffle.
  4565. if (!lhsPtr.isRValue() && !isVectorShuffle(lhs)) {
  4566. lhsVal = loadIfGLValue(lhs, lhsPtr);
  4567. }
  4568. // For a compound assignments, the AST does not have the proper implicit
  4569. // cast if lhs and rhs have different types. So we need to manually cast lhs
  4570. // to the computation type.
  4571. if (computationType != lhsType)
  4572. lhsVal.setResultId(
  4573. castToType(lhsVal, lhsType, computationType, lhs->getExprLoc()));
  4574. } else {
  4575. // Evalute lhs before rhs
  4576. lhsPtr = doExpr(lhs);
  4577. lhsVal = loadIfGLValue(lhs, lhsPtr);
  4578. rhsVal = loadIfGLValue(rhs);
  4579. }
  4580. if (lhsInfo)
  4581. *lhsInfo = lhsPtr;
  4582. const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
  4583. ? translateOp(opcode, computationType)
  4584. : mandateGenOpcode;
  4585. switch (opcode) {
  4586. case BO_Shl:
  4587. case BO_Shr:
  4588. case BO_ShlAssign:
  4589. case BO_ShrAssign:
  4590. // We need to cull the RHS to make sure that we are not shifting by an
  4591. // amount that is larger than the bitwidth of the LHS.
  4592. rhsVal.setResultId(theBuilder.createBinaryOp(
  4593. spv::Op::OpBitwiseAnd, typeTranslator.translateType(computationType),
  4594. rhsVal, getMaskForBitwidthValue(rhsType)));
  4595. // Fall through
  4596. case BO_Add:
  4597. case BO_Sub:
  4598. case BO_Mul:
  4599. case BO_Div:
  4600. case BO_Rem:
  4601. case BO_LT:
  4602. case BO_LE:
  4603. case BO_GT:
  4604. case BO_GE:
  4605. case BO_EQ:
  4606. case BO_NE:
  4607. case BO_And:
  4608. case BO_Or:
  4609. case BO_Xor:
  4610. case BO_LAnd:
  4611. case BO_LOr:
  4612. case BO_AddAssign:
  4613. case BO_SubAssign:
  4614. case BO_MulAssign:
  4615. case BO_DivAssign:
  4616. case BO_RemAssign:
  4617. case BO_AndAssign:
  4618. case BO_OrAssign:
  4619. case BO_XorAssign: {
  4620. // To evaluate this expression as an OpSpecConstantOp, we need to make sure
  4621. // both operands are constant and at least one of them is a spec constant.
  4622. if (lhsVal.isConstant() && rhsVal.isConstant() &&
  4623. (lhsVal.isSpecConstant() || rhsVal.isSpecConstant()) &&
  4624. isAcceptedSpecConstantBinaryOp(spvOp)) {
  4625. const auto valId = theBuilder.createSpecConstantBinaryOp(
  4626. spvOp, typeTranslator.translateType(resultType), lhsVal, rhsVal);
  4627. return SpirvEvalInfo(valId).setRValue().setSpecConstant();
  4628. }
  4629. // Normal binary operation
  4630. uint32_t valId = 0;
  4631. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  4632. valId = theBuilder.createBinaryOp(
  4633. spvOp, typeTranslator.translateType(computationType), lhsVal, rhsVal);
  4634. // For a compound assignments, the AST does not have the proper implicit
  4635. // cast if lhs and rhs have different types. So we need to manually cast
  4636. // the result back to lhs' type.
  4637. if (computationType != lhsType)
  4638. valId = castToType(valId, computationType, lhsType, lhs->getExprLoc());
  4639. } else {
  4640. valId = theBuilder.createBinaryOp(
  4641. spvOp, typeTranslator.translateType(resultType), lhsVal, rhsVal);
  4642. }
  4643. auto result = SpirvEvalInfo(valId).setRValue();
  4644. // Propagate RelaxedPrecision
  4645. if (lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision())
  4646. result.setRelaxedPrecision();
  4647. // Propagate NonUniformEXT
  4648. if (lhsVal.isNonUniform() || rhsVal.isNonUniform())
  4649. result.setNonUniform();
  4650. return result;
  4651. }
  4652. case BO_Assign:
  4653. llvm_unreachable("assignment should not be handled here");
  4654. break;
  4655. case BO_PtrMemD:
  4656. case BO_PtrMemI:
  4657. case BO_Comma:
  4658. // Unimplemented
  4659. break;
  4660. }
  4661. emitError("binary operator '%0' unimplemented", lhs->getExprLoc())
  4662. << BinaryOperator::getOpcodeStr(opcode) << sourceRange;
  4663. return 0;
  4664. }
  4665. void SPIRVEmitter::initOnce(QualType varType, std::string varName,
  4666. uint32_t varPtr, const Expr *varInit) {
  4667. // For uninitialized resource objects, we do nothing since there is no
  4668. // meaningful zero values for them.
  4669. if (!varInit && hlsl::IsHLSLResourceType(varType))
  4670. return;
  4671. const uint32_t boolType = theBuilder.getBoolType();
  4672. varName = "init.done." + varName;
  4673. // Create a file/module visible variable to hold the initialization state.
  4674. const uint32_t initDoneVar =
  4675. theBuilder.addModuleVar(boolType, spv::StorageClass::Private, varName,
  4676. theBuilder.getConstantBool(false));
  4677. const uint32_t condition = theBuilder.createLoad(boolType, initDoneVar);
  4678. const uint32_t todoBB = theBuilder.createBasicBlock("if.init.todo");
  4679. const uint32_t doneBB = theBuilder.createBasicBlock("if.init.done");
  4680. // If initDoneVar contains true, we jump to the "done" basic block; otherwise,
  4681. // jump to the "todo" basic block.
  4682. theBuilder.createConditionalBranch(condition, doneBB, todoBB, doneBB);
  4683. theBuilder.addSuccessor(todoBB);
  4684. theBuilder.addSuccessor(doneBB);
  4685. theBuilder.setMergeTarget(doneBB);
  4686. theBuilder.setInsertPoint(todoBB);
  4687. // Do initialization and mark done
  4688. if (varInit) {
  4689. storeValue(
  4690. // Static function variable are of private storage class
  4691. SpirvEvalInfo(varPtr).setStorageClass(spv::StorageClass::Private),
  4692. doExpr(varInit), varInit->getType());
  4693. } else {
  4694. const auto typeId = typeTranslator.translateType(varType);
  4695. theBuilder.createStore(varPtr, theBuilder.getConstantNull(typeId));
  4696. }
  4697. theBuilder.createStore(initDoneVar, theBuilder.getConstantBool(true));
  4698. theBuilder.createBranch(doneBB);
  4699. theBuilder.addSuccessor(doneBB);
  4700. theBuilder.setInsertPoint(doneBB);
  4701. }
  4702. bool SPIRVEmitter::isVectorShuffle(const Expr *expr) {
  4703. // TODO: the following check is essentially duplicated from
  4704. // doHLSLVectorElementExpr. Should unify them.
  4705. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  4706. const Expr *base = nullptr;
  4707. hlsl::VectorMemberAccessPositions accessor;
  4708. condenseVectorElementExpr(vecElemExpr, &base, &accessor);
  4709. const auto accessorSize = accessor.Count;
  4710. if (accessorSize == 1) {
  4711. // Selecting only one element. OpAccessChain or OpCompositeExtract for
  4712. // such cases.
  4713. return false;
  4714. }
  4715. const auto baseSize = hlsl::GetHLSLVecSize(base->getType());
  4716. if (accessorSize != baseSize)
  4717. return true;
  4718. for (uint32_t i = 0; i < accessorSize; ++i) {
  4719. uint32_t position;
  4720. accessor.GetPosition(i, &position);
  4721. if (position != i)
  4722. return true;
  4723. }
  4724. // Selecting exactly the original vector. No vector shuffle generated.
  4725. return false;
  4726. }
  4727. return false;
  4728. }
  4729. bool SPIRVEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
  4730. const Expr **base,
  4731. const Expr **location,
  4732. const Expr **lod) {
  4733. if (!expr)
  4734. return false;
  4735. // <object>.mips[][] consists of an outer operator[] and an inner operator[]
  4736. const CXXOperatorCallExpr *outerExpr = expr;
  4737. if (outerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  4738. return false;
  4739. const Expr *arg0 = outerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  4740. const CXXOperatorCallExpr *innerExpr = dyn_cast<CXXOperatorCallExpr>(arg0);
  4741. // Must have an inner operator[]
  4742. if (!innerExpr ||
  4743. innerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript) {
  4744. return false;
  4745. }
  4746. const Expr *innerArg0 =
  4747. innerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  4748. const MemberExpr *memberExpr = dyn_cast<MemberExpr>(innerArg0);
  4749. if (!memberExpr)
  4750. return false;
  4751. // Must be accessing the member named "mips" or "sample"
  4752. const auto &memberName =
  4753. memberExpr->getMemberNameInfo().getName().getAsString();
  4754. if (memberName != "mips" && memberName != "sample")
  4755. return false;
  4756. const Expr *object = memberExpr->getBase();
  4757. const auto objectType = object->getType();
  4758. if (!TypeTranslator::isTexture(objectType))
  4759. return false;
  4760. if (base)
  4761. *base = object;
  4762. if (lod)
  4763. *lod = innerExpr->getArg(1);
  4764. if (location)
  4765. *location = outerExpr->getArg(1);
  4766. return true;
  4767. }
  4768. bool SPIRVEmitter::isBufferTextureIndexing(const CXXOperatorCallExpr *indexExpr,
  4769. const Expr **base,
  4770. const Expr **index) {
  4771. if (!indexExpr)
  4772. return false;
  4773. // Must be operator[]
  4774. if (indexExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  4775. return false;
  4776. const Expr *object = indexExpr->getArg(0);
  4777. const auto objectType = object->getType();
  4778. if (TypeTranslator::isBuffer(objectType) ||
  4779. TypeTranslator::isRWBuffer(objectType) ||
  4780. TypeTranslator::isTexture(objectType) ||
  4781. TypeTranslator::isRWTexture(objectType)) {
  4782. if (base)
  4783. *base = object;
  4784. if (index)
  4785. *index = indexExpr->getArg(1);
  4786. return true;
  4787. }
  4788. return false;
  4789. }
  4790. void SPIRVEmitter::condenseVectorElementExpr(
  4791. const HLSLVectorElementExpr *expr, const Expr **basePtr,
  4792. hlsl::VectorMemberAccessPositions *flattenedAccessor) {
  4793. llvm::SmallVector<hlsl::VectorMemberAccessPositions, 2> accessors;
  4794. accessors.push_back(expr->getEncodedElementAccess());
  4795. // Recursively descending until we find the true base vector. In the
  4796. // meanwhile, collecting accessors in the reverse order.
  4797. *basePtr = expr->getBase();
  4798. while (const auto *vecElemBase = dyn_cast<HLSLVectorElementExpr>(*basePtr)) {
  4799. accessors.push_back(vecElemBase->getEncodedElementAccess());
  4800. *basePtr = vecElemBase->getBase();
  4801. }
  4802. *flattenedAccessor = accessors.back();
  4803. for (int32_t i = accessors.size() - 2; i >= 0; --i) {
  4804. const auto &currentAccessor = accessors[i];
  4805. // Apply the current level of accessor to the flattened accessor of all
  4806. // previous levels of ones.
  4807. hlsl::VectorMemberAccessPositions combinedAccessor;
  4808. for (uint32_t j = 0; j < currentAccessor.Count; ++j) {
  4809. uint32_t currentPosition = 0;
  4810. currentAccessor.GetPosition(j, &currentPosition);
  4811. uint32_t previousPosition = 0;
  4812. flattenedAccessor->GetPosition(currentPosition, &previousPosition);
  4813. combinedAccessor.SetPosition(j, previousPosition);
  4814. }
  4815. combinedAccessor.Count = currentAccessor.Count;
  4816. combinedAccessor.IsValid =
  4817. flattenedAccessor->IsValid && currentAccessor.IsValid;
  4818. *flattenedAccessor = combinedAccessor;
  4819. }
  4820. }
  4821. SpirvEvalInfo SPIRVEmitter::createVectorSplat(const Expr *scalarExpr,
  4822. uint32_t size) {
  4823. bool isConstVal = false;
  4824. SpirvEvalInfo scalarVal = 0;
  4825. // Try to evaluate the element as constant first. If successful, then we
  4826. // can generate constant instructions for this vector splat.
  4827. if ((scalarVal = tryToEvaluateAsConst(scalarExpr))) {
  4828. isConstVal = true;
  4829. } else {
  4830. scalarVal = doExpr(scalarExpr);
  4831. }
  4832. if (size == 1) {
  4833. // Just return the scalar value for vector splat with size 1.
  4834. // Note that can be used as an lvalue, so we need to carry over
  4835. // the lvalueness for non-constant cases.
  4836. return isConstVal ? scalarVal.setConstant().setRValue() : scalarVal;
  4837. }
  4838. const uint32_t vecType = theBuilder.getVecType(
  4839. typeTranslator.translateType(scalarExpr->getType()), size);
  4840. llvm::SmallVector<uint32_t, 4> elements(size_t(size), scalarVal);
  4841. // TODO: we are saying the constant has Function storage class here.
  4842. // Should find a more meaningful one.
  4843. if (isConstVal) {
  4844. const auto valueId = theBuilder.getConstantComposite(vecType, elements);
  4845. return SpirvEvalInfo(valueId).setConstant().setRValue();
  4846. } else {
  4847. const auto valueId = theBuilder.createCompositeConstruct(vecType, elements);
  4848. return SpirvEvalInfo(valueId).setRValue();
  4849. }
  4850. }
  4851. void SPIRVEmitter::splitVecLastElement(QualType vecType, uint32_t vec,
  4852. uint32_t *residual,
  4853. uint32_t *lastElement) {
  4854. assert(hlsl::IsHLSLVecType(vecType));
  4855. const uint32_t count = hlsl::GetHLSLVecSize(vecType);
  4856. assert(count > 1);
  4857. const uint32_t elemTypeId =
  4858. typeTranslator.translateType(hlsl::GetHLSLVecElementType(vecType));
  4859. if (count == 2) {
  4860. *residual = theBuilder.createCompositeExtract(elemTypeId, vec, 0);
  4861. } else {
  4862. llvm::SmallVector<uint32_t, 4> indices;
  4863. for (uint32_t i = 0; i < count - 1; ++i)
  4864. indices.push_back(i);
  4865. const uint32_t typeId = theBuilder.getVecType(elemTypeId, count - 1);
  4866. *residual = theBuilder.createVectorShuffle(typeId, vec, vec, indices);
  4867. }
  4868. *lastElement =
  4869. theBuilder.createCompositeExtract(elemTypeId, vec, {count - 1});
  4870. }
  4871. uint32_t SPIRVEmitter::convertVectorToStruct(QualType structType,
  4872. uint32_t elemTypeId,
  4873. uint32_t vector) {
  4874. assert(structType->isStructureType());
  4875. const auto *structDecl = structType->getAsStructureType()->getDecl();
  4876. uint32_t vectorIndex = 0;
  4877. uint32_t elemCount = 1;
  4878. llvm::SmallVector<uint32_t, 4> members;
  4879. for (const auto *field : structDecl->fields()) {
  4880. if (TypeTranslator::isScalarType(field->getType())) {
  4881. members.push_back(theBuilder.createCompositeExtract(elemTypeId, vector,
  4882. {vectorIndex++}));
  4883. } else if (TypeTranslator::isVectorType(field->getType(), nullptr,
  4884. &elemCount)) {
  4885. llvm::SmallVector<uint32_t, 4> indices;
  4886. for (uint32_t i = 0; i < elemCount; ++i)
  4887. indices.push_back(vectorIndex++);
  4888. const uint32_t type = theBuilder.getVecType(elemTypeId, elemCount);
  4889. members.push_back(
  4890. theBuilder.createVectorShuffle(type, vector, vector, indices));
  4891. } else {
  4892. assert(false && "unhandled type");
  4893. }
  4894. }
  4895. return theBuilder.createCompositeConstruct(
  4896. typeTranslator.translateType(structType), members);
  4897. }
  4898. SpirvEvalInfo
  4899. SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
  4900. const QualType type = expr->getType();
  4901. const SourceRange range = expr->getSourceRange();
  4902. // We can only translate floatN * float into OpVectorTimesScalar.
  4903. // So the result type must be floatN.
  4904. if (!hlsl::IsHLSLVecType(type) ||
  4905. !hlsl::GetHLSLVecElementType(type)->isFloatingType())
  4906. return 0;
  4907. const Expr *lhs = expr->getLHS();
  4908. const Expr *rhs = expr->getRHS();
  4909. // Multiplying a float vector with a float scalar will be represented in
  4910. // AST via a binary operation with two float vectors as operands; one of
  4911. // the operand is from an implicit cast with kind CK_HLSLVectorSplat.
  4912. // vector * scalar
  4913. if (hlsl::IsHLSLVecType(lhs->getType())) {
  4914. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  4915. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  4916. const QualType vecType = expr->getType();
  4917. if (isa<CompoundAssignOperator>(expr)) {
  4918. SpirvEvalInfo lhsPtr = 0;
  4919. const auto result = processBinaryOp(
  4920. lhs, cast->getSubExpr(), expr->getOpcode(), vecType, vecType,
  4921. range, &lhsPtr, spv::Op::OpVectorTimesScalar);
  4922. return processAssignment(lhs, result, true, lhsPtr);
  4923. } else {
  4924. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  4925. vecType, vecType, range, nullptr,
  4926. spv::Op::OpVectorTimesScalar);
  4927. }
  4928. }
  4929. }
  4930. }
  4931. // scalar * vector
  4932. if (hlsl::IsHLSLVecType(rhs->getType())) {
  4933. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  4934. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  4935. const QualType vecType = expr->getType();
  4936. // We need to switch the positions of lhs and rhs here because
  4937. // OpVectorTimesScalar requires the first operand to be a vector and
  4938. // the second to be a scalar.
  4939. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  4940. vecType, vecType, range, nullptr,
  4941. spv::Op::OpVectorTimesScalar);
  4942. }
  4943. }
  4944. }
  4945. return 0;
  4946. }
  4947. SpirvEvalInfo
  4948. SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
  4949. const QualType type = expr->getType();
  4950. const SourceRange range = expr->getSourceRange();
  4951. // We can only translate floatMxN * float into OpMatrixTimesScalar.
  4952. // So the result type must be floatMxN.
  4953. if (!hlsl::IsHLSLMatType(type) ||
  4954. !hlsl::GetHLSLMatElementType(type)->isFloatingType())
  4955. return 0;
  4956. const Expr *lhs = expr->getLHS();
  4957. const Expr *rhs = expr->getRHS();
  4958. const QualType lhsType = lhs->getType();
  4959. const QualType rhsType = rhs->getType();
  4960. const auto selectOpcode = [](const QualType ty) {
  4961. return TypeTranslator::isMx1Matrix(ty) || TypeTranslator::is1xNMatrix(ty)
  4962. ? spv::Op::OpVectorTimesScalar
  4963. : spv::Op::OpMatrixTimesScalar;
  4964. };
  4965. // Multiplying a float matrix with a float scalar will be represented in
  4966. // AST via a binary operation with two float matrices as operands; one of
  4967. // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
  4968. // matrix * scalar
  4969. if (hlsl::IsHLSLMatType(lhsType)) {
  4970. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  4971. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  4972. const QualType matType = expr->getType();
  4973. const spv::Op opcode = selectOpcode(lhsType);
  4974. if (isa<CompoundAssignOperator>(expr)) {
  4975. SpirvEvalInfo lhsPtr = 0;
  4976. const auto result =
  4977. processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  4978. matType, matType, range, &lhsPtr, opcode);
  4979. return processAssignment(lhs, result, true, lhsPtr);
  4980. } else {
  4981. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  4982. matType, matType, range, nullptr, opcode);
  4983. }
  4984. }
  4985. }
  4986. }
  4987. // scalar * matrix
  4988. if (hlsl::IsHLSLMatType(rhsType)) {
  4989. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  4990. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  4991. const QualType matType = expr->getType();
  4992. const spv::Op opcode = selectOpcode(rhsType);
  4993. // We need to switch the positions of lhs and rhs here because
  4994. // OpMatrixTimesScalar requires the first operand to be a matrix and
  4995. // the second to be a scalar.
  4996. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  4997. matType, matType, range, nullptr, opcode);
  4998. }
  4999. }
  5000. }
  5001. return 0;
  5002. }
  5003. SpirvEvalInfo
  5004. SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
  5005. const SpirvEvalInfo &rhs) {
  5006. // Assigning to a vector swizzling lhs is tricky if we are neither
  5007. // writing to one element nor all elements in their original order.
  5008. // Under such cases, we need to create a new vector swizzling involving
  5009. // both the lhs and rhs vectors and then write the result of this swizzling
  5010. // into the base vector of lhs.
  5011. // For example, for vec4.yz = vec2, we nee to do the following:
  5012. //
  5013. // %vec4Val = OpLoad %v4float %vec4
  5014. // %vec2Val = OpLoad %v2float %vec2
  5015. // %shuffle = OpVectorShuffle %v4float %vec4Val %vec2Val 0 4 5 3
  5016. // OpStore %vec4 %shuffle
  5017. //
  5018. // When doing the vector shuffle, we use the lhs base vector as the first
  5019. // vector and the rhs vector as the second vector. Therefore, all elements
  5020. // in the second vector will be selected into the shuffle result.
  5021. const auto *lhsExpr = dyn_cast<HLSLVectorElementExpr>(lhs);
  5022. if (!lhsExpr)
  5023. return 0;
  5024. // Special case for <scalar-value>.x, which will have an AST of
  5025. // HLSLVectorElementExpr whose base is an ImplicitCastExpr
  5026. // (CK_HLSLVectorSplat). We just need to assign to <scalar-value>
  5027. // for such case.
  5028. if (const auto *baseCast = dyn_cast<CastExpr>(lhsExpr->getBase()))
  5029. if (baseCast->getCastKind() == CastKind::CK_HLSLVectorSplat &&
  5030. hlsl::GetHLSLVecSize(baseCast->getType()) == 1)
  5031. return processAssignment(baseCast->getSubExpr(), rhs, false);
  5032. const Expr *base = nullptr;
  5033. hlsl::VectorMemberAccessPositions accessor;
  5034. condenseVectorElementExpr(lhsExpr, &base, &accessor);
  5035. const QualType baseType = base->getType();
  5036. assert(hlsl::IsHLSLVecType(baseType));
  5037. const uint32_t baseTypeId = typeTranslator.translateType(baseType);
  5038. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  5039. const auto accessorSize = accessor.Count;
  5040. // Whether selecting the whole original vector
  5041. bool isSelectOrigin = accessorSize == baseSize;
  5042. // Assigning to one component
  5043. if (accessorSize == 1) {
  5044. if (isBufferTextureIndexing(dyn_cast_or_null<CXXOperatorCallExpr>(base))) {
  5045. // Assigning to one component of a RWBuffer/RWTexture element
  5046. // We need to use OpImageWrite here.
  5047. // Compose the new vector value first
  5048. const uint32_t oldVec = doExpr(base);
  5049. const uint32_t newVec = theBuilder.createCompositeInsert(
  5050. baseTypeId, oldVec, {accessor.Swz0}, rhs);
  5051. const auto result = tryToAssignToRWBufferRWTexture(base, newVec);
  5052. assert(result); // Definitely RWBuffer/RWTexture assignment
  5053. (void)result;
  5054. return rhs; // TODO: incorrect for compound assignments
  5055. } else {
  5056. // Assigning to one normal vector component. Nothing special, just fall
  5057. // back to the normal CodeGen path.
  5058. return 0;
  5059. }
  5060. }
  5061. if (isSelectOrigin) {
  5062. for (uint32_t i = 0; i < accessorSize; ++i) {
  5063. uint32_t position;
  5064. accessor.GetPosition(i, &position);
  5065. if (position != i)
  5066. isSelectOrigin = false;
  5067. }
  5068. }
  5069. // Assigning to the original vector
  5070. if (isSelectOrigin) {
  5071. // Ignore this HLSLVectorElementExpr and dispatch to base
  5072. return processAssignment(base, rhs, false);
  5073. }
  5074. llvm::SmallVector<uint32_t, 4> selectors;
  5075. selectors.resize(baseSize);
  5076. // Assume we are selecting all original elements first.
  5077. for (uint32_t i = 0; i < baseSize; ++i) {
  5078. selectors[i] = i;
  5079. }
  5080. // Now fix up the elements that actually got overwritten by the rhs vector.
  5081. // Since we are using the rhs vector as the second vector, their index
  5082. // should be offset'ed by the size of the lhs base vector.
  5083. for (uint32_t i = 0; i < accessor.Count; ++i) {
  5084. uint32_t position;
  5085. accessor.GetPosition(i, &position);
  5086. selectors[position] = baseSize + i;
  5087. }
  5088. const auto vec1 = doExpr(base);
  5089. const uint32_t vec1Val = vec1.isRValue()
  5090. ? static_cast<uint32_t>(vec1)
  5091. : theBuilder.createLoad(baseTypeId, vec1);
  5092. const uint32_t shuffle =
  5093. theBuilder.createVectorShuffle(baseTypeId, vec1Val, rhs, selectors);
  5094. if (!tryToAssignToRWBufferRWTexture(base, shuffle))
  5095. theBuilder.createStore(vec1, shuffle);
  5096. // TODO: OK, this return value is incorrect for compound assignments, for
  5097. // which cases we should return lvalues. Should at least emit errors if
  5098. // this return value is used (can be checked via ASTContext.getParents).
  5099. return rhs;
  5100. }
  5101. SpirvEvalInfo
  5102. SPIRVEmitter::tryToAssignToRWBufferRWTexture(const Expr *lhs,
  5103. const SpirvEvalInfo &rhs) {
  5104. const Expr *baseExpr = nullptr;
  5105. const Expr *indexExpr = nullptr;
  5106. const auto lhsExpr = dyn_cast<CXXOperatorCallExpr>(lhs);
  5107. if (isBufferTextureIndexing(lhsExpr, &baseExpr, &indexExpr)) {
  5108. const uint32_t locId = doExpr(indexExpr);
  5109. const QualType imageType = baseExpr->getType();
  5110. const auto baseInfo = doExpr(baseExpr);
  5111. const uint32_t imageId = theBuilder.createLoad(
  5112. typeTranslator.translateType(imageType), baseInfo);
  5113. theBuilder.createImageWrite(imageType, imageId, locId, rhs);
  5114. if (baseInfo.isNonUniform()) {
  5115. // Decorate the image handle for OpImageWrite
  5116. theBuilder.decorateNonUniformEXT(imageId);
  5117. }
  5118. return rhs;
  5119. }
  5120. return 0;
  5121. }
  5122. SpirvEvalInfo
  5123. SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
  5124. const SpirvEvalInfo &rhs) {
  5125. const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
  5126. if (!lhsExpr)
  5127. return 0;
  5128. const Expr *baseMat = lhsExpr->getBase();
  5129. const auto base = doExpr(baseMat);
  5130. const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
  5131. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  5132. uint32_t rowCount = 0, colCount = 0;
  5133. hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
  5134. // For each lhs element written to:
  5135. // 1. Extract the corresponding rhs element using OpCompositeExtract
  5136. // 2. Create access chain for the lhs element using OpAccessChain
  5137. // 3. Write using OpStore
  5138. const auto accessor = lhsExpr->getEncodedElementAccess();
  5139. for (uint32_t i = 0; i < accessor.Count; ++i) {
  5140. uint32_t row = 0, col = 0;
  5141. accessor.GetPosition(i, &row, &col);
  5142. llvm::SmallVector<uint32_t, 2> indices;
  5143. // If the matrix only have one row/column, we are indexing into a vector
  5144. // then. Only one index is needed for such cases.
  5145. if (rowCount > 1)
  5146. indices.push_back(row);
  5147. if (colCount > 1)
  5148. indices.push_back(col);
  5149. for (uint32_t i = 0; i < indices.size(); ++i)
  5150. indices[i] = theBuilder.getConstantInt32(indices[i]);
  5151. // If we are writing to only one element, the rhs should already be a
  5152. // scalar value.
  5153. uint32_t rhsElem = rhs;
  5154. if (accessor.Count > 1)
  5155. rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
  5156. const uint32_t ptrType =
  5157. theBuilder.getPointerType(elemTypeId, base.getStorageClass());
  5158. // If the lhs is actually a matrix of size 1x1, we don't need the access
  5159. // chain. base is already the dest pointer.
  5160. uint32_t lhsElemPtr = base;
  5161. if (!indices.empty()) {
  5162. assert(!base.isRValue());
  5163. // Load the element via access chain
  5164. lhsElemPtr = theBuilder.createAccessChain(ptrType, lhsElemPtr, indices);
  5165. }
  5166. theBuilder.createStore(lhsElemPtr, rhsElem);
  5167. }
  5168. // TODO: OK, this return value is incorrect for compound assignments, for
  5169. // which cases we should return lvalues. Should at least emit errors if
  5170. // this return value is used (can be checked via ASTContext.getParents).
  5171. return rhs;
  5172. }
  5173. SpirvEvalInfo SPIRVEmitter::processEachVectorInMatrix(
  5174. const Expr *matrix, const uint32_t matrixVal,
  5175. llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
  5176. actOnEachVector) {
  5177. const auto matType = matrix->getType();
  5178. assert(TypeTranslator::isMxNMatrix(matType));
  5179. const uint32_t vecType = typeTranslator.getComponentVectorType(matType);
  5180. uint32_t rowCount = 0, colCount = 0;
  5181. hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
  5182. llvm::SmallVector<uint32_t, 4> vectors;
  5183. // Extract each component vector and do operation on it
  5184. for (uint32_t i = 0; i < rowCount; ++i) {
  5185. const uint32_t lhsVec =
  5186. theBuilder.createCompositeExtract(vecType, matrixVal, {i});
  5187. vectors.push_back(actOnEachVector(i, vecType, lhsVec));
  5188. }
  5189. // Construct the result matrix
  5190. const auto valId = theBuilder.createCompositeConstruct(
  5191. typeTranslator.translateType(matType), vectors);
  5192. return SpirvEvalInfo(valId).setRValue();
  5193. }
  5194. void SPIRVEmitter::createSpecConstant(const VarDecl *varDecl) {
  5195. class SpecConstantEnvRAII {
  5196. public:
  5197. // Creates a new instance which sets mode to true on creation,
  5198. // and resets mode to false on destruction.
  5199. SpecConstantEnvRAII(bool *mode) : modeSlot(mode) { *modeSlot = true; }
  5200. ~SpecConstantEnvRAII() { *modeSlot = false; }
  5201. private:
  5202. bool *modeSlot;
  5203. };
  5204. const QualType varType = varDecl->getType();
  5205. bool hasError = false;
  5206. if (!varDecl->isExternallyVisible()) {
  5207. emitError("specialization constant must be externally visible",
  5208. varDecl->getLocation());
  5209. hasError = true;
  5210. }
  5211. if (const auto *builtinType = varType->getAs<BuiltinType>()) {
  5212. switch (builtinType->getKind()) {
  5213. case BuiltinType::Bool:
  5214. case BuiltinType::Int:
  5215. case BuiltinType::UInt:
  5216. case BuiltinType::Float:
  5217. break;
  5218. default:
  5219. emitError("unsupported specialization constant type",
  5220. varDecl->getLocStart());
  5221. hasError = true;
  5222. }
  5223. }
  5224. const auto *init = varDecl->getInit();
  5225. if (!init) {
  5226. emitError("missing default value for specialization constant",
  5227. varDecl->getLocation());
  5228. hasError = true;
  5229. } else if (!isAcceptedSpecConstantInit(init)) {
  5230. emitError("unsupported specialization constant initializer",
  5231. init->getLocStart())
  5232. << init->getSourceRange();
  5233. hasError = true;
  5234. }
  5235. if (hasError)
  5236. return;
  5237. SpecConstantEnvRAII specConstantEnvRAII(&isSpecConstantMode);
  5238. const auto specConstant = doExpr(init);
  5239. // We are not creating a variable to hold the spec constant, instead, we
  5240. // translate the varDecl directly into the spec constant here.
  5241. theBuilder.decorateSpecId(
  5242. specConstant, varDecl->getAttr<VKConstantIdAttr>()->getSpecConstId());
  5243. declIdMapper.registerSpecConstant(varDecl, specConstant);
  5244. }
  5245. SpirvEvalInfo
  5246. SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
  5247. const BinaryOperatorKind opcode,
  5248. SourceRange range) {
  5249. // TODO: some code are duplicated from processBinaryOp. Try to unify them.
  5250. const auto lhsType = lhs->getType();
  5251. assert(TypeTranslator::isMxNMatrix(lhsType));
  5252. const spv::Op spvOp = translateOp(opcode, lhsType);
  5253. uint32_t rhsVal, lhsPtr, lhsVal;
  5254. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  5255. // Evalute rhs before lhs
  5256. rhsVal = doExpr(rhs);
  5257. lhsPtr = doExpr(lhs);
  5258. const uint32_t lhsTy = typeTranslator.translateType(lhsType);
  5259. lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
  5260. } else {
  5261. // Evalute lhs before rhs
  5262. lhsVal = lhsPtr = doExpr(lhs);
  5263. rhsVal = doExpr(rhs);
  5264. }
  5265. switch (opcode) {
  5266. case BO_Add:
  5267. case BO_Sub:
  5268. case BO_Mul:
  5269. case BO_Div:
  5270. case BO_Rem:
  5271. case BO_AddAssign:
  5272. case BO_SubAssign:
  5273. case BO_MulAssign:
  5274. case BO_DivAssign:
  5275. case BO_RemAssign: {
  5276. const auto actOnEachVec = [this, spvOp, rhsVal](uint32_t index,
  5277. uint32_t vecType,
  5278. uint32_t lhsVec) {
  5279. // For each vector of lhs, we need to load the corresponding vector of
  5280. // rhs and do the operation on them.
  5281. const uint32_t rhsVec =
  5282. theBuilder.createCompositeExtract(vecType, rhsVal, {index});
  5283. const auto valId =
  5284. theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec);
  5285. return SpirvEvalInfo(valId).setRValue();
  5286. };
  5287. return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec);
  5288. }
  5289. case BO_Assign:
  5290. llvm_unreachable("assignment should not be handled here");
  5291. default:
  5292. break;
  5293. }
  5294. emitError("binary operator '%0' over matrix type unimplemented",
  5295. lhs->getExprLoc())
  5296. << BinaryOperator::getOpcodeStr(opcode) << range;
  5297. return 0;
  5298. }
  5299. const Expr *SPIRVEmitter::collectArrayStructIndices(
  5300. const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices, bool rawIndex) {
  5301. if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
  5302. // First check whether this is referring to a static member. If it is, we
  5303. // create a DeclRefExpr for it.
  5304. if (auto *varDecl = dyn_cast<VarDecl>(indexing->getMemberDecl()))
  5305. if (varDecl->isStaticDataMember())
  5306. return DeclRefExpr::Create(
  5307. astContext, NestedNameSpecifierLoc(), SourceLocation(), varDecl,
  5308. /*RefersToEnclosingVariableOrCapture=*/false, SourceLocation(),
  5309. varDecl->getType(), VK_LValue);
  5310. const Expr *base = collectArrayStructIndices(
  5311. indexing->getBase()->IgnoreParenNoopCasts(astContext), indices,
  5312. rawIndex);
  5313. // Append the index of the current level
  5314. const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
  5315. assert(fieldDecl);
  5316. // If we are accessing a derived struct, we need to account for the number
  5317. // of base structs, since they are placed as fields at the beginning of the
  5318. // derived struct.
  5319. const uint32_t index = getNumBaseClasses(indexing->getBase()->getType()) +
  5320. fieldDecl->getFieldIndex();
  5321. indices->push_back(rawIndex ? index : theBuilder.getConstantInt32(index));
  5322. return base;
  5323. }
  5324. // Provide a hint to the TypeTranslator that the integer literal used to
  5325. // index into the following cases should be translated as a 32-bit integer.
  5326. TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
  5327. if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
  5328. if (rawIndex)
  5329. return nullptr; // TODO: handle constant array index
  5330. // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
  5331. // cast. We need to ingore it to avoid creating OpLoad.
  5332. const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
  5333. const Expr *base = collectArrayStructIndices(thisBase, indices, rawIndex);
  5334. indices->push_back(doExpr(indexing->getIdx()));
  5335. return base;
  5336. }
  5337. if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
  5338. if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
  5339. if (rawIndex)
  5340. return nullptr; // TODO: handle constant array index
  5341. // If this is indexing into resources, we need specific OpImage*
  5342. // instructions for accessing. Return directly to avoid further building
  5343. // up the access chain.
  5344. if (isBufferTextureIndexing(indexing))
  5345. return indexing;
  5346. const Expr *thisBase =
  5347. indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
  5348. const auto thisBaseType = thisBase->getType();
  5349. const Expr *base = collectArrayStructIndices(thisBase, indices, rawIndex);
  5350. if (thisBaseType != base->getType() &&
  5351. TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
  5352. // The immediate base is a kind of structured or byte buffer. It should
  5353. // be an alias variable. Break the normal index collecting chain.
  5354. // Return the immediate base as the base so that we can apply other
  5355. // hacks for legalization over it.
  5356. //
  5357. // Note: legalization specific code
  5358. indices->clear();
  5359. base = thisBase;
  5360. }
  5361. // If the base is a StructureType, we need to push an addtional index 0
  5362. // here. This is because we created an additional OpTypeRuntimeArray
  5363. // in the structure.
  5364. if (TypeTranslator::isStructuredBuffer(thisBaseType))
  5365. indices->push_back(theBuilder.getConstantInt32(0));
  5366. if ((hlsl::IsHLSLVecType(thisBaseType) &&
  5367. (hlsl::GetHLSLVecSize(thisBaseType) == 1)) ||
  5368. typeTranslator.is1x1Matrix(thisBaseType) ||
  5369. typeTranslator.is1xNMatrix(thisBaseType)) {
  5370. // If this is a size-1 vector or 1xN matrix, ignore the index.
  5371. } else {
  5372. indices->push_back(doExpr(indexing->getArg(1)));
  5373. }
  5374. return base;
  5375. }
  5376. {
  5377. const Expr *index = nullptr;
  5378. // TODO: the following is duplicating the logic in doCXXMemberCallExpr.
  5379. if (const auto *object = isStructuredBufferLoad(expr, &index)) {
  5380. if (rawIndex)
  5381. return nullptr; // TODO: handle constant array index
  5382. // For object.Load(index), there should be no more indexing into the
  5383. // object.
  5384. indices->push_back(theBuilder.getConstantInt32(0));
  5385. indices->push_back(doExpr(index));
  5386. return object;
  5387. }
  5388. }
  5389. // This the deepest we can go. No more array or struct indexing.
  5390. return expr;
  5391. }
  5392. SpirvEvalInfo &SPIRVEmitter::turnIntoElementPtr(
  5393. QualType baseType, SpirvEvalInfo &base, QualType elemType,
  5394. const llvm::SmallVector<uint32_t, 4> &indices) {
  5395. // If this is a rvalue, we need a temporary object to hold it
  5396. // so that we can get access chain from it.
  5397. const bool needTempVar = base.isRValue();
  5398. if (needTempVar) {
  5399. auto varName = TypeTranslator::getName(baseType);
  5400. const auto var = createTemporaryVar(baseType, varName, base);
  5401. base.setResultId(var)
  5402. .setLayoutRule(SpirvLayoutRule::Void)
  5403. .setStorageClass(spv::StorageClass::Function);
  5404. }
  5405. const uint32_t elemTypeId =
  5406. typeTranslator.translateType(elemType, base.getLayoutRule());
  5407. const uint32_t ptrType =
  5408. theBuilder.getPointerType(elemTypeId, base.getStorageClass());
  5409. base.setResultId(theBuilder.createAccessChain(ptrType, base, indices));
  5410. // Okay, this part seems weird, but it is intended:
  5411. // If the base is originally a rvalue, the whole AST involving the base
  5412. // is consistently set up to handle rvalues. By copying the base into
  5413. // a temporary variable and grab an access chain from it, we are breaking
  5414. // the consistency by turning the base from rvalue into lvalue. Keep in
  5415. // mind that there will be no LValueToRValue casts in the AST for us
  5416. // to rely on to load the access chain if a rvalue is expected. Therefore,
  5417. // we must do the load here. Otherwise, it's up to the consumer of this
  5418. // access chain to do the load, and that can be everywhere.
  5419. if (needTempVar) {
  5420. base.setResultId(theBuilder.createLoad(elemTypeId, base));
  5421. }
  5422. return base;
  5423. }
  5424. uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
  5425. QualType toBoolType) {
  5426. if (TypeTranslator::isSameScalarOrVecType(fromType, toBoolType))
  5427. return fromVal;
  5428. const uint32_t boolType = typeTranslator.translateType(toBoolType);
  5429. { // Special case handling for converting to a matrix of booleans.
  5430. QualType elemType = {};
  5431. uint32_t rowCount = 0, colCount = 0;
  5432. if (TypeTranslator::isMxNMatrix(fromType, &elemType, &rowCount,
  5433. &colCount)) {
  5434. const auto fromRowQualType =
  5435. astContext.getExtVectorType(elemType, colCount);
  5436. const auto fromRowQualTypeId =
  5437. typeTranslator.translateType(fromRowQualType);
  5438. const auto toBoolRowQualType =
  5439. astContext.getExtVectorType(astContext.BoolTy, colCount);
  5440. llvm::SmallVector<uint32_t, 4> rows;
  5441. for (uint32_t i = 0; i < rowCount; ++i) {
  5442. const auto row =
  5443. theBuilder.createCompositeExtract(fromRowQualTypeId, fromVal, {i});
  5444. rows.push_back(castToBool(row, fromRowQualType, toBoolRowQualType));
  5445. }
  5446. return theBuilder.createCompositeConstruct(boolType, rows);
  5447. }
  5448. }
  5449. // Converting to bool means comparing with value zero.
  5450. const spv::Op spvOp = translateOp(BO_NE, fromType);
  5451. const uint32_t zeroVal = getValueZero(fromType);
  5452. return theBuilder.createBinaryOp(spvOp, boolType, fromVal, zeroVal);
  5453. }
  5454. uint32_t SPIRVEmitter::castToInt(uint32_t fromVal, QualType fromType,
  5455. QualType toIntType, SourceLocation srcLoc) {
  5456. if (TypeTranslator::isSameScalarOrVecType(fromType, toIntType))
  5457. return fromVal;
  5458. uint32_t intType = typeTranslator.translateType(toIntType);
  5459. if (isBoolOrVecOfBoolType(fromType)) {
  5460. const uint32_t one = getValueOne(toIntType);
  5461. const uint32_t zero = getValueZero(toIntType);
  5462. return theBuilder.createSelect(intType, fromVal, one, zero);
  5463. }
  5464. if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
  5465. // First convert the source to the bitwidth of the destination if necessary.
  5466. uint32_t convertedType = 0;
  5467. fromVal = convertBitwidth(fromVal, fromType, toIntType, &convertedType);
  5468. // If bitwidth conversion was the only thing we needed to do, we're done.
  5469. if (convertedType == typeTranslator.translateType(toIntType))
  5470. return fromVal;
  5471. return theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, fromVal);
  5472. }
  5473. if (isFloatOrVecOfFloatType(fromType)) {
  5474. // First convert the source to the bitwidth of the destination if necessary.
  5475. fromVal = convertBitwidth(fromVal, fromType, toIntType);
  5476. if (isSintOrVecOfSintType(toIntType)) {
  5477. return theBuilder.createUnaryOp(spv::Op::OpConvertFToS, intType, fromVal);
  5478. } else if (isUintOrVecOfUintType(toIntType)) {
  5479. return theBuilder.createUnaryOp(spv::Op::OpConvertFToU, intType, fromVal);
  5480. } else {
  5481. emitError("casting from floating point to integer unimplemented", srcLoc);
  5482. }
  5483. }
  5484. {
  5485. QualType elemType = {};
  5486. uint32_t numRows = 0, numCols = 0;
  5487. if (TypeTranslator::isMxNMatrix(fromType, &elemType, &numRows, &numCols)) {
  5488. // The source matrix and the target matrix must have the same dimensions.
  5489. QualType toElemType = {};
  5490. uint32_t toNumRows = 0, toNumCols = 0;
  5491. const bool isMat = TypeTranslator::isMxNMatrix(toIntType, &toElemType,
  5492. &toNumRows, &toNumCols);
  5493. assert(isMat && numRows == toNumRows && numCols == toNumCols);
  5494. (void)isMat;
  5495. (void)toNumRows;
  5496. (void)toNumCols;
  5497. // Casting to a matrix of integers: Cast each row and construct a
  5498. // composite.
  5499. llvm::SmallVector<uint32_t, 4> castedRows;
  5500. const uint32_t vecType = typeTranslator.getComponentVectorType(fromType);
  5501. const auto fromVecQualType =
  5502. astContext.getExtVectorType(elemType, numCols);
  5503. const auto toIntVecQualType =
  5504. astContext.getExtVectorType(toElemType, numCols);
  5505. for (uint32_t row = 0; row < numRows; ++row) {
  5506. const auto rowId =
  5507. theBuilder.createCompositeExtract(vecType, fromVal, {row});
  5508. castedRows.push_back(
  5509. castToInt(rowId, fromVecQualType, toIntVecQualType, srcLoc));
  5510. }
  5511. return theBuilder.createCompositeConstruct(intType, castedRows);
  5512. }
  5513. }
  5514. return 0;
  5515. }
  5516. uint32_t SPIRVEmitter::convertBitwidth(uint32_t fromVal, QualType fromType,
  5517. QualType toType, uint32_t *resultType) {
  5518. // At the moment, we will not make bitwidth conversions for literal int and
  5519. // literal float types because they always indicate 64-bit and do not
  5520. // represent what SPIR-V was actually resolved to.
  5521. // TODO: If the evaluated type is added to SpirvEvalInfo, change 'fromVal' to
  5522. // SpirvEvalInfo and use it to handle literal types more accurately.
  5523. if (fromType->isSpecificBuiltinType(BuiltinType::LitFloat) ||
  5524. fromType->isSpecificBuiltinType(BuiltinType::LitInt))
  5525. return fromVal;
  5526. const auto fromBitwidth = typeTranslator.getElementSpirvBitwidth(fromType);
  5527. const auto toBitwidth = typeTranslator.getElementSpirvBitwidth(toType);
  5528. if (fromBitwidth == toBitwidth) {
  5529. if (resultType)
  5530. *resultType = typeTranslator.translateType(fromType);
  5531. return fromVal;
  5532. }
  5533. // We want the 'fromType' with the 'toBitwidth'.
  5534. const uint32_t targetTypeId =
  5535. typeTranslator.getTypeWithCustomBitwidth(fromType, toBitwidth);
  5536. if (resultType)
  5537. *resultType = targetTypeId;
  5538. if (isFloatOrVecOfFloatType(fromType))
  5539. return theBuilder.createUnaryOp(spv::Op::OpFConvert, targetTypeId, fromVal);
  5540. if (isSintOrVecOfSintType(fromType))
  5541. return theBuilder.createUnaryOp(spv::Op::OpSConvert, targetTypeId, fromVal);
  5542. if (isUintOrVecOfUintType(fromType))
  5543. return theBuilder.createUnaryOp(spv::Op::OpUConvert, targetTypeId, fromVal);
  5544. llvm_unreachable("invalid type passed to convertBitwidth");
  5545. }
  5546. uint32_t SPIRVEmitter::castToFloat(uint32_t fromVal, QualType fromType,
  5547. QualType toFloatType,
  5548. SourceLocation srcLoc) {
  5549. if (TypeTranslator::isSameScalarOrVecType(fromType, toFloatType))
  5550. return fromVal;
  5551. const uint32_t floatType = typeTranslator.translateType(toFloatType);
  5552. if (isBoolOrVecOfBoolType(fromType)) {
  5553. const uint32_t one = getValueOne(toFloatType);
  5554. const uint32_t zero = getValueZero(toFloatType);
  5555. return theBuilder.createSelect(floatType, fromVal, one, zero);
  5556. }
  5557. if (isSintOrVecOfSintType(fromType)) {
  5558. // First convert the source to the bitwidth of the destination if necessary.
  5559. fromVal = convertBitwidth(fromVal, fromType, toFloatType);
  5560. return theBuilder.createUnaryOp(spv::Op::OpConvertSToF, floatType, fromVal);
  5561. }
  5562. if (isUintOrVecOfUintType(fromType)) {
  5563. // First convert the source to the bitwidth of the destination if necessary.
  5564. fromVal = convertBitwidth(fromVal, fromType, toFloatType);
  5565. return theBuilder.createUnaryOp(spv::Op::OpConvertUToF, floatType, fromVal);
  5566. }
  5567. if (isFloatOrVecOfFloatType(fromType)) {
  5568. // This is the case of float to float conversion with different bitwidths.
  5569. return convertBitwidth(fromVal, fromType, toFloatType);
  5570. }
  5571. // Casting matrix types
  5572. {
  5573. QualType elemType = {};
  5574. uint32_t numRows = 0, numCols = 0;
  5575. if (TypeTranslator::isMxNMatrix(fromType, &elemType, &numRows, &numCols)) {
  5576. // The source matrix and the target matrix must have the same dimensions.
  5577. QualType toElemType = {};
  5578. uint32_t toNumRows = 0, toNumCols = 0;
  5579. const auto isMat = TypeTranslator::isMxNMatrix(toFloatType, &toElemType,
  5580. &toNumRows, &toNumCols);
  5581. assert(isMat && numRows == toNumRows && numCols == toNumCols);
  5582. (void)isMat;
  5583. (void)toNumRows;
  5584. (void)toNumCols;
  5585. // Casting to a matrix of floats: Cast each row and construct a
  5586. // composite.
  5587. llvm::SmallVector<uint32_t, 4> castedRows;
  5588. const uint32_t vecType = typeTranslator.getComponentVectorType(fromType);
  5589. const auto fromVecQualType =
  5590. astContext.getExtVectorType(elemType, numCols);
  5591. const auto toIntVecQualType =
  5592. astContext.getExtVectorType(toElemType, numCols);
  5593. for (uint32_t row = 0; row < numRows; ++row) {
  5594. const auto rowId =
  5595. theBuilder.createCompositeExtract(vecType, fromVal, {row});
  5596. castedRows.push_back(
  5597. castToFloat(rowId, fromVecQualType, toIntVecQualType, srcLoc));
  5598. }
  5599. return theBuilder.createCompositeConstruct(floatType, castedRows);
  5600. }
  5601. }
  5602. emitError("casting to floating point unimplemented", srcLoc);
  5603. return 0;
  5604. }
  5605. SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
  5606. const FunctionDecl *callee = callExpr->getDirectCallee();
  5607. assert(hlsl::IsIntrinsicOp(callee) &&
  5608. "doIntrinsicCallExpr was called for a non-intrinsic function.");
  5609. const bool isFloatType = isFloatOrVecMatOfFloatType(callExpr->getType());
  5610. const bool isSintType = isSintOrVecMatOfSintType(callExpr->getType());
  5611. // Figure out which intrinsic function to translate.
  5612. llvm::StringRef group;
  5613. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  5614. hlsl::GetIntrinsicOp(callee, opcode, group);
  5615. GLSLstd450 glslOpcode = GLSLstd450Bad;
  5616. uint32_t retVal = 0;
  5617. #define INTRINSIC_SPIRV_OP_WITH_CAP_CASE(intrinsicOp, spirvOp, doEachVec, cap) \
  5618. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5619. theBuilder.requireCapability(cap); \
  5620. retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  5621. doEachVec); \
  5622. } break
  5623. #define INTRINSIC_SPIRV_OP_CASE(intrinsicOp, spirvOp, doEachVec) \
  5624. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5625. retVal = processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  5626. doEachVec); \
  5627. } break
  5628. #define INTRINSIC_OP_CASE(intrinsicOp, glslOp, doEachVec) \
  5629. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5630. glslOpcode = GLSLstd450::GLSLstd450##glslOp; \
  5631. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  5632. } break
  5633. #define INTRINSIC_OP_CASE_INT_FLOAT(intrinsicOp, glslIntOp, glslFloatOp, \
  5634. doEachVec) \
  5635. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5636. glslOpcode = isFloatType ? GLSLstd450::GLSLstd450##glslFloatOp \
  5637. : GLSLstd450::GLSLstd450##glslIntOp; \
  5638. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  5639. } break
  5640. #define INTRINSIC_OP_CASE_SINT_UINT(intrinsicOp, glslSintOp, glslUintOp, \
  5641. doEachVec) \
  5642. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5643. glslOpcode = isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  5644. : GLSLstd450::GLSLstd450##glslUintOp; \
  5645. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  5646. } break
  5647. #define INTRINSIC_OP_CASE_SINT_UINT_FLOAT(intrinsicOp, glslSintOp, glslUintOp, \
  5648. glslFloatOp, doEachVec) \
  5649. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  5650. glslOpcode = isFloatType \
  5651. ? GLSLstd450::GLSLstd450##glslFloatOp \
  5652. : isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  5653. : GLSLstd450::GLSLstd450##glslUintOp; \
  5654. retVal = processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  5655. } break
  5656. switch (const auto hlslOpcode = static_cast<hlsl::IntrinsicOp>(opcode)) {
  5657. case hlsl::IntrinsicOp::IOP_InterlockedAdd:
  5658. case hlsl::IntrinsicOp::IOP_InterlockedAnd:
  5659. case hlsl::IntrinsicOp::IOP_InterlockedMax:
  5660. case hlsl::IntrinsicOp::IOP_InterlockedUMax:
  5661. case hlsl::IntrinsicOp::IOP_InterlockedMin:
  5662. case hlsl::IntrinsicOp::IOP_InterlockedUMin:
  5663. case hlsl::IntrinsicOp::IOP_InterlockedOr:
  5664. case hlsl::IntrinsicOp::IOP_InterlockedXor:
  5665. case hlsl::IntrinsicOp::IOP_InterlockedExchange:
  5666. case hlsl::IntrinsicOp::IOP_InterlockedCompareStore:
  5667. case hlsl::IntrinsicOp::IOP_InterlockedCompareExchange:
  5668. retVal = processIntrinsicInterlockedMethod(callExpr, hlslOpcode);
  5669. break;
  5670. case hlsl::IntrinsicOp::IOP_NonUniformResourceIndex:
  5671. retVal = processIntrinsicNonUniformResourceIndex(callExpr);
  5672. break;
  5673. case hlsl::IntrinsicOp::IOP_tex1D:
  5674. case hlsl::IntrinsicOp::IOP_tex1Dbias:
  5675. case hlsl::IntrinsicOp::IOP_tex1Dgrad:
  5676. case hlsl::IntrinsicOp::IOP_tex1Dlod:
  5677. case hlsl::IntrinsicOp::IOP_tex1Dproj:
  5678. case hlsl::IntrinsicOp::IOP_tex2D:
  5679. case hlsl::IntrinsicOp::IOP_tex2Dbias:
  5680. case hlsl::IntrinsicOp::IOP_tex2Dgrad:
  5681. case hlsl::IntrinsicOp::IOP_tex2Dlod:
  5682. case hlsl::IntrinsicOp::IOP_tex2Dproj:
  5683. case hlsl::IntrinsicOp::IOP_tex3D:
  5684. case hlsl::IntrinsicOp::IOP_tex3Dbias:
  5685. case hlsl::IntrinsicOp::IOP_tex3Dgrad:
  5686. case hlsl::IntrinsicOp::IOP_tex3Dlod:
  5687. case hlsl::IntrinsicOp::IOP_tex3Dproj:
  5688. case hlsl::IntrinsicOp::IOP_texCUBE:
  5689. case hlsl::IntrinsicOp::IOP_texCUBEbias:
  5690. case hlsl::IntrinsicOp::IOP_texCUBEgrad:
  5691. case hlsl::IntrinsicOp::IOP_texCUBElod:
  5692. case hlsl::IntrinsicOp::IOP_texCUBEproj: {
  5693. emitError("deprecated %0 intrinsic function will not be supported",
  5694. callExpr->getExprLoc())
  5695. << callee->getName();
  5696. return 0;
  5697. }
  5698. case hlsl::IntrinsicOp::IOP_dot:
  5699. retVal = processIntrinsicDot(callExpr);
  5700. break;
  5701. case hlsl::IntrinsicOp::IOP_GroupMemoryBarrier:
  5702. retVal = processIntrinsicMemoryBarrier(callExpr,
  5703. /*isDevice*/ false,
  5704. /*groupSync*/ false,
  5705. /*isAllBarrier*/ false);
  5706. break;
  5707. case hlsl::IntrinsicOp::IOP_GroupMemoryBarrierWithGroupSync:
  5708. retVal = processIntrinsicMemoryBarrier(callExpr,
  5709. /*isDevice*/ false,
  5710. /*groupSync*/ true,
  5711. /*isAllBarrier*/ false);
  5712. break;
  5713. case hlsl::IntrinsicOp::IOP_DeviceMemoryBarrier:
  5714. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  5715. /*groupSync*/ false,
  5716. /*isAllBarrier*/ false);
  5717. break;
  5718. case hlsl::IntrinsicOp::IOP_DeviceMemoryBarrierWithGroupSync:
  5719. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  5720. /*groupSync*/ true,
  5721. /*isAllBarrier*/ false);
  5722. break;
  5723. case hlsl::IntrinsicOp::IOP_AllMemoryBarrier:
  5724. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  5725. /*groupSync*/ false,
  5726. /*isAllBarrier*/ true);
  5727. break;
  5728. case hlsl::IntrinsicOp::IOP_AllMemoryBarrierWithGroupSync:
  5729. retVal = processIntrinsicMemoryBarrier(callExpr, /*isDevice*/ true,
  5730. /*groupSync*/ true,
  5731. /*isAllBarrier*/ true);
  5732. break;
  5733. case hlsl::IntrinsicOp::IOP_CheckAccessFullyMapped:
  5734. retVal =
  5735. theBuilder.createImageSparseTexelsResident(doExpr(callExpr->getArg(0)));
  5736. break;
  5737. case hlsl::IntrinsicOp::IOP_mul:
  5738. case hlsl::IntrinsicOp::IOP_umul:
  5739. retVal = processIntrinsicMul(callExpr);
  5740. break;
  5741. case hlsl::IntrinsicOp::IOP_all:
  5742. retVal = processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
  5743. break;
  5744. case hlsl::IntrinsicOp::IOP_any:
  5745. retVal = processIntrinsicAllOrAny(callExpr, spv::Op::OpAny);
  5746. break;
  5747. case hlsl::IntrinsicOp::IOP_asdouble:
  5748. case hlsl::IntrinsicOp::IOP_asfloat:
  5749. case hlsl::IntrinsicOp::IOP_asint:
  5750. case hlsl::IntrinsicOp::IOP_asuint:
  5751. retVal = processIntrinsicAsType(callExpr);
  5752. break;
  5753. case hlsl::IntrinsicOp::IOP_clip:
  5754. retVal = processIntrinsicClip(callExpr);
  5755. break;
  5756. case hlsl::IntrinsicOp::IOP_dst:
  5757. retVal = processIntrinsicDst(callExpr);
  5758. break;
  5759. case hlsl::IntrinsicOp::IOP_clamp:
  5760. case hlsl::IntrinsicOp::IOP_uclamp:
  5761. retVal = processIntrinsicClamp(callExpr);
  5762. break;
  5763. case hlsl::IntrinsicOp::IOP_frexp:
  5764. retVal = processIntrinsicFrexp(callExpr);
  5765. break;
  5766. case hlsl::IntrinsicOp::IOP_ldexp:
  5767. retVal = processIntrinsicLdexp(callExpr);
  5768. break;
  5769. case hlsl::IntrinsicOp::IOP_lit:
  5770. retVal = processIntrinsicLit(callExpr);
  5771. break;
  5772. case hlsl::IntrinsicOp::IOP_modf:
  5773. retVal = processIntrinsicModf(callExpr);
  5774. break;
  5775. case hlsl::IntrinsicOp::IOP_msad4:
  5776. retVal = processIntrinsicMsad4(callExpr);
  5777. break;
  5778. case hlsl::IntrinsicOp::IOP_sign: {
  5779. if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
  5780. retVal = processIntrinsicFloatSign(callExpr);
  5781. else
  5782. retVal =
  5783. processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450SSign,
  5784. /*actPerRowForMatrices*/ true);
  5785. } break;
  5786. case hlsl::IntrinsicOp::IOP_D3DCOLORtoUBYTE4:
  5787. retVal = processD3DCOLORtoUBYTE4(callExpr);
  5788. break;
  5789. case hlsl::IntrinsicOp::IOP_isfinite:
  5790. retVal = processIntrinsicIsFinite(callExpr);
  5791. break;
  5792. case hlsl::IntrinsicOp::IOP_sincos:
  5793. retVal = processIntrinsicSinCos(callExpr);
  5794. break;
  5795. case hlsl::IntrinsicOp::IOP_rcp:
  5796. retVal = processIntrinsicRcp(callExpr);
  5797. break;
  5798. case hlsl::IntrinsicOp::IOP_saturate:
  5799. retVal = processIntrinsicSaturate(callExpr);
  5800. break;
  5801. case hlsl::IntrinsicOp::IOP_log10:
  5802. retVal = processIntrinsicLog10(callExpr);
  5803. break;
  5804. case hlsl::IntrinsicOp::IOP_f16tof32:
  5805. retVal = processIntrinsicF16ToF32(callExpr);
  5806. break;
  5807. case hlsl::IntrinsicOp::IOP_f32tof16:
  5808. retVal = processIntrinsicF32ToF16(callExpr);
  5809. break;
  5810. case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
  5811. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneCount",
  5812. callExpr->getExprLoc());
  5813. const uint32_t retType =
  5814. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  5815. const uint32_t varId =
  5816. declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupSize);
  5817. retVal = theBuilder.createLoad(retType, varId);
  5818. } break;
  5819. case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
  5820. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "WaveGetLaneIndex",
  5821. callExpr->getExprLoc());
  5822. const uint32_t retType =
  5823. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  5824. const uint32_t varId =
  5825. declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
  5826. retVal = theBuilder.createLoad(retType, varId);
  5827. } break;
  5828. case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
  5829. retVal = processWaveQuery(callExpr, spv::Op::OpGroupNonUniformElect);
  5830. break;
  5831. case hlsl::IntrinsicOp::IOP_WaveActiveAllTrue:
  5832. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAll);
  5833. break;
  5834. case hlsl::IntrinsicOp::IOP_WaveActiveAnyTrue:
  5835. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAny);
  5836. break;
  5837. case hlsl::IntrinsicOp::IOP_WaveActiveBallot:
  5838. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformBallot);
  5839. break;
  5840. case hlsl::IntrinsicOp::IOP_WaveActiveAllEqual:
  5841. retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
  5842. break;
  5843. case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
  5844. retVal = processWaveReductionOrPrefix(
  5845. callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
  5846. spv::GroupOperation::Reduce);
  5847. break;
  5848. case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
  5849. case hlsl::IntrinsicOp::IOP_WaveActiveSum:
  5850. case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
  5851. case hlsl::IntrinsicOp::IOP_WaveActiveProduct:
  5852. case hlsl::IntrinsicOp::IOP_WaveActiveUMax:
  5853. case hlsl::IntrinsicOp::IOP_WaveActiveMax:
  5854. case hlsl::IntrinsicOp::IOP_WaveActiveUMin:
  5855. case hlsl::IntrinsicOp::IOP_WaveActiveMin:
  5856. case hlsl::IntrinsicOp::IOP_WaveActiveBitAnd:
  5857. case hlsl::IntrinsicOp::IOP_WaveActiveBitOr:
  5858. case hlsl::IntrinsicOp::IOP_WaveActiveBitXor: {
  5859. const auto retType = callExpr->getCallReturnType(astContext);
  5860. retVal = processWaveReductionOrPrefix(
  5861. callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
  5862. spv::GroupOperation::Reduce);
  5863. } break;
  5864. case hlsl::IntrinsicOp::IOP_WavePrefixUSum:
  5865. case hlsl::IntrinsicOp::IOP_WavePrefixSum:
  5866. case hlsl::IntrinsicOp::IOP_WavePrefixUProduct:
  5867. case hlsl::IntrinsicOp::IOP_WavePrefixProduct: {
  5868. const auto retType = callExpr->getCallReturnType(astContext);
  5869. retVal = processWaveReductionOrPrefix(
  5870. callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
  5871. spv::GroupOperation::ExclusiveScan);
  5872. } break;
  5873. case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
  5874. retVal = processWaveReductionOrPrefix(
  5875. callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
  5876. spv::GroupOperation::ExclusiveScan);
  5877. break;
  5878. case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
  5879. case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
  5880. retVal = processWaveBroadcast(callExpr);
  5881. break;
  5882. case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
  5883. case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
  5884. case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
  5885. case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
  5886. retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
  5887. break;
  5888. case hlsl::IntrinsicOp::IOP_abort:
  5889. case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
  5890. case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
  5891. emitError("no equivalent for %0 intrinsic function in Vulkan",
  5892. callExpr->getExprLoc())
  5893. << callee->getName();
  5894. return 0;
  5895. }
  5896. case hlsl::IntrinsicOp::IOP_transpose: {
  5897. const Expr *mat = callExpr->getArg(0);
  5898. const QualType matType = mat->getType();
  5899. if (hlsl::GetHLSLMatElementType(matType)->isFloatingType())
  5900. retVal =
  5901. processIntrinsicUsingSpirvInst(callExpr, spv::Op::OpTranspose, false);
  5902. else
  5903. retVal = processNonFpMatrixTranspose(matType, doExpr(mat));
  5904. break;
  5905. }
  5906. INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
  5907. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_coarse, DPdxCoarse, false,
  5908. spv::Capability::DerivativeControl);
  5909. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_fine, DPdxFine, false,
  5910. spv::Capability::DerivativeControl);
  5911. INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
  5912. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_coarse, DPdyCoarse, false,
  5913. spv::Capability::DerivativeControl);
  5914. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_fine, DPdyFine, false,
  5915. spv::Capability::DerivativeControl);
  5916. INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
  5917. INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
  5918. INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
  5919. INTRINSIC_SPIRV_OP_CASE(fmod, FMod, true);
  5920. INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
  5921. INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
  5922. INTRINSIC_OP_CASE(round, Round, true);
  5923. INTRINSIC_OP_CASE_INT_FLOAT(abs, SAbs, FAbs, true);
  5924. INTRINSIC_OP_CASE(acos, Acos, true);
  5925. INTRINSIC_OP_CASE(asin, Asin, true);
  5926. INTRINSIC_OP_CASE(atan, Atan, true);
  5927. INTRINSIC_OP_CASE(atan2, Atan2, true);
  5928. INTRINSIC_OP_CASE(ceil, Ceil, true);
  5929. INTRINSIC_OP_CASE(cos, Cos, true);
  5930. INTRINSIC_OP_CASE(cosh, Cosh, true);
  5931. INTRINSIC_OP_CASE(cross, Cross, false);
  5932. INTRINSIC_OP_CASE(degrees, Degrees, true);
  5933. INTRINSIC_OP_CASE(distance, Distance, false);
  5934. INTRINSIC_OP_CASE(determinant, Determinant, false);
  5935. INTRINSIC_OP_CASE(exp, Exp, true);
  5936. INTRINSIC_OP_CASE(exp2, Exp2, true);
  5937. INTRINSIC_OP_CASE_SINT_UINT(firstbithigh, FindSMsb, FindUMsb, false);
  5938. INTRINSIC_OP_CASE_SINT_UINT(ufirstbithigh, FindSMsb, FindUMsb, false);
  5939. INTRINSIC_OP_CASE(faceforward, FaceForward, false);
  5940. INTRINSIC_OP_CASE(firstbitlow, FindILsb, false);
  5941. INTRINSIC_OP_CASE(floor, Floor, true);
  5942. INTRINSIC_OP_CASE(fma, Fma, true);
  5943. INTRINSIC_OP_CASE(frac, Fract, true);
  5944. INTRINSIC_OP_CASE(length, Length, false);
  5945. INTRINSIC_OP_CASE(lerp, FMix, true);
  5946. INTRINSIC_OP_CASE(log, Log, true);
  5947. INTRINSIC_OP_CASE(log2, Log2, true);
  5948. INTRINSIC_OP_CASE(mad, Fma, true);
  5949. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(max, SMax, UMax, FMax, true);
  5950. INTRINSIC_OP_CASE(umax, UMax, true);
  5951. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(min, SMin, UMin, FMin, true);
  5952. INTRINSIC_OP_CASE(umin, UMin, true);
  5953. INTRINSIC_OP_CASE(normalize, Normalize, false);
  5954. INTRINSIC_OP_CASE(pow, Pow, true);
  5955. INTRINSIC_OP_CASE(radians, Radians, true);
  5956. INTRINSIC_OP_CASE(reflect, Reflect, false);
  5957. INTRINSIC_OP_CASE(refract, Refract, false);
  5958. INTRINSIC_OP_CASE(rsqrt, InverseSqrt, true);
  5959. INTRINSIC_OP_CASE(smoothstep, SmoothStep, true);
  5960. INTRINSIC_OP_CASE(step, Step, true);
  5961. INTRINSIC_OP_CASE(sin, Sin, true);
  5962. INTRINSIC_OP_CASE(sinh, Sinh, true);
  5963. INTRINSIC_OP_CASE(tan, Tan, true);
  5964. INTRINSIC_OP_CASE(tanh, Tanh, true);
  5965. INTRINSIC_OP_CASE(sqrt, Sqrt, true);
  5966. INTRINSIC_OP_CASE(trunc, Trunc, true);
  5967. default:
  5968. emitError("%0 intrinsic function unimplemented", callExpr->getExprLoc())
  5969. << callee->getName();
  5970. return 0;
  5971. }
  5972. #undef INTRINSIC_OP_CASE
  5973. #undef INTRINSIC_OP_CASE_INT_FLOAT
  5974. return SpirvEvalInfo(retVal).setRValue();
  5975. }
  5976. uint32_t
  5977. SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
  5978. hlsl::IntrinsicOp opcode) {
  5979. // The signature of intrinsic atomic methods are:
  5980. // void Interlocked*(in R dest, in T value);
  5981. // void Interlocked*(in R dest, in T value, out T original_value);
  5982. // Note: ALL Interlocked*() methods are forced to have an unsigned integer
  5983. // 'value'. Meaning, T is forced to be 'unsigned int'. If the provided
  5984. // parameter is not an unsigned integer, the frontend inserts an
  5985. // 'ImplicitCastExpr' to convert it to unsigned integer. OpAtomicIAdd (and
  5986. // other SPIR-V OpAtomic* instructions) require that the pointee in 'dest' to
  5987. // be of the same type as T. This will result in an invalid SPIR-V if 'dest'
  5988. // is a signed integer typed resource such as RWTexture1D<int>. For example,
  5989. // the following OpAtomicIAdd is invalid because the pointee type defined in
  5990. // %1 is a signed integer, while the value passed to atomic add (%3) is an
  5991. // unsigned integer.
  5992. //
  5993. // %_ptr_Image_int = OpTypePointer Image %int
  5994. // %1 = OpImageTexelPointer %_ptr_Image_int %RWTexture1D_int %index %uint_0
  5995. // %2 = OpLoad %int %value
  5996. // %3 = OpBitcast %uint %2 <-------- Inserted by the frontend
  5997. // %4 = OpAtomicIAdd %int %1 %uint_1 %uint_0 %3
  5998. //
  5999. // In such cases, we bypass the forced IntegralCast.
  6000. // Moreover, the frontend does not add a cast AST node to cast uint to int
  6001. // where necessary. To ensure SPIR-V validity, we add that where necessary.
  6002. const uint32_t zero = theBuilder.getConstantUint32(0);
  6003. const uint32_t scope = theBuilder.getConstantUint32(1); // Device
  6004. const auto *dest = expr->getArg(0);
  6005. const auto baseType = dest->getType();
  6006. if (!baseType->isIntegerType()) {
  6007. emitError("can only perform atomic operations on scalar integer values",
  6008. dest->getLocStart());
  6009. return 0;
  6010. }
  6011. const uint32_t baseTypeId = typeTranslator.translateType(baseType);
  6012. const auto doArg = [baseType, this](const CallExpr *callExpr,
  6013. uint32_t argIndex) {
  6014. const Expr *valueExpr = callExpr->getArg(argIndex);
  6015. if (const auto *castExpr = dyn_cast<ImplicitCastExpr>(valueExpr))
  6016. if (castExpr->getCastKind() == CK_IntegralCast &&
  6017. castExpr->getSubExpr()->getType() == baseType)
  6018. valueExpr = castExpr->getSubExpr();
  6019. uint32_t argId = doExpr(valueExpr);
  6020. if (valueExpr->getType() != baseType)
  6021. argId = castToInt(argId, valueExpr->getType(), baseType,
  6022. valueExpr->getExprLoc());
  6023. return argId;
  6024. };
  6025. const auto writeToOutputArg = [&baseType, dest, this](
  6026. uint32_t toWrite, const CallExpr *callExpr,
  6027. uint32_t outputArgIndex) {
  6028. const auto outputArg = callExpr->getArg(outputArgIndex);
  6029. const auto outputArgType = outputArg->getType();
  6030. if (baseType != outputArgType)
  6031. toWrite = castToInt(toWrite, baseType, outputArgType, dest->getExprLoc());
  6032. theBuilder.createStore(doExpr(outputArg), toWrite);
  6033. };
  6034. // If the argument is indexing into a texture/buffer, we need to create an
  6035. // OpImageTexelPointer instruction.
  6036. uint32_t ptr = 0;
  6037. if (const auto *callExpr = dyn_cast<CXXOperatorCallExpr>(dest)) {
  6038. const Expr *base = nullptr;
  6039. const Expr *index = nullptr;
  6040. if (isBufferTextureIndexing(callExpr, &base, &index)) {
  6041. const auto ptrType =
  6042. theBuilder.getPointerType(baseTypeId, spv::StorageClass::Image);
  6043. auto baseId = doExpr(base);
  6044. if (baseId.isRValue()) {
  6045. // OpImageTexelPointer's Image argument must have a type of
  6046. // OpTypePointer with Type OpTypeImage. Need to create a temporary
  6047. // variable if the baseId is an rvalue.
  6048. baseId = createTemporaryVar(
  6049. base->getType(), TypeTranslator::getName(base->getType()), baseId);
  6050. }
  6051. const auto coordId = doExpr(index);
  6052. ptr = theBuilder.createImageTexelPointer(ptrType, baseId, coordId, zero);
  6053. if (baseId.isNonUniform()) {
  6054. // Image texel pointer will used to access image memory. Vulkan requires
  6055. // it to be decorated with NonUniformEXT.
  6056. theBuilder.decorateNonUniformEXT(ptr);
  6057. }
  6058. }
  6059. }
  6060. if (!ptr)
  6061. ptr = doExpr(dest);
  6062. const bool isCompareExchange =
  6063. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareExchange;
  6064. const bool isCompareStore =
  6065. opcode == hlsl::IntrinsicOp::IOP_InterlockedCompareStore;
  6066. if (isCompareExchange || isCompareStore) {
  6067. const uint32_t comparator = doArg(expr, 1);
  6068. const uint32_t valueId = doArg(expr, 2);
  6069. const uint32_t originalVal = theBuilder.createAtomicCompareExchange(
  6070. baseTypeId, ptr, scope, zero, zero, valueId, comparator);
  6071. if (isCompareExchange)
  6072. writeToOutputArg(originalVal, expr, 3);
  6073. } else {
  6074. const uint32_t valueId = doArg(expr, 1);
  6075. // Since these atomic operations write through the provided pointer, the
  6076. // signed vs. unsigned opcode must be decided based on the pointee type
  6077. // of the first argument. However, the frontend decides the opcode based on
  6078. // the second argument (value). Therefore, the HLSL opcode provided by the
  6079. // frontend may be wrong. Therefore we need the following code to make sure
  6080. // we are using the correct SPIR-V opcode.
  6081. spv::Op atomicOp = translateAtomicHlslOpcodeToSpirvOpcode(opcode);
  6082. if (atomicOp == spv::Op::OpAtomicUMax && baseType->isSignedIntegerType())
  6083. atomicOp = spv::Op::OpAtomicSMax;
  6084. if (atomicOp == spv::Op::OpAtomicSMax && baseType->isUnsignedIntegerType())
  6085. atomicOp = spv::Op::OpAtomicUMax;
  6086. if (atomicOp == spv::Op::OpAtomicUMin && baseType->isSignedIntegerType())
  6087. atomicOp = spv::Op::OpAtomicSMin;
  6088. if (atomicOp == spv::Op::OpAtomicSMin && baseType->isUnsignedIntegerType())
  6089. atomicOp = spv::Op::OpAtomicUMin;
  6090. const uint32_t originalVal = theBuilder.createAtomicOp(
  6091. atomicOp, baseTypeId, ptr, scope, zero, valueId);
  6092. if (expr->getNumArgs() > 2)
  6093. writeToOutputArg(originalVal, expr, 2);
  6094. }
  6095. return 0;
  6096. }
  6097. SpirvEvalInfo
  6098. SPIRVEmitter::processIntrinsicNonUniformResourceIndex(const CallExpr *expr) {
  6099. foundNonUniformResourceIndex = true;
  6100. theBuilder.addExtension(Extension::EXT_descriptor_indexing,
  6101. "NonUniformResourceIndex", expr->getExprLoc());
  6102. theBuilder.requireCapability(spv::Capability::ShaderNonUniformEXT);
  6103. auto index = doExpr(expr->getArg(0)).setNonUniform();
  6104. // Decorate the expression in NonUniformResourceIndex() with NonUniformEXT.
  6105. // Aside from this, we also need to eventually populate the NonUniformEXT
  6106. // status to the usage of this expression: the "pointer" operand to a memory
  6107. // access instruction. Vulkan spec has the following rules:
  6108. //
  6109. // If an instruction loads from or stores to a resource (including atomics and
  6110. // image instructions) and the resource descriptor being accessed is not
  6111. // dynamically uniform, then the operand corresponding to that resource (e.g.
  6112. // the pointer or sampled image operand) must be decorated with NonUniformEXT.
  6113. theBuilder.decorateNonUniformEXT(index);
  6114. return index;
  6115. }
  6116. uint32_t SPIRVEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
  6117. if (!spirvOptions.noWarnEmulatedFeatures)
  6118. emitWarning("msad4 intrinsic function is emulated using many SPIR-V "
  6119. "instructions due to lack of direct SPIR-V equivalent",
  6120. callExpr->getExprLoc());
  6121. // Compares a 4-byte reference value and an 8-byte source value and
  6122. // accumulates a vector of 4 sums. Each sum corresponds to the masked sum
  6123. // of absolute differences of a different byte alignment between the
  6124. // reference value and the source value.
  6125. // If we have:
  6126. // uint v0; // reference
  6127. // uint2 v1; // source
  6128. // uint4 v2; // accum
  6129. // uint4 o0; // result of msad4
  6130. // uint4 r0, t0; // temporary values
  6131. //
  6132. // Then msad4(v0, v1, v2) translates to the following SM5 assembly according
  6133. // to fxc:
  6134. // Step 1:
  6135. // ushr r0.xyz, v1.xxxx, l(8, 16, 24, 0)
  6136. // Step 2:
  6137. // [result], [ width ], [ offset ], [ insert ], [ base ]
  6138. // bfi t0.yzw, l(0, 8, 16, 24), l(0, 24, 16, 8), v1.yyyy , r0.xxyz
  6139. // mov t0.x, v1.x
  6140. // Step 3:
  6141. // msad o0.xyzw, v0.xxxx, t0.xyzw, v2.xyzw
  6142. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  6143. const auto boolType = theBuilder.getBoolType();
  6144. const auto intType = theBuilder.getInt32Type();
  6145. const auto uintType = theBuilder.getUint32Type();
  6146. const auto uint4Type = theBuilder.getVecType(uintType, 4);
  6147. const uint32_t reference = doExpr(callExpr->getArg(0));
  6148. const uint32_t source = doExpr(callExpr->getArg(1));
  6149. const uint32_t accum = doExpr(callExpr->getArg(2));
  6150. const auto uint0 = theBuilder.getConstantUint32(0);
  6151. const auto uint8 = theBuilder.getConstantUint32(8);
  6152. const auto uint16 = theBuilder.getConstantUint32(16);
  6153. const auto uint24 = theBuilder.getConstantUint32(24);
  6154. // Step 1.
  6155. const uint32_t v1x = theBuilder.createCompositeExtract(uintType, source, {0});
  6156. // r0.x = v1xS8 = v1.x shifted by 8 bits
  6157. uint32_t v1xS8 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  6158. uintType, v1x, uint8);
  6159. // r0.y = v1xS16 = v1.x shifted by 16 bits
  6160. uint32_t v1xS16 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  6161. uintType, v1x, uint16);
  6162. // r0.z = v1xS24 = v1.x shifted by 24 bits
  6163. uint32_t v1xS24 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  6164. uintType, v1x, uint24);
  6165. // Step 2.
  6166. // Do bfi 3 times. DXIL bfi is equivalent to SPIR-V OpBitFieldInsert.
  6167. const uint32_t v1y = theBuilder.createCompositeExtract(uintType, source, {1});
  6168. // Note that t0.x = v1.x, nothing we need to do for that.
  6169. const uint32_t t0y =
  6170. theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS8, /*insert*/ v1y,
  6171. /*offset*/ uint24,
  6172. /*width*/ uint8);
  6173. const uint32_t t0z =
  6174. theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS16, /*insert*/ v1y,
  6175. /*offset*/ uint16,
  6176. /*width*/ uint16);
  6177. const uint32_t t0w =
  6178. theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS24, /*insert*/ v1y,
  6179. /*offset*/ uint8,
  6180. /*width*/ uint24);
  6181. // Step 3. MSAD (Masked Sum of Absolute Differences)
  6182. // Now perform MSAD four times.
  6183. // Need to mimic this algorithm in SPIR-V!
  6184. //
  6185. // UINT msad( UINT ref, UINT src, UINT accum )
  6186. // {
  6187. // for (UINT i = 0; i < 4; i++)
  6188. // {
  6189. // BYTE refByte, srcByte, absDiff;
  6190. //
  6191. // refByte = (BYTE)(ref >> (i * 8));
  6192. // if (!refByte)
  6193. // {
  6194. // continue;
  6195. // }
  6196. //
  6197. // srcByte = (BYTE)(src >> (i * 8));
  6198. // if (refByte >= srcByte)
  6199. // {
  6200. // absDiff = refByte - srcByte;
  6201. // }
  6202. // else
  6203. // {
  6204. // absDiff = srcByte - refByte;
  6205. // }
  6206. //
  6207. // // The recommended overflow behavior for MSAD is
  6208. // // to do a 32-bit saturate. This is not
  6209. // // required, however, and wrapping is allowed.
  6210. // // So from an application point of view,
  6211. // // overflow behavior is undefined.
  6212. // if (UINT_MAX - accum < absDiff)
  6213. // {
  6214. // accum = UINT_MAX;
  6215. // break;
  6216. // }
  6217. // accum += absDiff;
  6218. // }
  6219. //
  6220. // return accum;
  6221. // }
  6222. llvm::SmallVector<uint32_t, 4> result;
  6223. const uint32_t accum0 =
  6224. theBuilder.createCompositeExtract(uintType, accum, {0});
  6225. const uint32_t accum1 =
  6226. theBuilder.createCompositeExtract(uintType, accum, {1});
  6227. const uint32_t accum2 =
  6228. theBuilder.createCompositeExtract(uintType, accum, {2});
  6229. const uint32_t accum3 =
  6230. theBuilder.createCompositeExtract(uintType, accum, {3});
  6231. const llvm::SmallVector<uint32_t, 4> sources = {v1x, t0y, t0z, t0w};
  6232. llvm::SmallVector<uint32_t, 4> accums = {accum0, accum1, accum2, accum3};
  6233. llvm::SmallVector<uint32_t, 4> refBytes;
  6234. llvm::SmallVector<uint32_t, 4> signedRefBytes;
  6235. llvm::SmallVector<uint32_t, 4> isRefByteZero;
  6236. for (uint32_t i = 0; i < 4; ++i) {
  6237. refBytes.push_back(theBuilder.createBitFieldExtract(
  6238. uintType, reference, /*offset*/ theBuilder.getConstantUint32(i * 8),
  6239. /*count*/ uint8, /*isSigned*/ false));
  6240. signedRefBytes.push_back(
  6241. theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, refBytes.back()));
  6242. isRefByteZero.push_back(theBuilder.createBinaryOp(
  6243. spv::Op::OpIEqual, boolType, refBytes.back(), uint0));
  6244. }
  6245. for (uint32_t msadNum = 0; msadNum < 4; ++msadNum) {
  6246. for (uint32_t byteCount = 0; byteCount < 4; ++byteCount) {
  6247. // 'count' is always 8 because we are extracting 8 bits out of 32.
  6248. const uint32_t srcByte = theBuilder.createBitFieldExtract(
  6249. uintType, sources[msadNum],
  6250. /*offset*/ theBuilder.getConstantUint32(8 * byteCount),
  6251. /*count*/ uint8, /*isSigned*/ false);
  6252. const uint32_t signedSrcByte =
  6253. theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, srcByte);
  6254. const uint32_t sub = theBuilder.createBinaryOp(
  6255. spv::Op::OpISub, intType, signedRefBytes[byteCount], signedSrcByte);
  6256. const uint32_t absSub = theBuilder.createExtInst(
  6257. intType, glsl, GLSLstd450::GLSLstd450SAbs, {sub});
  6258. const uint32_t diff = theBuilder.createSelect(
  6259. uintType, isRefByteZero[byteCount], uint0,
  6260. theBuilder.createUnaryOp(spv::Op::OpBitcast, uintType, absSub));
  6261. // As pointed out by the DXIL reference above, it is *not* required to
  6262. // saturate the output to UINT_MAX in case of overflow. Wrapping around is
  6263. // also allowed. For simplicity, we will wrap around at this point.
  6264. accums[msadNum] = theBuilder.createBinaryOp(spv::Op::OpIAdd, uintType,
  6265. accums[msadNum], diff);
  6266. }
  6267. }
  6268. return theBuilder.createCompositeConstruct(uint4Type, accums);
  6269. }
  6270. uint32_t SPIRVEmitter::processWaveQuery(const CallExpr *callExpr,
  6271. spv::Op opcode) {
  6272. // Signatures:
  6273. // bool WaveIsFirstLane()
  6274. // uint WaveGetLaneCount()
  6275. // uint WaveGetLaneIndex()
  6276. assert(callExpr->getNumArgs() == 0);
  6277. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  6278. callExpr->getExprLoc());
  6279. theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
  6280. const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
  6281. const uint32_t retType =
  6282. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  6283. return theBuilder.createGroupNonUniformOp(opcode, retType, subgroupScope);
  6284. }
  6285. uint32_t SPIRVEmitter::processWaveVote(const CallExpr *callExpr,
  6286. spv::Op opcode) {
  6287. // Signatures:
  6288. // bool WaveActiveAnyTrue( bool expr )
  6289. // bool WaveActiveAllTrue( bool expr )
  6290. // bool uint4 WaveActiveBallot( bool expr )
  6291. assert(callExpr->getNumArgs() == 1);
  6292. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  6293. callExpr->getExprLoc());
  6294. theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
  6295. const uint32_t predicate = doExpr(callExpr->getArg(0));
  6296. const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
  6297. const uint32_t retType =
  6298. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  6299. return theBuilder.createGroupNonUniformUnaryOp(opcode, retType, subgroupScope,
  6300. predicate);
  6301. }
  6302. spv::Op SPIRVEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
  6303. SourceLocation srcLoc) {
  6304. const bool isSintType = isSintOrVecMatOfSintType(type);
  6305. const bool isUintType = isUintOrVecMatOfUintType(type);
  6306. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  6307. #define WAVE_OP_CASE_INT(kind, intWaveOp) \
  6308. \
  6309. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  6310. if (isSintType || isUintType) { \
  6311. return spv::Op::OpGroupNonUniform##intWaveOp; \
  6312. } \
  6313. } break
  6314. #define WAVE_OP_CASE_INT_FLOAT(kind, intWaveOp, floatWaveOp) \
  6315. \
  6316. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  6317. if (isSintType || isUintType) { \
  6318. return spv::Op::OpGroupNonUniform##intWaveOp; \
  6319. } \
  6320. if (isFloatType) { \
  6321. return spv::Op::OpGroupNonUniform##floatWaveOp; \
  6322. } \
  6323. } break
  6324. #define WAVE_OP_CASE_SINT_UINT_FLOAT(kind, sintWaveOp, uintWaveOp, \
  6325. floatWaveOp) \
  6326. \
  6327. case hlsl::IntrinsicOp::IOP_Wave##kind: { \
  6328. if (isSintType) { \
  6329. return spv::Op::OpGroupNonUniform##sintWaveOp; \
  6330. } \
  6331. if (isUintType) { \
  6332. return spv::Op::OpGroupNonUniform##uintWaveOp; \
  6333. } \
  6334. if (isFloatType) { \
  6335. return spv::Op::OpGroupNonUniform##floatWaveOp; \
  6336. } \
  6337. } break
  6338. switch (op) {
  6339. WAVE_OP_CASE_INT_FLOAT(ActiveUSum, IAdd, FAdd);
  6340. WAVE_OP_CASE_INT_FLOAT(ActiveSum, IAdd, FAdd);
  6341. WAVE_OP_CASE_INT_FLOAT(ActiveUProduct, IMul, FMul);
  6342. WAVE_OP_CASE_INT_FLOAT(ActiveProduct, IMul, FMul);
  6343. WAVE_OP_CASE_INT_FLOAT(PrefixUSum, IAdd, FAdd);
  6344. WAVE_OP_CASE_INT_FLOAT(PrefixSum, IAdd, FAdd);
  6345. WAVE_OP_CASE_INT_FLOAT(PrefixUProduct, IMul, FMul);
  6346. WAVE_OP_CASE_INT_FLOAT(PrefixProduct, IMul, FMul);
  6347. WAVE_OP_CASE_INT(ActiveBitAnd, BitwiseAnd);
  6348. WAVE_OP_CASE_INT(ActiveBitOr, BitwiseOr);
  6349. WAVE_OP_CASE_INT(ActiveBitXor, BitwiseXor);
  6350. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMax, SMax, UMax, FMax);
  6351. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax);
  6352. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin);
  6353. WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin);
  6354. default:
  6355. // Only Simple Wave Ops are handled here.
  6356. break;
  6357. }
  6358. #undef WAVE_OP_CASE_INT_FLOAT
  6359. #undef WAVE_OP_CASE_INT
  6360. #undef WAVE_OP_CASE_SINT_UINT_FLOAT
  6361. emitError("translating wave operator '%0' unimplemented", srcLoc)
  6362. << static_cast<uint32_t>(op);
  6363. return spv::Op::OpNop;
  6364. }
  6365. uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
  6366. const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
  6367. // Signatures:
  6368. // bool WaveActiveAllEqual( <type> expr )
  6369. // uint WaveActiveCountBits( bool bBit )
  6370. // <type> WaveActiveSum( <type> expr )
  6371. // <type> WaveActiveProduct( <type> expr )
  6372. // <int_type> WaveActiveBitAnd( <int_type> expr )
  6373. // <int_type> WaveActiveBitOr( <int_type> expr )
  6374. // <int_type> WaveActiveBitXor( <int_type> expr )
  6375. // <type> WaveActiveMin( <type> expr)
  6376. // <type> WaveActiveMax( <type> expr)
  6377. //
  6378. // uint WavePrefixCountBits(Bool bBit)
  6379. // <type> WavePrefixProduct(<type> value)
  6380. // <type> WavePrefixSum(<type> value)
  6381. assert(callExpr->getNumArgs() == 1);
  6382. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  6383. callExpr->getExprLoc());
  6384. theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
  6385. const uint32_t predicate = doExpr(callExpr->getArg(0));
  6386. const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
  6387. const uint32_t retType =
  6388. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  6389. return theBuilder.createGroupNonUniformUnaryOp(
  6390. opcode, retType, subgroupScope, predicate,
  6391. llvm::Optional<spv::GroupOperation>(groupOp));
  6392. }
  6393. uint32_t SPIRVEmitter::processWaveBroadcast(const CallExpr *callExpr) {
  6394. // Signatures:
  6395. // <type> WaveReadLaneFirst(<type> expr)
  6396. // <type> WaveReadLaneAt(<type> expr, uint laneIndex)
  6397. const auto numArgs = callExpr->getNumArgs();
  6398. assert(numArgs == 1 || numArgs == 2);
  6399. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  6400. callExpr->getExprLoc());
  6401. theBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
  6402. const uint32_t value = doExpr(callExpr->getArg(0));
  6403. const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
  6404. const uint32_t retType =
  6405. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  6406. if (numArgs == 2)
  6407. return theBuilder.createGroupNonUniformBinaryOp(
  6408. spv::Op::OpGroupNonUniformBroadcast, retType, subgroupScope, value,
  6409. doExpr(callExpr->getArg(1)));
  6410. else
  6411. return theBuilder.createGroupNonUniformUnaryOp(
  6412. spv::Op::OpGroupNonUniformBroadcastFirst, retType, subgroupScope,
  6413. value);
  6414. }
  6415. uint32_t SPIRVEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
  6416. hlsl::IntrinsicOp op) {
  6417. // Signatures:
  6418. // <type> QuadReadAcrossX(<type> localValue)
  6419. // <type> QuadReadAcrossY(<type> localValue)
  6420. // <type> QuadReadAcrossDiagonal(<type> localValue)
  6421. // <type> QuadReadLaneAt(<type> sourceValue, uint quadLaneID)
  6422. assert(callExpr->getNumArgs() == 1 || callExpr->getNumArgs() == 2);
  6423. featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
  6424. callExpr->getExprLoc());
  6425. theBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
  6426. const uint32_t value = doExpr(callExpr->getArg(0));
  6427. const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
  6428. const uint32_t retType =
  6429. typeTranslator.translateType(callExpr->getCallReturnType(astContext));
  6430. uint32_t target = 0;
  6431. spv::Op opcode = spv::Op::OpGroupNonUniformQuadSwap;
  6432. switch (op) {
  6433. case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
  6434. target = theBuilder.getConstantUint32(0);
  6435. break;
  6436. case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
  6437. target = theBuilder.getConstantUint32(1);
  6438. break;
  6439. case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
  6440. target = theBuilder.getConstantUint32(2);
  6441. break;
  6442. case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
  6443. target = doExpr(callExpr->getArg(1));
  6444. opcode = spv::Op::OpGroupNonUniformQuadBroadcast;
  6445. break;
  6446. default:
  6447. llvm_unreachable("case should not appear here");
  6448. }
  6449. return theBuilder.createGroupNonUniformBinaryOp(opcode, retType,
  6450. subgroupScope, value, target);
  6451. }
  6452. uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
  6453. // Signature is: ret modf(x, ip)
  6454. // [in] x: the input floating-point value.
  6455. // [out] ip: the integer portion of x.
  6456. // [out] ret: the fractional portion of x.
  6457. // All of the above must be a scalar, vector, or matrix with the same
  6458. // component types. Component types can be float or int.
  6459. // The ModfStruct SPIR-V instruction returns a struct. The first member is the
  6460. // fractional part and the second member is the integer portion.
  6461. // ModfStruct {
  6462. // <scalar or vector of float> frac;
  6463. // <scalar or vector of float> ip;
  6464. // }
  6465. // Note if the input number (x) is not a float (i.e. 'x' is an int), it is
  6466. // automatically converted to float before modf is invoked. Sadly, the 'ip'
  6467. // argument is not treated the same way. Therefore, in such cases we'll have
  6468. // to manually convert the float result into int.
  6469. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  6470. const Expr *arg = callExpr->getArg(0);
  6471. const Expr *ipArg = callExpr->getArg(1);
  6472. const auto argType = arg->getType();
  6473. const auto ipType = ipArg->getType();
  6474. const auto returnType = callExpr->getType();
  6475. const auto returnTypeId = typeTranslator.translateType(returnType);
  6476. const uint32_t argId = doExpr(arg);
  6477. const uint32_t ipId = doExpr(ipArg);
  6478. // For scalar and vector argument types.
  6479. {
  6480. if (TypeTranslator::isScalarType(argType) ||
  6481. TypeTranslator::isVectorType(argType)) {
  6482. const auto argTypeId = typeTranslator.translateType(argType);
  6483. // The struct members *must* have the same type.
  6484. const auto modfStructTypeId = theBuilder.getStructType(
  6485. {argTypeId, argTypeId}, "ModfStructType", {"frac", "ip"});
  6486. const auto modf =
  6487. theBuilder.createExtInst(modfStructTypeId, glslInstSetId,
  6488. GLSLstd450::GLSLstd450ModfStruct, {argId});
  6489. auto ip = theBuilder.createCompositeExtract(argTypeId, modf, {1});
  6490. // This will do nothing if the input number (x) and the ip are both of the
  6491. // same type. Otherwise, it will convert the ip into int as necessary.
  6492. ip = castToInt(ip, argType, ipType, arg->getExprLoc());
  6493. theBuilder.createStore(ipId, ip);
  6494. return theBuilder.createCompositeExtract(argTypeId, modf, {0});
  6495. }
  6496. }
  6497. // For matrix argument types.
  6498. {
  6499. uint32_t rowCount = 0, colCount = 0;
  6500. QualType elemType = {};
  6501. if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  6502. const auto elemTypeId = typeTranslator.translateType(elemType);
  6503. const auto colTypeId = theBuilder.getVecType(elemTypeId, colCount);
  6504. const auto modfStructTypeId = theBuilder.getStructType(
  6505. {colTypeId, colTypeId}, "ModfStructType", {"frac", "ip"});
  6506. llvm::SmallVector<uint32_t, 4> fracs;
  6507. llvm::SmallVector<uint32_t, 4> ips;
  6508. for (uint32_t i = 0; i < rowCount; ++i) {
  6509. const auto curRow =
  6510. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  6511. const auto modf = theBuilder.createExtInst(
  6512. modfStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450ModfStruct,
  6513. {curRow});
  6514. auto ip = theBuilder.createCompositeExtract(colTypeId, modf, {1});
  6515. ips.push_back(ip);
  6516. fracs.push_back(
  6517. theBuilder.createCompositeExtract(colTypeId, modf, {0}));
  6518. }
  6519. uint32_t ip = theBuilder.createCompositeConstruct(
  6520. typeTranslator.translateType(argType), ips);
  6521. // If the 'ip' is not a float type, the AST will not contain a CastExpr
  6522. // because this is internal to the intrinsic function. So, in such a
  6523. // case we need to cast manually.
  6524. if (!hlsl::GetHLSLMatElementType(ipType)->isFloatingType())
  6525. ip = castToInt(ip, argType, ipType, ipArg->getExprLoc());
  6526. theBuilder.createStore(ipId, ip);
  6527. return theBuilder.createCompositeConstruct(returnTypeId, fracs);
  6528. }
  6529. }
  6530. emitError("invalid argument type passed to Modf intrinsic function",
  6531. callExpr->getExprLoc());
  6532. return 0;
  6533. }
  6534. uint32_t SPIRVEmitter::processIntrinsicLit(const CallExpr *callExpr) {
  6535. // Signature is: float4 lit(float n_dot_l, float n_dot_h, float m)
  6536. //
  6537. // This function returns a lighting coefficient vector
  6538. // (ambient, diffuse, specular, 1) where:
  6539. // ambient = 1.
  6540. // diffuse = (n_dot_l < 0) ? 0 : n_dot_l
  6541. // specular = (n_dot_l < 0 || n_dot_h < 0) ? 0 : ((n_dot_h) * m)
  6542. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  6543. const uint32_t nDotL = doExpr(callExpr->getArg(0));
  6544. const uint32_t nDotH = doExpr(callExpr->getArg(1));
  6545. const uint32_t m = doExpr(callExpr->getArg(2));
  6546. const uint32_t floatType = theBuilder.getFloat32Type();
  6547. const uint32_t boolType = theBuilder.getBoolType();
  6548. const uint32_t floatZero = theBuilder.getConstantFloat32(0);
  6549. const uint32_t floatOne = theBuilder.getConstantFloat32(1);
  6550. const uint32_t retType = typeTranslator.translateType(callExpr->getType());
  6551. const uint32_t diffuse = theBuilder.createExtInst(
  6552. floatType, glslInstSetId, GLSLstd450::GLSLstd450FMax, {floatZero, nDotL});
  6553. const uint32_t min = theBuilder.createExtInst(
  6554. floatType, glslInstSetId, GLSLstd450::GLSLstd450FMin, {nDotL, nDotH});
  6555. const uint32_t isNeg = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  6556. boolType, min, floatZero);
  6557. const uint32_t mul =
  6558. theBuilder.createBinaryOp(spv::Op::OpFMul, floatType, nDotH, m);
  6559. const uint32_t specular =
  6560. theBuilder.createSelect(floatType, isNeg, floatZero, mul);
  6561. return theBuilder.createCompositeConstruct(
  6562. retType, {floatOne, diffuse, specular, floatOne});
  6563. }
  6564. uint32_t SPIRVEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
  6565. // Signature is: ret frexp(x, exp)
  6566. // [in] x: the input floating-point value.
  6567. // [out] exp: the calculated exponent.
  6568. // [out] ret: the calculated mantissa.
  6569. // All of the above must be a scalar, vector, or matrix of *float* type.
  6570. // The FrexpStruct SPIR-V instruction returns a struct. The first
  6571. // member is the significand (mantissa) and must be of the same type as the
  6572. // input parameter, and the second member is the exponent and must always be a
  6573. // scalar or vector of 32-bit *integer* type.
  6574. // FrexpStruct {
  6575. // <scalar or vector of int/float> mantissa;
  6576. // <scalar or vector of integers> exponent;
  6577. // }
  6578. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  6579. const Expr *arg = callExpr->getArg(0);
  6580. const auto argType = arg->getType();
  6581. const auto intId = theBuilder.getInt32Type();
  6582. const auto returnTypeId = typeTranslator.translateType(callExpr->getType());
  6583. const uint32_t argId = doExpr(arg);
  6584. const uint32_t expId = doExpr(callExpr->getArg(1));
  6585. // For scalar and vector argument types.
  6586. {
  6587. uint32_t elemCount = 1;
  6588. if (TypeTranslator::isScalarType(argType) ||
  6589. TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  6590. const auto argTypeId = typeTranslator.translateType(argType);
  6591. const auto expTypeId =
  6592. elemCount == 1 ? intId : theBuilder.getVecType(intId, elemCount);
  6593. const auto frexpStructTypeId = theBuilder.getStructType(
  6594. {argTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  6595. const auto frexp =
  6596. theBuilder.createExtInst(frexpStructTypeId, glslInstSetId,
  6597. GLSLstd450::GLSLstd450FrexpStruct, {argId});
  6598. const auto exponentInt =
  6599. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  6600. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  6601. // expects a float, an conversion must take place before writing the
  6602. // results.
  6603. const auto exponentFloat = theBuilder.createUnaryOp(
  6604. spv::Op::OpConvertSToF, returnTypeId, exponentInt);
  6605. theBuilder.createStore(expId, exponentFloat);
  6606. return theBuilder.createCompositeExtract(argTypeId, frexp, {0});
  6607. }
  6608. }
  6609. // For matrix argument types.
  6610. {
  6611. uint32_t rowCount = 0, colCount = 0;
  6612. if (TypeTranslator::isMxNMatrix(argType, nullptr, &rowCount, &colCount)) {
  6613. const auto floatId = theBuilder.getFloat32Type();
  6614. const auto expTypeId = theBuilder.getVecType(intId, colCount);
  6615. const auto colTypeId = theBuilder.getVecType(floatId, colCount);
  6616. const auto frexpStructTypeId = theBuilder.getStructType(
  6617. {colTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  6618. llvm::SmallVector<uint32_t, 4> exponents;
  6619. llvm::SmallVector<uint32_t, 4> mantissas;
  6620. for (uint32_t i = 0; i < rowCount; ++i) {
  6621. const auto curRow =
  6622. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  6623. const auto frexp = theBuilder.createExtInst(
  6624. frexpStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450FrexpStruct,
  6625. {curRow});
  6626. const auto exponentInt =
  6627. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  6628. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  6629. // expects a float, an conversion must take place before writing the
  6630. // results.
  6631. const auto exponentFloat = theBuilder.createUnaryOp(
  6632. spv::Op::OpConvertSToF, colTypeId, exponentInt);
  6633. exponents.push_back(exponentFloat);
  6634. mantissas.push_back(
  6635. theBuilder.createCompositeExtract(colTypeId, frexp, {0}));
  6636. }
  6637. const auto exponentsResultId =
  6638. theBuilder.createCompositeConstruct(returnTypeId, exponents);
  6639. theBuilder.createStore(expId, exponentsResultId);
  6640. return theBuilder.createCompositeConstruct(returnTypeId, mantissas);
  6641. }
  6642. }
  6643. emitError("invalid argument type passed to Frexp intrinsic function",
  6644. callExpr->getExprLoc());
  6645. return 0;
  6646. }
  6647. uint32_t SPIRVEmitter::processIntrinsicLdexp(const CallExpr *callExpr) {
  6648. // Signature: ret ldexp(x, exp)
  6649. // This function uses the following formula: x * 2^exp.
  6650. // Note that we cannot use GLSL extended instruction Ldexp since it requires
  6651. // the exponent to be an integer (vector) but HLSL takes an float (vector)
  6652. // exponent. So we must calculate the result manually.
  6653. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  6654. const Expr *x = callExpr->getArg(0);
  6655. const auto paramType = x->getType();
  6656. const uint32_t xId = doExpr(x);
  6657. const uint32_t expId = doExpr(callExpr->getArg(1));
  6658. // For scalar and vector argument types.
  6659. if (TypeTranslator::isScalarType(paramType) ||
  6660. TypeTranslator::isVectorType(paramType)) {
  6661. const auto paramTypeId = typeTranslator.translateType(paramType);
  6662. const auto twoExp = theBuilder.createExtInst(
  6663. paramTypeId, glsl, GLSLstd450::GLSLstd450Exp2, {expId});
  6664. return theBuilder.createBinaryOp(spv::Op::OpFMul, paramTypeId, xId, twoExp);
  6665. }
  6666. // For matrix argument types.
  6667. {
  6668. uint32_t rowCount = 0, colCount = 0;
  6669. if (TypeTranslator::isMxNMatrix(paramType, nullptr, &rowCount, &colCount)) {
  6670. const auto actOnEachVec = [this, glsl, expId](uint32_t index,
  6671. uint32_t vecType,
  6672. uint32_t xRowId) {
  6673. const auto expRowId =
  6674. theBuilder.createCompositeExtract(vecType, expId, {index});
  6675. const auto twoExp = theBuilder.createExtInst(
  6676. vecType, glsl, GLSLstd450::GLSLstd450Exp2, {expRowId});
  6677. return theBuilder.createBinaryOp(spv::Op::OpFMul, vecType, xRowId,
  6678. twoExp);
  6679. };
  6680. return processEachVectorInMatrix(x, xId, actOnEachVec);
  6681. }
  6682. }
  6683. emitError("invalid argument type passed to ldexp intrinsic function",
  6684. callExpr->getExprLoc());
  6685. return 0;
  6686. }
  6687. uint32_t SPIRVEmitter::processIntrinsicDst(const CallExpr *callExpr) {
  6688. // Signature is float4 dst(float4 src0, float4 src1)
  6689. // result.x = 1;
  6690. // result.y = src0.y * src1.y;
  6691. // result.z = src0.z;
  6692. // result.w = src1.w;
  6693. const auto floatId = theBuilder.getFloat32Type();
  6694. const auto arg0Id = doExpr(callExpr->getArg(0));
  6695. const auto arg1Id = doExpr(callExpr->getArg(1));
  6696. const auto arg0y = theBuilder.createCompositeExtract(floatId, arg0Id, {1});
  6697. const auto arg1y = theBuilder.createCompositeExtract(floatId, arg1Id, {1});
  6698. const auto arg0z = theBuilder.createCompositeExtract(floatId, arg0Id, {2});
  6699. const auto arg1w = theBuilder.createCompositeExtract(floatId, arg1Id, {3});
  6700. const auto arg0yMularg1y =
  6701. theBuilder.createBinaryOp(spv::Op::OpFMul, floatId, arg0y, arg1y);
  6702. return theBuilder.createCompositeConstruct(
  6703. typeTranslator.translateType(callExpr->getType()),
  6704. {theBuilder.getConstantFloat32(1.0), arg0yMularg1y, arg0z, arg1w});
  6705. }
  6706. uint32_t SPIRVEmitter::processIntrinsicClip(const CallExpr *callExpr) {
  6707. // Discards the current pixel if the specified value is less than zero.
  6708. // TODO: If the argument can be const folded and evaluated, we could
  6709. // potentially avoid creating a branch. This would be a bit challenging for
  6710. // matrix/vector arguments.
  6711. assert(callExpr->getNumArgs() == 1u);
  6712. const Expr *arg = callExpr->getArg(0);
  6713. const auto argType = arg->getType();
  6714. const auto boolType = theBuilder.getBoolType();
  6715. uint32_t condition = 0;
  6716. // Could not determine the argument as a constant. We need to branch based on
  6717. // the argument. If the argument is a vector/matrix, clipping is done if *any*
  6718. // element of the vector/matrix is less than zero.
  6719. const uint32_t argId = doExpr(arg);
  6720. QualType elemType = {};
  6721. uint32_t elemCount = 0, rowCount = 0, colCount = 0;
  6722. if (TypeTranslator::isScalarType(argType)) {
  6723. const auto zero = getValueZero(argType);
  6724. condition = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolType,
  6725. argId, zero);
  6726. } else if (TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  6727. const auto zero = getValueZero(argType);
  6728. const auto boolVecType = theBuilder.getVecType(boolType, elemCount);
  6729. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  6730. boolVecType, argId, zero);
  6731. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  6732. } else if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount,
  6733. &colCount)) {
  6734. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  6735. const uint32_t floatVecType = theBuilder.getVecType(elemTypeId, colCount);
  6736. const uint32_t elemZeroId = getValueZero(elemType);
  6737. llvm::SmallVector<uint32_t, 4> elements(size_t(colCount), elemZeroId);
  6738. const auto zero = theBuilder.getConstantComposite(floatVecType, elements);
  6739. llvm::SmallVector<uint32_t, 4> cmpResults;
  6740. for (uint32_t i = 0; i < rowCount; ++i) {
  6741. const uint32_t lhsVec =
  6742. theBuilder.createCompositeExtract(floatVecType, argId, {i});
  6743. const auto boolColType = theBuilder.getVecType(boolType, colCount);
  6744. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  6745. boolColType, lhsVec, zero);
  6746. const auto any = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  6747. cmpResults.push_back(any);
  6748. }
  6749. const auto boolRowType = theBuilder.getVecType(boolType, rowCount);
  6750. const auto results =
  6751. theBuilder.createCompositeConstruct(boolRowType, cmpResults);
  6752. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, results);
  6753. } else {
  6754. emitError("invalid argument type passed to clip intrinsic function",
  6755. callExpr->getExprLoc());
  6756. return 0;
  6757. }
  6758. // Then we need to emit the instruction for the conditional branch.
  6759. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  6760. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  6761. // Create the branch instruction. This will end the current basic block.
  6762. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  6763. theBuilder.addSuccessor(thenBB);
  6764. theBuilder.addSuccessor(mergeBB);
  6765. theBuilder.setMergeTarget(mergeBB);
  6766. // Handle the then branch
  6767. theBuilder.setInsertPoint(thenBB);
  6768. theBuilder.createKill();
  6769. theBuilder.addSuccessor(mergeBB);
  6770. // From now on, we'll emit instructions into the merge block.
  6771. theBuilder.setInsertPoint(mergeBB);
  6772. return 0;
  6773. }
  6774. uint32_t SPIRVEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
  6775. // According the HLSL reference: clamp(X, Min, Max) takes 3 arguments. Each
  6776. // one may be int, uint, or float.
  6777. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  6778. const QualType returnType = callExpr->getType();
  6779. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  6780. GLSLstd450 glslOpcode = GLSLstd450::GLSLstd450UClamp;
  6781. if (isFloatOrVecMatOfFloatType(returnType))
  6782. glslOpcode = GLSLstd450::GLSLstd450FClamp;
  6783. else if (isSintOrVecMatOfSintType(returnType))
  6784. glslOpcode = GLSLstd450::GLSLstd450SClamp;
  6785. // Get the function parameters. Expect 3 parameters.
  6786. assert(callExpr->getNumArgs() == 3u);
  6787. const Expr *argX = callExpr->getArg(0);
  6788. const Expr *argMin = callExpr->getArg(1);
  6789. const Expr *argMax = callExpr->getArg(2);
  6790. const uint32_t argXId = doExpr(argX);
  6791. const uint32_t argMinId = doExpr(argMin);
  6792. const uint32_t argMaxId = doExpr(argMax);
  6793. // FClamp, UClamp, and SClamp do not operate on matrices, so we should perform
  6794. // the operation on each vector of the matrix.
  6795. if (TypeTranslator::isMxNMatrix(argX->getType())) {
  6796. const auto actOnEachVec = [this, glslInstSetId, glslOpcode, argMinId,
  6797. argMaxId](uint32_t index, uint32_t vecType,
  6798. uint32_t curRowId) {
  6799. const auto minRowId =
  6800. theBuilder.createCompositeExtract(vecType, argMinId, {index});
  6801. const auto maxRowId =
  6802. theBuilder.createCompositeExtract(vecType, argMaxId, {index});
  6803. return theBuilder.createExtInst(vecType, glslInstSetId, glslOpcode,
  6804. {curRowId, minRowId, maxRowId});
  6805. };
  6806. return processEachVectorInMatrix(argX, argXId, actOnEachVec);
  6807. }
  6808. return theBuilder.createExtInst(returnTypeId, glslInstSetId, glslOpcode,
  6809. {argXId, argMinId, argMaxId});
  6810. }
  6811. uint32_t SPIRVEmitter::processIntrinsicMemoryBarrier(const CallExpr *callExpr,
  6812. bool isDevice,
  6813. bool groupSync,
  6814. bool isAllBarrier) {
  6815. // * DeviceMemoryBarrier =
  6816. // OpMemoryBarrier (memScope=Device,
  6817. // sem=Image|Uniform|AcquireRelease)
  6818. //
  6819. // * DeviceMemoryBarrierWithGroupSync =
  6820. // OpControlBarrier(execScope = Workgroup,
  6821. // memScope=Device,
  6822. // sem=Image|Uniform|AcquireRelease)
  6823. const spv::MemorySemanticsMask deviceMemoryBarrierSema =
  6824. spv::MemorySemanticsMask::ImageMemory |
  6825. spv::MemorySemanticsMask::UniformMemory |
  6826. spv::MemorySemanticsMask::AcquireRelease;
  6827. // * GroupMemoryBarrier =
  6828. // OpMemoryBarrier (memScope=Workgroup,
  6829. // sem = Workgroup|AcquireRelease)
  6830. //
  6831. // * GroupMemoryBarrierWithGroupSync =
  6832. // OpControlBarrier (execScope = Workgroup,
  6833. // memScope = Workgroup,
  6834. // sem = Workgroup|AcquireRelease)
  6835. const spv::MemorySemanticsMask groupMemoryBarrierSema =
  6836. spv::MemorySemanticsMask::WorkgroupMemory |
  6837. spv::MemorySemanticsMask::AcquireRelease;
  6838. // * AllMemoryBarrier =
  6839. // OpMemoryBarrier(memScope = Device,
  6840. // sem = Image|Uniform|Workgroup|AcquireRelease)
  6841. //
  6842. // * AllMemoryBarrierWithGroupSync =
  6843. // OpControlBarrier(execScope = Workgroup,
  6844. // memScope = Device,
  6845. // sem = Image|Uniform|Workgroup|AcquireRelease)
  6846. const spv::MemorySemanticsMask allMemoryBarrierSema =
  6847. spv::MemorySemanticsMask::ImageMemory |
  6848. spv::MemorySemanticsMask::UniformMemory |
  6849. spv::MemorySemanticsMask::WorkgroupMemory |
  6850. spv::MemorySemanticsMask::AtomicCounterMemory |
  6851. spv::MemorySemanticsMask::AcquireRelease;
  6852. // Get <result-id> for execution scope.
  6853. // If present, execution scope is always Workgroup!
  6854. const uint32_t execScopeId =
  6855. groupSync ? theBuilder.getConstantUint32(
  6856. static_cast<uint32_t>(spv::Scope::Workgroup))
  6857. : 0;
  6858. // Get <result-id> for memory scope
  6859. const spv::Scope memScope =
  6860. (isDevice || isAllBarrier) ? spv::Scope::Device : spv::Scope::Workgroup;
  6861. const uint32_t memScopeId =
  6862. theBuilder.getConstantUint32(static_cast<uint32_t>(memScope));
  6863. // Get <result-id> for memory semantics
  6864. const auto memSemaMask = isAllBarrier ? allMemoryBarrierSema
  6865. : isDevice ? deviceMemoryBarrierSema
  6866. : groupMemoryBarrierSema;
  6867. const uint32_t memSema =
  6868. theBuilder.getConstantUint32(static_cast<uint32_t>(memSemaMask));
  6869. theBuilder.createBarrier(execScopeId, memScopeId, memSema);
  6870. return 0;
  6871. }
  6872. uint32_t SPIRVEmitter::processNonFpMatrixTranspose(QualType matType,
  6873. uint32_t matId) {
  6874. // Simplest way is to flatten the matrix construct a new matrix from the
  6875. // flattened elements. (for a mat4x4).
  6876. QualType elemType = {};
  6877. uint32_t numRows = 0, numCols = 0;
  6878. const bool isMat =
  6879. TypeTranslator::isMxNMatrix(matType, &elemType, &numRows, &numCols);
  6880. assert(isMat && !elemType->isFloatingType());
  6881. (void)isMat;
  6882. const auto colQualType = astContext.getExtVectorType(elemType, numRows);
  6883. const uint32_t colTypeId = typeTranslator.translateType(colQualType);
  6884. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  6885. // You cannot perform a composite construct of an array using a few vectors.
  6886. // The number of constutients passed to OpCompositeConstruct must be equal to
  6887. // the number of array elements.
  6888. llvm::SmallVector<uint32_t, 4> elems;
  6889. for (uint32_t i = 0; i < numRows; ++i)
  6890. for (uint32_t j = 0; j < numCols; ++j)
  6891. elems.push_back(
  6892. theBuilder.createCompositeExtract(elemTypeId, matId, {i, j}));
  6893. llvm::SmallVector<uint32_t, 4> cols;
  6894. for (uint32_t i = 0; i < numCols; ++i) {
  6895. // The elements in the ith vector of the "transposed" array are at offset i,
  6896. // i + <original-vector-size>, ...
  6897. llvm::SmallVector<uint32_t, 4> indexes;
  6898. for (uint32_t j = 0; j < numRows; ++j)
  6899. indexes.push_back(elems[i + (j * numCols)]);
  6900. cols.push_back(theBuilder.createCompositeConstruct(colTypeId, indexes));
  6901. }
  6902. const auto transposeTypeId =
  6903. theBuilder.getArrayType(colTypeId, theBuilder.getConstantUint32(numCols));
  6904. return theBuilder.createCompositeConstruct(transposeTypeId, cols);
  6905. }
  6906. uint32_t SPIRVEmitter::processNonFpDot(uint32_t vec1Id, uint32_t vec2Id,
  6907. uint32_t vecSize, QualType elemType) {
  6908. const auto elemTypeId = typeTranslator.translateType(elemType);
  6909. llvm::SmallVector<uint32_t, 4> muls;
  6910. for (uint32_t i = 0; i < vecSize; ++i) {
  6911. const auto elem1 =
  6912. theBuilder.createCompositeExtract(elemTypeId, vec1Id, {i});
  6913. const auto elem2 =
  6914. theBuilder.createCompositeExtract(elemTypeId, vec2Id, {i});
  6915. muls.push_back(theBuilder.createBinaryOp(translateOp(BO_Mul, elemType),
  6916. elemTypeId, elem1, elem2));
  6917. }
  6918. uint32_t sum = muls[0];
  6919. for (uint32_t i = 1; i < vecSize; ++i) {
  6920. sum = theBuilder.createBinaryOp(translateOp(BO_Add, elemType), elemTypeId,
  6921. sum, muls[i]);
  6922. }
  6923. return sum;
  6924. }
  6925. uint32_t SPIRVEmitter::processNonFpScalarTimesMatrix(QualType scalarType,
  6926. uint32_t scalarId,
  6927. QualType matrixType,
  6928. uint32_t matrixId) {
  6929. assert(TypeTranslator::isScalarType(scalarType));
  6930. QualType elemType = {};
  6931. uint32_t numRows = 0, numCols = 0;
  6932. const bool isMat =
  6933. TypeTranslator::isMxNMatrix(matrixType, &elemType, &numRows, &numCols);
  6934. assert(isMat);
  6935. assert(typeTranslator.isSameType(scalarType, elemType));
  6936. (void)isMat;
  6937. // We need to multiply the scalar by each vector of the matrix.
  6938. // The front-end guarantees that the scalar and matrix element type are
  6939. // the same. For example, if the scalar is a float, the matrix is casted
  6940. // to a float matrix before being passed to mul(). It is also guaranteed
  6941. // that types such as bool are casted to float or int before being
  6942. // passed to mul().
  6943. const auto rowType = astContext.getExtVectorType(elemType, numCols);
  6944. const auto rowTypeId = typeTranslator.translateType(rowType);
  6945. llvm::SmallVector<uint32_t, 4> splat(size_t(numCols), scalarId);
  6946. const auto scalarSplat =
  6947. theBuilder.createCompositeConstruct(rowTypeId, splat);
  6948. llvm::SmallVector<uint32_t, 4> mulRows;
  6949. for (uint32_t row = 0; row < numRows; ++row) {
  6950. const auto rowId =
  6951. theBuilder.createCompositeExtract(rowTypeId, matrixId, {row});
  6952. mulRows.push_back(theBuilder.createBinaryOp(translateOp(BO_Mul, scalarType),
  6953. rowTypeId, rowId, scalarSplat));
  6954. }
  6955. return theBuilder.createCompositeConstruct(
  6956. typeTranslator.translateType(matrixType), mulRows);
  6957. }
  6958. uint32_t SPIRVEmitter::processNonFpVectorTimesMatrix(QualType vecType,
  6959. uint32_t vecId,
  6960. QualType matType,
  6961. uint32_t matId,
  6962. uint32_t matTransposeId) {
  6963. // This function assumes that the vector element type and matrix elemet type
  6964. // are the same.
  6965. QualType vecElemType = {}, matElemType = {};
  6966. uint32_t vecSize = 0, numRows = 0, numCols = 0;
  6967. const bool isVec =
  6968. TypeTranslator::isVectorType(vecType, &vecElemType, &vecSize);
  6969. const bool isMat =
  6970. TypeTranslator::isMxNMatrix(matType, &matElemType, &numRows, &numCols);
  6971. assert(typeTranslator.isSameType(vecElemType, matElemType));
  6972. assert(isVec);
  6973. assert(isMat);
  6974. assert(vecSize == numRows);
  6975. (void)isVec;
  6976. (void)isMat;
  6977. // When processing vector times matrix, the vector is a row vector, and it
  6978. // should be multiplied by the matrix *columns*. The most efficient way to
  6979. // handle this in SPIR-V would be to first transpose the matrix, and then use
  6980. // OpAccessChain.
  6981. if (!matTransposeId)
  6982. matTransposeId = processNonFpMatrixTranspose(matType, matId);
  6983. const auto vecTypeId = typeTranslator.translateType(vecType);
  6984. llvm::SmallVector<uint32_t, 4> resultElems;
  6985. for (uint32_t col = 0; col < numCols; ++col) {
  6986. const auto colId =
  6987. theBuilder.createCompositeExtract(vecTypeId, matTransposeId, {col});
  6988. resultElems.push_back(processNonFpDot(vecId, colId, vecSize, vecElemType));
  6989. }
  6990. return theBuilder.createCompositeConstruct(
  6991. typeTranslator.translateType(
  6992. astContext.getExtVectorType(vecElemType, numCols)),
  6993. resultElems);
  6994. }
  6995. uint32_t SPIRVEmitter::processNonFpMatrixTimesVector(QualType matType,
  6996. uint32_t matId,
  6997. QualType vecType,
  6998. uint32_t vecId) {
  6999. // This function assumes that the vector element type and matrix elemet type
  7000. // are the same.
  7001. QualType vecElemType = {}, matElemType = {};
  7002. uint32_t vecSize = 0, numRows = 0, numCols = 0;
  7003. const bool isVec =
  7004. TypeTranslator::isVectorType(vecType, &vecElemType, &vecSize);
  7005. const bool isMat =
  7006. TypeTranslator::isMxNMatrix(matType, &matElemType, &numRows, &numCols);
  7007. assert(typeTranslator.isSameType(vecElemType, matElemType));
  7008. assert(isVec);
  7009. assert(isMat);
  7010. assert(vecSize == numCols);
  7011. (void)isVec;
  7012. (void)isMat;
  7013. // When processing matrix times vector, the vector is a column vector. So we
  7014. // simply get each row of the matrix and perform a dot product with the
  7015. // vector.
  7016. const auto vecTypeId = typeTranslator.translateType(vecType);
  7017. llvm::SmallVector<uint32_t, 4> resultElems;
  7018. for (uint32_t row = 0; row < numRows; ++row) {
  7019. const auto rowId =
  7020. theBuilder.createCompositeExtract(vecTypeId, matId, {row});
  7021. resultElems.push_back(processNonFpDot(rowId, vecId, vecSize, vecElemType));
  7022. }
  7023. return theBuilder.createCompositeConstruct(
  7024. typeTranslator.translateType(
  7025. astContext.getExtVectorType(vecElemType, numRows)),
  7026. resultElems);
  7027. }
  7028. uint32_t SPIRVEmitter::processNonFpMatrixTimesMatrix(QualType lhsType,
  7029. uint32_t lhsId,
  7030. QualType rhsType,
  7031. uint32_t rhsId) {
  7032. // This function assumes that the vector element type and matrix elemet type
  7033. // are the same.
  7034. QualType lhsElemType = {}, rhsElemType = {};
  7035. uint32_t lhsNumRows = 0, lhsNumCols = 0;
  7036. uint32_t rhsNumRows = 0, rhsNumCols = 0;
  7037. const bool lhsIsMat = TypeTranslator::isMxNMatrix(lhsType, &lhsElemType,
  7038. &lhsNumRows, &lhsNumCols);
  7039. const bool rhsIsMat = TypeTranslator::isMxNMatrix(rhsType, &rhsElemType,
  7040. &rhsNumRows, &rhsNumCols);
  7041. assert(typeTranslator.isSameType(lhsElemType, rhsElemType));
  7042. assert(lhsIsMat && rhsIsMat);
  7043. assert(lhsNumCols == rhsNumRows);
  7044. (void)rhsIsMat;
  7045. (void)lhsIsMat;
  7046. const uint32_t rhsTranspose = processNonFpMatrixTranspose(rhsType, rhsId);
  7047. const auto vecType = astContext.getExtVectorType(lhsElemType, lhsNumCols);
  7048. const auto vecTypeId = typeTranslator.translateType(vecType);
  7049. llvm::SmallVector<uint32_t, 4> resultRows;
  7050. for (uint32_t row = 0; row < lhsNumRows; ++row) {
  7051. const auto rowId =
  7052. theBuilder.createCompositeExtract(vecTypeId, lhsId, {row});
  7053. resultRows.push_back(processNonFpVectorTimesMatrix(vecType, rowId, rhsType,
  7054. rhsId, rhsTranspose));
  7055. }
  7056. // The resulting matrix will have 'lhsNumRows' rows and 'rhsNumCols' columns.
  7057. const auto elemTypeId = typeTranslator.translateType(lhsElemType);
  7058. const auto resultNumRows = theBuilder.getConstantUint32(lhsNumRows);
  7059. const auto resultColType = theBuilder.getVecType(elemTypeId, rhsNumCols);
  7060. const auto resultType = theBuilder.getArrayType(resultColType, resultNumRows);
  7061. return theBuilder.createCompositeConstruct(resultType, resultRows);
  7062. }
  7063. uint32_t SPIRVEmitter::processIntrinsicMul(const CallExpr *callExpr) {
  7064. const uint32_t returnTypeId =
  7065. typeTranslator.translateType(callExpr->getType());
  7066. // Get the function parameters. Expect 2 parameters.
  7067. assert(callExpr->getNumArgs() == 2u);
  7068. const Expr *arg0 = callExpr->getArg(0);
  7069. const Expr *arg1 = callExpr->getArg(1);
  7070. const QualType arg0Type = arg0->getType();
  7071. const QualType arg1Type = arg1->getType();
  7072. // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
  7073. // vector, or matrix. The frontend ensures that the two arguments have the
  7074. // same component type. The only allowed component types are int and float.
  7075. // mul(scalar, vector)
  7076. {
  7077. uint32_t elemCount = 0;
  7078. if (TypeTranslator::isScalarType(arg0Type) &&
  7079. TypeTranslator::isVectorType(arg1Type, nullptr, &elemCount)) {
  7080. const uint32_t arg1Id = doExpr(arg1);
  7081. // We can use OpVectorTimesScalar if arguments are floats.
  7082. if (arg0Type->isFloatingType())
  7083. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  7084. returnTypeId, arg1Id, doExpr(arg0));
  7085. // Use OpIMul for integers
  7086. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId,
  7087. createVectorSplat(arg0, elemCount),
  7088. arg1Id);
  7089. }
  7090. }
  7091. // mul(vector, scalar)
  7092. {
  7093. uint32_t elemCount = 0;
  7094. if (TypeTranslator::isVectorType(arg0Type, nullptr, &elemCount) &&
  7095. TypeTranslator::isScalarType(arg1Type)) {
  7096. const uint32_t arg0Id = doExpr(arg0);
  7097. // We can use OpVectorTimesScalar if arguments are floats.
  7098. if (arg1Type->isFloatingType())
  7099. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  7100. returnTypeId, arg0Id, doExpr(arg1));
  7101. // Use OpIMul for integers
  7102. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId, arg0Id,
  7103. createVectorSplat(arg1, elemCount));
  7104. }
  7105. }
  7106. // mul(vector, vector)
  7107. if (TypeTranslator::isVectorType(arg0Type) &&
  7108. TypeTranslator::isVectorType(arg1Type))
  7109. return processIntrinsicDot(callExpr);
  7110. // All the following cases require handling arg0 and arg1 expressions first.
  7111. const uint32_t arg0Id = doExpr(arg0);
  7112. const uint32_t arg1Id = doExpr(arg1);
  7113. // mul(scalar, scalar)
  7114. if (TypeTranslator::isScalarType(arg0Type) &&
  7115. TypeTranslator::isScalarType(arg1Type))
  7116. return theBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type),
  7117. returnTypeId, arg0Id, arg1Id);
  7118. // mul(scalar, matrix)
  7119. {
  7120. QualType elemType = {};
  7121. if (TypeTranslator::isScalarType(arg0Type) &&
  7122. TypeTranslator::isMxNMatrix(arg1Type, &elemType)) {
  7123. // OpMatrixTimesScalar can only be used if *both* the matrix element type
  7124. // and the scalar type are float.
  7125. if (arg0Type->isFloatingType() && elemType->isFloatingType())
  7126. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  7127. returnTypeId, arg1Id, arg0Id);
  7128. else
  7129. return processNonFpScalarTimesMatrix(arg0Type, arg0Id, arg1Type,
  7130. arg1Id);
  7131. }
  7132. }
  7133. // mul(matrix, scalar)
  7134. {
  7135. QualType elemType = {};
  7136. if (TypeTranslator::isScalarType(arg1Type) &&
  7137. TypeTranslator::isMxNMatrix(arg0Type, &elemType)) {
  7138. // OpMatrixTimesScalar can only be used if *both* the matrix element type
  7139. // and the scalar type are float.
  7140. if (arg1Type->isFloatingType() && elemType->isFloatingType())
  7141. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  7142. returnTypeId, arg0Id, arg1Id);
  7143. else
  7144. return processNonFpScalarTimesMatrix(arg1Type, arg1Id, arg0Type,
  7145. arg0Id);
  7146. }
  7147. }
  7148. // mul(vector, matrix)
  7149. {
  7150. QualType vecElemType = {}, matElemType = {};
  7151. uint32_t elemCount = 0, numRows = 0;
  7152. if (TypeTranslator::isVectorType(arg0Type, &vecElemType, &elemCount) &&
  7153. TypeTranslator::isMxNMatrix(arg1Type, &matElemType, &numRows)) {
  7154. assert(elemCount == numRows);
  7155. if (vecElemType->isFloatingType() && matElemType->isFloatingType())
  7156. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
  7157. returnTypeId, arg1Id, arg0Id);
  7158. else
  7159. return processNonFpVectorTimesMatrix(arg0Type, arg0Id, arg1Type,
  7160. arg1Id);
  7161. }
  7162. }
  7163. // mul(matrix, vector)
  7164. {
  7165. QualType vecElemType = {}, matElemType = {};
  7166. uint32_t elemCount = 0, numCols = 0;
  7167. if (TypeTranslator::isMxNMatrix(arg0Type, &matElemType, nullptr,
  7168. &numCols) &&
  7169. TypeTranslator::isVectorType(arg1Type, &vecElemType, &elemCount)) {
  7170. assert(elemCount == numCols);
  7171. if (vecElemType->isFloatingType() && matElemType->isFloatingType())
  7172. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
  7173. returnTypeId, arg1Id, arg0Id);
  7174. else
  7175. return processNonFpMatrixTimesVector(arg0Type, arg0Id, arg1Type,
  7176. arg1Id);
  7177. }
  7178. }
  7179. // mul(matrix, matrix)
  7180. {
  7181. // The front-end ensures that the two matrix element types match.
  7182. QualType elemType = {};
  7183. uint32_t lhsCols = 0, rhsRows = 0;
  7184. if (TypeTranslator::isMxNMatrix(arg0Type, &elemType, nullptr, &lhsCols) &&
  7185. TypeTranslator::isMxNMatrix(arg1Type, nullptr, &rhsRows, nullptr)) {
  7186. assert(lhsCols == rhsRows);
  7187. if (elemType->isFloatingType())
  7188. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
  7189. returnTypeId, arg1Id, arg0Id);
  7190. else
  7191. return processNonFpMatrixTimesMatrix(arg0Type, arg0Id, arg1Type,
  7192. arg1Id);
  7193. }
  7194. }
  7195. emitError("invalid argument type passed to mul intrinsic function",
  7196. callExpr->getExprLoc());
  7197. return 0;
  7198. }
  7199. uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
  7200. const QualType returnType = callExpr->getType();
  7201. const uint32_t returnTypeId =
  7202. typeTranslator.translateType(callExpr->getType());
  7203. // Get the function parameters. Expect 2 vectors as parameters.
  7204. assert(callExpr->getNumArgs() == 2u);
  7205. const Expr *arg0 = callExpr->getArg(0);
  7206. const Expr *arg1 = callExpr->getArg(1);
  7207. const uint32_t arg0Id = doExpr(arg0);
  7208. const uint32_t arg1Id = doExpr(arg1);
  7209. QualType arg0Type = arg0->getType();
  7210. QualType arg1Type = arg1->getType();
  7211. const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
  7212. const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
  7213. const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
  7214. const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
  7215. assert(returnType == vec1ComponentType);
  7216. assert(vec0ComponentType == vec1ComponentType);
  7217. assert(vec0Size == vec1Size);
  7218. assert(vec0Size >= 1 && vec0Size <= 4);
  7219. (void)vec0ComponentType;
  7220. (void)vec1ComponentType;
  7221. (void)vec1Size;
  7222. // According to HLSL reference, the dot function only works on integers
  7223. // and floats.
  7224. assert(returnType->isFloatingType() || returnType->isIntegerType());
  7225. // Special case: dot product of two vectors, each of size 1. That is
  7226. // basically the same as regular multiplication of 2 scalars.
  7227. if (vec0Size == 1) {
  7228. const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
  7229. return theBuilder.createBinaryOp(spvOp, returnTypeId, arg0Id, arg1Id);
  7230. }
  7231. // If the vectors are of type Float, we can use OpDot.
  7232. if (returnType->isFloatingType()) {
  7233. return theBuilder.createBinaryOp(spv::Op::OpDot, returnTypeId, arg0Id,
  7234. arg1Id);
  7235. }
  7236. // Vector component type is Integer (signed or unsigned).
  7237. // Create all instructions necessary to perform a dot product on
  7238. // two integer vectors. SPIR-V OpDot does not support integer vectors.
  7239. // Therefore, we use other SPIR-V instructions (addition and
  7240. // multiplication).
  7241. else {
  7242. uint32_t result = 0;
  7243. llvm::SmallVector<uint32_t, 4> multIds;
  7244. const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
  7245. const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
  7246. // Extract members from the two vectors and multiply them.
  7247. for (unsigned int i = 0; i < vec0Size; ++i) {
  7248. const uint32_t vec0member =
  7249. theBuilder.createCompositeExtract(returnTypeId, arg0Id, {i});
  7250. const uint32_t vec1member =
  7251. theBuilder.createCompositeExtract(returnTypeId, arg1Id, {i});
  7252. const uint32_t multId = theBuilder.createBinaryOp(multSpvOp, returnTypeId,
  7253. vec0member, vec1member);
  7254. multIds.push_back(multId);
  7255. }
  7256. // Add all the multiplications.
  7257. result = multIds[0];
  7258. for (unsigned int i = 1; i < vec0Size; ++i) {
  7259. const uint32_t additionId =
  7260. theBuilder.createBinaryOp(addSpvOp, returnTypeId, result, multIds[i]);
  7261. result = additionId;
  7262. }
  7263. return result;
  7264. }
  7265. }
  7266. uint32_t SPIRVEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
  7267. // 'rcp' takes only 1 argument that is a scalar, vector, or matrix of type
  7268. // float or double.
  7269. assert(callExpr->getNumArgs() == 1u);
  7270. const QualType returnType = callExpr->getType();
  7271. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  7272. const Expr *arg = callExpr->getArg(0);
  7273. const uint32_t argId = doExpr(arg);
  7274. const QualType argType = arg->getType();
  7275. // For cases with matrix argument.
  7276. QualType elemType = {};
  7277. uint32_t numRows = 0, numCols = 0;
  7278. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  7279. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  7280. const auto actOnEachVec = [this, vecOne](uint32_t /*index*/,
  7281. uint32_t vecType,
  7282. uint32_t curRowId) {
  7283. return theBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne,
  7284. curRowId);
  7285. };
  7286. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  7287. }
  7288. // For cases with scalar or vector arguments.
  7289. return theBuilder.createBinaryOp(spv::Op::OpFDiv, returnTypeId,
  7290. getValueOne(argType), argId);
  7291. }
  7292. uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
  7293. spv::Op spvOp) {
  7294. // 'all' and 'any' take only 1 parameter.
  7295. assert(callExpr->getNumArgs() == 1u);
  7296. const QualType returnType = callExpr->getType();
  7297. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  7298. const Expr *arg = callExpr->getArg(0);
  7299. const QualType argType = arg->getType();
  7300. // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
  7301. // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
  7302. {
  7303. QualType scalarType = {};
  7304. if (TypeTranslator::isScalarType(argType, &scalarType) &&
  7305. (scalarType->isBooleanType() || scalarType->isFloatingType() ||
  7306. scalarType->isIntegerType()))
  7307. return castToBool(doExpr(arg), argType, returnType);
  7308. }
  7309. // Handle vectors larger than 1, Mx1 matrices, and 1xN matrices as arguments.
  7310. // Cast the vector to a boolean vector, then run OpAny/OpAll on it.
  7311. {
  7312. QualType elemType = {};
  7313. uint32_t size = 0;
  7314. if (TypeTranslator::isVectorType(argType, &elemType, &size)) {
  7315. const QualType castToBoolType =
  7316. astContext.getExtVectorType(returnType, size);
  7317. uint32_t castedToBoolId =
  7318. castToBool(doExpr(arg), argType, castToBoolType);
  7319. return theBuilder.createUnaryOp(spvOp, returnTypeId, castedToBoolId);
  7320. }
  7321. }
  7322. // Handle MxN matrices as arguments.
  7323. {
  7324. QualType elemType = {};
  7325. uint32_t matRowCount = 0, matColCount = 0;
  7326. if (TypeTranslator::isMxNMatrix(argType, &elemType, &matRowCount,
  7327. &matColCount)) {
  7328. uint32_t matrixId = doExpr(arg);
  7329. const uint32_t vecType = typeTranslator.getComponentVectorType(argType);
  7330. llvm::SmallVector<uint32_t, 4> rowResults;
  7331. for (uint32_t i = 0; i < matRowCount; ++i) {
  7332. // Extract the row which is a float vector of size matColCount.
  7333. const uint32_t rowFloatVec =
  7334. theBuilder.createCompositeExtract(vecType, matrixId, {i});
  7335. // Cast the float vector to boolean vector.
  7336. const auto rowFloatQualType =
  7337. astContext.getExtVectorType(elemType, matColCount);
  7338. const auto rowBoolQualType =
  7339. astContext.getExtVectorType(returnType, matColCount);
  7340. const uint32_t rowBoolVec =
  7341. castToBool(rowFloatVec, rowFloatQualType, rowBoolQualType);
  7342. // Perform OpAny/OpAll on the boolean vector.
  7343. rowResults.push_back(
  7344. theBuilder.createUnaryOp(spvOp, returnTypeId, rowBoolVec));
  7345. }
  7346. // Create a new vector that is the concatenation of results of all rows.
  7347. uint32_t boolId = theBuilder.getBoolType();
  7348. uint32_t vecOfBoolsId = theBuilder.getVecType(boolId, matRowCount);
  7349. const uint32_t rowResultsId =
  7350. theBuilder.createCompositeConstruct(vecOfBoolsId, rowResults);
  7351. // Run OpAny/OpAll on the newly-created vector.
  7352. return theBuilder.createUnaryOp(spvOp, returnTypeId, rowResultsId);
  7353. }
  7354. }
  7355. // All types should be handled already.
  7356. llvm_unreachable("Unknown argument type passed to all()/any().");
  7357. return 0;
  7358. }
  7359. uint32_t SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
  7360. // This function handles 'asint', 'asuint', 'asfloat', and 'asdouble'.
  7361. // Method 1: ret asint(arg)
  7362. // arg component type = {float, uint}
  7363. // arg template type = {scalar, vector, matrix}
  7364. // ret template type = same as arg template type.
  7365. // ret component type = int
  7366. // Method 2: ret asuint(arg)
  7367. // arg component type = {float, int}
  7368. // arg template type = {scalar, vector, matrix}
  7369. // ret template type = same as arg template type.
  7370. // ret component type = uint
  7371. // Method 3: ret asfloat(arg)
  7372. // arg component type = {float, uint, int}
  7373. // arg template type = {scalar, vector, matrix}
  7374. // ret template type = same as arg template type.
  7375. // ret component type = float
  7376. // Method 4: double asdouble(uint lowbits, uint highbits)
  7377. // Method 5: double2 asdouble(uint2 lowbits, uint2 highbits)
  7378. // Method 6:
  7379. // void asuint(
  7380. // in double value,
  7381. // out uint lowbits,
  7382. // out uint highbits
  7383. // );
  7384. const QualType returnType = callExpr->getType();
  7385. const uint32_t numArgs = callExpr->getNumArgs();
  7386. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  7387. const Expr *arg0 = callExpr->getArg(0);
  7388. const QualType argType = arg0->getType();
  7389. // Method 3 return type may be the same as arg type, so it would be a no-op.
  7390. if (typeTranslator.isSameType(returnType, argType))
  7391. return doExpr(arg0);
  7392. switch (numArgs) {
  7393. case 1: {
  7394. // Handling Method 1, 2, and 3.
  7395. const auto argId = doExpr(arg0);
  7396. QualType fromElemType = {};
  7397. uint32_t numRows = 0, numCols = 0;
  7398. // For non-matrix arguments (scalar or vector), just do an OpBitCast.
  7399. if (!TypeTranslator::isMxNMatrix(argType, &fromElemType, &numRows,
  7400. &numCols)) {
  7401. return theBuilder.createUnaryOp(spv::Op::OpBitcast, returnTypeId, argId);
  7402. }
  7403. // Input or output type is a matrix.
  7404. const QualType toElemType = hlsl::GetHLSLMatElementType(returnType);
  7405. llvm::SmallVector<uint32_t, 4> castedRows;
  7406. const auto fromVecQualType =
  7407. astContext.getExtVectorType(fromElemType, numCols);
  7408. const auto toVecQualType = astContext.getExtVectorType(toElemType, numCols);
  7409. const auto fromVecTypeId = typeTranslator.translateType(fromVecQualType);
  7410. const auto toVecTypeId = typeTranslator.translateType(toVecQualType);
  7411. for (uint32_t row = 0; row < numRows; ++row) {
  7412. const auto rowId =
  7413. theBuilder.createCompositeExtract(fromVecTypeId, argId, {row});
  7414. castedRows.push_back(
  7415. theBuilder.createUnaryOp(spv::Op::OpBitcast, toVecTypeId, rowId));
  7416. }
  7417. return theBuilder.createCompositeConstruct(returnTypeId, castedRows);
  7418. }
  7419. case 2: {
  7420. const uint32_t lowbits = doExpr(arg0);
  7421. const uint32_t highbits = doExpr(callExpr->getArg(1));
  7422. const uint32_t uintType = theBuilder.getUint32Type();
  7423. const uint32_t doubleType = theBuilder.getFloat64Type();
  7424. // Handling Method 4
  7425. if (argType->isUnsignedIntegerType()) {
  7426. const uint32_t uintVec2Type = theBuilder.getVecType(uintType, 2);
  7427. const uint32_t operand = theBuilder.createCompositeConstruct(
  7428. uintVec2Type, {lowbits, highbits});
  7429. return theBuilder.createUnaryOp(spv::Op::OpBitcast, doubleType, operand);
  7430. }
  7431. // Handling Method 5
  7432. else {
  7433. const uint32_t uintVec4Type = theBuilder.getVecType(uintType, 4);
  7434. const uint32_t doubleVec2Type = theBuilder.getVecType(doubleType, 2);
  7435. const uint32_t operand = theBuilder.createVectorShuffle(
  7436. uintVec4Type, lowbits, highbits, {0, 2, 1, 3});
  7437. return theBuilder.createUnaryOp(spv::Op::OpBitcast, doubleVec2Type,
  7438. operand);
  7439. }
  7440. }
  7441. case 3: {
  7442. // Handling Method 6.
  7443. const uint32_t value = doExpr(arg0);
  7444. const uint32_t lowbits = doExpr(callExpr->getArg(1));
  7445. const uint32_t highbits = doExpr(callExpr->getArg(2));
  7446. const uint32_t uintType = theBuilder.getUint32Type();
  7447. const uint32_t uintVec2Type = theBuilder.getVecType(uintType, 2);
  7448. const uint32_t vecResult =
  7449. theBuilder.createUnaryOp(spv::Op::OpBitcast, uintVec2Type, value);
  7450. theBuilder.createStore(
  7451. lowbits, theBuilder.createCompositeExtract(uintType, vecResult, {0}));
  7452. theBuilder.createStore(
  7453. highbits, theBuilder.createCompositeExtract(uintType, vecResult, {1}));
  7454. return 0;
  7455. }
  7456. default:
  7457. emitError("unrecognized signature for %0 intrinsic function",
  7458. callExpr->getExprLoc())
  7459. << callExpr->getDirectCallee()->getName();
  7460. return 0;
  7461. }
  7462. }
  7463. uint32_t SPIRVEmitter::processD3DCOLORtoUBYTE4(const CallExpr *callExpr) {
  7464. // Should take a float4 and return an int4 by doing:
  7465. // int4 result = input.zyxw * 255.001953;
  7466. // Maximum float precision makes the scaling factor 255.002.
  7467. const auto arg = callExpr->getArg(0);
  7468. const auto argId = doExpr(arg);
  7469. const auto argTypeId = typeTranslator.translateType(arg->getType());
  7470. const auto swizzle =
  7471. theBuilder.createVectorShuffle(argTypeId, argId, argId, {2, 1, 0, 3});
  7472. const auto scaled = theBuilder.createBinaryOp(
  7473. spv::Op::OpVectorTimesScalar, argTypeId, swizzle,
  7474. theBuilder.getConstantFloat32(255.002f));
  7475. return castToInt(scaled, arg->getType(), callExpr->getType(),
  7476. callExpr->getExprLoc());
  7477. }
  7478. uint32_t SPIRVEmitter::processIntrinsicIsFinite(const CallExpr *callExpr) {
  7479. // Since OpIsFinite needs the Kernel capability, translation is instead done
  7480. // using OpIsNan and OpIsInf:
  7481. // isFinite = !(isNan || isInf)
  7482. const auto arg = doExpr(callExpr->getArg(0));
  7483. const auto returnType = typeTranslator.translateType(callExpr->getType());
  7484. const auto isNan =
  7485. theBuilder.createUnaryOp(spv::Op::OpIsNan, returnType, arg);
  7486. const auto isInf =
  7487. theBuilder.createUnaryOp(spv::Op::OpIsInf, returnType, arg);
  7488. const auto isNanOrInf =
  7489. theBuilder.createBinaryOp(spv::Op::OpLogicalOr, returnType, isNan, isInf);
  7490. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, returnType,
  7491. isNanOrInf);
  7492. }
  7493. uint32_t SPIRVEmitter::processIntrinsicSinCos(const CallExpr *callExpr) {
  7494. // Since there is no sincos equivalent in SPIR-V, we need to perform Sin
  7495. // once and Cos once. We can reuse existing Sine/Cosine handling functions.
  7496. CallExpr *sincosExpr =
  7497. new (astContext) CallExpr(astContext, Stmt::StmtClass::NoStmtClass, {});
  7498. sincosExpr->setType(callExpr->getArg(0)->getType());
  7499. sincosExpr->setNumArgs(astContext, 1);
  7500. sincosExpr->setArg(0, const_cast<Expr *>(callExpr->getArg(0)));
  7501. // Perform Sin and store results in argument 1.
  7502. const uint32_t sin =
  7503. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Sin,
  7504. /*actPerRowForMatrices*/ true);
  7505. theBuilder.createStore(doExpr(callExpr->getArg(1)), sin);
  7506. // Perform Cos and store results in argument 2.
  7507. const uint32_t cos =
  7508. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Cos,
  7509. /*actPerRowForMatrices*/ true);
  7510. theBuilder.createStore(doExpr(callExpr->getArg(2)), cos);
  7511. return 0;
  7512. }
  7513. uint32_t SPIRVEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
  7514. const auto *arg = callExpr->getArg(0);
  7515. const auto argId = doExpr(arg);
  7516. const auto argType = arg->getType();
  7517. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  7518. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  7519. if (argType->isFloatingType()) {
  7520. const uint32_t floatZero = getValueZero(argType);
  7521. const uint32_t floatOne = getValueOne(argType);
  7522. return theBuilder.createExtInst(returnType, glslInstSetId,
  7523. GLSLstd450::GLSLstd450FClamp,
  7524. {argId, floatZero, floatOne});
  7525. }
  7526. QualType elemType = {};
  7527. uint32_t vecSize = 0;
  7528. if (TypeTranslator::isVectorType(argType, &elemType, &vecSize)) {
  7529. const uint32_t vecZero = getVecValueZero(elemType, vecSize);
  7530. const uint32_t vecOne = getVecValueOne(elemType, vecSize);
  7531. return theBuilder.createExtInst(returnType, glslInstSetId,
  7532. GLSLstd450::GLSLstd450FClamp,
  7533. {argId, vecZero, vecOne});
  7534. }
  7535. uint32_t numRows = 0, numCols = 0;
  7536. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  7537. const uint32_t vecZero = getVecValueZero(elemType, numCols);
  7538. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  7539. const auto actOnEachVec = [this, vecZero, vecOne, glslInstSetId](
  7540. uint32_t /*index*/, uint32_t vecType,
  7541. uint32_t curRowId) {
  7542. return theBuilder.createExtInst(vecType, glslInstSetId,
  7543. GLSLstd450::GLSLstd450FClamp,
  7544. {curRowId, vecZero, vecOne});
  7545. };
  7546. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  7547. }
  7548. emitError("invalid argument type passed to saturate intrinsic function",
  7549. callExpr->getExprLoc());
  7550. return 0;
  7551. }
  7552. uint32_t SPIRVEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
  7553. // Import the GLSL.std.450 extended instruction set.
  7554. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  7555. const Expr *arg = callExpr->getArg(0);
  7556. const QualType returnType = callExpr->getType();
  7557. const QualType argType = arg->getType();
  7558. assert(isFloatOrVecMatOfFloatType(argType));
  7559. const uint32_t argTypeId = typeTranslator.translateType(argType);
  7560. const uint32_t argId = doExpr(arg);
  7561. uint32_t floatSignResultId = 0;
  7562. // For matrices, we can perform the instruction on each vector of the matrix.
  7563. if (TypeTranslator::isMxNMatrix(argType)) {
  7564. const auto actOnEachVec = [this, glslInstSetId](uint32_t /*index*/,
  7565. uint32_t vecType,
  7566. uint32_t curRowId) {
  7567. return theBuilder.createExtInst(vecType, glslInstSetId,
  7568. GLSLstd450::GLSLstd450FSign, {curRowId});
  7569. };
  7570. floatSignResultId = processEachVectorInMatrix(arg, argId, actOnEachVec);
  7571. } else {
  7572. floatSignResultId = theBuilder.createExtInst(
  7573. argTypeId, glslInstSetId, GLSLstd450::GLSLstd450FSign, {argId});
  7574. }
  7575. return castToInt(floatSignResultId, arg->getType(), returnType,
  7576. arg->getExprLoc());
  7577. }
  7578. uint32_t SPIRVEmitter::processIntrinsicF16ToF32(const CallExpr *callExpr) {
  7579. // f16tof32() takes in (vector of) uint and returns (vector of) float.
  7580. // The frontend should guarantee that by inserting implicit casts.
  7581. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  7582. const uint32_t f32TypeId = theBuilder.getFloat32Type();
  7583. const uint32_t u32TypeId = theBuilder.getUint32Type();
  7584. const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
  7585. const auto *arg = callExpr->getArg(0);
  7586. const uint32_t argId = doExpr(arg);
  7587. uint32_t elemCount = {};
  7588. if (TypeTranslator::isVectorType(arg->getType(), nullptr, &elemCount)) {
  7589. // The input is a vector. We need to handle each element separately.
  7590. llvm::SmallVector<uint32_t, 4> elements;
  7591. for (uint32_t i = 0; i < elemCount; ++i) {
  7592. const uint32_t srcElem =
  7593. theBuilder.createCompositeExtract(u32TypeId, argId, {i});
  7594. const uint32_t convert = theBuilder.createExtInst(
  7595. v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, srcElem);
  7596. elements.push_back(
  7597. theBuilder.createCompositeExtract(f32TypeId, convert, {0}));
  7598. }
  7599. return theBuilder.createCompositeConstruct(
  7600. theBuilder.getVecType(f32TypeId, elemCount), elements);
  7601. }
  7602. const uint32_t convert = theBuilder.createExtInst(
  7603. v2f32TypeId, glsl, GLSLstd450::GLSLstd450UnpackHalf2x16, argId);
  7604. // f16tof32() converts the float16 stored in the low-half of the uint to
  7605. // a float. So just need to return the first component.
  7606. return theBuilder.createCompositeExtract(f32TypeId, convert, {0});
  7607. }
  7608. uint32_t SPIRVEmitter::processIntrinsicF32ToF16(const CallExpr *callExpr) {
  7609. // f32tof16() takes in (vector of) float and returns (vector of) uint.
  7610. // The frontend should guarantee that by inserting implicit casts.
  7611. const uint32_t glsl = theBuilder.getGLSLExtInstSet();
  7612. const uint32_t f32TypeId = theBuilder.getFloat32Type();
  7613. const uint32_t u32TypeId = theBuilder.getUint32Type();
  7614. const uint32_t v2f32TypeId = theBuilder.getVecType(f32TypeId, 2);
  7615. const uint32_t zero = theBuilder.getConstantFloat32(0);
  7616. const auto *arg = callExpr->getArg(0);
  7617. const uint32_t argId = doExpr(arg);
  7618. uint32_t elemCount = {};
  7619. if (TypeTranslator::isVectorType(arg->getType(), nullptr, &elemCount)) {
  7620. // The input is a vector. We need to handle each element separately.
  7621. llvm::SmallVector<uint32_t, 4> elements;
  7622. for (uint32_t i = 0; i < elemCount; ++i) {
  7623. const uint32_t srcElem =
  7624. theBuilder.createCompositeExtract(f32TypeId, argId, {i});
  7625. const uint32_t srcVec =
  7626. theBuilder.createCompositeConstruct(v2f32TypeId, {srcElem, zero});
  7627. elements.push_back(theBuilder.createExtInst(
  7628. u32TypeId, glsl, GLSLstd450::GLSLstd450PackHalf2x16, srcVec));
  7629. }
  7630. return theBuilder.createCompositeConstruct(
  7631. theBuilder.getVecType(u32TypeId, elemCount), elements);
  7632. }
  7633. // f16tof32() stores the float into the low-half of the uint. So we need
  7634. // to supply another zero to take the other half.
  7635. const uint32_t srcVec =
  7636. theBuilder.createCompositeConstruct(v2f32TypeId, {argId, zero});
  7637. return theBuilder.createExtInst(u32TypeId, glsl,
  7638. GLSLstd450::GLSLstd450PackHalf2x16, srcVec);
  7639. }
  7640. uint32_t SPIRVEmitter::processIntrinsicUsingSpirvInst(
  7641. const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
  7642. // Certain opcodes are only allowed in pixel shader
  7643. if (!shaderModel.IsPS())
  7644. switch (opcode) {
  7645. case spv::Op::OpDPdx:
  7646. case spv::Op::OpDPdy:
  7647. case spv::Op::OpDPdxFine:
  7648. case spv::Op::OpDPdyFine:
  7649. case spv::Op::OpDPdxCoarse:
  7650. case spv::Op::OpDPdyCoarse:
  7651. case spv::Op::OpFwidth:
  7652. case spv::Op::OpFwidthFine:
  7653. case spv::Op::OpFwidthCoarse:
  7654. needsLegalization = true;
  7655. break;
  7656. default:
  7657. // Only the given opcodes need legalization. Anything else should preserve
  7658. // previous.
  7659. break;
  7660. }
  7661. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  7662. if (callExpr->getNumArgs() == 1u) {
  7663. const Expr *arg = callExpr->getArg(0);
  7664. const uint32_t argId = doExpr(arg);
  7665. // If the instruction does not operate on matrices, we can perform the
  7666. // instruction on each vector of the matrix.
  7667. if (actPerRowForMatrices && TypeTranslator::isMxNMatrix(arg->getType())) {
  7668. const auto actOnEachVec = [this, opcode](uint32_t /*index*/,
  7669. uint32_t vecType,
  7670. uint32_t curRowId) {
  7671. return theBuilder.createUnaryOp(opcode, vecType, curRowId);
  7672. };
  7673. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  7674. }
  7675. return theBuilder.createUnaryOp(opcode, returnType, argId);
  7676. } else if (callExpr->getNumArgs() == 2u) {
  7677. const Expr *arg0 = callExpr->getArg(0);
  7678. const uint32_t arg0Id = doExpr(arg0);
  7679. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  7680. // If the instruction does not operate on matrices, we can perform the
  7681. // instruction on each vector of the matrix.
  7682. if (actPerRowForMatrices && TypeTranslator::isMxNMatrix(arg0->getType())) {
  7683. const auto actOnEachVec = [this, opcode, arg1Id](uint32_t index,
  7684. uint32_t vecType,
  7685. uint32_t arg0RowId) {
  7686. const uint32_t arg1RowId =
  7687. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  7688. return theBuilder.createBinaryOp(opcode, vecType, arg0RowId, arg1RowId);
  7689. };
  7690. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  7691. }
  7692. return theBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id);
  7693. }
  7694. emitError("unsupported %0 intrinsic function", callExpr->getExprLoc())
  7695. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  7696. return 0;
  7697. }
  7698. uint32_t SPIRVEmitter::processIntrinsicUsingGLSLInst(
  7699. const CallExpr *callExpr, GLSLstd450 opcode, bool actPerRowForMatrices) {
  7700. // Import the GLSL.std.450 extended instruction set.
  7701. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  7702. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  7703. if (callExpr->getNumArgs() == 1u) {
  7704. const Expr *arg = callExpr->getArg(0);
  7705. const uint32_t argId = doExpr(arg);
  7706. // If the instruction does not operate on matrices, we can perform the
  7707. // instruction on each vector of the matrix.
  7708. if (actPerRowForMatrices && TypeTranslator::isMxNMatrix(arg->getType())) {
  7709. const auto actOnEachVec = [this, glslInstSetId,
  7710. opcode](uint32_t /*index*/, uint32_t vecType,
  7711. uint32_t curRowId) {
  7712. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  7713. {curRowId});
  7714. };
  7715. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  7716. }
  7717. return theBuilder.createExtInst(returnType, glslInstSetId, opcode, {argId});
  7718. } else if (callExpr->getNumArgs() == 2u) {
  7719. const Expr *arg0 = callExpr->getArg(0);
  7720. const uint32_t arg0Id = doExpr(arg0);
  7721. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  7722. // If the instruction does not operate on matrices, we can perform the
  7723. // instruction on each vector of the matrix.
  7724. if (actPerRowForMatrices && TypeTranslator::isMxNMatrix(arg0->getType())) {
  7725. const auto actOnEachVec = [this, glslInstSetId, opcode,
  7726. arg1Id](uint32_t index, uint32_t vecType,
  7727. uint32_t arg0RowId) {
  7728. const uint32_t arg1RowId =
  7729. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  7730. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  7731. {arg0RowId, arg1RowId});
  7732. };
  7733. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  7734. }
  7735. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  7736. {arg0Id, arg1Id});
  7737. } else if (callExpr->getNumArgs() == 3u) {
  7738. const Expr *arg0 = callExpr->getArg(0);
  7739. const uint32_t arg0Id = doExpr(arg0);
  7740. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  7741. const uint32_t arg2Id = doExpr(callExpr->getArg(2));
  7742. // If the instruction does not operate on matrices, we can perform the
  7743. // instruction on each vector of the matrix.
  7744. if (actPerRowForMatrices && TypeTranslator::isMxNMatrix(arg0->getType())) {
  7745. const auto actOnEachVec = [this, glslInstSetId, opcode, arg1Id,
  7746. arg2Id](uint32_t index, uint32_t vecType,
  7747. uint32_t arg0RowId) {
  7748. const uint32_t arg1RowId =
  7749. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  7750. const uint32_t arg2RowId =
  7751. theBuilder.createCompositeExtract(vecType, arg2Id, {index});
  7752. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  7753. {arg0RowId, arg1RowId, arg2RowId});
  7754. };
  7755. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  7756. }
  7757. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  7758. {arg0Id, arg1Id, arg2Id});
  7759. }
  7760. emitError("unsupported %0 intrinsic function", callExpr->getExprLoc())
  7761. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  7762. return 0;
  7763. }
  7764. uint32_t SPIRVEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
  7765. // Since there is no log10 instruction in SPIR-V, we can use:
  7766. // log10(x) = log2(x) * ( 1 / log2(10) )
  7767. // 1 / log2(10) = 0.30103
  7768. const auto scale = theBuilder.getConstantFloat32(0.30103f);
  7769. const auto log2 =
  7770. processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450Log2, true);
  7771. const auto returnType = callExpr->getType();
  7772. const auto returnTypeId = typeTranslator.translateType(returnType);
  7773. spv::Op scaleOp = TypeTranslator::isScalarType(returnType)
  7774. ? spv::Op::OpFMul
  7775. : TypeTranslator::isVectorType(returnType)
  7776. ? spv::Op::OpVectorTimesScalar
  7777. : spv::Op::OpMatrixTimesScalar;
  7778. return theBuilder.createBinaryOp(scaleOp, returnTypeId, log2, scale);
  7779. }
  7780. uint32_t SPIRVEmitter::getValueZero(QualType type) {
  7781. {
  7782. QualType scalarType = {};
  7783. if (TypeTranslator::isScalarType(type, &scalarType)) {
  7784. if (scalarType->isSignedIntegerType()) {
  7785. return theBuilder.getConstantInt32(0);
  7786. }
  7787. if (scalarType->isUnsignedIntegerType()) {
  7788. return theBuilder.getConstantUint32(0);
  7789. }
  7790. if (scalarType->isFloatingType()) {
  7791. return theBuilder.getConstantFloat32(0.0);
  7792. }
  7793. }
  7794. }
  7795. {
  7796. QualType elemType = {};
  7797. uint32_t size = {};
  7798. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  7799. return getVecValueZero(elemType, size);
  7800. }
  7801. }
  7802. {
  7803. QualType elemType = {};
  7804. uint32_t rowCount = 0, colCount = 0;
  7805. if (TypeTranslator::isMxNMatrix(type, &elemType, &rowCount, &colCount)) {
  7806. const auto row = getVecValueZero(elemType, colCount);
  7807. llvm::SmallVector<uint32_t, 4> rows((size_t)rowCount, row);
  7808. return theBuilder.createCompositeConstruct(
  7809. typeTranslator.translateType(type), rows);
  7810. }
  7811. }
  7812. emitError("getting value 0 for type %0 unimplemented", {})
  7813. << type.getAsString();
  7814. return 0;
  7815. }
  7816. uint32_t SPIRVEmitter::getVecValueZero(QualType elemType, uint32_t size) {
  7817. const uint32_t elemZeroId = getValueZero(elemType);
  7818. if (size == 1)
  7819. return elemZeroId;
  7820. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemZeroId);
  7821. const uint32_t vecType =
  7822. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  7823. return theBuilder.getConstantComposite(vecType, elements);
  7824. }
  7825. uint32_t SPIRVEmitter::getValueOne(QualType type) {
  7826. {
  7827. QualType scalarType = {};
  7828. if (TypeTranslator::isScalarType(type, &scalarType)) {
  7829. if (scalarType->isBooleanType()) {
  7830. return theBuilder.getConstantBool(true);
  7831. }
  7832. const auto bitWidth = typeTranslator.getElementSpirvBitwidth(scalarType);
  7833. if (scalarType->isSignedIntegerType()) {
  7834. switch (bitWidth) {
  7835. case 16:
  7836. return theBuilder.getConstantInt16(1);
  7837. case 32:
  7838. return theBuilder.getConstantInt32(1);
  7839. case 64:
  7840. return theBuilder.getConstantInt64(1);
  7841. }
  7842. }
  7843. if (scalarType->isUnsignedIntegerType()) {
  7844. switch (bitWidth) {
  7845. case 16:
  7846. return theBuilder.getConstantUint16(1);
  7847. case 32:
  7848. return theBuilder.getConstantUint32(1);
  7849. case 64:
  7850. return theBuilder.getConstantUint64(1);
  7851. }
  7852. }
  7853. if (scalarType->isFloatingType()) {
  7854. switch (bitWidth) {
  7855. case 16:
  7856. return theBuilder.getConstantFloat16(1);
  7857. case 32:
  7858. return theBuilder.getConstantFloat32(1.0);
  7859. case 64:
  7860. return theBuilder.getConstantFloat64(1.0);
  7861. }
  7862. }
  7863. }
  7864. }
  7865. {
  7866. QualType elemType = {};
  7867. uint32_t size = {};
  7868. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  7869. return getVecValueOne(elemType, size);
  7870. }
  7871. }
  7872. emitError("getting value 1 for type %0 unimplemented", {}) << type;
  7873. return 0;
  7874. }
  7875. uint32_t SPIRVEmitter::getVecValueOne(QualType elemType, uint32_t size) {
  7876. const uint32_t elemOneId = getValueOne(elemType);
  7877. if (size == 1)
  7878. return elemOneId;
  7879. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemOneId);
  7880. const uint32_t vecType =
  7881. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  7882. return theBuilder.getConstantComposite(vecType, elements);
  7883. }
  7884. uint32_t SPIRVEmitter::getMatElemValueOne(QualType type) {
  7885. assert(hlsl::IsHLSLMatType(type));
  7886. const auto elemType = hlsl::GetHLSLMatElementType(type);
  7887. uint32_t rowCount = 0, colCount = 0;
  7888. hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
  7889. if (rowCount == 1 && colCount == 1)
  7890. return getValueOne(elemType);
  7891. if (colCount == 1)
  7892. return getVecValueOne(elemType, rowCount);
  7893. return getVecValueOne(elemType, colCount);
  7894. }
  7895. uint32_t SPIRVEmitter::getMaskForBitwidthValue(QualType type) {
  7896. QualType elemType = {};
  7897. uint32_t count = 1;
  7898. if (TypeTranslator::isScalarType(type, &elemType) ||
  7899. TypeTranslator::isVectorType(type, &elemType, &count)) {
  7900. const auto bitwidth = typeTranslator.getElementSpirvBitwidth(elemType);
  7901. uint32_t mask = 0;
  7902. uint32_t elemTypeId = 0;
  7903. switch (bitwidth) {
  7904. case 16:
  7905. mask = theBuilder.getConstantUint16(bitwidth - 1);
  7906. elemTypeId = theBuilder.getUint16Type();
  7907. break;
  7908. case 32:
  7909. mask = theBuilder.getConstantUint32(bitwidth - 1);
  7910. elemTypeId = theBuilder.getUint32Type();
  7911. break;
  7912. case 64:
  7913. mask = theBuilder.getConstantUint64(bitwidth - 1);
  7914. elemTypeId = theBuilder.getUint64Type();
  7915. break;
  7916. default:
  7917. assert(false && "this method only supports 16-, 32-, and 64-bit types");
  7918. }
  7919. if (count == 1)
  7920. return mask;
  7921. const uint32_t typeId = theBuilder.getVecType(elemTypeId, count);
  7922. llvm::SmallVector<uint32_t, 4> elements(size_t(count), mask);
  7923. return theBuilder.getConstantComposite(typeId, elements);
  7924. }
  7925. assert(false && "this method only supports scalars and vectors");
  7926. return 0;
  7927. }
  7928. uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
  7929. const QualType targetType) {
  7930. uint32_t result = 0;
  7931. // Provide a hint to the typeTranslator that if a literal is discovered, its
  7932. // intended usage is targetType.
  7933. TypeTranslator::LiteralTypeHint hint(typeTranslator, targetType);
  7934. if (targetType->isBooleanType()) {
  7935. result = theBuilder.getConstantBool(value.getInt().getBoolValue(),
  7936. isSpecConstantMode);
  7937. } else if (targetType->isIntegerType()) {
  7938. result = translateAPInt(value.getInt(), targetType);
  7939. } else if (targetType->isFloatingType()) {
  7940. result = translateAPFloat(value.getFloat(), targetType);
  7941. } else if (hlsl::IsHLSLVecType(targetType)) {
  7942. const uint32_t vecType = typeTranslator.translateType(targetType);
  7943. const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
  7944. const auto numElements = value.getVectorLength();
  7945. // Special case for vectors of size 1. SPIR-V doesn't support this vector
  7946. // size so we need to translate it to scalar values.
  7947. if (numElements == 1) {
  7948. result = translateAPValue(value.getVectorElt(0), elemType);
  7949. } else {
  7950. llvm::SmallVector<uint32_t, 4> elements;
  7951. for (uint32_t i = 0; i < numElements; ++i) {
  7952. elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
  7953. }
  7954. result = theBuilder.getConstantComposite(vecType, elements);
  7955. }
  7956. }
  7957. if (result)
  7958. return result;
  7959. emitError("APValue of type %0 unimplemented", {}) << value.getKind();
  7960. value.dump();
  7961. return 0;
  7962. }
  7963. uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
  7964. QualType targetType) {
  7965. targetType = typeTranslator.getIntendedLiteralType(targetType);
  7966. const auto targetTypeBitWidth = astContext.getTypeSize(targetType);
  7967. const bool isSigned = targetType->isSignedIntegerType();
  7968. switch (targetTypeBitWidth) {
  7969. case 16: {
  7970. if (spirvOptions.enable16BitTypes) {
  7971. if (isSigned) {
  7972. return theBuilder.getConstantInt16(
  7973. static_cast<int16_t>(intValue.getSExtValue()));
  7974. } else {
  7975. return theBuilder.getConstantUint16(
  7976. static_cast<uint16_t>(intValue.getZExtValue()));
  7977. }
  7978. } else {
  7979. // If enable16BitTypes option is not true, treat as 32-bit integer.
  7980. if (isSigned)
  7981. return theBuilder.getConstantInt32(
  7982. static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
  7983. else
  7984. return theBuilder.getConstantUint32(
  7985. static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
  7986. }
  7987. }
  7988. case 32: {
  7989. if (isSigned) {
  7990. if (!intValue.isSignedIntN(32)) {
  7991. emitError("evaluating integer literal %0 as a 32-bit integer loses "
  7992. "inforamtion",
  7993. {})
  7994. << std::to_string(intValue.getSExtValue());
  7995. return 0;
  7996. }
  7997. return theBuilder.getConstantInt32(
  7998. static_cast<int32_t>(intValue.getSExtValue()), isSpecConstantMode);
  7999. } else {
  8000. if (!intValue.isIntN(32)) {
  8001. emitError("evaluating integer literal %0 as a 32-bit integer loses "
  8002. "inforamtion",
  8003. {})
  8004. << std::to_string(intValue.getZExtValue());
  8005. return 0;
  8006. }
  8007. return theBuilder.getConstantUint32(
  8008. static_cast<uint32_t>(intValue.getZExtValue()), isSpecConstantMode);
  8009. }
  8010. }
  8011. case 64: {
  8012. if (isSigned)
  8013. return theBuilder.getConstantInt64(intValue.getSExtValue());
  8014. else
  8015. return theBuilder.getConstantUint64(intValue.getZExtValue());
  8016. }
  8017. }
  8018. emitError("APInt for target bitwidth %0 unimplemented", {})
  8019. << astContext.getIntWidth(targetType);
  8020. return 0;
  8021. }
  8022. bool SPIRVEmitter::isLiteralLargerThan32Bits(const Expr *expr) {
  8023. if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  8024. const bool isSigned = expr->getType()->isSignedIntegerType();
  8025. const llvm::APInt &value = intLiteral->getValue();
  8026. return (isSigned && !value.isSignedIntN(32)) ||
  8027. (!isSigned && !value.isIntN(32));
  8028. }
  8029. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  8030. llvm::APFloat value = floatLiteral->getValue();
  8031. const auto &semantics = value.getSemantics();
  8032. // regular 'half' and 'float' can be represented in 32 bits.
  8033. if (&semantics == &llvm::APFloat::IEEEsingle ||
  8034. &semantics == &llvm::APFloat::IEEEhalf)
  8035. return true;
  8036. // See if 'double' value can be represented in 32 bits without losing info.
  8037. bool losesInfo = false;
  8038. const auto convertStatus =
  8039. value.convert(llvm::APFloat::IEEEsingle,
  8040. llvm::APFloat::rmNearestTiesToEven, &losesInfo);
  8041. if (convertStatus != llvm::APFloat::opOK &&
  8042. convertStatus != llvm::APFloat::opInexact)
  8043. return true;
  8044. }
  8045. return false;
  8046. }
  8047. uint32_t SPIRVEmitter::tryToEvaluateAsInt32(const llvm::APInt &intValue,
  8048. bool isSigned) {
  8049. if (isSigned && intValue.isSignedIntN(32)) {
  8050. return theBuilder.getConstantInt32(
  8051. static_cast<int32_t>(intValue.getSExtValue()));
  8052. }
  8053. if (!isSigned && intValue.isIntN(32)) {
  8054. return theBuilder.getConstantUint32(
  8055. static_cast<uint32_t>(intValue.getZExtValue()));
  8056. }
  8057. // Couldn't evaluate as a 32-bit int without losing information.
  8058. return 0;
  8059. }
  8060. uint32_t SPIRVEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
  8061. const auto &semantics = floatValue.getSemantics();
  8062. // If the given value is already a 32-bit float, there is no need to convert.
  8063. if (&semantics == &llvm::APFloat::IEEEsingle) {
  8064. return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
  8065. isSpecConstantMode);
  8066. }
  8067. // Try to see if this literal float can be represented in 32-bit.
  8068. // Since the convert function below may modify the fp value, we call it on a
  8069. // temporary copy.
  8070. llvm::APFloat eval = floatValue;
  8071. bool losesInfo = false;
  8072. const auto convertStatus =
  8073. eval.convert(llvm::APFloat::IEEEsingle,
  8074. llvm::APFloat::rmNearestTiesToEven, &losesInfo);
  8075. if (convertStatus == llvm::APFloat::opOK && !losesInfo)
  8076. return theBuilder.getConstantFloat32(eval.convertToFloat());
  8077. // Couldn't evaluate as a 32-bit float without losing information.
  8078. return 0;
  8079. }
  8080. uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
  8081. QualType targetType) {
  8082. using llvm::APFloat;
  8083. const auto originalValue = floatValue;
  8084. const auto valueBitwidth = APFloat::getSizeInBits(floatValue.getSemantics());
  8085. // Find out the target bitwidth.
  8086. targetType = typeTranslator.getIntendedLiteralType(targetType);
  8087. auto targetBitwidth =
  8088. APFloat::getSizeInBits(astContext.getFloatTypeSemantics(targetType));
  8089. // If 16-bit types are not enabled, treat them as 32-bit float.
  8090. if (targetBitwidth == 16 && !spirvOptions.enable16BitTypes)
  8091. targetBitwidth = 32;
  8092. if (targetBitwidth != valueBitwidth) {
  8093. bool losesInfo = false;
  8094. const llvm::fltSemantics &targetSemantics =
  8095. targetBitwidth == 16
  8096. ? APFloat::IEEEhalf
  8097. : targetBitwidth == 32 ? APFloat::IEEEsingle : APFloat::IEEEdouble;
  8098. const auto status = floatValue.convert(
  8099. targetSemantics, APFloat::roundingMode::rmTowardZero, &losesInfo);
  8100. if (status != APFloat::opStatus::opOK &&
  8101. status != APFloat::opStatus::opInexact) {
  8102. emitError(
  8103. "evaluating float literal %0 at a lower bitwidth loses information",
  8104. {})
  8105. // Converting from 16bit to 32/64-bit won't lose information.
  8106. // So only 32/64-bit values can reach here.
  8107. << std::to_string(valueBitwidth == 32
  8108. ? originalValue.convertToFloat()
  8109. : originalValue.convertToDouble());
  8110. return 0;
  8111. }
  8112. }
  8113. switch (targetBitwidth) {
  8114. case 16:
  8115. return theBuilder.getConstantFloat16(
  8116. static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
  8117. case 32:
  8118. return theBuilder.getConstantFloat32(floatValue.convertToFloat(),
  8119. isSpecConstantMode);
  8120. case 64:
  8121. return theBuilder.getConstantFloat64(floatValue.convertToDouble());
  8122. default:
  8123. break;
  8124. }
  8125. emitError("APFloat for target bitwidth %0 unimplemented", {})
  8126. << targetBitwidth;
  8127. return 0;
  8128. }
  8129. uint32_t SPIRVEmitter::tryToEvaluateAsConst(const Expr *expr) {
  8130. Expr::EvalResult evalResult;
  8131. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  8132. !evalResult.HasSideEffects) {
  8133. return translateAPValue(evalResult.Val, expr->getType());
  8134. }
  8135. return 0;
  8136. }
  8137. spv::ExecutionModel
  8138. SPIRVEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
  8139. // DXIL Models are:
  8140. // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
  8141. // vs_<version> : Vertex Shader : Vertex Shader
  8142. // hs_<version> : Hull Shader : Tassellation Control Shader
  8143. // ds_<version> : Domain Shader : Tessellation Evaluation Shader
  8144. // gs_<version> : Geometry Shader : Geometry Shader
  8145. // ps_<version> : Pixel Shader : Fragment Shader
  8146. // cs_<version> : Compute Shader : Compute Shader
  8147. switch (model.GetKind()) {
  8148. case hlsl::ShaderModel::Kind::Vertex:
  8149. return spv::ExecutionModel::Vertex;
  8150. case hlsl::ShaderModel::Kind::Hull:
  8151. return spv::ExecutionModel::TessellationControl;
  8152. case hlsl::ShaderModel::Kind::Domain:
  8153. return spv::ExecutionModel::TessellationEvaluation;
  8154. case hlsl::ShaderModel::Kind::Geometry:
  8155. return spv::ExecutionModel::Geometry;
  8156. case hlsl::ShaderModel::Kind::Pixel:
  8157. return spv::ExecutionModel::Fragment;
  8158. case hlsl::ShaderModel::Kind::Compute:
  8159. return spv::ExecutionModel::GLCompute;
  8160. default:
  8161. break;
  8162. }
  8163. llvm_unreachable("unknown shader model");
  8164. }
  8165. void SPIRVEmitter::AddRequiredCapabilitiesForShaderModel() {
  8166. if (shaderModel.IsHS() || shaderModel.IsDS()) {
  8167. theBuilder.requireCapability(spv::Capability::Tessellation);
  8168. } else if (shaderModel.IsGS()) {
  8169. theBuilder.requireCapability(spv::Capability::Geometry);
  8170. } else {
  8171. theBuilder.requireCapability(spv::Capability::Shader);
  8172. }
  8173. }
  8174. bool SPIRVEmitter::processGeometryShaderAttributes(const FunctionDecl *decl,
  8175. uint32_t *arraySize) {
  8176. bool success = true;
  8177. assert(shaderModel.IsGS());
  8178. if (auto *vcAttr = decl->getAttr<HLSLMaxVertexCountAttr>()) {
  8179. theBuilder.addExecutionMode(entryFunctionId,
  8180. spv::ExecutionMode::OutputVertices,
  8181. {static_cast<uint32_t>(vcAttr->getCount())});
  8182. }
  8183. uint32_t invocations = 1;
  8184. if (auto *instanceAttr = decl->getAttr<HLSLInstanceAttr>()) {
  8185. invocations = static_cast<uint32_t>(instanceAttr->getCount());
  8186. }
  8187. theBuilder.addExecutionMode(entryFunctionId, spv::ExecutionMode::Invocations,
  8188. {invocations});
  8189. // Only one primitive type is permitted for the geometry shader.
  8190. bool outPoint = false, outLine = false, outTriangle = false, inPoint = false,
  8191. inLine = false, inTriangle = false, inLineAdj = false,
  8192. inTriangleAdj = false;
  8193. for (const auto *param : decl->params()) {
  8194. // Add an execution mode based on the output stream type. Do not an
  8195. // execution mode more than once.
  8196. if (param->hasAttr<HLSLInOutAttr>()) {
  8197. const auto paramType = param->getType();
  8198. if (hlsl::IsHLSLTriangleStreamType(paramType) && !outTriangle) {
  8199. theBuilder.addExecutionMode(
  8200. entryFunctionId, spv::ExecutionMode::OutputTriangleStrip, {});
  8201. outTriangle = true;
  8202. } else if (hlsl::IsHLSLLineStreamType(paramType) && !outLine) {
  8203. theBuilder.addExecutionMode(entryFunctionId,
  8204. spv::ExecutionMode::OutputLineStrip, {});
  8205. outLine = true;
  8206. } else if (hlsl::IsHLSLPointStreamType(paramType) && !outPoint) {
  8207. theBuilder.addExecutionMode(entryFunctionId,
  8208. spv::ExecutionMode::OutputPoints, {});
  8209. outPoint = true;
  8210. }
  8211. // An output stream parameter will not have the input primitive type
  8212. // attributes, so we can continue to the next parameter.
  8213. continue;
  8214. }
  8215. // Add an execution mode based on the input primitive type. Do not add an
  8216. // execution mode more than once.
  8217. if (param->hasAttr<HLSLPointAttr>() && !inPoint) {
  8218. theBuilder.addExecutionMode(entryFunctionId,
  8219. spv::ExecutionMode::InputPoints, {});
  8220. *arraySize = 1;
  8221. inPoint = true;
  8222. } else if (param->hasAttr<HLSLLineAttr>() && !inLine) {
  8223. theBuilder.addExecutionMode(entryFunctionId,
  8224. spv::ExecutionMode::InputLines, {});
  8225. *arraySize = 2;
  8226. inLine = true;
  8227. } else if (param->hasAttr<HLSLTriangleAttr>() && !inTriangle) {
  8228. theBuilder.addExecutionMode(entryFunctionId,
  8229. spv::ExecutionMode::Triangles, {});
  8230. *arraySize = 3;
  8231. inTriangle = true;
  8232. } else if (param->hasAttr<HLSLLineAdjAttr>() && !inLineAdj) {
  8233. theBuilder.addExecutionMode(entryFunctionId,
  8234. spv::ExecutionMode::InputLinesAdjacency, {});
  8235. *arraySize = 4;
  8236. inLineAdj = true;
  8237. } else if (param->hasAttr<HLSLTriangleAdjAttr>() && !inTriangleAdj) {
  8238. theBuilder.addExecutionMode(
  8239. entryFunctionId, spv::ExecutionMode::InputTrianglesAdjacency, {});
  8240. *arraySize = 6;
  8241. inTriangleAdj = true;
  8242. }
  8243. }
  8244. if (inPoint + inLine + inLineAdj + inTriangle + inTriangleAdj > 1) {
  8245. emitError("only one input primitive type can be specified in the geometry "
  8246. "shader",
  8247. {});
  8248. success = false;
  8249. }
  8250. if (outPoint + outTriangle + outLine > 1) {
  8251. emitError("only one output primitive type can be specified in the geometry "
  8252. "shader",
  8253. {});
  8254. success = false;
  8255. }
  8256. return success;
  8257. }
  8258. void SPIRVEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {
  8259. theBuilder.addExecutionMode(entryFunctionId,
  8260. spv::ExecutionMode::OriginUpperLeft, {});
  8261. if (decl->getAttr<HLSLEarlyDepthStencilAttr>()) {
  8262. theBuilder.addExecutionMode(entryFunctionId,
  8263. spv::ExecutionMode::EarlyFragmentTests, {});
  8264. }
  8265. if (decl->getAttr<VKPostDepthCoverageAttr>()) {
  8266. theBuilder.addExtension(Extension::KHR_post_depth_coverage,
  8267. "[[vk::post_depth_coverage]]", decl->getLocation());
  8268. theBuilder.requireCapability(spv::Capability::SampleMaskPostDepthCoverage);
  8269. theBuilder.addExecutionMode(entryFunctionId,
  8270. spv::ExecutionMode::PostDepthCoverage, {});
  8271. }
  8272. }
  8273. void SPIRVEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
  8274. // If not explicitly specified, x, y, and z should be defaulted to 1.
  8275. uint32_t x = 1, y = 1, z = 1;
  8276. if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
  8277. x = static_cast<uint32_t>(numThreadsAttr->getX());
  8278. y = static_cast<uint32_t>(numThreadsAttr->getY());
  8279. z = static_cast<uint32_t>(numThreadsAttr->getZ());
  8280. }
  8281. theBuilder.addExecutionMode(entryFunctionId, spv::ExecutionMode::LocalSize,
  8282. {x, y, z});
  8283. }
  8284. bool SPIRVEmitter::processTessellationShaderAttributes(
  8285. const FunctionDecl *decl, uint32_t *numOutputControlPoints) {
  8286. assert(shaderModel.IsHS() || shaderModel.IsDS());
  8287. using namespace spv;
  8288. if (auto *domain = decl->getAttr<HLSLDomainAttr>()) {
  8289. const auto domainType = domain->getDomainType().lower();
  8290. const ExecutionMode hsExecMode =
  8291. llvm::StringSwitch<ExecutionMode>(domainType)
  8292. .Case("tri", ExecutionMode::Triangles)
  8293. .Case("quad", ExecutionMode::Quads)
  8294. .Case("isoline", ExecutionMode::Isolines)
  8295. .Default(ExecutionMode::Max);
  8296. if (hsExecMode == ExecutionMode::Max) {
  8297. emitError("unknown domain type specified for entry function",
  8298. domain->getLocation());
  8299. return false;
  8300. }
  8301. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  8302. }
  8303. // Early return for domain shaders as domain shaders only takes the 'domain'
  8304. // attribute.
  8305. if (shaderModel.IsDS())
  8306. return true;
  8307. if (auto *partitioning = decl->getAttr<HLSLPartitioningAttr>()) {
  8308. const auto scheme = partitioning->getScheme().lower();
  8309. if (scheme == "pow2") {
  8310. emitError("pow2 partitioning scheme is not supported since there is no "
  8311. "equivalent in Vulkan",
  8312. partitioning->getLocation());
  8313. return false;
  8314. }
  8315. const ExecutionMode hsExecMode =
  8316. llvm::StringSwitch<ExecutionMode>(scheme)
  8317. .Case("fractional_even", ExecutionMode::SpacingFractionalEven)
  8318. .Case("fractional_odd", ExecutionMode::SpacingFractionalOdd)
  8319. .Case("integer", ExecutionMode::SpacingEqual)
  8320. .Default(ExecutionMode::Max);
  8321. if (hsExecMode == ExecutionMode::Max) {
  8322. emitError("unknown partitioning scheme in hull shader",
  8323. partitioning->getLocation());
  8324. return false;
  8325. }
  8326. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  8327. }
  8328. if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
  8329. const auto topology = outputTopology->getTopology().lower();
  8330. const ExecutionMode hsExecMode =
  8331. llvm::StringSwitch<ExecutionMode>(topology)
  8332. .Case("point", ExecutionMode::PointMode)
  8333. .Case("triangle_cw", ExecutionMode::VertexOrderCw)
  8334. .Case("triangle_ccw", ExecutionMode::VertexOrderCcw)
  8335. .Default(ExecutionMode::Max);
  8336. // TODO: There is no SPIR-V equivalent for "line" topology. Is it the
  8337. // default?
  8338. if (topology != "line") {
  8339. if (hsExecMode != spv::ExecutionMode::Max) {
  8340. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  8341. } else {
  8342. emitError("unknown output topology in hull shader",
  8343. outputTopology->getLocation());
  8344. return false;
  8345. }
  8346. }
  8347. }
  8348. if (auto *controlPoints = decl->getAttr<HLSLOutputControlPointsAttr>()) {
  8349. *numOutputControlPoints = controlPoints->getCount();
  8350. theBuilder.addExecutionMode(entryFunctionId,
  8351. spv::ExecutionMode::OutputVertices,
  8352. {*numOutputControlPoints});
  8353. }
  8354. if (auto *pcf = decl->getAttr<HLSLPatchConstantFuncAttr>()) {
  8355. llvm::StringRef pcf_name = pcf->getFunctionName();
  8356. for (auto *decl : astContext.getTranslationUnitDecl()->decls())
  8357. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl))
  8358. if (astContext.IsPatchConstantFunctionDecl(funcDecl) &&
  8359. funcDecl->getName() == pcf_name)
  8360. patchConstFunc = funcDecl;
  8361. }
  8362. return true;
  8363. }
  8364. bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
  8365. const uint32_t entryFuncId) {
  8366. // HS specific attributes
  8367. uint32_t numOutputControlPoints = 0;
  8368. uint32_t outputControlPointIdVal = 0; // SV_OutputControlPointID value
  8369. uint32_t primitiveIdVar = 0; // SV_PrimitiveID variable
  8370. uint32_t viewIdVar = 0; // SV_ViewID variable
  8371. uint32_t hullMainInputPatchParam = 0; // Temporary parameter for InputPatch<>
  8372. // The array size of per-vertex input/output variables
  8373. // Used by HS/DS/GS for the additional arrayness, zero means not an array.
  8374. uint32_t inputArraySize = 0;
  8375. uint32_t outputArraySize = 0;
  8376. // Construct the wrapper function signature.
  8377. const uint32_t voidType = theBuilder.getVoidType();
  8378. const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
  8379. // The wrapper entry function surely does not have pre-assigned <result-id>
  8380. // for it like other functions that got added to the work queue following
  8381. // function calls. And the wrapper is the entry function.
  8382. entryFunctionId =
  8383. theBuilder.beginFunction(funcType, voidType, decl->getName());
  8384. // Note this should happen before using declIdMapper for other tasks.
  8385. declIdMapper.setEntryFunctionId(entryFunctionId);
  8386. // Handle attributes specific to each shader stage
  8387. if (shaderModel.IsPS()) {
  8388. processPixelShaderAttributes(decl);
  8389. } else if (shaderModel.IsCS()) {
  8390. processComputeShaderAttributes(decl);
  8391. } else if (shaderModel.IsHS()) {
  8392. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  8393. return false;
  8394. // The input array size for HS is specified in the InputPatch parameter.
  8395. for (const auto *param : decl->params())
  8396. if (hlsl::IsHLSLInputPatchType(param->getType())) {
  8397. inputArraySize = hlsl::GetHLSLInputPatchCount(param->getType());
  8398. break;
  8399. }
  8400. outputArraySize = numOutputControlPoints;
  8401. } else if (shaderModel.IsDS()) {
  8402. if (!processTessellationShaderAttributes(decl, &numOutputControlPoints))
  8403. return false;
  8404. // The input array size for HS is specified in the OutputPatch parameter.
  8405. for (const auto *param : decl->params())
  8406. if (hlsl::IsHLSLOutputPatchType(param->getType())) {
  8407. inputArraySize = hlsl::GetHLSLOutputPatchCount(param->getType());
  8408. break;
  8409. }
  8410. // The per-vertex output of DS is not an array.
  8411. } else if (shaderModel.IsGS()) {
  8412. if (!processGeometryShaderAttributes(decl, &inputArraySize))
  8413. return false;
  8414. // The per-vertex output of GS is not an array.
  8415. }
  8416. // Go through all parameters and record the declaration of SV_ClipDistance
  8417. // and SV_CullDistance. We need to do this extra step because in HLSL we
  8418. // can declare multiple SV_ClipDistance/SV_CullDistance variables of float
  8419. // or vector of float types, but we can only have one single float array
  8420. // for the ClipDistance/CullDistance builtin. So we need to group all
  8421. // SV_ClipDistance/SV_CullDistance variables into one float array, thus we
  8422. // need to calculate the total size of the array and the offset of each
  8423. // variable within that array.
  8424. // Also go through all parameters to record the semantic strings provided for
  8425. // the builtins in gl_PerVertex.
  8426. for (const auto *param : decl->params()) {
  8427. if (canActAsInParmVar(param))
  8428. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(param, true))
  8429. return false;
  8430. if (canActAsOutParmVar(param))
  8431. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(param, false))
  8432. return false;
  8433. }
  8434. // Also consider the SV_ClipDistance/SV_CullDistance in the return type
  8435. if (!declIdMapper.glPerVertex.recordGlPerVertexDeclFacts(decl, false))
  8436. return false;
  8437. // Calculate the total size of the ClipDistance/CullDistance array and the
  8438. // offset of SV_ClipDistance/SV_CullDistance variables within the array.
  8439. declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
  8440. if (!shaderModel.IsCS()) {
  8441. // Generate stand-alone builtins of Position, ClipDistance, and
  8442. // CullDistance, which belongs to gl_PerVertex.
  8443. declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
  8444. }
  8445. // Require the ClipDistance/CullDistance capability if necessary.
  8446. // It is legal to just use the ClipDistance/CullDistance builtin without
  8447. // requiring the ClipDistance/CullDistance capability, as long as we don't
  8448. // read or write the builtin variable.
  8449. // For our CodeGen, that corresponds to not seeing SV_ClipDistance or
  8450. // SV_CullDistance at all. If we see them, we will generate code to read
  8451. // them to initialize temporary variable for calling the source code entry
  8452. // function or write to them after calling the source code entry function.
  8453. declIdMapper.glPerVertex.requireCapabilityIfNecessary();
  8454. // The entry basic block.
  8455. const uint32_t entryLabel = theBuilder.createBasicBlock();
  8456. theBuilder.setInsertPoint(entryLabel);
  8457. // Initialize all global variables at the beginning of the wrapper
  8458. for (const VarDecl *varDecl : toInitGloalVars) {
  8459. const auto varInfo = declIdMapper.getDeclEvalInfo(varDecl);
  8460. if (const auto *init = varDecl->getInit()) {
  8461. storeValue(varInfo, doExpr(init), varDecl->getType());
  8462. // Update counter variable associated with global variables
  8463. tryToAssignCounterVar(varDecl, init);
  8464. }
  8465. // If not explicitly initialized, initialize with their zero values if not
  8466. // resource objects
  8467. else if (!hlsl::IsHLSLResourceType(varDecl->getType())) {
  8468. const auto typeId = typeTranslator.translateType(varDecl->getType());
  8469. theBuilder.createStore(varInfo, theBuilder.getConstantNull(typeId));
  8470. }
  8471. }
  8472. // Create temporary variables for holding function call arguments
  8473. llvm::SmallVector<uint32_t, 4> params;
  8474. for (const auto *param : decl->params()) {
  8475. const auto paramType = param->getType();
  8476. const uint32_t typeId = typeTranslator.translateType(paramType);
  8477. std::string tempVarName = "param.var." + param->getNameAsString();
  8478. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  8479. params.push_back(tempVar);
  8480. // Create the stage input variable for parameter not marked as pure out and
  8481. // initialize the corresponding temporary variable
  8482. // Also do not create input variables for output stream objects of geometry
  8483. // shaders (e.g. TriangleStream) which are required to be marked as 'inout'.
  8484. if (canActAsInParmVar(param)) {
  8485. if (shaderModel.IsHS() && hlsl::IsHLSLInputPatchType(paramType)) {
  8486. // Record the temporary variable holding InputPatch. It may be used
  8487. // later in the patch constant function.
  8488. hullMainInputPatchParam = tempVar;
  8489. }
  8490. uint32_t loadedValue = 0;
  8491. if (!declIdMapper.createStageInputVar(param, &loadedValue, false))
  8492. return false;
  8493. // Only initialize the temporary variable if the parameter is indeed used.
  8494. if (param->isUsed()) {
  8495. theBuilder.createStore(tempVar, loadedValue);
  8496. }
  8497. // Record the temporary variable holding SV_OutputControlPointID,
  8498. // SV_PrimitiveID, and SV_ViewID. It may be used later in the patch
  8499. // constant function.
  8500. if (hasSemantic(param, hlsl::DXIL::SemanticKind::OutputControlPointID))
  8501. outputControlPointIdVal = loadedValue;
  8502. else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID))
  8503. primitiveIdVar = tempVar;
  8504. else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID))
  8505. viewIdVar = tempVar;
  8506. }
  8507. }
  8508. // Call the original entry function
  8509. const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
  8510. const uint32_t retVal =
  8511. theBuilder.createFunctionCall(retType, entryFuncId, params);
  8512. // Create and write stage output variables for return value. Special case for
  8513. // Hull shaders since they operate differently in 2 ways:
  8514. // 1- Their return value is in fact an array and each invocation should write
  8515. // to the proper offset in the array.
  8516. // 2- The patch constant function must be called *once* after all invocations
  8517. // of the main entry point function is done.
  8518. if (shaderModel.IsHS()) {
  8519. // Create stage output variables out of the return type.
  8520. if (!declIdMapper.createStageOutputVar(decl, numOutputControlPoints,
  8521. outputControlPointIdVal, retVal))
  8522. return false;
  8523. if (!processHSEntryPointOutputAndPCF(
  8524. decl, retType, retVal, numOutputControlPoints,
  8525. outputControlPointIdVal, primitiveIdVar, viewIdVar,
  8526. hullMainInputPatchParam))
  8527. return false;
  8528. } else {
  8529. if (!declIdMapper.createStageOutputVar(decl, retVal, /*forPCF*/ false))
  8530. return false;
  8531. }
  8532. // Create and write stage output variables for parameters marked as
  8533. // out/inout
  8534. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  8535. const auto *param = decl->getParamDecl(i);
  8536. if (canActAsOutParmVar(param)) {
  8537. // Load the value from the parameter after function call
  8538. const uint32_t typeId = typeTranslator.translateType(param->getType());
  8539. uint32_t loadedParam = 0;
  8540. // No need to write back the value if the parameter is not used at all in
  8541. // the original entry function.
  8542. //
  8543. // Write back of stage output variables in GS is manually controlled by
  8544. // .Append() intrinsic method. No need to load the parameter since we
  8545. // won't need to write back here.
  8546. if (param->isUsed() && !shaderModel.IsGS())
  8547. loadedParam = theBuilder.createLoad(typeId, params[i]);
  8548. if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
  8549. return false;
  8550. }
  8551. }
  8552. theBuilder.createReturn();
  8553. theBuilder.endFunction();
  8554. // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
  8555. // We should invoke a translation of the PCF manually.
  8556. if (shaderModel.IsHS())
  8557. doDecl(patchConstFunc);
  8558. return true;
  8559. }
  8560. bool SPIRVEmitter::processHSEntryPointOutputAndPCF(
  8561. const FunctionDecl *hullMainFuncDecl, uint32_t retType, uint32_t retVal,
  8562. uint32_t numOutputControlPoints, uint32_t outputControlPointId,
  8563. uint32_t primitiveId, uint32_t viewId, uint32_t hullMainInputPatch) {
  8564. // This method may only be called for Hull shaders.
  8565. assert(shaderModel.IsHS());
  8566. // For Hull shaders, the real output is an array of size
  8567. // numOutputControlPoints. The results of the main should be written to the
  8568. // correct offset in the array (based on InvocationID).
  8569. if (!numOutputControlPoints) {
  8570. emitError("number of output control points cannot be zero",
  8571. hullMainFuncDecl->getLocation());
  8572. return false;
  8573. }
  8574. // TODO: We should be able to handle cases where the SV_OutputControlPointID
  8575. // is not provided.
  8576. if (!outputControlPointId) {
  8577. emitError(
  8578. "SV_OutputControlPointID semantic must be provided in hull shader",
  8579. hullMainFuncDecl->getLocation());
  8580. return false;
  8581. }
  8582. if (!patchConstFunc) {
  8583. emitError("patch constant function not defined in hull shader",
  8584. hullMainFuncDecl->getLocation());
  8585. return false;
  8586. }
  8587. uint32_t hullMainOutputPatch = 0;
  8588. // If the patch constant function (PCF) takes the result of the Hull main
  8589. // entry point, create a temporary function-scope variable and write the
  8590. // results to it, so it can be passed to the PCF.
  8591. if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
  8592. const uint32_t hullMainRetType = theBuilder.getArrayType(
  8593. retType, theBuilder.getConstantUint32(numOutputControlPoints));
  8594. hullMainOutputPatch =
  8595. theBuilder.addFnVar(hullMainRetType, "temp.var.hullMainRetVal");
  8596. const auto tempLocation = theBuilder.createAccessChain(
  8597. theBuilder.getPointerType(retType, spv::StorageClass::Function),
  8598. hullMainOutputPatch, {outputControlPointId});
  8599. theBuilder.createStore(tempLocation, retVal);
  8600. }
  8601. // Now create a barrier before calling the Patch Constant Function (PCF).
  8602. // Flags are:
  8603. // Execution Barrier scope = Workgroup (2)
  8604. // Memory Barrier scope = Invocation (4)
  8605. // Memory Semantics Barrier scope = None (0)
  8606. const auto constZero = theBuilder.getConstantUint32(0);
  8607. const auto constFour = theBuilder.getConstantUint32(4);
  8608. const auto constTwo = theBuilder.getConstantUint32(2);
  8609. theBuilder.createBarrier(constTwo, constFour, constZero);
  8610. // The PCF should be called only once. Therefore, we check the invocationID,
  8611. // and we only allow ID 0 to call the PCF.
  8612. const uint32_t condition = theBuilder.createBinaryOp(
  8613. spv::Op::OpIEqual, theBuilder.getBoolType(), outputControlPointId,
  8614. theBuilder.getConstantUint32(0));
  8615. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  8616. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  8617. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  8618. theBuilder.addSuccessor(thenBB);
  8619. theBuilder.addSuccessor(mergeBB);
  8620. theBuilder.setMergeTarget(mergeBB);
  8621. theBuilder.setInsertPoint(thenBB);
  8622. // Call the PCF. Since the function is not explicitly called, we must first
  8623. // register an ID for it.
  8624. const uint32_t pcfId = declIdMapper.getOrRegisterFnResultId(patchConstFunc);
  8625. const uint32_t pcfRetType =
  8626. typeTranslator.translateType(patchConstFunc->getReturnType());
  8627. std::vector<uint32_t> pcfParams;
  8628. // A lambda for creating a stage input variable and its associated temporary
  8629. // variable for function call. Also initializes the temporary variable using
  8630. // the contents loaded from the stage input variable. Returns the <result-id>
  8631. // of the temporary variable.
  8632. const auto createParmVarAndInitFromStageInputVar =
  8633. [this](const ParmVarDecl *param) {
  8634. const uint32_t typeId = typeTranslator.translateType(param->getType());
  8635. std::string tempVarName = "param.var." + param->getNameAsString();
  8636. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  8637. uint32_t loadedValue = 0;
  8638. declIdMapper.createStageInputVar(param, &loadedValue, /*forPCF*/ true);
  8639. theBuilder.createStore(tempVar, loadedValue);
  8640. return tempVar;
  8641. };
  8642. for (const auto *param : patchConstFunc->parameters()) {
  8643. // Note: According to the HLSL reference, the PCF takes an InputPatch of
  8644. // ControlPoints as well as the PatchID (PrimitiveID). This does not
  8645. // necessarily mean that they are present. There is also no requirement
  8646. // for the order of parameters passed to PCF.
  8647. if (hlsl::IsHLSLInputPatchType(param->getType())) {
  8648. pcfParams.push_back(hullMainInputPatch);
  8649. } else if (hlsl::IsHLSLOutputPatchType(param->getType())) {
  8650. pcfParams.push_back(hullMainOutputPatch);
  8651. } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
  8652. if (!primitiveId) {
  8653. primitiveId = createParmVarAndInitFromStageInputVar(param);
  8654. }
  8655. pcfParams.push_back(primitiveId);
  8656. } else if (hasSemantic(param, hlsl::DXIL::SemanticKind::ViewID)) {
  8657. if (!viewId) {
  8658. viewId = createParmVarAndInitFromStageInputVar(param);
  8659. }
  8660. pcfParams.push_back(viewId);
  8661. } else {
  8662. emitError("patch constant function parameter '%0' unknown",
  8663. param->getLocation())
  8664. << param->getName();
  8665. }
  8666. }
  8667. const uint32_t pcfResultId =
  8668. theBuilder.createFunctionCall(pcfRetType, pcfId, {pcfParams});
  8669. if (!declIdMapper.createStageOutputVar(patchConstFunc, pcfResultId,
  8670. /*forPCF*/ true))
  8671. return false;
  8672. theBuilder.createBranch(mergeBB);
  8673. theBuilder.addSuccessor(mergeBB);
  8674. theBuilder.setInsertPoint(mergeBB);
  8675. return true;
  8676. }
  8677. bool SPIRVEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
  8678. if (!root)
  8679. return false;
  8680. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  8681. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  8682. if (!caseStmt && !compoundStmt)
  8683. return true;
  8684. if (caseStmt) {
  8685. const Expr *caseExpr = caseStmt->getLHS();
  8686. return caseExpr && caseExpr->isEvaluatable(astContext);
  8687. }
  8688. // Recurse down if facing a compound statement.
  8689. for (auto *st : compoundStmt->body())
  8690. if (!allSwitchCasesAreIntegerLiterals(st))
  8691. return false;
  8692. return true;
  8693. }
  8694. void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
  8695. const Stmt *root, uint32_t *defaultBB,
  8696. std::vector<std::pair<uint32_t, uint32_t>> *targets) {
  8697. if (!root)
  8698. return;
  8699. // A switch case can only appear in DefaultStmt, CaseStmt, or
  8700. // CompoundStmt. For the rest, we can just return.
  8701. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  8702. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  8703. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  8704. if (!defaultStmt && !caseStmt && !compoundStmt)
  8705. return;
  8706. // Recurse down if facing a compound statement.
  8707. if (compoundStmt) {
  8708. for (auto *st : compoundStmt->body())
  8709. discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
  8710. return;
  8711. }
  8712. std::string caseLabel;
  8713. uint32_t caseValue = 0;
  8714. if (defaultStmt) {
  8715. // This is the default branch.
  8716. caseLabel = "switch.default";
  8717. } else if (caseStmt) {
  8718. // This is a non-default case.
  8719. // When using OpSwitch, we only allow integer literal cases. e.g:
  8720. // case <literal_integer>: {...; break;}
  8721. const Expr *caseExpr = caseStmt->getLHS();
  8722. assert(caseExpr && caseExpr->isEvaluatable(astContext));
  8723. auto bitWidth = astContext.getIntWidth(caseExpr->getType());
  8724. if (bitWidth != 32)
  8725. emitError(
  8726. "non-32bit integer case value in switch statement unimplemented",
  8727. caseExpr->getExprLoc());
  8728. Expr::EvalResult evalResult;
  8729. caseExpr->EvaluateAsRValue(evalResult, astContext);
  8730. const int64_t value = evalResult.Val.getInt().getSExtValue();
  8731. caseValue = static_cast<uint32_t>(value);
  8732. caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
  8733. llvm::itostr(std::abs(value));
  8734. }
  8735. const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
  8736. theBuilder.addSuccessor(caseBB);
  8737. stmtBasicBlock[root] = caseBB;
  8738. // Add all cases to the 'targets' vector.
  8739. if (caseStmt)
  8740. targets->emplace_back(caseValue, caseBB);
  8741. // The default label is not part of the 'targets' vector that is passed
  8742. // to the OpSwitch instruction.
  8743. // If default statement was discovered, return its label via defaultBB.
  8744. if (defaultStmt)
  8745. *defaultBB = caseBB;
  8746. // Process cases nested in other cases. It happens when we have fall through
  8747. // cases. For example:
  8748. // case 1: case 2: ...; break;
  8749. // will result in the CaseSmt for case 2 nested in the one for case 1.
  8750. discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
  8751. : defaultStmt->getSubStmt(),
  8752. defaultBB, targets);
  8753. }
  8754. void SPIRVEmitter::flattenSwitchStmtAST(const Stmt *root,
  8755. std::vector<const Stmt *> *flatSwitch) {
  8756. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  8757. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  8758. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  8759. if (!compoundStmt) {
  8760. flatSwitch->push_back(root);
  8761. }
  8762. if (compoundStmt) {
  8763. for (const auto *st : compoundStmt->body())
  8764. flattenSwitchStmtAST(st, flatSwitch);
  8765. } else if (caseStmt) {
  8766. flattenSwitchStmtAST(caseStmt->getSubStmt(), flatSwitch);
  8767. } else if (defaultStmt) {
  8768. flattenSwitchStmtAST(defaultStmt->getSubStmt(), flatSwitch);
  8769. }
  8770. }
  8771. void SPIRVEmitter::processCaseStmtOrDefaultStmt(const Stmt *stmt) {
  8772. auto *caseStmt = dyn_cast<CaseStmt>(stmt);
  8773. auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
  8774. assert(caseStmt || defaultStmt);
  8775. uint32_t caseBB = stmtBasicBlock[stmt];
  8776. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  8777. // We are about to handle the case passed in as parameter. If the current
  8778. // basic block is not terminated, it means the previous case is a fall
  8779. // through case. We need to link it to the case to be processed.
  8780. theBuilder.createBranch(caseBB);
  8781. theBuilder.addSuccessor(caseBB);
  8782. }
  8783. theBuilder.setInsertPoint(caseBB);
  8784. doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
  8785. }
  8786. void SPIRVEmitter::processSwitchStmtUsingSpirvOpSwitch(
  8787. const SwitchStmt *switchStmt) {
  8788. // First handle the condition variable DeclStmt if one exists.
  8789. // For example: handle 'int a = b' in the following:
  8790. // switch (int a = b) {...}
  8791. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  8792. doDeclStmt(condVarDeclStmt);
  8793. const uint32_t selector = doExpr(switchStmt->getCond());
  8794. // We need a merge block regardless of the number of switch cases.
  8795. // Since OpSwitch always requires a default label, if the switch statement
  8796. // does not have a default branch, we use the merge block as the default
  8797. // target.
  8798. const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
  8799. theBuilder.setMergeTarget(mergeBB);
  8800. breakStack.push(mergeBB);
  8801. uint32_t defaultBB = mergeBB;
  8802. // (literal, labelId) pairs to pass to the OpSwitch instruction.
  8803. std::vector<std::pair<uint32_t, uint32_t>> targets;
  8804. discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
  8805. // Create the OpSelectionMerge and OpSwitch.
  8806. theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
  8807. // Handle the switch body.
  8808. doStmt(switchStmt->getBody());
  8809. if (!theBuilder.isCurrentBasicBlockTerminated())
  8810. theBuilder.createBranch(mergeBB);
  8811. theBuilder.setInsertPoint(mergeBB);
  8812. breakStack.pop();
  8813. }
  8814. void SPIRVEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
  8815. std::vector<const Stmt *> flatSwitch;
  8816. flattenSwitchStmtAST(switchStmt->getBody(), &flatSwitch);
  8817. // First handle the condition variable DeclStmt if one exists.
  8818. // For example: handle 'int a = b' in the following:
  8819. // switch (int a = b) {...}
  8820. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  8821. doDeclStmt(condVarDeclStmt);
  8822. // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
  8823. // the flattened switch AST.
  8824. // For instance, for the following flat vector:
  8825. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  8826. // |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
  8827. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  8828. // The indexes are: {0, 2, 5, 6, 9}
  8829. std::vector<uint32_t> caseStmtLocs;
  8830. for (uint32_t i = 0; i < flatSwitch.size(); ++i)
  8831. if (isa<CaseStmt>(flatSwitch[i]) || isa<DefaultStmt>(flatSwitch[i]))
  8832. caseStmtLocs.push_back(i);
  8833. IfStmt *prevIfStmt = nullptr;
  8834. IfStmt *rootIfStmt = nullptr;
  8835. CompoundStmt *defaultBody = nullptr;
  8836. // For each case, start at its index in the vector, and go forward
  8837. // accumulating statements until BreakStmt or end of vector is reached.
  8838. for (auto curCaseIndex : caseStmtLocs) {
  8839. const Stmt *curCase = flatSwitch[curCaseIndex];
  8840. // CompoundStmt to hold all statements for this case.
  8841. CompoundStmt *cs = new (astContext) CompoundStmt(Stmt::EmptyShell());
  8842. // Accumulate all non-case/default/break statements as the body for the
  8843. // current case.
  8844. std::vector<Stmt *> statements;
  8845. for (unsigned i = curCaseIndex + 1;
  8846. i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
  8847. if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
  8848. statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
  8849. }
  8850. if (!statements.empty())
  8851. cs->setStmts(astContext, statements.data(), statements.size());
  8852. // For non-default cases, generate the IfStmt that compares the switch
  8853. // value to the case value.
  8854. if (auto *caseStmt = dyn_cast<CaseStmt>(curCase)) {
  8855. IfStmt *curIf = new (astContext) IfStmt(Stmt::EmptyShell());
  8856. BinaryOperator *bo = new (astContext) BinaryOperator(Stmt::EmptyShell());
  8857. bo->setLHS(const_cast<Expr *>(switchStmt->getCond()));
  8858. bo->setRHS(const_cast<Expr *>(caseStmt->getLHS()));
  8859. bo->setOpcode(BO_EQ);
  8860. bo->setType(astContext.getLogicalOperationType());
  8861. curIf->setCond(bo);
  8862. curIf->setThen(cs);
  8863. // No conditional variable associated with this faux if statement.
  8864. curIf->setConditionVariable(astContext, nullptr);
  8865. // Each If statement is the "else" of the previous if statement.
  8866. if (prevIfStmt)
  8867. prevIfStmt->setElse(curIf);
  8868. else
  8869. rootIfStmt = curIf;
  8870. prevIfStmt = curIf;
  8871. } else {
  8872. // Record the DefaultStmt body as it will be used as the body of the
  8873. // "else" block in the if-elseif-...-else pattern.
  8874. defaultBody = cs;
  8875. }
  8876. }
  8877. // If a default case exists, it is the "else" of the last if statement.
  8878. if (prevIfStmt)
  8879. prevIfStmt->setElse(defaultBody);
  8880. // Since all else-if and else statements are the child nodes of the first
  8881. // IfStmt, we only need to call doStmt for the first IfStmt.
  8882. if (rootIfStmt)
  8883. doStmt(rootIfStmt);
  8884. // If there are no CaseStmt and there is only 1 DefaultStmt, there will be
  8885. // no if statements. The switch in that case only executes the body of the
  8886. // default case.
  8887. else if (defaultBody)
  8888. doStmt(defaultBody);
  8889. }
  8890. uint32_t SPIRVEmitter::extractVecFromVec4(uint32_t fromId,
  8891. uint32_t targetVecSize,
  8892. uint32_t targetElemTypeId) {
  8893. assert(targetVecSize > 0 && targetVecSize < 5);
  8894. const uint32_t retType =
  8895. targetVecSize == 1
  8896. ? targetElemTypeId
  8897. : theBuilder.getVecType(targetElemTypeId, targetVecSize);
  8898. switch (targetVecSize) {
  8899. case 1:
  8900. return theBuilder.createCompositeExtract(retType, fromId, {0});
  8901. break;
  8902. case 2:
  8903. return theBuilder.createVectorShuffle(retType, fromId, fromId, {0, 1});
  8904. break;
  8905. case 3:
  8906. return theBuilder.createVectorShuffle(retType, fromId, fromId, {0, 1, 2});
  8907. break;
  8908. case 4:
  8909. return fromId;
  8910. default:
  8911. llvm_unreachable("vector element count must be 1, 2, 3, or 4");
  8912. }
  8913. }
  8914. void SPIRVEmitter::emitDebugLine(SourceLocation loc) {
  8915. if (spirvOptions.debugInfoLine && mainSourceFileId != 0) {
  8916. auto floc = FullSourceLoc(loc, theCompilerInstance.getSourceManager());
  8917. theBuilder.debugLine(mainSourceFileId, floc.getSpellingLineNumber(),
  8918. floc.getSpellingColumnNumber());
  8919. }
  8920. }
  8921. } // end namespace spirv
  8922. } // end namespace clang