DeclResultIdMapper.cpp 147 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824
  1. //===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "DeclResultIdMapper.h"
  10. #include <sstream>
  11. #include "dxc/DXIL/DxilConstants.h"
  12. #include "dxc/DXIL/DxilTypeSystem.h"
  13. #include "clang/AST/Expr.h"
  14. #include "clang/AST/HlslTypes.h"
  15. #include "clang/SPIRV/AstTypeProbe.h"
  16. #include "llvm/ADT/SmallBitVector.h"
  17. #include "llvm/ADT/StringMap.h"
  18. #include "llvm/ADT/StringSet.h"
  19. #include "llvm/Support/Casting.h"
  20. #include "AlignmentSizeCalculator.h"
  21. #include "SpirvEmitter.h"
  22. namespace clang {
  23. namespace spirv {
  24. namespace {
  25. uint32_t getVkBindingAttrSet(const VKBindingAttr *attr, uint32_t defaultSet) {
  26. // If the [[vk::binding(x)]] attribute is provided without the descriptor set,
  27. // we should use the default descriptor set.
  28. if (attr->getSet() == INT_MIN) {
  29. return defaultSet;
  30. }
  31. return attr->getSet();
  32. }
  33. /// Returns the :packoffset() annotation on the given decl. Returns nullptr if
  34. /// the decl does not have one.
  35. hlsl::ConstantPacking *getPackOffset(const clang::NamedDecl *decl) {
  36. for (auto *annotation : decl->getUnusualAnnotations())
  37. if (auto *packing = llvm::dyn_cast<hlsl::ConstantPacking>(annotation))
  38. return packing;
  39. return nullptr;
  40. }
  41. /// Returns the number of binding numbers that are used up by the given type.
  42. /// An array of size N consumes N*M binding numbers where M is the number of
  43. /// binding numbers used by each array element.
  44. /// The number of binding numbers used by a structure is the sum of binding
  45. /// numbers used by its members.
  46. uint32_t getNumBindingsUsedByResourceType(QualType type) {
  47. // For custom-generated types that have SpirvType but no QualType.
  48. if (type.isNull())
  49. return 1;
  50. // For every array dimension, the number of bindings needed should be
  51. // multiplied by the array size. For example: an array of two Textures should
  52. // use 2 binding slots.
  53. uint32_t arrayFactor = 1;
  54. while (auto constArrayType = dyn_cast<ConstantArrayType>(type)) {
  55. arrayFactor *=
  56. static_cast<uint32_t>(constArrayType->getSize().getZExtValue());
  57. type = constArrayType->getElementType();
  58. }
  59. // Once we remove the arrayness, we expect the given type to be either a
  60. // resource OR a structure that only contains resources.
  61. assert(isResourceType(type) || isResourceOnlyStructure(type));
  62. // In the case of a resource, each resource takes 1 binding slot, so in total
  63. // it consumes: 1 * arrayFactor.
  64. if (isResourceType(type))
  65. return arrayFactor;
  66. // In the case of a struct of resources, we need to sum up the number of
  67. // bindings for the struct members. So in total it consumes:
  68. // sum(bindings of struct members) * arrayFactor.
  69. if (isResourceOnlyStructure(type)) {
  70. uint32_t sumOfMemberBindings = 0;
  71. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  72. assert(structDecl);
  73. for (const auto *field : structDecl->fields())
  74. sumOfMemberBindings += getNumBindingsUsedByResourceType(field->getType());
  75. return sumOfMemberBindings * arrayFactor;
  76. }
  77. llvm_unreachable(
  78. "getNumBindingsUsedByResourceType was called with unknown resource type");
  79. }
  80. QualType getUintTypeWithSourceComponents(const ASTContext &astContext,
  81. QualType sourceType) {
  82. if (isScalarType(sourceType)) {
  83. return astContext.UnsignedIntTy;
  84. }
  85. uint32_t elemCount = 0;
  86. if (isVectorType(sourceType, nullptr, &elemCount)) {
  87. return astContext.getExtVectorType(astContext.UnsignedIntTy, elemCount);
  88. }
  89. llvm_unreachable("only scalar and vector types are supported in "
  90. "getUintTypeWithSourceComponents");
  91. }
  92. uint32_t getLocationCount(const ASTContext &astContext, QualType type) {
  93. // See Vulkan spec 14.1.4. Location Assignment for the complete set of rules.
  94. const auto canonicalType = type.getCanonicalType();
  95. if (canonicalType != type)
  96. return getLocationCount(astContext, canonicalType);
  97. // Inputs and outputs of the following types consume a single interface
  98. // location:
  99. // * 16-bit scalar and vector types, and
  100. // * 32-bit scalar and vector types, and
  101. // * 64-bit scalar and 2-component vector types.
  102. // 64-bit three- and four- component vectors consume two consecutive
  103. // locations.
  104. // Primitive types
  105. if (isScalarType(type))
  106. return 1;
  107. // Vector types
  108. {
  109. QualType elemType = {};
  110. uint32_t elemCount = {};
  111. if (isVectorType(type, &elemType, &elemCount)) {
  112. const auto *builtinType = elemType->getAs<BuiltinType>();
  113. switch (builtinType->getKind()) {
  114. case BuiltinType::Double:
  115. case BuiltinType::LongLong:
  116. case BuiltinType::ULongLong:
  117. if (elemCount >= 3)
  118. return 2;
  119. default:
  120. // Filter switch only interested in types occupying 2 locations.
  121. break;
  122. }
  123. return 1;
  124. }
  125. }
  126. // If the declared input or output is an n * m 16- , 32- or 64- bit matrix,
  127. // it will be assigned multiple locations starting with the location
  128. // specified. The number of locations assigned for each matrix will be the
  129. // same as for an n-element array of m-component vectors.
  130. // Matrix types
  131. {
  132. QualType elemType = {};
  133. uint32_t rowCount = 0, colCount = 0;
  134. if (isMxNMatrix(type, &elemType, &rowCount, &colCount))
  135. return getLocationCount(astContext,
  136. astContext.getExtVectorType(elemType, colCount)) *
  137. rowCount;
  138. }
  139. // Typedefs
  140. if (const auto *typedefType = type->getAs<TypedefType>())
  141. return getLocationCount(astContext, typedefType->desugar());
  142. // Reference types
  143. if (const auto *refType = type->getAs<ReferenceType>())
  144. return getLocationCount(astContext, refType->getPointeeType());
  145. // Pointer types
  146. if (const auto *ptrType = type->getAs<PointerType>())
  147. return getLocationCount(astContext, ptrType->getPointeeType());
  148. // If a declared input or output is an array of size n and each element takes
  149. // m locations, it will be assigned m * n consecutive locations starting with
  150. // the location specified.
  151. // Array types
  152. if (const auto *arrayType = astContext.getAsConstantArrayType(type))
  153. return getLocationCount(astContext, arrayType->getElementType()) *
  154. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  155. // Struct type
  156. if (type->getAs<RecordType>()) {
  157. assert(false && "all structs should already be flattened");
  158. return 0;
  159. }
  160. llvm_unreachable(
  161. "calculating number of occupied locations for type unimplemented");
  162. return 0;
  163. }
  164. bool shouldSkipInStructLayout(const Decl *decl) {
  165. // Ignore implicit generated struct declarations/constructors/destructors
  166. if (decl->isImplicit())
  167. return true;
  168. // Ignore embedded type decls
  169. if (isa<TypeDecl>(decl))
  170. return true;
  171. // Ignore embeded function decls
  172. if (isa<FunctionDecl>(decl))
  173. return true;
  174. // Ignore empty decls
  175. if (isa<EmptyDecl>(decl))
  176. return true;
  177. // For the $Globals cbuffer, we only care about externally-visible
  178. // non-resource-type variables. The rest should be filtered out.
  179. const auto *declContext = decl->getDeclContext();
  180. // $Globals' "struct" is the TranslationUnit, so we should ignore resources
  181. // in the TranslationUnit "struct" and its child namespaces.
  182. if (declContext->isTranslationUnit() || declContext->isNamespace()) {
  183. // External visibility
  184. if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
  185. if (!declDecl->hasExternalFormalLinkage())
  186. return true;
  187. // cbuffer/tbuffer
  188. if (isa<HLSLBufferDecl>(decl))
  189. return true;
  190. // 'groupshared' variables should not be placed in $Globals cbuffer.
  191. if (decl->hasAttr<HLSLGroupSharedAttr>())
  192. return true;
  193. // Other resource types
  194. if (const auto *valueDecl = dyn_cast<ValueDecl>(decl)) {
  195. const auto declType = valueDecl->getType();
  196. if (isResourceType(declType) || isResourceOnlyStructure(declType))
  197. return true;
  198. }
  199. }
  200. return false;
  201. }
  202. void collectDeclsInField(const Decl *field,
  203. llvm::SmallVector<const Decl *, 4> *decls) {
  204. // Case of nested namespaces.
  205. if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
  206. for (const auto *decl : nsDecl->decls()) {
  207. collectDeclsInField(decl, decls);
  208. }
  209. }
  210. if (shouldSkipInStructLayout(field))
  211. return;
  212. if (!isa<DeclaratorDecl>(field)) {
  213. return;
  214. }
  215. decls->push_back(field);
  216. }
  217. llvm::SmallVector<const Decl *, 4>
  218. collectDeclsInDeclContext(const DeclContext *declContext) {
  219. llvm::SmallVector<const Decl *, 4> decls;
  220. for (const auto *field : declContext->decls()) {
  221. collectDeclsInField(field, &decls);
  222. }
  223. return decls;
  224. }
  225. /// \brief Returns true if the given decl is a boolean stage I/O variable.
  226. /// Returns false if the type is not boolean, or the decl is a built-in stage
  227. /// variable.
  228. bool isBooleanStageIOVar(const NamedDecl *decl, QualType type,
  229. const hlsl::DXIL::SemanticKind semanticKind,
  230. const hlsl::SigPoint::Kind sigPointKind) {
  231. // [[vk::builtin(...)]] makes the decl a built-in stage variable.
  232. // IsFrontFace (if used as PSIn) is the only known boolean built-in stage
  233. // variable.
  234. const bool isBooleanBuiltin =
  235. (decl->getAttr<VKBuiltInAttr>() != nullptr) ||
  236. (semanticKind == hlsl::Semantic::Kind::IsFrontFace &&
  237. sigPointKind == hlsl::SigPoint::Kind::PSIn);
  238. // TODO: support boolean matrix stage I/O variable if needed.
  239. QualType elemType = {};
  240. const bool isBooleanType =
  241. ((isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
  242. elemType->isBooleanType());
  243. return isBooleanType && !isBooleanBuiltin;
  244. }
  245. /// \brief Returns the stage variable's register assignment for the given Decl.
  246. const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
  247. for (auto *annotation : decl->getUnusualAnnotations()) {
  248. if (auto *reg = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
  249. return reg;
  250. }
  251. }
  252. return nullptr;
  253. }
  254. /// \brief Returns the stage variable's 'register(c#) assignment for the given
  255. /// Decl. Return nullptr if the given variable does not have such assignment.
  256. const hlsl::RegisterAssignment *getRegisterCAssignment(const NamedDecl *decl) {
  257. const auto *regAssignment = getResourceBinding(decl);
  258. if (regAssignment)
  259. return regAssignment->RegisterType == 'c' ? regAssignment : nullptr;
  260. return nullptr;
  261. }
  262. /// \brief Returns true if the given declaration has a primitive type qualifier.
  263. /// Returns false otherwise.
  264. inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
  265. return decl->hasAttr<HLSLTriangleAttr>() ||
  266. decl->hasAttr<HLSLTriangleAdjAttr>() ||
  267. decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
  268. decl->hasAttr<HLSLLineAdjAttr>();
  269. }
  270. /// \brief Deduces the parameter qualifier for the given decl.
  271. hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
  272. bool asInput) {
  273. const auto type = decl->getType();
  274. if (hlsl::IsHLSLInputPatchType(type))
  275. return hlsl::DxilParamInputQual::InputPatch;
  276. if (hlsl::IsHLSLOutputPatchType(type))
  277. return hlsl::DxilParamInputQual::OutputPatch;
  278. // TODO: Add support for multiple output streams.
  279. if (hlsl::IsHLSLStreamOutputType(type))
  280. return hlsl::DxilParamInputQual::OutStream0;
  281. // The inputs to the geometry shader that have a primitive type qualifier
  282. // must use 'InputPrimitive'.
  283. if (hasGSPrimitiveTypeQualifier(decl))
  284. return hlsl::DxilParamInputQual::InputPrimitive;
  285. if (decl->hasAttr<HLSLIndicesAttr>())
  286. return hlsl::DxilParamInputQual::OutIndices;
  287. if (decl->hasAttr<HLSLVerticesAttr>())
  288. return hlsl::DxilParamInputQual::OutVertices;
  289. if (decl->hasAttr<HLSLPrimitivesAttr>())
  290. return hlsl::DxilParamInputQual::OutPrimitives;
  291. if (decl->hasAttr<HLSLPayloadAttr>())
  292. return hlsl::DxilParamInputQual::InPayload;
  293. return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
  294. }
  295. /// \brief Deduces the HLSL SigPoint for the given decl appearing in the given
  296. /// shader model.
  297. const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput,
  298. const hlsl::ShaderModel::Kind kind,
  299. bool forPCF) {
  300. return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  301. deduceParamQual(decl, asInput), kind, forPCF));
  302. }
  303. /// Returns the type of the given decl. If the given decl is a FunctionDecl,
  304. /// returns its result type.
  305. inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
  306. if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  307. return funcDecl->getReturnType();
  308. }
  309. return decl->getType();
  310. }
  311. /// Returns the number of base classes if this type is a derived class/struct.
  312. /// Returns zero otherwise.
  313. inline uint32_t getNumBaseClasses(QualType type) {
  314. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  315. return cxxDecl->getNumBases();
  316. return 0;
  317. }
  318. /// Returns the appropriate storage class for an extern variable of the given
  319. /// type.
  320. spv::StorageClass getStorageClassForExternVar(QualType type,
  321. bool hasGroupsharedAttr) {
  322. // For CS groupshared variables
  323. if (hasGroupsharedAttr)
  324. return spv::StorageClass::Workgroup;
  325. if (isAKindOfStructuredOrByteBuffer(type))
  326. return spv::StorageClass::Uniform;
  327. return spv::StorageClass::UniformConstant;
  328. }
  329. /// Returns the appropriate layout rule for an extern variable of the given
  330. /// type.
  331. SpirvLayoutRule getLayoutRuleForExternVar(QualType type,
  332. const SpirvCodeGenOptions &opts) {
  333. if (isAKindOfStructuredOrByteBuffer(type))
  334. return opts.sBufferLayoutRule;
  335. return SpirvLayoutRule::Void;
  336. }
  337. spv::ImageFormat getSpvImageFormat(const VKImageFormatAttr *imageFormatAttr) {
  338. if (imageFormatAttr == nullptr)
  339. return spv::ImageFormat::Unknown;
  340. switch (imageFormatAttr->getImageFormat()) {
  341. case VKImageFormatAttr::unknown:
  342. return spv::ImageFormat::Unknown;
  343. case VKImageFormatAttr::rgba32f:
  344. return spv::ImageFormat::Rgba32f;
  345. case VKImageFormatAttr::rgba16f:
  346. return spv::ImageFormat::Rgba16f;
  347. case VKImageFormatAttr::r32f:
  348. return spv::ImageFormat::R32f;
  349. case VKImageFormatAttr::rgba8:
  350. return spv::ImageFormat::Rgba8;
  351. case VKImageFormatAttr::rgba8snorm:
  352. return spv::ImageFormat::Rgba8Snorm;
  353. case VKImageFormatAttr::rg32f:
  354. return spv::ImageFormat::Rg32f;
  355. case VKImageFormatAttr::rg16f:
  356. return spv::ImageFormat::Rg16f;
  357. case VKImageFormatAttr::r11g11b10f:
  358. return spv::ImageFormat::R11fG11fB10f;
  359. case VKImageFormatAttr::r16f:
  360. return spv::ImageFormat::R16f;
  361. case VKImageFormatAttr::rgba16:
  362. return spv::ImageFormat::Rgba16;
  363. case VKImageFormatAttr::rgb10a2:
  364. return spv::ImageFormat::Rgb10A2;
  365. case VKImageFormatAttr::rg16:
  366. return spv::ImageFormat::Rg16;
  367. case VKImageFormatAttr::rg8:
  368. return spv::ImageFormat::Rg8;
  369. case VKImageFormatAttr::r16:
  370. return spv::ImageFormat::R16;
  371. case VKImageFormatAttr::r8:
  372. return spv::ImageFormat::R8;
  373. case VKImageFormatAttr::rgba16snorm:
  374. return spv::ImageFormat::Rgba16Snorm;
  375. case VKImageFormatAttr::rg16snorm:
  376. return spv::ImageFormat::Rg16Snorm;
  377. case VKImageFormatAttr::rg8snorm:
  378. return spv::ImageFormat::Rg8Snorm;
  379. case VKImageFormatAttr::r16snorm:
  380. return spv::ImageFormat::R16Snorm;
  381. case VKImageFormatAttr::r8snorm:
  382. return spv::ImageFormat::R8Snorm;
  383. case VKImageFormatAttr::rgba32i:
  384. return spv::ImageFormat::Rgba32i;
  385. case VKImageFormatAttr::rgba16i:
  386. return spv::ImageFormat::Rgba16i;
  387. case VKImageFormatAttr::rgba8i:
  388. return spv::ImageFormat::Rgba8i;
  389. case VKImageFormatAttr::r32i:
  390. return spv::ImageFormat::R32i;
  391. case VKImageFormatAttr::rg32i:
  392. return spv::ImageFormat::Rg32i;
  393. case VKImageFormatAttr::rg16i:
  394. return spv::ImageFormat::Rg16i;
  395. case VKImageFormatAttr::rg8i:
  396. return spv::ImageFormat::Rg8i;
  397. case VKImageFormatAttr::r16i:
  398. return spv::ImageFormat::R16i;
  399. case VKImageFormatAttr::r8i:
  400. return spv::ImageFormat::R8i;
  401. case VKImageFormatAttr::rgba32ui:
  402. return spv::ImageFormat::Rgba32ui;
  403. case VKImageFormatAttr::rgba16ui:
  404. return spv::ImageFormat::Rgba16ui;
  405. case VKImageFormatAttr::rgba8ui:
  406. return spv::ImageFormat::Rgba8ui;
  407. case VKImageFormatAttr::r32ui:
  408. return spv::ImageFormat::R32ui;
  409. case VKImageFormatAttr::rgb10a2ui:
  410. return spv::ImageFormat::Rgb10a2ui;
  411. case VKImageFormatAttr::rg32ui:
  412. return spv::ImageFormat::Rg32ui;
  413. case VKImageFormatAttr::rg16ui:
  414. return spv::ImageFormat::Rg16ui;
  415. case VKImageFormatAttr::rg8ui:
  416. return spv::ImageFormat::Rg8ui;
  417. case VKImageFormatAttr::r16ui:
  418. return spv::ImageFormat::R16ui;
  419. case VKImageFormatAttr::r8ui:
  420. return spv::ImageFormat::R8ui;
  421. case VKImageFormatAttr::r64ui:
  422. return spv::ImageFormat::R64ui;
  423. case VKImageFormatAttr::r64i:
  424. return spv::ImageFormat::R64i;
  425. }
  426. return spv::ImageFormat::Unknown;
  427. }
  428. } // anonymous namespace
  429. std::string StageVar::getSemanticStr() const {
  430. // A special case for zero index, which is equivalent to no index.
  431. // Use what is in the source code.
  432. // TODO: this looks like a hack to make the current tests happy.
  433. // Should consider remove it and fix all tests.
  434. if (semanticInfo.index == 0)
  435. return semanticInfo.str;
  436. std::ostringstream ss;
  437. ss << semanticInfo.name.str() << semanticInfo.index;
  438. return ss.str();
  439. }
  440. SpirvInstruction *CounterIdAliasPair::get(SpirvBuilder &builder,
  441. SpirvContext &spvContext) const {
  442. if (isAlias) {
  443. const auto *counterType = spvContext.getACSBufferCounterType();
  444. const auto *counterVarType =
  445. spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
  446. return builder.createLoad(counterVarType, counterVar,
  447. /* SourceLocation */ {});
  448. }
  449. return counterVar;
  450. }
  451. const CounterIdAliasPair *
  452. CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
  453. for (const auto &field : fields)
  454. if (field.indices == indices)
  455. return &field.counterVar;
  456. return nullptr;
  457. }
  458. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  459. SpirvBuilder &builder,
  460. SpirvContext &context) const {
  461. for (const auto &field : fields) {
  462. const auto *srcField = srcFields.get(field.indices);
  463. if (!srcField)
  464. return false;
  465. field.counterVar.assign(*srcField, builder, context);
  466. }
  467. return true;
  468. }
  469. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  470. const llvm::SmallVector<uint32_t, 4> &dstPrefix,
  471. const llvm::SmallVector<uint32_t, 4> &srcPrefix,
  472. SpirvBuilder &builder,
  473. SpirvContext &context) const {
  474. if (dstPrefix.empty() && srcPrefix.empty())
  475. return assign(srcFields, builder, context);
  476. llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
  477. // If whole has the given prefix, appends all elements after the prefix in
  478. // whole to srcIndices.
  479. const auto applyDiff =
  480. [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
  481. const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
  482. uint32_t i = 0;
  483. for (; i < prefix.size(); ++i)
  484. if (whole[i] != prefix[i]) {
  485. break;
  486. }
  487. if (i == prefix.size()) {
  488. for (; i < whole.size(); ++i)
  489. srcIndices.push_back(whole[i]);
  490. return true;
  491. }
  492. return false;
  493. };
  494. for (const auto &field : fields)
  495. if (applyDiff(field.indices, dstPrefix)) {
  496. const auto *srcField = srcFields.get(srcIndices);
  497. if (!srcField)
  498. return false;
  499. field.counterVar.assign(*srcField, builder, context);
  500. for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
  501. srcIndices.pop_back();
  502. }
  503. return true;
  504. }
  505. SemanticInfo DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
  506. for (auto *annotation : decl->getUnusualAnnotations()) {
  507. if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
  508. llvm::StringRef semanticStr = sema->SemanticName;
  509. llvm::StringRef semanticName;
  510. uint32_t index = 0;
  511. hlsl::Semantic::DecomposeNameAndIndex(semanticStr, &semanticName, &index);
  512. const auto *semantic = hlsl::Semantic::GetByName(semanticName);
  513. return {semanticStr, semantic, semanticName, index, sema->Loc};
  514. }
  515. }
  516. return {};
  517. }
  518. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  519. SpirvInstruction *storedValue,
  520. bool forPCF) {
  521. QualType type = getTypeOrFnRetType(decl);
  522. uint32_t arraySize = 0;
  523. // Output stream types (PointStream, LineStream, TriangleStream) are
  524. // translated as their underlying struct types.
  525. if (hlsl::IsHLSLStreamOutputType(type))
  526. type = hlsl::GetHLSLResourceResultType(type);
  527. if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLVerticesAttr>() ||
  528. decl->hasAttr<HLSLPrimitivesAttr>()) {
  529. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  530. type = typeDecl->getElementType();
  531. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  532. if (decl->hasAttr<HLSLIndicesAttr>()) {
  533. // create SPIR-V builtin array PrimitiveIndicesNV of type
  534. // "uint [MaxPrimitiveCount * verticesPerPrim]"
  535. uint32_t verticesPerPrim = 1;
  536. if (!isVectorType(type, nullptr, &verticesPerPrim)) {
  537. assert(isScalarType(type));
  538. }
  539. arraySize = arraySize * verticesPerPrim;
  540. QualType arrayType = astContext.getConstantArrayType(
  541. astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
  542. clang::ArrayType::Normal, 0);
  543. stageVarInstructions[cast<DeclaratorDecl>(decl)] = getBuiltinVar(
  544. spv::BuiltIn::PrimitiveIndicesNV, arrayType, decl->getLocation());
  545. return true;
  546. }
  547. }
  548. const auto *sigPoint = deduceSigPoint(
  549. decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
  550. // HS output variables are created using the other overload. For the rest,
  551. // none of them should be created as arrays.
  552. assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
  553. SemanticInfo inheritSemantic = {};
  554. // If storedValue is 0, it means this parameter in the original source code is
  555. // not used at all. Avoid writing back.
  556. //
  557. // Write back of stage output variables in GS is manually controlled by
  558. // .Append() intrinsic method, implemented in writeBackOutputStream(). So
  559. // ignoreValue should be set to true for GS.
  560. const bool noWriteBack =
  561. storedValue == nullptr || spvContext.isGS() || spvContext.isMS();
  562. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  563. "out.var", llvm::None, &storedValue, noWriteBack,
  564. &inheritSemantic);
  565. }
  566. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  567. uint32_t arraySize,
  568. SpirvInstruction *invocationId,
  569. SpirvInstruction *storedValue) {
  570. assert(spvContext.isHS());
  571. QualType type = getTypeOrFnRetType(decl);
  572. const auto *sigPoint =
  573. hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
  574. SemanticInfo inheritSemantic = {};
  575. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  576. "out.var", invocationId, &storedValue,
  577. /*noWriteBack=*/false, &inheritSemantic);
  578. }
  579. bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
  580. SpirvInstruction **loadedValue,
  581. bool forPCF) {
  582. uint32_t arraySize = 0;
  583. QualType type = paramDecl->getType();
  584. // Deprive the outermost arrayness for HS/DS/GS and use arraySize
  585. // to convey that information
  586. if (hlsl::IsHLSLInputPatchType(type)) {
  587. arraySize = hlsl::GetHLSLInputPatchCount(type);
  588. type = hlsl::GetHLSLInputPatchElementType(type);
  589. } else if (hlsl::IsHLSLOutputPatchType(type)) {
  590. arraySize = hlsl::GetHLSLOutputPatchCount(type);
  591. type = hlsl::GetHLSLOutputPatchElementType(type);
  592. }
  593. if (hasGSPrimitiveTypeQualifier(paramDecl)) {
  594. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  595. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  596. type = typeDecl->getElementType();
  597. }
  598. const auto *sigPoint =
  599. deduceSigPoint(paramDecl, /*asInput=*/true,
  600. spvContext.getCurrentShaderModelKind(), forPCF);
  601. SemanticInfo inheritSemantic = {};
  602. if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
  603. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  604. return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
  605. type, "in.var", loadedValue);
  606. } else {
  607. return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type,
  608. arraySize, "in.var", llvm::None, loadedValue,
  609. /*noWriteBack=*/false, &inheritSemantic);
  610. }
  611. }
  612. const DeclResultIdMapper::DeclSpirvInfo *
  613. DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
  614. auto it = astDecls.find(decl);
  615. if (it != astDecls.end())
  616. return &it->second;
  617. return nullptr;
  618. }
  619. SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
  620. SourceLocation loc) {
  621. const DeclSpirvInfo *info = getDeclSpirvInfo(decl);
  622. // If DeclSpirvInfo is not found for this decl, it might be because it is an
  623. // implicit VarDecl. All implicit VarDecls are lazily created in order to
  624. // avoid creating large number of unused variables/constants/enums.
  625. if (!info) {
  626. tryToCreateImplicitConstVar(decl);
  627. info = getDeclSpirvInfo(decl);
  628. }
  629. if (info) {
  630. if (info->indexInCTBuffer >= 0) {
  631. // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
  632. // OpAccessChain to get the pointer to the variable since we created
  633. // a single variable for the whole buffer object.
  634. // Should only have VarDecls in a HLSLBufferDecl.
  635. QualType valueType = cast<VarDecl>(decl)->getType();
  636. return spvBuilder.createAccessChain(
  637. valueType, info->instr,
  638. {spvBuilder.getConstantInt(
  639. astContext.IntTy, llvm::APInt(32, info->indexInCTBuffer, true))},
  640. loc);
  641. } else if (auto *type = info->instr->getResultType()) {
  642. const auto *ptrTy = dyn_cast<HybridPointerType>(type);
  643. // If it is a local variable or function parameter with a bindless
  644. // array of an opaque type, we have to load it because we pass a
  645. // pointer of a global variable that has the bindless opaque array.
  646. if (ptrTy != nullptr && isBindlessOpaqueArray(decl->getType())) {
  647. auto *load = spvBuilder.createLoad(ptrTy, info->instr, loc);
  648. load->setRValue(false);
  649. return load;
  650. } else {
  651. return *info;
  652. }
  653. } else {
  654. return *info;
  655. }
  656. }
  657. emitFatalError("found unregistered decl", decl->getLocation())
  658. << decl->getName();
  659. emitNote("please file a bug report on "
  660. "https://github.com/Microsoft/DirectXShaderCompiler/issues with "
  661. "source code if possible",
  662. {});
  663. return 0;
  664. }
  665. SpirvFunctionParameter *
  666. DeclResultIdMapper::createFnParam(const ParmVarDecl *param,
  667. uint32_t dbgArgNumber) {
  668. const auto type = getTypeOrFnRetType(param);
  669. const auto loc = param->getLocation();
  670. const auto name = param->getName();
  671. SpirvFunctionParameter *fnParamInstr = spvBuilder.addFnParam(
  672. type, param->hasAttr<HLSLPreciseAttr>(), loc, param->getName());
  673. bool isAlias = false;
  674. (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
  675. fnParamInstr->setContainsAliasComponent(isAlias);
  676. assert(astDecls[param].instr == nullptr);
  677. astDecls[param].instr = fnParamInstr;
  678. if (spirvOptions.debugInfoRich) {
  679. // Add DebugLocalVariable information
  680. const auto &sm = astContext.getSourceManager();
  681. const uint32_t line = sm.getPresumedLineNumber(loc);
  682. const uint32_t column = sm.getPresumedColumnNumber(loc);
  683. const auto *info = theEmitter.getOrCreateRichDebugInfo(loc);
  684. // TODO: replace this with FlagIsLocal enum.
  685. uint32_t flags = 1 << 2;
  686. auto *debugLocalVar = spvBuilder.createDebugLocalVariable(
  687. type, name, info->source, line, column, info->scopeStack.back(), flags,
  688. dbgArgNumber);
  689. spvBuilder.createDebugDeclare(debugLocalVar, fnParamInstr);
  690. }
  691. return fnParamInstr;
  692. }
  693. void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
  694. const QualType declType = getTypeOrFnRetType(decl);
  695. if (!counterVars.count(decl) && isRWAppendConsumeSBuffer(declType)) {
  696. createCounterVar(decl, /*declId=*/0, /*isAlias=*/true);
  697. } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
  698. // Exclude other resource types which are represented as structs
  699. !hlsl::IsHLSLResourceType(declType)) {
  700. createFieldCounterVars(decl);
  701. }
  702. }
  703. SpirvVariable *
  704. DeclResultIdMapper::createFnVar(const VarDecl *var,
  705. llvm::Optional<SpirvInstruction *> init) {
  706. const auto type = getTypeOrFnRetType(var);
  707. const auto loc = var->getLocation();
  708. const auto name = var->getName();
  709. const bool isPrecise = var->hasAttr<HLSLPreciseAttr>();
  710. SpirvVariable *varInstr = spvBuilder.addFnVar(
  711. type, loc, name, isPrecise, init.hasValue() ? init.getValue() : nullptr);
  712. bool isAlias = false;
  713. (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
  714. varInstr->setContainsAliasComponent(isAlias);
  715. assert(astDecls[var].instr == nullptr);
  716. astDecls[var].instr = varInstr;
  717. return varInstr;
  718. }
  719. SpirvDebugGlobalVariable *DeclResultIdMapper::createDebugGlobalVariable(
  720. SpirvVariable *var, const QualType &type, const SourceLocation &loc,
  721. const StringRef &name) {
  722. if (spirvOptions.debugInfoRich) {
  723. // Add DebugGlobalVariable information
  724. const auto &sm = astContext.getSourceManager();
  725. const uint32_t line = sm.getPresumedLineNumber(loc);
  726. const uint32_t column = sm.getPresumedColumnNumber(loc);
  727. const auto *info = theEmitter.getOrCreateRichDebugInfo(loc);
  728. // TODO: replace this with FlagIsDefinition enum.
  729. uint32_t flags = 1 << 3;
  730. // TODO: update linkageName correctly.
  731. auto *dbgGlobalVar = spvBuilder.createDebugGlobalVariable(
  732. type, name, info->source, line, column, info->scopeStack.back(),
  733. /* linkageName */ name, var, flags);
  734. dbgGlobalVar->setDebugSpirvType(var->getResultType());
  735. dbgGlobalVar->setLayoutRule(var->getLayoutRule());
  736. return dbgGlobalVar;
  737. }
  738. return nullptr;
  739. }
  740. SpirvVariable *
  741. DeclResultIdMapper::createFileVar(const VarDecl *var,
  742. llvm::Optional<SpirvInstruction *> init) {
  743. const auto type = getTypeOrFnRetType(var);
  744. const auto loc = var->getLocation();
  745. const auto name = var->getName();
  746. SpirvVariable *varInstr =
  747. spvBuilder.addModuleVar(type, spv::StorageClass::Private,
  748. var->hasAttr<HLSLPreciseAttr>(), name, init, loc);
  749. bool isAlias = false;
  750. (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
  751. varInstr->setContainsAliasComponent(isAlias);
  752. assert(astDecls[var].instr == nullptr);
  753. astDecls[var].instr = varInstr;
  754. createDebugGlobalVariable(varInstr, type, loc, name);
  755. return varInstr;
  756. }
  757. SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
  758. const auto type = var->getType();
  759. const bool isGroupShared = var->hasAttr<HLSLGroupSharedAttr>();
  760. const bool isACRWSBuffer = isRWAppendConsumeSBuffer(type);
  761. const auto storageClass = getStorageClassForExternVar(type, isGroupShared);
  762. const auto rule = getLayoutRuleForExternVar(type, spirvOptions);
  763. const auto loc = var->getLocation();
  764. if (!isGroupShared && !isResourceType(type) &&
  765. !isResourceOnlyStructure(type)) {
  766. // We currently cannot support global structures that contain both resources
  767. // and non-resources. That would require significant work in manipulating
  768. // structure field decls, manipulating QualTypes, as well as inserting
  769. // non-resources into the Globals cbuffer which changes offset decorations
  770. // for it.
  771. if (isStructureContainingMixOfResourcesAndNonResources(type)) {
  772. emitError("global structures containing both resources and non-resources "
  773. "are not supported",
  774. loc);
  775. return nullptr;
  776. }
  777. // This is a stand-alone externally-visiable non-resource-type variable.
  778. // They should be grouped into the $Globals cbuffer. We create that cbuffer
  779. // and record all variables inside it upon seeing the first such variable.
  780. if (astDecls.count(var) == 0)
  781. createGlobalsCBuffer(var);
  782. auto *varInstr = astDecls[var].instr;
  783. return varInstr ? cast<SpirvVariable>(varInstr) : nullptr;
  784. }
  785. if (isResourceOnlyStructure(type)) {
  786. // We currently do not support global structures that contain buffers.
  787. // Supporting global structures that contain buffers has two complications:
  788. //
  789. // 1- Buffers have the Uniform storage class, whereas Textures/Samplers have
  790. // UniformConstant storage class. As a result, if a struct contains both
  791. // textures and buffers, it is not clear what storage class should be used
  792. // for the struct. Also legalization cannot deduce the proper storage class
  793. // for struct members based on the structure's storage class.
  794. //
  795. // 2- Any kind of structured buffer has associated counters. The current DXC
  796. // code is not written in a way to place associated counters inside a
  797. // structure. Changing this behavior is non-trivial. There's also
  798. // significant work to be done both in DXC (to properly generate binding
  799. // numbers for the resource and its associated counters at correct offsets)
  800. // and in spirv-opt (to flatten such strcutures and modify the binding
  801. // numbers accordingly).
  802. if (isStructureContainingAnyKindOfBuffer(type)) {
  803. emitError("global structures containing buffers are not supported", loc);
  804. return nullptr;
  805. }
  806. needsFlatteningCompositeResources = true;
  807. }
  808. const auto name = var->getName();
  809. SpirvVariable *varInstr = spvBuilder.addModuleVar(
  810. type, storageClass, var->hasAttr<HLSLPreciseAttr>(), name, llvm::None,
  811. loc);
  812. varInstr->setLayoutRule(rule);
  813. // If this variable has [[vk::image_format("..")]] attribute, we have to keep
  814. // it in the SpirvContext and use it when we lower the QualType to SpirvType.
  815. auto spvImageFormat = getSpvImageFormat(var->getAttr<VKImageFormatAttr>());
  816. if (spvImageFormat != spv::ImageFormat::Unknown)
  817. spvContext.registerImageFormatForSpirvVariable(varInstr, spvImageFormat);
  818. DeclSpirvInfo info(varInstr);
  819. astDecls[var] = info;
  820. createDebugGlobalVariable(varInstr, type, loc, name);
  821. // Variables in Workgroup do not need descriptor decorations.
  822. if (storageClass == spv::StorageClass::Workgroup)
  823. return varInstr;
  824. const auto *regAttr = getResourceBinding(var);
  825. const auto *bindingAttr = var->getAttr<VKBindingAttr>();
  826. const auto *counterBindingAttr = var->getAttr<VKCounterBindingAttr>();
  827. resourceVars.emplace_back(varInstr, var, loc, regAttr, bindingAttr,
  828. counterBindingAttr);
  829. if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
  830. spvBuilder.decorateInputAttachmentIndex(varInstr,
  831. inputAttachment->getIndex(), loc);
  832. if (isACRWSBuffer) {
  833. // For {Append|Consume|RW}StructuredBuffer, we need to always create another
  834. // variable for its associated counter.
  835. createCounterVar(var, varInstr, /*isAlias=*/false);
  836. }
  837. return varInstr;
  838. }
  839. SpirvInstruction *
  840. DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
  841. assert(hlsl::IsStringType(var->getType()) ||
  842. hlsl::IsStringLiteralType(var->getType()));
  843. // If the string variable is not initialized to a string literal, we cannot
  844. // generate an OpString for it.
  845. if (!var->hasInit()) {
  846. emitError("Found uninitialized string variable.", var->getLocation());
  847. return nullptr;
  848. }
  849. const StringLiteral *stringLiteral =
  850. dyn_cast<StringLiteral>(var->getInit()->IgnoreParenCasts());
  851. SpirvString *init = spvBuilder.getString(stringLiteral->getString());
  852. DeclSpirvInfo info(init);
  853. astDecls[var] = info;
  854. return init;
  855. }
  856. SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
  857. const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
  858. llvm::StringRef typeName, llvm::StringRef varName) {
  859. // cbuffers are translated into OpTypeStruct with Block decoration.
  860. // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
  861. // Push constants are translated into OpTypeStruct with Block decoration.
  862. //
  863. // Both cbuffers and tbuffers have the SPIR-V Uniform storage class.
  864. // Push constants have the SPIR-V PushConstant storage class.
  865. const bool forCBuffer = usageKind == ContextUsageKind::CBuffer;
  866. const bool forTBuffer = usageKind == ContextUsageKind::TBuffer;
  867. const bool forGlobals = usageKind == ContextUsageKind::Globals;
  868. const bool forPC = usageKind == ContextUsageKind::PushConstant;
  869. const bool forShaderRecordNV =
  870. usageKind == ContextUsageKind::ShaderRecordBufferNV;
  871. const bool forShaderRecordEXT =
  872. usageKind == ContextUsageKind::ShaderRecordBufferEXT;
  873. const auto &declGroup = collectDeclsInDeclContext(decl);
  874. // Collect the type and name for each field
  875. llvm::SmallVector<HybridStructType::FieldInfo, 4> fields;
  876. for (const auto *subDecl : declGroup) {
  877. // The field can only be FieldDecl (for normal structs) or VarDecl (for
  878. // HLSLBufferDecls).
  879. assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
  880. const auto *declDecl = cast<DeclaratorDecl>(subDecl);
  881. // In case 'register(c#)' annotation is placed on a global variable.
  882. const hlsl::RegisterAssignment *registerC =
  883. forGlobals ? getRegisterCAssignment(declDecl) : nullptr;
  884. // All fields are qualified with const. It will affect the debug name.
  885. // We don't need it here.
  886. auto varType = declDecl->getType();
  887. varType.removeLocalConst();
  888. HybridStructType::FieldInfo info(varType, declDecl->getName(),
  889. declDecl->getAttr<VKOffsetAttr>(),
  890. getPackOffset(declDecl), registerC,
  891. declDecl->hasAttr<HLSLPreciseAttr>());
  892. fields.push_back(info);
  893. }
  894. // Get the type for the whole struct
  895. // tbuffer/TextureBuffers are non-writable SSBOs.
  896. const SpirvType *resultType = spvContext.getHybridStructType(
  897. fields, typeName, /*isReadOnly*/ forTBuffer,
  898. forTBuffer ? StructInterfaceType::StorageBuffer
  899. : StructInterfaceType::UniformBuffer);
  900. // Make an array if requested.
  901. if (arraySize > 0) {
  902. resultType = spvContext.getArrayType(resultType, arraySize,
  903. /*ArrayStride*/ llvm::None);
  904. } else if (arraySize == -1) {
  905. resultType =
  906. spvContext.getRuntimeArrayType(resultType, /*ArrayStride*/ llvm::None);
  907. }
  908. // Register the <type-id> for this decl
  909. ctBufferPCTypes[decl] = resultType;
  910. const auto sc = forPC ? spv::StorageClass::PushConstant
  911. : forShaderRecordNV
  912. ? spv::StorageClass::ShaderRecordBufferNV
  913. : forShaderRecordEXT
  914. ? spv::StorageClass::ShaderRecordBufferKHR
  915. : spv::StorageClass::Uniform;
  916. // Create the variable for the whole struct / struct array.
  917. // The fields may be 'precise', but the structure itself is not.
  918. SpirvVariable *var =
  919. spvBuilder.addModuleVar(resultType, sc, /*isPrecise*/ false, varName);
  920. const SpirvLayoutRule layoutRule =
  921. (forCBuffer || forGlobals)
  922. ? spirvOptions.cBufferLayoutRule
  923. : (forTBuffer ? spirvOptions.tBufferLayoutRule
  924. : spirvOptions.sBufferLayoutRule);
  925. var->setHlslUserType(forCBuffer ? "cbuffer" : forTBuffer ? "tbuffer" : "");
  926. var->setLayoutRule(layoutRule);
  927. return var;
  928. }
  929. void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
  930. const auto *valueDecl = dyn_cast<ValueDecl>(decl);
  931. const auto enumConstant =
  932. spvBuilder.getConstantInt(astContext.IntTy, decl->getInitVal());
  933. SpirvVariable *varInstr = spvBuilder.addModuleVar(
  934. astContext.IntTy, spv::StorageClass::Private, /*isPrecise*/ false,
  935. decl->getName(), enumConstant, decl->getLocation());
  936. astDecls[valueDecl] = DeclSpirvInfo(varInstr);
  937. }
  938. SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
  939. // This function handles creation of cbuffer or tbuffer.
  940. const auto usageKind =
  941. decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
  942. const std::string structName = "type." + decl->getName().str();
  943. // The front-end does not allow arrays of cbuffer/tbuffer.
  944. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  945. decl, /*arraySize*/ 0, usageKind, structName, decl->getName());
  946. // We still register all VarDecls seperately here. All the VarDecls are
  947. // mapped to the <result-id> of the buffer object, which means when querying
  948. // querying the <result-id> for a certain VarDecl, we need to do an extra
  949. // OpAccessChain.
  950. int index = 0;
  951. for (const auto *subDecl : decl->decls()) {
  952. if (shouldSkipInStructLayout(subDecl))
  953. continue;
  954. const auto *varDecl = cast<VarDecl>(subDecl);
  955. astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
  956. }
  957. resourceVars.emplace_back(
  958. bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
  959. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  960. auto *dbgGlobalVar = createDebugGlobalVariable(
  961. bufferVar, QualType(), decl->getLocation(), decl->getName());
  962. if (dbgGlobalVar != nullptr) {
  963. // C/TBuffer needs HLSLBufferDecl for debug type lowering.
  964. spvContext.registerStructDeclForSpirvType(bufferVar->getResultType(), decl);
  965. }
  966. return bufferVar;
  967. }
  968. SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
  969. // This function handles creation of ConstantBuffer<T> or TextureBuffer<T>.
  970. // The way this is represented in the AST is as follows:
  971. //
  972. // |-VarDecl MyCbuffer 'ConstantBuffer<T>':'ConstantBuffer<T>'
  973. // |-CXXRecordDecl referenced struct T definition
  974. // |-CXXRecordDecl implicit struct T
  975. // |-FieldDecl
  976. // |-...
  977. // |-FieldDecl
  978. const QualType type = decl->getType();
  979. assert(isConstantTextureBuffer(type));
  980. const RecordType *recordType = nullptr;
  981. const RecordType *templatedType = nullptr;
  982. int arraySize = 0;
  983. // In case we have an array of ConstantBuffer/TextureBuffer:
  984. if (const auto *arrayType = type->getAsArrayTypeUnsafe()) {
  985. const QualType elemType = arrayType->getElementType();
  986. recordType = elemType->getAs<RecordType>();
  987. templatedType =
  988. hlsl::GetHLSLResourceResultType(elemType)->getAs<RecordType>();
  989. if (const auto *caType = astContext.getAsConstantArrayType(type)) {
  990. arraySize = static_cast<uint32_t>(caType->getSize().getZExtValue());
  991. } else {
  992. arraySize = -1;
  993. }
  994. } else {
  995. recordType = type->getAs<RecordType>();
  996. templatedType = hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
  997. }
  998. if (!recordType) {
  999. emitError("constant/texture buffer type %0 unimplemented",
  1000. decl->getLocStart())
  1001. << type;
  1002. return nullptr;
  1003. }
  1004. if (!templatedType) {
  1005. emitError(
  1006. "the underlying type for constant/texture buffer must be a struct",
  1007. decl->getLocStart())
  1008. << type;
  1009. return nullptr;
  1010. }
  1011. const bool isConstBuffer = isConstantBuffer(type);
  1012. const auto usageKind =
  1013. isConstBuffer ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
  1014. const std::string structName = "type." +
  1015. recordType->getDecl()->getName().str() + "." +
  1016. templatedType->getDecl()->getName().str();
  1017. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  1018. templatedType->getDecl(), arraySize, usageKind, structName,
  1019. decl->getName());
  1020. // We register the VarDecl here.
  1021. astDecls[decl] = DeclSpirvInfo(bufferVar);
  1022. resourceVars.emplace_back(
  1023. bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
  1024. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  1025. return bufferVar;
  1026. }
  1027. SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
  1028. // The front-end errors out if non-struct type push constant is used.
  1029. const QualType type = decl->getType();
  1030. const auto *recordType = type->getAs<RecordType>();
  1031. if (isConstantBuffer(type)) {
  1032. // Get the templated type for ConstantBuffer.
  1033. recordType = hlsl::GetHLSLResourceResultType(type)->getAs<RecordType>();
  1034. }
  1035. assert(recordType);
  1036. const std::string structName =
  1037. "type.PushConstant." + recordType->getDecl()->getName().str();
  1038. SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
  1039. recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
  1040. structName, decl->getName());
  1041. // Register the VarDecl
  1042. astDecls[decl] = DeclSpirvInfo(var);
  1043. // Do not push this variable into resourceVars since it does not need
  1044. // descriptor set.
  1045. return var;
  1046. }
  1047. SpirvVariable *
  1048. DeclResultIdMapper::createShaderRecordBuffer(const VarDecl *decl,
  1049. ContextUsageKind kind) {
  1050. const auto *recordType =
  1051. hlsl::GetHLSLResourceResultType(decl->getType())->getAs<RecordType>();
  1052. assert(recordType);
  1053. assert(kind == ContextUsageKind::ShaderRecordBufferEXT ||
  1054. kind == ContextUsageKind::ShaderRecordBufferNV);
  1055. const auto typeName = kind == ContextUsageKind::ShaderRecordBufferEXT
  1056. ? "type.ShaderRecordBufferEXT."
  1057. : "type.ShaderRecordBufferNV.";
  1058. const std::string structName =
  1059. typeName + recordType->getDecl()->getName().str();
  1060. SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
  1061. recordType->getDecl(), /*arraySize*/ 0,
  1062. kind, structName, decl->getName());
  1063. // Register the VarDecl
  1064. astDecls[decl] = DeclSpirvInfo(var);
  1065. // Do not push this variable into resourceVars since it does not need
  1066. // descriptor set.
  1067. return var;
  1068. }
  1069. SpirvVariable *
  1070. DeclResultIdMapper::createShaderRecordBuffer(const HLSLBufferDecl *decl,
  1071. ContextUsageKind kind) {
  1072. assert(kind == ContextUsageKind::ShaderRecordBufferEXT ||
  1073. kind == ContextUsageKind::ShaderRecordBufferNV);
  1074. const auto typeName = kind == ContextUsageKind::ShaderRecordBufferEXT
  1075. ? "type.ShaderRecordBufferEXT."
  1076. : "type.ShaderRecordBufferNV.";
  1077. const std::string structName =
  1078. typeName + decl->getName().str();
  1079. // The front-end does not allow arrays of cbuffer/tbuffer.
  1080. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  1081. decl, /*arraySize*/ 0, kind, structName,
  1082. decl->getName());
  1083. // We still register all VarDecls seperately here. All the VarDecls are
  1084. // mapped to the <result-id> of the buffer object, which means when
  1085. // querying the <result-id> for a certain VarDecl, we need to do an extra
  1086. // OpAccessChain.
  1087. int index = 0;
  1088. for (const auto *subDecl : decl->decls()) {
  1089. if (shouldSkipInStructLayout(subDecl))
  1090. continue;
  1091. const auto *varDecl = cast<VarDecl>(subDecl);
  1092. astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
  1093. }
  1094. return bufferVar;
  1095. }
  1096. void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
  1097. if (astDecls.count(var) != 0)
  1098. return;
  1099. const auto *context = var->getTranslationUnitDecl();
  1100. SpirvVariable *globals = createStructOrStructArrayVarOfExplicitLayout(
  1101. context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
  1102. "$Globals");
  1103. resourceVars.emplace_back(globals, /*decl*/ nullptr, SourceLocation(),
  1104. nullptr, nullptr, nullptr, /*isCounterVar*/ false,
  1105. /*isGlobalsCBuffer*/ true);
  1106. uint32_t index = 0;
  1107. for (const auto *decl : collectDeclsInDeclContext(context)) {
  1108. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  1109. if (!spirvOptions.noWarnIgnoredFeatures) {
  1110. if (const auto *init = varDecl->getInit())
  1111. emitWarning(
  1112. "variable '%0' will be placed in $Globals so initializer ignored",
  1113. init->getExprLoc())
  1114. << var->getName() << init->getSourceRange();
  1115. }
  1116. if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
  1117. emitError("variable '%0' will be placed in $Globals so cannot have "
  1118. "vk::binding attribute",
  1119. attr->getLocation())
  1120. << var->getName();
  1121. return;
  1122. }
  1123. astDecls[varDecl] = DeclSpirvInfo(globals, index++);
  1124. }
  1125. }
  1126. }
  1127. SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
  1128. // Return it if it's already been created.
  1129. auto it = astFunctionDecls.find(fn);
  1130. if (it != astFunctionDecls.end()) {
  1131. return it->second;
  1132. }
  1133. bool isAlias = false;
  1134. (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
  1135. const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
  1136. const bool isNoInline = fn->hasAttr<NoInlineAttr>();
  1137. // Note: we do not need to worry about function parameter types at this point
  1138. // as this is used when function declarations are seen. When function
  1139. // definition is seen, the parameter types will be set properly and take into
  1140. // account whether the function is a member function of a class/struct (in
  1141. // which case a 'this' parameter is added at the beginnig).
  1142. SpirvFunction *spirvFunction =
  1143. spvBuilder.createSpirvFunction(fn->getReturnType(), fn->getLocation(),
  1144. fn->getName(), isPrecise, isNoInline);
  1145. // No need to dereference to get the pointer. Function returns that are
  1146. // stand-alone aliases are already pointers to values. All other cases should
  1147. // be normal rvalues.
  1148. if (!isAlias || !isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
  1149. spirvFunction->setRValue();
  1150. spirvFunction->setConstainsAliasComponent(isAlias);
  1151. astFunctionDecls[fn] = spirvFunction;
  1152. return spirvFunction;
  1153. }
  1154. const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
  1155. const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
  1156. if (!decl)
  1157. return nullptr;
  1158. if (indices) {
  1159. // Indices are provided. Walk through the fields of the decl.
  1160. const auto counter = fieldCounterVars.find(decl);
  1161. if (counter != fieldCounterVars.end())
  1162. return counter->second.get(*indices);
  1163. } else {
  1164. // No indices. Check the stand-alone entities.
  1165. const auto counter = counterVars.find(decl);
  1166. if (counter != counterVars.end())
  1167. return &counter->second;
  1168. }
  1169. return nullptr;
  1170. }
  1171. const CounterVarFields *
  1172. DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
  1173. if (!decl)
  1174. return nullptr;
  1175. const auto found = fieldCounterVars.find(decl);
  1176. if (found != fieldCounterVars.end())
  1177. return &found->second;
  1178. return nullptr;
  1179. }
  1180. void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
  1181. SpirvInstruction *specConstant) {
  1182. specConstant->setRValue();
  1183. astDecls[decl] = DeclSpirvInfo(specConstant);
  1184. }
  1185. void DeclResultIdMapper::createCounterVar(
  1186. const DeclaratorDecl *decl, SpirvInstruction *declInstr, bool isAlias,
  1187. const llvm::SmallVector<uint32_t, 4> *indices) {
  1188. std::string counterName = "counter.var." + decl->getName().str();
  1189. if (indices) {
  1190. // Append field indices to the name
  1191. for (const auto index : *indices)
  1192. counterName += "." + std::to_string(index);
  1193. }
  1194. const SpirvType *counterType = spvContext.getACSBufferCounterType();
  1195. // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
  1196. // Alias counter variables should be created into the Private storage class.
  1197. const spv::StorageClass sc =
  1198. isAlias ? spv::StorageClass::Private : spv::StorageClass::Uniform;
  1199. if (isAlias) {
  1200. // Apply an extra level of pointer for alias counter variable
  1201. counterType =
  1202. spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
  1203. }
  1204. SpirvVariable *counterInstr = spvBuilder.addModuleVar(
  1205. counterType, sc, /*isPrecise*/ false, counterName);
  1206. if (!isAlias) {
  1207. // Non-alias counter variables should be put in to resourceVars so that
  1208. // descriptors can be allocated for them.
  1209. resourceVars.emplace_back(counterInstr, decl, decl->getLocation(),
  1210. getResourceBinding(decl),
  1211. decl->getAttr<VKBindingAttr>(),
  1212. decl->getAttr<VKCounterBindingAttr>(), true);
  1213. assert(declInstr);
  1214. spvBuilder.decorateCounterBuffer(declInstr, counterInstr,
  1215. decl->getLocation());
  1216. }
  1217. if (indices)
  1218. fieldCounterVars[decl].append(*indices, counterInstr);
  1219. else
  1220. counterVars[decl] = {counterInstr, isAlias};
  1221. }
  1222. void DeclResultIdMapper::createFieldCounterVars(
  1223. const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
  1224. llvm::SmallVector<uint32_t, 4> *indices) {
  1225. const QualType type = getTypeOrFnRetType(decl);
  1226. const auto *recordType = type->getAs<RecordType>();
  1227. assert(recordType);
  1228. const auto *recordDecl = recordType->getDecl();
  1229. for (const auto *field : recordDecl->fields()) {
  1230. // Build up the index chain
  1231. indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
  1232. const QualType fieldType = field->getType();
  1233. if (isRWAppendConsumeSBuffer(fieldType))
  1234. createCounterVar(rootDecl, /*declId=*/0, /*isAlias=*/true, indices);
  1235. else if (fieldType->isStructureType() &&
  1236. !hlsl::IsHLSLResourceType(fieldType))
  1237. // Go recursively into all nested structs
  1238. createFieldCounterVars(rootDecl, field, indices);
  1239. indices->pop_back();
  1240. }
  1241. }
  1242. const SpirvType *
  1243. DeclResultIdMapper::getCTBufferPushConstantType(const DeclContext *decl) {
  1244. const auto found = ctBufferPCTypes.find(decl);
  1245. assert(found != ctBufferPCTypes.end());
  1246. return found->second;
  1247. }
  1248. std::vector<SpirvVariable *> DeclResultIdMapper::collectStageVars() const {
  1249. std::vector<SpirvVariable *> vars;
  1250. for (auto var : glPerVertex.getStageInVars())
  1251. vars.push_back(var);
  1252. for (auto var : glPerVertex.getStageOutVars())
  1253. vars.push_back(var);
  1254. llvm::DenseSet<SpirvInstruction *> seenVars;
  1255. for (const auto &var : stageVars) {
  1256. auto *instr = var.getSpirvInstr();
  1257. if (seenVars.count(instr) == 0) {
  1258. vars.push_back(instr);
  1259. seenVars.insert(instr);
  1260. }
  1261. }
  1262. return vars;
  1263. }
  1264. namespace {
  1265. /// A class for managing stage input/output locations to avoid duplicate uses of
  1266. /// the same location.
  1267. class LocationSet {
  1268. public:
  1269. /// Maximum number of indices supported
  1270. const static uint32_t kMaxIndex = 2;
  1271. /// Maximum number of locations supported
  1272. // Typically we won't have that many stage input or output variables.
  1273. // Using 64 should be fine here.
  1274. const static uint32_t kMaxLoc = 64;
  1275. LocationSet() {
  1276. for (uint32_t i = 0; i < kMaxIndex; ++i) {
  1277. usedLocs[i].resize(kMaxLoc);
  1278. nextLoc[i] = 0;
  1279. }
  1280. }
  1281. /// Uses the given location.
  1282. void useLoc(uint32_t loc, uint32_t index = 0) {
  1283. assert(index < kMaxIndex);
  1284. usedLocs[index].set(loc);
  1285. }
  1286. /// Uses the next |count| available location.
  1287. int useNextLocs(uint32_t count, uint32_t index = 0) {
  1288. assert(index < kMaxIndex);
  1289. auto &locs = usedLocs[index];
  1290. auto &next = nextLoc[index];
  1291. while (locs[next])
  1292. next++;
  1293. int toUse = next;
  1294. for (uint32_t i = 0; i < count; ++i) {
  1295. assert(!locs[next]);
  1296. locs.set(next++);
  1297. }
  1298. return toUse;
  1299. }
  1300. /// Returns true if the given location number is already used.
  1301. bool isLocUsed(uint32_t loc, uint32_t index = 0) {
  1302. assert(index < kMaxIndex);
  1303. return usedLocs[index][loc];
  1304. }
  1305. private:
  1306. llvm::SmallBitVector usedLocs[kMaxIndex]; ///< All previously used locations
  1307. uint32_t nextLoc[kMaxIndex]; ///< Next available location
  1308. };
  1309. /// A class for managing resource bindings to avoid duplicate uses of the same
  1310. /// set and binding number.
  1311. class BindingSet {
  1312. public:
  1313. /// Uses the given set and binding number. Returns false if the binding number
  1314. /// was already occupied in the set, and returns true otherwise.
  1315. bool useBinding(uint32_t binding, uint32_t set) {
  1316. bool inserted = false;
  1317. std::tie(std::ignore, inserted) = usedBindings[set].insert(binding);
  1318. return inserted;
  1319. }
  1320. /// Uses the next available binding number in |set|. If more than one binding
  1321. /// number is to be occupied, it finds the next available chunk that can fit
  1322. /// |numBindingsToUse| in the |set|.
  1323. uint32_t useNextBinding(uint32_t set, uint32_t numBindingsToUse = 1,
  1324. uint32_t bindingShift = 0) {
  1325. uint32_t bindingNoStart =
  1326. getNextBindingChunk(set, numBindingsToUse, bindingShift);
  1327. auto &binding = usedBindings[set];
  1328. for (uint32_t i = 0; i < numBindingsToUse; ++i)
  1329. binding.insert(bindingNoStart + i);
  1330. return bindingNoStart;
  1331. }
  1332. /// Returns the first available binding number in the |set| for which |n|
  1333. /// consecutive binding numbers are unused starting at |bindingShift|.
  1334. uint32_t getNextBindingChunk(uint32_t set, uint32_t n,
  1335. uint32_t bindingShift) {
  1336. auto &existingBindings = usedBindings[set];
  1337. // There were no bindings in this set. Can start at binding zero.
  1338. if (existingBindings.empty())
  1339. return bindingShift;
  1340. // Check whether the chunk of |n| binding numbers can be fitted at the
  1341. // very beginning of the list (start at binding 0 in the current set).
  1342. uint32_t curBinding = *existingBindings.begin();
  1343. if (curBinding >= (n + bindingShift))
  1344. return bindingShift;
  1345. auto iter = std::next(existingBindings.begin());
  1346. while (iter != existingBindings.end()) {
  1347. // There exists a next binding number that is used. Check to see if the
  1348. // gap between current binding number and next binding number is large
  1349. // enough to accommodate |n|.
  1350. uint32_t nextBinding = *iter;
  1351. if ((bindingShift > 0) && (curBinding < (bindingShift - 1)))
  1352. curBinding = bindingShift - 1;
  1353. if (curBinding < nextBinding && n <= nextBinding - curBinding - 1)
  1354. return curBinding + 1;
  1355. curBinding = nextBinding;
  1356. // Peek at the next binding that has already been used (if any).
  1357. ++iter;
  1358. }
  1359. // |curBinding| was the last binding that was used in this set. The next
  1360. // chunk of |n| bindings can start at |curBinding|+1.
  1361. return std::max(curBinding + 1, bindingShift);
  1362. }
  1363. private:
  1364. ///< set number -> set of used binding number
  1365. llvm::DenseMap<uint32_t, std::set<uint32_t>> usedBindings;
  1366. };
  1367. } // namespace
  1368. bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
  1369. llvm::StringSet<> seenSemantics;
  1370. bool success = true;
  1371. for (const auto &var : stageVars) {
  1372. auto s = var.getSemanticStr();
  1373. if (s.empty()) {
  1374. // We translate WaveGetLaneCount(), WaveGetLaneIndex() and 'payload' param
  1375. // block declaration into builtin variables. Those variables are inserted
  1376. // into the normal stage IO processing pipeline, but with the semantics as
  1377. // empty strings.
  1378. assert(var.isSpirvBuitin());
  1379. continue;
  1380. }
  1381. // Allow builtin variables to alias each other. We already have uniqify
  1382. // mechanism in SpirvBuilder.
  1383. if (var.isSpirvBuitin())
  1384. continue;
  1385. if (forInput && var.getSigPoint()->IsInput()) {
  1386. if (seenSemantics.count(s)) {
  1387. emitError("input semantic '%0' used more than once", {}) << s;
  1388. success = false;
  1389. }
  1390. seenSemantics.insert(s);
  1391. } else if (!forInput && var.getSigPoint()->IsOutput()) {
  1392. if (seenSemantics.count(s)) {
  1393. emitError("output semantic '%0' used more than once", {}) << s;
  1394. success = false;
  1395. }
  1396. seenSemantics.insert(s);
  1397. }
  1398. }
  1399. return success;
  1400. }
  1401. bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
  1402. if (!checkSemanticDuplication(forInput))
  1403. return false;
  1404. // Returns false if the given StageVar is an input/output variable without
  1405. // explicit location assignment. Otherwise, returns true.
  1406. const auto locAssigned = [forInput, this](const StageVar &v) {
  1407. if (forInput == isInputStorageClass(v))
  1408. // No need to assign location for builtins. Treat as assigned.
  1409. return v.isSpirvBuitin() || v.getLocationAttr() != nullptr;
  1410. // For the ones we don't care, treat as assigned.
  1411. return true;
  1412. };
  1413. // If we have explicit location specified for all input/output variables,
  1414. // use them instead assign by ourselves.
  1415. if (std::all_of(stageVars.begin(), stageVars.end(), locAssigned)) {
  1416. LocationSet locSet;
  1417. bool noError = true;
  1418. for (const auto &var : stageVars) {
  1419. // Skip builtins & those stage variables we are not handling for this call
  1420. if (var.isSpirvBuitin() || forInput != isInputStorageClass(var))
  1421. continue;
  1422. const auto *attr = var.getLocationAttr();
  1423. const auto loc = attr->getNumber();
  1424. const auto attrLoc = attr->getLocation(); // Attr source code location
  1425. const auto idx = var.getIndexAttr() ? var.getIndexAttr()->getNumber() : 0;
  1426. if ((const unsigned)loc >= LocationSet::kMaxLoc) {
  1427. emitError("stage %select{output|input}0 location #%1 too large",
  1428. attrLoc)
  1429. << forInput << loc;
  1430. return false;
  1431. }
  1432. // Make sure the same location is not assigned more than once
  1433. if (locSet.isLocUsed(loc, idx)) {
  1434. emitError("stage %select{output|input}0 location #%1 already assigned",
  1435. attrLoc)
  1436. << forInput << loc;
  1437. noError = false;
  1438. }
  1439. locSet.useLoc(loc, idx);
  1440. spvBuilder.decorateLocation(var.getSpirvInstr(), loc);
  1441. if (var.getIndexAttr())
  1442. spvBuilder.decorateIndex(var.getSpirvInstr(), idx,
  1443. var.getSemanticInfo().loc);
  1444. }
  1445. return noError;
  1446. }
  1447. std::vector<const StageVar *> vars;
  1448. LocationSet locSet;
  1449. for (const auto &var : stageVars) {
  1450. if (var.isSpirvBuitin() || forInput != isInputStorageClass(var))
  1451. continue;
  1452. if (var.getLocationAttr()) {
  1453. // We have checked that not all of the stage variables have explicit
  1454. // location assignment.
  1455. emitError("partial explicit stage %select{output|input}0 location "
  1456. "assignment via vk::location(X) unsupported",
  1457. {})
  1458. << forInput;
  1459. return false;
  1460. }
  1461. const auto &semaInfo = var.getSemanticInfo();
  1462. // We should special rules for SV_Target: the location number comes from the
  1463. // semantic string index.
  1464. if (semaInfo.isTarget()) {
  1465. spvBuilder.decorateLocation(var.getSpirvInstr(), semaInfo.index);
  1466. locSet.useLoc(semaInfo.index);
  1467. } else {
  1468. vars.push_back(&var);
  1469. }
  1470. }
  1471. // If alphabetical ordering was requested, sort by semantic string.
  1472. // Since HS includes 2 sets of outputs (patch-constant output and
  1473. // OutputPatch), running into location mismatches between HS and DS is very
  1474. // likely. In order to avoid location mismatches between HS and DS, use
  1475. // alphabetical ordering.
  1476. if (spirvOptions.stageIoOrder == "alpha" ||
  1477. (!forInput && spvContext.isHS()) || (forInput && spvContext.isDS())) {
  1478. // Sort stage input/output variables alphabetically
  1479. std::sort(vars.begin(), vars.end(),
  1480. [](const StageVar *a, const StageVar *b) {
  1481. return a->getSemanticStr() < b->getSemanticStr();
  1482. });
  1483. }
  1484. for (const auto *var : vars)
  1485. spvBuilder.decorateLocation(var->getSpirvInstr(),
  1486. locSet.useNextLocs(var->getLocationCount()));
  1487. return true;
  1488. }
  1489. namespace {
  1490. /// A class for maintaining the binding number shift requested for descriptor
  1491. /// sets.
  1492. class BindingShiftMapper {
  1493. public:
  1494. explicit BindingShiftMapper(const llvm::SmallVectorImpl<int32_t> &shifts)
  1495. : masterShift(0) {
  1496. assert(shifts.size() % 2 == 0);
  1497. if (shifts.size() == 2 && shifts[1] == -1) {
  1498. masterShift = shifts[0];
  1499. } else {
  1500. for (uint32_t i = 0; i < shifts.size(); i += 2)
  1501. perSetShift[shifts[i + 1]] = shifts[i];
  1502. }
  1503. }
  1504. /// Returns the shift amount for the given set.
  1505. int32_t getShiftForSet(int32_t set) const {
  1506. const auto found = perSetShift.find(set);
  1507. if (found != perSetShift.end())
  1508. return found->second;
  1509. return masterShift;
  1510. }
  1511. private:
  1512. uint32_t masterShift; /// Shift amount applies to all sets.
  1513. llvm::DenseMap<int32_t, int32_t> perSetShift;
  1514. };
  1515. /// A class for maintaining the mapping from source code register attributes to
  1516. /// descriptor set and number settings.
  1517. class RegisterBindingMapper {
  1518. public:
  1519. /// Takes in the relation between register attributes and descriptor settings.
  1520. /// Each relation is represented by four strings:
  1521. /// <register-type-number> <space> <descriptor-binding> <set>
  1522. bool takeInRelation(const std::vector<std::string> &relation,
  1523. std::string *error) {
  1524. assert(relation.size() % 4 == 0);
  1525. mapping.clear();
  1526. for (uint32_t i = 0; i < relation.size(); i += 4) {
  1527. int32_t spaceNo = -1, setNo = -1, bindNo = -1;
  1528. if (StringRef(relation[i + 1]).getAsInteger(10, spaceNo) || spaceNo < 0) {
  1529. *error = "space number: " + relation[i + 1];
  1530. return false;
  1531. }
  1532. if (StringRef(relation[i + 2]).getAsInteger(10, bindNo) || bindNo < 0) {
  1533. *error = "binding number: " + relation[i + 2];
  1534. return false;
  1535. }
  1536. if (StringRef(relation[i + 3]).getAsInteger(10, setNo) || setNo < 0) {
  1537. *error = "set number: " + relation[i + 3];
  1538. return false;
  1539. }
  1540. mapping[relation[i + 1] + relation[i]] = std::make_pair(setNo, bindNo);
  1541. }
  1542. return true;
  1543. }
  1544. /// Returns true and set the correct set and binding number if we can find a
  1545. /// descriptor setting for the given register. False otherwise.
  1546. bool getSetBinding(const hlsl::RegisterAssignment *regAttr,
  1547. uint32_t defaultSpace, int *setNo, int *bindNo) const {
  1548. std::ostringstream iss;
  1549. iss << regAttr->RegisterSpace.getValueOr(defaultSpace)
  1550. << regAttr->RegisterType << regAttr->RegisterNumber;
  1551. auto found = mapping.find(iss.str());
  1552. if (found != mapping.end()) {
  1553. *setNo = found->second.first;
  1554. *bindNo = found->second.second;
  1555. return true;
  1556. }
  1557. return false;
  1558. }
  1559. private:
  1560. llvm::StringMap<std::pair<int, int>> mapping;
  1561. };
  1562. } // namespace
  1563. bool DeclResultIdMapper::decorateResourceBindings() {
  1564. // For normal resource, we support 4 approaches of setting binding numbers:
  1565. // - m1: [[vk::binding(...)]]
  1566. // - m2: :register(xX, spaceY)
  1567. // - m3: None
  1568. // - m4: :register(spaceY)
  1569. //
  1570. // For associated counters, we support 2 approaches:
  1571. // - c1: [[vk::counter_binding(...)]
  1572. // - c2: None
  1573. //
  1574. // In combination, we need to handle 12 cases:
  1575. // - 4 cases for nomral resoures (m1, m2, m3, m4)
  1576. // - 8 cases for associated counters (mX * cY)
  1577. //
  1578. // In the following order:
  1579. // - m1, mX * c1
  1580. // - m2
  1581. // - m3, m4, mX * c2
  1582. // The "-auto-binding-space" command line option can be used to specify a
  1583. // certain space as default. UINT_MAX means the user has not provided this
  1584. // option. If not provided, the SPIR-V backend uses space "0" as default.
  1585. auto defaultSpaceOpt =
  1586. theEmitter.getCompilerInstance().getCodeGenOpts().HLSLDefaultSpace;
  1587. uint32_t defaultSpace = (defaultSpaceOpt == UINT_MAX) ? 0 : defaultSpaceOpt;
  1588. const bool bindGlobals = !spirvOptions.bindGlobals.empty();
  1589. int32_t globalsBindNo = -1, globalsSetNo = -1;
  1590. if (bindGlobals) {
  1591. assert(spirvOptions.bindGlobals.size() == 2);
  1592. if (StringRef(spirvOptions.bindGlobals[0])
  1593. .getAsInteger(10, globalsBindNo) ||
  1594. globalsBindNo < 0) {
  1595. emitError("invalid -fvk-bind-globals binding number: %0", {})
  1596. << spirvOptions.bindGlobals[0];
  1597. return false;
  1598. }
  1599. if (StringRef(spirvOptions.bindGlobals[1]).getAsInteger(10, globalsSetNo) ||
  1600. globalsSetNo < 0) {
  1601. emitError("invalid -fvk-bind-globals set number: %0", {})
  1602. << spirvOptions.bindGlobals[1];
  1603. return false;
  1604. }
  1605. }
  1606. // Special handling of -fvk-bind-register, which requires
  1607. // * All resources are annoated with :register() in the source code
  1608. // * -fvk-bind-register is specified for every resource
  1609. if (!spirvOptions.bindRegister.empty()) {
  1610. RegisterBindingMapper bindingMapper;
  1611. std::string error;
  1612. if (!bindingMapper.takeInRelation(spirvOptions.bindRegister, &error)) {
  1613. emitError("invalid -fvk-bind-register %0", {}) << error;
  1614. return false;
  1615. }
  1616. for (const auto &var : resourceVars)
  1617. if (const auto *regAttr = var.getRegister()) {
  1618. if (var.isCounter()) {
  1619. emitError("-fvk-bind-register for RW/Append/Consume StructuredBuffer "
  1620. "unimplemented",
  1621. var.getSourceLocation());
  1622. } else {
  1623. int setNo = 0, bindNo = 0;
  1624. if (!bindingMapper.getSetBinding(regAttr, defaultSpace, &setNo,
  1625. &bindNo)) {
  1626. emitError("missing -fvk-bind-register for resource",
  1627. var.getSourceLocation());
  1628. return false;
  1629. }
  1630. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindNo);
  1631. }
  1632. } else if (bindGlobals && var.isGlobalsBuffer()) {
  1633. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
  1634. globalsBindNo);
  1635. } else {
  1636. emitError(
  1637. "-fvk-bind-register requires register annotations on all resources",
  1638. var.getSourceLocation());
  1639. return false;
  1640. }
  1641. return true;
  1642. }
  1643. BindingSet bindingSet;
  1644. // Decorates the given varId of the given category with set number
  1645. // setNo, binding number bindingNo. Ignores overlaps.
  1646. const auto tryToDecorate = [this, &bindingSet](const ResourceVar &var,
  1647. const uint32_t setNo,
  1648. const uint32_t bindingNo) {
  1649. // By default we use one binding number per resource, and an array of
  1650. // resources also gets only one binding number. However, for array of
  1651. // resources (e.g. array of textures), DX uses one binding number per array
  1652. // element. We can match this behavior via a command line option.
  1653. uint32_t numBindingsToUse = 1;
  1654. if (spirvOptions.flattenResourceArrays || needsFlatteningCompositeResources)
  1655. numBindingsToUse = getNumBindingsUsedByResourceType(
  1656. var.getSpirvInstr()->getAstResultType());
  1657. for (uint32_t i = 0; i < numBindingsToUse; ++i) {
  1658. bool success = bindingSet.useBinding(bindingNo + i, setNo);
  1659. // We will not emit an error if we find a set/binding overlap because it
  1660. // is possible that the optimizer optimizes away a resource which resolves
  1661. // the overlap.
  1662. (void)success;
  1663. }
  1664. // No need to decorate multiple binding numbers for arrays. It will be done
  1665. // by legalization/optimization.
  1666. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindingNo);
  1667. };
  1668. for (const auto &var : resourceVars) {
  1669. if (var.isCounter()) {
  1670. if (const auto *vkCBinding = var.getCounterBinding()) {
  1671. // Process mX * c1
  1672. uint32_t set = defaultSpace;
  1673. if (const auto *vkBinding = var.getBinding())
  1674. set = getVkBindingAttrSet(vkBinding, defaultSpace);
  1675. else if (const auto *reg = var.getRegister())
  1676. set = reg->RegisterSpace.getValueOr(defaultSpace);
  1677. tryToDecorate(var, set, vkCBinding->getBinding());
  1678. }
  1679. } else {
  1680. if (const auto *vkBinding = var.getBinding()) {
  1681. // Process m1
  1682. tryToDecorate(var, getVkBindingAttrSet(vkBinding, defaultSpace),
  1683. vkBinding->getBinding());
  1684. }
  1685. }
  1686. }
  1687. BindingShiftMapper bShiftMapper(spirvOptions.bShift);
  1688. BindingShiftMapper tShiftMapper(spirvOptions.tShift);
  1689. BindingShiftMapper sShiftMapper(spirvOptions.sShift);
  1690. BindingShiftMapper uShiftMapper(spirvOptions.uShift);
  1691. // Process m2
  1692. for (const auto &var : resourceVars)
  1693. if (!var.isCounter() && !var.getBinding())
  1694. if (const auto *reg = var.getRegister()) {
  1695. // Skip space-only register() annotations
  1696. if (reg->isSpaceOnly())
  1697. continue;
  1698. const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
  1699. uint32_t binding = reg->RegisterNumber;
  1700. switch (reg->RegisterType) {
  1701. case 'b':
  1702. binding += bShiftMapper.getShiftForSet(set);
  1703. break;
  1704. case 't':
  1705. binding += tShiftMapper.getShiftForSet(set);
  1706. break;
  1707. case 's':
  1708. binding += sShiftMapper.getShiftForSet(set);
  1709. break;
  1710. case 'u':
  1711. binding += uShiftMapper.getShiftForSet(set);
  1712. break;
  1713. case 'c':
  1714. // For setting packing offset. Does not affect binding.
  1715. break;
  1716. default:
  1717. llvm_unreachable("unknown register type found");
  1718. }
  1719. tryToDecorate(var, set, binding);
  1720. }
  1721. for (const auto &var : resourceVars) {
  1722. // By default we use one binding number per resource, and an array of
  1723. // resources also gets only one binding number. However, for array of
  1724. // resources (e.g. array of textures), DX uses one binding number per array
  1725. // element. We can match this behavior via a command line option.
  1726. uint32_t numBindingsToUse = 1;
  1727. if (spirvOptions.flattenResourceArrays || needsFlatteningCompositeResources)
  1728. numBindingsToUse = getNumBindingsUsedByResourceType(
  1729. var.getSpirvInstr()->getAstResultType());
  1730. BindingShiftMapper *bindingShiftMapper = nullptr;
  1731. if (spirvOptions.autoShiftBindings) {
  1732. char registerType = '\0';
  1733. if (getImplicitRegisterType(var, &registerType)) {
  1734. switch (registerType) {
  1735. case 'b':
  1736. bindingShiftMapper = &bShiftMapper;
  1737. break;
  1738. case 't':
  1739. bindingShiftMapper = &tShiftMapper;
  1740. break;
  1741. case 's':
  1742. bindingShiftMapper = &sShiftMapper;
  1743. break;
  1744. case 'u':
  1745. bindingShiftMapper = &uShiftMapper;
  1746. break;
  1747. default:
  1748. llvm_unreachable("unknown register type found");
  1749. }
  1750. }
  1751. }
  1752. if (var.isCounter()) {
  1753. if (!var.getCounterBinding()) {
  1754. // Process mX * c2
  1755. uint32_t set = defaultSpace;
  1756. if (const auto *vkBinding = var.getBinding())
  1757. set = getVkBindingAttrSet(vkBinding, defaultSpace);
  1758. else if (const auto *reg = var.getRegister())
  1759. set = reg->RegisterSpace.getValueOr(defaultSpace);
  1760. uint32_t bindingShift = 0;
  1761. if (bindingShiftMapper)
  1762. bindingShift = bindingShiftMapper->getShiftForSet(set);
  1763. spvBuilder.decorateDSetBinding(
  1764. var.getSpirvInstr(), set,
  1765. bindingSet.useNextBinding(set, numBindingsToUse, bindingShift));
  1766. }
  1767. } else if (!var.getBinding()) {
  1768. const auto *reg = var.getRegister();
  1769. if (reg && reg->isSpaceOnly()) {
  1770. const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
  1771. uint32_t bindingShift = 0;
  1772. if (bindingShiftMapper)
  1773. bindingShift = bindingShiftMapper->getShiftForSet(set);
  1774. spvBuilder.decorateDSetBinding(
  1775. var.getSpirvInstr(), set,
  1776. bindingSet.useNextBinding(set, numBindingsToUse, bindingShift));
  1777. } else if (!reg) {
  1778. // Process m3 (no 'vk::binding' and no ':register' assignment)
  1779. // There is a special case for the $Globals cbuffer. The $Globals buffer
  1780. // doesn't have either 'vk::binding' or ':register', but the user may
  1781. // ask for a specific binding for it via command line options.
  1782. if (bindGlobals && var.isGlobalsBuffer()) {
  1783. uint32_t bindingShift = 0;
  1784. if (bindingShiftMapper)
  1785. bindingShift = bindingShiftMapper->getShiftForSet(globalsSetNo);
  1786. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
  1787. globalsBindNo + bindingShift);
  1788. }
  1789. // The normal case
  1790. else {
  1791. uint32_t bindingShift = 0;
  1792. if (bindingShiftMapper)
  1793. bindingShift = bindingShiftMapper->getShiftForSet(defaultSpace);
  1794. spvBuilder.decorateDSetBinding(
  1795. var.getSpirvInstr(), defaultSpace,
  1796. bindingSet.useNextBinding(defaultSpace, numBindingsToUse,
  1797. bindingShift));
  1798. }
  1799. }
  1800. }
  1801. }
  1802. return true;
  1803. }
  1804. bool DeclResultIdMapper::decorateResourceCoherent() {
  1805. for (const auto &var : resourceVars) {
  1806. if (const auto *decl = var.getDeclaration()) {
  1807. if (decl->getAttr<HLSLGloballyCoherentAttr>()) {
  1808. spvBuilder.decorateCoherent(var.getSpirvInstr(),
  1809. var.getSourceLocation());
  1810. }
  1811. }
  1812. }
  1813. return true;
  1814. }
  1815. bool DeclResultIdMapper::createStageVars(
  1816. const hlsl::SigPoint *sigPoint, const NamedDecl *decl, bool asInput,
  1817. QualType type, uint32_t arraySize, const llvm::StringRef namePrefix,
  1818. llvm::Optional<SpirvInstruction *> invocationId, SpirvInstruction **value,
  1819. bool noWriteBack, SemanticInfo *inheritSemantic) {
  1820. assert(value);
  1821. // invocationId should only be used for handling HS per-vertex output.
  1822. if (invocationId.hasValue()) {
  1823. assert(spvContext.isHS() && arraySize != 0 && !asInput);
  1824. }
  1825. assert(inheritSemantic);
  1826. if (type->isVoidType()) {
  1827. // No stage variables will be created for void type.
  1828. return true;
  1829. }
  1830. // The type the variable is evaluated as for SPIR-V.
  1831. QualType evalType = type;
  1832. // We have several cases regarding HLSL semantics to handle here:
  1833. // * If the currrent decl inherits a semantic from some enclosing entity,
  1834. // use the inherited semantic no matter whether there is a semantic
  1835. // attached to the current decl.
  1836. // * If there is no semantic to inherit,
  1837. // * If the current decl is a struct,
  1838. // * If the current decl has a semantic, all its members inhert this
  1839. // decl's semantic, with the index sequentially increasing;
  1840. // * If the current decl does not have a semantic, all its members
  1841. // should have semantics attached;
  1842. // * If the current decl is not a struct, it should have semantic attached.
  1843. auto thisSemantic = getStageVarSemantic(decl);
  1844. // Which semantic we should use for this decl
  1845. auto *semanticToUse = &thisSemantic;
  1846. // Enclosing semantics override internal ones
  1847. if (inheritSemantic->isValid()) {
  1848. if (thisSemantic.isValid()) {
  1849. emitWarning(
  1850. "internal semantic '%0' overridden by enclosing semantic '%1'",
  1851. thisSemantic.loc)
  1852. << thisSemantic.str << inheritSemantic->str;
  1853. }
  1854. semanticToUse = inheritSemantic;
  1855. }
  1856. const auto loc = decl->getLocation();
  1857. if (semanticToUse->isValid() &&
  1858. // Structs with attached semantics will be handled later.
  1859. !type->isStructureType()) {
  1860. // Found semantic attached directly to this Decl. This means we need to
  1861. // map this decl to a single stage variable.
  1862. if (!validateVKAttributes(decl))
  1863. return false;
  1864. const auto semanticKind = semanticToUse->getKind();
  1865. const auto sigPointKind = sigPoint->GetKind();
  1866. // Error out when the given semantic is invalid in this shader model
  1867. if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind,
  1868. spvContext.getMajorVersion(),
  1869. spvContext.getMinorVersion()) ==
  1870. hlsl::DXIL::SemanticInterpretationKind::NA) {
  1871. // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex".
  1872. switch (sigPointKind) {
  1873. case hlsl::SigPoint::Kind::MSIn:
  1874. case hlsl::SigPoint::Kind::ASIn:
  1875. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  1876. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  1877. if (builtin == "DrawIndex") {
  1878. break;
  1879. }
  1880. }
  1881. // fall through
  1882. default:
  1883. emitError("invalid usage of semantic '%0' in shader profile %1", loc)
  1884. << semanticToUse->str
  1885. << hlsl::ShaderModel::GetKindName(
  1886. spvContext.getCurrentShaderModelKind());
  1887. return false;
  1888. }
  1889. }
  1890. if (!validateVKBuiltins(decl, sigPoint))
  1891. return false;
  1892. const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>();
  1893. // Special handling of certain mappings between HLSL semantics and
  1894. // SPIR-V builtins:
  1895. // * SV_CullDistance/SV_ClipDistance are outsourced to GlPerVertex.
  1896. // * SV_DomainLocation can refer to a float2, whereas TessCoord is a float3.
  1897. // To ensure SPIR-V validity, we must create a float3 and extract a
  1898. // float2 from it before passing it to the main function.
  1899. // * SV_TessFactor is an array of size 2 for isoline patch, array of size 3
  1900. // for tri patch, and array of size 4 for quad patch, but it must always
  1901. // be an array of size 4 in SPIR-V for Vulkan.
  1902. // * SV_InsideTessFactor is a single float for tri patch, and an array of
  1903. // size 2 for a quad patch, but it must always be an array of size 2 in
  1904. // SPIR-V for Vulkan.
  1905. // * SV_Coverage is an uint value, but the builtin it corresponds to,
  1906. // SampleMask, must be an array of integers.
  1907. // * SV_InnerCoverage is an uint value, but the corresponding builtin,
  1908. // FullyCoveredEXT, must be an boolean value.
  1909. // * SV_DispatchThreadID, SV_GroupThreadID, and SV_GroupID are allowed to be
  1910. // uint, uint2, or uint3, but the corresponding builtins
  1911. // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3.
  1912. if (glPerVertex.tryToAccess(sigPointKind, semanticKind,
  1913. semanticToUse->index, invocationId, value,
  1914. noWriteBack, /*vecComponent=*/nullptr, loc))
  1915. return true;
  1916. switch (semanticKind) {
  1917. case hlsl::Semantic::Kind::DomainLocation:
  1918. evalType = astContext.getExtVectorType(astContext.FloatTy, 3);
  1919. break;
  1920. case hlsl::Semantic::Kind::TessFactor:
  1921. evalType = astContext.getConstantArrayType(
  1922. astContext.FloatTy, llvm::APInt(32, 4), clang::ArrayType::Normal, 0);
  1923. break;
  1924. case hlsl::Semantic::Kind::InsideTessFactor:
  1925. evalType = astContext.getConstantArrayType(
  1926. astContext.FloatTy, llvm::APInt(32, 2), clang::ArrayType::Normal, 0);
  1927. break;
  1928. case hlsl::Semantic::Kind::Coverage:
  1929. evalType = astContext.getConstantArrayType(astContext.UnsignedIntTy,
  1930. llvm::APInt(32, 1),
  1931. clang::ArrayType::Normal, 0);
  1932. break;
  1933. case hlsl::Semantic::Kind::InnerCoverage:
  1934. if (!type->isSpecificBuiltinType(BuiltinType::UInt)) {
  1935. emitError("SV_InnerCoverage must be of uint type.", loc);
  1936. return false;
  1937. }
  1938. evalType = astContext.BoolTy;
  1939. break;
  1940. case hlsl::Semantic::Kind::Barycentrics:
  1941. evalType = astContext.getExtVectorType(astContext.FloatTy, 2);
  1942. break;
  1943. case hlsl::Semantic::Kind::DispatchThreadID:
  1944. case hlsl::Semantic::Kind::GroupThreadID:
  1945. case hlsl::Semantic::Kind::GroupID:
  1946. // Keep the original integer signedness
  1947. evalType = astContext.getExtVectorType(
  1948. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type,
  1949. 3);
  1950. break;
  1951. default:
  1952. // Only the semantic kinds mentioned above are handled.
  1953. break;
  1954. }
  1955. // Boolean stage I/O variables must be represented as unsigned integers.
  1956. // Boolean built-in variables are represented as bool.
  1957. if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  1958. evalType = getUintTypeWithSourceComponents(astContext, type);
  1959. }
  1960. // Handle the extra arrayness
  1961. if (arraySize != 0) {
  1962. evalType = astContext.getConstantArrayType(
  1963. evalType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
  1964. }
  1965. StageVar stageVar(
  1966. sigPoint, *semanticToUse, builtinAttr, evalType,
  1967. // For HS/DS/GS, we have already stripped the outmost arrayness on type.
  1968. getLocationCount(astContext, type));
  1969. const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
  1970. SpirvVariable *varInstr =
  1971. createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
  1972. if (!varInstr)
  1973. return false;
  1974. stageVar.setSpirvInstr(varInstr);
  1975. stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
  1976. stageVar.setIndexAttr(decl->getAttr<VKIndexAttr>());
  1977. stageVars.push_back(stageVar);
  1978. // Emit OpDecorate* instructions to link this stage variable with the HLSL
  1979. // semantic it is created for
  1980. spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
  1981. // We have semantics attached to this decl, which means it must be a
  1982. // function/parameter/variable. All are DeclaratorDecls.
  1983. stageVarInstructions[cast<DeclaratorDecl>(decl)] = varInstr;
  1984. // Special case: The DX12 SV_InstanceID always counts from 0, even if the
  1985. // StartInstanceLocation parameter is non-zero. gl_InstanceIndex, however,
  1986. // starts from firstInstance. Thus it doesn't emulate actual DX12 shader
  1987. // behavior. To make it equivalent, SPIR-V codegen should emit:
  1988. // SV_InstanceID = gl_InstanceIndex - gl_BaseInstance
  1989. // Unfortunately, this means that there is no 1-to-1 mapping of the HLSL
  1990. // semantic to the SPIR-V builtin. As a result, we have to manually create
  1991. // a second stage variable for this specific case.
  1992. //
  1993. // According to the Vulkan spec on builtin variables:
  1994. // www.khronos.org/registry/vulkan/specs/1.1-extensions/html/vkspec.html#interfaces-builtin-variables
  1995. //
  1996. // InstanceIndex:
  1997. // Decorating a variable in a vertex shader with the InstanceIndex
  1998. // built-in decoration will make that variable contain the index of the
  1999. // instance that is being processed by the current vertex shader
  2000. // invocation. InstanceIndex begins at the firstInstance.
  2001. // BaseInstance
  2002. // Decorating a variable with the BaseInstance built-in will make that
  2003. // variable contain the integer value corresponding to the first instance
  2004. // that was passed to the command that invoked the current vertex shader
  2005. // invocation. BaseInstance is the firstInstance parameter to a direct
  2006. // drawing command or the firstInstance member of a structure consumed by
  2007. // an indirect drawing command.
  2008. if (spirvOptions.supportNonzeroBaseInstance && asInput &&
  2009. semanticKind == hlsl::Semantic::Kind::InstanceID &&
  2010. sigPointKind == hlsl::SigPoint::Kind::VSIn) {
  2011. // The above call to createSpirvStageVar creates the gl_InstanceIndex.
  2012. // We should now manually create the gl_BaseInstance variable and do the
  2013. // subtraction.
  2014. auto *instanceIndexVar = varInstr;
  2015. auto *baseInstanceVar = spvBuilder.addStageBuiltinVar(
  2016. type, spv::StorageClass::Input, spv::BuiltIn::BaseInstance,
  2017. decl->hasAttr<HLSLPreciseAttr>(), semanticToUse->loc);
  2018. StageVar stageVar2(sigPoint, *semanticToUse, builtinAttr, evalType,
  2019. getLocationCount(astContext, type));
  2020. stageVar2.setSpirvInstr(baseInstanceVar);
  2021. stageVar2.setLocationAttr(decl->getAttr<VKLocationAttr>());
  2022. stageVar2.setIndexAttr(decl->getAttr<VKIndexAttr>());
  2023. stageVar2.setIsSpirvBuiltin();
  2024. stageVars.push_back(stageVar2);
  2025. // SPIR-V code fore 'SV_InstanceID = gl_InstanceIndex - gl_BaseInstance'
  2026. auto *instanceIdVar =
  2027. spvBuilder.addFnVar(type, semanticToUse->loc, "SV_InstanceID");
  2028. auto *instanceIndexValue =
  2029. spvBuilder.createLoad(type, instanceIndexVar, semanticToUse->loc);
  2030. auto *baseInstanceValue =
  2031. spvBuilder.createLoad(type, baseInstanceVar, semanticToUse->loc);
  2032. auto *instanceIdValue =
  2033. spvBuilder.createBinaryOp(spv::Op::OpISub, type, instanceIndexValue,
  2034. baseInstanceValue, semanticToUse->loc);
  2035. spvBuilder.createStore(instanceIdVar, instanceIdValue,
  2036. semanticToUse->loc);
  2037. stageVarInstructions[cast<DeclaratorDecl>(decl)] = instanceIdVar;
  2038. varInstr = instanceIdVar;
  2039. }
  2040. // Mark that we have used one index for this semantic
  2041. ++semanticToUse->index;
  2042. // TODO: the following may not be correct?
  2043. if (sigPoint->GetSignatureKind() ==
  2044. hlsl::DXIL::SignatureKind::PatchConstOrPrim) {
  2045. if (sigPointKind == hlsl::SigPoint::Kind::MSPOut) {
  2046. // Decorate with PerPrimitiveNV for per-primitive out variables.
  2047. spvBuilder.decoratePerPrimitiveNV(varInstr,
  2048. varInstr->getSourceLocation());
  2049. } else {
  2050. spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
  2051. }
  2052. }
  2053. // Decorate with interpolation modes for pixel shader input variables
  2054. // or vertex shader output variables.
  2055. if (((spvContext.isPS() && sigPoint->IsInput()) ||
  2056. (spvContext.isVS() && sigPoint->IsOutput())) &&
  2057. // BaryCoord*AMD buitins already encode the interpolation mode.
  2058. semanticKind != hlsl::Semantic::Kind::Barycentrics)
  2059. decorateInterpolationMode(decl, type, varInstr);
  2060. if (asInput) {
  2061. *value = spvBuilder.createLoad(evalType, varInstr, loc);
  2062. // Fix ups for corner cases
  2063. // Special handling of SV_TessFactor DS patch constant input.
  2064. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  2065. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  2066. // relevant indexes must be loaded.
  2067. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  2068. hlsl::GetArraySize(type) != 4) {
  2069. llvm::SmallVector<SpirvInstruction *, 4> components;
  2070. const auto tessFactorSize = hlsl::GetArraySize(type);
  2071. const auto arrType = astContext.getConstantArrayType(
  2072. astContext.FloatTy, llvm::APInt(32, tessFactorSize),
  2073. clang::ArrayType::Normal, 0);
  2074. for (uint32_t i = 0; i < tessFactorSize; ++i)
  2075. components.push_back(spvBuilder.createCompositeExtract(
  2076. astContext.FloatTy, *value, {i}, thisSemantic.loc));
  2077. *value = spvBuilder.createCompositeConstruct(arrType, components,
  2078. thisSemantic.loc);
  2079. }
  2080. // Special handling of SV_InsideTessFactor DS patch constant input.
  2081. // TessLevelInner is always an array of size 2 in SPIR-V, but
  2082. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  2083. // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of
  2084. // TessLevelInner.
  2085. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  2086. // Some developers use float[1] instead of a scalar float.
  2087. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  2088. *value = spvBuilder.createCompositeExtract(astContext.FloatTy, *value,
  2089. {0}, thisSemantic.loc);
  2090. if (type->isArrayType()) { // float[1]
  2091. const auto arrType = astContext.getConstantArrayType(
  2092. astContext.FloatTy, llvm::APInt(32, 1), clang::ArrayType::Normal,
  2093. 0);
  2094. *value = spvBuilder.createCompositeConstruct(arrType, {*value},
  2095. thisSemantic.loc);
  2096. }
  2097. }
  2098. // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord
  2099. // is always a float3. To ensure SPIR-V validity, a float3 stage variable
  2100. // is created, and we must extract a float2 from it before passing it to
  2101. // the main function.
  2102. else if (semanticKind == hlsl::Semantic::Kind::DomainLocation &&
  2103. hlsl::GetHLSLVecSize(type) != 3) {
  2104. const auto domainLocSize = hlsl::GetHLSLVecSize(type);
  2105. *value = spvBuilder.createVectorShuffle(
  2106. astContext.getExtVectorType(astContext.FloatTy, domainLocSize),
  2107. *value, *value, {0, 1}, thisSemantic.loc);
  2108. }
  2109. // Special handling of SV_Coverage, which is an uint value. We need to
  2110. // read SampleMask and extract its first element.
  2111. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  2112. *value = spvBuilder.createCompositeExtract(type, *value, {0},
  2113. thisSemantic.loc);
  2114. }
  2115. // Special handling of SV_InnerCoverage, which is an uint value. We need
  2116. // to read FullyCoveredEXT, which is a boolean value, and convert it to an
  2117. // uint value. According to D3D12 "Conservative Rasterization" doc: "The
  2118. // Pixel Shader has a 32-bit scalar integer System Generate Value
  2119. // available: InnerCoverage. This is a bit-field that has bit 0 from the
  2120. // LSB set to 1 for a given conservatively rasterized pixel, only when
  2121. // that pixel is guaranteed to be entirely inside the current primitive.
  2122. // All other input register bits must be set to 0 when bit 0 is not set,
  2123. // but are undefined when bit 0 is set to 1 (essentially, this bit-field
  2124. // represents a Boolean value where false must be exactly 0, but true can
  2125. // be any odd (i.e. bit 0 set) non-zero value)."
  2126. else if (semanticKind == hlsl::Semantic::Kind::InnerCoverage) {
  2127. const auto constOne = spvBuilder.getConstantInt(
  2128. astContext.UnsignedIntTy, llvm::APInt(32, 1));
  2129. const auto constZero = spvBuilder.getConstantInt(
  2130. astContext.UnsignedIntTy, llvm::APInt(32, 0));
  2131. *value = spvBuilder.createSelect(astContext.UnsignedIntTy, *value,
  2132. constOne, constZero, thisSemantic.loc);
  2133. }
  2134. // Special handling of SV_Barycentrics, which is a float3, but the
  2135. // underlying stage input variable is a float2 (only provides the first
  2136. // two components). Calculate the third element.
  2137. else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) {
  2138. const auto x = spvBuilder.createCompositeExtract(
  2139. astContext.FloatTy, *value, {0}, thisSemantic.loc);
  2140. const auto y = spvBuilder.createCompositeExtract(
  2141. astContext.FloatTy, *value, {1}, thisSemantic.loc);
  2142. const auto xy = spvBuilder.createBinaryOp(
  2143. spv::Op::OpFAdd, astContext.FloatTy, x, y, thisSemantic.loc);
  2144. const auto z = spvBuilder.createBinaryOp(
  2145. spv::Op::OpFSub, astContext.FloatTy,
  2146. spvBuilder.getConstantFloat(astContext.FloatTy,
  2147. llvm::APFloat(1.0f)),
  2148. xy, thisSemantic.loc);
  2149. *value = spvBuilder.createCompositeConstruct(
  2150. astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z},
  2151. thisSemantic.loc);
  2152. }
  2153. // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may
  2154. // be a uint or uint2, but the underlying stage input variable is a uint3.
  2155. // The last component(s) should be discarded in needed.
  2156. else if ((semanticKind == hlsl::Semantic::Kind::DispatchThreadID ||
  2157. semanticKind == hlsl::Semantic::Kind::GroupThreadID ||
  2158. semanticKind == hlsl::Semantic::Kind::GroupID) &&
  2159. (!hlsl::IsHLSLVecType(type) ||
  2160. hlsl::GetHLSLVecSize(type) != 3)) {
  2161. const auto srcVecElemType = hlsl::IsHLSLVecType(type)
  2162. ? hlsl::GetHLSLVecElementType(type)
  2163. : type;
  2164. const auto vecSize =
  2165. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1;
  2166. if (vecSize == 1)
  2167. *value = spvBuilder.createCompositeExtract(srcVecElemType, *value,
  2168. {0}, thisSemantic.loc);
  2169. else if (vecSize == 2)
  2170. *value = spvBuilder.createVectorShuffle(
  2171. astContext.getExtVectorType(srcVecElemType, 2), *value, *value,
  2172. {0, 1}, thisSemantic.loc);
  2173. }
  2174. // Reciprocate SV_Position.w if requested
  2175. if (semanticKind == hlsl::Semantic::Kind::Position)
  2176. *value = invertWIfRequested(*value, thisSemantic.loc);
  2177. // Since boolean stage input variables are represented as unsigned
  2178. // integers, after loading them, we should cast them to boolean.
  2179. if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  2180. *value =
  2181. theEmitter.castToType(*value, evalType, type, thisSemantic.loc);
  2182. }
  2183. } else {
  2184. if (noWriteBack)
  2185. return true;
  2186. // Negate SV_Position.y if requested
  2187. if (semanticKind == hlsl::Semantic::Kind::Position)
  2188. *value = invertYIfRequested(*value, thisSemantic.loc);
  2189. SpirvInstruction *ptr = varInstr;
  2190. // Special handling of SV_TessFactor HS patch constant output.
  2191. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  2192. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  2193. // relevant indexes must be written to.
  2194. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  2195. hlsl::GetArraySize(type) != 4) {
  2196. const auto tessFactorSize = hlsl::GetArraySize(type);
  2197. for (uint32_t i = 0; i < tessFactorSize; ++i) {
  2198. ptr = spvBuilder.createAccessChain(
  2199. astContext.FloatTy, varInstr,
  2200. {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  2201. llvm::APInt(32, i))},
  2202. thisSemantic.loc);
  2203. spvBuilder.createStore(
  2204. ptr,
  2205. spvBuilder.createCompositeExtract(astContext.FloatTy, *value, {i},
  2206. thisSemantic.loc),
  2207. thisSemantic.loc);
  2208. }
  2209. }
  2210. // Special handling of SV_InsideTessFactor HS patch constant output.
  2211. // TessLevelInner is always an array of size 2 in SPIR-V, but
  2212. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  2213. // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of
  2214. // TessLevelInner.
  2215. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  2216. // Some developers use float[1] instead of a scalar float.
  2217. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  2218. ptr = spvBuilder.createAccessChain(
  2219. astContext.FloatTy, varInstr,
  2220. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  2221. llvm::APInt(32, 0)),
  2222. thisSemantic.loc);
  2223. if (type->isArrayType()) // float[1]
  2224. *value = spvBuilder.createCompositeExtract(astContext.FloatTy, *value,
  2225. {0}, thisSemantic.loc);
  2226. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  2227. }
  2228. // Special handling of SV_Coverage, which is an unit value. We need to
  2229. // write it to the first element in the SampleMask builtin.
  2230. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  2231. ptr = spvBuilder.createAccessChain(
  2232. type, varInstr,
  2233. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  2234. llvm::APInt(32, 0)),
  2235. thisSemantic.loc);
  2236. ptr->setStorageClass(spv::StorageClass::Output);
  2237. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  2238. }
  2239. // Special handling of HS ouput, for which we write to only one
  2240. // element in the per-vertex data array: the one indexed by
  2241. // SV_ControlPointID.
  2242. else if (invocationId.hasValue() && invocationId.getValue() != nullptr) {
  2243. // Remove the arrayness to get the element type.
  2244. assert(isa<ConstantArrayType>(evalType));
  2245. const auto elementType =
  2246. astContext.getAsArrayType(evalType)->getElementType();
  2247. auto index = invocationId.getValue();
  2248. ptr = spvBuilder.createAccessChain(elementType, varInstr, index,
  2249. thisSemantic.loc);
  2250. ptr->setStorageClass(spv::StorageClass::Output);
  2251. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  2252. }
  2253. // Since boolean output stage variables are represented as unsigned
  2254. // integers, we must cast the value to uint before storing.
  2255. else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  2256. *value =
  2257. theEmitter.castToType(*value, type, evalType, thisSemantic.loc);
  2258. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  2259. }
  2260. // For all normal cases
  2261. else {
  2262. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  2263. }
  2264. }
  2265. return true;
  2266. }
  2267. // If the decl itself doesn't have semantic string attached and there is no
  2268. // one to inherit, it should be a struct having all its fields with semantic
  2269. // strings.
  2270. if (!semanticToUse->isValid() && !type->isStructureType()) {
  2271. emitError("semantic string missing for shader %select{output|input}0 "
  2272. "variable '%1'",
  2273. loc)
  2274. << asInput << decl->getName();
  2275. return false;
  2276. }
  2277. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  2278. if (asInput) {
  2279. // If this decl translates into multiple stage input variables, we need to
  2280. // load their values into a composite.
  2281. llvm::SmallVector<SpirvInstruction *, 4> subValues;
  2282. // If we have base classes, we need to handle them first.
  2283. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  2284. for (auto base : cxxDecl->bases()) {
  2285. SpirvInstruction *subValue = nullptr;
  2286. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  2287. asInput, base.getType(), arraySize, namePrefix,
  2288. invocationId, &subValue, noWriteBack,
  2289. semanticToUse))
  2290. return false;
  2291. subValues.push_back(subValue);
  2292. }
  2293. }
  2294. for (const auto *field : structDecl->fields()) {
  2295. SpirvInstruction *subValue = nullptr;
  2296. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  2297. arraySize, namePrefix, invocationId, &subValue,
  2298. noWriteBack, semanticToUse))
  2299. return false;
  2300. subValues.push_back(subValue);
  2301. }
  2302. if (arraySize == 0) {
  2303. *value = spvBuilder.createCompositeConstruct(evalType, subValues, loc);
  2304. return true;
  2305. }
  2306. // Handle the extra level of arrayness.
  2307. // We need to return an array of structs. But we get arrays of fields
  2308. // from visiting all fields. So now we need to extract all the elements
  2309. // at the same index of each field arrays and compose a new struct out
  2310. // of them.
  2311. const auto structType = type;
  2312. const auto arrayType = astContext.getConstantArrayType(
  2313. structType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
  2314. llvm::SmallVector<SpirvInstruction *, 16> arrayElements;
  2315. for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) {
  2316. llvm::SmallVector<SpirvInstruction *, 8> fields;
  2317. // If we have base classes, we need to handle them first.
  2318. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  2319. uint32_t baseIndex = 0;
  2320. for (auto base : cxxDecl->bases()) {
  2321. const auto baseType = base.getType();
  2322. fields.push_back(spvBuilder.createCompositeExtract(
  2323. baseType, subValues[baseIndex++], {arrayIndex}, loc));
  2324. }
  2325. }
  2326. // Extract the element at index arrayIndex from each field
  2327. for (const auto *field : structDecl->fields()) {
  2328. const auto fieldType = field->getType();
  2329. fields.push_back(spvBuilder.createCompositeExtract(
  2330. fieldType,
  2331. subValues[getNumBaseClasses(type) + field->getFieldIndex()],
  2332. {arrayIndex}, loc));
  2333. }
  2334. // Compose a new struct out of them
  2335. arrayElements.push_back(
  2336. spvBuilder.createCompositeConstruct(structType, fields, loc));
  2337. }
  2338. *value = spvBuilder.createCompositeConstruct(arrayType, arrayElements, loc);
  2339. } else {
  2340. // If we have base classes, we need to handle them first.
  2341. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  2342. uint32_t baseIndex = 0;
  2343. for (auto base : cxxDecl->bases()) {
  2344. SpirvInstruction *subValue = nullptr;
  2345. if (!noWriteBack)
  2346. subValue = spvBuilder.createCompositeExtract(base.getType(), *value,
  2347. {baseIndex++}, loc);
  2348. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  2349. asInput, base.getType(), arraySize, namePrefix,
  2350. invocationId, &subValue, noWriteBack,
  2351. semanticToUse))
  2352. return false;
  2353. }
  2354. }
  2355. // Unlike reading, which may require us to read stand-alone builtins and
  2356. // stage input variables and compose an array of structs out of them,
  2357. // it happens that we don't need to write an array of structs in a bunch
  2358. // for all shader stages:
  2359. //
  2360. // * VS: output is a single struct, without extra arrayness
  2361. // * HS: output is an array of structs, with extra arrayness,
  2362. // but we only write to the struct at the InvocationID index
  2363. // * DS: output is a single struct, without extra arrayness
  2364. // * GS: output is controlled by OpEmitVertex, one vertex per time
  2365. // * MS: output is an array of structs, with extra arrayness
  2366. //
  2367. // The interesting shader stage is HS. We need the InvocationID to write
  2368. // out the value to the correct array element.
  2369. for (const auto *field : structDecl->fields()) {
  2370. const auto fieldType = field->getType();
  2371. SpirvInstruction *subValue = nullptr;
  2372. if (!noWriteBack)
  2373. subValue = spvBuilder.createCompositeExtract(
  2374. fieldType, *value,
  2375. {getNumBaseClasses(type) + field->getFieldIndex()}, loc);
  2376. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  2377. arraySize, namePrefix, invocationId, &subValue,
  2378. noWriteBack, semanticToUse))
  2379. return false;
  2380. }
  2381. }
  2382. return true;
  2383. }
  2384. bool DeclResultIdMapper::createPayloadStageVars(
  2385. const hlsl::SigPoint *sigPoint, spv::StorageClass sc, const NamedDecl *decl,
  2386. bool asInput, QualType type, const llvm::StringRef namePrefix,
  2387. SpirvInstruction **value, uint32_t payloadMemOffset) {
  2388. assert(spvContext.isMS() || spvContext.isAS());
  2389. assert(value);
  2390. if (type->isVoidType()) {
  2391. // No stage variables will be created for void type.
  2392. return true;
  2393. }
  2394. const auto loc = decl->getLocation();
  2395. if (!type->isStructureType()) {
  2396. StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
  2397. getLocationCount(astContext, type));
  2398. const auto name = namePrefix.str() + "." + decl->getNameAsString();
  2399. SpirvVariable *varInstr =
  2400. spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false, loc);
  2401. if (!varInstr)
  2402. return false;
  2403. // Even though these as user defined IO stage variables, set them as SPIR-V
  2404. // builtins in order to bypass any semantic string checks and location
  2405. // assignment.
  2406. stageVar.setIsSpirvBuiltin();
  2407. stageVar.setSpirvInstr(varInstr);
  2408. stageVars.push_back(stageVar);
  2409. // Decorate with PerTaskNV for mesh/amplification shader payload variables.
  2410. spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
  2411. varInstr->getSourceLocation());
  2412. if (asInput) {
  2413. *value = spvBuilder.createLoad(type, varInstr, loc);
  2414. } else {
  2415. spvBuilder.createStore(varInstr, *value, loc);
  2416. }
  2417. return true;
  2418. }
  2419. // This decl translates into multiple stage input/output payload variables
  2420. // and we need to load/store these individual member variables.
  2421. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  2422. llvm::SmallVector<SpirvInstruction *, 4> subValues;
  2423. AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
  2424. uint32_t nextMemberOffset = 0;
  2425. for (const auto *field : structDecl->fields()) {
  2426. const auto fieldType = field->getType();
  2427. SpirvInstruction *subValue = nullptr;
  2428. uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
  2429. // The next avaiable offset after laying out the previous members.
  2430. std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
  2431. field->getType(), spirvOptions.ampPayloadLayoutRule,
  2432. /*isRowMajor*/ llvm::None, &stride);
  2433. alignmentCalc.alignUsingHLSLRelaxedLayout(
  2434. field->getType(), memberSize, memberAlignment, &nextMemberOffset);
  2435. // The vk::offset attribute takes precedence over all.
  2436. if (field->getAttr<VKOffsetAttr>()) {
  2437. nextMemberOffset = field->getAttr<VKOffsetAttr>()->getOffset();
  2438. }
  2439. // Each payload member must have an Offset Decoration.
  2440. payloadMemOffset = nextMemberOffset;
  2441. nextMemberOffset += memberSize;
  2442. if (!asInput) {
  2443. subValue = spvBuilder.createCompositeExtract(
  2444. fieldType, *value, {getNumBaseClasses(type) + field->getFieldIndex()},
  2445. loc);
  2446. }
  2447. if (!createPayloadStageVars(sigPoint, sc, field, asInput, field->getType(),
  2448. namePrefix, &subValue, payloadMemOffset))
  2449. return false;
  2450. if (asInput) {
  2451. subValues.push_back(subValue);
  2452. }
  2453. }
  2454. if (asInput) {
  2455. *value = spvBuilder.createCompositeConstruct(type, subValues, loc);
  2456. }
  2457. return true;
  2458. }
  2459. bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
  2460. QualType type,
  2461. SpirvInstruction *value) {
  2462. assert(spvContext.isGS()); // Only for GS use
  2463. if (hlsl::IsHLSLStreamOutputType(type))
  2464. type = hlsl::GetHLSLResourceResultType(type);
  2465. if (hasGSPrimitiveTypeQualifier(decl))
  2466. type = astContext.getAsConstantArrayType(type)->getElementType();
  2467. auto semanticInfo = getStageVarSemantic(decl);
  2468. const auto loc = decl->getLocation();
  2469. if (semanticInfo.isValid()) {
  2470. // Found semantic attached directly to this Decl. Write the value for this
  2471. // Decl to the corresponding stage output variable.
  2472. // Handle SV_ClipDistance, and SV_CullDistance
  2473. if (glPerVertex.tryToAccess(
  2474. hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
  2475. semanticInfo.index, llvm::None, &value,
  2476. /*noWriteBack=*/false, /*vecComponent=*/nullptr, loc))
  2477. return true;
  2478. // Query the <result-id> for the stage output variable generated out
  2479. // of this decl.
  2480. // We have semantic string attached to this decl; therefore, it must be a
  2481. // DeclaratorDecl.
  2482. const auto found = stageVarInstructions.find(cast<DeclaratorDecl>(decl));
  2483. // We should have recorded its stage output variable previously.
  2484. assert(found != stageVarInstructions.end());
  2485. // Negate SV_Position.y if requested
  2486. if (semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position)
  2487. value = invertYIfRequested(value, loc);
  2488. // Boolean stage output variables are represented as unsigned integers.
  2489. if (isBooleanStageIOVar(decl, type, semanticInfo.semantic->GetKind(),
  2490. hlsl::SigPoint::Kind::GSOut)) {
  2491. QualType uintType = getUintTypeWithSourceComponents(astContext, type);
  2492. value = theEmitter.castToType(value, type, uintType, loc);
  2493. }
  2494. spvBuilder.createStore(found->second, value, loc);
  2495. return true;
  2496. }
  2497. // If the decl itself doesn't have semantic string attached, it should be
  2498. // a struct having all its fields with semantic strings.
  2499. if (!type->isStructureType()) {
  2500. emitError("semantic string missing for shader output variable '%0'", loc)
  2501. << decl->getName();
  2502. return false;
  2503. }
  2504. // If we have base classes, we need to handle them first.
  2505. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  2506. uint32_t baseIndex = 0;
  2507. for (auto base : cxxDecl->bases()) {
  2508. auto *subValue = spvBuilder.createCompositeExtract(base.getType(), value,
  2509. {baseIndex++}, loc);
  2510. if (!writeBackOutputStream(base.getType()->getAsCXXRecordDecl(),
  2511. base.getType(), subValue))
  2512. return false;
  2513. }
  2514. }
  2515. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  2516. // Write out each field
  2517. for (const auto *field : structDecl->fields()) {
  2518. const auto fieldType = field->getType();
  2519. auto *subValue = spvBuilder.createCompositeExtract(
  2520. fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()},
  2521. loc);
  2522. if (!writeBackOutputStream(field, field->getType(), subValue))
  2523. return false;
  2524. }
  2525. return true;
  2526. }
  2527. SpirvInstruction *
  2528. DeclResultIdMapper::invertYIfRequested(SpirvInstruction *position,
  2529. SourceLocation loc) {
  2530. // Negate SV_Position.y if requested
  2531. if (spirvOptions.invertY) {
  2532. const auto oldY = spvBuilder.createCompositeExtract(astContext.FloatTy,
  2533. position, {1}, loc);
  2534. const auto newY = spvBuilder.createUnaryOp(spv::Op::OpFNegate,
  2535. astContext.FloatTy, oldY, loc);
  2536. position = spvBuilder.createCompositeInsert(
  2537. astContext.getExtVectorType(astContext.FloatTy, 4), position, {1}, newY,
  2538. loc);
  2539. }
  2540. return position;
  2541. }
  2542. SpirvInstruction *
  2543. DeclResultIdMapper::invertWIfRequested(SpirvInstruction *position,
  2544. SourceLocation loc) {
  2545. // Reciprocate SV_Position.w if requested
  2546. if (spirvOptions.invertW && spvContext.isPS()) {
  2547. const auto oldW = spvBuilder.createCompositeExtract(astContext.FloatTy,
  2548. position, {3}, loc);
  2549. const auto newW = spvBuilder.createBinaryOp(
  2550. spv::Op::OpFDiv, astContext.FloatTy,
  2551. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)),
  2552. oldW, loc);
  2553. position = spvBuilder.createCompositeInsert(
  2554. astContext.getExtVectorType(astContext.FloatTy, 4), position, {3}, newW,
  2555. loc);
  2556. }
  2557. return position;
  2558. }
  2559. void DeclResultIdMapper::decorateInterpolationMode(const NamedDecl *decl,
  2560. QualType type,
  2561. SpirvVariable *varInstr) {
  2562. const auto loc = decl->getLocation();
  2563. if (isUintOrVecMatOfUintType(type) || isSintOrVecMatOfSintType(type) ||
  2564. isBoolOrVecMatOfBoolType(type)) {
  2565. // TODO: Probably we can call hlsl::ValidateSignatureElement() for the
  2566. // following check.
  2567. if (decl->getAttr<HLSLLinearAttr>() || decl->getAttr<HLSLCentroidAttr>() ||
  2568. decl->getAttr<HLSLNoPerspectiveAttr>() ||
  2569. decl->getAttr<HLSLSampleAttr>()) {
  2570. emitError("only nointerpolation mode allowed for integer input "
  2571. "parameters in pixel shader or integer output in vertex shader",
  2572. decl->getLocation());
  2573. } else {
  2574. spvBuilder.decorateFlat(varInstr, loc);
  2575. }
  2576. } else {
  2577. // Do nothing for HLSLLinearAttr since its the default
  2578. // Attributes can be used together. So cannot use else if.
  2579. if (decl->getAttr<HLSLCentroidAttr>())
  2580. spvBuilder.decorateCentroid(varInstr, loc);
  2581. if (decl->getAttr<HLSLNoInterpolationAttr>())
  2582. spvBuilder.decorateFlat(varInstr, loc);
  2583. if (decl->getAttr<HLSLNoPerspectiveAttr>())
  2584. spvBuilder.decorateNoPerspective(varInstr, loc);
  2585. if (decl->getAttr<HLSLSampleAttr>()) {
  2586. spvBuilder.decorateSample(varInstr, loc);
  2587. }
  2588. }
  2589. }
  2590. SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
  2591. QualType type,
  2592. SourceLocation loc) {
  2593. // Guarantee uniqueness
  2594. uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
  2595. const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
  2596. if (builtInVar != builtinToVarMap.end()) {
  2597. return builtInVar->second;
  2598. }
  2599. spv::StorageClass sc = spv::StorageClass::Max;
  2600. // Valid builtins supported
  2601. switch (builtIn) {
  2602. case spv::BuiltIn::SubgroupSize:
  2603. case spv::BuiltIn::SubgroupLocalInvocationId:
  2604. case spv::BuiltIn::HitTNV:
  2605. case spv::BuiltIn::RayTmaxNV:
  2606. case spv::BuiltIn::RayTminNV:
  2607. case spv::BuiltIn::HitKindNV:
  2608. case spv::BuiltIn::IncomingRayFlagsNV:
  2609. case spv::BuiltIn::InstanceCustomIndexNV:
  2610. case spv::BuiltIn::RayGeometryIndexKHR:
  2611. case spv::BuiltIn::PrimitiveId:
  2612. case spv::BuiltIn::InstanceId:
  2613. case spv::BuiltIn::WorldRayDirectionNV:
  2614. case spv::BuiltIn::WorldRayOriginNV:
  2615. case spv::BuiltIn::ObjectRayDirectionNV:
  2616. case spv::BuiltIn::ObjectRayOriginNV:
  2617. case spv::BuiltIn::ObjectToWorldNV:
  2618. case spv::BuiltIn::WorldToObjectNV:
  2619. case spv::BuiltIn::LaunchIdNV:
  2620. case spv::BuiltIn::LaunchSizeNV:
  2621. case spv::BuiltIn::GlobalInvocationId:
  2622. case spv::BuiltIn::WorkgroupId:
  2623. case spv::BuiltIn::LocalInvocationIndex:
  2624. sc = spv::StorageClass::Input;
  2625. break;
  2626. case spv::BuiltIn::PrimitiveCountNV:
  2627. case spv::BuiltIn::PrimitiveIndicesNV:
  2628. case spv::BuiltIn::TaskCountNV:
  2629. sc = spv::StorageClass::Output;
  2630. break;
  2631. default:
  2632. assert(false && "unsupported SPIR-V builtin");
  2633. return nullptr;
  2634. }
  2635. // Create a dummy StageVar for this builtin variable
  2636. auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
  2637. /*isPrecise*/ false, loc);
  2638. const hlsl::SigPoint *sigPoint =
  2639. hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  2640. hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
  2641. /*isPatchConstant=*/false));
  2642. StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
  2643. /*locCount=*/0);
  2644. stageVar.setIsSpirvBuiltin();
  2645. stageVar.setSpirvInstr(var);
  2646. stageVars.push_back(stageVar);
  2647. // Store in map for re-use
  2648. builtinToVarMap[spvBuiltinId] = var;
  2649. return var;
  2650. }
  2651. SpirvVariable *DeclResultIdMapper::createSpirvIntermediateOutputStageVar(
  2652. const NamedDecl *decl, const llvm::StringRef name, QualType type) {
  2653. const auto *semantic = hlsl::Semantic::GetByName(name);
  2654. SemanticInfo thisSemantic{name, semantic, name, 0, decl->getLocation()};
  2655. const auto *sigPoint =
  2656. deduceSigPoint(cast<DeclaratorDecl>(decl), /*asInput=*/false,
  2657. spvContext.getCurrentShaderModelKind(), /*forPCF=*/false);
  2658. StageVar stageVar(sigPoint, thisSemantic, decl->getAttr<VKBuiltInAttr>(),
  2659. type, /*locCount=*/1);
  2660. SpirvVariable *varInstr =
  2661. createSpirvStageVar(&stageVar, decl, name, thisSemantic.loc);
  2662. if (!varInstr)
  2663. return nullptr;
  2664. stageVar.setSpirvInstr(varInstr);
  2665. stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
  2666. stageVar.setIndexAttr(decl->getAttr<VKIndexAttr>());
  2667. stageVars.push_back(stageVar);
  2668. // Emit OpDecorate* instructions to link this stage variable with the HLSL
  2669. // semantic it is created for.
  2670. spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
  2671. // We have semantics attached to this decl, which means it must be a
  2672. // function/parameter/variable. All are DeclaratorDecls.
  2673. stageVarInstructions[cast<DeclaratorDecl>(decl)] = varInstr;
  2674. return varInstr;
  2675. }
  2676. SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
  2677. StageVar *stageVar, const NamedDecl *decl, const llvm::StringRef name,
  2678. SourceLocation srcLoc) {
  2679. using spv::BuiltIn;
  2680. const auto sigPoint = stageVar->getSigPoint();
  2681. const auto semanticKind = stageVar->getSemanticInfo().getKind();
  2682. const auto sigPointKind = sigPoint->GetKind();
  2683. const auto type = stageVar->getAstType();
  2684. const auto isPrecise = decl->hasAttr<HLSLPreciseAttr>();
  2685. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  2686. if (sc == spv::StorageClass::Max)
  2687. return 0;
  2688. stageVar->setStorageClass(sc);
  2689. // [[vk::builtin(...)]] takes precedence.
  2690. if (const auto *builtinAttr = stageVar->getBuiltInAttr()) {
  2691. const auto spvBuiltIn =
  2692. llvm::StringSwitch<BuiltIn>(builtinAttr->getBuiltIn())
  2693. .Case("PointSize", BuiltIn::PointSize)
  2694. .Case("HelperInvocation", BuiltIn::HelperInvocation)
  2695. .Case("BaseVertex", BuiltIn::BaseVertex)
  2696. .Case("BaseInstance", BuiltIn::BaseInstance)
  2697. .Case("DrawIndex", BuiltIn::DrawIndex)
  2698. .Case("DeviceIndex", BuiltIn::DeviceIndex)
  2699. .Case("ViewportMaskNV", BuiltIn::ViewportMaskNV)
  2700. .Default(BuiltIn::Max);
  2701. assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
  2702. return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn, isPrecise,
  2703. srcLoc);
  2704. }
  2705. // The following translation assumes that semantic validity in the current
  2706. // shader model is already checked, so it only covers valid SigPoints for
  2707. // each semantic.
  2708. switch (semanticKind) {
  2709. // According to DXIL spec, the Position SV can be used by all SigPoints
  2710. // other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
  2711. // According to Vulkan spec, the Position BuiltIn can only be used
  2712. // by VSOut, HS/DS/GS In/Out, MSOut.
  2713. case hlsl::Semantic::Kind::Position: {
  2714. switch (sigPointKind) {
  2715. case hlsl::SigPoint::Kind::VSIn:
  2716. case hlsl::SigPoint::Kind::PCOut:
  2717. case hlsl::SigPoint::Kind::DSIn:
  2718. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2719. case hlsl::SigPoint::Kind::VSOut:
  2720. case hlsl::SigPoint::Kind::HSCPIn:
  2721. case hlsl::SigPoint::Kind::HSCPOut:
  2722. case hlsl::SigPoint::Kind::DSCPIn:
  2723. case hlsl::SigPoint::Kind::DSOut:
  2724. case hlsl::SigPoint::Kind::GSVIn:
  2725. case hlsl::SigPoint::Kind::GSOut:
  2726. case hlsl::SigPoint::Kind::MSOut:
  2727. stageVar->setIsSpirvBuiltin();
  2728. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
  2729. isPrecise, srcLoc);
  2730. case hlsl::SigPoint::Kind::PSIn:
  2731. stageVar->setIsSpirvBuiltin();
  2732. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord,
  2733. isPrecise, srcLoc);
  2734. default:
  2735. llvm_unreachable("invalid usage of SV_Position sneaked in");
  2736. }
  2737. }
  2738. // According to DXIL spec, the VertexID SV can only be used by VSIn.
  2739. // According to Vulkan spec, the VertexIndex BuiltIn can only be used by
  2740. // VSIn.
  2741. case hlsl::Semantic::Kind::VertexID: {
  2742. stageVar->setIsSpirvBuiltin();
  2743. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex,
  2744. isPrecise, srcLoc);
  2745. }
  2746. // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
  2747. // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
  2748. // According to Vulkan spec, the InstanceIndex BuitIn can only be used by
  2749. // VSIn.
  2750. case hlsl::Semantic::Kind::InstanceID: {
  2751. switch (sigPointKind) {
  2752. case hlsl::SigPoint::Kind::VSIn:
  2753. stageVar->setIsSpirvBuiltin();
  2754. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex,
  2755. isPrecise, srcLoc);
  2756. case hlsl::SigPoint::Kind::VSOut:
  2757. case hlsl::SigPoint::Kind::HSCPIn:
  2758. case hlsl::SigPoint::Kind::HSCPOut:
  2759. case hlsl::SigPoint::Kind::DSCPIn:
  2760. case hlsl::SigPoint::Kind::DSOut:
  2761. case hlsl::SigPoint::Kind::GSVIn:
  2762. case hlsl::SigPoint::Kind::GSOut:
  2763. case hlsl::SigPoint::Kind::PSIn:
  2764. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2765. default:
  2766. llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
  2767. }
  2768. }
  2769. // According to DXIL spec, the Depth{|GreaterEqual|LessEqual} SV can only be
  2770. // used by PSOut.
  2771. // According to Vulkan spec, the FragDepth BuiltIn can only be used by PSOut.
  2772. case hlsl::Semantic::Kind::Depth:
  2773. case hlsl::Semantic::Kind::DepthGreaterEqual:
  2774. case hlsl::Semantic::Kind::DepthLessEqual: {
  2775. stageVar->setIsSpirvBuiltin();
  2776. // Vulkan requires the DepthReplacing execution mode to write to FragDepth.
  2777. spvBuilder.addExecutionMode(entryFunction,
  2778. spv::ExecutionMode::DepthReplacing, {}, srcLoc);
  2779. if (semanticKind == hlsl::Semantic::Kind::DepthGreaterEqual)
  2780. spvBuilder.addExecutionMode(entryFunction,
  2781. spv::ExecutionMode::DepthGreater, {}, srcLoc);
  2782. else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
  2783. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::DepthLess,
  2784. {}, srcLoc);
  2785. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth,
  2786. isPrecise, srcLoc);
  2787. }
  2788. // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
  2789. // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
  2790. // According to Vulkan spec, the ClipDistance/CullDistance
  2791. // BuiltIn can only be used by VSOut, HS/DS/GS In/Out, MSOut.
  2792. case hlsl::Semantic::Kind::ClipDistance:
  2793. case hlsl::Semantic::Kind::CullDistance: {
  2794. switch (sigPointKind) {
  2795. case hlsl::SigPoint::Kind::VSIn:
  2796. case hlsl::SigPoint::Kind::PCOut:
  2797. case hlsl::SigPoint::Kind::DSIn:
  2798. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2799. case hlsl::SigPoint::Kind::VSOut:
  2800. case hlsl::SigPoint::Kind::HSCPIn:
  2801. case hlsl::SigPoint::Kind::HSCPOut:
  2802. case hlsl::SigPoint::Kind::DSCPIn:
  2803. case hlsl::SigPoint::Kind::DSOut:
  2804. case hlsl::SigPoint::Kind::GSVIn:
  2805. case hlsl::SigPoint::Kind::GSOut:
  2806. case hlsl::SigPoint::Kind::PSIn:
  2807. case hlsl::SigPoint::Kind::MSOut:
  2808. llvm_unreachable("should be handled in gl_PerVertex struct");
  2809. default:
  2810. llvm_unreachable(
  2811. "invalid usage of SV_ClipDistance/SV_CullDistance sneaked in");
  2812. }
  2813. }
  2814. // According to DXIL spec, the IsFrontFace SV can only be used by GSOut and
  2815. // PSIn.
  2816. // According to Vulkan spec, the FrontFacing BuitIn can only be used in PSIn.
  2817. case hlsl::Semantic::Kind::IsFrontFace: {
  2818. switch (sigPointKind) {
  2819. case hlsl::SigPoint::Kind::GSOut:
  2820. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2821. case hlsl::SigPoint::Kind::PSIn:
  2822. stageVar->setIsSpirvBuiltin();
  2823. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing,
  2824. isPrecise, srcLoc);
  2825. default:
  2826. llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
  2827. }
  2828. }
  2829. // According to DXIL spec, the Target SV can only be used by PSOut.
  2830. // There is no corresponding builtin decoration in SPIR-V. So generate normal
  2831. // Vulkan stage input/output variables.
  2832. case hlsl::Semantic::Kind::Target:
  2833. // An arbitrary semantic is defined by users. Generate normal Vulkan stage
  2834. // input/output variables.
  2835. case hlsl::Semantic::Kind::Arbitrary: {
  2836. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2837. // TODO: patch constant function in hull shader
  2838. }
  2839. // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
  2840. // According to Vulkan spec, the GlobalInvocationId can only be used in CSIn.
  2841. case hlsl::Semantic::Kind::DispatchThreadID: {
  2842. stageVar->setIsSpirvBuiltin();
  2843. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId,
  2844. isPrecise, srcLoc);
  2845. }
  2846. // According to DXIL spec, the GroupID SV can only be used by CSIn.
  2847. // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
  2848. case hlsl::Semantic::Kind::GroupID: {
  2849. stageVar->setIsSpirvBuiltin();
  2850. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId,
  2851. isPrecise, srcLoc);
  2852. }
  2853. // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
  2854. // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
  2855. case hlsl::Semantic::Kind::GroupThreadID: {
  2856. stageVar->setIsSpirvBuiltin();
  2857. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId,
  2858. isPrecise, srcLoc);
  2859. }
  2860. // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
  2861. // According to Vulkan spec, the LocalInvocationIndex can only be used in
  2862. // CSIn.
  2863. case hlsl::Semantic::Kind::GroupIndex: {
  2864. stageVar->setIsSpirvBuiltin();
  2865. return spvBuilder.addStageBuiltinVar(
  2866. type, sc, BuiltIn::LocalInvocationIndex, isPrecise, srcLoc);
  2867. }
  2868. // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
  2869. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  2870. // HS/GS In.
  2871. case hlsl::Semantic::Kind::OutputControlPointID: {
  2872. stageVar->setIsSpirvBuiltin();
  2873. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
  2874. isPrecise, srcLoc);
  2875. }
  2876. // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
  2877. // DSIn, GSIn, GSOut, PSIn, and MSPOut.
  2878. // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
  2879. // HS/DS/PS In, GS In/Out, MSPOut.
  2880. case hlsl::Semantic::Kind::PrimitiveID: {
  2881. // Translate to PrimitiveId BuiltIn for all valid SigPoints.
  2882. stageVar->setIsSpirvBuiltin();
  2883. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId,
  2884. isPrecise, srcLoc);
  2885. }
  2886. // According to DXIL spec, the TessFactor SV can only be used by PCOut and
  2887. // DSIn.
  2888. // According to Vulkan spec, the TessLevelOuter BuiltIn can only be used in
  2889. // PCOut and DSIn.
  2890. case hlsl::Semantic::Kind::TessFactor: {
  2891. stageVar->setIsSpirvBuiltin();
  2892. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter,
  2893. isPrecise, srcLoc);
  2894. }
  2895. // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
  2896. // and DSIn.
  2897. // According to Vulkan spec, the TessLevelInner BuiltIn can only be used in
  2898. // PCOut and DSIn.
  2899. case hlsl::Semantic::Kind::InsideTessFactor: {
  2900. stageVar->setIsSpirvBuiltin();
  2901. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner,
  2902. isPrecise, srcLoc);
  2903. }
  2904. // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
  2905. // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
  2906. case hlsl::Semantic::Kind::DomainLocation: {
  2907. stageVar->setIsSpirvBuiltin();
  2908. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord,
  2909. isPrecise, srcLoc);
  2910. }
  2911. // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
  2912. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  2913. // HS/GS In.
  2914. case hlsl::Semantic::Kind::GSInstanceID: {
  2915. stageVar->setIsSpirvBuiltin();
  2916. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
  2917. isPrecise, srcLoc);
  2918. }
  2919. // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
  2920. // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
  2921. case hlsl::Semantic::Kind::SampleIndex: {
  2922. stageVar->setIsSpirvBuiltin();
  2923. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId, isPrecise,
  2924. srcLoc);
  2925. }
  2926. // According to DXIL spec, the StencilRef SV can only be used by PSOut.
  2927. case hlsl::Semantic::Kind::StencilRef: {
  2928. stageVar->setIsSpirvBuiltin();
  2929. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT,
  2930. isPrecise, srcLoc);
  2931. }
  2932. // According to DXIL spec, the Barycentrics SV can only be used by PSIn.
  2933. case hlsl::Semantic::Kind::Barycentrics: {
  2934. stageVar->setIsSpirvBuiltin();
  2935. // Selecting the correct builtin according to interpolation mode
  2936. auto bi = BuiltIn::Max;
  2937. if (decl->hasAttr<HLSLNoPerspectiveAttr>()) {
  2938. if (decl->hasAttr<HLSLCentroidAttr>()) {
  2939. bi = BuiltIn::BaryCoordNoPerspCentroidAMD;
  2940. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  2941. bi = BuiltIn::BaryCoordNoPerspSampleAMD;
  2942. } else {
  2943. bi = BuiltIn::BaryCoordNoPerspAMD;
  2944. }
  2945. } else {
  2946. if (decl->hasAttr<HLSLCentroidAttr>()) {
  2947. bi = BuiltIn::BaryCoordSmoothCentroidAMD;
  2948. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  2949. bi = BuiltIn::BaryCoordSmoothSampleAMD;
  2950. } else {
  2951. bi = BuiltIn::BaryCoordSmoothAMD;
  2952. }
  2953. }
  2954. return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
  2955. }
  2956. // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
  2957. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
  2958. // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut
  2959. // PSIn, and MSPOut.
  2960. case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
  2961. switch (sigPointKind) {
  2962. case hlsl::SigPoint::Kind::VSIn:
  2963. case hlsl::SigPoint::Kind::HSCPIn:
  2964. case hlsl::SigPoint::Kind::HSCPOut:
  2965. case hlsl::SigPoint::Kind::PCOut:
  2966. case hlsl::SigPoint::Kind::DSIn:
  2967. case hlsl::SigPoint::Kind::DSCPIn:
  2968. case hlsl::SigPoint::Kind::GSVIn:
  2969. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2970. case hlsl::SigPoint::Kind::VSOut:
  2971. case hlsl::SigPoint::Kind::DSOut:
  2972. stageVar->setIsSpirvBuiltin();
  2973. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
  2974. srcLoc);
  2975. case hlsl::SigPoint::Kind::GSOut:
  2976. case hlsl::SigPoint::Kind::PSIn:
  2977. case hlsl::SigPoint::Kind::MSPOut:
  2978. stageVar->setIsSpirvBuiltin();
  2979. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
  2980. srcLoc);
  2981. default:
  2982. llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
  2983. }
  2984. }
  2985. // According to DXIL spec, the ViewportArrayIndex SV can only be used by
  2986. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
  2987. // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
  2988. // GSOut, PSIn, and MSPOut.
  2989. case hlsl::Semantic::Kind::ViewPortArrayIndex: {
  2990. switch (sigPointKind) {
  2991. case hlsl::SigPoint::Kind::VSIn:
  2992. case hlsl::SigPoint::Kind::HSCPIn:
  2993. case hlsl::SigPoint::Kind::HSCPOut:
  2994. case hlsl::SigPoint::Kind::PCOut:
  2995. case hlsl::SigPoint::Kind::DSIn:
  2996. case hlsl::SigPoint::Kind::DSCPIn:
  2997. case hlsl::SigPoint::Kind::GSVIn:
  2998. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2999. case hlsl::SigPoint::Kind::VSOut:
  3000. case hlsl::SigPoint::Kind::DSOut:
  3001. stageVar->setIsSpirvBuiltin();
  3002. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
  3003. isPrecise, srcLoc);
  3004. case hlsl::SigPoint::Kind::GSOut:
  3005. case hlsl::SigPoint::Kind::PSIn:
  3006. case hlsl::SigPoint::Kind::MSPOut:
  3007. stageVar->setIsSpirvBuiltin();
  3008. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
  3009. isPrecise, srcLoc);
  3010. default:
  3011. llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
  3012. }
  3013. }
  3014. // According to DXIL spec, the Coverage SV can only be used by PSIn and PSOut.
  3015. // According to Vulkan spec, the SampleMask BuiltIn can only be used in
  3016. // PSIn and PSOut.
  3017. case hlsl::Semantic::Kind::Coverage: {
  3018. stageVar->setIsSpirvBuiltin();
  3019. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask,
  3020. isPrecise, srcLoc);
  3021. }
  3022. // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
  3023. // HSIn, DSIn, GSIn, PSIn.
  3024. // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
  3025. // VS/HS/DS/GS/PS input.
  3026. case hlsl::Semantic::Kind::ViewID: {
  3027. stageVar->setIsSpirvBuiltin();
  3028. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex,
  3029. isPrecise, srcLoc);
  3030. }
  3031. // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
  3032. // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
  3033. // PSIn.
  3034. case hlsl::Semantic::Kind::InnerCoverage: {
  3035. stageVar->setIsSpirvBuiltin();
  3036. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT,
  3037. isPrecise, srcLoc);
  3038. }
  3039. // According to DXIL spec, the ShadingRate SV can only be used by GSOut,
  3040. // VSOut, or PSIn. According to Vulkan spec, the FragSizeEXT BuiltIn can only
  3041. // be used as VSOut, GSOut, MSOut or PSIn.
  3042. case hlsl::Semantic::Kind::ShadingRate: {
  3043. QualType checkType = type->getAs<ReferenceType>()
  3044. ? type->getAs<ReferenceType>()->getPointeeType()
  3045. : type;
  3046. QualType scalarTy;
  3047. if (!isScalarType(checkType, &scalarTy) || !scalarTy->isIntegerType()) {
  3048. emitError("semantic ShadingRate must be interger scalar type", srcLoc);
  3049. }
  3050. switch (sigPointKind) {
  3051. case hlsl::SigPoint::Kind::PSIn:
  3052. stageVar->setIsSpirvBuiltin();
  3053. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ShadingRateKHR,
  3054. isPrecise, srcLoc);
  3055. case hlsl::SigPoint::Kind::VSOut:
  3056. case hlsl::SigPoint::Kind::GSOut:
  3057. case hlsl::SigPoint::Kind::MSOut:
  3058. stageVar->setIsSpirvBuiltin();
  3059. return spvBuilder.addStageBuiltinVar(
  3060. type, sc, BuiltIn::PrimitiveShadingRateKHR, isPrecise, srcLoc);
  3061. default:
  3062. emitError("semantic ShadingRate must be used only for PSIn, VSOut, "
  3063. "GSOut, MSOut",
  3064. srcLoc);
  3065. break;
  3066. }
  3067. break;
  3068. }
  3069. default:
  3070. emitError("semantic %0 unimplemented", srcLoc)
  3071. << stageVar->getSemanticStr();
  3072. break;
  3073. }
  3074. return 0;
  3075. }
  3076. bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) {
  3077. bool success = true;
  3078. if (const auto *idxAttr = decl->getAttr<VKIndexAttr>()) {
  3079. if (!spvContext.isPS()) {
  3080. emitError("vk::index only allowed in pixel shader",
  3081. idxAttr->getLocation());
  3082. success = false;
  3083. }
  3084. const auto *locAttr = decl->getAttr<VKLocationAttr>();
  3085. if (!locAttr) {
  3086. emitError("vk::index should be used together with vk::location for "
  3087. "dual-source blending",
  3088. idxAttr->getLocation());
  3089. success = false;
  3090. } else {
  3091. const auto locNumber = locAttr->getNumber();
  3092. if (locNumber != 0) {
  3093. emitError("dual-source blending should use vk::location 0",
  3094. locAttr->getLocation());
  3095. success = false;
  3096. }
  3097. }
  3098. const auto idxNumber = idxAttr->getNumber();
  3099. if (idxNumber != 0 && idxNumber != 1) {
  3100. emitError("dual-source blending only accepts 0 or 1 as vk::index",
  3101. idxAttr->getLocation());
  3102. success = false;
  3103. }
  3104. }
  3105. return success;
  3106. }
  3107. bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
  3108. const hlsl::SigPoint *sigPoint) {
  3109. bool success = true;
  3110. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  3111. // The front end parsing only allows vk::builtin to be attached to a
  3112. // function/parameter/variable; all of them are DeclaratorDecls.
  3113. const auto declType = getTypeOrFnRetType(cast<DeclaratorDecl>(decl));
  3114. const auto loc = builtinAttr->getLocation();
  3115. if (decl->hasAttr<VKLocationAttr>()) {
  3116. emitError("cannot use vk::builtin and vk::location together", loc);
  3117. success = false;
  3118. }
  3119. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  3120. if (builtin == "HelperInvocation") {
  3121. if (!declType->isBooleanType()) {
  3122. emitError("HelperInvocation builtin must be of boolean type", loc);
  3123. success = false;
  3124. }
  3125. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) {
  3126. emitError(
  3127. "HelperInvocation builtin can only be used as pixel shader input",
  3128. loc);
  3129. success = false;
  3130. }
  3131. } else if (builtin == "PointSize") {
  3132. if (!declType->isFloatingType()) {
  3133. emitError("PointSize builtin must be of float type", loc);
  3134. success = false;
  3135. }
  3136. switch (sigPoint->GetKind()) {
  3137. case hlsl::SigPoint::Kind::VSOut:
  3138. case hlsl::SigPoint::Kind::HSCPIn:
  3139. case hlsl::SigPoint::Kind::HSCPOut:
  3140. case hlsl::SigPoint::Kind::DSCPIn:
  3141. case hlsl::SigPoint::Kind::DSOut:
  3142. case hlsl::SigPoint::Kind::GSVIn:
  3143. case hlsl::SigPoint::Kind::GSOut:
  3144. case hlsl::SigPoint::Kind::PSIn:
  3145. case hlsl::SigPoint::Kind::MSOut:
  3146. break;
  3147. default:
  3148. emitError("PointSize builtin cannot be used as %0", loc)
  3149. << sigPoint->GetName();
  3150. success = false;
  3151. }
  3152. } else if (builtin == "BaseVertex" || builtin == "BaseInstance" ||
  3153. builtin == "DrawIndex") {
  3154. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  3155. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  3156. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  3157. << builtin;
  3158. success = false;
  3159. }
  3160. switch (sigPoint->GetKind()) {
  3161. case hlsl::SigPoint::Kind::VSIn:
  3162. break;
  3163. case hlsl::SigPoint::Kind::MSIn:
  3164. case hlsl::SigPoint::Kind::ASIn:
  3165. if (builtin != "DrawIndex") {
  3166. emitError("%0 builtin cannot be used as %1", loc)
  3167. << builtin << sigPoint->GetName();
  3168. success = false;
  3169. }
  3170. break;
  3171. default:
  3172. emitError("%0 builtin cannot be used as %1", loc)
  3173. << builtin << sigPoint->GetName();
  3174. success = false;
  3175. }
  3176. } else if (builtin == "DeviceIndex") {
  3177. if (getStorageClassForSigPoint(sigPoint) != spv::StorageClass::Input) {
  3178. emitError("%0 builtin can only be used as shader input", loc)
  3179. << builtin;
  3180. success = false;
  3181. }
  3182. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  3183. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  3184. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  3185. << builtin;
  3186. success = false;
  3187. }
  3188. } else if (builtin == "ViewportMaskNV") {
  3189. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) {
  3190. emitError("%0 builtin can only be used as 'primitives' output in MS",
  3191. loc)
  3192. << builtin;
  3193. success = false;
  3194. }
  3195. if (!declType->isArrayType() ||
  3196. !declType->getArrayElementTypeNoTypeQual()->isSpecificBuiltinType(
  3197. BuiltinType::Kind::Int)) {
  3198. emitError("%0 builtin must be of type array of integers", loc)
  3199. << builtin;
  3200. success = false;
  3201. }
  3202. }
  3203. }
  3204. return success;
  3205. }
  3206. spv::StorageClass
  3207. DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
  3208. // This translation is done based on the HLSL reference (see docs/dxil.rst).
  3209. const auto sigPointKind = sigPoint->GetKind();
  3210. const auto signatureKind = sigPoint->GetSignatureKind();
  3211. spv::StorageClass sc = spv::StorageClass::Max;
  3212. switch (signatureKind) {
  3213. case hlsl::DXIL::SignatureKind::Input:
  3214. sc = spv::StorageClass::Input;
  3215. break;
  3216. case hlsl::DXIL::SignatureKind::Output:
  3217. sc = spv::StorageClass::Output;
  3218. break;
  3219. case hlsl::DXIL::SignatureKind::Invalid: {
  3220. // There are some special cases in HLSL (See docs/dxil.rst):
  3221. // SignatureKind is "invalid" for PCIn, HSIn, GSIn, and CSIn.
  3222. switch (sigPointKind) {
  3223. case hlsl::DXIL::SigPointKind::PCIn:
  3224. case hlsl::DXIL::SigPointKind::HSIn:
  3225. case hlsl::DXIL::SigPointKind::GSIn:
  3226. case hlsl::DXIL::SigPointKind::CSIn:
  3227. case hlsl::DXIL::SigPointKind::MSIn:
  3228. case hlsl::DXIL::SigPointKind::ASIn:
  3229. sc = spv::StorageClass::Input;
  3230. break;
  3231. default:
  3232. llvm_unreachable("Found invalid SigPoint kind for semantic");
  3233. }
  3234. break;
  3235. }
  3236. case hlsl::DXIL::SignatureKind::PatchConstOrPrim: {
  3237. // There are some special cases in HLSL (See docs/dxil.rst):
  3238. // SignatureKind is "PatchConstOrPrim" for PCOut, MSPOut and DSIn.
  3239. switch (sigPointKind) {
  3240. case hlsl::DXIL::SigPointKind::PCOut:
  3241. case hlsl::DXIL::SigPointKind::MSPOut:
  3242. // Patch Constant Output (Output of Hull which is passed to Domain).
  3243. // Mesh Shader per-primitive output attributes.
  3244. sc = spv::StorageClass::Output;
  3245. break;
  3246. case hlsl::DXIL::SigPointKind::DSIn:
  3247. // Domain Shader regular input - Patch Constant data plus system values.
  3248. sc = spv::StorageClass::Input;
  3249. break;
  3250. default:
  3251. llvm_unreachable("Found invalid SigPoint kind for semantic");
  3252. }
  3253. break;
  3254. }
  3255. default:
  3256. llvm_unreachable("Found invalid SigPoint kind for semantic");
  3257. }
  3258. return sc;
  3259. }
  3260. QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
  3261. const DeclaratorDecl *decl, bool *shouldBeAlias) {
  3262. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  3263. // This method is only intended to be used to create SPIR-V variables in the
  3264. // Function or Private storage class.
  3265. assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
  3266. }
  3267. const QualType type = getTypeOrFnRetType(decl);
  3268. // Whether we should generate this decl as an alias variable.
  3269. bool genAlias = false;
  3270. // For ConstantBuffers, TextureBuffers, StructuredBuffers, ByteAddressBuffers
  3271. if (isConstantTextureBuffer(type) ||
  3272. isOrContainsAKindOfStructuredOrByteBuffer(type)) {
  3273. genAlias = true;
  3274. }
  3275. // Return via parameter whether alias was generated.
  3276. if (shouldBeAlias)
  3277. *shouldBeAlias = genAlias;
  3278. if (genAlias) {
  3279. needsLegalization = true;
  3280. createCounterVarForDecl(decl);
  3281. }
  3282. return type;
  3283. }
  3284. bool DeclResultIdMapper::getImplicitRegisterType(const ResourceVar &var,
  3285. char *registerTypeOut) const {
  3286. assert(registerTypeOut);
  3287. if (var.getSpirvInstr()) {
  3288. if (var.getSpirvInstr()->hasAstResultType()) {
  3289. QualType type = var.getSpirvInstr()->getAstResultType();
  3290. // Strip outer arrayness first
  3291. while (type->isArrayType())
  3292. type = type->getAsArrayTypeUnsafe()->getElementType();
  3293. // t - for shader resource views (SRV)
  3294. if (isTexture(type) || isNonWritableStructuredBuffer(type) ||
  3295. isByteAddressBuffer(type) || isBuffer(type)) {
  3296. *registerTypeOut = 't';
  3297. return true;
  3298. }
  3299. // s - for samplers
  3300. else if (isSampler(type)) {
  3301. *registerTypeOut = 's';
  3302. return true;
  3303. }
  3304. // u - for unordered access views (UAV)
  3305. else if (isRWByteAddressBuffer(type) || isRWAppendConsumeSBuffer(type) ||
  3306. isRWBuffer(type) || isRWTexture(type)) {
  3307. *registerTypeOut = 'u';
  3308. return true;
  3309. }
  3310. } else {
  3311. llvm::StringRef hlslUserType = var.getSpirvInstr()->getHlslUserType();
  3312. // b - for constant buffer views (CBV)
  3313. if (var.isGlobalsBuffer() || hlslUserType == "cbuffer" ||
  3314. hlslUserType == "ConstantBuffer") {
  3315. *registerTypeOut = 'b';
  3316. return true;
  3317. }
  3318. if (hlslUserType == "tbuffer") {
  3319. *registerTypeOut = 't';
  3320. return true;
  3321. }
  3322. }
  3323. }
  3324. *registerTypeOut = '\0';
  3325. return false;
  3326. }
  3327. SpirvVariable *
  3328. DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
  3329. const VarDecl *decl) {
  3330. QualType type = decl->getType();
  3331. SpirvVariable *retVal = nullptr;
  3332. // Raytracing interface variables are special since they do not participate
  3333. // in any interface matching and hence do not create StageVar and
  3334. // track them under StageVars vector
  3335. const auto name = decl->getName();
  3336. switch (sc) {
  3337. case spv::StorageClass::IncomingRayPayloadNV:
  3338. case spv::StorageClass::IncomingCallableDataNV:
  3339. case spv::StorageClass::HitAttributeNV:
  3340. case spv::StorageClass::RayPayloadNV:
  3341. case spv::StorageClass::CallableDataNV:
  3342. retVal = spvBuilder.addModuleVar(type, sc, decl->hasAttr<HLSLPreciseAttr>(),
  3343. name.str());
  3344. break;
  3345. default:
  3346. assert(false && "Unsupported SPIR-V storage class for raytracing");
  3347. }
  3348. return retVal;
  3349. }
  3350. void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
  3351. const VarDecl *varDecl = dyn_cast<VarDecl>(decl);
  3352. if (!varDecl || !varDecl->isImplicit())
  3353. return;
  3354. APValue *val = varDecl->evaluateValue();
  3355. if (!val)
  3356. return;
  3357. SpirvInstruction *constVal =
  3358. spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
  3359. constVal->setRValue(true);
  3360. astDecls[varDecl].instr = constVal;
  3361. }
  3362. SpirvInstruction *
  3363. DeclResultIdMapper::createHullMainOutputPatch(const ParmVarDecl *param,
  3364. const QualType retType,
  3365. uint32_t numOutputControlPoints) {
  3366. const QualType hullMainRetType = astContext.getConstantArrayType(
  3367. retType, llvm::APInt(32, numOutputControlPoints),
  3368. clang::ArrayType::Normal, 0);
  3369. SpirvInstruction *hullMainOutputPatch = createSpirvIntermediateOutputStageVar(
  3370. param, "temp.var.hullMainRetVal", hullMainRetType);
  3371. assert(astDecls[param].instr == nullptr);
  3372. astDecls[param].instr = hullMainOutputPatch;
  3373. return hullMainOutputPatch;
  3374. }
  3375. } // end namespace spirv
  3376. } // end namespace clang