DeclResultIdMapper.cpp 127 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318
  1. //===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "DeclResultIdMapper.h"
  10. #include <sstream>
  11. #include "dxc/DXIL/DxilConstants.h"
  12. #include "dxc/DXIL/DxilTypeSystem.h"
  13. #include "clang/AST/Expr.h"
  14. #include "clang/AST/HlslTypes.h"
  15. #include "clang/SPIRV/AstTypeProbe.h"
  16. #include "llvm/ADT/SmallBitVector.h"
  17. #include "llvm/ADT/StringMap.h"
  18. #include "llvm/ADT/StringSet.h"
  19. #include "llvm/Support/Casting.h"
  20. #include "AlignmentSizeCalculator.h"
  21. #include "SpirvEmitter.h"
  22. namespace clang {
  23. namespace spirv {
  24. namespace {
  25. uint32_t getVkBindingAttrSet(const VKBindingAttr *attr, uint32_t defaultSet) {
  26. // If the [[vk::binding(x)]] attribute is provided without the descriptor set,
  27. // we should use the default descriptor set.
  28. if (attr->getSet() == INT_MIN) {
  29. return defaultSet;
  30. }
  31. return attr->getSet();
  32. }
  33. /// Returns the :packoffset() annotation on the given decl. Returns nullptr if
  34. /// the decl does not have one.
  35. hlsl::ConstantPacking *getPackOffset(const clang::NamedDecl *decl) {
  36. for (auto *annotation : decl->getUnusualAnnotations())
  37. if (auto *packing = llvm::dyn_cast<hlsl::ConstantPacking>(annotation))
  38. return packing;
  39. return nullptr;
  40. }
  41. QualType getUintTypeWithSourceComponents(const ASTContext &astContext,
  42. QualType sourceType) {
  43. if (isScalarType(sourceType)) {
  44. return astContext.UnsignedIntTy;
  45. }
  46. uint32_t elemCount = 0;
  47. if (isVectorType(sourceType, nullptr, &elemCount)) {
  48. return astContext.getExtVectorType(astContext.UnsignedIntTy, elemCount);
  49. }
  50. llvm_unreachable("only scalar and vector types are supported in "
  51. "getUintTypeWithSourceComponents");
  52. }
  53. uint32_t getLocationCount(const ASTContext &astContext, QualType type) {
  54. // See Vulkan spec 14.1.4. Location Assignment for the complete set of rules.
  55. const auto canonicalType = type.getCanonicalType();
  56. if (canonicalType != type)
  57. return getLocationCount(astContext, canonicalType);
  58. // Inputs and outputs of the following types consume a single interface
  59. // location:
  60. // * 16-bit scalar and vector types, and
  61. // * 32-bit scalar and vector types, and
  62. // * 64-bit scalar and 2-component vector types.
  63. // 64-bit three- and four- component vectors consume two consecutive
  64. // locations.
  65. // Primitive types
  66. if (isScalarType(type))
  67. return 1;
  68. // Vector types
  69. {
  70. QualType elemType = {};
  71. uint32_t elemCount = {};
  72. if (isVectorType(type, &elemType, &elemCount)) {
  73. const auto *builtinType = elemType->getAs<BuiltinType>();
  74. switch (builtinType->getKind()) {
  75. case BuiltinType::Double:
  76. case BuiltinType::LongLong:
  77. case BuiltinType::ULongLong:
  78. if (elemCount >= 3)
  79. return 2;
  80. default:
  81. // Filter switch only interested in types occupying 2 locations.
  82. break;
  83. }
  84. return 1;
  85. }
  86. }
  87. // If the declared input or output is an n * m 16- , 32- or 64- bit matrix,
  88. // it will be assigned multiple locations starting with the location
  89. // specified. The number of locations assigned for each matrix will be the
  90. // same as for an n-element array of m-component vectors.
  91. // Matrix types
  92. {
  93. QualType elemType = {};
  94. uint32_t rowCount = 0, colCount = 0;
  95. if (isMxNMatrix(type, &elemType, &rowCount, &colCount))
  96. return getLocationCount(astContext,
  97. astContext.getExtVectorType(elemType, colCount)) *
  98. rowCount;
  99. }
  100. // Typedefs
  101. if (const auto *typedefType = type->getAs<TypedefType>())
  102. return getLocationCount(astContext, typedefType->desugar());
  103. // Reference types
  104. if (const auto *refType = type->getAs<ReferenceType>())
  105. return getLocationCount(astContext, refType->getPointeeType());
  106. // Pointer types
  107. if (const auto *ptrType = type->getAs<PointerType>())
  108. return getLocationCount(astContext, ptrType->getPointeeType());
  109. // If a declared input or output is an array of size n and each element takes
  110. // m locations, it will be assigned m * n consecutive locations starting with
  111. // the location specified.
  112. // Array types
  113. if (const auto *arrayType = astContext.getAsConstantArrayType(type))
  114. return getLocationCount(astContext, arrayType->getElementType()) *
  115. static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  116. // Struct type
  117. if (type->getAs<RecordType>()) {
  118. assert(false && "all structs should already be flattened");
  119. return 0;
  120. }
  121. llvm_unreachable(
  122. "calculating number of occupied locations for type unimplemented");
  123. return 0;
  124. }
  125. bool shouldSkipInStructLayout(const Decl *decl) {
  126. // Ignore implicit generated struct declarations/constructors/destructors
  127. if (decl->isImplicit())
  128. return true;
  129. // Ignore embedded type decls
  130. if (isa<TypeDecl>(decl))
  131. return true;
  132. // Ignore embeded function decls
  133. if (isa<FunctionDecl>(decl))
  134. return true;
  135. // Ignore empty decls
  136. if (isa<EmptyDecl>(decl))
  137. return true;
  138. // For the $Globals cbuffer, we only care about externally-visible
  139. // non-resource-type variables. The rest should be filtered out.
  140. const auto *declContext = decl->getDeclContext();
  141. // Special check for ConstantBuffer/TextureBuffer, whose DeclContext is a
  142. // HLSLBufferDecl. So that we need to check the HLSLBufferDecl's parent decl
  143. // to check whether this is a ConstantBuffer/TextureBuffer defined in the
  144. // global namespace.
  145. // Note that we should not be seeing ConstantBuffer/TextureBuffer for normal
  146. // cbuffer/tbuffer or push constant blocks. So this case should only happen
  147. // for $Globals cbuffer.
  148. if (isConstantTextureBuffer(decl) &&
  149. declContext->getLexicalParent()->isTranslationUnit())
  150. return true;
  151. // $Globals' "struct" is the TranslationUnit, so we should ignore resources
  152. // in the TranslationUnit "struct" and its child namespaces.
  153. if (declContext->isTranslationUnit() || declContext->isNamespace()) {
  154. // External visibility
  155. if (const auto *declDecl = dyn_cast<DeclaratorDecl>(decl))
  156. if (!declDecl->hasExternalFormalLinkage())
  157. return true;
  158. // cbuffer/tbuffer
  159. if (isa<HLSLBufferDecl>(decl))
  160. return true;
  161. // Other resource types
  162. if (const auto *valueDecl = dyn_cast<ValueDecl>(decl))
  163. if (isResourceType(valueDecl))
  164. return true;
  165. }
  166. return false;
  167. }
  168. void collectDeclsInField(const Decl *field,
  169. llvm::SmallVector<const Decl *, 4> *decls) {
  170. // Case of nested namespaces.
  171. if (const auto *nsDecl = dyn_cast<NamespaceDecl>(field)) {
  172. for (const auto *decl : nsDecl->decls()) {
  173. collectDeclsInField(decl, decls);
  174. }
  175. }
  176. if (shouldSkipInStructLayout(field))
  177. return;
  178. if (!isa<DeclaratorDecl>(field)) {
  179. return;
  180. }
  181. decls->push_back(field);
  182. }
  183. llvm::SmallVector<const Decl *, 4>
  184. collectDeclsInDeclContext(const DeclContext *declContext) {
  185. llvm::SmallVector<const Decl *, 4> decls;
  186. for (const auto *field : declContext->decls()) {
  187. collectDeclsInField(field, &decls);
  188. }
  189. return decls;
  190. }
  191. /// \brief Returns true if the given decl is a boolean stage I/O variable.
  192. /// Returns false if the type is not boolean, or the decl is a built-in stage
  193. /// variable.
  194. bool isBooleanStageIOVar(const NamedDecl *decl, QualType type,
  195. const hlsl::DXIL::SemanticKind semanticKind,
  196. const hlsl::SigPoint::Kind sigPointKind) {
  197. // [[vk::builtin(...)]] makes the decl a built-in stage variable.
  198. // IsFrontFace (if used as PSIn) is the only known boolean built-in stage
  199. // variable.
  200. const bool isBooleanBuiltin =
  201. (decl->getAttr<VKBuiltInAttr>() != nullptr) ||
  202. (semanticKind == hlsl::Semantic::Kind::IsFrontFace &&
  203. sigPointKind == hlsl::SigPoint::Kind::PSIn);
  204. // TODO: support boolean matrix stage I/O variable if needed.
  205. QualType elemType = {};
  206. const bool isBooleanType =
  207. ((isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
  208. elemType->isBooleanType());
  209. return isBooleanType && !isBooleanBuiltin;
  210. }
  211. /// \brief Returns the stage variable's register assignment for the given Decl.
  212. const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
  213. for (auto *annotation : decl->getUnusualAnnotations()) {
  214. if (auto *reg = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
  215. return reg;
  216. }
  217. }
  218. return nullptr;
  219. }
  220. /// \brief Returns the stage variable's 'register(c#) assignment for the given
  221. /// Decl. Return nullptr if the given variable does not have such assignment.
  222. const hlsl::RegisterAssignment *getRegisterCAssignment(const NamedDecl *decl) {
  223. const auto *regAssignment = getResourceBinding(decl);
  224. if (regAssignment)
  225. return regAssignment->RegisterType == 'c' ? regAssignment : nullptr;
  226. return nullptr;
  227. }
  228. /// \brief Returns true if the given declaration has a primitive type qualifier.
  229. /// Returns false otherwise.
  230. inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
  231. return decl->hasAttr<HLSLTriangleAttr>() ||
  232. decl->hasAttr<HLSLTriangleAdjAttr>() ||
  233. decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
  234. decl->hasAttr<HLSLLineAdjAttr>();
  235. }
  236. /// \brief Deduces the parameter qualifier for the given decl.
  237. hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
  238. bool asInput) {
  239. const auto type = decl->getType();
  240. if (hlsl::IsHLSLInputPatchType(type))
  241. return hlsl::DxilParamInputQual::InputPatch;
  242. if (hlsl::IsHLSLOutputPatchType(type))
  243. return hlsl::DxilParamInputQual::OutputPatch;
  244. // TODO: Add support for multiple output streams.
  245. if (hlsl::IsHLSLStreamOutputType(type))
  246. return hlsl::DxilParamInputQual::OutStream0;
  247. // The inputs to the geometry shader that have a primitive type qualifier
  248. // must use 'InputPrimitive'.
  249. if (hasGSPrimitiveTypeQualifier(decl))
  250. return hlsl::DxilParamInputQual::InputPrimitive;
  251. if (decl->hasAttr<HLSLIndicesAttr>())
  252. return hlsl::DxilParamInputQual::OutIndices;
  253. if (decl->hasAttr<HLSLVerticesAttr>())
  254. return hlsl::DxilParamInputQual::OutVertices;
  255. if (decl->hasAttr<HLSLPrimitivesAttr>())
  256. return hlsl::DxilParamInputQual::OutPrimitives;
  257. if (decl->hasAttr<HLSLPayloadAttr>())
  258. return hlsl::DxilParamInputQual::InPayload;
  259. return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
  260. }
  261. /// \brief Deduces the HLSL SigPoint for the given decl appearing in the given
  262. /// shader model.
  263. const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput,
  264. const hlsl::ShaderModel::Kind kind,
  265. bool forPCF) {
  266. return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  267. deduceParamQual(decl, asInput), kind, forPCF));
  268. }
  269. /// Returns the type of the given decl. If the given decl is a FunctionDecl,
  270. /// returns its result type.
  271. inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
  272. if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  273. return funcDecl->getReturnType();
  274. }
  275. return decl->getType();
  276. }
  277. /// Returns the number of base classes if this type is a derived class/struct.
  278. /// Returns zero otherwise.
  279. inline uint32_t getNumBaseClasses(QualType type) {
  280. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  281. return cxxDecl->getNumBases();
  282. return 0;
  283. }
  284. } // anonymous namespace
  285. std::string StageVar::getSemanticStr() const {
  286. // A special case for zero index, which is equivalent to no index.
  287. // Use what is in the source code.
  288. // TODO: this looks like a hack to make the current tests happy.
  289. // Should consider remove it and fix all tests.
  290. if (semanticInfo.index == 0)
  291. return semanticInfo.str;
  292. std::ostringstream ss;
  293. ss << semanticInfo.name.str() << semanticInfo.index;
  294. return ss.str();
  295. }
  296. SpirvInstruction *CounterIdAliasPair::get(SpirvBuilder &builder,
  297. SpirvContext &spvContext) const {
  298. if (isAlias) {
  299. const auto *counterType = spvContext.getACSBufferCounterType();
  300. const auto *counterVarType =
  301. spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
  302. return builder.createLoad(counterVarType, counterVar,
  303. /* SourceLocation */ {});
  304. }
  305. return counterVar;
  306. }
  307. const CounterIdAliasPair *
  308. CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
  309. for (const auto &field : fields)
  310. if (field.indices == indices)
  311. return &field.counterVar;
  312. return nullptr;
  313. }
  314. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  315. SpirvBuilder &builder,
  316. SpirvContext &context) const {
  317. for (const auto &field : fields) {
  318. const auto *srcField = srcFields.get(field.indices);
  319. if (!srcField)
  320. return false;
  321. field.counterVar.assign(*srcField, builder, context);
  322. }
  323. return true;
  324. }
  325. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  326. const llvm::SmallVector<uint32_t, 4> &dstPrefix,
  327. const llvm::SmallVector<uint32_t, 4> &srcPrefix,
  328. SpirvBuilder &builder,
  329. SpirvContext &context) const {
  330. if (dstPrefix.empty() && srcPrefix.empty())
  331. return assign(srcFields, builder, context);
  332. llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
  333. // If whole has the given prefix, appends all elements after the prefix in
  334. // whole to srcIndices.
  335. const auto applyDiff =
  336. [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
  337. const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
  338. uint32_t i = 0;
  339. for (; i < prefix.size(); ++i)
  340. if (whole[i] != prefix[i]) {
  341. break;
  342. }
  343. if (i == prefix.size()) {
  344. for (; i < whole.size(); ++i)
  345. srcIndices.push_back(whole[i]);
  346. return true;
  347. }
  348. return false;
  349. };
  350. for (const auto &field : fields)
  351. if (applyDiff(field.indices, dstPrefix)) {
  352. const auto *srcField = srcFields.get(srcIndices);
  353. if (!srcField)
  354. return false;
  355. field.counterVar.assign(*srcField, builder, context);
  356. for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
  357. srcIndices.pop_back();
  358. }
  359. return true;
  360. }
  361. SemanticInfo DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
  362. for (auto *annotation : decl->getUnusualAnnotations()) {
  363. if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
  364. llvm::StringRef semanticStr = sema->SemanticName;
  365. llvm::StringRef semanticName;
  366. uint32_t index = 0;
  367. hlsl::Semantic::DecomposeNameAndIndex(semanticStr, &semanticName, &index);
  368. const auto *semantic = hlsl::Semantic::GetByName(semanticName);
  369. return {semanticStr, semantic, semanticName, index, sema->Loc};
  370. }
  371. }
  372. return {};
  373. }
  374. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  375. SpirvInstruction *storedValue,
  376. bool forPCF) {
  377. QualType type = getTypeOrFnRetType(decl);
  378. uint32_t arraySize = 0;
  379. // Output stream types (PointStream, LineStream, TriangleStream) are
  380. // translated as their underlying struct types.
  381. if (hlsl::IsHLSLStreamOutputType(type))
  382. type = hlsl::GetHLSLResourceResultType(type);
  383. if (decl->hasAttr<HLSLIndicesAttr>() || decl->hasAttr<HLSLVerticesAttr>() ||
  384. decl->hasAttr<HLSLPrimitivesAttr>()) {
  385. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  386. type = typeDecl->getElementType();
  387. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  388. if (decl->hasAttr<HLSLIndicesAttr>()) {
  389. // create SPIR-V builtin array PrimitiveIndicesNV of type
  390. // "uint [MaxPrimitiveCount * verticesPerPrim]"
  391. uint32_t verticesPerPrim = 1;
  392. if (!isVectorType(type, nullptr, &verticesPerPrim)) {
  393. assert(isScalarType(type));
  394. }
  395. arraySize = arraySize * verticesPerPrim;
  396. QualType arrayType = astContext.getConstantArrayType(
  397. astContext.UnsignedIntTy, llvm::APInt(32, arraySize),
  398. clang::ArrayType::Normal, 0);
  399. stageVarInstructions[cast<DeclaratorDecl>(decl)] = getBuiltinVar(
  400. spv::BuiltIn::PrimitiveIndicesNV, arrayType, decl->getLocation());
  401. return true;
  402. }
  403. }
  404. const auto *sigPoint = deduceSigPoint(
  405. decl, /*asInput=*/false, spvContext.getCurrentShaderModelKind(), forPCF);
  406. // HS output variables are created using the other overload. For the rest,
  407. // none of them should be created as arrays.
  408. assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
  409. SemanticInfo inheritSemantic = {};
  410. // If storedValue is 0, it means this parameter in the original source code is
  411. // not used at all. Avoid writing back.
  412. //
  413. // Write back of stage output variables in GS is manually controlled by
  414. // .Append() intrinsic method, implemented in writeBackOutputStream(). So
  415. // ignoreValue should be set to true for GS.
  416. const bool noWriteBack =
  417. storedValue == nullptr || spvContext.isGS() || spvContext.isMS();
  418. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  419. "out.var", llvm::None, &storedValue, noWriteBack,
  420. &inheritSemantic);
  421. }
  422. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  423. uint32_t arraySize,
  424. SpirvInstruction *invocationId,
  425. SpirvInstruction *storedValue) {
  426. assert(spvContext.isHS());
  427. QualType type = getTypeOrFnRetType(decl);
  428. const auto *sigPoint =
  429. hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
  430. SemanticInfo inheritSemantic = {};
  431. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  432. "out.var", invocationId, &storedValue,
  433. /*noWriteBack=*/false, &inheritSemantic);
  434. }
  435. bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
  436. SpirvInstruction **loadedValue,
  437. bool forPCF) {
  438. uint32_t arraySize = 0;
  439. QualType type = paramDecl->getType();
  440. // Deprive the outermost arrayness for HS/DS/GS and use arraySize
  441. // to convey that information
  442. if (hlsl::IsHLSLInputPatchType(type)) {
  443. arraySize = hlsl::GetHLSLInputPatchCount(type);
  444. type = hlsl::GetHLSLInputPatchElementType(type);
  445. } else if (hlsl::IsHLSLOutputPatchType(type)) {
  446. arraySize = hlsl::GetHLSLOutputPatchCount(type);
  447. type = hlsl::GetHLSLOutputPatchElementType(type);
  448. }
  449. if (hasGSPrimitiveTypeQualifier(paramDecl)) {
  450. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  451. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  452. type = typeDecl->getElementType();
  453. }
  454. const auto *sigPoint =
  455. deduceSigPoint(paramDecl, /*asInput=*/true,
  456. spvContext.getCurrentShaderModelKind(), forPCF);
  457. SemanticInfo inheritSemantic = {};
  458. if (paramDecl->hasAttr<HLSLPayloadAttr>()) {
  459. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  460. return createPayloadStageVars(sigPoint, sc, paramDecl, /*asInput=*/true,
  461. type, "in.var", loadedValue);
  462. } else {
  463. return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type,
  464. arraySize, "in.var", llvm::None, loadedValue,
  465. /*noWriteBack=*/false, &inheritSemantic);
  466. }
  467. }
  468. const DeclResultIdMapper::DeclSpirvInfo *
  469. DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
  470. auto it = astDecls.find(decl);
  471. if (it != astDecls.end())
  472. return &it->second;
  473. return nullptr;
  474. }
  475. SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
  476. SourceLocation loc) {
  477. if (const auto *info = getDeclSpirvInfo(decl)) {
  478. if (info->indexInCTBuffer >= 0) {
  479. // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
  480. // OpAccessChain to get the pointer to the variable since we created
  481. // a single variable for the whole buffer object.
  482. // Should only have VarDecls in a HLSLBufferDecl.
  483. QualType valueType = cast<VarDecl>(decl)->getType();
  484. return spvBuilder.createAccessChain(
  485. valueType, info->instr,
  486. {spvBuilder.getConstantInt(
  487. astContext.IntTy, llvm::APInt(32, info->indexInCTBuffer, true))},
  488. loc);
  489. } else {
  490. return *info;
  491. }
  492. }
  493. emitFatalError("found unregistered decl", decl->getLocation())
  494. << decl->getName();
  495. emitNote("please file a bug report on "
  496. "https://github.com/Microsoft/DirectXShaderCompiler/issues with "
  497. "source code if possible",
  498. {});
  499. return 0;
  500. }
  501. SpirvFunctionParameter *
  502. DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
  503. const auto type = getTypeOrFnRetType(param);
  504. const auto loc = param->getLocation();
  505. SpirvFunctionParameter *fnParamInstr = spvBuilder.addFnParam(
  506. type, param->hasAttr<HLSLPreciseAttr>(), loc, param->getName());
  507. bool isAlias = false;
  508. (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
  509. fnParamInstr->setContainsAliasComponent(isAlias);
  510. assert(astDecls[param].instr == nullptr);
  511. astDecls[param].instr = fnParamInstr;
  512. return fnParamInstr;
  513. }
  514. void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
  515. const QualType declType = getTypeOrFnRetType(decl);
  516. if (!counterVars.count(decl) && isRWAppendConsumeSBuffer(declType)) {
  517. createCounterVar(decl, /*declId=*/0, /*isAlias=*/true);
  518. } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
  519. // Exclude other resource types which are represented as structs
  520. !hlsl::IsHLSLResourceType(declType)) {
  521. createFieldCounterVars(decl);
  522. }
  523. }
  524. SpirvVariable *
  525. DeclResultIdMapper::createFnVar(const VarDecl *var,
  526. llvm::Optional<SpirvInstruction *> init) {
  527. const auto type = getTypeOrFnRetType(var);
  528. const auto loc = var->getLocation();
  529. const auto name = var->getName();
  530. const bool isPrecise = var->hasAttr<HLSLPreciseAttr>();
  531. SpirvVariable *varInstr = spvBuilder.addFnVar(
  532. type, loc, name, isPrecise, init.hasValue() ? init.getValue() : nullptr);
  533. bool isAlias = false;
  534. (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
  535. varInstr->setContainsAliasComponent(isAlias);
  536. assert(astDecls[var].instr == nullptr);
  537. astDecls[var].instr = varInstr;
  538. return varInstr;
  539. }
  540. SpirvVariable *
  541. DeclResultIdMapper::createFileVar(const VarDecl *var,
  542. llvm::Optional<SpirvInstruction *> init) {
  543. const auto type = getTypeOrFnRetType(var);
  544. const auto loc = var->getLocation();
  545. SpirvVariable *varInstr = spvBuilder.addModuleVar(
  546. type, spv::StorageClass::Private, var->hasAttr<HLSLPreciseAttr>(),
  547. var->getName(), init, loc);
  548. bool isAlias = false;
  549. (void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
  550. varInstr->setContainsAliasComponent(isAlias);
  551. assert(astDecls[var].instr == nullptr);
  552. astDecls[var].instr = varInstr;
  553. return varInstr;
  554. }
  555. SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
  556. auto storageClass = spv::StorageClass::UniformConstant;
  557. auto rule = SpirvLayoutRule::Void;
  558. bool isACRWSBuffer = false; // Whether is {Append|Consume|RW}StructuredBuffer
  559. if (var->getAttr<HLSLGroupSharedAttr>()) {
  560. // For CS groupshared variables
  561. storageClass = spv::StorageClass::Workgroup;
  562. } else if (isResourceType(var)) {
  563. // See through the possible outer arrays
  564. QualType resourceType = var->getType();
  565. while (resourceType->isArrayType()) {
  566. resourceType = resourceType->getAsArrayTypeUnsafe()->getElementType();
  567. }
  568. const llvm::StringRef typeName =
  569. resourceType->getAs<RecordType>()->getDecl()->getName();
  570. // These types are all translated into OpTypeStruct with BufferBlock
  571. // decoration. They should follow standard storage buffer layout,
  572. // which GLSL std430 rules statisfies.
  573. if (typeName == "StructuredBuffer" || typeName == "ByteAddressBuffer" ||
  574. typeName == "RWByteAddressBuffer") {
  575. storageClass = spv::StorageClass::Uniform;
  576. rule = spirvOptions.sBufferLayoutRule;
  577. } else if (typeName == "RWStructuredBuffer" ||
  578. typeName == "AppendStructuredBuffer" ||
  579. typeName == "ConsumeStructuredBuffer") {
  580. storageClass = spv::StorageClass::Uniform;
  581. rule = spirvOptions.sBufferLayoutRule;
  582. isACRWSBuffer = true;
  583. }
  584. } else {
  585. // This is a stand-alone externally-visiable non-resource-type variable.
  586. // They should be grouped into the $Globals cbuffer. We create that cbuffer
  587. // and record all variables inside it upon seeing the first such variable.
  588. if (astDecls.count(var) == 0)
  589. createGlobalsCBuffer(var);
  590. auto *varInstr = astDecls[var].instr;
  591. return varInstr ? cast<SpirvVariable>(varInstr) : nullptr;
  592. }
  593. const auto type = var->getType();
  594. const auto loc = var->getLocation();
  595. SpirvVariable *varInstr = spvBuilder.addModuleVar(
  596. type, storageClass, var->hasAttr<HLSLPreciseAttr>(), var->getName(),
  597. llvm::None, loc);
  598. varInstr->setLayoutRule(rule);
  599. DeclSpirvInfo info(varInstr);
  600. astDecls[var] = info;
  601. // Variables in Workgroup do not need descriptor decorations.
  602. if (storageClass == spv::StorageClass::Workgroup)
  603. return varInstr;
  604. const auto *regAttr = getResourceBinding(var);
  605. const auto *bindingAttr = var->getAttr<VKBindingAttr>();
  606. const auto *counterBindingAttr = var->getAttr<VKCounterBindingAttr>();
  607. resourceVars.emplace_back(varInstr, var, loc, regAttr, bindingAttr,
  608. counterBindingAttr);
  609. if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
  610. spvBuilder.decorateInputAttachmentIndex(varInstr,
  611. inputAttachment->getIndex(), loc);
  612. if (isACRWSBuffer) {
  613. // For {Append|Consume|RW}StructuredBuffer, we need to always create another
  614. // variable for its associated counter.
  615. createCounterVar(var, varInstr, /*isAlias=*/false);
  616. }
  617. return varInstr;
  618. }
  619. SpirvInstruction *
  620. DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
  621. assert(hlsl::IsStringType(var->getType()) ||
  622. hlsl::IsStringLiteralType(var->getType()));
  623. // If the string variable is not initialized to a string literal, we cannot
  624. // generate an OpString for it.
  625. if (!var->hasInit()) {
  626. emitError("Found uninitialized string variable.", var->getLocation());
  627. return nullptr;
  628. }
  629. const StringLiteral *stringLiteral =
  630. dyn_cast<StringLiteral>(var->getInit()->IgnoreParenCasts());
  631. SpirvString *init = spvBuilder.getString(stringLiteral->getString());
  632. DeclSpirvInfo info(init);
  633. astDecls[var] = info;
  634. return init;
  635. }
  636. SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
  637. const DeclContext *decl, int arraySize, const ContextUsageKind usageKind,
  638. llvm::StringRef typeName, llvm::StringRef varName) {
  639. // cbuffers are translated into OpTypeStruct with Block decoration.
  640. // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
  641. // Push constants are translated into OpTypeStruct with Block decoration.
  642. //
  643. // Both cbuffers and tbuffers have the SPIR-V Uniform storage class.
  644. // Push constants have the SPIR-V PushConstant storage class.
  645. const bool forCBuffer = usageKind == ContextUsageKind::CBuffer;
  646. const bool forTBuffer = usageKind == ContextUsageKind::TBuffer;
  647. const bool forGlobals = usageKind == ContextUsageKind::Globals;
  648. const bool forPC = usageKind == ContextUsageKind::PushConstant;
  649. const bool forShaderRecordNV =
  650. usageKind == ContextUsageKind::ShaderRecordBufferNV;
  651. const llvm::SmallVector<const Decl *, 4> &declGroup =
  652. collectDeclsInDeclContext(decl);
  653. // Collect the type and name for each field
  654. llvm::SmallVector<HybridStructType::FieldInfo, 4> fields;
  655. for (const auto *subDecl : declGroup) {
  656. // 'groupshared' variables should not be placed in $Globals cbuffer.
  657. if (forGlobals && subDecl->hasAttr<HLSLGroupSharedAttr>())
  658. continue;
  659. // The field can only be FieldDecl (for normal structs) or VarDecl (for
  660. // HLSLBufferDecls).
  661. assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
  662. const auto *declDecl = cast<DeclaratorDecl>(subDecl);
  663. // In case 'register(c#)' annotation is placed on a global variable.
  664. const hlsl::RegisterAssignment *registerC =
  665. forGlobals ? getRegisterCAssignment(declDecl) : nullptr;
  666. // All fields are qualified with const. It will affect the debug name.
  667. // We don't need it here.
  668. auto varType = declDecl->getType();
  669. varType.removeLocalConst();
  670. HybridStructType::FieldInfo info(varType, declDecl->getName(),
  671. declDecl->getAttr<VKOffsetAttr>(),
  672. getPackOffset(declDecl), registerC,
  673. declDecl->hasAttr<HLSLPreciseAttr>());
  674. fields.push_back(info);
  675. }
  676. // Get the type for the whole struct
  677. // tbuffer/TextureBuffers are non-writable SSBOs.
  678. const SpirvType *resultType = spvContext.getHybridStructType(
  679. fields, typeName, /*isReadOnly*/ forTBuffer,
  680. forTBuffer ? StructInterfaceType::StorageBuffer
  681. : StructInterfaceType::UniformBuffer);
  682. // Make an array if requested.
  683. if (arraySize > 0) {
  684. resultType = spvContext.getArrayType(resultType, arraySize,
  685. /*ArrayStride*/ llvm::None);
  686. } else if (arraySize == -1) {
  687. resultType =
  688. spvContext.getRuntimeArrayType(resultType, /*ArrayStride*/ llvm::None);
  689. }
  690. // Register the <type-id> for this decl
  691. ctBufferPCTypes[decl] = resultType;
  692. const auto sc = forPC ? spv::StorageClass::PushConstant
  693. : forShaderRecordNV
  694. ? spv::StorageClass::ShaderRecordBufferNV
  695. : spv::StorageClass::Uniform;
  696. // Create the variable for the whole struct / struct array.
  697. // The fields may be 'precise', but the structure itself is not.
  698. SpirvVariable *var =
  699. spvBuilder.addModuleVar(resultType, sc, /*isPrecise*/ false, varName);
  700. const SpirvLayoutRule layoutRule =
  701. (forCBuffer || forGlobals)
  702. ? spirvOptions.cBufferLayoutRule
  703. : (forTBuffer ? spirvOptions.tBufferLayoutRule
  704. : spirvOptions.sBufferLayoutRule);
  705. var->setHlslUserType(forCBuffer ? "cbuffer" : forTBuffer ? "tbuffer" : "");
  706. var->setLayoutRule(layoutRule);
  707. return var;
  708. }
  709. void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
  710. const auto *valueDecl = dyn_cast<ValueDecl>(decl);
  711. const auto enumConstant =
  712. spvBuilder.getConstantInt(astContext.IntTy, decl->getInitVal());
  713. SpirvVariable *varInstr = spvBuilder.addModuleVar(
  714. astContext.IntTy, spv::StorageClass::Private, /*isPrecise*/ false,
  715. decl->getName(), enumConstant, decl->getLocation());
  716. astDecls[valueDecl] = DeclSpirvInfo(varInstr);
  717. }
  718. SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
  719. const auto usageKind =
  720. decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
  721. const std::string structName = "type." + decl->getName().str();
  722. // The front-end does not allow arrays of cbuffer/tbuffer.
  723. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  724. decl, /*arraySize*/ 0, usageKind, structName, decl->getName());
  725. // We still register all VarDecls seperately here. All the VarDecls are
  726. // mapped to the <result-id> of the buffer object, which means when querying
  727. // querying the <result-id> for a certain VarDecl, we need to do an extra
  728. // OpAccessChain.
  729. int index = 0;
  730. for (const auto *subDecl : decl->decls()) {
  731. if (shouldSkipInStructLayout(subDecl))
  732. continue;
  733. const auto *varDecl = cast<VarDecl>(subDecl);
  734. astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
  735. }
  736. resourceVars.emplace_back(
  737. bufferVar, decl, decl->getLocation(), getResourceBinding(decl),
  738. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  739. return bufferVar;
  740. }
  741. SpirvVariable *DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
  742. const RecordType *recordType = nullptr;
  743. int arraySize = 0;
  744. // In case we have an array of ConstantBuffer/TextureBuffer:
  745. if (const auto *arrayType = decl->getType()->getAsArrayTypeUnsafe()) {
  746. recordType = arrayType->getElementType()->getAs<RecordType>();
  747. if (const auto *caType =
  748. astContext.getAsConstantArrayType(decl->getType())) {
  749. arraySize = static_cast<uint32_t>(caType->getSize().getZExtValue());
  750. } else {
  751. arraySize = -1;
  752. }
  753. } else {
  754. recordType = decl->getType()->getAs<RecordType>();
  755. }
  756. if (!recordType) {
  757. emitError("constant/texture buffer type %0 unimplemented",
  758. decl->getLocStart())
  759. << decl->getType();
  760. return 0;
  761. }
  762. const auto *context = cast<HLSLBufferDecl>(decl->getDeclContext());
  763. const auto usageKind = context->isCBuffer() ? ContextUsageKind::CBuffer
  764. : ContextUsageKind::TBuffer;
  765. const char *ctBufferName =
  766. context->isCBuffer() ? "ConstantBuffer." : "TextureBuffer.";
  767. const std::string structName = "type." + std::string(ctBufferName) +
  768. recordType->getDecl()->getName().str();
  769. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  770. recordType->getDecl(), arraySize, usageKind, structName, decl->getName());
  771. // We register the VarDecl here.
  772. astDecls[decl] = DeclSpirvInfo(bufferVar);
  773. resourceVars.emplace_back(
  774. bufferVar, decl, decl->getLocation(), getResourceBinding(context),
  775. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  776. return bufferVar;
  777. }
  778. SpirvVariable *DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
  779. // The front-end errors out if non-struct type push constant is used.
  780. const auto *recordType = decl->getType()->getAs<RecordType>();
  781. assert(recordType);
  782. const std::string structName =
  783. "type.PushConstant." + recordType->getDecl()->getName().str();
  784. SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
  785. recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
  786. structName, decl->getName());
  787. // Register the VarDecl
  788. astDecls[decl] = DeclSpirvInfo(var);
  789. // Do not push this variable into resourceVars since it does not need
  790. // descriptor set.
  791. return var;
  792. }
  793. SpirvVariable *
  794. DeclResultIdMapper::createShaderRecordBufferNV(const VarDecl *decl) {
  795. const auto *recordType = decl->getType()->getAs<RecordType>();
  796. assert(recordType);
  797. const std::string structName =
  798. "type.ShaderRecordBufferNV." + recordType->getDecl()->getName().str();
  799. SpirvVariable *var = createStructOrStructArrayVarOfExplicitLayout(
  800. recordType->getDecl(), /*arraySize*/ 0,
  801. ContextUsageKind::ShaderRecordBufferNV, structName, decl->getName());
  802. // Register the VarDecl
  803. astDecls[decl] = DeclSpirvInfo(var);
  804. // Do not push this variable into resourceVars since it does not need
  805. // descriptor set.
  806. return var;
  807. }
  808. SpirvVariable *
  809. DeclResultIdMapper::createShaderRecordBufferNV(const HLSLBufferDecl *decl) {
  810. const std::string structName =
  811. "type.ShaderRecordBufferNV." + decl->getName().str();
  812. // The front-end does not allow arrays of cbuffer/tbuffer.
  813. SpirvVariable *bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  814. decl, /*arraySize*/ 0, ContextUsageKind::ShaderRecordBufferNV, structName,
  815. decl->getName());
  816. // We still register all VarDecls seperately here. All the VarDecls are
  817. // mapped to the <result-id> of the buffer object, which means when
  818. // querying the <result-id> for a certain VarDecl, we need to do an extra
  819. // OpAccessChain.
  820. int index = 0;
  821. for (const auto *subDecl : decl->decls()) {
  822. if (shouldSkipInStructLayout(subDecl))
  823. continue;
  824. const auto *varDecl = cast<VarDecl>(subDecl);
  825. astDecls[varDecl] = DeclSpirvInfo(bufferVar, index++);
  826. }
  827. return bufferVar;
  828. }
  829. void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
  830. if (astDecls.count(var) != 0)
  831. return;
  832. const auto *context = var->getTranslationUnitDecl();
  833. SpirvVariable *globals = createStructOrStructArrayVarOfExplicitLayout(
  834. context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
  835. "$Globals");
  836. resourceVars.emplace_back(globals, /*decl*/ nullptr, SourceLocation(),
  837. nullptr, nullptr, nullptr, /*isCounterVar*/ false,
  838. /*isGlobalsCBuffer*/ true);
  839. uint32_t index = 0;
  840. for (const auto *decl : collectDeclsInDeclContext(context)) {
  841. // 'groupshared' variables should not be placed in $Globals cbuffer.
  842. if (decl->hasAttr<HLSLGroupSharedAttr>())
  843. continue;
  844. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  845. if (!spirvOptions.noWarnIgnoredFeatures) {
  846. if (const auto *init = varDecl->getInit())
  847. emitWarning(
  848. "variable '%0' will be placed in $Globals so initializer ignored",
  849. init->getExprLoc())
  850. << var->getName() << init->getSourceRange();
  851. }
  852. if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
  853. emitError("variable '%0' will be placed in $Globals so cannot have "
  854. "vk::binding attribute",
  855. attr->getLocation())
  856. << var->getName();
  857. return;
  858. }
  859. astDecls[varDecl] = DeclSpirvInfo(globals, index++);
  860. }
  861. }
  862. }
  863. SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
  864. // Return it if it's already been created.
  865. auto it = astFunctionDecls.find(fn);
  866. if (it != astFunctionDecls.end()) {
  867. return it->second;
  868. }
  869. bool isAlias = false;
  870. (void)getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias);
  871. const bool isPrecise = fn->hasAttr<HLSLPreciseAttr>();
  872. // Note: we do not need to worry about function parameter types at this point
  873. // as this is used when function declarations are seen. When function
  874. // definition is seen, the parameter types will be set properly and take into
  875. // account whether the function is a member function of a class/struct (in
  876. // which case a 'this' parameter is added at the beginnig).
  877. SpirvFunction *spirvFunction = new (spvContext)
  878. SpirvFunction(fn->getReturnType(), /* param QualTypes */ {},
  879. fn->getLocation(), fn->getName(), isPrecise);
  880. // No need to dereference to get the pointer. Function returns that are
  881. // stand-alone aliases are already pointers to values. All other cases should
  882. // be normal rvalues.
  883. if (!isAlias || !isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
  884. spirvFunction->setRValue();
  885. spirvFunction->setConstainsAliasComponent(isAlias);
  886. astFunctionDecls[fn] = spirvFunction;
  887. return spirvFunction;
  888. }
  889. const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
  890. const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
  891. if (!decl)
  892. return nullptr;
  893. if (indices) {
  894. // Indices are provided. Walk through the fields of the decl.
  895. const auto counter = fieldCounterVars.find(decl);
  896. if (counter != fieldCounterVars.end())
  897. return counter->second.get(*indices);
  898. } else {
  899. // No indices. Check the stand-alone entities.
  900. const auto counter = counterVars.find(decl);
  901. if (counter != counterVars.end())
  902. return &counter->second;
  903. }
  904. return nullptr;
  905. }
  906. const CounterVarFields *
  907. DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
  908. if (!decl)
  909. return nullptr;
  910. const auto found = fieldCounterVars.find(decl);
  911. if (found != fieldCounterVars.end())
  912. return &found->second;
  913. return nullptr;
  914. }
  915. void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
  916. SpirvInstruction *specConstant) {
  917. specConstant->setRValue();
  918. astDecls[decl] = DeclSpirvInfo(specConstant);
  919. }
  920. void DeclResultIdMapper::createCounterVar(
  921. const DeclaratorDecl *decl, SpirvInstruction *declInstr, bool isAlias,
  922. const llvm::SmallVector<uint32_t, 4> *indices) {
  923. std::string counterName = "counter.var." + decl->getName().str();
  924. if (indices) {
  925. // Append field indices to the name
  926. for (const auto index : *indices)
  927. counterName += "." + std::to_string(index);
  928. }
  929. const SpirvType *counterType = spvContext.getACSBufferCounterType();
  930. // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
  931. // Alias counter variables should be created into the Private storage class.
  932. const spv::StorageClass sc =
  933. isAlias ? spv::StorageClass::Private : spv::StorageClass::Uniform;
  934. if (isAlias) {
  935. // Apply an extra level of pointer for alias counter variable
  936. counterType =
  937. spvContext.getPointerType(counterType, spv::StorageClass::Uniform);
  938. }
  939. SpirvVariable *counterInstr = spvBuilder.addModuleVar(
  940. counterType, sc, /*isPrecise*/ false, counterName);
  941. if (!isAlias) {
  942. // Non-alias counter variables should be put in to resourceVars so that
  943. // descriptors can be allocated for them.
  944. resourceVars.emplace_back(counterInstr, decl, decl->getLocation(),
  945. getResourceBinding(decl),
  946. decl->getAttr<VKBindingAttr>(),
  947. decl->getAttr<VKCounterBindingAttr>(), true);
  948. assert(declInstr);
  949. spvBuilder.decorateCounterBuffer(declInstr, counterInstr,
  950. decl->getLocation());
  951. }
  952. if (indices)
  953. fieldCounterVars[decl].append(*indices, counterInstr);
  954. else
  955. counterVars[decl] = {counterInstr, isAlias};
  956. }
  957. void DeclResultIdMapper::createFieldCounterVars(
  958. const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
  959. llvm::SmallVector<uint32_t, 4> *indices) {
  960. const QualType type = getTypeOrFnRetType(decl);
  961. const auto *recordType = type->getAs<RecordType>();
  962. assert(recordType);
  963. const auto *recordDecl = recordType->getDecl();
  964. for (const auto *field : recordDecl->fields()) {
  965. // Build up the index chain
  966. indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
  967. const QualType fieldType = field->getType();
  968. if (isRWAppendConsumeSBuffer(fieldType))
  969. createCounterVar(rootDecl, /*declId=*/0, /*isAlias=*/true, indices);
  970. else if (fieldType->isStructureType() &&
  971. !hlsl::IsHLSLResourceType(fieldType))
  972. // Go recursively into all nested structs
  973. createFieldCounterVars(rootDecl, field, indices);
  974. indices->pop_back();
  975. }
  976. }
  977. const SpirvType *
  978. DeclResultIdMapper::getCTBufferPushConstantType(const DeclContext *decl) {
  979. const auto found = ctBufferPCTypes.find(decl);
  980. assert(found != ctBufferPCTypes.end());
  981. return found->second;
  982. }
  983. std::vector<SpirvVariable *> DeclResultIdMapper::collectStageVars() const {
  984. std::vector<SpirvVariable *> vars;
  985. for (auto var : glPerVertex.getStageInVars())
  986. vars.push_back(var);
  987. for (auto var : glPerVertex.getStageOutVars())
  988. vars.push_back(var);
  989. llvm::DenseSet<SpirvInstruction *> seenVars;
  990. for (const auto &var : stageVars) {
  991. auto *instr = var.getSpirvInstr();
  992. if (seenVars.count(instr) == 0) {
  993. vars.push_back(instr);
  994. seenVars.insert(instr);
  995. }
  996. }
  997. return vars;
  998. }
  999. namespace {
  1000. /// A class for managing stage input/output locations to avoid duplicate uses of
  1001. /// the same location.
  1002. class LocationSet {
  1003. public:
  1004. /// Maximum number of indices supported
  1005. const static uint32_t kMaxIndex = 2;
  1006. /// Maximum number of locations supported
  1007. // Typically we won't have that many stage input or output variables.
  1008. // Using 64 should be fine here.
  1009. const static uint32_t kMaxLoc = 64;
  1010. LocationSet() {
  1011. for (uint32_t i = 0; i < kMaxIndex; ++i) {
  1012. usedLocs[i].resize(kMaxLoc);
  1013. nextLoc[i] = 0;
  1014. }
  1015. }
  1016. /// Uses the given location.
  1017. void useLoc(uint32_t loc, uint32_t index = 0) {
  1018. assert(index < kMaxIndex);
  1019. usedLocs[index].set(loc);
  1020. }
  1021. /// Uses the next |count| available location.
  1022. int useNextLocs(uint32_t count, uint32_t index = 0) {
  1023. assert(index < kMaxIndex);
  1024. auto &locs = usedLocs[index];
  1025. auto &next = nextLoc[index];
  1026. while (locs[next])
  1027. next++;
  1028. int toUse = next;
  1029. for (uint32_t i = 0; i < count; ++i) {
  1030. assert(!locs[next]);
  1031. locs.set(next++);
  1032. }
  1033. return toUse;
  1034. }
  1035. /// Returns true if the given location number is already used.
  1036. bool isLocUsed(uint32_t loc, uint32_t index = 0) {
  1037. assert(index < kMaxIndex);
  1038. return usedLocs[index][loc];
  1039. }
  1040. private:
  1041. llvm::SmallBitVector usedLocs[kMaxIndex]; ///< All previously used locations
  1042. uint32_t nextLoc[kMaxIndex]; ///< Next available location
  1043. };
  1044. /// A class for managing resource bindings to avoid duplicate uses of the same
  1045. /// set and binding number.
  1046. class BindingSet {
  1047. public:
  1048. /// Uses the given set and binding number. Returns false if the binding number
  1049. /// was already occupied in the set, and returns true otherwise.
  1050. bool useBinding(uint32_t binding, uint32_t set) {
  1051. bool inserted = false;
  1052. std::tie(std::ignore, inserted) = usedBindings[set].insert(binding);
  1053. return inserted;
  1054. }
  1055. /// Uses the next avaiable binding number in |set|. If more than one binding
  1056. /// number is to be occupied, it finds the next available chunk that can fit
  1057. /// |numBindingsToUse| in the |set|.
  1058. uint32_t useNextBinding(uint32_t set, uint32_t numBindingsToUse = 1) {
  1059. uint32_t bindingNoStart = getNextBindingChunk(set, numBindingsToUse);
  1060. auto &binding = usedBindings[set];
  1061. for (uint32_t i = 0; i < numBindingsToUse; ++i)
  1062. binding.insert(bindingNoStart + i);
  1063. return bindingNoStart;
  1064. }
  1065. /// Returns the first available binding number in the |set| for which |n|
  1066. /// consecutive binding numbers are unused.
  1067. uint32_t getNextBindingChunk(uint32_t set, uint32_t n) {
  1068. auto &existingBindings = usedBindings[set];
  1069. // There were no bindings in this set. Can start at binding zero.
  1070. if (existingBindings.empty())
  1071. return 0;
  1072. // Check whether the chunk of |n| binding numbers can be fitted at the
  1073. // very beginning of the list (start at binding 0 in the current set).
  1074. uint32_t curBinding = *existingBindings.begin();
  1075. if (curBinding >= n)
  1076. return 0;
  1077. auto iter = std::next(existingBindings.begin());
  1078. while (iter != existingBindings.end()) {
  1079. // There exists a next binding number that is used. Check to see if the
  1080. // gap between current binding number and next binding number is large
  1081. // enough to accommodate |n|.
  1082. uint32_t nextBinding = *iter;
  1083. if (n <= nextBinding - curBinding - 1)
  1084. return curBinding + 1;
  1085. curBinding = nextBinding;
  1086. // Peek at the next binding that has already been used (if any).
  1087. ++iter;
  1088. }
  1089. // |curBinding| was the last binding that was used in this set. The next
  1090. // chunk of |n| bindings can start at |curBinding|+1.
  1091. return curBinding + 1;
  1092. }
  1093. private:
  1094. ///< set number -> set of used binding number
  1095. llvm::DenseMap<uint32_t, std::set<uint32_t>> usedBindings;
  1096. };
  1097. } // namespace
  1098. bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
  1099. llvm::StringSet<> seenSemantics;
  1100. bool success = true;
  1101. for (const auto &var : stageVars) {
  1102. auto s = var.getSemanticStr();
  1103. if (s.empty()) {
  1104. // We translate WaveGetLaneCount(), WaveGetLaneIndex() and 'payload' param
  1105. // block declaration into builtin variables. Those variables are inserted
  1106. // into the normal stage IO processing pipeline, but with the semantics as
  1107. // empty strings.
  1108. assert(var.isSpirvBuitin());
  1109. continue;
  1110. }
  1111. // Allow builtin variables to alias each other. We already have uniqify
  1112. // mechanism in SpirvBuilder.
  1113. if (var.isSpirvBuitin())
  1114. continue;
  1115. if (forInput && var.getSigPoint()->IsInput()) {
  1116. if (seenSemantics.count(s)) {
  1117. emitError("input semantic '%0' used more than once", {}) << s;
  1118. success = false;
  1119. }
  1120. seenSemantics.insert(s);
  1121. } else if (!forInput && var.getSigPoint()->IsOutput()) {
  1122. if (seenSemantics.count(s)) {
  1123. emitError("output semantic '%0' used more than once", {}) << s;
  1124. success = false;
  1125. }
  1126. seenSemantics.insert(s);
  1127. }
  1128. }
  1129. return success;
  1130. }
  1131. bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
  1132. if (!checkSemanticDuplication(forInput))
  1133. return false;
  1134. // Returns false if the given StageVar is an input/output variable without
  1135. // explicit location assignment. Otherwise, returns true.
  1136. const auto locAssigned = [forInput, this](const StageVar &v) {
  1137. if (forInput == isInputStorageClass(v))
  1138. // No need to assign location for builtins. Treat as assigned.
  1139. return v.isSpirvBuitin() || v.getLocationAttr() != nullptr;
  1140. // For the ones we don't care, treat as assigned.
  1141. return true;
  1142. };
  1143. // If we have explicit location specified for all input/output variables,
  1144. // use them instead assign by ourselves.
  1145. if (std::all_of(stageVars.begin(), stageVars.end(), locAssigned)) {
  1146. LocationSet locSet;
  1147. bool noError = true;
  1148. for (const auto &var : stageVars) {
  1149. // Skip builtins & those stage variables we are not handling for this call
  1150. if (var.isSpirvBuitin() || forInput != isInputStorageClass(var))
  1151. continue;
  1152. const auto *attr = var.getLocationAttr();
  1153. const auto loc = attr->getNumber();
  1154. const auto attrLoc = attr->getLocation(); // Attr source code location
  1155. const auto idx = var.getIndexAttr() ? var.getIndexAttr()->getNumber() : 0;
  1156. if ((const unsigned)loc >= LocationSet::kMaxLoc) {
  1157. emitError("stage %select{output|input}0 location #%1 too large",
  1158. attrLoc)
  1159. << forInput << loc;
  1160. return false;
  1161. }
  1162. // Make sure the same location is not assigned more than once
  1163. if (locSet.isLocUsed(loc, idx)) {
  1164. emitError("stage %select{output|input}0 location #%1 already assigned",
  1165. attrLoc)
  1166. << forInput << loc;
  1167. noError = false;
  1168. }
  1169. locSet.useLoc(loc, idx);
  1170. spvBuilder.decorateLocation(var.getSpirvInstr(), loc);
  1171. if (var.getIndexAttr())
  1172. spvBuilder.decorateIndex(var.getSpirvInstr(), idx,
  1173. var.getSemanticInfo().loc);
  1174. }
  1175. return noError;
  1176. }
  1177. std::vector<const StageVar *> vars;
  1178. LocationSet locSet;
  1179. for (const auto &var : stageVars) {
  1180. if (var.isSpirvBuitin() || forInput != isInputStorageClass(var))
  1181. continue;
  1182. if (var.getLocationAttr()) {
  1183. // We have checked that not all of the stage variables have explicit
  1184. // location assignment.
  1185. emitError("partial explicit stage %select{output|input}0 location "
  1186. "assignment via vk::location(X) unsupported",
  1187. {})
  1188. << forInput;
  1189. return false;
  1190. }
  1191. const auto &semaInfo = var.getSemanticInfo();
  1192. // We should special rules for SV_Target: the location number comes from the
  1193. // semantic string index.
  1194. if (semaInfo.isTarget()) {
  1195. spvBuilder.decorateLocation(var.getSpirvInstr(), semaInfo.index);
  1196. locSet.useLoc(semaInfo.index);
  1197. } else {
  1198. vars.push_back(&var);
  1199. }
  1200. }
  1201. // If alphabetical ordering was requested, sort by semantic string.
  1202. // Since HS includes 2 sets of outputs (patch-constant output and
  1203. // OutputPatch), running into location mismatches between HS and DS is very
  1204. // likely. In order to avoid location mismatches between HS and DS, use
  1205. // alphabetical ordering.
  1206. if (spirvOptions.stageIoOrder == "alpha" ||
  1207. (!forInput && spvContext.isHS()) || (forInput && spvContext.isDS())) {
  1208. // Sort stage input/output variables alphabetically
  1209. std::sort(vars.begin(), vars.end(),
  1210. [](const StageVar *a, const StageVar *b) {
  1211. return a->getSemanticStr() < b->getSemanticStr();
  1212. });
  1213. }
  1214. for (const auto *var : vars)
  1215. spvBuilder.decorateLocation(var->getSpirvInstr(),
  1216. locSet.useNextLocs(var->getLocationCount()));
  1217. return true;
  1218. }
  1219. namespace {
  1220. /// A class for maintaining the binding number shift requested for descriptor
  1221. /// sets.
  1222. class BindingShiftMapper {
  1223. public:
  1224. explicit BindingShiftMapper(const llvm::SmallVectorImpl<int32_t> &shifts)
  1225. : masterShift(0) {
  1226. assert(shifts.size() % 2 == 0);
  1227. if (shifts.size() == 2 && shifts[1] == -1) {
  1228. masterShift = shifts[0];
  1229. } else {
  1230. for (uint32_t i = 0; i < shifts.size(); i += 2)
  1231. perSetShift[shifts[i + 1]] = shifts[i];
  1232. }
  1233. }
  1234. /// Returns the shift amount for the given set.
  1235. int32_t getShiftForSet(int32_t set) const {
  1236. const auto found = perSetShift.find(set);
  1237. if (found != perSetShift.end())
  1238. return found->second;
  1239. return masterShift;
  1240. }
  1241. private:
  1242. uint32_t masterShift; /// Shift amount applies to all sets.
  1243. llvm::DenseMap<int32_t, int32_t> perSetShift;
  1244. };
  1245. /// A class for maintaining the mapping from source code register attributes to
  1246. /// descriptor set and number settings.
  1247. class RegisterBindingMapper {
  1248. public:
  1249. /// Takes in the relation between register attributes and descriptor settings.
  1250. /// Each relation is represented by four strings:
  1251. /// <register-type-number> <space> <descriptor-binding> <set>
  1252. bool takeInRelation(const std::vector<std::string> &relation,
  1253. std::string *error) {
  1254. assert(relation.size() % 4 == 0);
  1255. mapping.clear();
  1256. for (uint32_t i = 0; i < relation.size(); i += 4) {
  1257. int32_t spaceNo = -1, setNo = -1, bindNo = -1;
  1258. if (StringRef(relation[i + 1]).getAsInteger(10, spaceNo) || spaceNo < 0) {
  1259. *error = "space number: " + relation[i + 1];
  1260. return false;
  1261. }
  1262. if (StringRef(relation[i + 2]).getAsInteger(10, bindNo) || bindNo < 0) {
  1263. *error = "binding number: " + relation[i + 2];
  1264. return false;
  1265. }
  1266. if (StringRef(relation[i + 3]).getAsInteger(10, setNo) || setNo < 0) {
  1267. *error = "set number: " + relation[i + 3];
  1268. return false;
  1269. }
  1270. mapping[relation[i + 1] + relation[i]] = std::make_pair(setNo, bindNo);
  1271. }
  1272. return true;
  1273. }
  1274. /// Returns true and set the correct set and binding number if we can find a
  1275. /// descriptor setting for the given register. False otherwise.
  1276. bool getSetBinding(const hlsl::RegisterAssignment *regAttr,
  1277. uint32_t defaultSpace, int *setNo, int *bindNo) const {
  1278. std::ostringstream iss;
  1279. iss << regAttr->RegisterSpace.getValueOr(defaultSpace)
  1280. << regAttr->RegisterType << regAttr->RegisterNumber;
  1281. auto found = mapping.find(iss.str());
  1282. if (found != mapping.end()) {
  1283. *setNo = found->second.first;
  1284. *bindNo = found->second.second;
  1285. return true;
  1286. }
  1287. return false;
  1288. }
  1289. private:
  1290. llvm::StringMap<std::pair<int, int>> mapping;
  1291. };
  1292. } // namespace
  1293. bool DeclResultIdMapper::decorateResourceBindings() {
  1294. // For normal resource, we support 4 approaches of setting binding numbers:
  1295. // - m1: [[vk::binding(...)]]
  1296. // - m2: :register(xX, spaceY)
  1297. // - m3: None
  1298. // - m4: :register(spaceY)
  1299. //
  1300. // For associated counters, we support 2 approaches:
  1301. // - c1: [[vk::counter_binding(...)]
  1302. // - c2: None
  1303. //
  1304. // In combination, we need to handle 12 cases:
  1305. // - 4 cases for nomral resoures (m1, m2, m3, m4)
  1306. // - 8 cases for associated counters (mX * cY)
  1307. //
  1308. // In the following order:
  1309. // - m1, mX * c1
  1310. // - m2
  1311. // - m3, m4, mX * c2
  1312. // The "-auto-binding-space" command line option can be used to specify a
  1313. // certain space as default. UINT_MAX means the user has not provided this
  1314. // option. If not provided, the SPIR-V backend uses space "0" as default.
  1315. auto defaultSpaceOpt =
  1316. theEmitter.getCompilerInstance().getCodeGenOpts().HLSLDefaultSpace;
  1317. uint32_t defaultSpace = (defaultSpaceOpt == UINT_MAX) ? 0 : defaultSpaceOpt;
  1318. const bool bindGlobals = !spirvOptions.bindGlobals.empty();
  1319. int32_t globalsBindNo = -1, globalsSetNo = -1;
  1320. if (bindGlobals) {
  1321. assert(spirvOptions.bindGlobals.size() == 2);
  1322. if (StringRef(spirvOptions.bindGlobals[0])
  1323. .getAsInteger(10, globalsBindNo) ||
  1324. globalsBindNo < 0) {
  1325. emitError("invalid -fvk-bind-globals binding number: %0", {})
  1326. << spirvOptions.bindGlobals[0];
  1327. return false;
  1328. }
  1329. if (StringRef(spirvOptions.bindGlobals[1]).getAsInteger(10, globalsSetNo) ||
  1330. globalsSetNo < 0) {
  1331. emitError("invalid -fvk-bind-globals set number: %0", {})
  1332. << spirvOptions.bindGlobals[1];
  1333. return false;
  1334. }
  1335. }
  1336. // Special handling of -fvk-bind-register, which requires
  1337. // * All resources are annoated with :register() in the source code
  1338. // * -fvk-bind-register is specified for every resource
  1339. if (!spirvOptions.bindRegister.empty()) {
  1340. RegisterBindingMapper bindingMapper;
  1341. std::string error;
  1342. if (!bindingMapper.takeInRelation(spirvOptions.bindRegister, &error)) {
  1343. emitError("invalid -fvk-bind-register %0", {}) << error;
  1344. return false;
  1345. }
  1346. for (const auto &var : resourceVars)
  1347. if (const auto *regAttr = var.getRegister()) {
  1348. if (var.isCounter()) {
  1349. emitError("-fvk-bind-register for RW/Append/Consume StructuredBuffer "
  1350. "unimplemented",
  1351. var.getSourceLocation());
  1352. } else {
  1353. int setNo = 0, bindNo = 0;
  1354. if (!bindingMapper.getSetBinding(regAttr, defaultSpace, &setNo,
  1355. &bindNo)) {
  1356. emitError("missing -fvk-bind-register for resource",
  1357. var.getSourceLocation());
  1358. return false;
  1359. }
  1360. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindNo);
  1361. }
  1362. } else if (bindGlobals && var.isGlobalsBuffer()) {
  1363. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
  1364. globalsBindNo);
  1365. } else {
  1366. emitError(
  1367. "-fvk-bind-register requires register annotations on all resources",
  1368. var.getSourceLocation());
  1369. return false;
  1370. }
  1371. return true;
  1372. }
  1373. BindingSet bindingSet;
  1374. // Decorates the given varId of the given category with set number
  1375. // setNo, binding number bindingNo. Ignores overlaps.
  1376. const auto tryToDecorate = [this, &bindingSet](const ResourceVar &var,
  1377. const uint32_t setNo,
  1378. const uint32_t bindingNo) {
  1379. // By default we use one binding number per resource, and an array of
  1380. // resources also gets only one binding number. However, for array of
  1381. // resources (e.g. array of textures), DX uses one binding number per array
  1382. // element. We can match this behavior via a command line option.
  1383. uint32_t numBindingsToUse = 1;
  1384. if (spirvOptions.flattenResourceArrays)
  1385. numBindingsToUse = var.getArraySize();
  1386. for (uint32_t i = 0; i < numBindingsToUse; ++i) {
  1387. bool success = bindingSet.useBinding(bindingNo + i, setNo);
  1388. // We will not emit an error if we find a set/binding overlap because it
  1389. // is possible that the optimizer optimizes away a resource which resolves
  1390. // the overlap.
  1391. (void)success;
  1392. }
  1393. // No need to decorate multiple binding numbers for arrays. It will be done
  1394. // by legalization/optimization.
  1395. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), setNo, bindingNo);
  1396. };
  1397. for (const auto &var : resourceVars) {
  1398. if (var.isCounter()) {
  1399. if (const auto *vkCBinding = var.getCounterBinding()) {
  1400. // Process mX * c1
  1401. uint32_t set = defaultSpace;
  1402. if (const auto *vkBinding = var.getBinding())
  1403. set = getVkBindingAttrSet(vkBinding, defaultSpace);
  1404. else if (const auto *reg = var.getRegister())
  1405. set = reg->RegisterSpace.getValueOr(defaultSpace);
  1406. tryToDecorate(var, set, vkCBinding->getBinding());
  1407. }
  1408. } else {
  1409. if (const auto *vkBinding = var.getBinding()) {
  1410. // Process m1
  1411. tryToDecorate(var, getVkBindingAttrSet(vkBinding, defaultSpace),
  1412. vkBinding->getBinding());
  1413. }
  1414. }
  1415. }
  1416. BindingShiftMapper bShiftMapper(spirvOptions.bShift);
  1417. BindingShiftMapper tShiftMapper(spirvOptions.tShift);
  1418. BindingShiftMapper sShiftMapper(spirvOptions.sShift);
  1419. BindingShiftMapper uShiftMapper(spirvOptions.uShift);
  1420. // Process m2
  1421. for (const auto &var : resourceVars)
  1422. if (!var.isCounter() && !var.getBinding())
  1423. if (const auto *reg = var.getRegister()) {
  1424. // Skip space-only register() annotations
  1425. if (reg->isSpaceOnly())
  1426. continue;
  1427. const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
  1428. uint32_t binding = reg->RegisterNumber;
  1429. switch (reg->RegisterType) {
  1430. case 'b':
  1431. binding += bShiftMapper.getShiftForSet(set);
  1432. break;
  1433. case 't':
  1434. binding += tShiftMapper.getShiftForSet(set);
  1435. break;
  1436. case 's':
  1437. binding += sShiftMapper.getShiftForSet(set);
  1438. break;
  1439. case 'u':
  1440. binding += uShiftMapper.getShiftForSet(set);
  1441. break;
  1442. case 'c':
  1443. // For setting packing offset. Does not affect binding.
  1444. break;
  1445. default:
  1446. llvm_unreachable("unknown register type found");
  1447. }
  1448. tryToDecorate(var, set, binding);
  1449. }
  1450. for (const auto &var : resourceVars) {
  1451. // By default we use one binding number per resource, and an array of
  1452. // resources also gets only one binding number. However, for array of
  1453. // resources (e.g. array of textures), DX uses one binding number per array
  1454. // element. We can match this behavior via a command line option.
  1455. uint32_t numBindingsToUse = 1;
  1456. if (spirvOptions.flattenResourceArrays)
  1457. numBindingsToUse = var.getArraySize();
  1458. if (var.isCounter()) {
  1459. if (!var.getCounterBinding()) {
  1460. // Process mX * c2
  1461. uint32_t set = defaultSpace;
  1462. if (const auto *vkBinding = var.getBinding())
  1463. set = getVkBindingAttrSet(vkBinding, defaultSpace);
  1464. else if (const auto *reg = var.getRegister())
  1465. set = reg->RegisterSpace.getValueOr(defaultSpace);
  1466. spvBuilder.decorateDSetBinding(
  1467. var.getSpirvInstr(), set,
  1468. bindingSet.useNextBinding(set, numBindingsToUse));
  1469. }
  1470. } else if (!var.getBinding()) {
  1471. const auto *reg = var.getRegister();
  1472. if (reg && reg->isSpaceOnly()) {
  1473. const uint32_t set = reg->RegisterSpace.getValueOr(defaultSpace);
  1474. spvBuilder.decorateDSetBinding(
  1475. var.getSpirvInstr(), set,
  1476. bindingSet.useNextBinding(set, numBindingsToUse));
  1477. } else if (!reg) {
  1478. // Process m3 (no 'vk::binding' and no ':register' assignment)
  1479. // There is a special case for the $Globals cbuffer. The $Globals buffer
  1480. // doesn't have either 'vk::binding' or ':register', but the user may
  1481. // ask for a specific binding for it via command line options.
  1482. if (bindGlobals && var.isGlobalsBuffer()) {
  1483. spvBuilder.decorateDSetBinding(var.getSpirvInstr(), globalsSetNo,
  1484. globalsBindNo);
  1485. }
  1486. // The normal case
  1487. else {
  1488. spvBuilder.decorateDSetBinding(
  1489. var.getSpirvInstr(), defaultSpace,
  1490. bindingSet.useNextBinding(defaultSpace, numBindingsToUse));
  1491. }
  1492. }
  1493. }
  1494. }
  1495. return true;
  1496. }
  1497. bool DeclResultIdMapper::createStageVars(
  1498. const hlsl::SigPoint *sigPoint, const NamedDecl *decl, bool asInput,
  1499. QualType type, uint32_t arraySize, const llvm::StringRef namePrefix,
  1500. llvm::Optional<SpirvInstruction *> invocationId, SpirvInstruction **value,
  1501. bool noWriteBack, SemanticInfo *inheritSemantic) {
  1502. assert(value);
  1503. // invocationId should only be used for handling HS per-vertex output.
  1504. if (invocationId.hasValue()) {
  1505. assert(spvContext.isHS() && arraySize != 0 && !asInput);
  1506. }
  1507. assert(inheritSemantic);
  1508. if (type->isVoidType()) {
  1509. // No stage variables will be created for void type.
  1510. return true;
  1511. }
  1512. // The type the variable is evaluated as for SPIR-V.
  1513. QualType evalType = type;
  1514. // We have several cases regarding HLSL semantics to handle here:
  1515. // * If the currrent decl inherits a semantic from some enclosing entity,
  1516. // use the inherited semantic no matter whether there is a semantic
  1517. // attached to the current decl.
  1518. // * If there is no semantic to inherit,
  1519. // * If the current decl is a struct,
  1520. // * If the current decl has a semantic, all its members inhert this
  1521. // decl's semantic, with the index sequentially increasing;
  1522. // * If the current decl does not have a semantic, all its members
  1523. // should have semantics attached;
  1524. // * If the current decl is not a struct, it should have semantic attached.
  1525. auto thisSemantic = getStageVarSemantic(decl);
  1526. // Which semantic we should use for this decl
  1527. auto *semanticToUse = &thisSemantic;
  1528. // Enclosing semantics override internal ones
  1529. if (inheritSemantic->isValid()) {
  1530. if (thisSemantic.isValid()) {
  1531. emitWarning(
  1532. "internal semantic '%0' overridden by enclosing semantic '%1'",
  1533. thisSemantic.loc)
  1534. << thisSemantic.str << inheritSemantic->str;
  1535. }
  1536. semanticToUse = inheritSemantic;
  1537. }
  1538. const auto loc = decl->getLocation();
  1539. if (semanticToUse->isValid() &&
  1540. // Structs with attached semantics will be handled later.
  1541. !type->isStructureType()) {
  1542. // Found semantic attached directly to this Decl. This means we need to
  1543. // map this decl to a single stage variable.
  1544. if (!validateVKAttributes(decl))
  1545. return false;
  1546. const auto semanticKind = semanticToUse->getKind();
  1547. const auto sigPointKind = sigPoint->GetKind();
  1548. // Error out when the given semantic is invalid in this shader model
  1549. if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPointKind,
  1550. spvContext.getMajorVersion(),
  1551. spvContext.getMinorVersion()) ==
  1552. hlsl::DXIL::SemanticInterpretationKind::NA) {
  1553. // Special handle MSIn/ASIn allowing VK-only builtin "DrawIndex".
  1554. switch (sigPointKind) {
  1555. case hlsl::SigPoint::Kind::MSIn:
  1556. case hlsl::SigPoint::Kind::ASIn:
  1557. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  1558. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  1559. if (builtin == "DrawIndex") {
  1560. break;
  1561. }
  1562. }
  1563. // fall through
  1564. default:
  1565. emitError("invalid usage of semantic '%0' in shader profile %1", loc)
  1566. << semanticToUse->str
  1567. << hlsl::ShaderModel::GetKindName(
  1568. spvContext.getCurrentShaderModelKind());
  1569. return false;
  1570. }
  1571. }
  1572. if (!validateVKBuiltins(decl, sigPoint))
  1573. return false;
  1574. const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>();
  1575. // Special handling of certain mappings between HLSL semantics and
  1576. // SPIR-V builtins:
  1577. // * SV_CullDistance/SV_ClipDistance are outsourced to GlPerVertex.
  1578. // * SV_DomainLocation can refer to a float2, whereas TessCoord is a float3.
  1579. // To ensure SPIR-V validity, we must create a float3 and extract a
  1580. // float2 from it before passing it to the main function.
  1581. // * SV_TessFactor is an array of size 2 for isoline patch, array of size 3
  1582. // for tri patch, and array of size 4 for quad patch, but it must always
  1583. // be an array of size 4 in SPIR-V for Vulkan.
  1584. // * SV_InsideTessFactor is a single float for tri patch, and an array of
  1585. // size 2 for a quad patch, but it must always be an array of size 2 in
  1586. // SPIR-V for Vulkan.
  1587. // * SV_Coverage is an uint value, but the builtin it corresponds to,
  1588. // SampleMask, must be an array of integers.
  1589. // * SV_InnerCoverage is an uint value, but the corresponding builtin,
  1590. // FullyCoveredEXT, must be an boolean value.
  1591. // * SV_DispatchThreadID, SV_GroupThreadID, and SV_GroupID are allowed to be
  1592. // uint, uint2, or uint3, but the corresponding builtins
  1593. // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3.
  1594. // * SV_ShadingRate is a uint value, but the builtin it corresponds to is a
  1595. // int2.
  1596. if (glPerVertex.tryToAccess(sigPointKind, semanticKind,
  1597. semanticToUse->index, invocationId, value,
  1598. noWriteBack, /*vecComponent=*/nullptr, loc))
  1599. return true;
  1600. switch (semanticKind) {
  1601. case hlsl::Semantic::Kind::DomainLocation:
  1602. evalType = astContext.getExtVectorType(astContext.FloatTy, 3);
  1603. break;
  1604. case hlsl::Semantic::Kind::TessFactor:
  1605. evalType = astContext.getConstantArrayType(
  1606. astContext.FloatTy, llvm::APInt(32, 4), clang::ArrayType::Normal, 0);
  1607. break;
  1608. case hlsl::Semantic::Kind::InsideTessFactor:
  1609. evalType = astContext.getConstantArrayType(
  1610. astContext.FloatTy, llvm::APInt(32, 2), clang::ArrayType::Normal, 0);
  1611. break;
  1612. case hlsl::Semantic::Kind::Coverage:
  1613. evalType = astContext.getConstantArrayType(astContext.UnsignedIntTy,
  1614. llvm::APInt(32, 1),
  1615. clang::ArrayType::Normal, 0);
  1616. break;
  1617. case hlsl::Semantic::Kind::InnerCoverage:
  1618. evalType = astContext.BoolTy;
  1619. break;
  1620. case hlsl::Semantic::Kind::Barycentrics:
  1621. evalType = astContext.getExtVectorType(astContext.FloatTy, 2);
  1622. break;
  1623. case hlsl::Semantic::Kind::DispatchThreadID:
  1624. case hlsl::Semantic::Kind::GroupThreadID:
  1625. case hlsl::Semantic::Kind::GroupID:
  1626. // Keep the original integer signedness
  1627. evalType = astContext.getExtVectorType(
  1628. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type,
  1629. 3);
  1630. break;
  1631. case hlsl::Semantic::Kind::ShadingRate:
  1632. evalType = astContext.getExtVectorType(astContext.IntTy, 2);
  1633. break;
  1634. default:
  1635. // Only the semantic kinds mentioned above are handled.
  1636. break;
  1637. }
  1638. // Boolean stage I/O variables must be represented as unsigned integers.
  1639. // Boolean built-in variables are represented as bool.
  1640. if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  1641. evalType = getUintTypeWithSourceComponents(astContext, type);
  1642. }
  1643. // Handle the extra arrayness
  1644. if (arraySize != 0) {
  1645. evalType = astContext.getConstantArrayType(
  1646. evalType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
  1647. }
  1648. StageVar stageVar(
  1649. sigPoint, *semanticToUse, builtinAttr, evalType,
  1650. // For HS/DS/GS, we have already stripped the outmost arrayness on type.
  1651. getLocationCount(astContext, type));
  1652. const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
  1653. SpirvVariable *varInstr =
  1654. createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
  1655. if (!varInstr)
  1656. return false;
  1657. stageVar.setSpirvInstr(varInstr);
  1658. stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
  1659. stageVar.setIndexAttr(decl->getAttr<VKIndexAttr>());
  1660. stageVars.push_back(stageVar);
  1661. // Emit OpDecorate* instructions to link this stage variable with the HLSL
  1662. // semantic it is created for
  1663. spvBuilder.decorateHlslSemantic(varInstr, stageVar.getSemanticStr());
  1664. // We have semantics attached to this decl, which means it must be a
  1665. // function/parameter/variable. All are DeclaratorDecls.
  1666. stageVarInstructions[cast<DeclaratorDecl>(decl)] = varInstr;
  1667. // Mark that we have used one index for this semantic
  1668. ++semanticToUse->index;
  1669. // TODO: the following may not be correct?
  1670. if (sigPoint->GetSignatureKind() ==
  1671. hlsl::DXIL::SignatureKind::PatchConstOrPrim) {
  1672. if (sigPointKind == hlsl::SigPoint::Kind::MSPOut) {
  1673. // Decorate with PerPrimitiveNV for per-primitive out variables.
  1674. spvBuilder.decoratePerPrimitiveNV(varInstr,
  1675. varInstr->getSourceLocation());
  1676. } else {
  1677. spvBuilder.decoratePatch(varInstr, varInstr->getSourceLocation());
  1678. }
  1679. }
  1680. // Decorate with interpolation modes for pixel shader input variables
  1681. // or vertex shader output variables.
  1682. if (((spvContext.isPS() && sigPoint->IsInput()) ||
  1683. (spvContext.isVS() && sigPoint->IsOutput())) &&
  1684. // BaryCoord*AMD buitins already encode the interpolation mode.
  1685. semanticKind != hlsl::Semantic::Kind::Barycentrics)
  1686. decorateInterpolationMode(decl, type, varInstr);
  1687. if (asInput) {
  1688. *value = spvBuilder.createLoad(evalType, varInstr, loc);
  1689. // Fix ups for corner cases
  1690. // Special handling of SV_TessFactor DS patch constant input.
  1691. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1692. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1693. // relevant indexes must be loaded.
  1694. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1695. hlsl::GetArraySize(type) != 4) {
  1696. llvm::SmallVector<SpirvInstruction *, 4> components;
  1697. const auto tessFactorSize = hlsl::GetArraySize(type);
  1698. const auto arrType = astContext.getConstantArrayType(
  1699. astContext.FloatTy, llvm::APInt(32, tessFactorSize),
  1700. clang::ArrayType::Normal, 0);
  1701. for (uint32_t i = 0; i < tessFactorSize; ++i)
  1702. components.push_back(spvBuilder.createCompositeExtract(
  1703. astContext.FloatTy, *value, {i}, thisSemantic.loc));
  1704. *value = spvBuilder.createCompositeConstruct(arrType, components,
  1705. thisSemantic.loc);
  1706. }
  1707. // Special handling of SV_InsideTessFactor DS patch constant input.
  1708. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1709. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1710. // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of
  1711. // TessLevelInner.
  1712. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1713. // Some developers use float[1] instead of a scalar float.
  1714. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1715. *value = spvBuilder.createCompositeExtract(astContext.FloatTy, *value,
  1716. {0}, thisSemantic.loc);
  1717. if (type->isArrayType()) { // float[1]
  1718. const auto arrType = astContext.getConstantArrayType(
  1719. astContext.FloatTy, llvm::APInt(32, 1), clang::ArrayType::Normal,
  1720. 0);
  1721. *value = spvBuilder.createCompositeConstruct(arrType, {*value},
  1722. thisSemantic.loc);
  1723. }
  1724. }
  1725. // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord
  1726. // is always a float3. To ensure SPIR-V validity, a float3 stage variable
  1727. // is created, and we must extract a float2 from it before passing it to
  1728. // the main function.
  1729. else if (semanticKind == hlsl::Semantic::Kind::DomainLocation &&
  1730. hlsl::GetHLSLVecSize(type) != 3) {
  1731. const auto domainLocSize = hlsl::GetHLSLVecSize(type);
  1732. *value = spvBuilder.createVectorShuffle(
  1733. astContext.getExtVectorType(astContext.FloatTy, domainLocSize),
  1734. *value, *value, {0, 1}, thisSemantic.loc);
  1735. }
  1736. // Special handling of SV_Coverage, which is an uint value. We need to
  1737. // read SampleMask and extract its first element.
  1738. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1739. *value = spvBuilder.createCompositeExtract(type, *value, {0},
  1740. thisSemantic.loc);
  1741. }
  1742. // Special handling of SV_InnerCoverage, which is an uint value. We need
  1743. // to read FullyCoveredEXT, which is a boolean value, and convert it to an
  1744. // uint value. According to D3D12 "Conservative Rasterization" doc: "The
  1745. // Pixel Shader has a 32-bit scalar integer System Generate Value
  1746. // available: InnerCoverage. This is a bit-field that has bit 0 from the
  1747. // LSB set to 1 for a given conservatively rasterized pixel, only when
  1748. // that pixel is guaranteed to be entirely inside the current primitive.
  1749. // All other input register bits must be set to 0 when bit 0 is not set,
  1750. // but are undefined when bit 0 is set to 1 (essentially, this bit-field
  1751. // represents a Boolean value where false must be exactly 0, but true can
  1752. // be any odd (i.e. bit 0 set) non-zero value)."
  1753. else if (semanticKind == hlsl::Semantic::Kind::InnerCoverage) {
  1754. const auto constOne = spvBuilder.getConstantInt(
  1755. astContext.UnsignedIntTy, llvm::APInt(32, 1));
  1756. const auto constZero = spvBuilder.getConstantInt(
  1757. astContext.UnsignedIntTy, llvm::APInt(32, 0));
  1758. *value = spvBuilder.createSelect(astContext.UnsignedIntTy, *value,
  1759. constOne, constZero, thisSemantic.loc);
  1760. }
  1761. // Special handling of SV_Barycentrics, which is a float3, but the
  1762. // underlying stage input variable is a float2 (only provides the first
  1763. // two components). Calculate the third element.
  1764. else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) {
  1765. const auto x = spvBuilder.createCompositeExtract(
  1766. astContext.FloatTy, *value, {0}, thisSemantic.loc);
  1767. const auto y = spvBuilder.createCompositeExtract(
  1768. astContext.FloatTy, *value, {1}, thisSemantic.loc);
  1769. const auto xy = spvBuilder.createBinaryOp(
  1770. spv::Op::OpFAdd, astContext.FloatTy, x, y, thisSemantic.loc);
  1771. const auto z = spvBuilder.createBinaryOp(
  1772. spv::Op::OpFSub, astContext.FloatTy,
  1773. spvBuilder.getConstantFloat(astContext.FloatTy,
  1774. llvm::APFloat(1.0f)),
  1775. xy, thisSemantic.loc);
  1776. *value = spvBuilder.createCompositeConstruct(
  1777. astContext.getExtVectorType(astContext.FloatTy, 3), {x, y, z},
  1778. thisSemantic.loc);
  1779. }
  1780. // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may
  1781. // be a uint or uint2, but the underlying stage input variable is a uint3.
  1782. // The last component(s) should be discarded in needed.
  1783. else if ((semanticKind == hlsl::Semantic::Kind::DispatchThreadID ||
  1784. semanticKind == hlsl::Semantic::Kind::GroupThreadID ||
  1785. semanticKind == hlsl::Semantic::Kind::GroupID) &&
  1786. (!hlsl::IsHLSLVecType(type) ||
  1787. hlsl::GetHLSLVecSize(type) != 3)) {
  1788. const auto srcVecElemType = hlsl::IsHLSLVecType(type)
  1789. ? hlsl::GetHLSLVecElementType(type)
  1790. : type;
  1791. const auto vecSize =
  1792. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1;
  1793. if (vecSize == 1)
  1794. *value = spvBuilder.createCompositeExtract(srcVecElemType, *value,
  1795. {0}, thisSemantic.loc);
  1796. else if (vecSize == 2)
  1797. *value = spvBuilder.createVectorShuffle(
  1798. astContext.getExtVectorType(srcVecElemType, 2), *value, *value,
  1799. {0, 1}, thisSemantic.loc);
  1800. }
  1801. // Special handling of SV_ShadingRate, which is a bitpacked enum value,
  1802. // but SPIR-V's FragSizeEXT uses an int2. We build the enum value from
  1803. // the separate axis values.
  1804. else if (semanticKind == hlsl::Semantic::Kind::ShadingRate) {
  1805. // From the D3D12 functional spec for Variable-Rate Shading.
  1806. // #define D3D12_MAKE_COARSE_SHADING_RATE(x,y) ((x) << 2 | (y))
  1807. const auto x = spvBuilder.createCompositeExtract(
  1808. astContext.IntTy, *value, {0}, thisSemantic.loc);
  1809. const auto y = spvBuilder.createCompositeExtract(
  1810. astContext.IntTy, *value, {1}, thisSemantic.loc);
  1811. const auto constTwo =
  1812. spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 2));
  1813. *value = spvBuilder.createBinaryOp(
  1814. spv::Op::OpBitwiseOr, astContext.UnsignedIntTy,
  1815. spvBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
  1816. astContext.IntTy, x, constTwo,
  1817. thisSemantic.loc),
  1818. y, thisSemantic.loc);
  1819. }
  1820. // Reciprocate SV_Position.w if requested
  1821. if (semanticKind == hlsl::Semantic::Kind::Position)
  1822. *value = invertWIfRequested(*value, thisSemantic.loc);
  1823. // Since boolean stage input variables are represented as unsigned
  1824. // integers, after loading them, we should cast them to boolean.
  1825. if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  1826. *value =
  1827. theEmitter.castToType(*value, evalType, type, thisSemantic.loc);
  1828. }
  1829. } else {
  1830. if (noWriteBack)
  1831. return true;
  1832. // Negate SV_Position.y if requested
  1833. if (semanticKind == hlsl::Semantic::Kind::Position)
  1834. *value = invertYIfRequested(*value, thisSemantic.loc);
  1835. SpirvInstruction *ptr = varInstr;
  1836. // Special handling of SV_TessFactor HS patch constant output.
  1837. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1838. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1839. // relevant indexes must be written to.
  1840. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1841. hlsl::GetArraySize(type) != 4) {
  1842. const auto tessFactorSize = hlsl::GetArraySize(type);
  1843. for (uint32_t i = 0; i < tessFactorSize; ++i) {
  1844. ptr = spvBuilder.createAccessChain(
  1845. astContext.FloatTy, varInstr,
  1846. {spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  1847. llvm::APInt(32, i))},
  1848. thisSemantic.loc);
  1849. spvBuilder.createStore(
  1850. ptr,
  1851. spvBuilder.createCompositeExtract(astContext.FloatTy, *value, {i},
  1852. thisSemantic.loc),
  1853. thisSemantic.loc);
  1854. }
  1855. }
  1856. // Special handling of SV_InsideTessFactor HS patch constant output.
  1857. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1858. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1859. // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of
  1860. // TessLevelInner.
  1861. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1862. // Some developers use float[1] instead of a scalar float.
  1863. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1864. ptr = spvBuilder.createAccessChain(
  1865. astContext.FloatTy, varInstr,
  1866. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  1867. llvm::APInt(32, 0)),
  1868. thisSemantic.loc);
  1869. if (type->isArrayType()) // float[1]
  1870. *value = spvBuilder.createCompositeExtract(astContext.FloatTy, *value,
  1871. {0}, thisSemantic.loc);
  1872. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  1873. }
  1874. // Special handling of SV_Coverage, which is an unit value. We need to
  1875. // write it to the first element in the SampleMask builtin.
  1876. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1877. ptr = spvBuilder.createAccessChain(
  1878. type, varInstr,
  1879. spvBuilder.getConstantInt(astContext.UnsignedIntTy,
  1880. llvm::APInt(32, 0)),
  1881. thisSemantic.loc);
  1882. ptr->setStorageClass(spv::StorageClass::Output);
  1883. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  1884. }
  1885. // Special handling of HS ouput, for which we write to only one
  1886. // element in the per-vertex data array: the one indexed by
  1887. // SV_ControlPointID.
  1888. else if (invocationId.hasValue() && invocationId.getValue() != nullptr) {
  1889. // Remove the arrayness to get the element type.
  1890. assert(isa<ConstantArrayType>(evalType));
  1891. const auto elementType =
  1892. astContext.getAsArrayType(evalType)->getElementType();
  1893. auto index = invocationId.getValue();
  1894. ptr = spvBuilder.createAccessChain(elementType, varInstr, index,
  1895. thisSemantic.loc);
  1896. ptr->setStorageClass(spv::StorageClass::Output);
  1897. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  1898. }
  1899. // Since boolean output stage variables are represented as unsigned
  1900. // integers, we must cast the value to uint before storing.
  1901. else if (isBooleanStageIOVar(decl, type, semanticKind, sigPointKind)) {
  1902. *value =
  1903. theEmitter.castToType(*value, type, evalType, thisSemantic.loc);
  1904. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  1905. }
  1906. // For all normal cases
  1907. else {
  1908. spvBuilder.createStore(ptr, *value, thisSemantic.loc);
  1909. }
  1910. }
  1911. return true;
  1912. }
  1913. // If the decl itself doesn't have semantic string attached and there is no
  1914. // one to inherit, it should be a struct having all its fields with semantic
  1915. // strings.
  1916. if (!semanticToUse->isValid() && !type->isStructureType()) {
  1917. emitError("semantic string missing for shader %select{output|input}0 "
  1918. "variable '%1'",
  1919. loc)
  1920. << asInput << decl->getName();
  1921. return false;
  1922. }
  1923. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  1924. if (asInput) {
  1925. // If this decl translates into multiple stage input variables, we need to
  1926. // load their values into a composite.
  1927. llvm::SmallVector<SpirvInstruction *, 4> subValues;
  1928. // If we have base classes, we need to handle them first.
  1929. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1930. for (auto base : cxxDecl->bases()) {
  1931. SpirvInstruction *subValue = nullptr;
  1932. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  1933. asInput, base.getType(), arraySize, namePrefix,
  1934. invocationId, &subValue, noWriteBack,
  1935. semanticToUse))
  1936. return false;
  1937. subValues.push_back(subValue);
  1938. }
  1939. }
  1940. for (const auto *field : structDecl->fields()) {
  1941. SpirvInstruction *subValue = nullptr;
  1942. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  1943. arraySize, namePrefix, invocationId, &subValue,
  1944. noWriteBack, semanticToUse))
  1945. return false;
  1946. subValues.push_back(subValue);
  1947. }
  1948. if (arraySize == 0) {
  1949. *value = spvBuilder.createCompositeConstruct(evalType, subValues, loc);
  1950. return true;
  1951. }
  1952. // Handle the extra level of arrayness.
  1953. // We need to return an array of structs. But we get arrays of fields
  1954. // from visiting all fields. So now we need to extract all the elements
  1955. // at the same index of each field arrays and compose a new struct out
  1956. // of them.
  1957. const auto structType = type;
  1958. const auto arrayType = astContext.getConstantArrayType(
  1959. structType, llvm::APInt(32, arraySize), clang::ArrayType::Normal, 0);
  1960. llvm::SmallVector<SpirvInstruction *, 16> arrayElements;
  1961. for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) {
  1962. llvm::SmallVector<SpirvInstruction *, 8> fields;
  1963. // If we have base classes, we need to handle them first.
  1964. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1965. uint32_t baseIndex = 0;
  1966. for (auto base : cxxDecl->bases()) {
  1967. const auto baseType = base.getType();
  1968. fields.push_back(spvBuilder.createCompositeExtract(
  1969. baseType, subValues[baseIndex++], {arrayIndex}, loc));
  1970. }
  1971. }
  1972. // Extract the element at index arrayIndex from each field
  1973. for (const auto *field : structDecl->fields()) {
  1974. const auto fieldType = field->getType();
  1975. fields.push_back(spvBuilder.createCompositeExtract(
  1976. fieldType,
  1977. subValues[getNumBaseClasses(type) + field->getFieldIndex()],
  1978. {arrayIndex}, loc));
  1979. }
  1980. // Compose a new struct out of them
  1981. arrayElements.push_back(
  1982. spvBuilder.createCompositeConstruct(structType, fields, loc));
  1983. }
  1984. *value = spvBuilder.createCompositeConstruct(arrayType, arrayElements, loc);
  1985. } else {
  1986. // If we have base classes, we need to handle them first.
  1987. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1988. uint32_t baseIndex = 0;
  1989. for (auto base : cxxDecl->bases()) {
  1990. SpirvInstruction *subValue = nullptr;
  1991. if (!noWriteBack)
  1992. subValue = spvBuilder.createCompositeExtract(base.getType(), *value,
  1993. {baseIndex++}, loc);
  1994. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  1995. asInput, base.getType(), arraySize, namePrefix,
  1996. invocationId, &subValue, noWriteBack,
  1997. semanticToUse))
  1998. return false;
  1999. }
  2000. }
  2001. // Unlike reading, which may require us to read stand-alone builtins and
  2002. // stage input variables and compose an array of structs out of them,
  2003. // it happens that we don't need to write an array of structs in a bunch
  2004. // for all shader stages:
  2005. //
  2006. // * VS: output is a single struct, without extra arrayness
  2007. // * HS: output is an array of structs, with extra arrayness,
  2008. // but we only write to the struct at the InvocationID index
  2009. // * DS: output is a single struct, without extra arrayness
  2010. // * GS: output is controlled by OpEmitVertex, one vertex per time
  2011. // * MS: output is an array of structs, with extra arrayness
  2012. //
  2013. // The interesting shader stage is HS. We need the InvocationID to write
  2014. // out the value to the correct array element.
  2015. for (const auto *field : structDecl->fields()) {
  2016. const auto fieldType = field->getType();
  2017. SpirvInstruction *subValue = nullptr;
  2018. if (!noWriteBack)
  2019. subValue = spvBuilder.createCompositeExtract(
  2020. fieldType, *value,
  2021. {getNumBaseClasses(type) + field->getFieldIndex()}, loc);
  2022. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  2023. arraySize, namePrefix, invocationId, &subValue,
  2024. noWriteBack, semanticToUse))
  2025. return false;
  2026. }
  2027. }
  2028. return true;
  2029. }
  2030. bool DeclResultIdMapper::createPayloadStageVars(
  2031. const hlsl::SigPoint *sigPoint, spv::StorageClass sc, const NamedDecl *decl,
  2032. bool asInput, QualType type, const llvm::StringRef namePrefix,
  2033. SpirvInstruction **value, uint32_t payloadMemOffset) {
  2034. assert(spvContext.isMS() || spvContext.isAS());
  2035. assert(value);
  2036. if (type->isVoidType()) {
  2037. // No stage variables will be created for void type.
  2038. return true;
  2039. }
  2040. const auto loc = decl->getLocation();
  2041. if (!type->isStructureType()) {
  2042. StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
  2043. getLocationCount(astContext, type));
  2044. const auto name = namePrefix.str() + "." + decl->getNameAsString();
  2045. SpirvVariable *varInstr =
  2046. spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false, loc);
  2047. if (!varInstr)
  2048. return false;
  2049. // Even though these as user defined IO stage variables, set them as SPIR-V
  2050. // builtins in order to bypass any semantic string checks and location
  2051. // assignment.
  2052. stageVar.setIsSpirvBuiltin();
  2053. stageVar.setSpirvInstr(varInstr);
  2054. stageVars.push_back(stageVar);
  2055. // Decorate with PerTaskNV for mesh/amplification shader payload variables.
  2056. spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
  2057. varInstr->getSourceLocation());
  2058. if (asInput) {
  2059. *value = spvBuilder.createLoad(type, varInstr, loc);
  2060. } else {
  2061. spvBuilder.createStore(varInstr, *value, loc);
  2062. }
  2063. return true;
  2064. }
  2065. // This decl translates into multiple stage input/output payload variables
  2066. // and we need to load/store these individual member variables.
  2067. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  2068. llvm::SmallVector<SpirvInstruction *, 4> subValues;
  2069. AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions);
  2070. uint32_t nextMemberOffset = 0;
  2071. for (const auto *field : structDecl->fields()) {
  2072. const auto fieldType = field->getType();
  2073. SpirvInstruction *subValue = nullptr;
  2074. uint32_t memberAlignment = 0, memberSize = 0, stride = 0;
  2075. // The next avaiable offset after laying out the previous members.
  2076. std::tie(memberAlignment, memberSize) = alignmentCalc.getAlignmentAndSize(
  2077. field->getType(), spirvOptions.ampPayloadLayoutRule,
  2078. /*isRowMajor*/ llvm::None, &stride);
  2079. alignmentCalc.alignUsingHLSLRelaxedLayout(
  2080. field->getType(), memberSize, memberAlignment, &nextMemberOffset);
  2081. // The vk::offset attribute takes precedence over all.
  2082. if (field->getAttr<VKOffsetAttr>()) {
  2083. nextMemberOffset = field->getAttr<VKOffsetAttr>()->getOffset();
  2084. }
  2085. // Each payload member must have an Offset Decoration.
  2086. payloadMemOffset = nextMemberOffset;
  2087. nextMemberOffset += memberSize;
  2088. if (!asInput) {
  2089. subValue = spvBuilder.createCompositeExtract(
  2090. fieldType, *value, {getNumBaseClasses(type) + field->getFieldIndex()},
  2091. loc);
  2092. }
  2093. if (!createPayloadStageVars(sigPoint, sc, field, asInput, field->getType(),
  2094. namePrefix, &subValue, payloadMemOffset))
  2095. return false;
  2096. if (asInput) {
  2097. subValues.push_back(subValue);
  2098. }
  2099. }
  2100. if (asInput) {
  2101. *value = spvBuilder.createCompositeConstruct(type, subValues, loc);
  2102. }
  2103. return true;
  2104. }
  2105. bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
  2106. QualType type,
  2107. SpirvInstruction *value) {
  2108. assert(spvContext.isGS()); // Only for GS use
  2109. if (hlsl::IsHLSLStreamOutputType(type))
  2110. type = hlsl::GetHLSLResourceResultType(type);
  2111. if (hasGSPrimitiveTypeQualifier(decl))
  2112. type = astContext.getAsConstantArrayType(type)->getElementType();
  2113. auto semanticInfo = getStageVarSemantic(decl);
  2114. const auto loc = decl->getLocation();
  2115. if (semanticInfo.isValid()) {
  2116. // Found semantic attached directly to this Decl. Write the value for this
  2117. // Decl to the corresponding stage output variable.
  2118. // Handle SV_ClipDistance, and SV_CullDistance
  2119. if (glPerVertex.tryToAccess(
  2120. hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
  2121. semanticInfo.index, llvm::None, &value,
  2122. /*noWriteBack=*/false, /*vecComponent=*/nullptr, loc))
  2123. return true;
  2124. // Query the <result-id> for the stage output variable generated out
  2125. // of this decl.
  2126. // We have semantic string attached to this decl; therefore, it must be a
  2127. // DeclaratorDecl.
  2128. const auto found = stageVarInstructions.find(cast<DeclaratorDecl>(decl));
  2129. // We should have recorded its stage output variable previously.
  2130. assert(found != stageVarInstructions.end());
  2131. // Negate SV_Position.y if requested
  2132. if (semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position)
  2133. value = invertYIfRequested(value, loc);
  2134. // Boolean stage output variables are represented as unsigned integers.
  2135. if (isBooleanStageIOVar(decl, type, semanticInfo.semantic->GetKind(),
  2136. hlsl::SigPoint::Kind::GSOut)) {
  2137. QualType uintType = getUintTypeWithSourceComponents(astContext, type);
  2138. value = theEmitter.castToType(value, type, uintType, loc);
  2139. }
  2140. spvBuilder.createStore(found->second, value, loc);
  2141. return true;
  2142. }
  2143. // If the decl itself doesn't have semantic string attached, it should be
  2144. // a struct having all its fields with semantic strings.
  2145. if (!type->isStructureType()) {
  2146. emitError("semantic string missing for shader output variable '%0'", loc)
  2147. << decl->getName();
  2148. return false;
  2149. }
  2150. // If we have base classes, we need to handle them first.
  2151. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  2152. uint32_t baseIndex = 0;
  2153. for (auto base : cxxDecl->bases()) {
  2154. auto *subValue = spvBuilder.createCompositeExtract(base.getType(), value,
  2155. {baseIndex++}, loc);
  2156. if (!writeBackOutputStream(base.getType()->getAsCXXRecordDecl(),
  2157. base.getType(), subValue))
  2158. return false;
  2159. }
  2160. }
  2161. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  2162. // Write out each field
  2163. for (const auto *field : structDecl->fields()) {
  2164. const auto fieldType = field->getType();
  2165. auto *subValue = spvBuilder.createCompositeExtract(
  2166. fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()},
  2167. loc);
  2168. if (!writeBackOutputStream(field, field->getType(), subValue))
  2169. return false;
  2170. }
  2171. return true;
  2172. }
  2173. SpirvInstruction *
  2174. DeclResultIdMapper::invertYIfRequested(SpirvInstruction *position,
  2175. SourceLocation loc) {
  2176. // Negate SV_Position.y if requested
  2177. if (spirvOptions.invertY) {
  2178. const auto oldY = spvBuilder.createCompositeExtract(astContext.FloatTy,
  2179. position, {1}, loc);
  2180. const auto newY = spvBuilder.createUnaryOp(spv::Op::OpFNegate,
  2181. astContext.FloatTy, oldY, loc);
  2182. position = spvBuilder.createCompositeInsert(
  2183. astContext.getExtVectorType(astContext.FloatTy, 4), position, {1}, newY,
  2184. loc);
  2185. }
  2186. return position;
  2187. }
  2188. SpirvInstruction *
  2189. DeclResultIdMapper::invertWIfRequested(SpirvInstruction *position,
  2190. SourceLocation loc) {
  2191. // Reciprocate SV_Position.w if requested
  2192. if (spirvOptions.invertW && spvContext.isPS()) {
  2193. const auto oldW = spvBuilder.createCompositeExtract(astContext.FloatTy,
  2194. position, {3}, loc);
  2195. const auto newW = spvBuilder.createBinaryOp(
  2196. spv::Op::OpFDiv, astContext.FloatTy,
  2197. spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(1.0f)),
  2198. oldW, loc);
  2199. position = spvBuilder.createCompositeInsert(
  2200. astContext.getExtVectorType(astContext.FloatTy, 4), position, {3}, newW,
  2201. loc);
  2202. }
  2203. return position;
  2204. }
  2205. void DeclResultIdMapper::decorateInterpolationMode(const NamedDecl *decl,
  2206. QualType type,
  2207. SpirvVariable *varInstr) {
  2208. const auto loc = decl->getLocation();
  2209. if (isUintOrVecMatOfUintType(type) || isSintOrVecMatOfSintType(type) ||
  2210. isBoolOrVecMatOfBoolType(type)) {
  2211. // TODO: Probably we can call hlsl::ValidateSignatureElement() for the
  2212. // following check.
  2213. if (decl->getAttr<HLSLLinearAttr>() || decl->getAttr<HLSLCentroidAttr>() ||
  2214. decl->getAttr<HLSLNoPerspectiveAttr>() ||
  2215. decl->getAttr<HLSLSampleAttr>()) {
  2216. emitError("only nointerpolation mode allowed for integer input "
  2217. "parameters in pixel shader or integer output in vertex shader",
  2218. decl->getLocation());
  2219. } else {
  2220. spvBuilder.decorateFlat(varInstr, loc);
  2221. }
  2222. } else {
  2223. // Do nothing for HLSLLinearAttr since its the default
  2224. // Attributes can be used together. So cannot use else if.
  2225. if (decl->getAttr<HLSLCentroidAttr>())
  2226. spvBuilder.decorateCentroid(varInstr, loc);
  2227. if (decl->getAttr<HLSLNoInterpolationAttr>())
  2228. spvBuilder.decorateFlat(varInstr, loc);
  2229. if (decl->getAttr<HLSLNoPerspectiveAttr>())
  2230. spvBuilder.decorateNoPerspective(varInstr, loc);
  2231. if (decl->getAttr<HLSLSampleAttr>()) {
  2232. spvBuilder.decorateSample(varInstr, loc);
  2233. }
  2234. }
  2235. }
  2236. SpirvVariable *DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn,
  2237. QualType type,
  2238. SourceLocation loc) {
  2239. // Guarantee uniqueness
  2240. uint32_t spvBuiltinId = static_cast<uint32_t>(builtIn);
  2241. const auto builtInVar = builtinToVarMap.find(spvBuiltinId);
  2242. if (builtInVar != builtinToVarMap.end()) {
  2243. return builtInVar->second;
  2244. }
  2245. spv::StorageClass sc = spv::StorageClass::Max;
  2246. // Valid builtins supported
  2247. switch (builtIn) {
  2248. case spv::BuiltIn::SubgroupSize:
  2249. case spv::BuiltIn::SubgroupLocalInvocationId:
  2250. case spv::BuiltIn::HitTNV:
  2251. case spv::BuiltIn::RayTminNV:
  2252. case spv::BuiltIn::HitKindNV:
  2253. case spv::BuiltIn::IncomingRayFlagsNV:
  2254. case spv::BuiltIn::InstanceCustomIndexNV:
  2255. case spv::BuiltIn::RayGeometryIndexKHR:
  2256. case spv::BuiltIn::PrimitiveId:
  2257. case spv::BuiltIn::InstanceId:
  2258. case spv::BuiltIn::WorldRayDirectionNV:
  2259. case spv::BuiltIn::WorldRayOriginNV:
  2260. case spv::BuiltIn::ObjectRayDirectionNV:
  2261. case spv::BuiltIn::ObjectRayOriginNV:
  2262. case spv::BuiltIn::ObjectToWorldNV:
  2263. case spv::BuiltIn::WorldToObjectNV:
  2264. case spv::BuiltIn::LaunchIdNV:
  2265. case spv::BuiltIn::LaunchSizeNV:
  2266. sc = spv::StorageClass::Input;
  2267. break;
  2268. case spv::BuiltIn::PrimitiveCountNV:
  2269. case spv::BuiltIn::PrimitiveIndicesNV:
  2270. case spv::BuiltIn::TaskCountNV:
  2271. sc = spv::StorageClass::Output;
  2272. break;
  2273. default:
  2274. assert(false && "unsupported SPIR-V builtin");
  2275. return nullptr;
  2276. }
  2277. // Create a dummy StageVar for this builtin variable
  2278. auto var = spvBuilder.addStageBuiltinVar(type, sc, builtIn,
  2279. /*isPrecise*/ false, loc);
  2280. const hlsl::SigPoint *sigPoint =
  2281. hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  2282. hlsl::DxilParamInputQual::In, spvContext.getCurrentShaderModelKind(),
  2283. /*isPatchConstant=*/false));
  2284. StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
  2285. /*locCount=*/0);
  2286. stageVar.setIsSpirvBuiltin();
  2287. stageVar.setSpirvInstr(var);
  2288. stageVars.push_back(stageVar);
  2289. // Store in map for re-use
  2290. builtinToVarMap[spvBuiltinId] = var;
  2291. return var;
  2292. }
  2293. SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
  2294. StageVar *stageVar, const NamedDecl *decl, const llvm::StringRef name,
  2295. SourceLocation srcLoc) {
  2296. using spv::BuiltIn;
  2297. const auto sigPoint = stageVar->getSigPoint();
  2298. const auto semanticKind = stageVar->getSemanticInfo().getKind();
  2299. const auto sigPointKind = sigPoint->GetKind();
  2300. const auto type = stageVar->getAstType();
  2301. const auto isPrecise = decl->hasAttr<HLSLPreciseAttr>();
  2302. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  2303. if (sc == spv::StorageClass::Max)
  2304. return 0;
  2305. stageVar->setStorageClass(sc);
  2306. // [[vk::builtin(...)]] takes precedence.
  2307. if (const auto *builtinAttr = stageVar->getBuiltInAttr()) {
  2308. const auto spvBuiltIn =
  2309. llvm::StringSwitch<BuiltIn>(builtinAttr->getBuiltIn())
  2310. .Case("PointSize", BuiltIn::PointSize)
  2311. .Case("HelperInvocation", BuiltIn::HelperInvocation)
  2312. .Case("BaseVertex", BuiltIn::BaseVertex)
  2313. .Case("BaseInstance", BuiltIn::BaseInstance)
  2314. .Case("DrawIndex", BuiltIn::DrawIndex)
  2315. .Case("DeviceIndex", BuiltIn::DeviceIndex)
  2316. .Case("ViewportMaskNV", BuiltIn::ViewportMaskNV)
  2317. .Default(BuiltIn::Max);
  2318. assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
  2319. return spvBuilder.addStageBuiltinVar(type, sc, spvBuiltIn, isPrecise,
  2320. srcLoc);
  2321. }
  2322. // The following translation assumes that semantic validity in the current
  2323. // shader model is already checked, so it only covers valid SigPoints for
  2324. // each semantic.
  2325. switch (semanticKind) {
  2326. // According to DXIL spec, the Position SV can be used by all SigPoints
  2327. // other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
  2328. // According to Vulkan spec, the Position BuiltIn can only be used
  2329. // by VSOut, HS/DS/GS In/Out, MSOut.
  2330. case hlsl::Semantic::Kind::Position: {
  2331. switch (sigPointKind) {
  2332. case hlsl::SigPoint::Kind::VSIn:
  2333. case hlsl::SigPoint::Kind::PCOut:
  2334. case hlsl::SigPoint::Kind::DSIn:
  2335. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2336. case hlsl::SigPoint::Kind::VSOut:
  2337. case hlsl::SigPoint::Kind::HSCPIn:
  2338. case hlsl::SigPoint::Kind::HSCPOut:
  2339. case hlsl::SigPoint::Kind::DSCPIn:
  2340. case hlsl::SigPoint::Kind::DSOut:
  2341. case hlsl::SigPoint::Kind::GSVIn:
  2342. case hlsl::SigPoint::Kind::GSOut:
  2343. case hlsl::SigPoint::Kind::MSOut:
  2344. stageVar->setIsSpirvBuiltin();
  2345. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position,
  2346. isPrecise, srcLoc);
  2347. case hlsl::SigPoint::Kind::PSIn:
  2348. stageVar->setIsSpirvBuiltin();
  2349. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord,
  2350. isPrecise, srcLoc);
  2351. default:
  2352. llvm_unreachable("invalid usage of SV_Position sneaked in");
  2353. }
  2354. }
  2355. // According to DXIL spec, the VertexID SV can only be used by VSIn.
  2356. // According to Vulkan spec, the VertexIndex BuiltIn can only be used by
  2357. // VSIn.
  2358. case hlsl::Semantic::Kind::VertexID: {
  2359. stageVar->setIsSpirvBuiltin();
  2360. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex,
  2361. isPrecise, srcLoc);
  2362. }
  2363. // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
  2364. // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
  2365. // According to Vulkan spec, the InstanceIndex BuitIn can only be used by
  2366. // VSIn.
  2367. case hlsl::Semantic::Kind::InstanceID: {
  2368. switch (sigPointKind) {
  2369. case hlsl::SigPoint::Kind::VSIn:
  2370. stageVar->setIsSpirvBuiltin();
  2371. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex,
  2372. isPrecise, srcLoc);
  2373. case hlsl::SigPoint::Kind::VSOut:
  2374. case hlsl::SigPoint::Kind::HSCPIn:
  2375. case hlsl::SigPoint::Kind::HSCPOut:
  2376. case hlsl::SigPoint::Kind::DSCPIn:
  2377. case hlsl::SigPoint::Kind::DSOut:
  2378. case hlsl::SigPoint::Kind::GSVIn:
  2379. case hlsl::SigPoint::Kind::GSOut:
  2380. case hlsl::SigPoint::Kind::PSIn:
  2381. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2382. default:
  2383. llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
  2384. }
  2385. }
  2386. // According to DXIL spec, the Depth{|GreaterEqual|LessEqual} SV can only be
  2387. // used by PSOut.
  2388. // According to Vulkan spec, the FragDepth BuiltIn can only be used by PSOut.
  2389. case hlsl::Semantic::Kind::Depth:
  2390. case hlsl::Semantic::Kind::DepthGreaterEqual:
  2391. case hlsl::Semantic::Kind::DepthLessEqual: {
  2392. stageVar->setIsSpirvBuiltin();
  2393. // Vulkan requires the DepthReplacing execution mode to write to FragDepth.
  2394. spvBuilder.addExecutionMode(entryFunction,
  2395. spv::ExecutionMode::DepthReplacing, {}, srcLoc);
  2396. if (semanticKind == hlsl::Semantic::Kind::DepthGreaterEqual)
  2397. spvBuilder.addExecutionMode(entryFunction,
  2398. spv::ExecutionMode::DepthGreater, {}, srcLoc);
  2399. else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
  2400. spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::DepthLess,
  2401. {}, srcLoc);
  2402. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth,
  2403. isPrecise, srcLoc);
  2404. }
  2405. // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
  2406. // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn, MSIn, MSPOut, ASIn.
  2407. // According to Vulkan spec, the ClipDistance/CullDistance
  2408. // BuiltIn can only be used by VSOut, HS/DS/GS In/Out, MSOut.
  2409. case hlsl::Semantic::Kind::ClipDistance:
  2410. case hlsl::Semantic::Kind::CullDistance: {
  2411. switch (sigPointKind) {
  2412. case hlsl::SigPoint::Kind::VSIn:
  2413. case hlsl::SigPoint::Kind::PCOut:
  2414. case hlsl::SigPoint::Kind::DSIn:
  2415. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2416. case hlsl::SigPoint::Kind::VSOut:
  2417. case hlsl::SigPoint::Kind::HSCPIn:
  2418. case hlsl::SigPoint::Kind::HSCPOut:
  2419. case hlsl::SigPoint::Kind::DSCPIn:
  2420. case hlsl::SigPoint::Kind::DSOut:
  2421. case hlsl::SigPoint::Kind::GSVIn:
  2422. case hlsl::SigPoint::Kind::GSOut:
  2423. case hlsl::SigPoint::Kind::PSIn:
  2424. case hlsl::SigPoint::Kind::MSOut:
  2425. llvm_unreachable("should be handled in gl_PerVertex struct");
  2426. default:
  2427. llvm_unreachable(
  2428. "invalid usage of SV_ClipDistance/SV_CullDistance sneaked in");
  2429. }
  2430. }
  2431. // According to DXIL spec, the IsFrontFace SV can only be used by GSOut and
  2432. // PSIn.
  2433. // According to Vulkan spec, the FrontFacing BuitIn can only be used in PSIn.
  2434. case hlsl::Semantic::Kind::IsFrontFace: {
  2435. switch (sigPointKind) {
  2436. case hlsl::SigPoint::Kind::GSOut:
  2437. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2438. case hlsl::SigPoint::Kind::PSIn:
  2439. stageVar->setIsSpirvBuiltin();
  2440. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing,
  2441. isPrecise, srcLoc);
  2442. default:
  2443. llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
  2444. }
  2445. }
  2446. // According to DXIL spec, the Target SV can only be used by PSOut.
  2447. // There is no corresponding builtin decoration in SPIR-V. So generate normal
  2448. // Vulkan stage input/output variables.
  2449. case hlsl::Semantic::Kind::Target:
  2450. // An arbitrary semantic is defined by users. Generate normal Vulkan stage
  2451. // input/output variables.
  2452. case hlsl::Semantic::Kind::Arbitrary: {
  2453. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2454. // TODO: patch constant function in hull shader
  2455. }
  2456. // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
  2457. // According to Vulkan spec, the GlobalInvocationId can only be used in CSIn.
  2458. case hlsl::Semantic::Kind::DispatchThreadID: {
  2459. stageVar->setIsSpirvBuiltin();
  2460. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId,
  2461. isPrecise, srcLoc);
  2462. }
  2463. // According to DXIL spec, the GroupID SV can only be used by CSIn.
  2464. // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
  2465. case hlsl::Semantic::Kind::GroupID: {
  2466. stageVar->setIsSpirvBuiltin();
  2467. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId,
  2468. isPrecise, srcLoc);
  2469. }
  2470. // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
  2471. // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
  2472. case hlsl::Semantic::Kind::GroupThreadID: {
  2473. stageVar->setIsSpirvBuiltin();
  2474. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId,
  2475. isPrecise, srcLoc);
  2476. }
  2477. // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
  2478. // According to Vulkan spec, the LocalInvocationIndex can only be used in
  2479. // CSIn.
  2480. case hlsl::Semantic::Kind::GroupIndex: {
  2481. stageVar->setIsSpirvBuiltin();
  2482. return spvBuilder.addStageBuiltinVar(
  2483. type, sc, BuiltIn::LocalInvocationIndex, isPrecise, srcLoc);
  2484. }
  2485. // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
  2486. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  2487. // HS/GS In.
  2488. case hlsl::Semantic::Kind::OutputControlPointID: {
  2489. stageVar->setIsSpirvBuiltin();
  2490. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
  2491. isPrecise, srcLoc);
  2492. }
  2493. // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
  2494. // DSIn, GSIn, GSOut, PSIn, and MSPOut.
  2495. // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
  2496. // HS/DS/PS In, GS In/Out, MSPOut.
  2497. case hlsl::Semantic::Kind::PrimitiveID: {
  2498. // Translate to PrimitiveId BuiltIn for all valid SigPoints.
  2499. stageVar->setIsSpirvBuiltin();
  2500. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId,
  2501. isPrecise, srcLoc);
  2502. }
  2503. // According to DXIL spec, the TessFactor SV can only be used by PCOut and
  2504. // DSIn.
  2505. // According to Vulkan spec, the TessLevelOuter BuiltIn can only be used in
  2506. // PCOut and DSIn.
  2507. case hlsl::Semantic::Kind::TessFactor: {
  2508. stageVar->setIsSpirvBuiltin();
  2509. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter,
  2510. isPrecise, srcLoc);
  2511. }
  2512. // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
  2513. // and DSIn.
  2514. // According to Vulkan spec, the TessLevelInner BuiltIn can only be used in
  2515. // PCOut and DSIn.
  2516. case hlsl::Semantic::Kind::InsideTessFactor: {
  2517. stageVar->setIsSpirvBuiltin();
  2518. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner,
  2519. isPrecise, srcLoc);
  2520. }
  2521. // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
  2522. // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
  2523. case hlsl::Semantic::Kind::DomainLocation: {
  2524. stageVar->setIsSpirvBuiltin();
  2525. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord,
  2526. isPrecise, srcLoc);
  2527. }
  2528. // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
  2529. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  2530. // HS/GS In.
  2531. case hlsl::Semantic::Kind::GSInstanceID: {
  2532. stageVar->setIsSpirvBuiltin();
  2533. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId,
  2534. isPrecise, srcLoc);
  2535. }
  2536. // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
  2537. // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
  2538. case hlsl::Semantic::Kind::SampleIndex: {
  2539. stageVar->setIsSpirvBuiltin();
  2540. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId, isPrecise,
  2541. srcLoc);
  2542. }
  2543. // According to DXIL spec, the StencilRef SV can only be used by PSOut.
  2544. case hlsl::Semantic::Kind::StencilRef: {
  2545. stageVar->setIsSpirvBuiltin();
  2546. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT,
  2547. isPrecise, srcLoc);
  2548. }
  2549. // According to DXIL spec, the Barycentrics SV can only be used by PSIn.
  2550. case hlsl::Semantic::Kind::Barycentrics: {
  2551. stageVar->setIsSpirvBuiltin();
  2552. // Selecting the correct builtin according to interpolation mode
  2553. auto bi = BuiltIn::Max;
  2554. if (decl->hasAttr<HLSLNoPerspectiveAttr>()) {
  2555. if (decl->hasAttr<HLSLCentroidAttr>()) {
  2556. bi = BuiltIn::BaryCoordNoPerspCentroidAMD;
  2557. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  2558. bi = BuiltIn::BaryCoordNoPerspSampleAMD;
  2559. } else {
  2560. bi = BuiltIn::BaryCoordNoPerspAMD;
  2561. }
  2562. } else {
  2563. if (decl->hasAttr<HLSLCentroidAttr>()) {
  2564. bi = BuiltIn::BaryCoordSmoothCentroidAMD;
  2565. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  2566. bi = BuiltIn::BaryCoordSmoothSampleAMD;
  2567. } else {
  2568. bi = BuiltIn::BaryCoordSmoothAMD;
  2569. }
  2570. }
  2571. return spvBuilder.addStageBuiltinVar(type, sc, bi, isPrecise, srcLoc);
  2572. }
  2573. // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
  2574. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
  2575. // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut
  2576. // PSIn, and MSPOut.
  2577. case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
  2578. switch (sigPointKind) {
  2579. case hlsl::SigPoint::Kind::VSIn:
  2580. case hlsl::SigPoint::Kind::HSCPIn:
  2581. case hlsl::SigPoint::Kind::HSCPOut:
  2582. case hlsl::SigPoint::Kind::PCOut:
  2583. case hlsl::SigPoint::Kind::DSIn:
  2584. case hlsl::SigPoint::Kind::DSCPIn:
  2585. case hlsl::SigPoint::Kind::GSVIn:
  2586. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2587. case hlsl::SigPoint::Kind::VSOut:
  2588. case hlsl::SigPoint::Kind::DSOut:
  2589. stageVar->setIsSpirvBuiltin();
  2590. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
  2591. srcLoc);
  2592. case hlsl::SigPoint::Kind::GSOut:
  2593. case hlsl::SigPoint::Kind::PSIn:
  2594. case hlsl::SigPoint::Kind::MSPOut:
  2595. stageVar->setIsSpirvBuiltin();
  2596. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer, isPrecise,
  2597. srcLoc);
  2598. default:
  2599. llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
  2600. }
  2601. }
  2602. // According to DXIL spec, the ViewportArrayIndex SV can only be used by
  2603. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn, MSPOut.
  2604. // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
  2605. // GSOut, PSIn, and MSPOut.
  2606. case hlsl::Semantic::Kind::ViewPortArrayIndex: {
  2607. switch (sigPointKind) {
  2608. case hlsl::SigPoint::Kind::VSIn:
  2609. case hlsl::SigPoint::Kind::HSCPIn:
  2610. case hlsl::SigPoint::Kind::HSCPOut:
  2611. case hlsl::SigPoint::Kind::PCOut:
  2612. case hlsl::SigPoint::Kind::DSIn:
  2613. case hlsl::SigPoint::Kind::DSCPIn:
  2614. case hlsl::SigPoint::Kind::GSVIn:
  2615. return spvBuilder.addStageIOVar(type, sc, name.str(), isPrecise, srcLoc);
  2616. case hlsl::SigPoint::Kind::VSOut:
  2617. case hlsl::SigPoint::Kind::DSOut:
  2618. stageVar->setIsSpirvBuiltin();
  2619. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
  2620. isPrecise, srcLoc);
  2621. case hlsl::SigPoint::Kind::GSOut:
  2622. case hlsl::SigPoint::Kind::PSIn:
  2623. case hlsl::SigPoint::Kind::MSPOut:
  2624. stageVar->setIsSpirvBuiltin();
  2625. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex,
  2626. isPrecise, srcLoc);
  2627. default:
  2628. llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
  2629. }
  2630. }
  2631. // According to DXIL spec, the Coverage SV can only be used by PSIn and PSOut.
  2632. // According to Vulkan spec, the SampleMask BuiltIn can only be used in
  2633. // PSIn and PSOut.
  2634. case hlsl::Semantic::Kind::Coverage: {
  2635. stageVar->setIsSpirvBuiltin();
  2636. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask,
  2637. isPrecise, srcLoc);
  2638. }
  2639. // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
  2640. // HSIn, DSIn, GSIn, PSIn.
  2641. // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
  2642. // VS/HS/DS/GS/PS input.
  2643. case hlsl::Semantic::Kind::ViewID: {
  2644. stageVar->setIsSpirvBuiltin();
  2645. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex,
  2646. isPrecise, srcLoc);
  2647. }
  2648. // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
  2649. // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
  2650. // PSIn.
  2651. case hlsl::Semantic::Kind::InnerCoverage: {
  2652. stageVar->setIsSpirvBuiltin();
  2653. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT,
  2654. isPrecise, srcLoc);
  2655. }
  2656. // According to DXIL spec, the ShadingRate SV can only be used by GSOut,
  2657. // VSOut, or PSIn. According to Vulkan spec, the FragSizeEXT BuiltIn can only
  2658. // be used as PSIn.
  2659. case hlsl::Semantic::Kind::ShadingRate: {
  2660. switch (sigPointKind) {
  2661. case hlsl::SigPoint::Kind::PSIn:
  2662. stageVar->setIsSpirvBuiltin();
  2663. return spvBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragSizeEXT,
  2664. isPrecise, srcLoc);
  2665. default:
  2666. emitError("semantic ShadingRate currently unsupported in non-PS shader"
  2667. " stages",
  2668. srcLoc);
  2669. break;
  2670. }
  2671. break;
  2672. }
  2673. default:
  2674. emitError("semantic %0 unimplemented", srcLoc)
  2675. << stageVar->getSemanticStr();
  2676. break;
  2677. }
  2678. return 0;
  2679. }
  2680. bool DeclResultIdMapper::validateVKAttributes(const NamedDecl *decl) {
  2681. bool success = true;
  2682. if (const auto *idxAttr = decl->getAttr<VKIndexAttr>()) {
  2683. if (!spvContext.isPS()) {
  2684. emitError("vk::index only allowed in pixel shader",
  2685. idxAttr->getLocation());
  2686. success = false;
  2687. }
  2688. const auto *locAttr = decl->getAttr<VKLocationAttr>();
  2689. if (!locAttr) {
  2690. emitError("vk::index should be used together with vk::location for "
  2691. "dual-source blending",
  2692. idxAttr->getLocation());
  2693. success = false;
  2694. } else {
  2695. const auto locNumber = locAttr->getNumber();
  2696. if (locNumber != 0) {
  2697. emitError("dual-source blending should use vk::location 0",
  2698. locAttr->getLocation());
  2699. success = false;
  2700. }
  2701. }
  2702. const auto idxNumber = idxAttr->getNumber();
  2703. if (idxNumber != 0 && idxNumber != 1) {
  2704. emitError("dual-source blending only accepts 0 or 1 as vk::index",
  2705. idxAttr->getLocation());
  2706. success = false;
  2707. }
  2708. }
  2709. return success;
  2710. }
  2711. bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
  2712. const hlsl::SigPoint *sigPoint) {
  2713. bool success = true;
  2714. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  2715. // The front end parsing only allows vk::builtin to be attached to a
  2716. // function/parameter/variable; all of them are DeclaratorDecls.
  2717. const auto declType = getTypeOrFnRetType(cast<DeclaratorDecl>(decl));
  2718. const auto loc = builtinAttr->getLocation();
  2719. if (decl->hasAttr<VKLocationAttr>()) {
  2720. emitError("cannot use vk::builtin and vk::location together", loc);
  2721. success = false;
  2722. }
  2723. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  2724. if (builtin == "HelperInvocation") {
  2725. if (!declType->isBooleanType()) {
  2726. emitError("HelperInvocation builtin must be of boolean type", loc);
  2727. success = false;
  2728. }
  2729. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) {
  2730. emitError(
  2731. "HelperInvocation builtin can only be used as pixel shader input",
  2732. loc);
  2733. success = false;
  2734. }
  2735. } else if (builtin == "PointSize") {
  2736. if (!declType->isFloatingType()) {
  2737. emitError("PointSize builtin must be of float type", loc);
  2738. success = false;
  2739. }
  2740. switch (sigPoint->GetKind()) {
  2741. case hlsl::SigPoint::Kind::VSOut:
  2742. case hlsl::SigPoint::Kind::HSCPIn:
  2743. case hlsl::SigPoint::Kind::HSCPOut:
  2744. case hlsl::SigPoint::Kind::DSCPIn:
  2745. case hlsl::SigPoint::Kind::DSOut:
  2746. case hlsl::SigPoint::Kind::GSVIn:
  2747. case hlsl::SigPoint::Kind::GSOut:
  2748. case hlsl::SigPoint::Kind::PSIn:
  2749. case hlsl::SigPoint::Kind::MSOut:
  2750. break;
  2751. default:
  2752. emitError("PointSize builtin cannot be used as %0", loc)
  2753. << sigPoint->GetName();
  2754. success = false;
  2755. }
  2756. } else if (builtin == "BaseVertex" || builtin == "BaseInstance" ||
  2757. builtin == "DrawIndex") {
  2758. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  2759. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  2760. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  2761. << builtin;
  2762. success = false;
  2763. }
  2764. switch (sigPoint->GetKind()) {
  2765. case hlsl::SigPoint::Kind::VSIn:
  2766. break;
  2767. case hlsl::SigPoint::Kind::MSIn:
  2768. case hlsl::SigPoint::Kind::ASIn:
  2769. if (builtin != "DrawIndex") {
  2770. emitError("%0 builtin cannot be used as %1", loc)
  2771. << builtin << sigPoint->GetName();
  2772. success = false;
  2773. }
  2774. break;
  2775. default:
  2776. emitError("%0 builtin cannot be used as %1", loc)
  2777. << builtin << sigPoint->GetName();
  2778. success = false;
  2779. }
  2780. } else if (builtin == "DeviceIndex") {
  2781. if (getStorageClassForSigPoint(sigPoint) != spv::StorageClass::Input) {
  2782. emitError("%0 builtin can only be used as shader input", loc)
  2783. << builtin;
  2784. success = false;
  2785. }
  2786. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  2787. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  2788. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  2789. << builtin;
  2790. success = false;
  2791. }
  2792. } else if (builtin == "ViewportMaskNV") {
  2793. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::MSPOut) {
  2794. emitError("%0 builtin can only be used as 'primitives' output in MS",
  2795. loc)
  2796. << builtin;
  2797. success = false;
  2798. }
  2799. if (!declType->isArrayType() ||
  2800. !declType->getArrayElementTypeNoTypeQual()->isSpecificBuiltinType(
  2801. BuiltinType::Kind::Int)) {
  2802. emitError("%0 builtin must be of type array of integers", loc)
  2803. << builtin;
  2804. success = false;
  2805. }
  2806. }
  2807. }
  2808. return success;
  2809. }
  2810. spv::StorageClass
  2811. DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
  2812. // This translation is done based on the HLSL reference (see docs/dxil.rst).
  2813. const auto sigPointKind = sigPoint->GetKind();
  2814. const auto signatureKind = sigPoint->GetSignatureKind();
  2815. spv::StorageClass sc = spv::StorageClass::Max;
  2816. switch (signatureKind) {
  2817. case hlsl::DXIL::SignatureKind::Input:
  2818. sc = spv::StorageClass::Input;
  2819. break;
  2820. case hlsl::DXIL::SignatureKind::Output:
  2821. sc = spv::StorageClass::Output;
  2822. break;
  2823. case hlsl::DXIL::SignatureKind::Invalid: {
  2824. // There are some special cases in HLSL (See docs/dxil.rst):
  2825. // SignatureKind is "invalid" for PCIn, HSIn, GSIn, and CSIn.
  2826. switch (sigPointKind) {
  2827. case hlsl::DXIL::SigPointKind::PCIn:
  2828. case hlsl::DXIL::SigPointKind::HSIn:
  2829. case hlsl::DXIL::SigPointKind::GSIn:
  2830. case hlsl::DXIL::SigPointKind::CSIn:
  2831. case hlsl::DXIL::SigPointKind::MSIn:
  2832. case hlsl::DXIL::SigPointKind::ASIn:
  2833. sc = spv::StorageClass::Input;
  2834. break;
  2835. default:
  2836. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2837. }
  2838. break;
  2839. }
  2840. case hlsl::DXIL::SignatureKind::PatchConstOrPrim: {
  2841. // There are some special cases in HLSL (See docs/dxil.rst):
  2842. // SignatureKind is "PatchConstOrPrim" for PCOut, MSPOut and DSIn.
  2843. switch (sigPointKind) {
  2844. case hlsl::DXIL::SigPointKind::PCOut:
  2845. case hlsl::DXIL::SigPointKind::MSPOut:
  2846. // Patch Constant Output (Output of Hull which is passed to Domain).
  2847. // Mesh Shader per-primitive output attributes.
  2848. sc = spv::StorageClass::Output;
  2849. break;
  2850. case hlsl::DXIL::SigPointKind::DSIn:
  2851. // Domain Shader regular input - Patch Constant data plus system values.
  2852. sc = spv::StorageClass::Input;
  2853. break;
  2854. default:
  2855. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2856. }
  2857. break;
  2858. }
  2859. default:
  2860. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2861. }
  2862. return sc;
  2863. }
  2864. QualType DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
  2865. const DeclaratorDecl *decl, bool *shouldBeAlias) {
  2866. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  2867. // This method is only intended to be used to create SPIR-V variables in the
  2868. // Function or Private storage class.
  2869. assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
  2870. }
  2871. const QualType type = getTypeOrFnRetType(decl);
  2872. // Whether we should generate this decl as an alias variable.
  2873. bool genAlias = false;
  2874. if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
  2875. // For ConstantBuffer and TextureBuffer
  2876. if (buffer->isConstantBufferView())
  2877. genAlias = true;
  2878. } else if (isOrContainsAKindOfStructuredOrByteBuffer(type)) {
  2879. genAlias = true;
  2880. }
  2881. // Return via parameter whether alias was generated.
  2882. if (shouldBeAlias)
  2883. *shouldBeAlias = genAlias;
  2884. if (genAlias) {
  2885. needsLegalization = true;
  2886. createCounterVarForDecl(decl);
  2887. }
  2888. return type;
  2889. }
  2890. SpirvVariable *
  2891. DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
  2892. const VarDecl *decl) {
  2893. QualType type = decl->getType();
  2894. SpirvVariable *retVal = nullptr;
  2895. // Raytracing interface variables are special since they do not participate
  2896. // in any interface matching and hence do not create StageVar and
  2897. // track them under StageVars vector
  2898. const auto name = decl->getName();
  2899. switch (sc) {
  2900. case spv::StorageClass::IncomingRayPayloadNV:
  2901. case spv::StorageClass::IncomingCallableDataNV:
  2902. case spv::StorageClass::HitAttributeNV:
  2903. case spv::StorageClass::RayPayloadNV:
  2904. case spv::StorageClass::CallableDataNV:
  2905. retVal = spvBuilder.addModuleVar(type, sc, decl->hasAttr<HLSLPreciseAttr>(),
  2906. name.str());
  2907. break;
  2908. default:
  2909. assert(false && "Unsupported SPIR-V storage class for raytracing");
  2910. }
  2911. return retVal;
  2912. }
  2913. void DeclResultIdMapper::createRayTracingNVImplicitVar(const VarDecl *varDecl) {
  2914. APValue *val = varDecl->evaluateValue();
  2915. assert(val);
  2916. SpirvInstruction *constVal =
  2917. spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
  2918. constVal->setRValue(true);
  2919. astDecls[varDecl].instr = constVal;
  2920. }
  2921. } // end namespace spirv
  2922. } // end namespace clang