2
0

SPIRVEmitter.cpp 217 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534
  1. //===------- SPIRVEmitter.h - SPIR-V Binary Code Emitter --------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements a SPIR-V emitter class that takes in HLSL AST and emits
  10. // SPIR-V binary words.
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #include "SPIRVEmitter.h"
  14. #include "dxc/HlslIntrinsicOp.h"
  15. #include "spirv-tools/optimizer.hpp"
  16. #include "llvm/ADT/StringExtras.h"
  17. #include "InitListHandler.h"
  18. namespace clang {
  19. namespace spirv {
  20. namespace {
  21. // Returns true if the given decl has the given semantic.
  22. bool hasSemantic(const DeclaratorDecl *decl,
  23. hlsl::DXIL::SemanticKind semanticKind) {
  24. using namespace hlsl;
  25. for (auto *annotation : decl->getUnusualAnnotations()) {
  26. if (auto *semanticDecl = dyn_cast<SemanticDecl>(annotation)) {
  27. llvm::StringRef semanticName;
  28. uint32_t semanticIndex = 0;
  29. Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName,
  30. &semanticIndex);
  31. const auto *semantic = Semantic::GetByName(semanticName);
  32. if (semantic->GetKind() == semanticKind)
  33. return true;
  34. }
  35. }
  36. return false;
  37. }
  38. bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) {
  39. for (const auto *param : pcf->parameters())
  40. if (TypeTranslator::isOutputPatch(param->getType()))
  41. return true;
  42. return false;
  43. }
  44. // TODO: Maybe we should move these type probing functions to TypeTranslator.
  45. /// Returns true if the two types are the same scalar or vector type.
  46. bool isSameScalarOrVecType(QualType type1, QualType type2) {
  47. {
  48. QualType scalarType1 = {}, scalarType2 = {};
  49. if (TypeTranslator::isScalarType(type1, &scalarType1) &&
  50. TypeTranslator::isScalarType(type2, &scalarType2))
  51. return scalarType1.getCanonicalType() == scalarType2.getCanonicalType();
  52. }
  53. {
  54. QualType elemType1 = {}, elemType2 = {};
  55. uint32_t count1 = {}, count2 = {};
  56. if (TypeTranslator::isVectorType(type1, &elemType1, &count1) &&
  57. TypeTranslator::isVectorType(type2, &elemType2, &count2))
  58. return count1 == count2 &&
  59. elemType1.getCanonicalType() == elemType2.getCanonicalType();
  60. }
  61. return false;
  62. }
  63. /// Returns true if the given type is a bool or vector of bool type.
  64. bool isBoolOrVecOfBoolType(QualType type) {
  65. QualType elemType = {};
  66. return (TypeTranslator::isScalarType(type, &elemType) ||
  67. TypeTranslator::isVectorType(type, &elemType)) &&
  68. elemType->isBooleanType();
  69. }
  70. /// Returns true if the given type is a signed integer or vector of signed
  71. /// integer type.
  72. bool isSintOrVecOfSintType(QualType type) {
  73. QualType elemType = {};
  74. return (TypeTranslator::isScalarType(type, &elemType) ||
  75. TypeTranslator::isVectorType(type, &elemType)) &&
  76. elemType->isSignedIntegerType();
  77. }
  78. /// Returns true if the given type is an unsigned integer or vector of unsigned
  79. /// integer type.
  80. bool isUintOrVecOfUintType(QualType type) {
  81. QualType elemType = {};
  82. return (TypeTranslator::isScalarType(type, &elemType) ||
  83. TypeTranslator::isVectorType(type, &elemType)) &&
  84. elemType->isUnsignedIntegerType();
  85. }
  86. /// Returns true if the given type is a float or vector of float type.
  87. bool isFloatOrVecOfFloatType(QualType type) {
  88. QualType elemType = {};
  89. return (TypeTranslator::isScalarType(type, &elemType) ||
  90. TypeTranslator::isVectorType(type, &elemType)) &&
  91. elemType->isFloatingType();
  92. }
  93. /// Returns true if the given type is a bool or vector/matrix of bool type.
  94. bool isBoolOrVecMatOfBoolType(QualType type) {
  95. return isBoolOrVecOfBoolType(type) ||
  96. (hlsl::IsHLSLMatType(type) &&
  97. hlsl::GetHLSLMatElementType(type)->isBooleanType());
  98. }
  99. /// Returns true if the given type is a signed integer or vector/matrix of
  100. /// signed integer type.
  101. bool isSintOrVecMatOfSintType(QualType type) {
  102. return isSintOrVecOfSintType(type) ||
  103. (hlsl::IsHLSLMatType(type) &&
  104. hlsl::GetHLSLMatElementType(type)->isSignedIntegerType());
  105. }
  106. /// Returns true if the given type is an unsigned integer or vector/matrix of
  107. /// unsigned integer type.
  108. bool isUintOrVecMatOfUintType(QualType type) {
  109. return isUintOrVecOfUintType(type) ||
  110. (hlsl::IsHLSLMatType(type) &&
  111. hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType());
  112. }
  113. /// Returns true if the given type is a float or vector/matrix of float type.
  114. bool isFloatOrVecMatOfFloatType(QualType type) {
  115. return isFloatOrVecOfFloatType(type) ||
  116. (hlsl::IsHLSLMatType(type) &&
  117. hlsl::GetHLSLMatElementType(type)->isFloatingType());
  118. }
  119. bool isSpirvMatrixOp(spv::Op opcode) {
  120. switch (opcode) {
  121. case spv::Op::OpMatrixTimesMatrix:
  122. case spv::Op::OpMatrixTimesVector:
  123. case spv::Op::OpMatrixTimesScalar:
  124. return true;
  125. default:
  126. break;
  127. }
  128. return false;
  129. }
  130. /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes
  131. /// index. Otherwiser, returns false.
  132. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a
  133. // duplicate of doCXXMemberCallExpr.
  134. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) {
  135. using namespace hlsl;
  136. if (const auto *indexing = dyn_cast<CXXMemberCallExpr>(expr)) {
  137. const auto *callee = indexing->getDirectCallee();
  138. uint32_t opcode = static_cast<uint32_t>(IntrinsicOp::Num_Intrinsics);
  139. llvm::StringRef group;
  140. if (GetIntrinsicOp(callee, opcode, group)) {
  141. if (static_cast<IntrinsicOp>(opcode) == IntrinsicOp::MOP_Load) {
  142. const auto *object = indexing->getImplicitObjectArgument();
  143. if (TypeTranslator::isStructuredBuffer(object->getType())) {
  144. *index = indexing->getArg(0);
  145. return indexing->getImplicitObjectArgument();
  146. }
  147. }
  148. }
  149. }
  150. return nullptr;
  151. }
  152. bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
  153. spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
  154. optimizer.SetMessageConsumer(
  155. [messages](spv_message_level_t /*level*/, const char * /*source*/,
  156. const spv_position_t & /*position*/,
  157. const char *message) { *messages += message; });
  158. optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
  159. optimizer.RegisterPass(spvtools::CreateLocalAccessChainConvertPass());
  160. optimizer.RegisterPass(spvtools::CreateLocalSingleBlockLoadStoreElimPass());
  161. optimizer.RegisterPass(spvtools::CreateLocalSingleStoreElimPass());
  162. optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
  163. optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
  164. optimizer.RegisterPass(spvtools::CreateDeadBranchElimPass());
  165. optimizer.RegisterPass(spvtools::CreateBlockMergePass());
  166. optimizer.RegisterPass(spvtools::CreateLocalMultiStoreElimPass());
  167. optimizer.RegisterPass(spvtools::CreateInsertExtractElimPass());
  168. optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
  169. optimizer.RegisterPass(spvtools::CreateEliminateDeadFunctionsPass());
  170. optimizer.RegisterPass(spvtools::CreateEliminateDeadConstantPass());
  171. optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
  172. return optimizer.Run(module->data(), module->size(), module);
  173. }
  174. /// Translates RWByteAddressBuffer atomic method opcode into SPIR-V opcode.
  175. spv::Op translateRWBABufferAtomicMethods(hlsl::IntrinsicOp opcode) {
  176. using namespace hlsl;
  177. using namespace spv;
  178. switch (opcode) {
  179. case IntrinsicOp::MOP_InterlockedAdd:
  180. return Op::OpAtomicIAdd;
  181. case IntrinsicOp::MOP_InterlockedAnd:
  182. return Op::OpAtomicAnd;
  183. case IntrinsicOp::MOP_InterlockedOr:
  184. return Op::OpAtomicOr;
  185. case IntrinsicOp::MOP_InterlockedXor:
  186. return Op::OpAtomicXor;
  187. case IntrinsicOp::MOP_InterlockedUMax:
  188. return Op::OpAtomicUMax;
  189. case IntrinsicOp::MOP_InterlockedUMin:
  190. return Op::OpAtomicUMin;
  191. case IntrinsicOp::MOP_InterlockedMax:
  192. return Op::OpAtomicSMax;
  193. case IntrinsicOp::MOP_InterlockedMin:
  194. return Op::OpAtomicSMin;
  195. case IntrinsicOp::MOP_InterlockedExchange:
  196. return Op::OpAtomicExchange;
  197. }
  198. assert(false && "unimplemented hlsl intrinsic opcode");
  199. return Op::Max;
  200. }
  201. } // namespace
  202. SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
  203. const EmitSPIRVOptions &options)
  204. : theCompilerInstance(ci), astContext(ci.getASTContext()),
  205. diags(ci.getDiagnostics()), spirvOptions(options),
  206. entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
  207. shaderModel(*hlsl::ShaderModel::GetByName(
  208. ci.getCodeGenOpts().HLSLProfile.c_str())),
  209. theContext(), theBuilder(&theContext),
  210. declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
  211. typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
  212. curFunction(nullptr), curThis(0), needsLegalization(false) {
  213. if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
  214. emitError("unknown shader module: %0") << shaderModel.GetName();
  215. }
  216. void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
  217. // Stop translating if there are errors in previous compilation stages.
  218. if (context.getDiagnostics().hasErrorOccurred())
  219. return;
  220. TranslationUnitDecl *tu = context.getTranslationUnitDecl();
  221. // The entry function is the seed of the queue.
  222. for (auto *decl : tu->decls()) {
  223. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  224. if (funcDecl->getName() == entryFunctionName) {
  225. workQueue.insert(funcDecl);
  226. }
  227. if (context.IsPatchConstantFunctionDecl(funcDecl)) {
  228. patchConstFunc = funcDecl;
  229. }
  230. } else if (auto *varDecl = dyn_cast<VarDecl>(decl)) {
  231. if (isa<HLSLBufferDecl>(varDecl->getDeclContext())) {
  232. // This is a VarDecl of a ConstantBuffer/TextureBuffer type.
  233. (void)declIdMapper.createCTBuffer(varDecl);
  234. } else {
  235. doVarDecl(varDecl);
  236. }
  237. } else if (auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
  238. // This is a cbuffer/tbuffer decl.
  239. (void)declIdMapper.createCTBuffer(bufferDecl);
  240. }
  241. }
  242. // Translate all functions reachable from the entry function.
  243. // The queue can grow in the meanwhile; so need to keep evaluating
  244. // workQueue.size().
  245. for (uint32_t i = 0; i < workQueue.size(); ++i) {
  246. doDecl(workQueue[i]);
  247. }
  248. if (context.getDiagnostics().hasErrorOccurred())
  249. return;
  250. AddRequiredCapabilitiesForShaderModel();
  251. // Addressing and memory model are required in a valid SPIR-V module.
  252. theBuilder.setAddressingModel(spv::AddressingModel::Logical);
  253. theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
  254. theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId,
  255. entryFunctionName, declIdMapper.collectStageVars());
  256. AddExecutionModeForEntryPoint(entryFunctionId);
  257. // Add Location decorations to stage input/output variables.
  258. if (!declIdMapper.decorateStageIOLocations())
  259. return;
  260. // Add descriptor set and binding decorations to resource variables.
  261. if (!declIdMapper.decorateResourceBindings())
  262. return;
  263. // Output the constructed module.
  264. std::vector<uint32_t> m = theBuilder.takeModule();
  265. const auto optLevel = theCompilerInstance.getCodeGenOpts().OptimizationLevel;
  266. if (needsLegalization || optLevel > 0) {
  267. if (needsLegalization && optLevel == 0)
  268. emitWarning("-O0 ignored since SPIR-V legalization required");
  269. std::string messages;
  270. if (!spirvToolsOptimize(&m, &messages)) {
  271. emitFatalError("failed to legalize/optimize SPIR-V: %0") << messages;
  272. return;
  273. }
  274. }
  275. theCompilerInstance.getOutStream()->write(
  276. reinterpret_cast<const char *>(m.data()), m.size() * 4);
  277. }
  278. void SPIRVEmitter::doDecl(const Decl *decl) {
  279. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  280. doVarDecl(varDecl);
  281. } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  282. doFunctionDecl(funcDecl);
  283. } else if (dyn_cast<HLSLBufferDecl>(decl)) {
  284. llvm_unreachable("HLSLBufferDecl should not be handled here");
  285. } else {
  286. // TODO: Implement handling of other Decl types.
  287. emitWarning("Decl type '%0' is not supported yet.")
  288. << decl->getDeclKindName();
  289. }
  290. }
  291. void SPIRVEmitter::doStmt(const Stmt *stmt,
  292. llvm::ArrayRef<const Attr *> attrs) {
  293. if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
  294. for (auto *st : compoundStmt->body())
  295. doStmt(st);
  296. } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
  297. doReturnStmt(retStmt);
  298. } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
  299. doDeclStmt(declStmt);
  300. } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
  301. doIfStmt(ifStmt);
  302. } else if (const auto *switchStmt = dyn_cast<SwitchStmt>(stmt)) {
  303. doSwitchStmt(switchStmt, attrs);
  304. } else if (const auto *caseStmt = dyn_cast<CaseStmt>(stmt)) {
  305. processCaseStmtOrDefaultStmt(stmt);
  306. } else if (const auto *defaultStmt = dyn_cast<DefaultStmt>(stmt)) {
  307. processCaseStmtOrDefaultStmt(stmt);
  308. } else if (const auto *breakStmt = dyn_cast<BreakStmt>(stmt)) {
  309. doBreakStmt(breakStmt);
  310. } else if (const auto *theDoStmt = dyn_cast<DoStmt>(stmt)) {
  311. doDoStmt(theDoStmt, attrs);
  312. } else if (const auto *discardStmt = dyn_cast<DiscardStmt>(stmt)) {
  313. doDiscardStmt(discardStmt);
  314. } else if (const auto *continueStmt = dyn_cast<ContinueStmt>(stmt)) {
  315. doContinueStmt(continueStmt);
  316. } else if (const auto *whileStmt = dyn_cast<WhileStmt>(stmt)) {
  317. doWhileStmt(whileStmt, attrs);
  318. } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
  319. doForStmt(forStmt, attrs);
  320. } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
  321. // For the null statement ";". We don't need to do anything.
  322. } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
  323. // All cases for expressions used as statements
  324. doExpr(expr);
  325. } else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
  326. doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
  327. } else {
  328. emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
  329. }
  330. }
  331. SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
  332. if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  333. return declIdMapper.getDeclResultId(delRefExpr->getFoundDecl());
  334. }
  335. if (const auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
  336. // Just need to return what's inside the parentheses.
  337. return doExpr(parenExpr->getSubExpr());
  338. }
  339. if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
  340. return doMemberExpr(memberExpr);
  341. }
  342. if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
  343. return doCastExpr(castExpr);
  344. }
  345. if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
  346. return doInitListExpr(initListExpr);
  347. }
  348. if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
  349. const bool value = boolLiteral->getValue();
  350. return SpirvEvalInfo::withConst(theBuilder.getConstantBool(value));
  351. }
  352. if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  353. return SpirvEvalInfo::withConst(
  354. translateAPInt(intLiteral->getValue(), expr->getType()));
  355. }
  356. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  357. return SpirvEvalInfo::withConst(
  358. translateAPFloat(floatLiteral->getValue(), expr->getType()));
  359. }
  360. // CompoundAssignOperator is a subclass of BinaryOperator. It should be
  361. // checked before BinaryOperator.
  362. if (const auto *compoundAssignOp = dyn_cast<CompoundAssignOperator>(expr)) {
  363. return doCompoundAssignOperator(compoundAssignOp);
  364. }
  365. if (const auto *binOp = dyn_cast<BinaryOperator>(expr)) {
  366. return doBinaryOperator(binOp);
  367. }
  368. if (const auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
  369. return doUnaryOperator(unaryOp);
  370. }
  371. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  372. return doHLSLVectorElementExpr(vecElemExpr);
  373. }
  374. if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
  375. return doExtMatrixElementExpr(matElemExpr);
  376. }
  377. if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
  378. return doCallExpr(funcCall);
  379. }
  380. if (const auto *subscriptExpr = dyn_cast<ArraySubscriptExpr>(expr)) {
  381. return doArraySubscriptExpr(subscriptExpr);
  382. }
  383. if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
  384. return doConditionalOperator(condExpr);
  385. }
  386. if (isa<CXXThisExpr>(expr)) {
  387. assert(curThis);
  388. return curThis;
  389. }
  390. emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
  391. return 0;
  392. }
  393. SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) {
  394. auto info = doExpr(expr);
  395. if (expr->isGLValue())
  396. info.resultId = theBuilder.createLoad(
  397. typeTranslator.translateType(expr->getType(), info.layoutRule),
  398. info.resultId);
  399. return info;
  400. }
  401. uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
  402. QualType toType) {
  403. if (isFloatOrVecOfFloatType(toType))
  404. return castToFloat(value, fromType, toType);
  405. // Order matters here. Bool (vector) values will also be considered as uint
  406. // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType()
  407. // will also return true. We need to check bool before uint. The opposite is
  408. // not true.
  409. if (isBoolOrVecOfBoolType(toType))
  410. return castToBool(value, fromType, toType);
  411. if (isSintOrVecOfSintType(toType) || isUintOrVecOfUintType(toType))
  412. return castToInt(value, fromType, toType);
  413. emitError("casting to type %0 unimplemented") << toType;
  414. return 0;
  415. }
  416. void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
  417. // We are about to start translation for a new function. Clear the break stack
  418. // and the continue stack.
  419. breakStack = std::stack<uint32_t>();
  420. continueStack = std::stack<uint32_t>();
  421. curFunction = decl;
  422. std::string funcName = decl->getName();
  423. uint32_t funcId = 0;
  424. if (funcName == entryFunctionName) {
  425. // The entry function surely does not have pre-assigned <result-id> for
  426. // it like other functions that got added to the work queue following
  427. // function calls.
  428. funcId = theContext.takeNextId();
  429. funcName = "src." + funcName;
  430. // Create wrapper for the entry function
  431. if (!emitEntryFunctionWrapper(decl, funcId))
  432. return;
  433. } else {
  434. // Non-entry functions are added to the work queue following function
  435. // calls. We have already assigned <result-id>s for it when translating
  436. // its call site. Query it here.
  437. funcId = declIdMapper.getDeclResultId(decl).resultId;
  438. }
  439. if (!needsLegalization &&
  440. TypeTranslator::isOpaqueStructType(decl->getReturnType()))
  441. needsLegalization = true;
  442. const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
  443. // Construct the function signature.
  444. llvm::SmallVector<uint32_t, 4> paramTypes;
  445. bool isNonStaticMemberFn = false;
  446. if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
  447. isNonStaticMemberFn = !memberFn->isStatic();
  448. if (isNonStaticMemberFn) {
  449. // For non-static member function, the first parameter should be the
  450. // object on which we are invoking this method.
  451. const uint32_t valueType = typeTranslator.translateType(
  452. memberFn->getThisType(astContext)->getPointeeType());
  453. const uint32_t ptrType =
  454. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  455. paramTypes.push_back(ptrType);
  456. }
  457. // Prefix the function name with the struct name
  458. if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
  459. funcName = st->getName().str() + "." + funcName;
  460. }
  461. for (const auto *param : decl->params()) {
  462. const uint32_t valueType = typeTranslator.translateType(param->getType());
  463. const uint32_t ptrType =
  464. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  465. paramTypes.push_back(ptrType);
  466. if (!needsLegalization &&
  467. TypeTranslator::isOpaqueStructType(param->getType()))
  468. needsLegalization = true;
  469. }
  470. const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
  471. theBuilder.beginFunction(funcType, retType, funcName, funcId);
  472. if (isNonStaticMemberFn) {
  473. // Remember the parameter for the this object so later we can handle
  474. // CXXThisExpr correctly.
  475. curThis = theBuilder.addFnParam(paramTypes[0], "param.this");
  476. }
  477. // Create all parameters.
  478. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  479. const ParmVarDecl *paramDecl = decl->getParamDecl(i);
  480. (void)declIdMapper.createFnParam(paramTypes[i + isNonStaticMemberFn],
  481. paramDecl);
  482. }
  483. if (decl->hasBody()) {
  484. // The entry basic block.
  485. const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
  486. theBuilder.setInsertPoint(entryLabel);
  487. // Process all statments in the body.
  488. doStmt(decl->getBody());
  489. // We have processed all Stmts in this function and now in the last
  490. // basic block. Make sure we have OpReturn if missing.
  491. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  492. theBuilder.createReturn();
  493. }
  494. }
  495. theBuilder.endFunction();
  496. curFunction = nullptr;
  497. }
  498. void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
  499. uint32_t varId = 0;
  500. // The contents in externally visible variables can be updated via the
  501. // pipeline. They should be handled differently from file and function scope
  502. // variables.
  503. // File scope variables (static "global" and "local" variables) belongs to
  504. // the Private storage class, while function scope variables (normal "local"
  505. // variables) belongs to the Function storage class.
  506. if (!decl->isExternallyVisible()) {
  507. // Note: cannot move varType outside of this scope because it generates
  508. // SPIR-V types without decorations, while external visible variable should
  509. // have SPIR-V type with decorations.
  510. const uint32_t varType = typeTranslator.translateType(decl->getType());
  511. // We already know the variable is not externally visible here. If it does
  512. // not have local storage, it should be file scope variable.
  513. const bool isFileScopeVar = !decl->hasLocalStorage();
  514. // Handle initializer. SPIR-V requires that "initializer must be an <id>
  515. // from a constant instruction or a global (module scope) OpVariable
  516. // instruction."
  517. llvm::Optional<uint32_t> constInit;
  518. if (decl->hasInit()) {
  519. if (const uint32_t id = tryToEvaluateAsConst(decl->getInit()))
  520. constInit = llvm::Optional<uint32_t>(id);
  521. } else if (isFileScopeVar) {
  522. // For static variables, if no initializers are provided, we should
  523. // initialize them to zero values.
  524. constInit = llvm::Optional<uint32_t>(theBuilder.getConstantNull(varType));
  525. }
  526. if (isFileScopeVar)
  527. varId = declIdMapper.createFileVar(varType, decl, constInit);
  528. else
  529. varId = declIdMapper.createFnVar(varType, decl, constInit);
  530. // If we cannot evaluate the initializer as a constant expression, we'll
  531. // need to use OpStore to write the initializer to the variable.
  532. // Also we should only evaluate the initializer once for a static variable.
  533. if (decl->hasInit() && !constInit.hasValue()) {
  534. if (isFileScopeVar) {
  535. if (decl->isStaticLocal()) {
  536. initOnce(decl->getName(), varId, decl->getInit());
  537. } else {
  538. // Defer to initialize these global variables at the beginning of the
  539. // entry function.
  540. toInitGloalVars.push_back(decl);
  541. }
  542. } else {
  543. storeValue(varId, loadIfGLValue(decl->getInit()), decl->getType());
  544. }
  545. }
  546. } else {
  547. varId = declIdMapper.createExternVar(decl);
  548. }
  549. if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
  550. theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
  551. }
  552. if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType()))
  553. needsLegalization = true;
  554. }
  555. spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Attr &attr) {
  556. switch (attr.getKind()) {
  557. case attr::HLSLLoop:
  558. case attr::HLSLFastOpt:
  559. return spv::LoopControlMask::DontUnroll;
  560. case attr::HLSLUnroll:
  561. return spv::LoopControlMask::Unroll;
  562. case attr::HLSLAllowUAVCondition:
  563. emitWarning("Unsupported allow_uav_condition attribute ignored.");
  564. break;
  565. default:
  566. emitError("Found unknown loop attribute.");
  567. }
  568. return spv::LoopControlMask::MaskNone;
  569. }
  570. void SPIRVEmitter::doDiscardStmt(const DiscardStmt *discardStmt) {
  571. assert(!theBuilder.isCurrentBasicBlockTerminated());
  572. theBuilder.createKill();
  573. // Some statements that alter the control flow (break, continue, return, and
  574. // discard), require creation of a new basic block to hold any statement that
  575. // may follow them.
  576. const uint32_t newBB = theBuilder.createBasicBlock();
  577. theBuilder.setInsertPoint(newBB);
  578. }
  579. void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt,
  580. llvm::ArrayRef<const Attr *> attrs) {
  581. // do-while loops are composed of:
  582. //
  583. // do {
  584. // <body>
  585. // } while(<check>);
  586. //
  587. // SPIR-V requires loops to have a merge basic block as well as a continue
  588. // basic block. Even though do-while loops do not have an explicit continue
  589. // block as in for-loops, we still do need to create a continue block.
  590. //
  591. // Since SPIR-V requires structured control flow, we need two more basic
  592. // blocks, <header> and <merge>. <header> is the block before control flow
  593. // diverges, and <merge> is the block where control flow subsequently
  594. // converges. The <check> can be performed in the <continue> basic block.
  595. // The final CFG should normally be like the following. Exceptions
  596. // will occur with non-local exits like loop breaks or early returns.
  597. //
  598. // +----------+
  599. // | header | <-----------------------------------+
  600. // +----------+ |
  601. // | | (true)
  602. // v |
  603. // +------+ +--------------------+ |
  604. // | body | ----> | continue (<check>) |-----------+
  605. // +------+ +--------------------+
  606. // |
  607. // | (false)
  608. // +-------+ |
  609. // | merge | <-------------+
  610. // +-------+
  611. //
  612. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  613. const spv::LoopControlMask loopControl =
  614. attrs.empty() ? spv::LoopControlMask::MaskNone
  615. : translateLoopAttribute(*attrs.front());
  616. // Create basic blocks
  617. const uint32_t headerBB = theBuilder.createBasicBlock("do_while.header");
  618. const uint32_t bodyBB = theBuilder.createBasicBlock("do_while.body");
  619. const uint32_t continueBB = theBuilder.createBasicBlock("do_while.continue");
  620. const uint32_t mergeBB = theBuilder.createBasicBlock("do_while.merge");
  621. // Make sure any continue statements branch to the continue block, and any
  622. // break statements branch to the merge block.
  623. continueStack.push(continueBB);
  624. breakStack.push(mergeBB);
  625. // Branch from the current insert point to the header block.
  626. theBuilder.createBranch(headerBB);
  627. theBuilder.addSuccessor(headerBB);
  628. // Process the <header> block
  629. // The header block must always branch to the body.
  630. theBuilder.setInsertPoint(headerBB);
  631. theBuilder.createBranch(bodyBB, mergeBB, continueBB, loopControl);
  632. theBuilder.addSuccessor(bodyBB);
  633. // The current basic block has OpLoopMerge instruction. We need to set its
  634. // continue and merge target.
  635. theBuilder.setContinueTarget(continueBB);
  636. theBuilder.setMergeTarget(mergeBB);
  637. // Process the <body> block
  638. theBuilder.setInsertPoint(bodyBB);
  639. if (const Stmt *body = theDoStmt->getBody()) {
  640. doStmt(body);
  641. }
  642. if (!theBuilder.isCurrentBasicBlockTerminated())
  643. theBuilder.createBranch(continueBB);
  644. theBuilder.addSuccessor(continueBB);
  645. // Process the <continue> block. The check for whether the loop should
  646. // continue lies in the continue block.
  647. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in
  648. // a continue block of a loop, there should be no OpSelectionMerge. Only an
  649. // OpBranchConditional must be specified.
  650. theBuilder.setInsertPoint(continueBB);
  651. uint32_t condition = 0;
  652. if (const Expr *check = theDoStmt->getCond()) {
  653. condition = doExpr(check);
  654. } else {
  655. condition = theBuilder.getConstantBool(true);
  656. }
  657. theBuilder.createConditionalBranch(condition, headerBB, mergeBB);
  658. theBuilder.addSuccessor(headerBB);
  659. theBuilder.addSuccessor(mergeBB);
  660. // Set insertion point to the <merge> block for subsequent statements
  661. theBuilder.setInsertPoint(mergeBB);
  662. // Done with the current scope's continue block and merge block.
  663. continueStack.pop();
  664. breakStack.pop();
  665. }
  666. void SPIRVEmitter::doContinueStmt(const ContinueStmt *continueStmt) {
  667. assert(!theBuilder.isCurrentBasicBlockTerminated());
  668. const uint32_t continueTargetBB = continueStack.top();
  669. theBuilder.createBranch(continueTargetBB);
  670. theBuilder.addSuccessor(continueTargetBB);
  671. // Some statements that alter the control flow (break, continue, return, and
  672. // discard), require creation of a new basic block to hold any statement that
  673. // may follow them. For example: StmtB and StmtC below are put inside a new
  674. // basic block which is unreachable.
  675. //
  676. // while (true) {
  677. // StmtA;
  678. // continue;
  679. // StmtB;
  680. // StmtC;
  681. // }
  682. const uint32_t newBB = theBuilder.createBasicBlock();
  683. theBuilder.setInsertPoint(newBB);
  684. }
  685. void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt,
  686. llvm::ArrayRef<const Attr *> attrs) {
  687. // While loops are composed of:
  688. // while (<check>) { <body> }
  689. //
  690. // SPIR-V requires loops to have a merge basic block as well as a continue
  691. // basic block. Even though while loops do not have an explicit continue
  692. // block as in for-loops, we still do need to create a continue block.
  693. //
  694. // Since SPIR-V requires structured control flow, we need two more basic
  695. // blocks, <header> and <merge>. <header> is the block before control flow
  696. // diverges, and <merge> is the block where control flow subsequently
  697. // converges. The <check> block can take the responsibility of the <header>
  698. // block. The final CFG should normally be like the following. Exceptions
  699. // will occur with non-local exits like loop breaks or early returns.
  700. //
  701. // +----------+
  702. // | header | <------------------+
  703. // | (check) | |
  704. // +----------+ |
  705. // | |
  706. // +-------+-------+ |
  707. // | false | true |
  708. // | v |
  709. // | +------+ +------------------+
  710. // | | body | --> | continue (no-op) |
  711. // v +------+ +------------------+
  712. // +-------+
  713. // | merge |
  714. // +-------+
  715. //
  716. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  717. const spv::LoopControlMask loopControl =
  718. attrs.empty() ? spv::LoopControlMask::MaskNone
  719. : translateLoopAttribute(*attrs.front());
  720. // Create basic blocks
  721. const uint32_t checkBB = theBuilder.createBasicBlock("while.check");
  722. const uint32_t bodyBB = theBuilder.createBasicBlock("while.body");
  723. const uint32_t continueBB = theBuilder.createBasicBlock("while.continue");
  724. const uint32_t mergeBB = theBuilder.createBasicBlock("while.merge");
  725. // Make sure any continue statements branch to the continue block, and any
  726. // break statements branch to the merge block.
  727. continueStack.push(continueBB);
  728. breakStack.push(mergeBB);
  729. // Process the <check> block
  730. theBuilder.createBranch(checkBB);
  731. theBuilder.addSuccessor(checkBB);
  732. theBuilder.setInsertPoint(checkBB);
  733. // If we have:
  734. // while (int a = foo()) {...}
  735. // we should evaluate 'a' by calling 'foo()' every single time the check has
  736. // to occur.
  737. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt())
  738. doStmt(condVarDecl);
  739. uint32_t condition = 0;
  740. if (const Expr *check = whileStmt->getCond()) {
  741. condition = doExpr(check);
  742. } else {
  743. condition = theBuilder.getConstantBool(true);
  744. }
  745. theBuilder.createConditionalBranch(condition, bodyBB,
  746. /*false branch*/ mergeBB,
  747. /*merge*/ mergeBB, continueBB,
  748. spv::SelectionControlMask::MaskNone,
  749. loopControl);
  750. theBuilder.addSuccessor(bodyBB);
  751. theBuilder.addSuccessor(mergeBB);
  752. // The current basic block has OpLoopMerge instruction. We need to set its
  753. // continue and merge target.
  754. theBuilder.setContinueTarget(continueBB);
  755. theBuilder.setMergeTarget(mergeBB);
  756. // Process the <body> block
  757. theBuilder.setInsertPoint(bodyBB);
  758. if (const Stmt *body = whileStmt->getBody()) {
  759. doStmt(body);
  760. }
  761. if (!theBuilder.isCurrentBasicBlockTerminated())
  762. theBuilder.createBranch(continueBB);
  763. theBuilder.addSuccessor(continueBB);
  764. // Process the <continue> block. While loops do not have an explicit
  765. // continue block. The continue block just branches to the <check> block.
  766. theBuilder.setInsertPoint(continueBB);
  767. theBuilder.createBranch(checkBB);
  768. theBuilder.addSuccessor(checkBB);
  769. // Set insertion point to the <merge> block for subsequent statements
  770. theBuilder.setInsertPoint(mergeBB);
  771. // Done with the current scope's continue and merge blocks.
  772. continueStack.pop();
  773. breakStack.pop();
  774. }
  775. void SPIRVEmitter::doForStmt(const ForStmt *forStmt,
  776. llvm::ArrayRef<const Attr *> attrs) {
  777. // for loops are composed of:
  778. // for (<init>; <check>; <continue>) <body>
  779. //
  780. // To translate a for loop, we'll need to emit all <init> statements
  781. // in the current basic block, and then have separate basic blocks for
  782. // <check>, <continue>, and <body>. Besides, since SPIR-V requires
  783. // structured control flow, we need two more basic blocks, <header>
  784. // and <merge>. <header> is the block before control flow diverges,
  785. // while <merge> is the block where control flow subsequently converges.
  786. // The <check> block can take the responsibility of the <header> block.
  787. // The final CFG should normally be like the following. Exceptions will
  788. // occur with non-local exits like loop breaks or early returns.
  789. // +--------+
  790. // | init |
  791. // +--------+
  792. // |
  793. // v
  794. // +----------+
  795. // | header | <---------------+
  796. // | (check) | |
  797. // +----------+ |
  798. // | |
  799. // +-------+-------+ |
  800. // | false | true |
  801. // | v |
  802. // | +------+ +----------+
  803. // | | body | --> | continue |
  804. // v +------+ +----------+
  805. // +-------+
  806. // | merge |
  807. // +-------+
  808. //
  809. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  810. const spv::LoopControlMask loopControl =
  811. attrs.empty() ? spv::LoopControlMask::MaskNone
  812. : translateLoopAttribute(*attrs.front());
  813. // Create basic blocks
  814. const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
  815. const uint32_t bodyBB = theBuilder.createBasicBlock("for.body");
  816. const uint32_t continueBB = theBuilder.createBasicBlock("for.continue");
  817. const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge");
  818. // Make sure any continue statements branch to the continue block, and any
  819. // break statements branch to the merge block.
  820. continueStack.push(continueBB);
  821. breakStack.push(mergeBB);
  822. // Process the <init> block
  823. if (const Stmt *initStmt = forStmt->getInit()) {
  824. doStmt(initStmt);
  825. }
  826. theBuilder.createBranch(checkBB);
  827. theBuilder.addSuccessor(checkBB);
  828. // Process the <check> block
  829. theBuilder.setInsertPoint(checkBB);
  830. uint32_t condition;
  831. if (const Expr *check = forStmt->getCond()) {
  832. condition = doExpr(check);
  833. } else {
  834. condition = theBuilder.getConstantBool(true);
  835. }
  836. theBuilder.createConditionalBranch(condition, bodyBB,
  837. /*false branch*/ mergeBB,
  838. /*merge*/ mergeBB, continueBB,
  839. spv::SelectionControlMask::MaskNone,
  840. loopControl);
  841. theBuilder.addSuccessor(bodyBB);
  842. theBuilder.addSuccessor(mergeBB);
  843. // The current basic block has OpLoopMerge instruction. We need to set its
  844. // continue and merge target.
  845. theBuilder.setContinueTarget(continueBB);
  846. theBuilder.setMergeTarget(mergeBB);
  847. // Process the <body> block
  848. theBuilder.setInsertPoint(bodyBB);
  849. if (const Stmt *body = forStmt->getBody()) {
  850. doStmt(body);
  851. }
  852. if (!theBuilder.isCurrentBasicBlockTerminated())
  853. theBuilder.createBranch(continueBB);
  854. theBuilder.addSuccessor(continueBB);
  855. // Process the <continue> block
  856. theBuilder.setInsertPoint(continueBB);
  857. if (const Expr *cont = forStmt->getInc()) {
  858. doExpr(cont);
  859. }
  860. theBuilder.createBranch(checkBB); // <continue> should jump back to header
  861. theBuilder.addSuccessor(checkBB);
  862. // Set insertion point to the <merge> block for subsequent statements
  863. theBuilder.setInsertPoint(mergeBB);
  864. // Done with the current scope's continue block and merge block.
  865. continueStack.pop();
  866. breakStack.pop();
  867. }
  868. void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
  869. // if statements are composed of:
  870. // if (<check>) { <then> } else { <else> }
  871. //
  872. // To translate if statements, we'll need to emit the <check> expressions
  873. // in the current basic block, and then create separate basic blocks for
  874. // <then> and <else>. Additionally, we'll need a <merge> block as per
  875. // SPIR-V's structured control flow requirements. Depending whether there
  876. // exists the else branch, the final CFG should normally be like the
  877. // following. Exceptions will occur with non-local exits like loop breaks
  878. // or early returns.
  879. // +-------+ +-------+
  880. // | check | | check |
  881. // +-------+ +-------+
  882. // | |
  883. // +-------+-------+ +-----+-----+
  884. // | true | false | true | false
  885. // v v or v |
  886. // +------+ +------+ +------+ |
  887. // | then | | else | | then | |
  888. // +------+ +------+ +------+ |
  889. // | | | v
  890. // | +-------+ | | +-------+
  891. // +-> | merge | <-+ +---> | merge |
  892. // +-------+ +-------+
  893. { // Try to see if we can const-eval the condition
  894. bool condition = false;
  895. if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) {
  896. if (condition) {
  897. doStmt(ifStmt->getThen());
  898. } else if (ifStmt->getElse()) {
  899. doStmt(ifStmt->getElse());
  900. }
  901. return;
  902. }
  903. }
  904. if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt())
  905. doDeclStmt(declStmt);
  906. // First emit the instruction for evaluating the condition.
  907. const uint32_t condition = doExpr(ifStmt->getCond());
  908. // Then we need to emit the instruction for the conditional branch.
  909. // We'll need the <label-id> for the then/else/merge block to do so.
  910. const bool hasElse = ifStmt->getElse() != nullptr;
  911. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  912. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  913. const uint32_t elseBB =
  914. hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB;
  915. // Create the branch instruction. This will end the current basic block.
  916. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
  917. theBuilder.addSuccessor(thenBB);
  918. theBuilder.addSuccessor(elseBB);
  919. // The current basic block has the OpSelectionMerge instruction. We need
  920. // to record its merge target.
  921. theBuilder.setMergeTarget(mergeBB);
  922. // Handle the then branch
  923. theBuilder.setInsertPoint(thenBB);
  924. doStmt(ifStmt->getThen());
  925. if (!theBuilder.isCurrentBasicBlockTerminated())
  926. theBuilder.createBranch(mergeBB);
  927. theBuilder.addSuccessor(mergeBB);
  928. // Handle the else branch (if exists)
  929. if (hasElse) {
  930. theBuilder.setInsertPoint(elseBB);
  931. doStmt(ifStmt->getElse());
  932. if (!theBuilder.isCurrentBasicBlockTerminated())
  933. theBuilder.createBranch(mergeBB);
  934. theBuilder.addSuccessor(mergeBB);
  935. }
  936. // From now on, we'll emit instructions into the merge block.
  937. theBuilder.setInsertPoint(mergeBB);
  938. }
  939. void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
  940. if (const auto *retVal = stmt->getRetValue()) {
  941. const auto retInfo = doExpr(retVal);
  942. const auto retType = retVal->getType();
  943. if (retInfo.storageClass != spv::StorageClass::Function &&
  944. retType->isStructureType()) {
  945. // We are returning some value from a non-Function storage class. Need to
  946. // create a temporary variable to "convert" the value to Function storage
  947. // class and then return.
  948. const uint32_t valType = typeTranslator.translateType(retType);
  949. const uint32_t tempVar = theBuilder.addFnVar(valType, "temp.var.ret");
  950. storeValue(tempVar, retInfo, retType);
  951. theBuilder.createReturnValue(theBuilder.createLoad(valType, tempVar));
  952. } else {
  953. theBuilder.createReturnValue(retInfo);
  954. }
  955. } else {
  956. theBuilder.createReturn();
  957. }
  958. // Some statements that alter the control flow (break, continue, return, and
  959. // discard), require creation of a new basic block to hold any statement that
  960. // may follow them. In this case, the newly created basic block will contain
  961. // any statement that may come after an early return.
  962. const uint32_t newBB = theBuilder.createBasicBlock();
  963. theBuilder.setInsertPoint(newBB);
  964. }
  965. void SPIRVEmitter::doBreakStmt(const BreakStmt *breakStmt) {
  966. assert(!theBuilder.isCurrentBasicBlockTerminated());
  967. uint32_t breakTargetBB = breakStack.top();
  968. theBuilder.addSuccessor(breakTargetBB);
  969. theBuilder.createBranch(breakTargetBB);
  970. // Some statements that alter the control flow (break, continue, return, and
  971. // discard), require creation of a new basic block to hold any statement that
  972. // may follow them. For example: StmtB and StmtC below are put inside a new
  973. // basic block which is unreachable.
  974. //
  975. // while (true) {
  976. // StmtA;
  977. // break;
  978. // StmtB;
  979. // StmtC;
  980. // }
  981. const uint32_t newBB = theBuilder.createBasicBlock();
  982. theBuilder.setInsertPoint(newBB);
  983. }
  984. void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
  985. llvm::ArrayRef<const Attr *> attrs) {
  986. // Switch statements are composed of:
  987. // switch (<condition variable>) {
  988. // <CaseStmt>
  989. // <CaseStmt>
  990. // <CaseStmt>
  991. // <DefaultStmt> (optional)
  992. // }
  993. //
  994. // +-------+
  995. // | check |
  996. // +-------+
  997. // |
  998. // +-------+-------+----------------+---------------+
  999. // | 1 | 2 | 3 | (others)
  1000. // v v v v
  1001. // +-------+ +-------------+ +-------+ +------------+
  1002. // | case1 | | case2 | | case3 | ... | default |
  1003. // | | |(fallthrough)|---->| | | (optional) |
  1004. // +-------+ |+------------+ +-------+ +------------+
  1005. // | | |
  1006. // | | |
  1007. // | +-------+ | |
  1008. // | | | <--------------------+ |
  1009. // +-> | merge | |
  1010. // | | <-------------------------------------+
  1011. // +-------+
  1012. // If no attributes are given, or if "forcecase" attribute was provided,
  1013. // we'll do our best to use OpSwitch if possible.
  1014. // If any of the cases compares to a variable (rather than an integer
  1015. // literal), we cannot use OpSwitch because OpSwitch expects literal
  1016. // numbers as parameters.
  1017. const bool isAttrForceCase =
  1018. !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase;
  1019. const bool canUseSpirvOpSwitch =
  1020. (attrs.empty() || isAttrForceCase) &&
  1021. allSwitchCasesAreIntegerLiterals(switchStmt->getBody());
  1022. if (isAttrForceCase && !canUseSpirvOpSwitch)
  1023. emitWarning("Ignored 'forcecase' attribute for the switch statement "
  1024. "since one or more case values are not integer literals.");
  1025. if (canUseSpirvOpSwitch)
  1026. processSwitchStmtUsingSpirvOpSwitch(switchStmt);
  1027. else
  1028. processSwitchStmtUsingIfStmts(switchStmt);
  1029. }
  1030. SpirvEvalInfo
  1031. SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
  1032. llvm::SmallVector<uint32_t, 4> indices;
  1033. const auto *base = collectArrayStructIndices(expr, &indices);
  1034. auto info = doExpr(base);
  1035. const uint32_t ptrType = theBuilder.getPointerType(
  1036. typeTranslator.translateType(expr->getType(), info.layoutRule),
  1037. info.storageClass);
  1038. info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
  1039. return info;
  1040. }
  1041. SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
  1042. const auto opcode = expr->getOpcode();
  1043. // Handle assignment first since we need to evaluate rhs before lhs.
  1044. // For other binary operations, we need to evaluate lhs before rhs.
  1045. if (opcode == BO_Assign) {
  1046. return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()),
  1047. false);
  1048. }
  1049. // Try to optimize floatMxN * float and floatN * float case
  1050. if (opcode == BO_Mul) {
  1051. if (const SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  1052. return result;
  1053. if (const SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  1054. return result;
  1055. }
  1056. const uint32_t resultType = typeTranslator.translateType(expr->getType());
  1057. return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode, resultType);
  1058. }
  1059. SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
  1060. if (const auto *operatorCall = dyn_cast<CXXOperatorCallExpr>(callExpr))
  1061. return doCXXOperatorCallExpr(operatorCall);
  1062. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr))
  1063. return doCXXMemberCallExpr(memberCall);
  1064. // Intrinsic functions such as 'dot' or 'mul'
  1065. if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) {
  1066. return processIntrinsicCallExpr(callExpr);
  1067. }
  1068. // Normal standalone functions
  1069. return processCall(callExpr);
  1070. }
  1071. uint32_t SPIRVEmitter::processCall(const CallExpr *callExpr) {
  1072. const FunctionDecl *callee = callExpr->getDirectCallee();
  1073. if (callee) {
  1074. const auto numParams = callee->getNumParams();
  1075. bool isNonStaticMemberCall = false;
  1076. llvm::SmallVector<uint32_t, 4> params; // Temporary variables
  1077. llvm::SmallVector<uint32_t, 4> args; // Evaluated arguments
  1078. if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
  1079. isNonStaticMemberCall =
  1080. !cast<CXXMethodDecl>(memberCall->getCalleeDecl())->isStatic();
  1081. if (isNonStaticMemberCall) {
  1082. // For non-static member calls, evaluate the object and pass it as the
  1083. // first argument.
  1084. const auto *object = memberCall->getImplicitObjectArgument();
  1085. args.push_back(doExpr(object));
  1086. // We do not need to create a new temporary variable for the this
  1087. // object. Use the evaluated argument.
  1088. params.push_back(args.back());
  1089. }
  1090. }
  1091. // Evaluate parameters
  1092. for (uint32_t i = 0; i < numParams; ++i) {
  1093. const auto *arg = callExpr->getArg(i);
  1094. const auto *param = callee->getParamDecl(i);
  1095. // We need to create variables for holding the values to be used as
  1096. // arguments. The variables themselves are of pointer types.
  1097. const uint32_t varType = typeTranslator.translateType(arg->getType());
  1098. const std::string varName = "param.var." + param->getNameAsString();
  1099. const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
  1100. params.push_back(tempVarId);
  1101. args.push_back(doExpr(arg));
  1102. if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
  1103. // The current parameter is marked as out/inout. The argument then is
  1104. // essentially passed in by reference. We need to load the value
  1105. // explicitly here since the AST won't inject LValueToRValue implicit
  1106. // cast for this case.
  1107. const uint32_t value = theBuilder.createLoad(varType, args.back());
  1108. theBuilder.createStore(tempVarId, value);
  1109. } else {
  1110. theBuilder.createStore(tempVarId, args.back());
  1111. }
  1112. }
  1113. // Push the callee into the work queue if it is not there.
  1114. if (!workQueue.count(callee)) {
  1115. workQueue.insert(callee);
  1116. }
  1117. const uint32_t retType = typeTranslator.translateType(callExpr->getType());
  1118. // Get or forward declare the function <result-id>
  1119. const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
  1120. const uint32_t retVal =
  1121. theBuilder.createFunctionCall(retType, funcId, params);
  1122. // Go through all parameters and write those marked as out/inout
  1123. for (uint32_t i = 0; i < numParams; ++i) {
  1124. const auto *param = callee->getParamDecl(i);
  1125. if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
  1126. const uint32_t index = i + isNonStaticMemberCall;
  1127. const uint32_t typeId = typeTranslator.translateType(param->getType());
  1128. const uint32_t value = theBuilder.createLoad(typeId, params[index]);
  1129. theBuilder.createStore(args[index], value);
  1130. }
  1131. }
  1132. return retVal;
  1133. }
  1134. emitError("calling non-function unimplemented");
  1135. return 0;
  1136. }
  1137. SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
  1138. const Expr *subExpr = expr->getSubExpr();
  1139. const QualType toType = expr->getType();
  1140. switch (expr->getCastKind()) {
  1141. case CastKind::CK_LValueToRValue: {
  1142. auto info = doExpr(subExpr);
  1143. if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr) ||
  1144. isBufferTextureIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr)) ||
  1145. isTextureMipsSampleIndexing(dyn_cast<CXXOperatorCallExpr>(subExpr))) {
  1146. // By reaching here, it means the vector/matrix/Buffer/RWBuffer/RWTexture
  1147. // element accessing operation is an lvalue. For vector element accessing,
  1148. // if we generated a vector shuffle for it and trying to use it as a
  1149. // rvalue, we cannot do the load here as normal. Need the upper nodes in
  1150. // the AST tree to handle it properly. For matrix element accessing, load
  1151. // should have already happened after creating access chain for each
  1152. // element. For (RW)Buffer/RWTexture element accessing, load should have
  1153. // already happened using OpImageFetch.
  1154. return info;
  1155. }
  1156. // Using lvalue as rvalue means we need to OpLoad the contents from
  1157. // the parameter/variable first.
  1158. info.resultId = theBuilder.createLoad(
  1159. typeTranslator.translateType(expr->getType(), info.layoutRule), info);
  1160. return info;
  1161. }
  1162. case CastKind::CK_NoOp:
  1163. return doExpr(subExpr);
  1164. case CastKind::CK_IntegralCast:
  1165. case CastKind::CK_FloatingToIntegral:
  1166. case CastKind::CK_HLSLCC_IntegralCast:
  1167. case CastKind::CK_HLSLCC_FloatingToIntegral: {
  1168. // Integer literals in the AST are represented using 64bit APInt
  1169. // themselves and then implicitly casted into the expected bitwidth.
  1170. // We need special treatment of integer literals here because generating
  1171. // a 64bit constant and then explicit casting in SPIR-V requires Int64
  1172. // capability. We should avoid introducing unnecessary capabilities to
  1173. // our best.
  1174. llvm::APSInt intValue;
  1175. if (expr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects)) {
  1176. return translateAPInt(intValue, toType);
  1177. }
  1178. return castToInt(doExpr(subExpr), subExpr->getType(), toType);
  1179. }
  1180. case CastKind::CK_FloatingCast:
  1181. case CastKind::CK_IntegralToFloating:
  1182. case CastKind::CK_HLSLCC_FloatingCast:
  1183. case CastKind::CK_HLSLCC_IntegralToFloating: {
  1184. // First try to see if we can do constant folding for floating point
  1185. // numbers like what we are doing for integers in the above.
  1186. Expr::EvalResult evalResult;
  1187. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  1188. !evalResult.HasSideEffects) {
  1189. return translateAPFloat(evalResult.Val.getFloat(), toType);
  1190. }
  1191. return castToFloat(doExpr(subExpr), subExpr->getType(), toType);
  1192. }
  1193. case CastKind::CK_IntegralToBoolean:
  1194. case CastKind::CK_FloatingToBoolean:
  1195. case CastKind::CK_HLSLCC_IntegralToBoolean:
  1196. case CastKind::CK_HLSLCC_FloatingToBoolean: {
  1197. // First try to see if we can do constant folding.
  1198. bool boolVal;
  1199. if (!expr->HasSideEffects(astContext) &&
  1200. expr->EvaluateAsBooleanCondition(boolVal, astContext)) {
  1201. return theBuilder.getConstantBool(boolVal);
  1202. }
  1203. return castToBool(doExpr(subExpr), subExpr->getType(), toType);
  1204. }
  1205. case CastKind::CK_HLSLVectorSplat: {
  1206. const size_t size = hlsl::GetHLSLVecSize(expr->getType());
  1207. return createVectorSplat(subExpr, size);
  1208. }
  1209. case CastKind::CK_HLSLVectorTruncationCast: {
  1210. const uint32_t toVecTypeId = typeTranslator.translateType(toType);
  1211. const uint32_t elemTypeId =
  1212. typeTranslator.translateType(hlsl::GetHLSLVecElementType(toType));
  1213. const auto toSize = hlsl::GetHLSLVecSize(toType);
  1214. const uint32_t composite = doExpr(subExpr);
  1215. llvm::SmallVector<uint32_t, 4> elements;
  1216. for (uint32_t i = 0; i < toSize; ++i) {
  1217. elements.push_back(
  1218. theBuilder.createCompositeExtract(elemTypeId, composite, {i}));
  1219. }
  1220. if (toSize == 1) {
  1221. return elements.front();
  1222. }
  1223. return theBuilder.createCompositeConstruct(toVecTypeId, elements);
  1224. }
  1225. case CastKind::CK_HLSLVectorToScalarCast: {
  1226. // The underlying should already be a vector of size 1.
  1227. assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
  1228. return doExpr(subExpr);
  1229. }
  1230. case CastKind::CK_HLSLVectorToMatrixCast: {
  1231. // The target type should already be a 1xN matrix type.
  1232. assert(TypeTranslator::is1xNMatrix(toType));
  1233. return doExpr(subExpr);
  1234. }
  1235. case CastKind::CK_HLSLMatrixSplat: {
  1236. // From scalar to matrix
  1237. uint32_t rowCount = 0, colCount = 0;
  1238. hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount);
  1239. // Handle degenerated cases first
  1240. if (rowCount == 1 && colCount == 1)
  1241. return doExpr(subExpr);
  1242. if (colCount == 1)
  1243. return createVectorSplat(subExpr, rowCount);
  1244. const auto vecSplat = createVectorSplat(subExpr, colCount);
  1245. if (rowCount == 1)
  1246. return vecSplat;
  1247. const uint32_t matType = typeTranslator.translateType(toType);
  1248. llvm::SmallVector<uint32_t, 4> vectors(size_t(rowCount), vecSplat);
  1249. if (vecSplat.isConst) {
  1250. return SpirvEvalInfo::withConst(
  1251. theBuilder.getConstantComposite(matType, vectors));
  1252. } else {
  1253. return theBuilder.createCompositeConstruct(matType, vectors);
  1254. }
  1255. }
  1256. case CastKind::CK_HLSLMatrixToScalarCast: {
  1257. // The underlying should already be a matrix of 1x1.
  1258. assert(TypeTranslator::is1x1Matrix(subExpr->getType()));
  1259. return doExpr(subExpr);
  1260. }
  1261. case CastKind::CK_HLSLMatrixToVectorCast: {
  1262. // The underlying should already be a matrix of 1xN.
  1263. assert(TypeTranslator::is1xNMatrix(subExpr->getType()) ||
  1264. TypeTranslator::isMx1Matrix(subExpr->getType()));
  1265. return doExpr(subExpr);
  1266. }
  1267. case CastKind::CK_FunctionToPointerDecay:
  1268. // Just need to return the function id
  1269. return doExpr(subExpr);
  1270. default:
  1271. emitError("ImplictCast Kind '%0' is not supported yet.")
  1272. << expr->getCastKindName();
  1273. expr->dump();
  1274. return 0;
  1275. }
  1276. }
  1277. SpirvEvalInfo
  1278. SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
  1279. const auto opcode = expr->getOpcode();
  1280. // Try to optimize floatMxN *= float and floatN *= float case
  1281. if (opcode == BO_MulAssign) {
  1282. if (const SpirvEvalInfo result = tryToGenFloatMatrixScale(expr))
  1283. return result;
  1284. if (const SpirvEvalInfo result = tryToGenFloatVectorScale(expr))
  1285. return result;
  1286. }
  1287. const auto *rhs = expr->getRHS();
  1288. const auto *lhs = expr->getLHS();
  1289. SpirvEvalInfo lhsPtr = 0;
  1290. const uint32_t resultType = typeTranslator.translateType(expr->getType());
  1291. const auto result = processBinaryOp(lhs, rhs, opcode, resultType, &lhsPtr);
  1292. return processAssignment(lhs, result, true, lhsPtr);
  1293. }
  1294. uint32_t SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
  1295. // According to HLSL doc, all sides of the ?: expression are always
  1296. // evaluated.
  1297. const uint32_t type = typeTranslator.translateType(expr->getType());
  1298. const uint32_t condition = doExpr(expr->getCond());
  1299. const uint32_t trueBranch = doExpr(expr->getTrueExpr());
  1300. const uint32_t falseBranch = doExpr(expr->getFalseExpr());
  1301. return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
  1302. }
  1303. uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
  1304. const CXXMemberCallExpr *expr) {
  1305. const auto *object = expr->getImplicitObjectArgument();
  1306. const auto objectId = loadIfGLValue(object);
  1307. const auto type = object->getType();
  1308. const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) ||
  1309. TypeTranslator::isRWByteAddressBuffer(type);
  1310. const bool isStructuredBuffer =
  1311. TypeTranslator::isStructuredBuffer(type) ||
  1312. TypeTranslator::isAppendStructuredBuffer(type) ||
  1313. TypeTranslator::isConsumeStructuredBuffer(type);
  1314. assert(isByteAddressBuffer || isStructuredBuffer);
  1315. // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
  1316. // with only one member that is a runtime array. We need to perform
  1317. // OpArrayLength on member 0.
  1318. const auto uintType = theBuilder.getUint32Type();
  1319. uint32_t length =
  1320. theBuilder.createBinaryOp(spv::Op::OpArrayLength, uintType, objectId, 0);
  1321. // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
  1322. // in bytes, but OpArrayLength returns the number of uints in the runtime
  1323. // array. Therefore we must multiply the results by 4.
  1324. if (isByteAddressBuffer) {
  1325. length = theBuilder.createBinaryOp(spv::Op::OpIMul, uintType, length,
  1326. theBuilder.getConstantUint32(4u));
  1327. }
  1328. theBuilder.createStore(doExpr(expr->getArg(0)), length);
  1329. if (isStructuredBuffer) {
  1330. // For (RW)StructuredBuffer, the stride of the runtime array (which is the
  1331. // size of the struct) must also be written to the second argument.
  1332. uint32_t size = 0, stride = 0;
  1333. std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
  1334. type, LayoutRule::GLSLStd430, /*isRowMajor*/ false, &stride);
  1335. const auto sizeId = theBuilder.getConstantUint32(size);
  1336. theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
  1337. }
  1338. return 0;
  1339. }
  1340. uint32_t SPIRVEmitter::processRWByteAddressBufferAtomicMethods(
  1341. hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) {
  1342. // The signature of RWByteAddressBuffer atomic methods are largely:
  1343. // void Interlocked*(in UINT dest, in UINT value);
  1344. // void Interlocked*(in UINT dest, in UINT value, out UINT original_value);
  1345. const auto *object = expr->getImplicitObjectArgument();
  1346. // We do not need to load the object since we are using its pointers.
  1347. const auto objectInfo = doExpr(object);
  1348. const auto uintType = theBuilder.getUint32Type();
  1349. const uint32_t zero = theBuilder.getConstantUint32(0);
  1350. const uint32_t offset = doExpr(expr->getArg(0));
  1351. // Right shift by 2 to convert the byte offset to uint32_t offset
  1352. const uint32_t address =
  1353. theBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, uintType, offset,
  1354. theBuilder.getConstantUint32(2));
  1355. const auto ptrType =
  1356. theBuilder.getPointerType(uintType, objectInfo.storageClass);
  1357. const uint32_t ptr =
  1358. theBuilder.createAccessChain(ptrType, objectInfo, {zero, address});
  1359. const uint32_t scope = theBuilder.getConstantUint32(1); // Device
  1360. const bool isCompareExchange =
  1361. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange;
  1362. const bool isCompareStore =
  1363. opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore;
  1364. if (isCompareExchange || isCompareStore) {
  1365. const uint32_t comparator = doExpr(expr->getArg(1));
  1366. const uint32_t originalVal = theBuilder.createAtomicCompareExchange(
  1367. uintType, ptr, scope, zero, zero, doExpr(expr->getArg(2)), comparator);
  1368. if (isCompareExchange)
  1369. theBuilder.createStore(doExpr(expr->getArg(3)), originalVal);
  1370. } else {
  1371. const uint32_t value = doExpr(expr->getArg(1));
  1372. const uint32_t originalVal =
  1373. theBuilder.createAtomicOp(translateRWBABufferAtomicMethods(opcode),
  1374. uintType, ptr, scope, zero, value);
  1375. if (expr->getNumArgs() > 2)
  1376. theBuilder.createStore(doExpr(expr->getArg(2)), originalVal);
  1377. }
  1378. return 0;
  1379. }
  1380. uint32_t
  1381. SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
  1382. theBuilder.requireCapability(spv::Capability::ImageQuery);
  1383. const auto *object = expr->getImplicitObjectArgument();
  1384. const auto objectId = loadIfGLValue(object);
  1385. const auto type = object->getType();
  1386. const auto *recType = type->getAs<RecordType>();
  1387. assert(recType);
  1388. const auto typeName = recType->getDecl()->getName();
  1389. const auto numArgs = expr->getNumArgs();
  1390. const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr;
  1391. assert(TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) ||
  1392. TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type));
  1393. // For Texture1D, arguments are either:
  1394. // a) width
  1395. // b) MipLevel, width, NumLevels
  1396. // For Texture1DArray, arguments are either:
  1397. // a) width, elements
  1398. // b) MipLevel, width, elements, NumLevels
  1399. // For Texture2D, arguments are either:
  1400. // a) width, height
  1401. // b) MipLevel, width, height, NumLevels
  1402. // For Texture2DArray, arguments are either:
  1403. // a) width, height, elements
  1404. // b) MipLevel, width, height, elements, NumLevels
  1405. // For Texture3D, arguments are either:
  1406. // a) width, height, depth
  1407. // b) MipLevel, width, height, depth, NumLevels
  1408. // For Texture2DMS, arguments are: width, height, NumSamples
  1409. // For Texture2DMSArray, arguments are: width, height, elements, NumSamples
  1410. if ((typeName == "Texture1D" && numArgs > 1) ||
  1411. (typeName == "Texture2D" && numArgs > 2) ||
  1412. (typeName == "Texture3D" && numArgs > 3) ||
  1413. (typeName == "Texture1DArray" && numArgs > 2) ||
  1414. (typeName == "Texture2DArray" && numArgs > 3)) {
  1415. mipLevel = expr->getArg(0);
  1416. numLevels = expr->getArg(numArgs - 1);
  1417. }
  1418. if (TypeTranslator::isTextureMS(type)) {
  1419. numSamples = expr->getArg(numArgs - 1);
  1420. }
  1421. uint32_t querySize = numArgs;
  1422. // If numLevels arg is present, mipLevel must also be present. These are not
  1423. // queried via ImageQuerySizeLod.
  1424. if (numLevels)
  1425. querySize -= 2;
  1426. // If numLevels arg is present, mipLevel must also be present.
  1427. else if (numSamples)
  1428. querySize -= 1;
  1429. const uint32_t uintId = theBuilder.getUint32Type();
  1430. const uint32_t resultTypeId =
  1431. querySize == 1 ? uintId : theBuilder.getVecType(uintId, querySize);
  1432. // Only Texture types use ImageQuerySizeLod.
  1433. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize.
  1434. uint32_t lod = 0;
  1435. if (TypeTranslator::isTexture(type) && !numSamples) {
  1436. if (mipLevel) {
  1437. // For Texture types when mipLevel argument is present.
  1438. lod = doExpr(mipLevel);
  1439. } else {
  1440. // For Texture types when mipLevel argument is omitted.
  1441. lod = theBuilder.getConstantInt32(0);
  1442. }
  1443. }
  1444. const uint32_t query =
  1445. lod
  1446. ? theBuilder.createBinaryOp(spv::Op::OpImageQuerySizeLod,
  1447. resultTypeId, objectId, lod)
  1448. : theBuilder.createUnaryOp(spv::Op::OpImageQuerySize, resultTypeId,
  1449. objectId);
  1450. if (querySize == 1) {
  1451. const uint32_t argIndex = mipLevel ? 1 : 0;
  1452. theBuilder.createStore(doExpr(expr->getArg(argIndex)), query);
  1453. } else {
  1454. for (uint32_t i = 0; i < querySize; ++i) {
  1455. const uint32_t component =
  1456. theBuilder.createCompositeExtract(uintId, query, {i});
  1457. // If the first arg is the mipmap level, we must write the results
  1458. // starting from Arg(i+1), not Arg(i).
  1459. const uint32_t argIndex = mipLevel ? i + 1 : i;
  1460. theBuilder.createStore(doExpr(expr->getArg(argIndex)), component);
  1461. }
  1462. }
  1463. if (numLevels || numSamples) {
  1464. const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
  1465. const spv::Op opcode =
  1466. numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
  1467. const uint32_t resultType =
  1468. typeTranslator.translateType(numLevelsSamplesArg->getType());
  1469. const uint32_t numLevelsSamplesQuery =
  1470. theBuilder.createUnaryOp(opcode, resultType, objectId);
  1471. theBuilder.createStore(doExpr(numLevelsSamplesArg), numLevelsSamplesQuery);
  1472. }
  1473. return 0;
  1474. }
  1475. uint32_t
  1476. SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr) {
  1477. // Possible signatures are as follows:
  1478. // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x);
  1479. // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy);
  1480. // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz);
  1481. // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz);
  1482. // Return type is always a single float (LOD).
  1483. assert(expr->getNumArgs() == 2u);
  1484. theBuilder.requireCapability(spv::Capability::ImageQuery);
  1485. const auto *object = expr->getImplicitObjectArgument();
  1486. const uint32_t objectId = loadIfGLValue(object);
  1487. const uint32_t samplerState = doExpr(expr->getArg(0));
  1488. const uint32_t coordinate = doExpr(expr->getArg(1));
  1489. const uint32_t sampledImageType = theBuilder.getSampledImageType(
  1490. typeTranslator.translateType(object->getType()));
  1491. const uint32_t sampledImage = theBuilder.createBinaryOp(
  1492. spv::Op::OpSampledImage, sampledImageType, objectId, samplerState);
  1493. // The result type of OpImageQueryLod must be a float2.
  1494. const uint32_t queryResultType =
  1495. theBuilder.getVecType(theBuilder.getFloat32Type(), 2u);
  1496. const uint32_t query = theBuilder.createBinaryOp(
  1497. spv::Op::OpImageQueryLod, queryResultType, sampledImage, coordinate);
  1498. // The first component of the float2 contains the mipmap array layer.
  1499. return theBuilder.createCompositeExtract(theBuilder.getFloat32Type(), query,
  1500. {0});
  1501. }
  1502. uint32_t SPIRVEmitter::processTextureGatherRGBACmpRGBA(
  1503. const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) {
  1504. // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following
  1505. // two sets:
  1506. // * SamplerState s, float2 location, int2 offset
  1507. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  1508. // int offset2, int2 offset3
  1509. //
  1510. // An additional out uint status parameter can appear in both of the above,
  1511. // which we does not support yet.
  1512. //
  1513. // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following
  1514. // two sets:
  1515. // * SamplerState s, float2 location, int2 offset
  1516. // * SamplerState s, float2 location, int2 offset0, int2 offset1,
  1517. // int offset2, int2 offset3
  1518. //
  1519. // An additional out uint status parameter can appear in both of the above,
  1520. // which we does not support yet.
  1521. //
  1522. // Return type is always a 4-component vector.
  1523. const FunctionDecl *callee = expr->getDirectCallee();
  1524. const auto numArgs = expr->getNumArgs();
  1525. if (numArgs != 3 + isCmp && numArgs != 6 + isCmp) {
  1526. emitError("unsupported '%0' method call with status parameter",
  1527. expr->getExprLoc())
  1528. << callee->getName() << expr->getSourceRange();
  1529. return 0;
  1530. }
  1531. const auto *imageExpr = expr->getImplicitObjectArgument();
  1532. const uint32_t image = loadIfGLValue(imageExpr);
  1533. const uint32_t sampler = doExpr(expr->getArg(0));
  1534. const uint32_t coordinate = doExpr(expr->getArg(1));
  1535. const uint32_t compareVal = isCmp ? doExpr(expr->getArg(2)) : 0;
  1536. uint32_t constOffset = 0, varOffset = 0, constOffsets = 0;
  1537. if (numArgs == 3 + isCmp) {
  1538. // One offset parameter
  1539. handleOptionalOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset);
  1540. } else {
  1541. // Four offset parameters
  1542. const auto offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp));
  1543. const auto offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp));
  1544. const auto offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp));
  1545. const auto offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp));
  1546. // Make sure we can generate the ConstOffsets image operands in SPIR-V.
  1547. if (!offset0 || !offset1 || !offset2 || !offset3) {
  1548. emitError("all offset parameters to '%0' method call must be constants",
  1549. expr->getExprLoc())
  1550. << callee->getName() << expr->getSourceRange();
  1551. return 0;
  1552. }
  1553. const uint32_t v2i32 = theBuilder.getVecType(theBuilder.getInt32Type(), 2);
  1554. const uint32_t offsetType =
  1555. theBuilder.getArrayType(v2i32, theBuilder.getConstantUint32(4));
  1556. constOffsets = theBuilder.getConstantComposite(
  1557. offsetType, {offset0, offset1, offset2, offset3});
  1558. }
  1559. const auto retType = typeTranslator.translateType(callee->getReturnType());
  1560. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  1561. return theBuilder.createImageGather(
  1562. retType, imageType, image, sampler, coordinate,
  1563. theBuilder.getConstantInt32(component), compareVal, constOffset,
  1564. varOffset, constOffsets, /*sampleNumber*/ 0);
  1565. }
  1566. uint32_t SPIRVEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) {
  1567. // Signature:
  1568. //
  1569. // float4 GatherCmp(
  1570. // in SamplerComparisonState s,
  1571. // in float2 location,
  1572. // in float compare_value
  1573. // [,in int2 offset]
  1574. // );
  1575. const FunctionDecl *callee = expr->getDirectCallee();
  1576. const auto numArgs = expr->getNumArgs();
  1577. if (expr->getNumArgs() > 4) {
  1578. emitError("unsupported '%0' method call with status parameter",
  1579. expr->getExprLoc())
  1580. << callee->getName() << expr->getSourceRange();
  1581. return 0;
  1582. }
  1583. const auto *imageExpr = expr->getImplicitObjectArgument();
  1584. const uint32_t image = loadIfGLValue(imageExpr);
  1585. const uint32_t sampler = doExpr(expr->getArg(0));
  1586. const uint32_t coordinate = doExpr(expr->getArg(1));
  1587. const uint32_t comparator = doExpr(expr->getArg(2));
  1588. uint32_t constOffset = 0, varOffset = 0;
  1589. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  1590. const auto retType = typeTranslator.translateType(callee->getReturnType());
  1591. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  1592. return theBuilder.createImageGather(
  1593. retType, imageType, image, sampler, coordinate,
  1594. /*component*/ 0, comparator, constOffset, varOffset, /*constOffsets*/ 0,
  1595. /*sampleNumber*/ 0);
  1596. }
  1597. uint32_t SPIRVEmitter::processBufferTextureLoad(const Expr *object,
  1598. const uint32_t locationId,
  1599. uint32_t constOffset,
  1600. uint32_t varOffset,
  1601. uint32_t lod) {
  1602. // Loading for Buffer and RWBuffer translates to an OpImageFetch.
  1603. // The result type of an OpImageFetch must be a vec4 of float or int.
  1604. const auto type = object->getType();
  1605. assert(TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type) ||
  1606. TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type));
  1607. const bool doFetch =
  1608. TypeTranslator::isBuffer(type) || TypeTranslator::isTexture(type);
  1609. const uint32_t objectId = loadIfGLValue(object);
  1610. // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod.
  1611. uint32_t sampleNumber = 0;
  1612. if (TypeTranslator::isTextureMS(type)) {
  1613. sampleNumber = lod;
  1614. lod = 0;
  1615. }
  1616. const auto sampledType = hlsl::GetHLSLResourceResultType(type);
  1617. QualType elemType = sampledType;
  1618. uint32_t elemCount = 1;
  1619. uint32_t elemTypeId = 0;
  1620. (void)TypeTranslator::isVectorType(sampledType, &elemType, &elemCount);
  1621. if (elemType->isFloatingType()) {
  1622. elemTypeId = theBuilder.getFloat32Type();
  1623. } else if (elemType->isSignedIntegerType()) {
  1624. elemTypeId = theBuilder.getInt32Type();
  1625. } else if (elemType->isUnsignedIntegerType()) {
  1626. elemTypeId = theBuilder.getUint32Type();
  1627. } else {
  1628. emitError("Unimplemented Buffer/Texture type");
  1629. return 0;
  1630. }
  1631. const uint32_t resultTypeId =
  1632. elemCount == 1 ? elemTypeId
  1633. : theBuilder.getVecType(elemTypeId, elemCount);
  1634. // OpImageFetch can only fetch a vector of 4 elements. OpImageRead can load a
  1635. // vector of any size.
  1636. const uint32_t fetchTypeId = theBuilder.getVecType(elemTypeId, 4u);
  1637. const uint32_t texel = theBuilder.createImageFetchOrRead(
  1638. doFetch, doFetch ? fetchTypeId : resultTypeId, objectId, locationId, lod,
  1639. constOffset, varOffset, /*constOffsets*/ 0, sampleNumber);
  1640. // OpImageRead can load a vector of any size. So we can return the result of
  1641. // the instruction directly.
  1642. if (!doFetch) {
  1643. return texel;
  1644. }
  1645. // OpImageFetch can only fetch vec4. If the result type is a vec1, vec2, or
  1646. // vec3, some extra processing (extraction) is required.
  1647. switch (elemCount) {
  1648. case 1:
  1649. return theBuilder.createCompositeExtract(elemTypeId, texel, {0});
  1650. case 2:
  1651. return theBuilder.createVectorShuffle(resultTypeId, texel, texel, {0, 1});
  1652. case 3:
  1653. return theBuilder.createVectorShuffle(resultTypeId, texel, texel,
  1654. {0, 1, 2});
  1655. case 4:
  1656. return texel;
  1657. }
  1658. llvm_unreachable("Element count of a vector must be 1, 2, 3, or 4.");
  1659. }
  1660. uint32_t SPIRVEmitter::processByteAddressBufferLoadStore(
  1661. const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) {
  1662. uint32_t resultId = 0;
  1663. const auto object = expr->getImplicitObjectArgument();
  1664. const auto type = object->getType();
  1665. const auto objectInfo = doExpr(object);
  1666. assert(numWords >= 1 && numWords <= 4);
  1667. if (doStore) {
  1668. assert(typeTranslator.isRWByteAddressBuffer(type));
  1669. assert(expr->getNumArgs() == 2);
  1670. } else {
  1671. assert(typeTranslator.isRWByteAddressBuffer(type) ||
  1672. typeTranslator.isByteAddressBuffer(type));
  1673. if (expr->getNumArgs() == 2) {
  1674. emitError("Load(in Address, out Status) has not been implemented for "
  1675. "(RW)ByteAddressBuffer yet.");
  1676. return 0;
  1677. }
  1678. }
  1679. const Expr *addressExpr = expr->getArg(0);
  1680. const uint32_t byteAddress = doExpr(addressExpr);
  1681. const uint32_t addressTypeId =
  1682. typeTranslator.translateType(addressExpr->getType());
  1683. // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory
  1684. // access). The AST always casts the address to unsinged integer, so shift
  1685. // by unsinged integer 2.
  1686. const uint32_t constUint2 = theBuilder.getConstantUint32(2);
  1687. const uint32_t address = theBuilder.createBinaryOp(
  1688. spv::Op::OpShiftRightLogical, addressTypeId, byteAddress, constUint2);
  1689. // Perform access chain into the RWByteAddressBuffer.
  1690. // First index must be zero (member 0 of the struct is a
  1691. // runtimeArray). The second index passed to OpAccessChain should be
  1692. // the address.
  1693. const uint32_t uintTypeId = theBuilder.getUint32Type();
  1694. const uint32_t ptrType =
  1695. theBuilder.getPointerType(uintTypeId, objectInfo.storageClass);
  1696. const uint32_t constUint0 = theBuilder.getConstantUint32(0);
  1697. if (doStore) {
  1698. const uint32_t valuesId = doExpr(expr->getArg(1));
  1699. uint32_t curStoreAddress = address;
  1700. for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) {
  1701. // Extract a 32-bit word from the input.
  1702. const uint32_t curValue = numWords == 1
  1703. ? valuesId
  1704. : theBuilder.createCompositeExtract(
  1705. uintTypeId, valuesId, {wordCounter});
  1706. // Update the output address if necessary.
  1707. if (wordCounter > 0) {
  1708. const uint32_t offset = theBuilder.getConstantUint32(wordCounter);
  1709. curStoreAddress = theBuilder.createBinaryOp(
  1710. spv::Op::OpIAdd, addressTypeId, address, offset);
  1711. }
  1712. // Store the word to the right address at the output.
  1713. const uint32_t storePtr = theBuilder.createAccessChain(
  1714. ptrType, objectInfo, {constUint0, curStoreAddress});
  1715. theBuilder.createStore(storePtr, curValue);
  1716. }
  1717. } else {
  1718. uint32_t loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  1719. {constUint0, address});
  1720. resultId = theBuilder.createLoad(uintTypeId, loadPtr);
  1721. if (numWords > 1) {
  1722. // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to
  1723. // return a vector result.
  1724. llvm::SmallVector<uint32_t, 4> values;
  1725. values.push_back(resultId);
  1726. for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) {
  1727. const uint32_t offset = theBuilder.getConstantUint32(wordCounter - 1);
  1728. const uint32_t newAddress = theBuilder.createBinaryOp(
  1729. spv::Op::OpIAdd, addressTypeId, address, offset);
  1730. loadPtr = theBuilder.createAccessChain(ptrType, objectInfo,
  1731. {constUint0, newAddress});
  1732. values.push_back(theBuilder.createLoad(uintTypeId, loadPtr));
  1733. }
  1734. const uint32_t resultType =
  1735. theBuilder.getVecType(addressTypeId, numWords);
  1736. resultId = theBuilder.createCompositeConstruct(resultType, values);
  1737. }
  1738. }
  1739. return resultId;
  1740. }
  1741. SpirvEvalInfo
  1742. SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
  1743. if (expr->getNumArgs() == 2) {
  1744. emitError("Load(int, int) unimplemented for (RW)StructuredBuffer");
  1745. return 0;
  1746. }
  1747. const auto *buffer = expr->getImplicitObjectArgument();
  1748. auto info = doExpr(buffer);
  1749. const QualType structType =
  1750. hlsl::GetHLSLResourceResultType(buffer->getType());
  1751. const uint32_t ptrType = theBuilder.getPointerType(
  1752. typeTranslator.translateType(structType, info.layoutRule),
  1753. info.storageClass);
  1754. const uint32_t zero = theBuilder.getConstantInt32(0);
  1755. const uint32_t index = doExpr(expr->getArg(0));
  1756. info.resultId = theBuilder.createAccessChain(ptrType, info, {zero, index});
  1757. return info;
  1758. }
  1759. SpirvEvalInfo
  1760. SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
  1761. const bool isAppend = expr->getNumArgs() == 1;
  1762. const uint32_t u32Type = theBuilder.getUint32Type();
  1763. const uint32_t one = theBuilder.getConstantUint32(1); // As scope: Device
  1764. const uint32_t zero = theBuilder.getConstantUint32(0); // As memory sema: None
  1765. const auto *object = expr->getImplicitObjectArgument();
  1766. const auto *buffer = cast<DeclRefExpr>(object)->getDecl();
  1767. // Calculate the index we should use for appending the value
  1768. const uint32_t counterVar = declIdMapper.getCounterId(cast<VarDecl>(buffer));
  1769. const uint32_t counterPtrType = theBuilder.getPointerType(
  1770. theBuilder.getInt32Type(), spv::StorageClass::Uniform);
  1771. const uint32_t counterPtr =
  1772. theBuilder.createAccessChain(counterPtrType, counterVar, {zero});
  1773. uint32_t index = 0;
  1774. if (isAppend) {
  1775. // For append, we add one to the counter.
  1776. index = theBuilder.createAtomicOp(spv::Op::OpAtomicIAdd, u32Type,
  1777. counterPtr, one, zero, one);
  1778. } else {
  1779. // For consume, we substract one from the counter. Note that OpAtomicIAdd
  1780. // returns the value before the addition; so we need to do substraction
  1781. // again with OpAtomicIAdd's return value.
  1782. const auto prevIndex = theBuilder.createAtomicOp(
  1783. spv::Op::OpAtomicISub, u32Type, counterPtr, one, zero, one);
  1784. index = theBuilder.createBinaryOp(spv::Op::OpISub, u32Type, prevIndex, one);
  1785. }
  1786. auto bufferInfo = declIdMapper.getDeclResultId(buffer);
  1787. const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
  1788. const uint32_t bufferElemType =
  1789. typeTranslator.translateType(bufferElemTy, bufferInfo.layoutRule);
  1790. // Get the pointer inside the {Append|Consume}StructuredBuffer
  1791. const uint32_t bufferElemPtrType =
  1792. theBuilder.getPointerType(bufferElemType, bufferInfo.storageClass);
  1793. const uint32_t bufferElemPtr = theBuilder.createAccessChain(
  1794. bufferElemPtrType, bufferInfo.resultId, {zero, index});
  1795. if (isAppend) {
  1796. // Write out the value
  1797. bufferInfo.resultId = bufferElemPtr;
  1798. storeValue(bufferInfo, doExpr(expr->getArg(0)), bufferElemTy);
  1799. return 0;
  1800. } else {
  1801. // Somehow if the element type is not a structure type, the return value
  1802. // of .Consume() is not labelled as xvalue. That will cause OpLoad
  1803. // instruction missing. Load directly here.
  1804. if (bufferElemTy->isStructureType())
  1805. bufferInfo.resultId = bufferElemPtr;
  1806. else
  1807. bufferInfo.resultId =
  1808. theBuilder.createLoad(bufferElemType, bufferElemPtr);
  1809. return bufferInfo;
  1810. }
  1811. }
  1812. SpirvEvalInfo SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
  1813. const FunctionDecl *callee = expr->getDirectCallee();
  1814. llvm::StringRef group;
  1815. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  1816. if (hlsl::GetIntrinsicOp(callee, opcode, group)) {
  1817. return processIntrinsicMemberCall(expr,
  1818. static_cast<hlsl::IntrinsicOp>(opcode));
  1819. }
  1820. return processCall(expr);
  1821. }
  1822. void SPIRVEmitter::handleOptionalOffsetInMethodCall(
  1823. const CXXMemberCallExpr *expr, uint32_t index, uint32_t *constOffset,
  1824. uint32_t *varOffset) {
  1825. *constOffset = *varOffset = 0; // Initialize both first
  1826. if (expr->getNumArgs() == index + 1) { // Has offset argument
  1827. if (*constOffset = tryToEvaluateAsConst(expr->getArg(index)))
  1828. return; // Constant offset
  1829. else
  1830. *varOffset = doExpr(expr->getArg(index));
  1831. }
  1832. };
  1833. SpirvEvalInfo
  1834. SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
  1835. hlsl::IntrinsicOp opcode) {
  1836. using namespace hlsl;
  1837. switch (opcode) {
  1838. case IntrinsicOp::MOP_Sample:
  1839. return processTextureSampleGather(expr, /*isSample=*/true);
  1840. case IntrinsicOp::MOP_Gather:
  1841. return processTextureSampleGather(expr, /*isSample=*/false);
  1842. case IntrinsicOp::MOP_SampleBias:
  1843. return processTextureSampleBiasLevel(expr, /*isBias=*/true);
  1844. case IntrinsicOp::MOP_SampleLevel:
  1845. return processTextureSampleBiasLevel(expr, /*isBias=*/false);
  1846. case IntrinsicOp::MOP_SampleGrad:
  1847. return processTextureSampleGrad(expr);
  1848. case IntrinsicOp::MOP_SampleCmp:
  1849. return processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true);
  1850. case IntrinsicOp::MOP_SampleCmpLevelZero:
  1851. return processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false);
  1852. case IntrinsicOp::MOP_GatherRed:
  1853. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0);
  1854. case IntrinsicOp::MOP_GatherGreen:
  1855. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1);
  1856. case IntrinsicOp::MOP_GatherBlue:
  1857. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2);
  1858. case IntrinsicOp::MOP_GatherAlpha:
  1859. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3);
  1860. case IntrinsicOp::MOP_GatherCmp:
  1861. return processTextureGatherCmp(expr);
  1862. case IntrinsicOp::MOP_GatherCmpRed:
  1863. return processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0);
  1864. case IntrinsicOp::MOP_Load:
  1865. return processBufferTextureLoad(expr);
  1866. case IntrinsicOp::MOP_Load2:
  1867. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false);
  1868. case IntrinsicOp::MOP_Load3:
  1869. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false);
  1870. case IntrinsicOp::MOP_Load4:
  1871. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false);
  1872. case IntrinsicOp::MOP_Store:
  1873. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true);
  1874. case IntrinsicOp::MOP_Store2:
  1875. return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true);
  1876. case IntrinsicOp::MOP_Store3:
  1877. return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true);
  1878. case IntrinsicOp::MOP_Store4:
  1879. return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true);
  1880. case IntrinsicOp::MOP_GetDimensions:
  1881. return processGetDimensions(expr);
  1882. case IntrinsicOp::MOP_CalculateLevelOfDetail:
  1883. return processTextureLevelOfDetail(expr);
  1884. case IntrinsicOp::MOP_Append:
  1885. case IntrinsicOp::MOP_Consume:
  1886. return processACSBufferAppendConsume(expr);
  1887. case IntrinsicOp::MOP_InterlockedAdd:
  1888. case IntrinsicOp::MOP_InterlockedAnd:
  1889. case IntrinsicOp::MOP_InterlockedOr:
  1890. case IntrinsicOp::MOP_InterlockedXor:
  1891. case IntrinsicOp::MOP_InterlockedUMax:
  1892. case IntrinsicOp::MOP_InterlockedUMin:
  1893. case IntrinsicOp::MOP_InterlockedMax:
  1894. case IntrinsicOp::MOP_InterlockedMin:
  1895. case IntrinsicOp::MOP_InterlockedExchange:
  1896. case IntrinsicOp::MOP_InterlockedCompareExchange:
  1897. case IntrinsicOp::MOP_InterlockedCompareStore:
  1898. return processRWByteAddressBufferAtomicMethods(opcode, expr);
  1899. }
  1900. emitError("HLSL intrinsic member call unimplemented: %0")
  1901. << expr->getDirectCallee()->getName();
  1902. return 0;
  1903. }
  1904. uint32_t SPIRVEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr,
  1905. const bool isSample) {
  1906. // Signatures:
  1907. // DXGI_FORMAT Object.Sample(sampler_state S,
  1908. // float Location
  1909. // [, int Offset]);
  1910. //
  1911. // <Template Type>4 Object.Gather(sampler_state S,
  1912. // float2|3|4 Location
  1913. // [, int2 Offset]);
  1914. const auto *imageExpr = expr->getImplicitObjectArgument();
  1915. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  1916. const uint32_t image = loadIfGLValue(imageExpr);
  1917. const uint32_t sampler = doExpr(expr->getArg(0));
  1918. const uint32_t coordinate = doExpr(expr->getArg(1));
  1919. // .Sample()/.Gather() has a third optional paramter for offset.
  1920. uint32_t constOffset = 0, varOffset = 0;
  1921. handleOptionalOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  1922. const auto retType =
  1923. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  1924. if (isSample) {
  1925. return theBuilder.createImageSample(
  1926. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0,
  1927. /*bias*/ 0, /*lod*/ 0, std::make_pair(0, 0), constOffset, varOffset,
  1928. /*constOffsets*/ 0, /*sampleNumber*/ 0);
  1929. } else {
  1930. return theBuilder.createImageGather(
  1931. retType, imageType, image, sampler, coordinate,
  1932. // .Gather() doc says we return four components of red data.
  1933. theBuilder.getConstantInt32(0), /*compareVal*/ 0, constOffset,
  1934. varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0);
  1935. }
  1936. }
  1937. uint32_t
  1938. SPIRVEmitter::processTextureSampleBiasLevel(const CXXMemberCallExpr *expr,
  1939. const bool isBias) {
  1940. // Signatures:
  1941. // DXGI_FORMAT Object.SampleBias(sampler_state S,
  1942. // float Location,
  1943. // float Bias
  1944. // [, int Offset]);
  1945. //
  1946. // DXGI_FORMAT Object.SampleLevel(sampler_state S,
  1947. // float Location,
  1948. // float LOD
  1949. // [, int Offset]);
  1950. const auto *imageExpr = expr->getImplicitObjectArgument();
  1951. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  1952. const uint32_t image = loadIfGLValue(imageExpr);
  1953. const uint32_t sampler = doExpr(expr->getArg(0));
  1954. const uint32_t coordinate = doExpr(expr->getArg(1));
  1955. uint32_t lod = 0;
  1956. uint32_t bias = 0;
  1957. if (isBias) {
  1958. bias = doExpr(expr->getArg(2));
  1959. } else {
  1960. lod = doExpr(expr->getArg(2));
  1961. }
  1962. // .Bias()/.SampleLevel() has a fourth optional paramter for offset.
  1963. uint32_t constOffset = 0, varOffset = 0;
  1964. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  1965. const auto retType =
  1966. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  1967. return theBuilder.createImageSample(
  1968. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0, bias,
  1969. lod, std::make_pair(0, 0), constOffset, varOffset, /*constOffsets*/ 0,
  1970. /*sampleNumber*/ 0);
  1971. }
  1972. uint32_t SPIRVEmitter::processTextureSampleGrad(const CXXMemberCallExpr *expr) {
  1973. // Signature:
  1974. // DXGI_FORMAT Object.SampleGrad(sampler_state S,
  1975. // float Location,
  1976. // float DDX,
  1977. // float DDY
  1978. // [, int Offset]);
  1979. const auto *imageExpr = expr->getImplicitObjectArgument();
  1980. const uint32_t imageType = typeTranslator.translateType(imageExpr->getType());
  1981. const uint32_t image = loadIfGLValue(imageExpr);
  1982. const uint32_t sampler = doExpr(expr->getArg(0));
  1983. const uint32_t coordinate = doExpr(expr->getArg(1));
  1984. const uint32_t ddx = doExpr(expr->getArg(2));
  1985. const uint32_t ddy = doExpr(expr->getArg(3));
  1986. // .SampleGrad() has a fifth optional paramter for offset.
  1987. uint32_t constOffset = 0, varOffset = 0;
  1988. handleOptionalOffsetInMethodCall(expr, 4, &constOffset, &varOffset);
  1989. const auto retType =
  1990. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  1991. return theBuilder.createImageSample(
  1992. retType, imageType, image, sampler, coordinate, /*compareVal*/ 0,
  1993. /*bias*/ 0, /*lod*/ 0, std::make_pair(ddx, ddy), constOffset, varOffset,
  1994. /*constOffsets*/ 0,
  1995. /*sampleNumber*/ 0);
  1996. }
  1997. uint32_t
  1998. SPIRVEmitter::processTextureSampleCmpCmpLevelZero(const CXXMemberCallExpr *expr,
  1999. const bool isCmp) {
  2000. // .SampleCmp() Signature:
  2001. //
  2002. // float Object.SampleCmp(
  2003. // SamplerComparisonState S,
  2004. // float Location,
  2005. // float CompareValue,
  2006. // [int Offset]
  2007. // );
  2008. //
  2009. // .SampleCmpLevelZero() is identical to .SampleCmp() on mipmap level 0 only.
  2010. const auto *imageExpr = expr->getImplicitObjectArgument();
  2011. const uint32_t image = loadIfGLValue(imageExpr);
  2012. const uint32_t sampler = doExpr(expr->getArg(0));
  2013. const uint32_t coordinate = doExpr(expr->getArg(1));
  2014. const uint32_t compareVal = doExpr(expr->getArg(2));
  2015. // .SampleCmp() has a fourth optional paramter for offset.
  2016. uint32_t constOffset = 0, varOffset = 0;
  2017. handleOptionalOffsetInMethodCall(expr, 3, &constOffset, &varOffset);
  2018. const uint32_t lod = isCmp ? 0 : theBuilder.getConstantFloat32(0);
  2019. const auto retType =
  2020. typeTranslator.translateType(expr->getDirectCallee()->getReturnType());
  2021. const auto imageType = typeTranslator.translateType(imageExpr->getType());
  2022. return theBuilder.createImageSample(
  2023. retType, imageType, image, sampler, coordinate, compareVal, /*bias*/ 0,
  2024. lod, std::make_pair(0, 0), constOffset, varOffset,
  2025. /*constOffsets*/ 0, /*sampleNumber*/ 0);
  2026. }
  2027. SpirvEvalInfo
  2028. SPIRVEmitter::processBufferTextureLoad(const CXXMemberCallExpr *expr) {
  2029. // Signature:
  2030. // ret Object.Load(int Location
  2031. // [, int SampleIndex,]
  2032. // [, int Offset]);
  2033. const auto *object = expr->getImplicitObjectArgument();
  2034. const auto *location = expr->getArg(0);
  2035. const auto objectType = object->getType();
  2036. if (typeTranslator.isRWByteAddressBuffer(objectType) ||
  2037. typeTranslator.isByteAddressBuffer(objectType))
  2038. return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ false);
  2039. if (TypeTranslator::isStructuredBuffer(objectType))
  2040. return processStructuredBufferLoad(expr);
  2041. if (TypeTranslator::isBuffer(objectType) ||
  2042. TypeTranslator::isRWBuffer(objectType) ||
  2043. TypeTranslator::isRWTexture(objectType))
  2044. return processBufferTextureLoad(object, doExpr(location));
  2045. if (TypeTranslator::isTexture(objectType)) {
  2046. // .Load() has a second optional paramter for offset.
  2047. const auto locationId = doExpr(location);
  2048. uint32_t constOffset = 0, varOffset = 0;
  2049. uint32_t coordinate = locationId, lod = 0;
  2050. if (TypeTranslator::isTextureMS(objectType)) {
  2051. // SampleIndex is only available when the Object is of Texture2DMS or
  2052. // Texture2DMSArray types. Under those cases, Offset will be the third
  2053. // parameter (index 2).
  2054. lod = doExpr(expr->getArg(1));
  2055. handleOptionalOffsetInMethodCall(expr, 2, &constOffset, &varOffset);
  2056. } else {
  2057. // For Texture Load() functions, the location parameter is a vector
  2058. // that consists of both the coordinate and the mipmap level (via the
  2059. // last vector element). We need to split it here since the
  2060. // OpImageFetch SPIR-V instruction encodes them as separate arguments.
  2061. splitVecLastElement(location->getType(), locationId, &coordinate, &lod);
  2062. // For textures other than Texture2DMS(Array), offset should be the
  2063. // second parameter (index 1).
  2064. handleOptionalOffsetInMethodCall(expr, 1, &constOffset, &varOffset);
  2065. }
  2066. return processBufferTextureLoad(object, coordinate, constOffset, varOffset,
  2067. lod);
  2068. }
  2069. emitError("Load() is not implemented for the given object type.");
  2070. return 0;
  2071. }
  2072. uint32_t SPIRVEmitter::processGetDimensions(const CXXMemberCallExpr *expr) {
  2073. const auto objectType = expr->getImplicitObjectArgument()->getType();
  2074. if (TypeTranslator::isTexture(objectType) ||
  2075. TypeTranslator::isRWTexture(objectType) ||
  2076. TypeTranslator::isBuffer(objectType) ||
  2077. TypeTranslator::isRWBuffer(objectType)) {
  2078. return processBufferTextureGetDimensions(expr);
  2079. } else if (TypeTranslator::isByteAddressBuffer(objectType) ||
  2080. TypeTranslator::isRWByteAddressBuffer(objectType) ||
  2081. TypeTranslator::isStructuredBuffer(objectType) ||
  2082. TypeTranslator::isAppendStructuredBuffer(objectType) ||
  2083. TypeTranslator::isConsumeStructuredBuffer(objectType)) {
  2084. return processByteAddressBufferStructuredBufferGetDimensions(expr);
  2085. } else {
  2086. emitError("GetDimensions not implmented for the given type yet.");
  2087. return 0;
  2088. }
  2089. }
  2090. SpirvEvalInfo
  2091. SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
  2092. { // Handle Buffer/RWBuffer/Texture/RWTexture indexing
  2093. const Expr *baseExpr = nullptr;
  2094. const Expr *indexExpr = nullptr;
  2095. const Expr *lodExpr = nullptr;
  2096. // For Textures, regular indexing (operator[]) uses slice 0.
  2097. if (isBufferTextureIndexing(expr, &baseExpr, &indexExpr)) {
  2098. const uint32_t lod = TypeTranslator::isTexture(baseExpr->getType())
  2099. ? theBuilder.getConstantUint32(0)
  2100. : 0;
  2101. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  2102. /*constOffset*/ 0, /*varOffset*/ 0, lod);
  2103. }
  2104. // .mips[][] or .sample[][] must use the correct slice.
  2105. if (isTextureMipsSampleIndexing(expr, &baseExpr, &indexExpr, &lodExpr)) {
  2106. const uint32_t lod = doExpr(lodExpr);
  2107. return processBufferTextureLoad(baseExpr, doExpr(indexExpr),
  2108. /*constOffset*/ 0, /*varOffset*/ 0, lod);
  2109. }
  2110. }
  2111. llvm::SmallVector<uint32_t, 4> indices;
  2112. const Expr *baseExpr = collectArrayStructIndices(expr, &indices);
  2113. auto base = doExpr(baseExpr);
  2114. if (indices.empty())
  2115. return base; // For indexing into size-1 vectors and 1xN matrices
  2116. // If we are indexing into a rvalue, to use OpAccessChain, we first need
  2117. // to create a local variable to hold the rvalue.
  2118. //
  2119. // TODO: We can optimize the codegen by emitting OpCompositeExtract if
  2120. // all indices are contant integers.
  2121. if (!baseExpr->isGLValue()) {
  2122. const uint32_t baseType = typeTranslator.translateType(baseExpr->getType());
  2123. const uint32_t tempVar = theBuilder.addFnVar(baseType, "temp.var");
  2124. theBuilder.createStore(tempVar, base);
  2125. base = tempVar;
  2126. }
  2127. const uint32_t ptrType = theBuilder.getPointerType(
  2128. typeTranslator.translateType(expr->getType(), base.layoutRule),
  2129. base.storageClass);
  2130. base.resultId = theBuilder.createAccessChain(ptrType, base, indices);
  2131. return base;
  2132. }
  2133. SpirvEvalInfo
  2134. SPIRVEmitter::doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
  2135. const Expr *baseExpr = expr->getBase();
  2136. const auto baseInfo = doExpr(baseExpr);
  2137. const auto accessor = expr->getEncodedElementAccess();
  2138. const uint32_t elemType = typeTranslator.translateType(
  2139. hlsl::GetHLSLMatElementType(baseExpr->getType()));
  2140. uint32_t rowCount = 0, colCount = 0;
  2141. hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
  2142. // Construct a temporary vector out of all elements accessed:
  2143. // 1. Create access chain for each element using OpAccessChain
  2144. // 2. Load each element using OpLoad
  2145. // 3. Create the vector using OpCompositeConstruct
  2146. llvm::SmallVector<uint32_t, 4> elements;
  2147. for (uint32_t i = 0; i < accessor.Count; ++i) {
  2148. uint32_t row = 0, col = 0, elem = 0;
  2149. accessor.GetPosition(i, &row, &col);
  2150. llvm::SmallVector<uint32_t, 2> indices;
  2151. // If the matrix only has one row/column, we are indexing into a vector
  2152. // then. Only one index is needed for such cases.
  2153. if (rowCount > 1)
  2154. indices.push_back(row);
  2155. if (colCount > 1)
  2156. indices.push_back(col);
  2157. if (baseExpr->isGLValue()) {
  2158. for (uint32_t i = 0; i < indices.size(); ++i)
  2159. indices[i] = theBuilder.getConstantInt32(indices[i]);
  2160. const uint32_t ptrType =
  2161. theBuilder.getPointerType(elemType, baseInfo.storageClass);
  2162. if (!indices.empty()) {
  2163. // Load the element via access chain
  2164. elem = theBuilder.createAccessChain(ptrType, baseInfo, indices);
  2165. } else {
  2166. // The matrix is of size 1x1. No need to use access chain, base should
  2167. // be the source pointer.
  2168. elem = baseInfo;
  2169. }
  2170. elem = theBuilder.createLoad(elemType, elem);
  2171. } else { // e.g., (mat1 + mat2)._m11
  2172. elem = theBuilder.createCompositeExtract(elemType, baseInfo, indices);
  2173. }
  2174. elements.push_back(elem);
  2175. }
  2176. if (elements.size() == 1)
  2177. return elements.front();
  2178. const uint32_t vecType = theBuilder.getVecType(elemType, elements.size());
  2179. return theBuilder.createCompositeConstruct(vecType, elements);
  2180. }
  2181. SpirvEvalInfo
  2182. SPIRVEmitter::doHLSLVectorElementExpr(const HLSLVectorElementExpr *expr) {
  2183. const Expr *baseExpr = nullptr;
  2184. hlsl::VectorMemberAccessPositions accessor;
  2185. condenseVectorElementExpr(expr, &baseExpr, &accessor);
  2186. const QualType baseType = baseExpr->getType();
  2187. assert(hlsl::IsHLSLVecType(baseType));
  2188. const auto baseSize = hlsl::GetHLSLVecSize(baseType);
  2189. const uint32_t type = typeTranslator.translateType(expr->getType());
  2190. const auto accessorSize = accessor.Count;
  2191. // Depending on the number of elements selected, we emit different
  2192. // instructions.
  2193. // For vectors of size greater than 1, if we are only selecting one element,
  2194. // typical access chain or composite extraction should be fine. But if we
  2195. // are selecting more than one elements, we must resolve to vector specific
  2196. // operations.
  2197. // For size-1 vectors, if we are selecting their single elements multiple
  2198. // times, we need composite construct instructions.
  2199. if (accessorSize == 1) {
  2200. if (baseSize == 1) {
  2201. // Selecting one element from a size-1 vector. The underlying vector is
  2202. // already treated as a scalar.
  2203. return doExpr(baseExpr);
  2204. }
  2205. // If the base is an lvalue, we should emit an access chain instruction
  2206. // so that we can load/store the specified element. For rvalue base,
  2207. // we should use composite extraction. We should check the immediate base
  2208. // instead of the original base here since we can have something like
  2209. // v.xyyz to turn a lvalue v into rvalue.
  2210. if (expr->getBase()->isGLValue()) { // E.g., v.x;
  2211. const auto baseInfo = doExpr(baseExpr);
  2212. const uint32_t ptrType =
  2213. theBuilder.getPointerType(type, baseInfo.storageClass);
  2214. const uint32_t index = theBuilder.getConstantInt32(accessor.Swz0);
  2215. // We need a lvalue here. Do not try to load.
  2216. return theBuilder.createAccessChain(ptrType, baseInfo, {index});
  2217. } else { // E.g., (v + w).x;
  2218. // The original base vector may not be a rvalue. Need to load it if
  2219. // it is lvalue since ImplicitCastExpr (LValueToRValue) will be missing
  2220. // for that case.
  2221. return theBuilder.createCompositeExtract(type, loadIfGLValue(baseExpr),
  2222. {accessor.Swz0});
  2223. }
  2224. }
  2225. if (baseSize == 1) {
  2226. // Selecting one element from a size-1 vector. Construct the vector.
  2227. llvm::SmallVector<uint32_t, 4> components(static_cast<size_t>(accessorSize),
  2228. loadIfGLValue(baseExpr));
  2229. return theBuilder.createCompositeConstruct(type, components);
  2230. }
  2231. llvm::SmallVector<uint32_t, 4> selectors;
  2232. selectors.resize(accessorSize);
  2233. // Whether we are selecting elements in the original order
  2234. bool originalOrder = baseSize == accessorSize;
  2235. for (uint32_t i = 0; i < accessorSize; ++i) {
  2236. accessor.GetPosition(i, &selectors[i]);
  2237. // We can select more elements than the vector provides. This handles
  2238. // that case too.
  2239. originalOrder &= selectors[i] == i;
  2240. }
  2241. if (originalOrder)
  2242. return doExpr(baseExpr);
  2243. const uint32_t baseVal = loadIfGLValue(baseExpr);
  2244. // Use base for both vectors. But we are only selecting values from the
  2245. // first one.
  2246. return theBuilder.createVectorShuffle(type, baseVal, baseVal, selectors);
  2247. }
  2248. SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
  2249. if (const uint32_t id = tryToEvaluateAsConst(expr))
  2250. return id;
  2251. return InitListHandler(*this).process(expr);
  2252. }
  2253. SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
  2254. llvm::SmallVector<uint32_t, 4> indices;
  2255. const Expr *base = collectArrayStructIndices(expr, &indices);
  2256. auto info = doExpr(base);
  2257. const uint32_t ptrType = theBuilder.getPointerType(
  2258. typeTranslator.translateType(expr->getType(), info.layoutRule),
  2259. info.storageClass);
  2260. info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
  2261. return info;
  2262. }
  2263. SpirvEvalInfo SPIRVEmitter::doUnaryOperator(const UnaryOperator *expr) {
  2264. const auto opcode = expr->getOpcode();
  2265. const auto *subExpr = expr->getSubExpr();
  2266. const auto subType = subExpr->getType();
  2267. auto subValue = doExpr(subExpr);
  2268. const auto subTypeId = typeTranslator.translateType(subType);
  2269. switch (opcode) {
  2270. case UO_PreInc:
  2271. case UO_PreDec:
  2272. case UO_PostInc:
  2273. case UO_PostDec: {
  2274. const bool isPre = opcode == UO_PreInc || opcode == UO_PreDec;
  2275. const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
  2276. const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
  2277. const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
  2278. const uint32_t one = hlsl::IsHLSLMatType(subType)
  2279. ? getMatElemValueOne(subType)
  2280. : getValueOne(subType);
  2281. uint32_t incValue = 0;
  2282. if (TypeTranslator::isSpirvAcceptableMatrixType(subType)) {
  2283. // For matrices, we can only increment/decrement each vector of it.
  2284. const auto actOnEachVec = [this, spvOp, one](
  2285. uint32_t /*index*/, uint32_t vecType, uint32_t lhsVec) {
  2286. return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, one);
  2287. };
  2288. incValue = processEachVectorInMatrix(subExpr, originValue, actOnEachVec);
  2289. } else {
  2290. incValue = theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
  2291. }
  2292. theBuilder.createStore(subValue, incValue);
  2293. // Prefix increment/decrement operator returns a lvalue, while postfix
  2294. // increment/decrement returns a rvalue.
  2295. return isPre ? subValue : originValue;
  2296. }
  2297. case UO_Not:
  2298. return theBuilder.createUnaryOp(spv::Op::OpNot, subTypeId, subValue);
  2299. case UO_LNot:
  2300. // Parsing will do the necessary casting to make sure we are applying the
  2301. // ! operator on boolean values.
  2302. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, subTypeId, subValue);
  2303. case UO_Plus:
  2304. // No need to do anything for the prefix + operator.
  2305. return subValue;
  2306. case UO_Minus: {
  2307. // SPIR-V have two opcodes for negating values: OpSNegate and OpFNegate.
  2308. const spv::Op spvOp = isFloatOrVecOfFloatType(subType) ? spv::Op::OpFNegate
  2309. : spv::Op::OpSNegate;
  2310. return theBuilder.createUnaryOp(spvOp, subTypeId, subValue);
  2311. }
  2312. default:
  2313. break;
  2314. }
  2315. emitError("unary operator '%0' unimplemented yet")
  2316. << expr->getOpcodeStr(opcode);
  2317. expr->dump();
  2318. return 0;
  2319. }
  2320. spv::Op SPIRVEmitter::translateOp(BinaryOperator::Opcode op, QualType type) {
  2321. const bool isSintType = isSintOrVecMatOfSintType(type);
  2322. const bool isUintType = isUintOrVecMatOfUintType(type);
  2323. const bool isFloatType = isFloatOrVecMatOfFloatType(type);
  2324. #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
  2325. \
  2326. case BO_##kind : { \
  2327. if (isSintType || isUintType) { \
  2328. return spv::Op::Op##intBinOp; \
  2329. } \
  2330. if (isFloatType) { \
  2331. return spv::Op::Op##floatBinOp; \
  2332. } \
  2333. } \
  2334. break
  2335. #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp) \
  2336. \
  2337. case BO_##kind : { \
  2338. if (isSintType) { \
  2339. return spv::Op::Op##sintBinOp; \
  2340. } \
  2341. if (isUintType) { \
  2342. return spv::Op::Op##uintBinOp; \
  2343. } \
  2344. if (isFloatType) { \
  2345. return spv::Op::Op##floatBinOp; \
  2346. } \
  2347. } \
  2348. break
  2349. #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp) \
  2350. \
  2351. case BO_##kind : { \
  2352. if (isSintType) { \
  2353. return spv::Op::Op##sintBinOp; \
  2354. } \
  2355. if (isUintType) { \
  2356. return spv::Op::Op##uintBinOp; \
  2357. } \
  2358. } \
  2359. break
  2360. switch (op) {
  2361. case BO_EQ: {
  2362. if (isBoolOrVecMatOfBoolType(type))
  2363. return spv::Op::OpLogicalEqual;
  2364. if (isSintType || isUintType)
  2365. return spv::Op::OpIEqual;
  2366. if (isFloatType)
  2367. return spv::Op::OpFOrdEqual;
  2368. } break;
  2369. case BO_NE: {
  2370. if (isBoolOrVecMatOfBoolType(type))
  2371. return spv::Op::OpLogicalNotEqual;
  2372. if (isSintType || isUintType)
  2373. return spv::Op::OpINotEqual;
  2374. if (isFloatType)
  2375. return spv::Op::OpFOrdNotEqual;
  2376. } break;
  2377. // According to HLSL doc, all sides of the && and || expression are always
  2378. // evaluated.
  2379. case BO_LAnd:
  2380. return spv::Op::OpLogicalAnd;
  2381. case BO_LOr:
  2382. return spv::Op::OpLogicalOr;
  2383. BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
  2384. BIN_OP_CASE_INT_FLOAT(AddAssign, IAdd, FAdd);
  2385. BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
  2386. BIN_OP_CASE_INT_FLOAT(SubAssign, ISub, FSub);
  2387. BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
  2388. BIN_OP_CASE_INT_FLOAT(MulAssign, IMul, FMul);
  2389. BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
  2390. BIN_OP_CASE_SINT_UINT_FLOAT(DivAssign, SDiv, UDiv, FDiv);
  2391. // According to HLSL spec, "the modulus operator returns the remainder of
  2392. // a division." "The % operator is defined only in cases where either both
  2393. // sides are positive or both sides are negative."
  2394. //
  2395. // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
  2396. // the former, the sign of a non-0 result comes from Operand 1, while
  2397. // with the latter, from Operand 2.
  2398. //
  2399. // For operands with different signs, technically we can map % to either
  2400. // Op*Rem or Op*Mod since it's undefined behavior. But it is more
  2401. // consistent with C (HLSL starts as a C derivative) and Clang frontend
  2402. // const expression evaluation if we map % to Op*Rem.
  2403. //
  2404. // Note there is no OpURem in SPIR-V.
  2405. BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
  2406. BIN_OP_CASE_SINT_UINT_FLOAT(RemAssign, SRem, UMod, FRem);
  2407. BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
  2408. BIN_OP_CASE_SINT_UINT_FLOAT(LE, SLessThanEqual, ULessThanEqual,
  2409. FOrdLessThanEqual);
  2410. BIN_OP_CASE_SINT_UINT_FLOAT(GT, SGreaterThan, UGreaterThan,
  2411. FOrdGreaterThan);
  2412. BIN_OP_CASE_SINT_UINT_FLOAT(GE, SGreaterThanEqual, UGreaterThanEqual,
  2413. FOrdGreaterThanEqual);
  2414. BIN_OP_CASE_SINT_UINT(And, BitwiseAnd, BitwiseAnd);
  2415. BIN_OP_CASE_SINT_UINT(AndAssign, BitwiseAnd, BitwiseAnd);
  2416. BIN_OP_CASE_SINT_UINT(Or, BitwiseOr, BitwiseOr);
  2417. BIN_OP_CASE_SINT_UINT(OrAssign, BitwiseOr, BitwiseOr);
  2418. BIN_OP_CASE_SINT_UINT(Xor, BitwiseXor, BitwiseXor);
  2419. BIN_OP_CASE_SINT_UINT(XorAssign, BitwiseXor, BitwiseXor);
  2420. BIN_OP_CASE_SINT_UINT(Shl, ShiftLeftLogical, ShiftLeftLogical);
  2421. BIN_OP_CASE_SINT_UINT(ShlAssign, ShiftLeftLogical, ShiftLeftLogical);
  2422. BIN_OP_CASE_SINT_UINT(Shr, ShiftRightArithmetic, ShiftRightLogical);
  2423. BIN_OP_CASE_SINT_UINT(ShrAssign, ShiftRightArithmetic, ShiftRightLogical);
  2424. default:
  2425. break;
  2426. }
  2427. #undef BIN_OP_CASE_INT_FLOAT
  2428. #undef BIN_OP_CASE_SINT_UINT_FLOAT
  2429. #undef BIN_OP_CASE_SINT_UINT
  2430. emitError("translating binary operator '%0' unimplemented")
  2431. << BinaryOperator::getOpcodeStr(op);
  2432. return spv::Op::OpNop;
  2433. }
  2434. SpirvEvalInfo SPIRVEmitter::processAssignment(const Expr *lhs,
  2435. const SpirvEvalInfo &rhs,
  2436. const bool isCompoundAssignment,
  2437. SpirvEvalInfo lhsPtr) {
  2438. // Assigning to vector swizzling should be handled differently.
  2439. if (const SpirvEvalInfo result = tryToAssignToVectorElements(lhs, rhs))
  2440. return result;
  2441. // Assigning to matrix swizzling should be handled differently.
  2442. if (const SpirvEvalInfo result = tryToAssignToMatrixElements(lhs, rhs))
  2443. return result;
  2444. // Assigning to a RWBuffer/RWTexture should be handled differently.
  2445. if (const SpirvEvalInfo result = tryToAssignToRWBufferRWTexture(lhs, rhs))
  2446. return result;
  2447. // Normal assignment procedure
  2448. if (!lhsPtr.resultId)
  2449. lhsPtr = doExpr(lhs);
  2450. storeValue(lhsPtr, rhs, lhs->getType());
  2451. // Plain assignment returns a rvalue, while compound assignment returns
  2452. // lvalue.
  2453. return isCompoundAssignment ? lhsPtr : rhs;
  2454. }
  2455. void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
  2456. const SpirvEvalInfo &rhsVal,
  2457. const QualType valType) {
  2458. // If lhs and rhs has the same memory layout, we should be safe to load
  2459. // from rhs and directly store into lhs and avoid decomposing rhs.
  2460. // TODO: is this optimization always correct?
  2461. if (lhsPtr.layoutRule == rhsVal.layoutRule ||
  2462. typeTranslator.isScalarType(valType) ||
  2463. typeTranslator.isVectorType(valType) ||
  2464. typeTranslator.isMxNMatrix(valType)) {
  2465. theBuilder.createStore(lhsPtr, rhsVal);
  2466. } else if (const auto *recordType = valType->getAs<RecordType>()) {
  2467. uint32_t index = 0;
  2468. for (const auto *decl : recordType->getDecl()->decls()) {
  2469. // Ignore implicit generated struct declarations/constructors/destructors.
  2470. if (decl->isImplicit())
  2471. continue;
  2472. const auto *field = cast<FieldDecl>(decl);
  2473. assert(field);
  2474. const auto subRhsValType =
  2475. typeTranslator.translateType(field->getType(), rhsVal.layoutRule);
  2476. const auto subRhsVal =
  2477. theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
  2478. const auto subLhsPtrType = theBuilder.getPointerType(
  2479. typeTranslator.translateType(field->getType(), lhsPtr.layoutRule),
  2480. lhsPtr.storageClass);
  2481. const auto subLhsPtr = theBuilder.createAccessChain(
  2482. subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});
  2483. storeValue(lhsPtr.substResultId(subLhsPtr),
  2484. rhsVal.substResultId(subRhsVal), field->getType());
  2485. ++index;
  2486. }
  2487. } else if (const auto *arrayType =
  2488. astContext.getAsConstantArrayType(valType)) {
  2489. const auto elemType = arrayType->getElementType();
  2490. // TODO: handle extra large array size?
  2491. const auto size =
  2492. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  2493. for (uint32_t i = 0; i < size; ++i) {
  2494. const auto subRhsValType =
  2495. typeTranslator.translateType(elemType, rhsVal.layoutRule);
  2496. const auto subRhsVal =
  2497. theBuilder.createCompositeExtract(subRhsValType, rhsVal, {i});
  2498. const auto subLhsPtrType = theBuilder.getPointerType(
  2499. typeTranslator.translateType(elemType, lhsPtr.layoutRule),
  2500. lhsPtr.storageClass);
  2501. const auto subLhsPtr = theBuilder.createAccessChain(
  2502. subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(i)});
  2503. storeValue(lhsPtr.substResultId(subLhsPtr),
  2504. rhsVal.substResultId(subRhsVal), elemType);
  2505. }
  2506. } else {
  2507. emitError("storing value of type %0 unimplemented") << valType;
  2508. }
  2509. }
  2510. SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
  2511. const BinaryOperatorKind opcode,
  2512. const uint32_t resultType,
  2513. SpirvEvalInfo *lhsInfo,
  2514. const spv::Op mandateGenOpcode) {
  2515. // If the operands are of matrix type, we need to dispatch the operation
  2516. // onto each element vector iff the operands are not degenerated matrices
  2517. // and we don't have a matrix specific SPIR-V instruction for the operation.
  2518. if (!isSpirvMatrixOp(mandateGenOpcode) &&
  2519. TypeTranslator::isSpirvAcceptableMatrixType(lhs->getType())) {
  2520. return processMatrixBinaryOp(lhs, rhs, opcode);
  2521. }
  2522. // Comma operator works differently from other binary operations as there is
  2523. // no SPIR-V instruction for it. For each comma, we must evaluate lhs and rhs
  2524. // respectively, and return the results of rhs.
  2525. if (opcode == BO_Comma) {
  2526. (void)doExpr(lhs);
  2527. return doExpr(rhs);
  2528. }
  2529. const spv::Op spvOp = (mandateGenOpcode == spv::Op::Max)
  2530. ? translateOp(opcode, lhs->getType())
  2531. : mandateGenOpcode;
  2532. SpirvEvalInfo rhsVal = 0, lhsPtr = 0, lhsVal = 0;
  2533. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  2534. // Evalute rhs before lhs
  2535. rhsVal = doExpr(rhs);
  2536. lhsVal = lhsPtr = doExpr(lhs);
  2537. // This is a compound assignment. We need to load the lhs value if lhs
  2538. // does not generate a vector shuffle.
  2539. if (!isVectorShuffle(lhs)) {
  2540. const uint32_t lhsTy = typeTranslator.translateType(lhs->getType());
  2541. lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
  2542. }
  2543. } else {
  2544. // Evalute lhs before rhs
  2545. lhsVal = lhsPtr = doExpr(lhs);
  2546. rhsVal = doExpr(rhs);
  2547. }
  2548. if (lhsInfo)
  2549. *lhsInfo = lhsPtr;
  2550. switch (opcode) {
  2551. case BO_Add:
  2552. case BO_Sub:
  2553. case BO_Mul:
  2554. case BO_Div:
  2555. case BO_Rem:
  2556. case BO_LT:
  2557. case BO_LE:
  2558. case BO_GT:
  2559. case BO_GE:
  2560. case BO_EQ:
  2561. case BO_NE:
  2562. case BO_And:
  2563. case BO_Or:
  2564. case BO_Xor:
  2565. case BO_Shl:
  2566. case BO_Shr:
  2567. case BO_LAnd:
  2568. case BO_LOr:
  2569. case BO_AddAssign:
  2570. case BO_SubAssign:
  2571. case BO_MulAssign:
  2572. case BO_DivAssign:
  2573. case BO_RemAssign:
  2574. case BO_AndAssign:
  2575. case BO_OrAssign:
  2576. case BO_XorAssign:
  2577. case BO_ShlAssign:
  2578. case BO_ShrAssign: {
  2579. const auto result =
  2580. theBuilder.createBinaryOp(spvOp, resultType, lhsVal, rhsVal);
  2581. return lhsVal.isRelaxedPrecision || rhsVal.isRelaxedPrecision
  2582. ? SpirvEvalInfo::withRelaxedPrecision(result)
  2583. : result;
  2584. }
  2585. case BO_Assign:
  2586. llvm_unreachable("assignment should not be handled here");
  2587. default:
  2588. break;
  2589. }
  2590. emitError("BinaryOperator '%0' is not supported yet.")
  2591. << BinaryOperator::getOpcodeStr(opcode);
  2592. return 0;
  2593. }
  2594. void SPIRVEmitter::initOnce(std::string varName, uint32_t varPtr,
  2595. const Expr *varInit) {
  2596. const uint32_t boolType = theBuilder.getBoolType();
  2597. varName = "init.done." + varName;
  2598. // Create a file/module visible variable to hold the initialization state.
  2599. const uint32_t initDoneVar =
  2600. theBuilder.addModuleVar(boolType, spv::StorageClass::Private, varName,
  2601. theBuilder.getConstantBool(false));
  2602. const uint32_t condition = theBuilder.createLoad(boolType, initDoneVar);
  2603. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  2604. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  2605. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  2606. theBuilder.addSuccessor(thenBB);
  2607. theBuilder.addSuccessor(mergeBB);
  2608. theBuilder.setMergeTarget(mergeBB);
  2609. theBuilder.setInsertPoint(thenBB);
  2610. // Do initialization and mark done
  2611. theBuilder.createStore(varPtr, doExpr(varInit));
  2612. theBuilder.createStore(initDoneVar, theBuilder.getConstantBool(true));
  2613. theBuilder.createBranch(mergeBB);
  2614. theBuilder.addSuccessor(mergeBB);
  2615. theBuilder.setInsertPoint(mergeBB);
  2616. }
  2617. bool SPIRVEmitter::isVectorShuffle(const Expr *expr) {
  2618. // TODO: the following check is essentially duplicated from
  2619. // doHLSLVectorElementExpr. Should unify them.
  2620. if (const auto *vecElemExpr = dyn_cast<HLSLVectorElementExpr>(expr)) {
  2621. const Expr *base = nullptr;
  2622. hlsl::VectorMemberAccessPositions accessor;
  2623. condenseVectorElementExpr(vecElemExpr, &base, &accessor);
  2624. const auto accessorSize = accessor.Count;
  2625. if (accessorSize == 1) {
  2626. // Selecting only one element. OpAccessChain or OpCompositeExtract for
  2627. // such cases.
  2628. return false;
  2629. }
  2630. const auto baseSize = hlsl::GetHLSLVecSize(base->getType());
  2631. if (accessorSize != baseSize)
  2632. return true;
  2633. for (uint32_t i = 0; i < accessorSize; ++i) {
  2634. uint32_t position;
  2635. accessor.GetPosition(i, &position);
  2636. if (position != i)
  2637. return true;
  2638. }
  2639. // Selecting exactly the original vector. No vector shuffle generated.
  2640. return false;
  2641. }
  2642. return false;
  2643. }
  2644. bool SPIRVEmitter::isTextureMipsSampleIndexing(const CXXOperatorCallExpr *expr,
  2645. const Expr **base,
  2646. const Expr **location,
  2647. const Expr **lod) {
  2648. if (!expr)
  2649. return false;
  2650. // <object>.mips[][] consists of an outer operator[] and an inner operator[]
  2651. const CXXOperatorCallExpr *outerExpr = expr;
  2652. if (outerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  2653. return false;
  2654. const Expr *arg0 = outerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  2655. const CXXOperatorCallExpr *innerExpr = dyn_cast<CXXOperatorCallExpr>(arg0);
  2656. // Must have an inner operator[]
  2657. if (!innerExpr ||
  2658. innerExpr->getOperator() != OverloadedOperatorKind::OO_Subscript) {
  2659. return false;
  2660. }
  2661. const Expr *innerArg0 =
  2662. innerExpr->getArg(0)->IgnoreParenNoopCasts(astContext);
  2663. const MemberExpr *memberExpr = dyn_cast<MemberExpr>(innerArg0);
  2664. if (!memberExpr)
  2665. return false;
  2666. // Must be accessing the member named "mips" or "sample"
  2667. const auto &memberName =
  2668. memberExpr->getMemberNameInfo().getName().getAsString();
  2669. if (memberName != "mips" && memberName != "sample")
  2670. return false;
  2671. const Expr *object = memberExpr->getBase();
  2672. const auto objectType = object->getType();
  2673. if (!TypeTranslator::isTexture(objectType))
  2674. return false;
  2675. if (base)
  2676. *base = object;
  2677. if (lod)
  2678. *lod = innerExpr->getArg(1);
  2679. if (location)
  2680. *location = outerExpr->getArg(1);
  2681. return true;
  2682. }
  2683. bool SPIRVEmitter::isBufferTextureIndexing(const CXXOperatorCallExpr *indexExpr,
  2684. const Expr **base,
  2685. const Expr **index) {
  2686. if (!indexExpr)
  2687. return false;
  2688. // Must be operator[]
  2689. if (indexExpr->getOperator() != OverloadedOperatorKind::OO_Subscript)
  2690. return false;
  2691. const Expr *object = indexExpr->getArg(0);
  2692. const auto objectType = object->getType();
  2693. if (TypeTranslator::isBuffer(objectType) ||
  2694. TypeTranslator::isRWBuffer(objectType) ||
  2695. TypeTranslator::isTexture(objectType) ||
  2696. TypeTranslator::isRWTexture(objectType)) {
  2697. if (base)
  2698. *base = object;
  2699. if (index)
  2700. *index = indexExpr->getArg(1);
  2701. return true;
  2702. }
  2703. return false;
  2704. }
  2705. void SPIRVEmitter::condenseVectorElementExpr(
  2706. const HLSLVectorElementExpr *expr, const Expr **basePtr,
  2707. hlsl::VectorMemberAccessPositions *flattenedAccessor) {
  2708. llvm::SmallVector<hlsl::VectorMemberAccessPositions, 2> accessors;
  2709. accessors.push_back(expr->getEncodedElementAccess());
  2710. // Recursively descending until we find the true base vector. In the
  2711. // meanwhile, collecting accessors in the reverse order.
  2712. *basePtr = expr->getBase();
  2713. while (const auto *vecElemBase = dyn_cast<HLSLVectorElementExpr>(*basePtr)) {
  2714. accessors.push_back(vecElemBase->getEncodedElementAccess());
  2715. *basePtr = vecElemBase->getBase();
  2716. }
  2717. *flattenedAccessor = accessors.back();
  2718. for (int32_t i = accessors.size() - 2; i >= 0; --i) {
  2719. const auto &currentAccessor = accessors[i];
  2720. // Apply the current level of accessor to the flattened accessor of all
  2721. // previous levels of ones.
  2722. hlsl::VectorMemberAccessPositions combinedAccessor;
  2723. for (uint32_t j = 0; j < currentAccessor.Count; ++j) {
  2724. uint32_t currentPosition = 0;
  2725. currentAccessor.GetPosition(j, &currentPosition);
  2726. uint32_t previousPosition = 0;
  2727. flattenedAccessor->GetPosition(currentPosition, &previousPosition);
  2728. combinedAccessor.SetPosition(j, previousPosition);
  2729. }
  2730. combinedAccessor.Count = currentAccessor.Count;
  2731. combinedAccessor.IsValid =
  2732. flattenedAccessor->IsValid && currentAccessor.IsValid;
  2733. *flattenedAccessor = combinedAccessor;
  2734. }
  2735. }
  2736. SpirvEvalInfo SPIRVEmitter::createVectorSplat(const Expr *scalarExpr,
  2737. uint32_t size) {
  2738. bool isConstVal = false;
  2739. uint32_t scalarVal = 0;
  2740. // Try to evaluate the element as constant first. If successful, then we
  2741. // can generate constant instructions for this vector splat.
  2742. if (scalarVal = tryToEvaluateAsConst(scalarExpr)) {
  2743. isConstVal = true;
  2744. } else {
  2745. scalarVal = doExpr(scalarExpr);
  2746. }
  2747. // Just return the scalar value for vector splat with size 1
  2748. if (size == 1)
  2749. return isConstVal ? SpirvEvalInfo::withConst(scalarVal) : scalarVal;
  2750. const uint32_t vecType = theBuilder.getVecType(
  2751. typeTranslator.translateType(scalarExpr->getType()), size);
  2752. llvm::SmallVector<uint32_t, 4> elements(size_t(size), scalarVal);
  2753. if (isConstVal) {
  2754. // TODO: we are saying the constant has Function storage class here.
  2755. // Should find a more meaningful one.
  2756. return SpirvEvalInfo::withConst(
  2757. theBuilder.getConstantComposite(vecType, elements));
  2758. } else {
  2759. return theBuilder.createCompositeConstruct(vecType, elements);
  2760. }
  2761. }
  2762. void SPIRVEmitter::splitVecLastElement(QualType vecType, uint32_t vec,
  2763. uint32_t *residual,
  2764. uint32_t *lastElement) {
  2765. assert(hlsl::IsHLSLVecType(vecType));
  2766. const uint32_t count = hlsl::GetHLSLVecSize(vecType);
  2767. assert(count > 1);
  2768. const uint32_t elemTypeId =
  2769. typeTranslator.translateType(hlsl::GetHLSLVecElementType(vecType));
  2770. if (count == 2) {
  2771. *residual = theBuilder.createCompositeExtract(elemTypeId, vec, 0);
  2772. } else {
  2773. llvm::SmallVector<uint32_t, 4> indices;
  2774. for (uint32_t i = 0; i < count - 1; ++i)
  2775. indices.push_back(i);
  2776. const uint32_t typeId = theBuilder.getVecType(elemTypeId, count - 1);
  2777. *residual = theBuilder.createVectorShuffle(typeId, vec, vec, indices);
  2778. }
  2779. *lastElement =
  2780. theBuilder.createCompositeExtract(elemTypeId, vec, {count - 1});
  2781. }
  2782. SpirvEvalInfo
  2783. SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
  2784. const QualType type = expr->getType();
  2785. // We can only translate floatN * float into OpVectorTimesScalar.
  2786. // So the result type must be floatN.
  2787. if (!hlsl::IsHLSLVecType(type) ||
  2788. !hlsl::GetHLSLVecElementType(type)->isFloatingType())
  2789. return 0;
  2790. const Expr *lhs = expr->getLHS();
  2791. const Expr *rhs = expr->getRHS();
  2792. // Multiplying a float vector with a float scalar will be represented in
  2793. // AST via a binary operation with two float vectors as operands; one of
  2794. // the operand is from an implicit cast with kind CK_HLSLVectorSplat.
  2795. // vector * scalar
  2796. if (hlsl::IsHLSLVecType(lhs->getType())) {
  2797. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  2798. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  2799. const uint32_t vecType = typeTranslator.translateType(expr->getType());
  2800. if (isa<CompoundAssignOperator>(expr)) {
  2801. SpirvEvalInfo lhsPtr = 0;
  2802. const auto result =
  2803. processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  2804. vecType, &lhsPtr, spv::Op::OpVectorTimesScalar);
  2805. return processAssignment(lhs, result, true, lhsPtr);
  2806. } else {
  2807. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  2808. vecType, nullptr,
  2809. spv::Op::OpVectorTimesScalar);
  2810. }
  2811. }
  2812. }
  2813. }
  2814. // scalar * vector
  2815. if (hlsl::IsHLSLVecType(rhs->getType())) {
  2816. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  2817. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  2818. const uint32_t vecType = typeTranslator.translateType(expr->getType());
  2819. // We need to switch the positions of lhs and rhs here because
  2820. // OpVectorTimesScalar requires the first operand to be a vector and
  2821. // the second to be a scalar.
  2822. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  2823. vecType, nullptr, spv::Op::OpVectorTimesScalar);
  2824. }
  2825. }
  2826. }
  2827. return 0;
  2828. }
  2829. SpirvEvalInfo
  2830. SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
  2831. const QualType type = expr->getType();
  2832. // We can only translate floatMxN * float into OpMatrixTimesScalar.
  2833. // So the result type must be floatMxN.
  2834. if (!hlsl::IsHLSLMatType(type) ||
  2835. !hlsl::GetHLSLMatElementType(type)->isFloatingType())
  2836. return 0;
  2837. const Expr *lhs = expr->getLHS();
  2838. const Expr *rhs = expr->getRHS();
  2839. const QualType lhsType = lhs->getType();
  2840. const QualType rhsType = rhs->getType();
  2841. const auto selectOpcode = [](const QualType ty) {
  2842. return TypeTranslator::isMx1Matrix(ty) || TypeTranslator::is1xNMatrix(ty)
  2843. ? spv::Op::OpVectorTimesScalar
  2844. : spv::Op::OpMatrixTimesScalar;
  2845. };
  2846. // Multiplying a float matrix with a float scalar will be represented in
  2847. // AST via a binary operation with two float matrices as operands; one of
  2848. // the operand is from an implicit cast with kind CK_HLSLMatrixSplat.
  2849. // matrix * scalar
  2850. if (hlsl::IsHLSLMatType(lhsType)) {
  2851. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  2852. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  2853. const uint32_t matType = typeTranslator.translateType(expr->getType());
  2854. const spv::Op opcode = selectOpcode(lhsType);
  2855. if (isa<CompoundAssignOperator>(expr)) {
  2856. SpirvEvalInfo lhsPtr = 0;
  2857. const auto result =
  2858. processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  2859. matType, &lhsPtr, opcode);
  2860. return processAssignment(lhs, result, true, lhsPtr);
  2861. } else {
  2862. return processBinaryOp(lhs, cast->getSubExpr(), expr->getOpcode(),
  2863. matType, nullptr, opcode);
  2864. }
  2865. }
  2866. }
  2867. }
  2868. // scalar * matrix
  2869. if (hlsl::IsHLSLMatType(rhsType)) {
  2870. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  2871. if (cast->getCastKind() == CK_HLSLMatrixSplat) {
  2872. const uint32_t matType = typeTranslator.translateType(expr->getType());
  2873. const spv::Op opcode = selectOpcode(rhsType);
  2874. // We need to switch the positions of lhs and rhs here because
  2875. // OpMatrixTimesScalar requires the first operand to be a matrix and
  2876. // the second to be a scalar.
  2877. return processBinaryOp(rhs, cast->getSubExpr(), expr->getOpcode(),
  2878. matType, nullptr, opcode);
  2879. }
  2880. }
  2881. }
  2882. return 0;
  2883. }
  2884. SpirvEvalInfo
  2885. SPIRVEmitter::tryToAssignToVectorElements(const Expr *lhs,
  2886. const SpirvEvalInfo &rhs) {
  2887. // Assigning to a vector swizzling lhs is tricky if we are neither
  2888. // writing to one element nor all elements in their original order.
  2889. // Under such cases, we need to create a new vector swizzling involving
  2890. // both the lhs and rhs vectors and then write the result of this swizzling
  2891. // into the base vector of lhs.
  2892. // For example, for vec4.yz = vec2, we nee to do the following:
  2893. //
  2894. // %vec4Val = OpLoad %v4float %vec4
  2895. // %vec2Val = OpLoad %v2float %vec2
  2896. // %shuffle = OpVectorShuffle %v4float %vec4Val %vec2Val 0 4 5 3
  2897. // OpStore %vec4 %shuffle
  2898. //
  2899. // When doing the vector shuffle, we use the lhs base vector as the first
  2900. // vector and the rhs vector as the second vector. Therefore, all elements
  2901. // in the second vector will be selected into the shuffle result.
  2902. const auto *lhsExpr = dyn_cast<HLSLVectorElementExpr>(lhs);
  2903. if (!lhsExpr)
  2904. return 0;
  2905. if (!isVectorShuffle(lhs)) {
  2906. // No vector shuffle needed to be generated for this assignment.
  2907. // Should fall back to the normal handling of assignment.
  2908. return 0;
  2909. }
  2910. const Expr *base = nullptr;
  2911. hlsl::VectorMemberAccessPositions accessor;
  2912. condenseVectorElementExpr(lhsExpr, &base, &accessor);
  2913. const QualType baseType = base->getType();
  2914. assert(hlsl::IsHLSLVecType(baseType));
  2915. const auto baseSizse = hlsl::GetHLSLVecSize(baseType);
  2916. llvm::SmallVector<uint32_t, 4> selectors;
  2917. selectors.resize(baseSizse);
  2918. // Assume we are selecting all original elements first.
  2919. for (uint32_t i = 0; i < baseSizse; ++i) {
  2920. selectors[i] = i;
  2921. }
  2922. // Now fix up the elements that actually got overwritten by the rhs vector.
  2923. // Since we are using the rhs vector as the second vector, their index
  2924. // should be offset'ed by the size of the lhs base vector.
  2925. for (uint32_t i = 0; i < accessor.Count; ++i) {
  2926. uint32_t position;
  2927. accessor.GetPosition(i, &position);
  2928. selectors[position] = baseSizse + i;
  2929. }
  2930. const uint32_t baseTypeId = typeTranslator.translateType(baseType);
  2931. const uint32_t vec1 = doExpr(base);
  2932. const uint32_t vec1Val = theBuilder.createLoad(baseTypeId, vec1);
  2933. const uint32_t shuffle =
  2934. theBuilder.createVectorShuffle(baseTypeId, vec1Val, rhs, selectors);
  2935. theBuilder.createStore(vec1, shuffle);
  2936. // TODO: OK, this return value is incorrect for compound assignments, for
  2937. // which cases we should return lvalues. Should at least emit errors if
  2938. // this return value is used (can be checked via ASTContext.getParents).
  2939. return rhs;
  2940. }
  2941. SpirvEvalInfo
  2942. SPIRVEmitter::tryToAssignToRWBufferRWTexture(const Expr *lhs,
  2943. const SpirvEvalInfo &rhs) {
  2944. const Expr *baseExpr = nullptr;
  2945. const Expr *indexExpr = nullptr;
  2946. const auto lhsExpr = dyn_cast<CXXOperatorCallExpr>(lhs);
  2947. if (isBufferTextureIndexing(lhsExpr, &baseExpr, &indexExpr)) {
  2948. const uint32_t locId = doExpr(indexExpr);
  2949. const uint32_t imageId = theBuilder.createLoad(
  2950. typeTranslator.translateType(baseExpr->getType()), doExpr(baseExpr));
  2951. theBuilder.createImageWrite(imageId, locId, rhs);
  2952. return rhs;
  2953. }
  2954. return 0;
  2955. }
  2956. SpirvEvalInfo
  2957. SPIRVEmitter::tryToAssignToMatrixElements(const Expr *lhs,
  2958. const SpirvEvalInfo &rhs) {
  2959. const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
  2960. if (!lhsExpr)
  2961. return 0;
  2962. const Expr *baseMat = lhsExpr->getBase();
  2963. const auto &base = doExpr(baseMat);
  2964. const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
  2965. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  2966. uint32_t rowCount = 0, colCount = 0;
  2967. hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
  2968. // For each lhs element written to:
  2969. // 1. Extract the corresponding rhs element using OpCompositeExtract
  2970. // 2. Create access chain for the lhs element using OpAccessChain
  2971. // 3. Write using OpStore
  2972. const auto accessor = lhsExpr->getEncodedElementAccess();
  2973. for (uint32_t i = 0; i < accessor.Count; ++i) {
  2974. uint32_t row = 0, col = 0;
  2975. accessor.GetPosition(i, &row, &col);
  2976. llvm::SmallVector<uint32_t, 2> indices;
  2977. // If the matrix only have one row/column, we are indexing into a vector
  2978. // then. Only one index is needed for such cases.
  2979. if (rowCount > 1)
  2980. indices.push_back(row);
  2981. if (colCount > 1)
  2982. indices.push_back(col);
  2983. for (uint32_t i = 0; i < indices.size(); ++i)
  2984. indices[i] = theBuilder.getConstantInt32(indices[i]);
  2985. // If we are writing to only one element, the rhs should already be a
  2986. // scalar value.
  2987. uint32_t rhsElem = rhs;
  2988. if (accessor.Count > 1)
  2989. rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
  2990. const uint32_t ptrType =
  2991. theBuilder.getPointerType(elemTypeId, base.storageClass);
  2992. // If the lhs is actually a matrix of size 1x1, we don't need the access
  2993. // chain. base is already the dest pointer.
  2994. uint32_t lhsElemPtr = base;
  2995. if (!indices.empty()) {
  2996. // Load the element via access chain
  2997. lhsElemPtr = theBuilder.createAccessChain(ptrType, lhsElemPtr, indices);
  2998. }
  2999. theBuilder.createStore(lhsElemPtr, rhsElem);
  3000. }
  3001. // TODO: OK, this return value is incorrect for compound assignments, for
  3002. // which cases we should return lvalues. Should at least emit errors if
  3003. // this return value is used (can be checked via ASTContext.getParents).
  3004. return rhs;
  3005. }
  3006. uint32_t SPIRVEmitter::processEachVectorInMatrix(
  3007. const Expr *matrix, const uint32_t matrixVal,
  3008. llvm::function_ref<uint32_t(uint32_t, uint32_t, uint32_t)>
  3009. actOnEachVector) {
  3010. const auto matType = matrix->getType();
  3011. assert(TypeTranslator::isSpirvAcceptableMatrixType(matType));
  3012. const uint32_t vecType = typeTranslator.getComponentVectorType(matType);
  3013. uint32_t rowCount = 0, colCount = 0;
  3014. hlsl::GetHLSLMatRowColCount(matType, rowCount, colCount);
  3015. llvm::SmallVector<uint32_t, 4> vectors;
  3016. // Extract each component vector and do operation on it
  3017. for (uint32_t i = 0; i < rowCount; ++i) {
  3018. const uint32_t lhsVec =
  3019. theBuilder.createCompositeExtract(vecType, matrixVal, {i});
  3020. vectors.push_back(actOnEachVector(i, vecType, lhsVec));
  3021. }
  3022. // Construct the result matrix
  3023. return theBuilder.createCompositeConstruct(
  3024. typeTranslator.translateType(matType), vectors);
  3025. }
  3026. SpirvEvalInfo
  3027. SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
  3028. const BinaryOperatorKind opcode) {
  3029. // TODO: some code are duplicated from processBinaryOp. Try to unify them.
  3030. const auto lhsType = lhs->getType();
  3031. assert(TypeTranslator::isSpirvAcceptableMatrixType(lhsType));
  3032. const spv::Op spvOp = translateOp(opcode, lhsType);
  3033. uint32_t rhsVal, lhsPtr, lhsVal;
  3034. if (BinaryOperator::isCompoundAssignmentOp(opcode)) {
  3035. // Evalute rhs before lhs
  3036. rhsVal = doExpr(rhs);
  3037. lhsPtr = doExpr(lhs);
  3038. const uint32_t lhsTy = typeTranslator.translateType(lhsType);
  3039. lhsVal = theBuilder.createLoad(lhsTy, lhsPtr);
  3040. } else {
  3041. // Evalute lhs before rhs
  3042. lhsVal = lhsPtr = doExpr(lhs);
  3043. rhsVal = doExpr(rhs);
  3044. }
  3045. switch (opcode) {
  3046. case BO_Add:
  3047. case BO_Sub:
  3048. case BO_Mul:
  3049. case BO_Div:
  3050. case BO_Rem:
  3051. case BO_AddAssign:
  3052. case BO_SubAssign:
  3053. case BO_MulAssign:
  3054. case BO_DivAssign:
  3055. case BO_RemAssign: {
  3056. const uint32_t vecType = typeTranslator.getComponentVectorType(lhsType);
  3057. const auto actOnEachVec = [this, spvOp, rhsVal](
  3058. uint32_t index, uint32_t vecType, uint32_t lhsVec) {
  3059. // For each vector of lhs, we need to load the corresponding vector of
  3060. // rhs and do the operation on them.
  3061. const uint32_t rhsVec =
  3062. theBuilder.createCompositeExtract(vecType, rhsVal, {index});
  3063. return theBuilder.createBinaryOp(spvOp, vecType, lhsVec, rhsVec);
  3064. };
  3065. return processEachVectorInMatrix(lhs, lhsVal, actOnEachVec);
  3066. }
  3067. case BO_Assign:
  3068. llvm_unreachable("assignment should not be handled here");
  3069. default:
  3070. break;
  3071. }
  3072. emitError("BinaryOperator '%0' for matrices not supported yet")
  3073. << BinaryOperator::getOpcodeStr(opcode);
  3074. return 0;
  3075. }
  3076. const Expr *SPIRVEmitter::collectArrayStructIndices(
  3077. const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices) {
  3078. if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
  3079. const Expr *base = collectArrayStructIndices(
  3080. indexing->getBase()->IgnoreParenNoopCasts(astContext), indices);
  3081. // Append the index of the current level
  3082. const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
  3083. assert(fieldDecl);
  3084. indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex()));
  3085. return base;
  3086. }
  3087. if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
  3088. // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
  3089. // cast. We need to ingore it to avoid creating OpLoad.
  3090. const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
  3091. const Expr *base = collectArrayStructIndices(thisBase, indices);
  3092. indices->push_back(doExpr(indexing->getIdx()));
  3093. return base;
  3094. }
  3095. if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
  3096. if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
  3097. const Expr *thisBase =
  3098. indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
  3099. const auto thisBaseType = thisBase->getType();
  3100. const Expr *base = collectArrayStructIndices(thisBase, indices);
  3101. // If the base is a StructureType, we need to push an addtional index 0
  3102. // here. This is because we created an additional OpTypeRuntimeArray
  3103. // in the structure.
  3104. if (TypeTranslator::isStructuredBuffer(thisBaseType))
  3105. indices->push_back(theBuilder.getConstantInt32(0));
  3106. if ((hlsl::IsHLSLVecType(thisBaseType) &&
  3107. (hlsl::GetHLSLVecSize(thisBaseType) == 1)) ||
  3108. typeTranslator.is1x1Matrix(thisBaseType) ||
  3109. typeTranslator.is1xNMatrix(thisBaseType)) {
  3110. // If this is a size-1 vector or 1xN matrix, ignore the index.
  3111. } else {
  3112. indices->push_back(doExpr(indexing->getArg(1)));
  3113. }
  3114. return base;
  3115. }
  3116. {
  3117. const Expr *index = nullptr;
  3118. // TODO: the following is duplicating the logic in doCXXMemberCallExpr.
  3119. if (const auto *object = isStructuredBufferLoad(expr, &index)) {
  3120. // For object.Load(index), there should be no more indexing into the
  3121. // object.
  3122. indices->push_back(theBuilder.getConstantInt32(0));
  3123. indices->push_back(doExpr(index));
  3124. return object;
  3125. }
  3126. }
  3127. // This the deepest we can go. No more array or struct indexing.
  3128. return expr;
  3129. }
  3130. uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,
  3131. QualType toBoolType) {
  3132. if (isSameScalarOrVecType(fromType, toBoolType))
  3133. return fromVal;
  3134. // Converting to bool means comparing with value zero.
  3135. const spv::Op spvOp = translateOp(BO_NE, fromType);
  3136. const uint32_t boolType = typeTranslator.translateType(toBoolType);
  3137. const uint32_t zeroVal = getValueZero(fromType);
  3138. return theBuilder.createBinaryOp(spvOp, boolType, fromVal, zeroVal);
  3139. }
  3140. uint32_t SPIRVEmitter::castToInt(const uint32_t fromVal, QualType fromType,
  3141. QualType toIntType) {
  3142. if (isSameScalarOrVecType(fromType, toIntType))
  3143. return fromVal;
  3144. const uint32_t intType = typeTranslator.translateType(toIntType);
  3145. if (isBoolOrVecOfBoolType(fromType)) {
  3146. const uint32_t one = getValueOne(toIntType);
  3147. const uint32_t zero = getValueZero(toIntType);
  3148. return theBuilder.createSelect(intType, fromVal, one, zero);
  3149. }
  3150. if (isSintOrVecOfSintType(fromType) || isUintOrVecOfUintType(fromType)) {
  3151. // TODO: handle different bitwidths
  3152. return theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, fromVal);
  3153. }
  3154. if (isFloatOrVecOfFloatType(fromType)) {
  3155. if (isSintOrVecOfSintType(toIntType)) {
  3156. return theBuilder.createUnaryOp(spv::Op::OpConvertFToS, intType, fromVal);
  3157. } else if (isUintOrVecOfUintType(toIntType)) {
  3158. return theBuilder.createUnaryOp(spv::Op::OpConvertFToU, intType, fromVal);
  3159. } else {
  3160. emitError("unimplemented casting to integer from floating point");
  3161. }
  3162. } else {
  3163. emitError("unimplemented casting to integer");
  3164. }
  3165. return 0;
  3166. }
  3167. uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
  3168. QualType toFloatType) {
  3169. if (isSameScalarOrVecType(fromType, toFloatType))
  3170. return fromVal;
  3171. const uint32_t floatType = typeTranslator.translateType(toFloatType);
  3172. if (isBoolOrVecOfBoolType(fromType)) {
  3173. const uint32_t one = getValueOne(toFloatType);
  3174. const uint32_t zero = getValueZero(toFloatType);
  3175. return theBuilder.createSelect(floatType, fromVal, one, zero);
  3176. }
  3177. if (isSintOrVecOfSintType(fromType)) {
  3178. return theBuilder.createUnaryOp(spv::Op::OpConvertSToF, floatType, fromVal);
  3179. }
  3180. if (isUintOrVecOfUintType(fromType)) {
  3181. return theBuilder.createUnaryOp(spv::Op::OpConvertUToF, floatType, fromVal);
  3182. }
  3183. if (isFloatOrVecOfFloatType(fromType)) {
  3184. emitError("casting between different fp bitwidth unimplemented");
  3185. return 0;
  3186. }
  3187. emitError("unimplemented casting to floating point");
  3188. return 0;
  3189. }
  3190. uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
  3191. const FunctionDecl *callee = callExpr->getDirectCallee();
  3192. assert(hlsl::IsIntrinsicOp(callee) &&
  3193. "doIntrinsicCallExpr was called for a non-intrinsic function.");
  3194. const bool isFloatType = isFloatOrVecMatOfFloatType(callExpr->getType());
  3195. const bool isSintType = isSintOrVecMatOfSintType(callExpr->getType());
  3196. // Figure out which intrinsic function to translate.
  3197. llvm::StringRef group;
  3198. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  3199. hlsl::GetIntrinsicOp(callee, opcode, group);
  3200. GLSLstd450 glslOpcode = GLSLstd450Bad;
  3201. #define INTRINSIC_SPIRV_OP_WITH_CAP_CASE(intrinsicOp, spirvOp, doEachVec, cap) \
  3202. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3203. theBuilder.requireCapability(cap); \
  3204. return processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  3205. doEachVec); \
  3206. } break
  3207. #define INTRINSIC_SPIRV_OP_CASE(intrinsicOp, spirvOp, doEachVec) \
  3208. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3209. return processIntrinsicUsingSpirvInst(callExpr, spv::Op::Op##spirvOp, \
  3210. doEachVec); \
  3211. } break
  3212. #define INTRINSIC_OP_CASE(intrinsicOp, glslOp, doEachVec) \
  3213. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3214. glslOpcode = GLSLstd450::GLSLstd450##glslOp; \
  3215. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3216. } break
  3217. #define INTRINSIC_OP_CASE_INT_FLOAT(intrinsicOp, glslIntOp, glslFloatOp, \
  3218. doEachVec) \
  3219. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3220. glslOpcode = isFloatType ? GLSLstd450::GLSLstd450##glslFloatOp \
  3221. : GLSLstd450::GLSLstd450##glslIntOp; \
  3222. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3223. } break
  3224. #define INTRINSIC_OP_CASE_SINT_UINT(intrinsicOp, glslSintOp, glslUintOp, \
  3225. doEachVec) \
  3226. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3227. glslOpcode = isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  3228. : GLSLstd450::GLSLstd450##glslUintOp; \
  3229. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3230. } break
  3231. #define INTRINSIC_OP_CASE_SINT_UINT_FLOAT(intrinsicOp, glslSintOp, glslUintOp, \
  3232. glslFloatOp, doEachVec) \
  3233. case hlsl::IntrinsicOp::IOP_##intrinsicOp: { \
  3234. glslOpcode = isFloatType \
  3235. ? GLSLstd450::GLSLstd450##glslFloatOp \
  3236. : isSintType ? GLSLstd450::GLSLstd450##glslSintOp \
  3237. : GLSLstd450::GLSLstd450##glslUintOp; \
  3238. return processIntrinsicUsingGLSLInst(callExpr, glslOpcode, doEachVec); \
  3239. } break
  3240. switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
  3241. case hlsl::IntrinsicOp::IOP_dot:
  3242. return processIntrinsicDot(callExpr);
  3243. case hlsl::IntrinsicOp::IOP_mul:
  3244. return processIntrinsicMul(callExpr);
  3245. case hlsl::IntrinsicOp::IOP_all:
  3246. return processIntrinsicAllOrAny(callExpr, spv::Op::OpAll);
  3247. case hlsl::IntrinsicOp::IOP_any:
  3248. return processIntrinsicAllOrAny(callExpr, spv::Op::OpAny);
  3249. case hlsl::IntrinsicOp::IOP_asfloat:
  3250. case hlsl::IntrinsicOp::IOP_asint:
  3251. case hlsl::IntrinsicOp::IOP_asuint:
  3252. return processIntrinsicAsType(callExpr);
  3253. case hlsl::IntrinsicOp::IOP_clip: {
  3254. return processIntrinsicClip(callExpr);
  3255. }
  3256. case hlsl::IntrinsicOp::IOP_clamp:
  3257. case hlsl::IntrinsicOp::IOP_uclamp:
  3258. return processIntrinsicClamp(callExpr);
  3259. case hlsl::IntrinsicOp::IOP_frexp:
  3260. return processIntrinsicFrexp(callExpr);
  3261. case hlsl::IntrinsicOp::IOP_modf:
  3262. return processIntrinsicModf(callExpr);
  3263. case hlsl::IntrinsicOp::IOP_sign: {
  3264. if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
  3265. return processIntrinsicFloatSign(callExpr);
  3266. else
  3267. return processIntrinsicUsingGLSLInst(callExpr,
  3268. GLSLstd450::GLSLstd450SSign,
  3269. /*actPerRowForMatrices*/ true);
  3270. }
  3271. case hlsl::IntrinsicOp::IOP_isfinite: {
  3272. return processIntrinsicIsFinite(callExpr);
  3273. }
  3274. case hlsl::IntrinsicOp::IOP_sincos: {
  3275. return processIntrinsicSinCos(callExpr);
  3276. }
  3277. case hlsl::IntrinsicOp::IOP_rcp: {
  3278. return processIntrinsicRcp(callExpr);
  3279. }
  3280. case hlsl::IntrinsicOp::IOP_saturate: {
  3281. return processIntrinsicSaturate(callExpr);
  3282. }
  3283. case hlsl::IntrinsicOp::IOP_log10: {
  3284. return processIntrinsicLog10(callExpr);
  3285. }
  3286. INTRINSIC_SPIRV_OP_CASE(transpose, Transpose, false);
  3287. INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
  3288. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_coarse, DPdxCoarse, false,
  3289. spv::Capability::DerivativeControl);
  3290. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddx_fine, DPdxFine, false,
  3291. spv::Capability::DerivativeControl);
  3292. INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
  3293. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_coarse, DPdyCoarse, false,
  3294. spv::Capability::DerivativeControl);
  3295. INTRINSIC_SPIRV_OP_WITH_CAP_CASE(ddy_fine, DPdyFine, false,
  3296. spv::Capability::DerivativeControl);
  3297. INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
  3298. INTRINSIC_SPIRV_OP_CASE(isinf, IsInf, true);
  3299. INTRINSIC_SPIRV_OP_CASE(isnan, IsNan, true);
  3300. INTRINSIC_SPIRV_OP_CASE(fmod, FMod, true);
  3301. INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
  3302. INTRINSIC_SPIRV_OP_CASE(reversebits, BitReverse, false);
  3303. INTRINSIC_OP_CASE(round, Round, true);
  3304. INTRINSIC_OP_CASE_INT_FLOAT(abs, SAbs, FAbs, true);
  3305. INTRINSIC_OP_CASE(acos, Acos, true);
  3306. INTRINSIC_OP_CASE(asin, Asin, true);
  3307. INTRINSIC_OP_CASE(atan, Atan, true);
  3308. INTRINSIC_OP_CASE(atan2, Atan2, true);
  3309. INTRINSIC_OP_CASE(ceil, Ceil, true);
  3310. INTRINSIC_OP_CASE(cos, Cos, true);
  3311. INTRINSIC_OP_CASE(cosh, Cosh, true);
  3312. INTRINSIC_OP_CASE(cross, Cross, false);
  3313. INTRINSIC_OP_CASE(degrees, Degrees, true);
  3314. INTRINSIC_OP_CASE(distance, Distance, false);
  3315. INTRINSIC_OP_CASE(determinant, Determinant, false);
  3316. INTRINSIC_OP_CASE(exp, Exp, true);
  3317. INTRINSIC_OP_CASE(exp2, Exp2, true);
  3318. INTRINSIC_OP_CASE_SINT_UINT(firstbithigh, FindSMsb, FindUMsb, false);
  3319. INTRINSIC_OP_CASE_SINT_UINT(ufirstbithigh, FindSMsb, FindUMsb, false);
  3320. INTRINSIC_OP_CASE(faceforward, FaceForward, false);
  3321. INTRINSIC_OP_CASE(firstbitlow, FindILsb, false);
  3322. INTRINSIC_OP_CASE(floor, Floor, true);
  3323. INTRINSIC_OP_CASE(fma, Fma, true);
  3324. INTRINSIC_OP_CASE(frac, Fract, true);
  3325. INTRINSIC_OP_CASE(length, Length, false);
  3326. INTRINSIC_OP_CASE(ldexp, Ldexp, true);
  3327. INTRINSIC_OP_CASE(lerp, FMix, true);
  3328. INTRINSIC_OP_CASE(log, Log, true);
  3329. INTRINSIC_OP_CASE(log2, Log2, true);
  3330. INTRINSIC_OP_CASE(mad, Fma, true);
  3331. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(max, SMax, UMax, FMax, true);
  3332. INTRINSIC_OP_CASE(umax, UMax, true);
  3333. INTRINSIC_OP_CASE_SINT_UINT_FLOAT(min, SMin, UMin, FMin, true);
  3334. INTRINSIC_OP_CASE(umin, UMin, true);
  3335. INTRINSIC_OP_CASE(normalize, Normalize, false);
  3336. INTRINSIC_OP_CASE(pow, Pow, true);
  3337. INTRINSIC_OP_CASE(radians, Radians, true);
  3338. INTRINSIC_OP_CASE(reflect, Reflect, false);
  3339. INTRINSIC_OP_CASE(refract, Refract, false);
  3340. INTRINSIC_OP_CASE(rsqrt, InverseSqrt, true);
  3341. INTRINSIC_OP_CASE(smoothstep, SmoothStep, true);
  3342. INTRINSIC_OP_CASE(step, Step, true);
  3343. INTRINSIC_OP_CASE(sin, Sin, true);
  3344. INTRINSIC_OP_CASE(sinh, Sinh, true);
  3345. INTRINSIC_OP_CASE(tan, Tan, true);
  3346. INTRINSIC_OP_CASE(tanh, Tanh, true);
  3347. INTRINSIC_OP_CASE(sqrt, Sqrt, true);
  3348. INTRINSIC_OP_CASE(trunc, Trunc, true);
  3349. default:
  3350. emitError("Intrinsic function '%0' not yet implemented.")
  3351. << callee->getName();
  3352. return 0;
  3353. }
  3354. #undef INTRINSIC_OP_CASE
  3355. #undef INTRINSIC_OP_CASE_INT_FLOAT
  3356. return 0;
  3357. }
  3358. uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
  3359. // Signature is: ret modf(x, ip)
  3360. // [in] x: the input floating-point value.
  3361. // [out] ip: the integer portion of x.
  3362. // [out] ret: the fractional portion of x.
  3363. // All of the above must be a scalar, vector, or matrix with the same
  3364. // component types. Component types can be float or int.
  3365. // The ModfStruct SPIR-V instruction returns a struct. The first member is the
  3366. // fractional part and the second member is the integer portion.
  3367. // ModfStruct {
  3368. // <scalar or vector of float> frac;
  3369. // <scalar or vector of float> ip;
  3370. // }
  3371. // Note if the input number (x) is not a float (i.e. 'x' is an int), it is
  3372. // automatically converted to float before modf is invoked. Sadly, the 'ip'
  3373. // argument is not treated the same way. Therefore, in such cases we'll have
  3374. // to manually convert the float result into int.
  3375. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3376. const Expr *arg = callExpr->getArg(0);
  3377. const Expr *ipArg = callExpr->getArg(1);
  3378. const auto argType = arg->getType();
  3379. const auto ipType = ipArg->getType();
  3380. const auto returnType = callExpr->getType();
  3381. const auto returnTypeId = typeTranslator.translateType(returnType);
  3382. const auto ipTypeId = typeTranslator.translateType(ipType);
  3383. const uint32_t argId = doExpr(arg);
  3384. const uint32_t ipId = doExpr(ipArg);
  3385. // TODO: We currently do not support non-float matrices.
  3386. QualType ipElemType = {};
  3387. if (TypeTranslator::isMxNMatrix(ipType, &ipElemType) &&
  3388. !ipElemType->isFloatingType()) {
  3389. emitError("Non-FP matrices are currently not supported.");
  3390. return 0;
  3391. }
  3392. // For scalar and vector argument types.
  3393. {
  3394. if (TypeTranslator::isScalarType(argType) ||
  3395. TypeTranslator::isVectorType(argType)) {
  3396. const auto argTypeId = typeTranslator.translateType(argType);
  3397. // The struct members *must* have the same type.
  3398. const auto modfStructTypeId = theBuilder.getStructType(
  3399. {argTypeId, argTypeId}, "ModfStructType", {"frac", "ip"});
  3400. const auto modf =
  3401. theBuilder.createExtInst(modfStructTypeId, glslInstSetId,
  3402. GLSLstd450::GLSLstd450ModfStruct, {argId});
  3403. auto ip = theBuilder.createCompositeExtract(argTypeId, modf, {1});
  3404. // This will do nothing if the input number (x) and the ip are both of the
  3405. // same type. Otherwise, it will convert the ip into int as necessary.
  3406. ip = castToInt(ip, argType, ipType);
  3407. theBuilder.createStore(ipId, ip);
  3408. return theBuilder.createCompositeExtract(argTypeId, modf, {0});
  3409. }
  3410. }
  3411. // For matrix argument types.
  3412. {
  3413. uint32_t rowCount = 0, colCount = 0;
  3414. QualType elemType = {};
  3415. if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
  3416. const auto elemTypeId = typeTranslator.translateType(elemType);
  3417. const auto colTypeId = theBuilder.getVecType(elemTypeId, colCount);
  3418. const auto modfStructTypeId = theBuilder.getStructType(
  3419. {colTypeId, colTypeId}, "ModfStructType", {"frac", "ip"});
  3420. llvm::SmallVector<uint32_t, 4> fracs;
  3421. llvm::SmallVector<uint32_t, 4> ips;
  3422. for (uint32_t i = 0; i < rowCount; ++i) {
  3423. const auto curRow =
  3424. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  3425. const auto modf = theBuilder.createExtInst(
  3426. modfStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450ModfStruct,
  3427. {curRow});
  3428. auto ip = theBuilder.createCompositeExtract(colTypeId, modf, {1});
  3429. ips.push_back(ip);
  3430. fracs.push_back(
  3431. theBuilder.createCompositeExtract(colTypeId, modf, {0}));
  3432. }
  3433. theBuilder.createStore(
  3434. ipId, theBuilder.createCompositeConstruct(returnTypeId, ips));
  3435. return theBuilder.createCompositeConstruct(returnTypeId, fracs);
  3436. }
  3437. }
  3438. emitError("Unknown argument type passed to Modf function.");
  3439. return 0;
  3440. }
  3441. uint32_t SPIRVEmitter::processIntrinsicFrexp(const CallExpr *callExpr) {
  3442. // Signature is: ret frexp(x, exp)
  3443. // [in] x: the input floating-point value.
  3444. // [out] exp: the calculated exponent.
  3445. // [out] ret: the calculated mantissa.
  3446. // All of the above must be a scalar, vector, or matrix of *float* type.
  3447. // The FrexpStruct SPIR-V instruction returns a struct. The first
  3448. // member is the significand (mantissa) and must be of the same type as the
  3449. // input parameter, and the second member is the exponent and must always be a
  3450. // scalar or vector of 32-bit *integer* type.
  3451. // FrexpStruct {
  3452. // <scalar or vector of int/float> mantissa;
  3453. // <scalar or vector of integers> exponent;
  3454. // }
  3455. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3456. const Expr *arg = callExpr->getArg(0);
  3457. const auto argType = arg->getType();
  3458. const auto intId = theBuilder.getInt32Type();
  3459. const auto returnTypeId = typeTranslator.translateType(callExpr->getType());
  3460. const uint32_t argId = doExpr(arg);
  3461. const uint32_t expId = doExpr(callExpr->getArg(1));
  3462. // For scalar and vector argument types.
  3463. {
  3464. uint32_t elemCount = 1;
  3465. if (TypeTranslator::isScalarType(argType) ||
  3466. TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  3467. const auto argTypeId = typeTranslator.translateType(argType);
  3468. const auto expTypeId =
  3469. elemCount == 1 ? intId : theBuilder.getVecType(intId, elemCount);
  3470. const auto frexpStructTypeId = theBuilder.getStructType(
  3471. {argTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  3472. const auto frexp =
  3473. theBuilder.createExtInst(frexpStructTypeId, glslInstSetId,
  3474. GLSLstd450::GLSLstd450FrexpStruct, {argId});
  3475. const auto exponentInt =
  3476. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  3477. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  3478. // expects a float, an conversion must take place before writing the
  3479. // results.
  3480. const auto exponentFloat = theBuilder.createUnaryOp(
  3481. spv::Op::OpConvertSToF, returnTypeId, exponentInt);
  3482. theBuilder.createStore(expId, exponentFloat);
  3483. return theBuilder.createCompositeExtract(argTypeId, frexp, {0});
  3484. }
  3485. }
  3486. // For matrix argument types.
  3487. {
  3488. uint32_t rowCount = 0, colCount = 0;
  3489. if (TypeTranslator::isMxNMatrix(argType, nullptr, &rowCount, &colCount)) {
  3490. const auto floatId = theBuilder.getFloat32Type();
  3491. const auto expTypeId = theBuilder.getVecType(intId, colCount);
  3492. const auto colTypeId = theBuilder.getVecType(floatId, colCount);
  3493. const auto frexpStructTypeId = theBuilder.getStructType(
  3494. {colTypeId, expTypeId}, "FrexpStructType", {"mantissa", "exponent"});
  3495. llvm::SmallVector<uint32_t, 4> exponents;
  3496. llvm::SmallVector<uint32_t, 4> mantissas;
  3497. for (uint32_t i = 0; i < rowCount; ++i) {
  3498. const auto curRow =
  3499. theBuilder.createCompositeExtract(colTypeId, argId, {i});
  3500. const auto frexp = theBuilder.createExtInst(
  3501. frexpStructTypeId, glslInstSetId, GLSLstd450::GLSLstd450FrexpStruct,
  3502. {curRow});
  3503. const auto exponentInt =
  3504. theBuilder.createCompositeExtract(expTypeId, frexp, {1});
  3505. // Since the SPIR-V instruction returns an int, and the intrinsic HLSL
  3506. // expects a float, an conversion must take place before writing the
  3507. // results.
  3508. const auto exponentFloat = theBuilder.createUnaryOp(
  3509. spv::Op::OpConvertSToF, colTypeId, exponentInt);
  3510. exponents.push_back(exponentFloat);
  3511. mantissas.push_back(
  3512. theBuilder.createCompositeExtract(colTypeId, frexp, {0}));
  3513. }
  3514. const auto exponentsResultId =
  3515. theBuilder.createCompositeConstruct(returnTypeId, exponents);
  3516. theBuilder.createStore(expId, exponentsResultId);
  3517. return theBuilder.createCompositeConstruct(returnTypeId, mantissas);
  3518. }
  3519. }
  3520. emitError("Unknown argument type passed to Frexp function.");
  3521. return 0;
  3522. }
  3523. uint32_t SPIRVEmitter::processIntrinsicClip(const CallExpr *callExpr) {
  3524. // Discards the current pixel if the specified value is less than zero.
  3525. // TODO: If the argument can be const folded and evaluated, we could
  3526. // potentially avoid creating a branch. This would be a bit challenging for
  3527. // matrix/vector arguments.
  3528. assert(callExpr->getNumArgs() == 1u);
  3529. const Expr *arg = callExpr->getArg(0);
  3530. const auto argType = arg->getType();
  3531. const auto boolType = theBuilder.getBoolType();
  3532. uint32_t condition = 0;
  3533. // Could not determine the argument as a constant. We need to branch based on
  3534. // the argument. If the argument is a vector/matrix, clipping is done if *any*
  3535. // element of the vector/matrix is less than zero.
  3536. const uint32_t argId = doExpr(arg);
  3537. QualType elemType = {};
  3538. uint32_t elemCount = 0, rowCount = 0, colCount = 0;
  3539. if (TypeTranslator::isScalarType(argType)) {
  3540. const auto zero = getValueZero(argType);
  3541. condition = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan, boolType,
  3542. argId, zero);
  3543. } else if (TypeTranslator::isVectorType(argType, nullptr, &elemCount)) {
  3544. const auto zero = getValueZero(argType);
  3545. const auto boolVecType = theBuilder.getVecType(boolType, elemCount);
  3546. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  3547. boolVecType, argId, zero);
  3548. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  3549. } else if (TypeTranslator::isMxNMatrix(argType, &elemType, &rowCount,
  3550. &colCount)) {
  3551. const uint32_t elemTypeId = typeTranslator.translateType(elemType);
  3552. const uint32_t floatVecType = theBuilder.getVecType(elemTypeId, colCount);
  3553. const uint32_t elemZeroId = getValueZero(elemType);
  3554. llvm::SmallVector<uint32_t, 4> elements(size_t(colCount), elemZeroId);
  3555. const auto zero = theBuilder.getConstantComposite(floatVecType, elements);
  3556. llvm::SmallVector<uint32_t, 4> cmpResults;
  3557. for (uint32_t i = 0; i < rowCount; ++i) {
  3558. const uint32_t lhsVec =
  3559. theBuilder.createCompositeExtract(floatVecType, argId, {i});
  3560. const auto boolColType = theBuilder.getVecType(boolType, colCount);
  3561. const auto cmp = theBuilder.createBinaryOp(spv::Op::OpFOrdLessThan,
  3562. boolColType, lhsVec, zero);
  3563. const auto any = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, cmp);
  3564. cmpResults.push_back(any);
  3565. }
  3566. const auto boolRowType = theBuilder.getVecType(boolType, rowCount);
  3567. const auto results =
  3568. theBuilder.createCompositeConstruct(boolRowType, cmpResults);
  3569. condition = theBuilder.createUnaryOp(spv::Op::OpAny, boolType, results);
  3570. } else {
  3571. emitError("Invalid type passed to clip function.");
  3572. return 0;
  3573. }
  3574. // Then we need to emit the instruction for the conditional branch.
  3575. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  3576. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  3577. // Create the branch instruction. This will end the current basic block.
  3578. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  3579. theBuilder.addSuccessor(thenBB);
  3580. theBuilder.addSuccessor(mergeBB);
  3581. theBuilder.setMergeTarget(mergeBB);
  3582. // Handle the then branch
  3583. theBuilder.setInsertPoint(thenBB);
  3584. theBuilder.createKill();
  3585. theBuilder.addSuccessor(mergeBB);
  3586. // From now on, we'll emit instructions into the merge block.
  3587. theBuilder.setInsertPoint(mergeBB);
  3588. return 0;
  3589. }
  3590. uint32_t SPIRVEmitter::processIntrinsicClamp(const CallExpr *callExpr) {
  3591. // According the HLSL reference: clamp(X, Min, Max) takes 3 arguments. Each
  3592. // one may be int, uint, or float.
  3593. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3594. const QualType returnType = callExpr->getType();
  3595. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  3596. GLSLstd450 glslOpcode = GLSLstd450::GLSLstd450UClamp;
  3597. if (isFloatOrVecMatOfFloatType(returnType))
  3598. glslOpcode = GLSLstd450::GLSLstd450FClamp;
  3599. else if (isSintOrVecMatOfSintType(returnType))
  3600. glslOpcode = GLSLstd450::GLSLstd450SClamp;
  3601. // Get the function parameters. Expect 3 parameters.
  3602. assert(callExpr->getNumArgs() == 3u);
  3603. const Expr *argX = callExpr->getArg(0);
  3604. const Expr *argMin = callExpr->getArg(1);
  3605. const Expr *argMax = callExpr->getArg(2);
  3606. const uint32_t argXId = doExpr(argX);
  3607. const uint32_t argMinId = doExpr(argMin);
  3608. const uint32_t argMaxId = doExpr(argMax);
  3609. // FClamp, UClamp, and SClamp do not operate on matrices, so we should perform
  3610. // the operation on each vector of the matrix.
  3611. if (TypeTranslator::isSpirvAcceptableMatrixType(argX->getType())) {
  3612. const auto actOnEachVec = [this, glslInstSetId, glslOpcode, argMinId,
  3613. argMaxId](uint32_t index, uint32_t vecType,
  3614. uint32_t curRowId) {
  3615. const auto minRowId =
  3616. theBuilder.createCompositeExtract(vecType, argMinId, {index});
  3617. const auto maxRowId =
  3618. theBuilder.createCompositeExtract(vecType, argMaxId, {index});
  3619. return theBuilder.createExtInst(vecType, glslInstSetId, glslOpcode,
  3620. {curRowId, minRowId, maxRowId});
  3621. };
  3622. return processEachVectorInMatrix(argX, argXId, actOnEachVec);
  3623. }
  3624. return theBuilder.createExtInst(returnTypeId, glslInstSetId, glslOpcode,
  3625. {argXId, argMinId, argMaxId});
  3626. }
  3627. uint32_t SPIRVEmitter::processIntrinsicMul(const CallExpr *callExpr) {
  3628. const QualType returnType = callExpr->getType();
  3629. const uint32_t returnTypeId =
  3630. typeTranslator.translateType(callExpr->getType());
  3631. // Get the function parameters. Expect 2 parameters.
  3632. assert(callExpr->getNumArgs() == 2u);
  3633. const Expr *arg0 = callExpr->getArg(0);
  3634. const Expr *arg1 = callExpr->getArg(1);
  3635. const QualType arg0Type = arg0->getType();
  3636. const QualType arg1Type = arg1->getType();
  3637. // The HLSL mul() function takes 2 arguments. Each argument may be a scalar,
  3638. // vector, or matrix. The frontend ensures that the two arguments have the
  3639. // same component type. The only allowed component types are int and float.
  3640. // mul(scalar, vector)
  3641. {
  3642. uint32_t elemCount = 0;
  3643. if (TypeTranslator::isScalarType(arg0Type) &&
  3644. TypeTranslator::isVectorType(arg1Type, nullptr, &elemCount)) {
  3645. const uint32_t arg1Id = doExpr(arg1);
  3646. // We can use OpVectorTimesScalar if arguments are floats.
  3647. if (arg0Type->isFloatingType())
  3648. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  3649. returnTypeId, arg1Id, doExpr(arg0));
  3650. // Use OpIMul for integers
  3651. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId,
  3652. createVectorSplat(arg0, elemCount),
  3653. arg1Id);
  3654. }
  3655. }
  3656. // mul(vector, scalar)
  3657. {
  3658. uint32_t elemCount = 0;
  3659. if (TypeTranslator::isVectorType(arg0Type, nullptr, &elemCount) &&
  3660. TypeTranslator::isScalarType(arg1Type)) {
  3661. const uint32_t arg0Id = doExpr(arg0);
  3662. // We can use OpVectorTimesScalar if arguments are floats.
  3663. if (arg1Type->isFloatingType())
  3664. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  3665. returnTypeId, arg0Id, doExpr(arg1));
  3666. // Use OpIMul for integers
  3667. return theBuilder.createBinaryOp(spv::Op::OpIMul, returnTypeId, arg0Id,
  3668. createVectorSplat(arg1, elemCount));
  3669. }
  3670. }
  3671. // mul(vector, vector)
  3672. if (TypeTranslator::isVectorType(arg0Type) &&
  3673. TypeTranslator::isVectorType(arg1Type))
  3674. return processIntrinsicDot(callExpr);
  3675. // All the following cases require handling arg0 and arg1 expressions first.
  3676. const uint32_t arg0Id = doExpr(arg0);
  3677. const uint32_t arg1Id = doExpr(arg1);
  3678. // mul(scalar, scalar)
  3679. if (TypeTranslator::isScalarType(arg0Type) &&
  3680. TypeTranslator::isScalarType(arg1Type))
  3681. return theBuilder.createBinaryOp(translateOp(BO_Mul, arg0Type),
  3682. returnTypeId, arg0Id, arg1Id);
  3683. // mul(scalar, matrix)
  3684. if (TypeTranslator::isScalarType(arg0Type) &&
  3685. TypeTranslator::isMxNMatrix(arg1Type)) {
  3686. // We currently only support float matrices. So we can use
  3687. // OpMatrixTimesScalar
  3688. if (arg0Type->isFloatingType())
  3689. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  3690. returnTypeId, arg1Id, arg0Id);
  3691. }
  3692. // mul(matrix, scalar)
  3693. if (TypeTranslator::isScalarType(arg1Type) &&
  3694. TypeTranslator::isMxNMatrix(arg0Type)) {
  3695. // We currently only support float matrices. So we can use
  3696. // OpMatrixTimesScalar
  3697. if (arg1Type->isFloatingType())
  3698. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesScalar,
  3699. returnTypeId, arg0Id, arg1Id);
  3700. }
  3701. // mul(vector, matrix)
  3702. {
  3703. QualType elemType = {};
  3704. uint32_t elemCount = 0, numRows = 0;
  3705. if (TypeTranslator::isVectorType(arg0Type, &elemType, &elemCount) &&
  3706. TypeTranslator::isMxNMatrix(arg1Type, nullptr, &numRows, nullptr) &&
  3707. elemType->isFloatingType()) {
  3708. assert(elemCount == numRows);
  3709. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesVector,
  3710. returnTypeId, arg1Id, arg0Id);
  3711. }
  3712. }
  3713. // mul(matrix, vector)
  3714. {
  3715. QualType elemType = {};
  3716. uint32_t elemCount = 0, numCols = 0;
  3717. if (TypeTranslator::isMxNMatrix(arg0Type, nullptr, nullptr, &numCols) &&
  3718. TypeTranslator::isVectorType(arg1Type, &elemType, &elemCount) &&
  3719. elemType->isFloatingType()) {
  3720. assert(elemCount == numCols);
  3721. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesMatrix,
  3722. returnTypeId, arg1Id, arg0Id);
  3723. }
  3724. }
  3725. // mul(matrix, matrix)
  3726. {
  3727. QualType elemType = {};
  3728. uint32_t arg0Cols = 0, arg1Rows = 0;
  3729. if (TypeTranslator::isMxNMatrix(arg0Type, &elemType, nullptr, &arg0Cols) &&
  3730. TypeTranslator::isMxNMatrix(arg1Type, nullptr, &arg1Rows, nullptr) &&
  3731. elemType->isFloatingType()) {
  3732. assert(arg0Cols == arg1Rows);
  3733. return theBuilder.createBinaryOp(spv::Op::OpMatrixTimesMatrix,
  3734. returnTypeId, arg1Id, arg0Id);
  3735. }
  3736. }
  3737. emitError("Unsupported arguments passed to mul() function.");
  3738. return 0;
  3739. }
  3740. uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
  3741. const QualType returnType = callExpr->getType();
  3742. const uint32_t returnTypeId =
  3743. typeTranslator.translateType(callExpr->getType());
  3744. // Get the function parameters. Expect 2 vectors as parameters.
  3745. assert(callExpr->getNumArgs() == 2u);
  3746. const Expr *arg0 = callExpr->getArg(0);
  3747. const Expr *arg1 = callExpr->getArg(1);
  3748. const uint32_t arg0Id = doExpr(arg0);
  3749. const uint32_t arg1Id = doExpr(arg1);
  3750. QualType arg0Type = arg0->getType();
  3751. QualType arg1Type = arg1->getType();
  3752. const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
  3753. const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
  3754. const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
  3755. const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
  3756. assert(returnType == vec1ComponentType);
  3757. assert(vec0ComponentType == vec1ComponentType);
  3758. assert(vec0Size == vec1Size);
  3759. assert(vec0Size >= 1 && vec0Size <= 4);
  3760. // According to HLSL reference, the dot function only works on integers
  3761. // and floats.
  3762. assert(returnType->isFloatingType() || returnType->isIntegerType());
  3763. // Special case: dot product of two vectors, each of size 1. That is
  3764. // basically the same as regular multiplication of 2 scalars.
  3765. if (vec0Size == 1) {
  3766. const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
  3767. return theBuilder.createBinaryOp(spvOp, returnTypeId, arg0Id, arg1Id);
  3768. }
  3769. // If the vectors are of type Float, we can use OpDot.
  3770. if (returnType->isFloatingType()) {
  3771. return theBuilder.createBinaryOp(spv::Op::OpDot, returnTypeId, arg0Id,
  3772. arg1Id);
  3773. }
  3774. // Vector component type is Integer (signed or unsigned).
  3775. // Create all instructions necessary to perform a dot product on
  3776. // two integer vectors. SPIR-V OpDot does not support integer vectors.
  3777. // Therefore, we use other SPIR-V instructions (addition and
  3778. // multiplication).
  3779. else {
  3780. uint32_t result = 0;
  3781. llvm::SmallVector<uint32_t, 4> multIds;
  3782. const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
  3783. const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
  3784. // Extract members from the two vectors and multiply them.
  3785. for (unsigned int i = 0; i < vec0Size; ++i) {
  3786. const uint32_t vec0member =
  3787. theBuilder.createCompositeExtract(returnTypeId, arg0Id, {i});
  3788. const uint32_t vec1member =
  3789. theBuilder.createCompositeExtract(returnTypeId, arg1Id, {i});
  3790. const uint32_t multId = theBuilder.createBinaryOp(multSpvOp, returnTypeId,
  3791. vec0member, vec1member);
  3792. multIds.push_back(multId);
  3793. }
  3794. // Add all the multiplications.
  3795. result = multIds[0];
  3796. for (unsigned int i = 1; i < vec0Size; ++i) {
  3797. const uint32_t additionId =
  3798. theBuilder.createBinaryOp(addSpvOp, returnTypeId, result, multIds[i]);
  3799. result = additionId;
  3800. }
  3801. return result;
  3802. }
  3803. }
  3804. uint32_t SPIRVEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
  3805. // 'rcp' takes only 1 argument that is a scalar, vector, or matrix of type
  3806. // float or double.
  3807. assert(callExpr->getNumArgs() == 1u);
  3808. const QualType returnType = callExpr->getType();
  3809. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  3810. const Expr *arg = callExpr->getArg(0);
  3811. const uint32_t argId = doExpr(arg);
  3812. const QualType argType = arg->getType();
  3813. // For cases with matrix argument.
  3814. QualType elemType = {};
  3815. uint32_t numRows = 0, numCols = 0;
  3816. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  3817. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  3818. const auto actOnEachVec = [this, vecOne](uint32_t /*index*/,
  3819. uint32_t vecType,
  3820. uint32_t curRowId) {
  3821. return theBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne,
  3822. curRowId);
  3823. };
  3824. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  3825. }
  3826. // For cases with scalar or vector arguments.
  3827. return theBuilder.createBinaryOp(spv::Op::OpFDiv, returnTypeId,
  3828. getValueOne(argType), argId);
  3829. }
  3830. uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
  3831. spv::Op spvOp) {
  3832. // 'all' and 'any' take only 1 parameter.
  3833. assert(callExpr->getNumArgs() == 1u);
  3834. const QualType returnType = callExpr->getType();
  3835. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  3836. const Expr *arg = callExpr->getArg(0);
  3837. const QualType argType = arg->getType();
  3838. // Handle scalars, vectors of size 1, and 1x1 matrices as arguments.
  3839. // Optimization: can directly cast them to boolean. No need for OpAny/OpAll.
  3840. {
  3841. QualType scalarType = {};
  3842. if (TypeTranslator::isScalarType(argType, &scalarType) &&
  3843. (scalarType->isBooleanType() || scalarType->isFloatingType() ||
  3844. scalarType->isIntegerType()))
  3845. return castToBool(doExpr(arg), argType, returnType);
  3846. }
  3847. // Handle vectors larger than 1, Mx1 matrices, and 1xN matrices as arguments.
  3848. // Cast the vector to a boolean vector, then run OpAny/OpAll on it.
  3849. {
  3850. QualType elemType = {};
  3851. uint32_t size = 0;
  3852. if (TypeTranslator::isVectorType(argType, &elemType, &size)) {
  3853. const QualType castToBoolType =
  3854. astContext.getExtVectorType(returnType, size);
  3855. uint32_t castedToBoolId =
  3856. castToBool(doExpr(arg), argType, castToBoolType);
  3857. return theBuilder.createUnaryOp(spvOp, returnTypeId, castedToBoolId);
  3858. }
  3859. }
  3860. // Handle MxN matrices as arguments.
  3861. {
  3862. QualType elemType = {};
  3863. uint32_t matRowCount = 0, matColCount = 0;
  3864. if (TypeTranslator::isMxNMatrix(argType, &elemType, &matRowCount,
  3865. &matColCount)) {
  3866. if (!elemType->isFloatingType()) {
  3867. emitError("'all' and 'any' currently do not take non-floating point "
  3868. "matrices as argument.");
  3869. return 0;
  3870. }
  3871. uint32_t matrixId = doExpr(arg);
  3872. const uint32_t vecType = typeTranslator.getComponentVectorType(argType);
  3873. llvm::SmallVector<uint32_t, 4> rowResults;
  3874. for (uint32_t i = 0; i < matRowCount; ++i) {
  3875. // Extract the row which is a float vector of size matColCount.
  3876. const uint32_t rowFloatVec =
  3877. theBuilder.createCompositeExtract(vecType, matrixId, {i});
  3878. // Cast the float vector to boolean vector.
  3879. const auto rowFloatQualType =
  3880. astContext.getExtVectorType(elemType, matColCount);
  3881. const auto rowBoolQualType =
  3882. astContext.getExtVectorType(returnType, matColCount);
  3883. const uint32_t rowBoolVec =
  3884. castToBool(rowFloatVec, rowFloatQualType, rowBoolQualType);
  3885. // Perform OpAny/OpAll on the boolean vector.
  3886. rowResults.push_back(
  3887. theBuilder.createUnaryOp(spvOp, returnTypeId, rowBoolVec));
  3888. }
  3889. // Create a new vector that is the concatenation of results of all rows.
  3890. uint32_t boolId = theBuilder.getBoolType();
  3891. uint32_t vecOfBoolsId = theBuilder.getVecType(boolId, matRowCount);
  3892. const uint32_t rowResultsId =
  3893. theBuilder.createCompositeConstruct(vecOfBoolsId, rowResults);
  3894. // Run OpAny/OpAll on the newly-created vector.
  3895. return theBuilder.createUnaryOp(spvOp, returnTypeId, rowResultsId);
  3896. }
  3897. }
  3898. // All types should be handled already.
  3899. llvm_unreachable("Unknown argument type passed to all()/any().");
  3900. return 0;
  3901. }
  3902. uint32_t SPIRVEmitter::processIntrinsicAsType(const CallExpr *callExpr) {
  3903. const QualType returnType = callExpr->getType();
  3904. const uint32_t returnTypeId = typeTranslator.translateType(returnType);
  3905. assert(callExpr->getNumArgs() == 1u);
  3906. const Expr *arg = callExpr->getArg(0);
  3907. const QualType argType = arg->getType();
  3908. // asfloat may take a float or a float vector or a float matrix as argument.
  3909. // These cases would be a no-op.
  3910. if (returnType.getCanonicalType() == argType.getCanonicalType())
  3911. return doExpr(arg);
  3912. // SPIR-V does not support non-floating point matrices. So 'asint' and
  3913. // 'asuint' for MxN matrices are currently not supported.
  3914. if (TypeTranslator::isMxNMatrix(argType)) {
  3915. emitError("SPIR-V does not support non-floating point matrices. Thus, "
  3916. "'asint' and 'asuint' currently do not take matrix arguments.");
  3917. return 0;
  3918. }
  3919. return theBuilder.createUnaryOp(spv::Op::OpBitcast, returnTypeId,
  3920. doExpr(arg));
  3921. }
  3922. uint32_t SPIRVEmitter::processIntrinsicIsFinite(const CallExpr *callExpr) {
  3923. // Since OpIsFinite needs the Kernel capability, translation is instead done
  3924. // using OpIsNan and OpIsInf:
  3925. // isFinite = !(isNan || isInf)
  3926. const auto arg = doExpr(callExpr->getArg(0));
  3927. const auto returnType = typeTranslator.translateType(callExpr->getType());
  3928. const auto isNan =
  3929. theBuilder.createUnaryOp(spv::Op::OpIsNan, returnType, arg);
  3930. const auto isInf =
  3931. theBuilder.createUnaryOp(spv::Op::OpIsInf, returnType, arg);
  3932. const auto isNanOrInf =
  3933. theBuilder.createBinaryOp(spv::Op::OpLogicalOr, returnType, isNan, isInf);
  3934. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, returnType,
  3935. isNanOrInf);
  3936. }
  3937. uint32_t SPIRVEmitter::processIntrinsicSinCos(const CallExpr *callExpr) {
  3938. // Since there is no sincos equivalent in SPIR-V, we need to perform Sin
  3939. // once and Cos once. We can reuse existing Sine/Cosine handling functions.
  3940. CallExpr *sincosExpr =
  3941. new (astContext) CallExpr(astContext, Stmt::StmtClass::NoStmtClass, {});
  3942. sincosExpr->setType(callExpr->getArg(0)->getType());
  3943. sincosExpr->setNumArgs(astContext, 1);
  3944. sincosExpr->setArg(0, const_cast<Expr *>(callExpr->getArg(0)));
  3945. // Perform Sin and store results in argument 1.
  3946. const uint32_t sin =
  3947. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Sin,
  3948. /*actPerRowForMatrices*/ true);
  3949. theBuilder.createStore(doExpr(callExpr->getArg(1)), sin);
  3950. // Perform Cos and store results in argument 2.
  3951. const uint32_t cos =
  3952. processIntrinsicUsingGLSLInst(sincosExpr, GLSLstd450::GLSLstd450Cos,
  3953. /*actPerRowForMatrices*/ true);
  3954. theBuilder.createStore(doExpr(callExpr->getArg(2)), cos);
  3955. return 0;
  3956. }
  3957. uint32_t SPIRVEmitter::processIntrinsicSaturate(const CallExpr *callExpr) {
  3958. const auto *arg = callExpr->getArg(0);
  3959. const auto argId = doExpr(arg);
  3960. const auto argType = arg->getType();
  3961. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  3962. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3963. if (argType->isFloatingType()) {
  3964. const uint32_t floatZero = getValueZero(argType);
  3965. const uint32_t floatOne = getValueOne(argType);
  3966. return theBuilder.createExtInst(returnType, glslInstSetId,
  3967. GLSLstd450::GLSLstd450FClamp,
  3968. {argId, floatZero, floatOne});
  3969. }
  3970. QualType elemType = {};
  3971. uint32_t vecSize = 0;
  3972. if (TypeTranslator::isVectorType(argType, &elemType, &vecSize)) {
  3973. const uint32_t vecZero = getVecValueZero(elemType, vecSize);
  3974. const uint32_t vecOne = getVecValueOne(elemType, vecSize);
  3975. return theBuilder.createExtInst(returnType, glslInstSetId,
  3976. GLSLstd450::GLSLstd450FClamp,
  3977. {argId, vecZero, vecOne});
  3978. }
  3979. uint32_t numRows = 0, numCols = 0;
  3980. if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
  3981. const uint32_t vecZero = getVecValueZero(elemType, numCols);
  3982. const uint32_t vecOne = getVecValueOne(elemType, numCols);
  3983. const auto actOnEachVec = [this, vecZero, vecOne, glslInstSetId](
  3984. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  3985. return theBuilder.createExtInst(vecType, glslInstSetId,
  3986. GLSLstd450::GLSLstd450FClamp,
  3987. {curRowId, vecZero, vecOne});
  3988. };
  3989. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  3990. }
  3991. emitError("Invalid argument type passed to saturate().");
  3992. return 0;
  3993. }
  3994. uint32_t SPIRVEmitter::processIntrinsicFloatSign(const CallExpr *callExpr) {
  3995. // Import the GLSL.std.450 extended instruction set.
  3996. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  3997. const Expr *arg = callExpr->getArg(0);
  3998. const QualType returnType = callExpr->getType();
  3999. const QualType argType = arg->getType();
  4000. assert(isFloatOrVecMatOfFloatType(argType));
  4001. const uint32_t argTypeId = typeTranslator.translateType(argType);
  4002. const uint32_t argId = doExpr(arg);
  4003. uint32_t floatSignResultId = 0;
  4004. // For matrices, we can perform the instruction on each vector of the matrix.
  4005. if (TypeTranslator::isSpirvAcceptableMatrixType(argType)) {
  4006. const auto actOnEachVec = [this, glslInstSetId](
  4007. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4008. return theBuilder.createExtInst(vecType, glslInstSetId,
  4009. GLSLstd450::GLSLstd450FSign, {curRowId});
  4010. };
  4011. floatSignResultId = processEachVectorInMatrix(arg, argId, actOnEachVec);
  4012. } else {
  4013. floatSignResultId = theBuilder.createExtInst(
  4014. argTypeId, glslInstSetId, GLSLstd450::GLSLstd450FSign, {argId});
  4015. }
  4016. return castToInt(floatSignResultId, arg->getType(), returnType);
  4017. }
  4018. uint32_t SPIRVEmitter::processIntrinsicUsingSpirvInst(
  4019. const CallExpr *callExpr, spv::Op opcode, bool actPerRowForMatrices) {
  4020. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  4021. if (callExpr->getNumArgs() == 1u) {
  4022. const Expr *arg = callExpr->getArg(0);
  4023. const uint32_t argId = doExpr(arg);
  4024. // If the instruction does not operate on matrices, we can perform the
  4025. // instruction on each vector of the matrix.
  4026. if (actPerRowForMatrices &&
  4027. TypeTranslator::isSpirvAcceptableMatrixType(arg->getType())) {
  4028. const auto actOnEachVec = [this, opcode](
  4029. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4030. return theBuilder.createUnaryOp(opcode, vecType, {curRowId});
  4031. };
  4032. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4033. }
  4034. return theBuilder.createUnaryOp(opcode, returnType, {argId});
  4035. } else if (callExpr->getNumArgs() == 2u) {
  4036. const Expr *arg0 = callExpr->getArg(0);
  4037. const uint32_t arg0Id = doExpr(arg0);
  4038. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4039. // If the instruction does not operate on matrices, we can perform the
  4040. // instruction on each vector of the matrix.
  4041. if (actPerRowForMatrices &&
  4042. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4043. const auto actOnEachVec = [this, opcode, arg1Id](
  4044. uint32_t index, uint32_t vecType, uint32_t arg0RowId) {
  4045. const uint32_t arg1RowId =
  4046. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4047. return theBuilder.createBinaryOp(opcode, vecType, arg0RowId, arg1RowId);
  4048. };
  4049. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4050. }
  4051. return theBuilder.createBinaryOp(opcode, returnType, arg0Id, arg1Id);
  4052. }
  4053. emitError("Unsupported intrinsic function %0.")
  4054. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  4055. return 0;
  4056. }
  4057. uint32_t SPIRVEmitter::processIntrinsicUsingGLSLInst(
  4058. const CallExpr *callExpr, GLSLstd450 opcode, bool actPerRowForMatrices) {
  4059. // Import the GLSL.std.450 extended instruction set.
  4060. const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
  4061. const uint32_t returnType = typeTranslator.translateType(callExpr->getType());
  4062. if (callExpr->getNumArgs() == 1u) {
  4063. const Expr *arg = callExpr->getArg(0);
  4064. const uint32_t argId = doExpr(arg);
  4065. // If the instruction does not operate on matrices, we can perform the
  4066. // instruction on each vector of the matrix.
  4067. if (actPerRowForMatrices &&
  4068. TypeTranslator::isSpirvAcceptableMatrixType(arg->getType())) {
  4069. const auto actOnEachVec = [this, glslInstSetId, opcode](
  4070. uint32_t /*index*/, uint32_t vecType, uint32_t curRowId) {
  4071. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4072. {curRowId});
  4073. };
  4074. return processEachVectorInMatrix(arg, argId, actOnEachVec);
  4075. }
  4076. return theBuilder.createExtInst(returnType, glslInstSetId, opcode, {argId});
  4077. } else if (callExpr->getNumArgs() == 2u) {
  4078. const Expr *arg0 = callExpr->getArg(0);
  4079. const uint32_t arg0Id = doExpr(arg0);
  4080. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4081. // If the instruction does not operate on matrices, we can perform the
  4082. // instruction on each vector of the matrix.
  4083. if (actPerRowForMatrices &&
  4084. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4085. const auto actOnEachVec = [this, glslInstSetId, opcode, arg1Id](
  4086. uint32_t index, uint32_t vecType, uint32_t arg0RowId) {
  4087. const uint32_t arg1RowId =
  4088. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4089. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4090. {arg0RowId, arg1RowId});
  4091. };
  4092. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4093. }
  4094. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  4095. {arg0Id, arg1Id});
  4096. } else if (callExpr->getNumArgs() == 3u) {
  4097. const Expr *arg0 = callExpr->getArg(0);
  4098. const uint32_t arg0Id = doExpr(arg0);
  4099. const uint32_t arg1Id = doExpr(callExpr->getArg(1));
  4100. const uint32_t arg2Id = doExpr(callExpr->getArg(2));
  4101. // If the instruction does not operate on matrices, we can perform the
  4102. // instruction on each vector of the matrix.
  4103. if (actPerRowForMatrices &&
  4104. TypeTranslator::isSpirvAcceptableMatrixType(arg0->getType())) {
  4105. const auto actOnEachVec = [this, glslInstSetId, opcode, arg0Id, arg1Id,
  4106. arg2Id](uint32_t index, uint32_t vecType,
  4107. uint32_t arg0RowId) {
  4108. const uint32_t arg1RowId =
  4109. theBuilder.createCompositeExtract(vecType, arg1Id, {index});
  4110. const uint32_t arg2RowId =
  4111. theBuilder.createCompositeExtract(vecType, arg2Id, {index});
  4112. return theBuilder.createExtInst(vecType, glslInstSetId, opcode,
  4113. {arg0RowId, arg1RowId, arg2RowId});
  4114. };
  4115. return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
  4116. }
  4117. return theBuilder.createExtInst(returnType, glslInstSetId, opcode,
  4118. {arg0Id, arg1Id, arg2Id});
  4119. }
  4120. emitError("Unsupported intrinsic function %0.")
  4121. << cast<DeclRefExpr>(callExpr->getCallee())->getNameInfo().getAsString();
  4122. return 0;
  4123. }
  4124. uint32_t SPIRVEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
  4125. // Since there is no log10 instruction in SPIR-V, we can use:
  4126. // log10(x) = log2(x) * ( 1 / log2(10) )
  4127. // 1 / log2(10) = 0.30103
  4128. const auto scale = theBuilder.getConstantFloat32(0.30103f);
  4129. const auto log2 =
  4130. processIntrinsicUsingGLSLInst(callExpr, GLSLstd450::GLSLstd450Log2, true);
  4131. const auto returnType = callExpr->getType();
  4132. const auto returnTypeId = typeTranslator.translateType(returnType);
  4133. spv::Op scaleOp = TypeTranslator::isScalarType(returnType)
  4134. ? spv::Op::OpFMul
  4135. : TypeTranslator::isVectorType(returnType)
  4136. ? spv::Op::OpVectorTimesScalar
  4137. : spv::Op::OpMatrixTimesScalar;
  4138. return theBuilder.createBinaryOp(scaleOp, returnTypeId, log2, scale);
  4139. }
  4140. uint32_t SPIRVEmitter::getValueZero(QualType type) {
  4141. {
  4142. QualType scalarType = {};
  4143. if (TypeTranslator::isScalarType(type, &scalarType)) {
  4144. if (scalarType->isSignedIntegerType()) {
  4145. return theBuilder.getConstantInt32(0);
  4146. }
  4147. if (scalarType->isUnsignedIntegerType()) {
  4148. return theBuilder.getConstantUint32(0);
  4149. }
  4150. if (scalarType->isFloatingType()) {
  4151. return theBuilder.getConstantFloat32(0.0);
  4152. }
  4153. }
  4154. }
  4155. {
  4156. QualType elemType = {};
  4157. uint32_t size = {};
  4158. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  4159. return getVecValueZero(elemType, size);
  4160. }
  4161. }
  4162. // TODO: Handle getValueZero for MxN matrices.
  4163. emitError("getting value 0 for type '%0' unimplemented")
  4164. << type.getAsString();
  4165. return 0;
  4166. }
  4167. uint32_t SPIRVEmitter::getVecValueZero(QualType elemType, uint32_t size) {
  4168. const uint32_t elemZeroId = getValueZero(elemType);
  4169. if (size == 1)
  4170. return elemZeroId;
  4171. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemZeroId);
  4172. const uint32_t vecType =
  4173. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  4174. return theBuilder.getConstantComposite(vecType, elements);
  4175. }
  4176. uint32_t SPIRVEmitter::getValueOne(QualType type) {
  4177. {
  4178. QualType scalarType = {};
  4179. if (TypeTranslator::isScalarType(type, &scalarType)) {
  4180. // TODO: Support other types such as short, half, etc.
  4181. if (scalarType->isSignedIntegerType()) {
  4182. return theBuilder.getConstantInt32(1);
  4183. }
  4184. if (scalarType->isUnsignedIntegerType()) {
  4185. return theBuilder.getConstantUint32(1);
  4186. }
  4187. if (const auto *builtinType = scalarType->getAs<BuiltinType>()) {
  4188. // TODO: Add support for other types that are not covered yet.
  4189. switch (builtinType->getKind()) {
  4190. case BuiltinType::Double:
  4191. return theBuilder.getConstantFloat64(1.0);
  4192. case BuiltinType::Float:
  4193. return theBuilder.getConstantFloat32(1.0);
  4194. }
  4195. }
  4196. }
  4197. }
  4198. {
  4199. QualType elemType = {};
  4200. uint32_t size = {};
  4201. if (TypeTranslator::isVectorType(type, &elemType, &size)) {
  4202. return getVecValueOne(elemType, size);
  4203. }
  4204. }
  4205. emitError("getting value 1 for type '%0' unimplemented") << type;
  4206. return 0;
  4207. }
  4208. uint32_t SPIRVEmitter::getVecValueOne(QualType elemType, uint32_t size) {
  4209. const uint32_t elemOneId = getValueOne(elemType);
  4210. if (size == 1)
  4211. return elemOneId;
  4212. llvm::SmallVector<uint32_t, 4> elements(size_t(size), elemOneId);
  4213. const uint32_t vecType =
  4214. theBuilder.getVecType(typeTranslator.translateType(elemType), size);
  4215. return theBuilder.getConstantComposite(vecType, elements);
  4216. }
  4217. uint32_t SPIRVEmitter::getMatElemValueOne(QualType type) {
  4218. assert(hlsl::IsHLSLMatType(type));
  4219. const auto elemType = hlsl::GetHLSLMatElementType(type);
  4220. uint32_t rowCount = 0, colCount = 0;
  4221. hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
  4222. if (rowCount == 1 && colCount == 1)
  4223. return getValueOne(elemType);
  4224. if (colCount == 1)
  4225. return getVecValueOne(elemType, rowCount);
  4226. return getVecValueOne(elemType, colCount);
  4227. }
  4228. uint32_t SPIRVEmitter::translateAPValue(const APValue &value,
  4229. const QualType targetType) {
  4230. if (targetType->isBooleanType()) {
  4231. const bool boolValue = value.getInt().getBoolValue();
  4232. return theBuilder.getConstantBool(boolValue);
  4233. }
  4234. if (targetType->isIntegerType()) {
  4235. const llvm::APInt &intValue = value.getInt();
  4236. return translateAPInt(intValue, targetType);
  4237. }
  4238. if (targetType->isFloatingType()) {
  4239. const llvm::APFloat &floatValue = value.getFloat();
  4240. return translateAPFloat(floatValue, targetType);
  4241. }
  4242. if (hlsl::IsHLSLVecType(targetType)) {
  4243. const uint32_t vecType = typeTranslator.translateType(targetType);
  4244. const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
  4245. const auto numElements = value.getVectorLength();
  4246. // Special case for vectors of size 1. SPIR-V doesn't support this vector
  4247. // size so we need to translate it to scalar values.
  4248. if (numElements == 1) {
  4249. return translateAPValue(value.getVectorElt(0), elemType);
  4250. }
  4251. llvm::SmallVector<uint32_t, 4> elements;
  4252. for (uint32_t i = 0; i < numElements; ++i) {
  4253. elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
  4254. }
  4255. return theBuilder.getConstantComposite(vecType, elements);
  4256. }
  4257. emitError("APValue of type '%0' is not supported yet.") << value.getKind();
  4258. value.dump();
  4259. return 0;
  4260. }
  4261. uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
  4262. QualType targetType) {
  4263. if (targetType->isSignedIntegerType()) {
  4264. // Try to see if this integer can be represented in 32-bit
  4265. if (intValue.isSignedIntN(32))
  4266. return theBuilder.getConstantInt32(
  4267. static_cast<int32_t>(intValue.getSExtValue()));
  4268. } else {
  4269. // Try to see if this integer can be represented in 32-bit
  4270. if (intValue.isIntN(32))
  4271. return theBuilder.getConstantUint32(
  4272. static_cast<uint32_t>(intValue.getZExtValue()));
  4273. }
  4274. emitError("APInt for target bitwidth '%0' is not supported yet.")
  4275. << astContext.getIntWidth(targetType);
  4276. return 0;
  4277. }
  4278. uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
  4279. QualType targetType) {
  4280. const auto &semantics = astContext.getFloatTypeSemantics(targetType);
  4281. const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
  4282. switch (bitwidth) {
  4283. case 32:
  4284. return theBuilder.getConstantFloat32(floatValue.convertToFloat());
  4285. case 64:
  4286. return theBuilder.getConstantFloat64(floatValue.convertToDouble());
  4287. default:
  4288. break;
  4289. }
  4290. emitError("APFloat for target bitwidth '%0' is not supported yet.")
  4291. << bitwidth;
  4292. return 0;
  4293. }
  4294. uint32_t SPIRVEmitter::tryToEvaluateAsConst(const Expr *expr) {
  4295. Expr::EvalResult evalResult;
  4296. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  4297. !evalResult.HasSideEffects) {
  4298. return translateAPValue(evalResult.Val, expr->getType());
  4299. }
  4300. return 0;
  4301. }
  4302. spv::ExecutionModel
  4303. SPIRVEmitter::getSpirvShaderStage(const hlsl::ShaderModel &model) {
  4304. // DXIL Models are:
  4305. // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
  4306. // vs_<version> : Vertex Shader : Vertex Shader
  4307. // hs_<version> : Hull Shader : Tassellation Control Shader
  4308. // ds_<version> : Domain Shader : Tessellation Evaluation Shader
  4309. // gs_<version> : Geometry Shader : Geometry Shader
  4310. // ps_<version> : Pixel Shader : Fragment Shader
  4311. // cs_<version> : Compute Shader : Compute Shader
  4312. switch (model.GetKind()) {
  4313. case hlsl::ShaderModel::Kind::Vertex:
  4314. return spv::ExecutionModel::Vertex;
  4315. case hlsl::ShaderModel::Kind::Hull:
  4316. return spv::ExecutionModel::TessellationControl;
  4317. case hlsl::ShaderModel::Kind::Domain:
  4318. return spv::ExecutionModel::TessellationEvaluation;
  4319. case hlsl::ShaderModel::Kind::Geometry:
  4320. return spv::ExecutionModel::Geometry;
  4321. case hlsl::ShaderModel::Kind::Pixel:
  4322. return spv::ExecutionModel::Fragment;
  4323. case hlsl::ShaderModel::Kind::Compute:
  4324. return spv::ExecutionModel::GLCompute;
  4325. default:
  4326. break;
  4327. }
  4328. llvm_unreachable("unknown shader model");
  4329. }
  4330. void SPIRVEmitter::AddRequiredCapabilitiesForShaderModel() {
  4331. if (shaderModel.IsHS() || shaderModel.IsDS()) {
  4332. theBuilder.requireCapability(spv::Capability::Tessellation);
  4333. } else if (shaderModel.IsGS()) {
  4334. theBuilder.requireCapability(spv::Capability::Geometry);
  4335. emitError("Geometry shaders are currently not supported.");
  4336. } else {
  4337. theBuilder.requireCapability(spv::Capability::Shader);
  4338. }
  4339. }
  4340. void SPIRVEmitter::AddExecutionModeForEntryPoint(uint32_t entryPointId) {
  4341. if (shaderModel.IsPS()) {
  4342. theBuilder.addExecutionMode(entryPointId,
  4343. spv::ExecutionMode::OriginUpperLeft, {});
  4344. }
  4345. }
  4346. bool SPIRVEmitter::processHullShaderAttributes(
  4347. const FunctionDecl *decl, uint32_t *numOutputControlPoints) {
  4348. assert(shaderModel.IsHS());
  4349. using namespace spv;
  4350. if (auto *domain = decl->getAttr<HLSLDomainAttr>()) {
  4351. const auto domainType = domain->getDomainType().lower();
  4352. const ExecutionMode hsExecMode =
  4353. llvm::StringSwitch<ExecutionMode>(domainType)
  4354. .Case("tri", ExecutionMode::Triangles)
  4355. .Case("quad", ExecutionMode::Quads)
  4356. .Case("isoline", ExecutionMode::Isolines)
  4357. .Default(ExecutionMode::Max);
  4358. if (hsExecMode == ExecutionMode::Max) {
  4359. emitError("unknown domain type in hull shader", decl->getLocation());
  4360. return false;
  4361. }
  4362. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  4363. }
  4364. if (auto *partitioning = decl->getAttr<HLSLPartitioningAttr>()) {
  4365. // TODO: Could not find an equivalent of "pow2" partitioning scheme in
  4366. // SPIR-V.
  4367. const auto scheme = partitioning->getScheme().lower();
  4368. const ExecutionMode hsExecMode =
  4369. llvm::StringSwitch<ExecutionMode>(scheme)
  4370. .Case("fractional_even", ExecutionMode::SpacingFractionalEven)
  4371. .Case("fractional_odd", ExecutionMode::SpacingFractionalOdd)
  4372. .Case("integer", ExecutionMode::SpacingEqual)
  4373. .Default(ExecutionMode::Max);
  4374. if (hsExecMode == ExecutionMode::Max) {
  4375. emitError("unknown partitioning scheme in hull shader",
  4376. decl->getLocation());
  4377. return false;
  4378. }
  4379. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  4380. }
  4381. if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
  4382. const auto topology = outputTopology->getTopology().lower();
  4383. const ExecutionMode hsExecMode =
  4384. llvm::StringSwitch<ExecutionMode>(topology)
  4385. .Case("point", ExecutionMode::PointMode)
  4386. .Case("triangle_cw", ExecutionMode::VertexOrderCw)
  4387. .Case("triangle_ccw", ExecutionMode::VertexOrderCcw)
  4388. .Default(ExecutionMode::Max);
  4389. // TODO: There is no SPIR-V equivalent for "line" topology. Is it the
  4390. // default?
  4391. if (topology != "line") {
  4392. if (hsExecMode != spv::ExecutionMode::Max) {
  4393. theBuilder.addExecutionMode(entryFunctionId, hsExecMode, {});
  4394. } else {
  4395. emitError("unknown output topology in hull shader",
  4396. decl->getLocation());
  4397. return false;
  4398. }
  4399. }
  4400. }
  4401. if (auto *controlPoints = decl->getAttr<HLSLOutputControlPointsAttr>()) {
  4402. *numOutputControlPoints = controlPoints->getCount();
  4403. theBuilder.addExecutionMode(entryFunctionId,
  4404. spv::ExecutionMode::OutputVertices,
  4405. {*numOutputControlPoints});
  4406. }
  4407. return true;
  4408. }
  4409. bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
  4410. const uint32_t entryFuncId) {
  4411. // These are going to be used for Hull shaders only.
  4412. uint32_t numOutputControlPoints = 0;
  4413. uint32_t outputControlPointIdVal = 0;
  4414. uint32_t primitiveIdVar = 0;
  4415. uint32_t hullMainInputPatchParam = 0;
  4416. // Construct the wrapper function signature.
  4417. const uint32_t voidType = theBuilder.getVoidType();
  4418. const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
  4419. // The wrapper entry function surely does not have pre-assigned <result-id>
  4420. // for it like other functions that got added to the work queue following
  4421. // function calls. And the wrapper is the entry function.
  4422. entryFunctionId =
  4423. theBuilder.beginFunction(funcType, voidType, decl->getName());
  4424. declIdMapper.setEntryFunctionId(entryFunctionId);
  4425. // Handle translation of numthreads attribute for compute shaders.
  4426. if (shaderModel.IsCS()) {
  4427. // Number of threads attributes are stored as integers. We cast them to
  4428. // uint32_t to pass to OpExecutionMode SPIR-V instruction.
  4429. if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
  4430. theBuilder.addExecutionMode(
  4431. entryFunctionId, spv::ExecutionMode::LocalSize,
  4432. {static_cast<uint32_t>(numThreadsAttr->getX()),
  4433. static_cast<uint32_t>(numThreadsAttr->getY()),
  4434. static_cast<uint32_t>(numThreadsAttr->getZ())});
  4435. } else {
  4436. theBuilder.addExecutionMode(entryFunctionId,
  4437. spv::ExecutionMode::LocalSize, {1, 1, 1});
  4438. }
  4439. } else if (shaderModel.IsHS()) {
  4440. if (!processHullShaderAttributes(decl, &numOutputControlPoints))
  4441. return false;
  4442. }
  4443. // The entry basic block.
  4444. const uint32_t entryLabel = theBuilder.createBasicBlock();
  4445. theBuilder.setInsertPoint(entryLabel);
  4446. // Initialize all global variables at the beginning of the wrapper
  4447. for (const VarDecl *varDecl : toInitGloalVars)
  4448. theBuilder.createStore(declIdMapper.getDeclResultId(varDecl),
  4449. doExpr(varDecl->getInit()));
  4450. // Create temporary variables for holding function call arguments
  4451. llvm::SmallVector<uint32_t, 4> params;
  4452. for (const auto *param : decl->params()) {
  4453. const uint32_t typeId = typeTranslator.translateType(param->getType());
  4454. std::string tempVarName = "param.var." + param->getNameAsString();
  4455. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  4456. params.push_back(tempVar);
  4457. // Create the stage input variable for parameter not marked as pure out and
  4458. // initialize the corresponding temporary variable
  4459. if (!param->getAttr<HLSLOutAttr>()) {
  4460. uint32_t loadedValue = 0;
  4461. if (TypeTranslator::isInputPatch(param->getType())) {
  4462. const uint32_t hullInputPatchId =
  4463. declIdMapper.createStageVarWithoutSemantics(
  4464. /*isInput*/ true, typeId, "hullEntryPointInput",
  4465. decl->getAttr<VKLocationAttr>());
  4466. loadedValue = theBuilder.createLoad(typeId, hullInputPatchId);
  4467. hullMainInputPatchParam = tempVar;
  4468. } else if (!declIdMapper.createStageInputVar(param, &loadedValue,
  4469. /*isPC*/ false)) {
  4470. return false;
  4471. }
  4472. theBuilder.createStore(tempVar, loadedValue);
  4473. if (hasSemantic(param, hlsl::DXIL::SemanticKind::OutputControlPointID))
  4474. outputControlPointIdVal = loadedValue;
  4475. if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID))
  4476. primitiveIdVar = tempVar;
  4477. }
  4478. }
  4479. // Call the original entry function
  4480. const uint32_t retType = typeTranslator.translateType(decl->getReturnType());
  4481. const uint32_t retVal =
  4482. theBuilder.createFunctionCall(retType, entryFuncId, params);
  4483. // Create and write stage output variables for return value. Special case for
  4484. // Hull shaders since they operate differently in 2 ways:
  4485. // 1- Their return value is in fact an array and each invocation should write
  4486. // to the proper offset in the array.
  4487. // 2- The patch constant function must be called *once* after all invocations
  4488. // of the main entry point function is done.
  4489. if (shaderModel.IsHS()) {
  4490. if (!processHullEntryPointOutputAndPatchConstFunc(
  4491. decl, retType, retVal, numOutputControlPoints,
  4492. outputControlPointIdVal, primitiveIdVar, hullMainInputPatchParam))
  4493. return false;
  4494. } else {
  4495. if (!declIdMapper.createStageOutputVar(decl, retVal, /*isPC*/ false))
  4496. return false;
  4497. }
  4498. // Create and write stage output variables for parameters marked as
  4499. // out/inout
  4500. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  4501. const auto *param = decl->getParamDecl(i);
  4502. if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
  4503. // Load the value from the parameter after function call
  4504. const uint32_t typeId = typeTranslator.translateType(param->getType());
  4505. const uint32_t loadedParam = theBuilder.createLoad(typeId, params[i]);
  4506. if (!declIdMapper.createStageOutputVar(param, loadedParam,
  4507. /*isPC*/ false))
  4508. return false;
  4509. }
  4510. }
  4511. theBuilder.createReturn();
  4512. theBuilder.endFunction();
  4513. // For Hull shaders, there is no explicit call to the PCF in the HLSL source.
  4514. // We should invoke a translation of the PCF manually.
  4515. if (shaderModel.IsHS())
  4516. doDecl(patchConstFunc);
  4517. return true;
  4518. }
  4519. bool SPIRVEmitter::processHullEntryPointOutputAndPatchConstFunc(
  4520. const FunctionDecl *hullMainFuncDecl, uint32_t retType, uint32_t retVal,
  4521. uint32_t numOutputControlPoints, uint32_t outputControlPointId,
  4522. uint32_t primitiveId, uint32_t hullMainInputPatch) {
  4523. // This method may only be called for Hull shaders.
  4524. assert(shaderModel.IsHS());
  4525. uint32_t hullMainOutputPatch = 0;
  4526. // For Hull shaders, the real output is an array of size
  4527. // numOutputControlPoints. The results of the main should be written to the
  4528. // correct offset in the array (based on InvocationID).
  4529. if (!numOutputControlPoints) {
  4530. emitError("number of output control points cannot be zero",
  4531. hullMainFuncDecl->getLocation());
  4532. return false;
  4533. }
  4534. // TODO: We should be able to handle cases where the SV_OutputControlPointID
  4535. // is not provided.
  4536. if (!outputControlPointId) {
  4537. emitError(
  4538. "SV_OutputControlPointID semantic must be provided in the hull shader",
  4539. hullMainFuncDecl->getLocation());
  4540. return false;
  4541. }
  4542. if (!patchConstFunc) {
  4543. emitError("patch constant function not defined in hull shader",
  4544. hullMainFuncDecl->getLocation());
  4545. return false;
  4546. }
  4547. // Let's call the return value of the Hull entry point function
  4548. // "hllEntryPointOutput". The type of hullEntryPointOutput should be an
  4549. // array of size numOutputControlPoints.
  4550. const uint32_t hullEntryPointOutputType = theBuilder.getArrayType(
  4551. retType, theBuilder.getConstantUint32(numOutputControlPoints));
  4552. const auto loc = hullMainFuncDecl->getAttr<VKLocationAttr>();
  4553. const auto hullOutputVar = declIdMapper.createStageVarWithoutSemantics(
  4554. /*isInput*/ false, hullEntryPointOutputType, "hullEntryPointOutput", loc);
  4555. if (!hullOutputVar)
  4556. return false;
  4557. // Write the results into the correct Output array offset.
  4558. const auto location = theBuilder.createAccessChain(
  4559. theBuilder.getPointerType(retType, spv::StorageClass::Output),
  4560. hullOutputVar, {outputControlPointId});
  4561. theBuilder.createStore(location, retVal);
  4562. // If the patch constant function (PCF) takes the result of the Hull main
  4563. // entry point, create a temporary function-scope variable and write the
  4564. // results to it, so it can be passed to the PCF.
  4565. if (patchConstFuncTakesHullOutputPatch(patchConstFunc)) {
  4566. hullMainOutputPatch = theBuilder.addFnVar(hullEntryPointOutputType,
  4567. "temp.var.hullEntryPointOutput");
  4568. const auto tempLocation = theBuilder.createAccessChain(
  4569. theBuilder.getPointerType(retType, spv::StorageClass::Function),
  4570. hullMainOutputPatch, {outputControlPointId});
  4571. theBuilder.createStore(tempLocation, retVal);
  4572. }
  4573. // Now create a barrier before calling the Patch Constant Function (PCF).
  4574. // Flags are:
  4575. // Execution Barrier scope = Workgroup (2)
  4576. // Memory Barrier scope = Device (1)
  4577. // Memory Semantics Barrier scope = None (0)
  4578. theBuilder.createControlBarrier(theBuilder.getConstantUint32(2),
  4579. theBuilder.getConstantUint32(1),
  4580. theBuilder.getConstantUint32(0));
  4581. // The PCF should be called only once. Therefore, we check the invocationID,
  4582. // and we only allow ID 0 to call the PCF.
  4583. const uint32_t condition = theBuilder.createBinaryOp(
  4584. spv::Op::OpIEqual, theBuilder.getBoolType(), outputControlPointId,
  4585. theBuilder.getConstantUint32(0));
  4586. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  4587. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  4588. theBuilder.createConditionalBranch(condition, thenBB, mergeBB, mergeBB);
  4589. theBuilder.addSuccessor(thenBB);
  4590. theBuilder.addSuccessor(mergeBB);
  4591. theBuilder.setMergeTarget(mergeBB);
  4592. theBuilder.setInsertPoint(thenBB);
  4593. // Call the PCF. Since the function is not explicitly called, we must first
  4594. // register an ID for it.
  4595. const uint32_t pcfId = declIdMapper.getOrRegisterFnResultId(patchConstFunc);
  4596. const uint32_t pcfRetType =
  4597. typeTranslator.translateType(patchConstFunc->getReturnType());
  4598. std::vector<uint32_t> pcfParams;
  4599. for (const auto *param : patchConstFunc->parameters()) {
  4600. // Note: According to the HLSL reference, the PCF takes an InputPatch of
  4601. // ControlPoints as well as the PatchID (PrimitiveID). This does not
  4602. // necessarily mean that they are present. There is also no requirement
  4603. // for the order of parameters passed to PCF.
  4604. if (TypeTranslator::isInputPatch(param->getType()))
  4605. pcfParams.push_back(hullMainInputPatch);
  4606. if (TypeTranslator::isOutputPatch(param->getType()))
  4607. pcfParams.push_back(hullMainOutputPatch);
  4608. if (hasSemantic(param, hlsl::DXIL::SemanticKind::PrimitiveID)) {
  4609. if (!primitiveId) {
  4610. const uint32_t typeId = typeTranslator.translateType(param->getType());
  4611. std::string tempVarName = "param.var." + param->getNameAsString();
  4612. const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
  4613. uint32_t loadedValue = 0;
  4614. declIdMapper.createStageInputVar(param, &loadedValue, /*isPC*/ true);
  4615. theBuilder.createStore(tempVar, loadedValue);
  4616. primitiveId = tempVar;
  4617. }
  4618. pcfParams.push_back(primitiveId);
  4619. }
  4620. }
  4621. const uint32_t pcfResultId =
  4622. theBuilder.createFunctionCall(pcfRetType, pcfId, {pcfParams});
  4623. if (!declIdMapper.createStageOutputVar(patchConstFunc, pcfResultId,
  4624. /*isPC*/ true))
  4625. return false;
  4626. theBuilder.createBranch(mergeBB);
  4627. theBuilder.addSuccessor(mergeBB);
  4628. theBuilder.setInsertPoint(mergeBB);
  4629. return true;
  4630. }
  4631. bool SPIRVEmitter::allSwitchCasesAreIntegerLiterals(const Stmt *root) {
  4632. if (!root)
  4633. return false;
  4634. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  4635. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  4636. if (!caseStmt && !compoundStmt)
  4637. return true;
  4638. if (caseStmt) {
  4639. const Expr *caseExpr = caseStmt->getLHS();
  4640. return caseExpr && caseExpr->isEvaluatable(astContext);
  4641. }
  4642. // Recurse down if facing a compound statement.
  4643. for (auto *st : compoundStmt->body())
  4644. if (!allSwitchCasesAreIntegerLiterals(st))
  4645. return false;
  4646. return true;
  4647. }
  4648. void SPIRVEmitter::discoverAllCaseStmtInSwitchStmt(
  4649. const Stmt *root, uint32_t *defaultBB,
  4650. std::vector<std::pair<uint32_t, uint32_t>> *targets) {
  4651. if (!root)
  4652. return;
  4653. // A switch case can only appear in DefaultStmt, CaseStmt, or
  4654. // CompoundStmt. For the rest, we can just return.
  4655. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  4656. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  4657. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  4658. if (!defaultStmt && !caseStmt && !compoundStmt)
  4659. return;
  4660. // Recurse down if facing a compound statement.
  4661. if (compoundStmt) {
  4662. for (auto *st : compoundStmt->body())
  4663. discoverAllCaseStmtInSwitchStmt(st, defaultBB, targets);
  4664. return;
  4665. }
  4666. std::string caseLabel;
  4667. uint32_t caseValue = 0;
  4668. if (defaultStmt) {
  4669. // This is the default branch.
  4670. caseLabel = "switch.default";
  4671. } else if (caseStmt) {
  4672. // This is a non-default case.
  4673. // When using OpSwitch, we only allow integer literal cases. e.g:
  4674. // case <literal_integer>: {...; break;}
  4675. const Expr *caseExpr = caseStmt->getLHS();
  4676. assert(caseExpr && caseExpr->isEvaluatable(astContext));
  4677. auto bitWidth = astContext.getIntWidth(caseExpr->getType());
  4678. if (bitWidth != 32)
  4679. emitError("Switch statement translation currently only supports 32-bit "
  4680. "integer case values.");
  4681. Expr::EvalResult evalResult;
  4682. caseExpr->EvaluateAsRValue(evalResult, astContext);
  4683. const int64_t value = evalResult.Val.getInt().getSExtValue();
  4684. caseValue = static_cast<uint32_t>(value);
  4685. caseLabel = "switch." + std::string(value < 0 ? "n" : "") +
  4686. llvm::itostr(std::abs(value));
  4687. }
  4688. const uint32_t caseBB = theBuilder.createBasicBlock(caseLabel);
  4689. theBuilder.addSuccessor(caseBB);
  4690. stmtBasicBlock[root] = caseBB;
  4691. // Add all cases to the 'targets' vector.
  4692. if (caseStmt)
  4693. targets->emplace_back(caseValue, caseBB);
  4694. // The default label is not part of the 'targets' vector that is passed
  4695. // to the OpSwitch instruction.
  4696. // If default statement was discovered, return its label via defaultBB.
  4697. if (defaultStmt)
  4698. *defaultBB = caseBB;
  4699. // Process cases nested in other cases. It happens when we have fall through
  4700. // cases. For example:
  4701. // case 1: case 2: ...; break;
  4702. // will result in the CaseSmt for case 2 nested in the one for case 1.
  4703. discoverAllCaseStmtInSwitchStmt(caseStmt ? caseStmt->getSubStmt()
  4704. : defaultStmt->getSubStmt(),
  4705. defaultBB, targets);
  4706. }
  4707. void SPIRVEmitter::flattenSwitchStmtAST(const Stmt *root,
  4708. std::vector<const Stmt *> *flatSwitch) {
  4709. const auto *caseStmt = dyn_cast<CaseStmt>(root);
  4710. const auto *compoundStmt = dyn_cast<CompoundStmt>(root);
  4711. const auto *defaultStmt = dyn_cast<DefaultStmt>(root);
  4712. if (!compoundStmt) {
  4713. flatSwitch->push_back(root);
  4714. }
  4715. if (compoundStmt) {
  4716. for (const auto *st : compoundStmt->body())
  4717. flattenSwitchStmtAST(st, flatSwitch);
  4718. } else if (caseStmt) {
  4719. flattenSwitchStmtAST(caseStmt->getSubStmt(), flatSwitch);
  4720. } else if (defaultStmt) {
  4721. flattenSwitchStmtAST(defaultStmt->getSubStmt(), flatSwitch);
  4722. }
  4723. }
  4724. void SPIRVEmitter::processCaseStmtOrDefaultStmt(const Stmt *stmt) {
  4725. auto *caseStmt = dyn_cast<CaseStmt>(stmt);
  4726. auto *defaultStmt = dyn_cast<DefaultStmt>(stmt);
  4727. assert(caseStmt || defaultStmt);
  4728. uint32_t caseBB = stmtBasicBlock[stmt];
  4729. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  4730. // We are about to handle the case passed in as parameter. If the current
  4731. // basic block is not terminated, it means the previous case is a fall
  4732. // through case. We need to link it to the case to be processed.
  4733. theBuilder.createBranch(caseBB);
  4734. theBuilder.addSuccessor(caseBB);
  4735. }
  4736. theBuilder.setInsertPoint(caseBB);
  4737. doStmt(caseStmt ? caseStmt->getSubStmt() : defaultStmt->getSubStmt());
  4738. }
  4739. void SPIRVEmitter::processSwitchStmtUsingSpirvOpSwitch(
  4740. const SwitchStmt *switchStmt) {
  4741. // First handle the condition variable DeclStmt if one exists.
  4742. // For example: handle 'int a = b' in the following:
  4743. // switch (int a = b) {...}
  4744. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  4745. doDeclStmt(condVarDeclStmt);
  4746. const uint32_t selector = doExpr(switchStmt->getCond());
  4747. // We need a merge block regardless of the number of switch cases.
  4748. // Since OpSwitch always requires a default label, if the switch statement
  4749. // does not have a default branch, we use the merge block as the default
  4750. // target.
  4751. const uint32_t mergeBB = theBuilder.createBasicBlock("switch.merge");
  4752. theBuilder.setMergeTarget(mergeBB);
  4753. breakStack.push(mergeBB);
  4754. uint32_t defaultBB = mergeBB;
  4755. // (literal, labelId) pairs to pass to the OpSwitch instruction.
  4756. std::vector<std::pair<uint32_t, uint32_t>> targets;
  4757. discoverAllCaseStmtInSwitchStmt(switchStmt->getBody(), &defaultBB, &targets);
  4758. // Create the OpSelectionMerge and OpSwitch.
  4759. theBuilder.createSwitch(mergeBB, selector, defaultBB, targets);
  4760. // Handle the switch body.
  4761. doStmt(switchStmt->getBody());
  4762. if (!theBuilder.isCurrentBasicBlockTerminated())
  4763. theBuilder.createBranch(mergeBB);
  4764. theBuilder.setInsertPoint(mergeBB);
  4765. breakStack.pop();
  4766. }
  4767. void SPIRVEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
  4768. std::vector<const Stmt *> flatSwitch;
  4769. flattenSwitchStmtAST(switchStmt->getBody(), &flatSwitch);
  4770. // First handle the condition variable DeclStmt if one exists.
  4771. // For example: handle 'int a = b' in the following:
  4772. // switch (int a = b) {...}
  4773. if (const auto *condVarDeclStmt = switchStmt->getConditionVariableDeclStmt())
  4774. doDeclStmt(condVarDeclStmt);
  4775. // Figure out the indexes of CaseStmts (and DefaultStmt if it exists) in
  4776. // the flattened switch AST.
  4777. // For instance, for the following flat vector:
  4778. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  4779. // |Case1|Stmt1|Case2|Stmt2|Break|Case3|Case4|Stmt4|Break|Default|Stmt5|
  4780. // +-----+-----+-----+-----+-----+-----+-----+-----+-----+-------+-----+
  4781. // The indexes are: {0, 2, 5, 6, 9}
  4782. std::vector<uint32_t> caseStmtLocs;
  4783. for (uint32_t i = 0; i < flatSwitch.size(); ++i)
  4784. if (isa<CaseStmt>(flatSwitch[i]) || isa<DefaultStmt>(flatSwitch[i]))
  4785. caseStmtLocs.push_back(i);
  4786. IfStmt *prevIfStmt = nullptr;
  4787. IfStmt *rootIfStmt = nullptr;
  4788. CompoundStmt *defaultBody = nullptr;
  4789. // For each case, start at its index in the vector, and go forward
  4790. // accumulating statements until BreakStmt or end of vector is reached.
  4791. for (auto curCaseIndex : caseStmtLocs) {
  4792. const Stmt *curCase = flatSwitch[curCaseIndex];
  4793. // CompoundStmt to hold all statements for this case.
  4794. CompoundStmt *cs = new (astContext) CompoundStmt(Stmt::EmptyShell());
  4795. // Accumulate all non-case/default/break statements as the body for the
  4796. // current case.
  4797. std::vector<Stmt *> statements;
  4798. for (int i = curCaseIndex + 1;
  4799. i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
  4800. if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
  4801. statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
  4802. }
  4803. if (!statements.empty())
  4804. cs->setStmts(astContext, statements.data(), statements.size());
  4805. // For non-default cases, generate the IfStmt that compares the switch
  4806. // value to the case value.
  4807. if (auto *caseStmt = dyn_cast<CaseStmt>(curCase)) {
  4808. IfStmt *curIf = new (astContext) IfStmt(Stmt::EmptyShell());
  4809. BinaryOperator *bo = new (astContext) BinaryOperator(Stmt::EmptyShell());
  4810. bo->setLHS(const_cast<Expr *>(switchStmt->getCond()));
  4811. bo->setRHS(const_cast<Expr *>(caseStmt->getLHS()));
  4812. bo->setOpcode(BO_EQ);
  4813. bo->setType(astContext.getLogicalOperationType());
  4814. curIf->setCond(bo);
  4815. curIf->setThen(cs);
  4816. // No conditional variable associated with this faux if statement.
  4817. curIf->setConditionVariable(astContext, nullptr);
  4818. // Each If statement is the "else" of the previous if statement.
  4819. if (prevIfStmt)
  4820. prevIfStmt->setElse(curIf);
  4821. else
  4822. rootIfStmt = curIf;
  4823. prevIfStmt = curIf;
  4824. } else {
  4825. // Record the DefaultStmt body as it will be used as the body of the
  4826. // "else" block in the if-elseif-...-else pattern.
  4827. defaultBody = cs;
  4828. }
  4829. }
  4830. // If a default case exists, it is the "else" of the last if statement.
  4831. if (prevIfStmt)
  4832. prevIfStmt->setElse(defaultBody);
  4833. // Since all else-if and else statements are the child nodes of the first
  4834. // IfStmt, we only need to call doStmt for the first IfStmt.
  4835. if (rootIfStmt)
  4836. doStmt(rootIfStmt);
  4837. // If there are no CaseStmt and there is only 1 DefaultStmt, there will be
  4838. // no if statements. The switch in that case only executes the body of the
  4839. // default case.
  4840. else if (defaultBody)
  4841. doStmt(defaultBody);
  4842. }
  4843. } // end namespace spirv
  4844. } // end namespace clang