DeclResultIdMapper.cpp 91 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371
  1. //===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "DeclResultIdMapper.h"
  10. #include <algorithm>
  11. #include <cstring>
  12. #include <sstream>
  13. #include <unordered_map>
  14. #include "dxc/HLSL/DxilConstants.h"
  15. #include "dxc/HLSL/DxilTypeSystem.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/HlslTypes.h"
  18. #include "clang/AST/RecursiveASTVisitor.h"
  19. #include "llvm/ADT/SmallBitVector.h"
  20. #include "llvm/ADT/StringSet.h"
  21. namespace clang {
  22. namespace spirv {
  23. namespace {
  24. /// \brief Returns the stage variable's register assignment for the given Decl.
  25. const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
  26. for (auto *annotation : decl->getUnusualAnnotations()) {
  27. if (auto *reg = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
  28. return reg;
  29. }
  30. }
  31. return nullptr;
  32. }
  33. /// \brief Returns true if the given declaration has a primitive type qualifier.
  34. /// Returns false otherwise.
  35. inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
  36. return decl->hasAttr<HLSLTriangleAttr>() ||
  37. decl->hasAttr<HLSLTriangleAdjAttr>() ||
  38. decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
  39. decl->hasAttr<HLSLLineAdjAttr>();
  40. }
  41. /// \brief Deduces the parameter qualifier for the given decl.
  42. hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
  43. bool asInput) {
  44. const auto type = decl->getType();
  45. if (hlsl::IsHLSLInputPatchType(type))
  46. return hlsl::DxilParamInputQual::InputPatch;
  47. if (hlsl::IsHLSLOutputPatchType(type))
  48. return hlsl::DxilParamInputQual::OutputPatch;
  49. // TODO: Add support for multiple output streams.
  50. if (hlsl::IsHLSLStreamOutputType(type))
  51. return hlsl::DxilParamInputQual::OutStream0;
  52. // The inputs to the geometry shader that have a primitive type qualifier
  53. // must use 'InputPrimitive'.
  54. if (hasGSPrimitiveTypeQualifier(decl))
  55. return hlsl::DxilParamInputQual::InputPrimitive;
  56. return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
  57. }
  58. /// \brief Deduces the HLSL SigPoint for the given decl appearing in the given
  59. /// shader model.
  60. const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput,
  61. const hlsl::ShaderModel::Kind kind,
  62. bool forPCF) {
  63. return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  64. deduceParamQual(decl, asInput), kind, forPCF));
  65. }
  66. /// Returns the type of the given decl. If the given decl is a FunctionDecl,
  67. /// returns its result type.
  68. inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
  69. if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  70. return funcDecl->getReturnType();
  71. }
  72. return decl->getType();
  73. }
  74. /// Returns the number of base classes if this type is a derived class/struct.
  75. /// Returns zero otherwise.
  76. inline uint32_t getNumBaseClasses(QualType type) {
  77. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  78. return cxxDecl->getNumBases();
  79. return 0;
  80. }
  81. } // anonymous namespace
  82. std::string StageVar::getSemanticStr() const {
  83. // A special case for zero index, which is equivalent to no index.
  84. // Use what is in the source code.
  85. // TODO: this looks like a hack to make the current tests happy.
  86. // Should consider remove it and fix all tests.
  87. if (semanticIndex == 0)
  88. return semanticStr;
  89. std::ostringstream ss;
  90. ss << semanticName.str() << semanticIndex;
  91. return ss.str();
  92. }
  93. uint32_t CounterIdAliasPair::get(ModuleBuilder &builder,
  94. TypeTranslator &translator) const {
  95. if (isAlias) {
  96. const uint32_t counterVarType = builder.getPointerType(
  97. translator.getACSBufferCounter(), spv::StorageClass::Uniform);
  98. return builder.createLoad(counterVarType, resultId);
  99. }
  100. return resultId;
  101. }
  102. const CounterIdAliasPair *
  103. CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
  104. for (const auto &field : fields)
  105. if (field.indices == indices)
  106. return &field.counterVar;
  107. return nullptr;
  108. }
  109. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  110. ModuleBuilder &builder,
  111. TypeTranslator &translator) const {
  112. for (const auto &field : fields) {
  113. const auto *srcField = srcFields.get(field.indices);
  114. if (!srcField)
  115. return false;
  116. field.counterVar.assign(*srcField, builder, translator);
  117. }
  118. return true;
  119. }
  120. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  121. const llvm::SmallVector<uint32_t, 4> &dstPrefix,
  122. const llvm::SmallVector<uint32_t, 4> &srcPrefix,
  123. ModuleBuilder &builder,
  124. TypeTranslator &translator) const {
  125. if (dstPrefix.empty() && srcPrefix.empty())
  126. return assign(srcFields, builder, translator);
  127. llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
  128. // If whole has the given prefix, appends all elements after the prefix in
  129. // whole to srcIndices.
  130. const auto applyDiff =
  131. [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
  132. const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
  133. uint32_t i = 0;
  134. for (; i < prefix.size(); ++i)
  135. if (whole[i] != prefix[i]) {
  136. break;
  137. }
  138. if (i == prefix.size()) {
  139. for (; i < whole.size(); ++i)
  140. srcIndices.push_back(whole[i]);
  141. return true;
  142. }
  143. return false;
  144. };
  145. for (const auto &field : fields)
  146. if (applyDiff(field.indices, dstPrefix)) {
  147. const auto *srcField = srcFields.get(srcIndices);
  148. if (!srcField)
  149. return false;
  150. field.counterVar.assign(*srcField, builder, translator);
  151. for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
  152. srcIndices.pop_back();
  153. }
  154. return true;
  155. }
  156. DeclResultIdMapper::SemanticInfo
  157. DeclResultIdMapper::getStageVarSemantic(const NamedDecl *decl) {
  158. for (auto *annotation : decl->getUnusualAnnotations()) {
  159. if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
  160. llvm::StringRef semanticStr = sema->SemanticName;
  161. llvm::StringRef semanticName;
  162. uint32_t index = 0;
  163. hlsl::Semantic::DecomposeNameAndIndex(semanticStr, &semanticName, &index);
  164. const auto *semantic = hlsl::Semantic::GetByName(semanticName);
  165. return {semanticStr, semantic, semanticName, index, sema->Loc};
  166. }
  167. }
  168. return {};
  169. }
  170. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  171. uint32_t storedValue,
  172. bool forPCF) {
  173. QualType type = getTypeOrFnRetType(decl);
  174. // Output stream types (PointStream, LineStream, TriangleStream) are
  175. // translated as their underlying struct types.
  176. if (hlsl::IsHLSLStreamOutputType(type))
  177. type = hlsl::GetHLSLResourceResultType(type);
  178. const auto *sigPoint =
  179. deduceSigPoint(decl, /*asInput=*/false, shaderModel.GetKind(), forPCF);
  180. // HS output variables are created using the other overload. For the rest,
  181. // none of them should be created as arrays.
  182. assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
  183. SemanticInfo inheritSemantic = {};
  184. return createStageVars(sigPoint, decl, /*asInput=*/false, type,
  185. /*arraySize=*/0, "out.var", llvm::None, &storedValue,
  186. // Write back of stage output variables in GS is
  187. // manually controlled by .Append() intrinsic method,
  188. // implemented in writeBackOutputStream(). So
  189. // noWriteBack should be set to true for GS.
  190. shaderModel.IsGS(), &inheritSemantic);
  191. }
  192. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  193. uint32_t arraySize,
  194. uint32_t invocationId,
  195. uint32_t storedValue) {
  196. assert(shaderModel.IsHS());
  197. QualType type = getTypeOrFnRetType(decl);
  198. const auto *sigPoint =
  199. hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
  200. SemanticInfo inheritSemantic = {};
  201. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  202. "out.var", invocationId, &storedValue,
  203. /*noWriteBack=*/false, &inheritSemantic);
  204. }
  205. bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
  206. uint32_t *loadedValue,
  207. bool forPCF) {
  208. uint32_t arraySize = 0;
  209. QualType type = paramDecl->getType();
  210. // Deprive the outermost arrayness for HS/DS/GS and use arraySize
  211. // to convey that information
  212. if (hlsl::IsHLSLInputPatchType(type)) {
  213. arraySize = hlsl::GetHLSLInputPatchCount(type);
  214. type = hlsl::GetHLSLInputPatchElementType(type);
  215. } else if (hlsl::IsHLSLOutputPatchType(type)) {
  216. arraySize = hlsl::GetHLSLOutputPatchCount(type);
  217. type = hlsl::GetHLSLOutputPatchElementType(type);
  218. }
  219. if (hasGSPrimitiveTypeQualifier(paramDecl)) {
  220. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  221. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  222. type = typeDecl->getElementType();
  223. }
  224. const auto *sigPoint = deduceSigPoint(paramDecl, /*asInput=*/true,
  225. shaderModel.GetKind(), forPCF);
  226. SemanticInfo inheritSemantic = {};
  227. return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, arraySize,
  228. "in.var", llvm::None, loadedValue,
  229. /*noWriteBack=*/false, &inheritSemantic);
  230. }
  231. const DeclResultIdMapper::DeclSpirvInfo *
  232. DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
  233. auto it = astDecls.find(decl);
  234. if (it != astDecls.end())
  235. return &it->second;
  236. return nullptr;
  237. }
  238. SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
  239. bool checkRegistered) {
  240. if (const auto *info = getDeclSpirvInfo(decl))
  241. if (info->indexInCTBuffer >= 0) {
  242. // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
  243. // OpAccessChain to get the pointer to the variable since we created
  244. // a single variable for the whole buffer object.
  245. const uint32_t varType = typeTranslator.translateType(
  246. // Should only have VarDecls in a HLSLBufferDecl.
  247. cast<VarDecl>(decl)->getType(),
  248. // We need to set decorateLayout here to avoid creating SPIR-V
  249. // instructions for the current type without decorations.
  250. info->info.getLayoutRule());
  251. const uint32_t elemId = theBuilder.createAccessChain(
  252. theBuilder.getPointerType(varType, info->info.getStorageClass()),
  253. info->info, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
  254. return info->info.substResultId(elemId);
  255. } else {
  256. return *info;
  257. }
  258. if (checkRegistered) {
  259. emitFatalError("found unregistered decl", decl->getLocation())
  260. << decl->getName();
  261. emitNote("please file a bug report on "
  262. "https://github.com/Microsoft/DirectXShaderCompiler/issues with "
  263. "source code if possible",
  264. {});
  265. }
  266. return 0;
  267. }
  268. uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
  269. bool isAlias = false;
  270. auto &info = astDecls[param].info;
  271. const uint32_t type =
  272. getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias, &info);
  273. const uint32_t ptrType =
  274. theBuilder.getPointerType(type, spv::StorageClass::Function);
  275. const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
  276. info.setResultId(id);
  277. return id;
  278. }
  279. void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
  280. const QualType declType = getTypeOrFnRetType(decl);
  281. if (!counterVars.count(decl) &&
  282. TypeTranslator::isRWAppendConsumeSBuffer(declType)) {
  283. createCounterVar(decl, /*declId=*/0, /*isAlias=*/true);
  284. } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
  285. // Exclude other resource types which are represented as structs
  286. !hlsl::IsHLSLResourceType(declType)) {
  287. createFieldCounterVars(decl);
  288. }
  289. }
  290. SpirvEvalInfo DeclResultIdMapper::createFnVar(const VarDecl *var,
  291. llvm::Optional<uint32_t> init) {
  292. bool isAlias = false;
  293. auto &info = astDecls[var].info;
  294. const uint32_t type =
  295. getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
  296. const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
  297. info.setResultId(id);
  298. return info;
  299. }
  300. SpirvEvalInfo DeclResultIdMapper::createFileVar(const VarDecl *var,
  301. llvm::Optional<uint32_t> init) {
  302. bool isAlias = false;
  303. auto &info = astDecls[var].info;
  304. const uint32_t type =
  305. getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
  306. const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
  307. var->getName(), init);
  308. info.setResultId(id).setStorageClass(spv::StorageClass::Private);
  309. return info;
  310. }
  311. SpirvEvalInfo DeclResultIdMapper::createExternVar(const VarDecl *var) {
  312. auto storageClass = spv::StorageClass::UniformConstant;
  313. auto rule = LayoutRule::Void;
  314. bool isACRWSBuffer = false; // Whether is {Append|Consume|RW}StructuredBuffer
  315. if (var->getAttr<HLSLGroupSharedAttr>()) {
  316. // For CS groupshared variables
  317. storageClass = spv::StorageClass::Workgroup;
  318. } else if (TypeTranslator::isResourceType(var)) {
  319. // See through the possible outer arrays
  320. QualType resourceType = var->getType();
  321. while (resourceType->isArrayType()) {
  322. resourceType = resourceType->getAsArrayTypeUnsafe()->getElementType();
  323. }
  324. const llvm::StringRef typeName =
  325. resourceType->getAs<RecordType>()->getDecl()->getName();
  326. // These types are all translated into OpTypeStruct with BufferBlock
  327. // decoration. They should follow standard storage buffer layout,
  328. // which GLSL std430 rules statisfies.
  329. if (typeName == "StructuredBuffer" || typeName == "ByteAddressBuffer" ||
  330. typeName == "RWByteAddressBuffer") {
  331. storageClass = spv::StorageClass::Uniform;
  332. rule = spirvOptions.sBufferLayoutRule;
  333. } else if (typeName == "RWStructuredBuffer" ||
  334. typeName == "AppendStructuredBuffer" ||
  335. typeName == "ConsumeStructuredBuffer") {
  336. storageClass = spv::StorageClass::Uniform;
  337. rule = spirvOptions.sBufferLayoutRule;
  338. isACRWSBuffer = true;
  339. }
  340. } else {
  341. // This is a stand-alone externally-visiable non-resource-type variable.
  342. // They should be grouped into the $Globals cbuffer. We create that cbuffer
  343. // and record all variables inside it upon seeing the first such variable.
  344. if (astDecls.count(var) == 0)
  345. createGlobalsCBuffer(var);
  346. return astDecls[var].info;
  347. }
  348. uint32_t varType = typeTranslator.translateType(var->getType(), rule);
  349. const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
  350. var->getName(), llvm::None);
  351. const auto info =
  352. SpirvEvalInfo(id).setStorageClass(storageClass).setLayoutRule(rule);
  353. astDecls[var] = info;
  354. // Variables in Workgroup do not need descriptor decorations.
  355. if (storageClass == spv::StorageClass::Workgroup)
  356. return info;
  357. const auto *regAttr = getResourceBinding(var);
  358. const auto *bindingAttr = var->getAttr<VKBindingAttr>();
  359. const auto *counterBindingAttr = var->getAttr<VKCounterBindingAttr>();
  360. resourceVars.emplace_back(id, regAttr, bindingAttr, counterBindingAttr);
  361. if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
  362. theBuilder.decorateInputAttachmentIndex(id, inputAttachment->getIndex());
  363. if (isACRWSBuffer) {
  364. // For {Append|Consume|RW}StructuredBuffer, we need to always create another
  365. // variable for its associated counter.
  366. createCounterVar(var, id, /*isAlias=*/false);
  367. }
  368. return info;
  369. }
  370. uint32_t DeclResultIdMapper::getMatrixStructType(const VarDecl *matVar,
  371. spv::StorageClass sc,
  372. LayoutRule rule) {
  373. const auto matType = matVar->getType();
  374. assert(TypeTranslator::isMxNMatrix(matType));
  375. auto &context = *theBuilder.getSPIRVContext();
  376. llvm::SmallVector<const Decoration *, 4> decorations;
  377. const bool isRowMajor = typeTranslator.isRowMajorMatrix(matType);
  378. uint32_t stride;
  379. (void)typeTranslator.getAlignmentAndSize(matType, rule, &stride);
  380. decorations.push_back(Decoration::getOffset(context, 0, 0));
  381. decorations.push_back(Decoration::getMatrixStride(context, stride, 0));
  382. decorations.push_back(isRowMajor ? Decoration::getColMajor(context, 0)
  383. : Decoration::getRowMajor(context, 0));
  384. decorations.push_back(Decoration::getBlock(context));
  385. // Get the type for the wrapping struct
  386. const std::string structName = "type." + matVar->getName().str();
  387. return theBuilder.getStructType({typeTranslator.translateType(matType)},
  388. structName, {}, decorations);
  389. }
  390. uint32_t DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
  391. const DeclContext *decl, uint32_t arraySize,
  392. const ContextUsageKind usageKind, llvm::StringRef typeName,
  393. llvm::StringRef varName) {
  394. // cbuffers are translated into OpTypeStruct with Block decoration.
  395. // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
  396. // Push constants are translated into OpTypeStruct with Block decoration.
  397. //
  398. // Both cbuffers and tbuffers have the SPIR-V Uniform storage class.
  399. // Push constants have the SPIR-V PushConstant storage class.
  400. const bool forCBuffer = usageKind == ContextUsageKind::CBuffer;
  401. const bool forTBuffer = usageKind == ContextUsageKind::TBuffer;
  402. const bool forGlobals = usageKind == ContextUsageKind::Globals;
  403. auto &context = *theBuilder.getSPIRVContext();
  404. const LayoutRule layoutRule =
  405. (forCBuffer || forGlobals)
  406. ? spirvOptions.cBufferLayoutRule
  407. : (forTBuffer ? spirvOptions.tBufferLayoutRule
  408. : spirvOptions.sBufferLayoutRule);
  409. const auto *blockDec = forTBuffer ? Decoration::getBufferBlock(context)
  410. : Decoration::getBlock(context);
  411. const llvm::SmallVector<const Decl *, 4> &declGroup =
  412. typeTranslator.collectDeclsInDeclContext(decl);
  413. auto decorations = typeTranslator.getLayoutDecorations(declGroup, layoutRule);
  414. decorations.push_back(blockDec);
  415. // Collect the type and name for each field
  416. llvm::SmallVector<uint32_t, 4> fieldTypes;
  417. llvm::SmallVector<llvm::StringRef, 4> fieldNames;
  418. uint32_t fieldIndex = 0;
  419. for (const auto *subDecl : declGroup) {
  420. // The field can only be FieldDecl (for normal structs) or VarDecl (for
  421. // HLSLBufferDecls).
  422. assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
  423. const auto *declDecl = cast<DeclaratorDecl>(subDecl);
  424. // All fields are qualified with const. It will affect the debug name.
  425. // We don't need it here.
  426. auto varType = declDecl->getType();
  427. varType.removeLocalConst();
  428. fieldTypes.push_back(typeTranslator.translateType(varType, layoutRule));
  429. fieldNames.push_back(declDecl->getName());
  430. // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
  431. // NonWritable must be applied to all fields.
  432. if (forTBuffer) {
  433. decorations.push_back(Decoration::getNonWritable(
  434. *theBuilder.getSPIRVContext(), fieldIndex));
  435. }
  436. ++fieldIndex;
  437. }
  438. // Get the type for the whole struct
  439. uint32_t resultType =
  440. theBuilder.getStructType(fieldTypes, typeName, fieldNames, decorations);
  441. // Make an array if requested.
  442. if (arraySize)
  443. resultType = theBuilder.getArrayType(
  444. resultType, theBuilder.getConstantUint32(arraySize));
  445. // Register the <type-id> for this decl
  446. ctBufferPCTypeIds[decl] = resultType;
  447. const auto sc = usageKind == ContextUsageKind::PushConstant
  448. ? spv::StorageClass::PushConstant
  449. : spv::StorageClass::Uniform;
  450. // Create the variable for the whole struct / struct array.
  451. return theBuilder.addModuleVar(resultType, sc, varName);
  452. }
  453. uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
  454. const auto usageKind =
  455. decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
  456. const std::string structName = "type." + decl->getName().str();
  457. // The front-end does not allow arrays of cbuffer/tbuffer.
  458. const uint32_t bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  459. decl, /*arraySize*/ 0, usageKind, structName, decl->getName());
  460. // We still register all VarDecls seperately here. All the VarDecls are
  461. // mapped to the <result-id> of the buffer object, which means when querying
  462. // querying the <result-id> for a certain VarDecl, we need to do an extra
  463. // OpAccessChain.
  464. int index = 0;
  465. for (const auto *subDecl : decl->decls()) {
  466. if (TypeTranslator::shouldSkipInStructLayout(subDecl))
  467. continue;
  468. const auto *varDecl = cast<VarDecl>(subDecl);
  469. astDecls[varDecl] =
  470. SpirvEvalInfo(bufferVar)
  471. .setStorageClass(spv::StorageClass::Uniform)
  472. .setLayoutRule(decl->isCBuffer() ? spirvOptions.cBufferLayoutRule
  473. : spirvOptions.tBufferLayoutRule);
  474. astDecls[varDecl].indexInCTBuffer = index++;
  475. }
  476. resourceVars.emplace_back(bufferVar, getResourceBinding(decl),
  477. decl->getAttr<VKBindingAttr>(),
  478. decl->getAttr<VKCounterBindingAttr>());
  479. return bufferVar;
  480. }
  481. uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
  482. const auto *recordType = decl->getType()->getAs<RecordType>();
  483. uint32_t arraySize = 0;
  484. // In case we have an array of ConstantBuffer/TextureBuffer:
  485. if (!recordType) {
  486. if (const auto *arrayType =
  487. astContext.getAsConstantArrayType(decl->getType())) {
  488. recordType = arrayType->getElementType()->getAs<RecordType>();
  489. arraySize = static_cast<uint32_t>(arrayType->getSize().getZExtValue());
  490. }
  491. }
  492. assert(recordType);
  493. const auto *context = cast<HLSLBufferDecl>(decl->getDeclContext());
  494. const auto usageKind = context->isCBuffer() ? ContextUsageKind::CBuffer
  495. : ContextUsageKind::TBuffer;
  496. const char *ctBufferName =
  497. context->isCBuffer() ? "ConstantBuffer." : "TextureBuffer.";
  498. const std::string structName = "type." + std::string(ctBufferName) +
  499. recordType->getDecl()->getName().str();
  500. const uint32_t bufferVar = createStructOrStructArrayVarOfExplicitLayout(
  501. recordType->getDecl(), arraySize, usageKind, structName, decl->getName());
  502. // We register the VarDecl here.
  503. astDecls[decl] =
  504. SpirvEvalInfo(bufferVar)
  505. .setStorageClass(spv::StorageClass::Uniform)
  506. .setLayoutRule(context->isCBuffer() ? spirvOptions.cBufferLayoutRule
  507. : spirvOptions.tBufferLayoutRule);
  508. resourceVars.emplace_back(bufferVar, getResourceBinding(context),
  509. decl->getAttr<VKBindingAttr>(),
  510. decl->getAttr<VKCounterBindingAttr>());
  511. return bufferVar;
  512. }
  513. uint32_t DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
  514. // The front-end errors out if non-struct type push constant is used.
  515. const auto *recordType = decl->getType()->getAs<RecordType>();
  516. assert(recordType);
  517. const std::string structName =
  518. "type.PushConstant." + recordType->getDecl()->getName().str();
  519. const uint32_t var = createStructOrStructArrayVarOfExplicitLayout(
  520. recordType->getDecl(), /*arraySize*/ 0, ContextUsageKind::PushConstant,
  521. structName, decl->getName());
  522. // Register the VarDecl
  523. astDecls[decl] = SpirvEvalInfo(var)
  524. .setStorageClass(spv::StorageClass::PushConstant)
  525. .setLayoutRule(spirvOptions.sBufferLayoutRule);
  526. // Do not push this variable into resourceVars since it does not need
  527. // descriptor set.
  528. return var;
  529. }
  530. void DeclResultIdMapper::createGlobalsCBuffer(const VarDecl *var) {
  531. if (astDecls.count(var) != 0)
  532. return;
  533. const auto *context = var->getTranslationUnitDecl();
  534. const uint32_t globals = createStructOrStructArrayVarOfExplicitLayout(
  535. context, /*arraySize*/ 0, ContextUsageKind::Globals, "type.$Globals",
  536. "$Globals");
  537. resourceVars.emplace_back(globals, nullptr, nullptr, nullptr);
  538. uint32_t index = 0;
  539. for (const auto *decl : typeTranslator.collectDeclsInDeclContext(context))
  540. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  541. if (const auto *attr = varDecl->getAttr<VKBindingAttr>()) {
  542. emitError("variable '%0' will be placed in $Globals so cannot have "
  543. "vk::binding attribute",
  544. attr->getLocation())
  545. << var->getName();
  546. return;
  547. }
  548. astDecls[varDecl] = SpirvEvalInfo(globals)
  549. .setStorageClass(spv::StorageClass::Uniform)
  550. .setLayoutRule(spirvOptions.cBufferLayoutRule);
  551. astDecls[varDecl].indexInCTBuffer = index++;
  552. }
  553. }
  554. uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
  555. if (const auto *info = getDeclSpirvInfo(fn))
  556. return info->info;
  557. auto &info = astDecls[fn].info;
  558. bool isAlias = false;
  559. const uint32_t type =
  560. getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias, &info);
  561. const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
  562. info.setResultId(id);
  563. // No need to dereference to get the pointer. Function returns that are
  564. // stand-alone aliases are already pointers to values. All other cases should
  565. // be normal rvalues.
  566. if (!isAlias ||
  567. !TypeTranslator::isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
  568. info.setRValue();
  569. return id;
  570. }
  571. const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
  572. const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
  573. if (!decl)
  574. return nullptr;
  575. if (indices) {
  576. // Indices are provided. Walk through the fields of the decl.
  577. const auto counter = fieldCounterVars.find(decl);
  578. if (counter != fieldCounterVars.end())
  579. return counter->second.get(*indices);
  580. } else {
  581. // No indices. Check the stand-alone entities.
  582. const auto counter = counterVars.find(decl);
  583. if (counter != counterVars.end())
  584. return &counter->second;
  585. }
  586. return nullptr;
  587. }
  588. const CounterVarFields *
  589. DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
  590. if (!decl)
  591. return nullptr;
  592. const auto found = fieldCounterVars.find(decl);
  593. if (found != fieldCounterVars.end())
  594. return &found->second;
  595. return nullptr;
  596. }
  597. void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
  598. uint32_t specConstant) {
  599. astDecls[decl].info.setResultId(specConstant).setRValue().setSpecConstant();
  600. }
  601. void DeclResultIdMapper::createCounterVar(
  602. const DeclaratorDecl *decl, uint32_t declId, bool isAlias,
  603. const llvm::SmallVector<uint32_t, 4> *indices) {
  604. std::string counterName = "counter.var." + decl->getName().str();
  605. if (indices) {
  606. // Append field indices to the name
  607. for (const auto index : *indices)
  608. counterName += "." + std::to_string(index);
  609. }
  610. uint32_t counterType = typeTranslator.getACSBufferCounter();
  611. // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
  612. // Alias counter variables should be created into the Private storage class.
  613. const spv::StorageClass sc =
  614. isAlias ? spv::StorageClass::Private : spv::StorageClass::Uniform;
  615. if (isAlias) {
  616. // Apply an extra level of pointer for alias counter variable
  617. counterType =
  618. theBuilder.getPointerType(counterType, spv::StorageClass::Uniform);
  619. }
  620. const uint32_t counterId =
  621. theBuilder.addModuleVar(counterType, sc, counterName);
  622. if (!isAlias) {
  623. // Non-alias counter variables should be put in to resourceVars so that
  624. // descriptors can be allocated for them.
  625. resourceVars.emplace_back(counterId, getResourceBinding(decl),
  626. decl->getAttr<VKBindingAttr>(),
  627. decl->getAttr<VKCounterBindingAttr>(), true);
  628. assert(declId);
  629. theBuilder.decorateCounterBufferId(declId, counterId);
  630. }
  631. if (indices)
  632. fieldCounterVars[decl].append(*indices, counterId);
  633. else
  634. counterVars[decl] = {counterId, isAlias};
  635. }
  636. void DeclResultIdMapper::createFieldCounterVars(
  637. const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
  638. llvm::SmallVector<uint32_t, 4> *indices) {
  639. const QualType type = getTypeOrFnRetType(decl);
  640. const auto *recordType = type->getAs<RecordType>();
  641. assert(recordType);
  642. const auto *recordDecl = recordType->getDecl();
  643. for (const auto *field : recordDecl->fields()) {
  644. // Build up the index chain
  645. indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
  646. const QualType fieldType = field->getType();
  647. if (TypeTranslator::isRWAppendConsumeSBuffer(fieldType))
  648. createCounterVar(rootDecl, /*declId=*/0, /*isAlias=*/true, indices);
  649. else if (fieldType->isStructureType() &&
  650. !hlsl::IsHLSLResourceType(fieldType))
  651. // Go recursively into all nested structs
  652. createFieldCounterVars(rootDecl, field, indices);
  653. indices->pop_back();
  654. }
  655. }
  656. uint32_t
  657. DeclResultIdMapper::getCTBufferPushConstantTypeId(const DeclContext *decl) {
  658. const auto found = ctBufferPCTypeIds.find(decl);
  659. assert(found != ctBufferPCTypeIds.end());
  660. return found->second;
  661. }
  662. std::vector<uint32_t> DeclResultIdMapper::collectStageVars() const {
  663. std::vector<uint32_t> vars;
  664. for (auto var : glPerVertex.getStageInVars())
  665. vars.push_back(var);
  666. for (auto var : glPerVertex.getStageOutVars())
  667. vars.push_back(var);
  668. for (const auto &var : stageVars)
  669. vars.push_back(var.getSpirvId());
  670. return vars;
  671. }
  672. namespace {
  673. /// A class for managing stage input/output locations to avoid duplicate uses of
  674. /// the same location.
  675. class LocationSet {
  676. public:
  677. /// Maximum number of locations supported
  678. // Typically we won't have that many stage input or output variables.
  679. // Using 64 should be fine here.
  680. const static uint32_t kMaxLoc = 64;
  681. LocationSet() : usedLocs(kMaxLoc, false), nextLoc(0) {}
  682. /// Uses the given location.
  683. void useLoc(uint32_t loc) { usedLocs.set(loc); }
  684. /// Uses the next |count| available location.
  685. int useNextLocs(uint32_t count) {
  686. while (usedLocs[nextLoc])
  687. nextLoc++;
  688. int toUse = nextLoc;
  689. for (uint32_t i = 0; i < count; ++i) {
  690. assert(!usedLocs[nextLoc]);
  691. usedLocs.set(nextLoc++);
  692. }
  693. return toUse;
  694. }
  695. /// Returns true if the given location number is already used.
  696. bool isLocUsed(uint32_t loc) { return usedLocs[loc]; }
  697. private:
  698. llvm::SmallBitVector usedLocs; ///< All previously used locations
  699. uint32_t nextLoc; ///< Next available location
  700. };
  701. /// A class for managing resource bindings to avoid duplicate uses of the same
  702. /// set and binding number.
  703. class BindingSet {
  704. public:
  705. /// Uses the given set and binding number.
  706. void useBinding(uint32_t binding, uint32_t set) {
  707. usedBindings[set].insert(binding);
  708. }
  709. /// Uses the next avaiable binding number in set 0.
  710. uint32_t useNextBinding(uint32_t set) {
  711. auto &binding = usedBindings[set];
  712. auto &next = nextBindings[set];
  713. while (binding.count(next))
  714. ++next;
  715. binding.insert(next);
  716. return next++;
  717. }
  718. private:
  719. ///< set number -> set of used binding number
  720. llvm::DenseMap<uint32_t, llvm::DenseSet<uint32_t>> usedBindings;
  721. ///< set number -> next available binding number
  722. llvm::DenseMap<uint32_t, uint32_t> nextBindings;
  723. };
  724. } // namespace
  725. bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
  726. llvm::StringSet<> seenSemantics;
  727. bool success = true;
  728. for (const auto &var : stageVars) {
  729. auto s = var.getSemanticStr();
  730. if (s.empty()) {
  731. // We translate WaveGetLaneCount() and WaveGetLaneIndex() into builtin
  732. // variables. Those variables are inserted into the normal stage IO
  733. // processing pipeline, but with the semantics as empty strings.
  734. assert(var.isSpirvBuitin());
  735. continue;
  736. }
  737. if (forInput && var.getSigPoint()->IsInput()) {
  738. if (seenSemantics.count(s)) {
  739. emitError("input semantic '%0' used more than once", {}) << s;
  740. success = false;
  741. }
  742. seenSemantics.insert(s);
  743. } else if (!forInput && var.getSigPoint()->IsOutput()) {
  744. if (seenSemantics.count(s)) {
  745. emitError("output semantic '%0' used more than once", {}) << s;
  746. success = false;
  747. }
  748. seenSemantics.insert(s);
  749. }
  750. }
  751. return success;
  752. }
  753. bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
  754. if (!checkSemanticDuplication(forInput))
  755. return false;
  756. // Returns false if the given StageVar is an input/output variable without
  757. // explicit location assignment. Otherwise, returns true.
  758. const auto locAssigned = [forInput, this](const StageVar &v) {
  759. if (forInput == isInputStorageClass(v))
  760. // No need to assign location for builtins. Treat as assigned.
  761. return v.isSpirvBuitin() || v.getLocationAttr() != nullptr;
  762. // For the ones we don't care, treat as assigned.
  763. return true;
  764. };
  765. // If we have explicit location specified for all input/output variables,
  766. // use them instead assign by ourselves.
  767. if (std::all_of(stageVars.begin(), stageVars.end(), locAssigned)) {
  768. LocationSet locSet;
  769. bool noError = true;
  770. for (const auto &var : stageVars) {
  771. // Skip those stage variables we are not handling for this call
  772. if (forInput != isInputStorageClass(var))
  773. continue;
  774. // Skip builtins
  775. if (var.isSpirvBuitin())
  776. continue;
  777. const auto *attr = var.getLocationAttr();
  778. const auto loc = attr->getNumber();
  779. const auto attrLoc = attr->getLocation(); // Attr source code location
  780. if (loc >= LocationSet::kMaxLoc) {
  781. emitError("stage %select{output|input}0 location #%1 too large",
  782. attrLoc)
  783. << forInput << loc;
  784. return false;
  785. }
  786. // Make sure the same location is not assigned more than once
  787. if (locSet.isLocUsed(loc)) {
  788. emitError("stage %select{output|input}0 location #%1 already assigned",
  789. attrLoc)
  790. << forInput << loc;
  791. noError = false;
  792. }
  793. locSet.useLoc(loc);
  794. theBuilder.decorateLocation(var.getSpirvId(), loc);
  795. }
  796. return noError;
  797. }
  798. std::vector<const StageVar *> vars;
  799. LocationSet locSet;
  800. for (const auto &var : stageVars) {
  801. if (forInput != isInputStorageClass(var))
  802. continue;
  803. if (!var.isSpirvBuitin()) {
  804. if (var.getLocationAttr() != nullptr) {
  805. // We have checked that not all of the stage variables have explicit
  806. // location assignment.
  807. emitError("partial explicit stage %select{output|input}0 location "
  808. "assignment via vk::location(X) unsupported",
  809. {})
  810. << forInput;
  811. return false;
  812. }
  813. // Only SV_Target, SV_Depth, SV_DepthLessEqual, SV_DepthGreaterEqual,
  814. // SV_StencilRef, SV_Coverage are allowed in the pixel shader.
  815. // Arbitrary semantics are disallowed in pixel shader.
  816. if (var.getSemantic() &&
  817. var.getSemantic()->GetKind() == hlsl::Semantic::Kind::Target) {
  818. theBuilder.decorateLocation(var.getSpirvId(), var.getSemanticIndex());
  819. locSet.useLoc(var.getSemanticIndex());
  820. } else {
  821. vars.push_back(&var);
  822. }
  823. }
  824. }
  825. // If alphabetical ordering was requested, sort by semantic string.
  826. // Since HS includes 2 sets of outputs (patch-constant output and
  827. // OutputPatch), running into location mismatches between HS and DS is very
  828. // likely. In order to avoid location mismatches between HS and DS, use
  829. // alphabetical ordering.
  830. if (spirvOptions.stageIoOrder == "alpha" ||
  831. (!forInput && shaderModel.IsHS()) || (forInput && shaderModel.IsDS())) {
  832. // Sort stage input/output variables alphabetically
  833. std::sort(vars.begin(), vars.end(),
  834. [](const StageVar *a, const StageVar *b) {
  835. return a->getSemanticStr() < b->getSemanticStr();
  836. });
  837. }
  838. for (const auto *var : vars)
  839. theBuilder.decorateLocation(var->getSpirvId(),
  840. locSet.useNextLocs(var->getLocationCount()));
  841. return true;
  842. }
  843. namespace {
  844. /// A class for maintaining the binding number shift requested for descriptor
  845. /// sets.
  846. class BindingShiftMapper {
  847. public:
  848. explicit BindingShiftMapper(const llvm::SmallVectorImpl<int32_t> &shifts)
  849. : masterShift(0) {
  850. assert(shifts.size() % 2 == 0);
  851. if (shifts.size() == 2 && shifts[1] == -1) {
  852. masterShift = shifts[0];
  853. } else {
  854. for (uint32_t i = 0; i < shifts.size(); i += 2)
  855. perSetShift[shifts[i + 1]] = shifts[i];
  856. }
  857. }
  858. /// Returns the shift amount for the given set.
  859. int32_t getShiftForSet(int32_t set) const {
  860. const auto found = perSetShift.find(set);
  861. if (found != perSetShift.end())
  862. return found->second;
  863. return masterShift;
  864. }
  865. private:
  866. uint32_t masterShift; /// Shift amount applies to all sets.
  867. llvm::DenseMap<int32_t, int32_t> perSetShift;
  868. };
  869. } // namespace
  870. bool DeclResultIdMapper::decorateResourceBindings() {
  871. // For normal resource, we support 3 approaches of setting binding numbers:
  872. // - m1: [[vk::binding(...)]]
  873. // - m2: :register(...)
  874. // - m3: None
  875. //
  876. // For associated counters, we support 2 approaches:
  877. // - c1: [[vk::counter_binding(...)]
  878. // - c2: None
  879. //
  880. // In combination, we need to handle 9 cases:
  881. // - 3 cases for nomral resoures (m1, m2, m3)
  882. // - 6 cases for associated counters (mX * cY)
  883. //
  884. // In the following order:
  885. // - m1, mX * c1
  886. // - m2
  887. // - m3, mX * c2
  888. BindingSet bindingSet;
  889. // Decorates the given varId of the given category with set number
  890. // setNo, binding number bindingNo. Ignores overlaps.
  891. const auto tryToDecorate = [this, &bindingSet](const uint32_t varId,
  892. const uint32_t setNo,
  893. const uint32_t bindingNo) {
  894. bindingSet.useBinding(bindingNo, setNo);
  895. theBuilder.decorateDSetBinding(varId, setNo, bindingNo);
  896. };
  897. for (const auto &var : resourceVars) {
  898. if (var.isCounter()) {
  899. if (const auto *vkCBinding = var.getCounterBinding()) {
  900. // Process mX * c1
  901. uint32_t set = 0;
  902. if (const auto *vkBinding = var.getBinding())
  903. set = vkBinding->getSet();
  904. if (const auto *reg = var.getRegister())
  905. set = reg->RegisterSpace;
  906. tryToDecorate(var.getSpirvId(), set, vkCBinding->getBinding());
  907. }
  908. } else {
  909. if (const auto *vkBinding = var.getBinding()) {
  910. // Process m1
  911. tryToDecorate(var.getSpirvId(), vkBinding->getSet(),
  912. vkBinding->getBinding());
  913. }
  914. }
  915. }
  916. BindingShiftMapper bShiftMapper(spirvOptions.bShift);
  917. BindingShiftMapper tShiftMapper(spirvOptions.tShift);
  918. BindingShiftMapper sShiftMapper(spirvOptions.sShift);
  919. BindingShiftMapper uShiftMapper(spirvOptions.uShift);
  920. // Process m2
  921. for (const auto &var : resourceVars)
  922. if (!var.isCounter() && !var.getBinding())
  923. if (const auto *reg = var.getRegister()) {
  924. const uint32_t set = reg->RegisterSpace;
  925. uint32_t binding = reg->RegisterNumber;
  926. switch (reg->RegisterType) {
  927. case 'b':
  928. binding += bShiftMapper.getShiftForSet(set);
  929. break;
  930. case 't':
  931. binding += tShiftMapper.getShiftForSet(set);
  932. break;
  933. case 's':
  934. binding += sShiftMapper.getShiftForSet(set);
  935. break;
  936. case 'u':
  937. binding += uShiftMapper.getShiftForSet(set);
  938. break;
  939. case 'c':
  940. // For setting packing offset. Does not affect binding.
  941. break;
  942. default:
  943. llvm_unreachable("unknown register type found");
  944. }
  945. tryToDecorate(var.getSpirvId(), set, binding);
  946. }
  947. for (const auto &var : resourceVars) {
  948. if (var.isCounter()) {
  949. if (!var.getCounterBinding()) {
  950. // Process mX * c2
  951. uint32_t set = 0;
  952. if (const auto *vkBinding = var.getBinding())
  953. set = vkBinding->getSet();
  954. else if (const auto *reg = var.getRegister())
  955. set = reg->RegisterSpace;
  956. theBuilder.decorateDSetBinding(var.getSpirvId(), set,
  957. bindingSet.useNextBinding(set));
  958. }
  959. } else if (!var.getBinding() && !var.getRegister()) {
  960. // Process m3
  961. theBuilder.decorateDSetBinding(var.getSpirvId(), 0,
  962. bindingSet.useNextBinding(0));
  963. }
  964. }
  965. return true;
  966. }
  967. bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
  968. const NamedDecl *decl, bool asInput,
  969. QualType type, uint32_t arraySize,
  970. const llvm::StringRef namePrefix,
  971. llvm::Optional<uint32_t> invocationId,
  972. uint32_t *value, bool noWriteBack,
  973. SemanticInfo *inheritSemantic) {
  974. // invocationId should only be used for handling HS per-vertex output.
  975. if (invocationId.hasValue()) {
  976. assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
  977. }
  978. assert(inheritSemantic);
  979. if (type->isVoidType()) {
  980. // No stage variables will be created for void type.
  981. return true;
  982. }
  983. uint32_t typeId = typeTranslator.translateType(type);
  984. // We have several cases regarding HLSL semantics to handle here:
  985. // * If the currrent decl inherits a semantic from some enclosing entity,
  986. // use the inherited semantic no matter whether there is a semantic
  987. // attached to the current decl.
  988. // * If there is no semantic to inherit,
  989. // * If the current decl is a struct,
  990. // * If the current decl has a semantic, all its members inhert this
  991. // decl's semantic, with the index sequentially increasing;
  992. // * If the current decl does not have a semantic, all its members
  993. // should have semantics attached;
  994. // * If the current decl is not a struct, it should have semantic attached.
  995. auto thisSemantic = getStageVarSemantic(decl);
  996. // Which semantic we should use for this decl
  997. auto *semanticToUse = &thisSemantic;
  998. // Enclosing semantics override internal ones
  999. if (inheritSemantic->isValid()) {
  1000. if (thisSemantic.isValid()) {
  1001. emitWarning(
  1002. "internal semantic '%0' overridden by enclosing semantic '%1'",
  1003. thisSemantic.loc)
  1004. << thisSemantic.str << inheritSemantic->str;
  1005. }
  1006. semanticToUse = inheritSemantic;
  1007. }
  1008. if (semanticToUse->isValid() &&
  1009. // Structs with attached semantics will be handled later.
  1010. !type->isStructureType()) {
  1011. // Found semantic attached directly to this Decl. This means we need to
  1012. // map this decl to a single stage variable.
  1013. const auto semanticKind = semanticToUse->semantic->GetKind();
  1014. // Error out when the given semantic is invalid in this shader model
  1015. if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPoint->GetKind(),
  1016. shaderModel.GetMajor(),
  1017. shaderModel.GetMinor()) ==
  1018. hlsl::DXIL::SemanticInterpretationKind::NA) {
  1019. emitError("invalid usage of semantic '%0' in shader profile %1",
  1020. decl->getLocation())
  1021. << semanticToUse->str << shaderModel.GetName();
  1022. return false;
  1023. }
  1024. if (!validateVKBuiltins(decl, sigPoint))
  1025. return false;
  1026. const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>();
  1027. // For VS/HS/DS, the PointSize builtin is handled in gl_PerVertex.
  1028. // For GSVIn also in gl_PerVertex; for GSOut, it's a stand-alone
  1029. // variable handled below.
  1030. if (builtinAttr && builtinAttr->getBuiltIn() == "PointSize" &&
  1031. glPerVertex.tryToAccessPointSize(sigPoint->GetKind(), invocationId,
  1032. value, noWriteBack))
  1033. return true;
  1034. // Special handling of certain mappings between HLSL semantics and
  1035. // SPIR-V builtins:
  1036. // * SV_Position/SV_CullDistance/SV_ClipDistance should be grouped into the
  1037. // gl_PerVertex struct in vertex processing stages.
  1038. // * SV_DomainLocation can refer to a float2, whereas TessCoord is a float3.
  1039. // To ensure SPIR-V validity, we must create a float3 and extract a
  1040. // float2 from it before passing it to the main function.
  1041. // * SV_TessFactor is an array of size 2 for isoline patch, array of size 3
  1042. // for tri patch, and array of size 4 for quad patch, but it must always
  1043. // be an array of size 4 in SPIR-V for Vulkan.
  1044. // * SV_InsideTessFactor is a single float for tri patch, and an array of
  1045. // size 2 for a quad patch, but it must always be an array of size 2 in
  1046. // SPIR-V for Vulkan.
  1047. // * SV_Coverage is an uint value, but the builtin it corresponds to,
  1048. // SampleMask, must be an array of integers.
  1049. // * SV_InnerCoverage is an uint value, but the corresponding builtin,
  1050. // FullyCoveredEXT, must be an boolean value.
  1051. // * SV_DispatchThreadID, SV_GroupThreadID, and SV_GroupID are allowed to be
  1052. // uint, uint2, or uint3, but the corresponding builtins
  1053. // (GlobalInvocationId, LocalInvocationId, WorkgroupId) must be a uint3.
  1054. if (glPerVertex.tryToAccess(sigPoint->GetKind(), semanticKind,
  1055. semanticToUse->index, invocationId, value,
  1056. noWriteBack))
  1057. return true;
  1058. const uint32_t srcTypeId = typeId; // Variable type in source code
  1059. uint32_t srcVecElemTypeId = 0; // Variable element type if vector
  1060. switch (semanticKind) {
  1061. case hlsl::Semantic::Kind::DomainLocation:
  1062. typeId = theBuilder.getVecType(theBuilder.getFloat32Type(), 3);
  1063. break;
  1064. case hlsl::Semantic::Kind::TessFactor:
  1065. typeId = theBuilder.getArrayType(theBuilder.getFloat32Type(),
  1066. theBuilder.getConstantUint32(4));
  1067. break;
  1068. case hlsl::Semantic::Kind::InsideTessFactor:
  1069. typeId = theBuilder.getArrayType(theBuilder.getFloat32Type(),
  1070. theBuilder.getConstantUint32(2));
  1071. break;
  1072. case hlsl::Semantic::Kind::Coverage:
  1073. typeId = theBuilder.getArrayType(typeId, theBuilder.getConstantUint32(1));
  1074. break;
  1075. case hlsl::Semantic::Kind::InnerCoverage:
  1076. typeId = theBuilder.getBoolType();
  1077. break;
  1078. case hlsl::Semantic::Kind::Barycentrics:
  1079. typeId = theBuilder.getVecType(theBuilder.getFloat32Type(), 2);
  1080. break;
  1081. case hlsl::Semantic::Kind::DispatchThreadID:
  1082. case hlsl::Semantic::Kind::GroupThreadID:
  1083. case hlsl::Semantic::Kind::GroupID:
  1084. // Keep the original integer signedness
  1085. srcVecElemTypeId = typeTranslator.translateType(
  1086. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecElementType(type) : type);
  1087. typeId = theBuilder.getVecType(srcVecElemTypeId, 3);
  1088. break;
  1089. }
  1090. // Handle the extra arrayness
  1091. const uint32_t elementTypeId = typeId; // Array element's type
  1092. if (arraySize != 0)
  1093. typeId = theBuilder.getArrayType(typeId,
  1094. theBuilder.getConstantUint32(arraySize));
  1095. StageVar stageVar(
  1096. sigPoint, semanticToUse->str, semanticToUse->semantic,
  1097. semanticToUse->name, semanticToUse->index, builtinAttr, typeId,
  1098. // For HS/DS/GS, we have already stripped the outmost arrayness on type.
  1099. typeTranslator.getLocationCount(type));
  1100. const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
  1101. const uint32_t varId =
  1102. createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
  1103. if (varId == 0)
  1104. return false;
  1105. stageVar.setSpirvId(varId);
  1106. stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
  1107. stageVars.push_back(stageVar);
  1108. // Emit OpDecorate* instructions to link this stage variable with the HLSL
  1109. // semantic it is created for
  1110. theBuilder.decorateHlslSemantic(varId, stageVar.getSemanticStr());
  1111. // We have semantics attached to this decl, which means it must be a
  1112. // function/parameter/variable. All are DeclaratorDecls.
  1113. stageVarIds[cast<DeclaratorDecl>(decl)] = varId;
  1114. // Mark that we have used one index for this semantic
  1115. ++semanticToUse->index;
  1116. // Require extension and capability if using 16-bit types
  1117. if (typeTranslator.getElementSpirvBitwidth(type) == 16) {
  1118. theBuilder.addExtension(Extension::KHR_16bit_storage,
  1119. "16-bit stage IO variables", decl->getLocation());
  1120. theBuilder.requireCapability(spv::Capability::StorageInputOutput16);
  1121. }
  1122. // TODO: the following may not be correct?
  1123. if (sigPoint->GetSignatureKind() ==
  1124. hlsl::DXIL::SignatureKind::PatchConstant)
  1125. theBuilder.decorate(varId, spv::Decoration::Patch);
  1126. // Decorate with interpolation modes for pixel shader input variables
  1127. if (shaderModel.IsPS() && sigPoint->IsInput() &&
  1128. // BaryCoord*AMD buitins already encode the interpolation mode.
  1129. semanticKind != hlsl::Semantic::Kind::Barycentrics)
  1130. decoratePSInterpolationMode(decl, type, varId);
  1131. if (asInput) {
  1132. *value = theBuilder.createLoad(typeId, varId);
  1133. // Fix ups for corner cases
  1134. // Special handling of SV_TessFactor DS patch constant input.
  1135. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1136. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1137. // relevant indexes must be loaded.
  1138. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1139. hlsl::GetArraySize(type) != 4) {
  1140. llvm::SmallVector<uint32_t, 4> components;
  1141. const auto f32TypeId = theBuilder.getFloat32Type();
  1142. const auto tessFactorSize = hlsl::GetArraySize(type);
  1143. const auto arrType = theBuilder.getArrayType(
  1144. f32TypeId, theBuilder.getConstantUint32(tessFactorSize));
  1145. for (uint32_t i = 0; i < tessFactorSize; ++i)
  1146. components.push_back(
  1147. theBuilder.createCompositeExtract(f32TypeId, *value, {i}));
  1148. *value = theBuilder.createCompositeConstruct(arrType, components);
  1149. }
  1150. // Special handling of SV_InsideTessFactor DS patch constant input.
  1151. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1152. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1153. // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of
  1154. // TessLevelInner.
  1155. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1156. // Some developers use float[1] instead of a scalar float.
  1157. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1158. const auto f32Type = theBuilder.getFloat32Type();
  1159. *value = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1160. if (type->isArrayType()) // float[1]
  1161. *value = theBuilder.createCompositeConstruct(
  1162. theBuilder.getArrayType(f32Type, theBuilder.getConstantUint32(1)),
  1163. {*value});
  1164. }
  1165. // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord
  1166. // is always a float3. To ensure SPIR-V validity, a float3 stage variable
  1167. // is created, and we must extract a float2 from it before passing it to
  1168. // the main function.
  1169. else if (semanticKind == hlsl::Semantic::Kind::DomainLocation &&
  1170. hlsl::GetHLSLVecSize(type) != 3) {
  1171. const auto domainLocSize = hlsl::GetHLSLVecSize(type);
  1172. *value = theBuilder.createVectorShuffle(
  1173. theBuilder.getVecType(theBuilder.getFloat32Type(), domainLocSize),
  1174. *value, *value, {0, 1});
  1175. }
  1176. // Special handling of SV_Coverage, which is an uint value. We need to
  1177. // read SampleMask and extract its first element.
  1178. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1179. *value = theBuilder.createCompositeExtract(srcTypeId, *value, {0});
  1180. }
  1181. // Special handling of SV_InnerCoverage, which is an uint value. We need
  1182. // to read FullyCoveredEXT, which is a boolean value, and convert it to an
  1183. // uint value. According to D3D12 "Conservative Rasterization" doc: "The
  1184. // Pixel Shader has a 32-bit scalar integer System Generate Value
  1185. // available: InnerCoverage. This is a bit-field that has bit 0 from the
  1186. // LSB set to 1 for a given conservatively rasterized pixel, only when
  1187. // that pixel is guaranteed to be entirely inside the current primitive.
  1188. // All other input register bits must be set to 0 when bit 0 is not set,
  1189. // but are undefined when bit 0 is set to 1 (essentially, this bit-field
  1190. // represents a Boolean value where false must be exactly 0, but true can
  1191. // be any odd (i.e. bit 0 set) non-zero value)."
  1192. else if (semanticKind == hlsl::Semantic::Kind::InnerCoverage) {
  1193. *value = theBuilder.createSelect(theBuilder.getUint32Type(), *value,
  1194. theBuilder.getConstantUint32(1),
  1195. theBuilder.getConstantUint32(0));
  1196. }
  1197. // Special handling of SV_Barycentrics, which is a float3, but the
  1198. // underlying stage input variable is a float2 (only provides the first
  1199. // two components). Calculate the third element.
  1200. else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) {
  1201. const auto f32Type = theBuilder.getFloat32Type();
  1202. const auto x = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1203. const auto y = theBuilder.createCompositeExtract(f32Type, *value, {1});
  1204. const auto xy =
  1205. theBuilder.createBinaryOp(spv::Op::OpFAdd, f32Type, x, y);
  1206. const auto z = theBuilder.createBinaryOp(
  1207. spv::Op::OpFSub, f32Type, theBuilder.getConstantFloat32(1), xy);
  1208. const auto v3f32Type = theBuilder.getVecType(f32Type, 3);
  1209. *value = theBuilder.createCompositeConstruct(v3f32Type, {x, y, z});
  1210. }
  1211. // Special handling of SV_DispatchThreadID and SV_GroupThreadID, which may
  1212. // be a uint or uint2, but the underlying stage input variable is a uint3.
  1213. // The last component(s) should be discarded in needed.
  1214. else if ((semanticKind == hlsl::Semantic::Kind::DispatchThreadID ||
  1215. semanticKind == hlsl::Semantic::Kind::GroupThreadID ||
  1216. semanticKind == hlsl::Semantic::Kind::GroupID) &&
  1217. (!hlsl::IsHLSLVecType(type) ||
  1218. hlsl::GetHLSLVecSize(type) != 3)) {
  1219. assert(srcVecElemTypeId);
  1220. const auto vecSize =
  1221. hlsl::IsHLSLVecType(type) ? hlsl::GetHLSLVecSize(type) : 1;
  1222. if (vecSize == 1)
  1223. *value =
  1224. theBuilder.createCompositeExtract(srcVecElemTypeId, *value, {0});
  1225. else if (vecSize == 2)
  1226. *value = theBuilder.createVectorShuffle(
  1227. theBuilder.getVecType(srcVecElemTypeId, 2), *value, *value,
  1228. {0, 1});
  1229. }
  1230. } else {
  1231. if (noWriteBack)
  1232. return true;
  1233. uint32_t ptr = varId;
  1234. // Special handling of SV_TessFactor HS patch constant output.
  1235. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1236. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1237. // relevant indexes must be written to.
  1238. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1239. hlsl::GetArraySize(type) != 4) {
  1240. const auto f32TypeId = theBuilder.getFloat32Type();
  1241. const auto tessFactorSize = hlsl::GetArraySize(type);
  1242. for (uint32_t i = 0; i < tessFactorSize; ++i) {
  1243. const uint32_t ptrType =
  1244. theBuilder.getPointerType(f32TypeId, spv::StorageClass::Output);
  1245. ptr = theBuilder.createAccessChain(ptrType, varId,
  1246. theBuilder.getConstantUint32(i));
  1247. theBuilder.createStore(
  1248. ptr, theBuilder.createCompositeExtract(f32TypeId, *value, i));
  1249. }
  1250. }
  1251. // Special handling of SV_InsideTessFactor HS patch constant output.
  1252. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1253. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1254. // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of
  1255. // TessLevelInner.
  1256. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1257. // Some developers use float[1] instead of a scalar float.
  1258. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1259. const auto f32Type = theBuilder.getFloat32Type();
  1260. ptr = theBuilder.createAccessChain(
  1261. theBuilder.getPointerType(f32Type, spv::StorageClass::Output),
  1262. varId, theBuilder.getConstantUint32(0));
  1263. if (type->isArrayType()) // float[1]
  1264. *value = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1265. theBuilder.createStore(ptr, *value);
  1266. }
  1267. // Special handling of SV_Coverage, which is an unit value. We need to
  1268. // write it to the first element in the SampleMask builtin.
  1269. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1270. ptr = theBuilder.createAccessChain(
  1271. theBuilder.getPointerType(srcTypeId, spv::StorageClass::Output),
  1272. varId, theBuilder.getConstantUint32(0));
  1273. theBuilder.createStore(ptr, *value);
  1274. }
  1275. // Special handling of HS ouput, for which we write to only one
  1276. // element in the per-vertex data array: the one indexed by
  1277. // SV_ControlPointID.
  1278. else if (invocationId.hasValue()) {
  1279. const uint32_t ptrType =
  1280. theBuilder.getPointerType(elementTypeId, spv::StorageClass::Output);
  1281. const uint32_t index = invocationId.getValue();
  1282. ptr = theBuilder.createAccessChain(ptrType, varId, index);
  1283. theBuilder.createStore(ptr, *value);
  1284. }
  1285. // For all normal cases
  1286. else {
  1287. theBuilder.createStore(ptr, *value);
  1288. }
  1289. }
  1290. return true;
  1291. }
  1292. // If the decl itself doesn't have semantic string attached and there is no
  1293. // one to inherit, it should be a struct having all its fields with semantic
  1294. // strings.
  1295. if (!semanticToUse->isValid() && !type->isStructureType()) {
  1296. emitError("semantic string missing for shader %select{output|input}0 "
  1297. "variable '%1'",
  1298. decl->getLocation())
  1299. << asInput << decl->getName();
  1300. return false;
  1301. }
  1302. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  1303. if (asInput) {
  1304. // If this decl translates into multiple stage input variables, we need to
  1305. // load their values into a composite.
  1306. llvm::SmallVector<uint32_t, 4> subValues;
  1307. // If we have base classes, we need to handle them first.
  1308. if (const auto *cxxDecl = type->getAsCXXRecordDecl())
  1309. for (auto base : cxxDecl->bases()) {
  1310. uint32_t subValue = 0;
  1311. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  1312. asInput, base.getType(), arraySize, namePrefix,
  1313. invocationId, &subValue, noWriteBack,
  1314. semanticToUse))
  1315. return false;
  1316. subValues.push_back(subValue);
  1317. }
  1318. for (const auto *field : structDecl->fields()) {
  1319. uint32_t subValue = 0;
  1320. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  1321. arraySize, namePrefix, invocationId, &subValue,
  1322. noWriteBack, semanticToUse))
  1323. return false;
  1324. subValues.push_back(subValue);
  1325. }
  1326. if (arraySize == 0) {
  1327. *value = theBuilder.createCompositeConstruct(typeId, subValues);
  1328. return true;
  1329. }
  1330. // Handle the extra level of arrayness.
  1331. // We need to return an array of structs. But we get arrays of fields
  1332. // from visiting all fields. So now we need to extract all the elements
  1333. // at the same index of each field arrays and compose a new struct out
  1334. // of them.
  1335. const uint32_t structType = typeTranslator.translateType(type);
  1336. const uint32_t arrayType = theBuilder.getArrayType(
  1337. structType, theBuilder.getConstantUint32(arraySize));
  1338. llvm::SmallVector<uint32_t, 16> arrayElements;
  1339. for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) {
  1340. llvm::SmallVector<uint32_t, 8> fields;
  1341. // If we have base classes, we need to handle them first.
  1342. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1343. uint32_t baseIndex = 0;
  1344. for (auto base : cxxDecl->bases()) {
  1345. const auto baseType = typeTranslator.translateType(base.getType());
  1346. fields.push_back(theBuilder.createCompositeExtract(
  1347. baseType, subValues[baseIndex++], {arrayIndex}));
  1348. }
  1349. }
  1350. // Extract the element at index arrayIndex from each field
  1351. for (const auto *field : structDecl->fields()) {
  1352. const uint32_t fieldType =
  1353. typeTranslator.translateType(field->getType());
  1354. fields.push_back(theBuilder.createCompositeExtract(
  1355. fieldType,
  1356. subValues[getNumBaseClasses(type) + field->getFieldIndex()],
  1357. {arrayIndex}));
  1358. }
  1359. // Compose a new struct out of them
  1360. arrayElements.push_back(
  1361. theBuilder.createCompositeConstruct(structType, fields));
  1362. }
  1363. *value = theBuilder.createCompositeConstruct(arrayType, arrayElements);
  1364. } else {
  1365. // If we have base classes, we need to handle them first.
  1366. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1367. uint32_t baseIndex = 0;
  1368. for (auto base : cxxDecl->bases()) {
  1369. uint32_t subValue = 0;
  1370. if (!noWriteBack)
  1371. subValue = theBuilder.createCompositeExtract(
  1372. typeTranslator.translateType(base.getType()), *value,
  1373. {baseIndex++});
  1374. if (!createStageVars(sigPoint, base.getType()->getAsCXXRecordDecl(),
  1375. asInput, base.getType(), arraySize, namePrefix,
  1376. invocationId, &subValue, noWriteBack,
  1377. semanticToUse))
  1378. return false;
  1379. }
  1380. }
  1381. // Unlike reading, which may require us to read stand-alone builtins and
  1382. // stage input variables and compose an array of structs out of them,
  1383. // it happens that we don't need to write an array of structs in a bunch
  1384. // for all shader stages:
  1385. //
  1386. // * VS: output is a single struct, without extra arrayness
  1387. // * HS: output is an array of structs, with extra arrayness,
  1388. // but we only write to the struct at the InvocationID index
  1389. // * DS: output is a single struct, without extra arrayness
  1390. // * GS: output is controlled by OpEmitVertex, one vertex per time
  1391. //
  1392. // The interesting shader stage is HS. We need the InvocationID to write
  1393. // out the value to the correct array element.
  1394. for (const auto *field : structDecl->fields()) {
  1395. const uint32_t fieldType = typeTranslator.translateType(field->getType());
  1396. uint32_t subValue = 0;
  1397. if (!noWriteBack)
  1398. subValue = theBuilder.createCompositeExtract(
  1399. fieldType, *value,
  1400. {getNumBaseClasses(type) + field->getFieldIndex()});
  1401. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  1402. arraySize, namePrefix, invocationId, &subValue,
  1403. noWriteBack, semanticToUse))
  1404. return false;
  1405. }
  1406. }
  1407. return true;
  1408. }
  1409. bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
  1410. QualType type, uint32_t value) {
  1411. assert(shaderModel.IsGS()); // Only for GS use
  1412. if (hlsl::IsHLSLStreamOutputType(type))
  1413. type = hlsl::GetHLSLResourceResultType(type);
  1414. if (hasGSPrimitiveTypeQualifier(decl))
  1415. type = astContext.getAsConstantArrayType(type)->getElementType();
  1416. auto semanticInfo = getStageVarSemantic(decl);
  1417. if (semanticInfo.isValid()) {
  1418. // Found semantic attached directly to this Decl. Write the value for this
  1419. // Decl to the corresponding stage output variable.
  1420. const uint32_t srcTypeId = typeTranslator.translateType(type);
  1421. // Handle SV_Position, SV_ClipDistance, and SV_CullDistance
  1422. if (glPerVertex.tryToAccess(
  1423. hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
  1424. semanticInfo.index, llvm::None, &value, /*noWriteBack=*/false))
  1425. return true;
  1426. // Query the <result-id> for the stage output variable generated out
  1427. // of this decl.
  1428. // We have semantic string attached to this decl; therefore, it must be a
  1429. // DeclaratorDecl.
  1430. const auto found = stageVarIds.find(cast<DeclaratorDecl>(decl));
  1431. // We should have recorded its stage output variable previously.
  1432. assert(found != stageVarIds.end());
  1433. // Negate SV_Position.y if requested
  1434. if (spirvOptions.invertY &&
  1435. semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position) {
  1436. const auto f32Type = theBuilder.getFloat32Type();
  1437. const auto v4f32Type = theBuilder.getVecType(f32Type, 4);
  1438. const auto oldY = theBuilder.createCompositeExtract(f32Type, value, {1});
  1439. const auto newY =
  1440. theBuilder.createUnaryOp(spv::Op::OpFNegate, f32Type, oldY);
  1441. value = theBuilder.createCompositeInsert(v4f32Type, value, {1}, newY);
  1442. }
  1443. theBuilder.createStore(found->second, value);
  1444. return true;
  1445. }
  1446. // If the decl itself doesn't have semantic string attached, it should be
  1447. // a struct having all its fields with semantic strings.
  1448. if (!type->isStructureType()) {
  1449. emitError("semantic string missing for shader output variable '%0'",
  1450. decl->getLocation())
  1451. << decl->getName();
  1452. return false;
  1453. }
  1454. // If we have base classes, we need to handle them first.
  1455. if (const auto *cxxDecl = type->getAsCXXRecordDecl()) {
  1456. uint32_t baseIndex = 0;
  1457. for (auto base : cxxDecl->bases()) {
  1458. const auto baseType = typeTranslator.translateType(base.getType());
  1459. const auto subValue =
  1460. theBuilder.createCompositeExtract(baseType, value, {baseIndex++});
  1461. if (!writeBackOutputStream(base.getType()->getAsCXXRecordDecl(),
  1462. base.getType(), subValue))
  1463. return false;
  1464. }
  1465. }
  1466. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  1467. // Write out each field
  1468. for (const auto *field : structDecl->fields()) {
  1469. const uint32_t fieldType = typeTranslator.translateType(field->getType());
  1470. const uint32_t subValue = theBuilder.createCompositeExtract(
  1471. fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()});
  1472. if (!writeBackOutputStream(field, field->getType(), subValue))
  1473. return false;
  1474. }
  1475. return true;
  1476. }
  1477. void DeclResultIdMapper::decoratePSInterpolationMode(const NamedDecl *decl,
  1478. QualType type,
  1479. uint32_t varId) {
  1480. const QualType elemType = typeTranslator.getElementType(type);
  1481. if (elemType->isBooleanType() || elemType->isIntegerType()) {
  1482. // TODO: Probably we can call hlsl::ValidateSignatureElement() for the
  1483. // following check.
  1484. if (decl->getAttr<HLSLLinearAttr>() || decl->getAttr<HLSLCentroidAttr>() ||
  1485. decl->getAttr<HLSLNoPerspectiveAttr>() ||
  1486. decl->getAttr<HLSLSampleAttr>()) {
  1487. emitError("only nointerpolation mode allowed for integer input "
  1488. "parameters in pixel shader",
  1489. decl->getLocation());
  1490. } else {
  1491. theBuilder.decorate(varId, spv::Decoration::Flat);
  1492. }
  1493. } else {
  1494. // Do nothing for HLSLLinearAttr since its the default
  1495. // Attributes can be used together. So cannot use else if.
  1496. if (decl->getAttr<HLSLCentroidAttr>())
  1497. theBuilder.decorate(varId, spv::Decoration::Centroid);
  1498. if (decl->getAttr<HLSLNoInterpolationAttr>())
  1499. theBuilder.decorate(varId, spv::Decoration::Flat);
  1500. if (decl->getAttr<HLSLNoPerspectiveAttr>())
  1501. theBuilder.decorate(varId, spv::Decoration::NoPerspective);
  1502. if (decl->getAttr<HLSLSampleAttr>()) {
  1503. theBuilder.requireCapability(spv::Capability::SampleRateShading);
  1504. theBuilder.decorate(varId, spv::Decoration::Sample);
  1505. }
  1506. }
  1507. }
  1508. uint32_t DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn) {
  1509. // Guarantee uniqueness
  1510. switch (builtIn) {
  1511. case spv::BuiltIn::SubgroupSize:
  1512. if (laneCountBuiltinId)
  1513. return laneCountBuiltinId;
  1514. break;
  1515. case spv::BuiltIn::SubgroupLocalInvocationId:
  1516. if (laneIndexBuiltinId)
  1517. return laneIndexBuiltinId;
  1518. break;
  1519. default:
  1520. // Only allow the two cases we know about
  1521. assert(false && "unsupported builtin case");
  1522. return 0;
  1523. }
  1524. theBuilder.requireCapability(spv::Capability::GroupNonUniform);
  1525. uint32_t type = theBuilder.getUint32Type();
  1526. // Create a dummy StageVar for this builtin variable
  1527. const uint32_t varId =
  1528. theBuilder.addStageBuiltinVar(type, spv::StorageClass::Input, builtIn);
  1529. const hlsl::SigPoint *sigPoint =
  1530. hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  1531. hlsl::DxilParamInputQual::In, shaderModel.GetKind(),
  1532. /*isPatchConstant=*/false));
  1533. StageVar stageVar(sigPoint, /*semaStr=*/"", hlsl::Semantic::GetInvalid(),
  1534. /*semaName=*/"", /*semaIndex=*/0, /*builtinAttr=*/nullptr,
  1535. type, /*locCount=*/0);
  1536. stageVar.setIsSpirvBuiltin();
  1537. stageVar.setSpirvId(varId);
  1538. stageVars.push_back(stageVar);
  1539. switch (builtIn) {
  1540. case spv::BuiltIn::SubgroupSize:
  1541. laneCountBuiltinId = varId;
  1542. break;
  1543. case spv::BuiltIn::SubgroupLocalInvocationId:
  1544. laneIndexBuiltinId = varId;
  1545. break;
  1546. }
  1547. return varId;
  1548. }
  1549. uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
  1550. const NamedDecl *decl,
  1551. const llvm::StringRef name,
  1552. SourceLocation srcLoc) {
  1553. using spv::BuiltIn;
  1554. const auto sigPoint = stageVar->getSigPoint();
  1555. const auto semanticKind = stageVar->getSemantic()->GetKind();
  1556. const auto sigPointKind = sigPoint->GetKind();
  1557. const uint32_t type = stageVar->getSpirvTypeId();
  1558. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  1559. if (sc == spv::StorageClass::Max)
  1560. return 0;
  1561. stageVar->setStorageClass(sc);
  1562. // [[vk::builtin(...)]] takes precedence.
  1563. if (const auto *builtinAttr = stageVar->getBuiltInAttr()) {
  1564. const auto spvBuiltIn =
  1565. llvm::StringSwitch<BuiltIn>(builtinAttr->getBuiltIn())
  1566. .Case("PointSize", BuiltIn::PointSize)
  1567. .Case("HelperInvocation", BuiltIn::HelperInvocation)
  1568. .Case("BaseVertex", BuiltIn::BaseVertex)
  1569. .Case("BaseInstance", BuiltIn::BaseInstance)
  1570. .Case("DrawIndex", BuiltIn::DrawIndex)
  1571. .Case("DeviceIndex", BuiltIn::DeviceIndex)
  1572. .Default(BuiltIn::Max);
  1573. assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
  1574. switch (spvBuiltIn) {
  1575. case BuiltIn::BaseVertex:
  1576. case BuiltIn::BaseInstance:
  1577. case BuiltIn::DrawIndex:
  1578. theBuilder.addExtension(Extension::KHR_shader_draw_parameters,
  1579. builtinAttr->getBuiltIn(),
  1580. builtinAttr->getLocation());
  1581. theBuilder.requireCapability(spv::Capability::DrawParameters);
  1582. break;
  1583. case BuiltIn::DeviceIndex:
  1584. theBuilder.addExtension(Extension::KHR_device_group,
  1585. stageVar->getSemanticStr(), srcLoc);
  1586. theBuilder.requireCapability(spv::Capability::DeviceGroup);
  1587. break;
  1588. }
  1589. return theBuilder.addStageBuiltinVar(type, sc, spvBuiltIn);
  1590. }
  1591. // The following translation assumes that semantic validity in the current
  1592. // shader model is already checked, so it only covers valid SigPoints for
  1593. // each semantic.
  1594. switch (semanticKind) {
  1595. // According to DXIL spec, the Position SV can be used by all SigPoints
  1596. // other than PCIn, HSIn, GSIn, PSOut, CSIn.
  1597. // According to Vulkan spec, the Position BuiltIn can only be used
  1598. // by VSOut, HS/DS/GS In/Out.
  1599. case hlsl::Semantic::Kind::Position: {
  1600. switch (sigPointKind) {
  1601. case hlsl::SigPoint::Kind::VSIn:
  1602. case hlsl::SigPoint::Kind::PCOut:
  1603. case hlsl::SigPoint::Kind::DSIn:
  1604. return theBuilder.addStageIOVar(type, sc, name.str());
  1605. case hlsl::SigPoint::Kind::VSOut:
  1606. case hlsl::SigPoint::Kind::HSCPIn:
  1607. case hlsl::SigPoint::Kind::HSCPOut:
  1608. case hlsl::SigPoint::Kind::DSCPIn:
  1609. case hlsl::SigPoint::Kind::DSOut:
  1610. case hlsl::SigPoint::Kind::GSVIn:
  1611. llvm_unreachable("should be handled in gl_PerVertex struct");
  1612. case hlsl::SigPoint::Kind::GSOut:
  1613. stageVar->setIsSpirvBuiltin();
  1614. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position);
  1615. case hlsl::SigPoint::Kind::PSIn:
  1616. stageVar->setIsSpirvBuiltin();
  1617. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord);
  1618. default:
  1619. llvm_unreachable("invalid usage of SV_Position sneaked in");
  1620. }
  1621. }
  1622. // According to DXIL spec, the VertexID SV can only be used by VSIn.
  1623. // According to Vulkan spec, the VertexIndex BuiltIn can only be used by
  1624. // VSIn.
  1625. case hlsl::Semantic::Kind::VertexID: {
  1626. stageVar->setIsSpirvBuiltin();
  1627. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex);
  1628. }
  1629. // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
  1630. // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
  1631. // According to Vulkan spec, the InstanceIndex BuitIn can only be used by
  1632. // VSIn.
  1633. case hlsl::Semantic::Kind::InstanceID: {
  1634. switch (sigPointKind) {
  1635. case hlsl::SigPoint::Kind::VSIn:
  1636. stageVar->setIsSpirvBuiltin();
  1637. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex);
  1638. case hlsl::SigPoint::Kind::VSOut:
  1639. case hlsl::SigPoint::Kind::HSCPIn:
  1640. case hlsl::SigPoint::Kind::HSCPOut:
  1641. case hlsl::SigPoint::Kind::DSCPIn:
  1642. case hlsl::SigPoint::Kind::DSOut:
  1643. case hlsl::SigPoint::Kind::GSVIn:
  1644. case hlsl::SigPoint::Kind::GSOut:
  1645. case hlsl::SigPoint::Kind::PSIn:
  1646. return theBuilder.addStageIOVar(type, sc, name.str());
  1647. default:
  1648. llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
  1649. }
  1650. }
  1651. // According to DXIL spec, the Depth{|GreaterEqual|LessEqual} SV can only be
  1652. // used by PSOut.
  1653. // According to Vulkan spec, the FragDepth BuiltIn can only be used by PSOut.
  1654. case hlsl::Semantic::Kind::Depth:
  1655. case hlsl::Semantic::Kind::DepthGreaterEqual:
  1656. case hlsl::Semantic::Kind::DepthLessEqual: {
  1657. stageVar->setIsSpirvBuiltin();
  1658. // Vulkan requires the DepthReplacing execution mode to write to FragDepth.
  1659. theBuilder.addExecutionMode(entryFunctionId,
  1660. spv::ExecutionMode::DepthReplacing, {});
  1661. if (semanticKind == hlsl::Semantic::Kind::DepthGreaterEqual)
  1662. theBuilder.addExecutionMode(entryFunctionId,
  1663. spv::ExecutionMode::DepthGreater, {});
  1664. else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
  1665. theBuilder.addExecutionMode(entryFunctionId,
  1666. spv::ExecutionMode::DepthLess, {});
  1667. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth);
  1668. }
  1669. // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
  1670. // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn.
  1671. // According to Vulkan spec, the ClipDistance/CullDistance BuiltIn can only
  1672. // be
  1673. // used by VSOut, HS/DS/GS In/Out.
  1674. case hlsl::Semantic::Kind::ClipDistance:
  1675. case hlsl::Semantic::Kind::CullDistance: {
  1676. switch (sigPointKind) {
  1677. case hlsl::SigPoint::Kind::VSIn:
  1678. case hlsl::SigPoint::Kind::PCOut:
  1679. case hlsl::SigPoint::Kind::DSIn:
  1680. return theBuilder.addStageIOVar(type, sc, name.str());
  1681. case hlsl::SigPoint::Kind::VSOut:
  1682. case hlsl::SigPoint::Kind::HSCPIn:
  1683. case hlsl::SigPoint::Kind::HSCPOut:
  1684. case hlsl::SigPoint::Kind::DSCPIn:
  1685. case hlsl::SigPoint::Kind::DSOut:
  1686. case hlsl::SigPoint::Kind::GSVIn:
  1687. case hlsl::SigPoint::Kind::GSOut:
  1688. case hlsl::SigPoint::Kind::PSIn:
  1689. llvm_unreachable("should be handled in gl_PerVertex struct");
  1690. default:
  1691. llvm_unreachable(
  1692. "invalid usage of SV_ClipDistance/SV_CullDistance sneaked in");
  1693. }
  1694. }
  1695. // According to DXIL spec, the IsFrontFace SV can only be used by GSOut and
  1696. // PSIn.
  1697. // According to Vulkan spec, the FrontFacing BuitIn can only be used in PSIn.
  1698. case hlsl::Semantic::Kind::IsFrontFace: {
  1699. switch (sigPointKind) {
  1700. case hlsl::SigPoint::Kind::GSOut:
  1701. return theBuilder.addStageIOVar(type, sc, name.str());
  1702. case hlsl::SigPoint::Kind::PSIn:
  1703. stageVar->setIsSpirvBuiltin();
  1704. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing);
  1705. default:
  1706. llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
  1707. }
  1708. }
  1709. // According to DXIL spec, the Target SV can only be used by PSOut.
  1710. // There is no corresponding builtin decoration in SPIR-V. So generate normal
  1711. // Vulkan stage input/output variables.
  1712. case hlsl::Semantic::Kind::Target:
  1713. // An arbitrary semantic is defined by users. Generate normal Vulkan stage
  1714. // input/output variables.
  1715. case hlsl::Semantic::Kind::Arbitrary: {
  1716. return theBuilder.addStageIOVar(type, sc, name.str());
  1717. // TODO: patch constant function in hull shader
  1718. }
  1719. // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
  1720. // According to Vulkan spec, the GlobalInvocationId can only be used in CSIn.
  1721. case hlsl::Semantic::Kind::DispatchThreadID: {
  1722. stageVar->setIsSpirvBuiltin();
  1723. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId);
  1724. }
  1725. // According to DXIL spec, the GroupID SV can only be used by CSIn.
  1726. // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
  1727. case hlsl::Semantic::Kind::GroupID: {
  1728. stageVar->setIsSpirvBuiltin();
  1729. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId);
  1730. }
  1731. // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
  1732. // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
  1733. case hlsl::Semantic::Kind::GroupThreadID: {
  1734. stageVar->setIsSpirvBuiltin();
  1735. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId);
  1736. }
  1737. // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
  1738. // According to Vulkan spec, the LocalInvocationIndex can only be used in
  1739. // CSIn.
  1740. case hlsl::Semantic::Kind::GroupIndex: {
  1741. stageVar->setIsSpirvBuiltin();
  1742. return theBuilder.addStageBuiltinVar(type, sc,
  1743. BuiltIn::LocalInvocationIndex);
  1744. }
  1745. // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
  1746. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  1747. // HS/GS In.
  1748. case hlsl::Semantic::Kind::OutputControlPointID: {
  1749. stageVar->setIsSpirvBuiltin();
  1750. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId);
  1751. }
  1752. // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
  1753. // DSIn, GSIn, GSOut, and PSIn.
  1754. // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
  1755. // HS/DS/PS In, GS In/Out.
  1756. case hlsl::Semantic::Kind::PrimitiveID: {
  1757. // PrimitiveId requires either Tessellation or Geometry capability.
  1758. // Need to require one for PSIn.
  1759. if (sigPointKind == hlsl::SigPoint::Kind::PSIn)
  1760. theBuilder.requireCapability(spv::Capability::Geometry);
  1761. // Translate to PrimitiveId BuiltIn for all valid SigPoints.
  1762. stageVar->setIsSpirvBuiltin();
  1763. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId);
  1764. }
  1765. // According to DXIL spec, the TessFactor SV can only be used by PCOut and
  1766. // DSIn.
  1767. // According to Vulkan spec, the TessLevelOuter BuiltIn can only be used in
  1768. // PCOut and DSIn.
  1769. case hlsl::Semantic::Kind::TessFactor: {
  1770. stageVar->setIsSpirvBuiltin();
  1771. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter);
  1772. }
  1773. // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
  1774. // and DSIn.
  1775. // According to Vulkan spec, the TessLevelInner BuiltIn can only be used in
  1776. // PCOut and DSIn.
  1777. case hlsl::Semantic::Kind::InsideTessFactor: {
  1778. stageVar->setIsSpirvBuiltin();
  1779. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner);
  1780. }
  1781. // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
  1782. // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
  1783. case hlsl::Semantic::Kind::DomainLocation: {
  1784. stageVar->setIsSpirvBuiltin();
  1785. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord);
  1786. }
  1787. // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
  1788. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  1789. // HS/GS In.
  1790. case hlsl::Semantic::Kind::GSInstanceID: {
  1791. stageVar->setIsSpirvBuiltin();
  1792. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId);
  1793. }
  1794. // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
  1795. // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
  1796. case hlsl::Semantic::Kind::SampleIndex: {
  1797. theBuilder.requireCapability(spv::Capability::SampleRateShading);
  1798. stageVar->setIsSpirvBuiltin();
  1799. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId);
  1800. }
  1801. // According to DXIL spec, the StencilRef SV can only be used by PSOut.
  1802. case hlsl::Semantic::Kind::StencilRef: {
  1803. theBuilder.addExtension(Extension::EXT_shader_stencil_export,
  1804. stageVar->getSemanticStr(), srcLoc);
  1805. theBuilder.requireCapability(spv::Capability::StencilExportEXT);
  1806. stageVar->setIsSpirvBuiltin();
  1807. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT);
  1808. }
  1809. // According to DXIL spec, the ViewID SV can only be used by PSIn.
  1810. case hlsl::Semantic::Kind::Barycentrics: {
  1811. theBuilder.addExtension(Extension::AMD_shader_explicit_vertex_parameter,
  1812. stageVar->getSemanticStr(), srcLoc);
  1813. stageVar->setIsSpirvBuiltin();
  1814. // Selecting the correct builtin according to interpolation mode
  1815. auto bi = BuiltIn::Max;
  1816. if (decl->hasAttr<HLSLNoPerspectiveAttr>()) {
  1817. if (decl->hasAttr<HLSLCentroidAttr>()) {
  1818. bi = BuiltIn::BaryCoordNoPerspCentroidAMD;
  1819. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  1820. bi = BuiltIn::BaryCoordNoPerspSampleAMD;
  1821. } else {
  1822. bi = BuiltIn::BaryCoordNoPerspAMD;
  1823. }
  1824. } else {
  1825. if (decl->hasAttr<HLSLCentroidAttr>()) {
  1826. bi = BuiltIn::BaryCoordSmoothCentroidAMD;
  1827. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  1828. bi = BuiltIn::BaryCoordSmoothSampleAMD;
  1829. } else {
  1830. bi = BuiltIn::BaryCoordSmoothAMD;
  1831. }
  1832. }
  1833. return theBuilder.addStageBuiltinVar(type, sc, bi);
  1834. }
  1835. // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
  1836. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
  1837. // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut and
  1838. // PSIn.
  1839. case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
  1840. switch (sigPointKind) {
  1841. case hlsl::SigPoint::Kind::VSIn:
  1842. case hlsl::SigPoint::Kind::VSOut:
  1843. case hlsl::SigPoint::Kind::HSCPIn:
  1844. case hlsl::SigPoint::Kind::HSCPOut:
  1845. case hlsl::SigPoint::Kind::PCOut:
  1846. case hlsl::SigPoint::Kind::DSIn:
  1847. case hlsl::SigPoint::Kind::DSCPIn:
  1848. case hlsl::SigPoint::Kind::DSOut:
  1849. case hlsl::SigPoint::Kind::GSVIn:
  1850. return theBuilder.addStageIOVar(type, sc, name.str());
  1851. case hlsl::SigPoint::Kind::GSOut:
  1852. case hlsl::SigPoint::Kind::PSIn:
  1853. theBuilder.requireCapability(spv::Capability::Geometry);
  1854. stageVar->setIsSpirvBuiltin();
  1855. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer);
  1856. default:
  1857. llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
  1858. }
  1859. }
  1860. // According to DXIL spec, the ViewportArrayIndex SV can only be used by
  1861. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
  1862. // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
  1863. // GSOut and PSIn.
  1864. case hlsl::Semantic::Kind::ViewPortArrayIndex: {
  1865. switch (sigPointKind) {
  1866. case hlsl::SigPoint::Kind::VSIn:
  1867. case hlsl::SigPoint::Kind::VSOut:
  1868. case hlsl::SigPoint::Kind::HSCPIn:
  1869. case hlsl::SigPoint::Kind::HSCPOut:
  1870. case hlsl::SigPoint::Kind::PCOut:
  1871. case hlsl::SigPoint::Kind::DSIn:
  1872. case hlsl::SigPoint::Kind::DSCPIn:
  1873. case hlsl::SigPoint::Kind::DSOut:
  1874. case hlsl::SigPoint::Kind::GSVIn:
  1875. return theBuilder.addStageIOVar(type, sc, name.str());
  1876. case hlsl::SigPoint::Kind::GSOut:
  1877. case hlsl::SigPoint::Kind::PSIn:
  1878. theBuilder.requireCapability(spv::Capability::MultiViewport);
  1879. stageVar->setIsSpirvBuiltin();
  1880. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex);
  1881. default:
  1882. llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
  1883. }
  1884. }
  1885. // According to DXIL spec, the Coverage SV can only be used by PSIn and PSOut.
  1886. // According to Vulkan spec, the SampleMask BuiltIn can only be used in
  1887. // PSIn and PSOut.
  1888. case hlsl::Semantic::Kind::Coverage: {
  1889. stageVar->setIsSpirvBuiltin();
  1890. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask);
  1891. }
  1892. // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
  1893. // HSIn, DSIn, GSIn, PSIn.
  1894. // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
  1895. // VS/HS/DS/GS/PS input.
  1896. case hlsl::Semantic::Kind::ViewID: {
  1897. theBuilder.addExtension(Extension::KHR_multiview,
  1898. stageVar->getSemanticStr(), srcLoc);
  1899. theBuilder.requireCapability(spv::Capability::MultiView);
  1900. stageVar->setIsSpirvBuiltin();
  1901. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex);
  1902. }
  1903. // According to DXIL spec, the InnerCoverage SV can only be used as PSIn.
  1904. // According to Vulkan spec, the FullyCoveredEXT BuiltIn can only be used as
  1905. // PSIn.
  1906. case hlsl::Semantic::Kind::InnerCoverage: {
  1907. theBuilder.addExtension(Extension::EXT_fragment_fully_covered,
  1908. stageVar->getSemanticStr(), srcLoc);
  1909. theBuilder.requireCapability(spv::Capability::FragmentFullyCoveredEXT);
  1910. stageVar->setIsSpirvBuiltin();
  1911. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FullyCoveredEXT);
  1912. }
  1913. default:
  1914. emitError("semantic %0 unimplemented", srcLoc)
  1915. << stageVar->getSemantic()->GetName();
  1916. break;
  1917. }
  1918. return 0;
  1919. }
  1920. bool DeclResultIdMapper::validateVKBuiltins(const NamedDecl *decl,
  1921. const hlsl::SigPoint *sigPoint) {
  1922. bool success = true;
  1923. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  1924. // The front end parsing only allows vk::builtin to be attached to a
  1925. // function/parameter/variable; all of them are DeclaratorDecls.
  1926. const auto declType = getTypeOrFnRetType(cast<DeclaratorDecl>(decl));
  1927. const auto loc = builtinAttr->getLocation();
  1928. if (decl->hasAttr<VKLocationAttr>()) {
  1929. emitError("cannot use vk::builtin and vk::location together", loc);
  1930. success = false;
  1931. }
  1932. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  1933. if (builtin == "HelperInvocation") {
  1934. if (!declType->isBooleanType()) {
  1935. emitError("HelperInvocation builtin must be of boolean type", loc);
  1936. success = false;
  1937. }
  1938. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) {
  1939. emitError(
  1940. "HelperInvocation builtin can only be used as pixel shader input",
  1941. loc);
  1942. success = false;
  1943. }
  1944. } else if (builtin == "PointSize") {
  1945. if (!declType->isFloatingType()) {
  1946. emitError("PointSize builtin must be of float type", loc);
  1947. success = false;
  1948. }
  1949. switch (sigPoint->GetKind()) {
  1950. case hlsl::SigPoint::Kind::VSOut:
  1951. case hlsl::SigPoint::Kind::HSCPIn:
  1952. case hlsl::SigPoint::Kind::HSCPOut:
  1953. case hlsl::SigPoint::Kind::DSCPIn:
  1954. case hlsl::SigPoint::Kind::DSOut:
  1955. case hlsl::SigPoint::Kind::GSVIn:
  1956. case hlsl::SigPoint::Kind::GSOut:
  1957. case hlsl::SigPoint::Kind::PSIn:
  1958. break;
  1959. default:
  1960. emitError("PointSize builtin cannot be used as %0", loc)
  1961. << sigPoint->GetName();
  1962. success = false;
  1963. }
  1964. } else if (builtin == "BaseVertex" || builtin == "BaseInstance" ||
  1965. builtin == "DrawIndex") {
  1966. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  1967. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  1968. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  1969. << builtin;
  1970. success = false;
  1971. }
  1972. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::VSIn) {
  1973. emitError("%0 builtin can only be used in vertex shader input", loc)
  1974. << builtin;
  1975. success = false;
  1976. }
  1977. } else if (builtin == "DeviceIndex") {
  1978. if (getStorageClassForSigPoint(sigPoint) != spv::StorageClass::Input) {
  1979. emitError("%0 builtin can only be used as shader input", loc)
  1980. << builtin;
  1981. success = false;
  1982. }
  1983. if (!declType->isSpecificBuiltinType(BuiltinType::Kind::Int) &&
  1984. !declType->isSpecificBuiltinType(BuiltinType::Kind::UInt)) {
  1985. emitError("%0 builtin must be of 32-bit scalar integer type", loc)
  1986. << builtin;
  1987. success = false;
  1988. }
  1989. }
  1990. }
  1991. return success;
  1992. }
  1993. spv::StorageClass
  1994. DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
  1995. // This translation is done based on the HLSL reference (see docs/dxil.rst).
  1996. const auto sigPointKind = sigPoint->GetKind();
  1997. const auto signatureKind = sigPoint->GetSignatureKind();
  1998. spv::StorageClass sc = spv::StorageClass::Max;
  1999. switch (signatureKind) {
  2000. case hlsl::DXIL::SignatureKind::Input:
  2001. sc = spv::StorageClass::Input;
  2002. break;
  2003. case hlsl::DXIL::SignatureKind::Output:
  2004. sc = spv::StorageClass::Output;
  2005. break;
  2006. case hlsl::DXIL::SignatureKind::Invalid: {
  2007. // There are some special cases in HLSL (See docs/dxil.rst):
  2008. // SignatureKind is "invalid" for PCIn, HSIn, GSIn, and CSIn.
  2009. switch (sigPointKind) {
  2010. case hlsl::DXIL::SigPointKind::PCIn:
  2011. case hlsl::DXIL::SigPointKind::HSIn:
  2012. case hlsl::DXIL::SigPointKind::GSIn:
  2013. case hlsl::DXIL::SigPointKind::CSIn:
  2014. sc = spv::StorageClass::Input;
  2015. break;
  2016. default:
  2017. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2018. }
  2019. break;
  2020. }
  2021. case hlsl::DXIL::SignatureKind::PatchConstant: {
  2022. // There are some special cases in HLSL (See docs/dxil.rst):
  2023. // SignatureKind is "PatchConstant" for PCOut and DSIn.
  2024. switch (sigPointKind) {
  2025. case hlsl::DXIL::SigPointKind::PCOut:
  2026. // Patch Constant Output (Output of Hull which is passed to Domain).
  2027. sc = spv::StorageClass::Output;
  2028. break;
  2029. case hlsl::DXIL::SigPointKind::DSIn:
  2030. // Domain Shader regular input - Patch Constant data plus system values.
  2031. sc = spv::StorageClass::Input;
  2032. break;
  2033. default:
  2034. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2035. }
  2036. break;
  2037. }
  2038. default:
  2039. llvm_unreachable("Found invalid SigPoint kind for semantic");
  2040. }
  2041. return sc;
  2042. }
  2043. uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
  2044. const DeclaratorDecl *decl, bool *shouldBeAlias, SpirvEvalInfo *info) {
  2045. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  2046. // This method is only intended to be used to create SPIR-V variables in the
  2047. // Function or Private storage class.
  2048. assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
  2049. }
  2050. const QualType type = getTypeOrFnRetType(decl);
  2051. // Whether we should generate this decl as an alias variable.
  2052. bool genAlias = false;
  2053. if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
  2054. // For ConstantBuffer and TextureBuffer
  2055. if (buffer->isConstantBufferView())
  2056. genAlias = true;
  2057. } else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) {
  2058. genAlias = true;
  2059. }
  2060. if (shouldBeAlias)
  2061. *shouldBeAlias = genAlias;
  2062. if (genAlias) {
  2063. needsLegalization = true;
  2064. createCounterVarForDecl(decl);
  2065. if (info)
  2066. info->setContainsAliasComponent(true);
  2067. }
  2068. return typeTranslator.translateType(type);
  2069. }
  2070. } // end namespace spirv
  2071. } // end namespace clang