AzslcEmitter.cpp 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366
  1. /*
  2. * Copyright (c) Contributors to the Open 3D Engine Project.
  3. * For complete copyright and license terms please see the LICENSE at the root of this distribution.
  4. *
  5. * SPDX-License-Identifier: Apache-2.0 OR MIT
  6. *
  7. */
  8. #include "AzslcEmitter.h"
  9. #include <tuple>
  10. #include <cmath>
  11. #include <filesystem>
  12. namespace StdFs = std::filesystem;
  13. // We should only include the base platform emitter
  14. // Every specific implementation is supplied via a factory get method
  15. #include "AzslcPlatformEmitter.h"
  16. namespace AZ
  17. {
  18. namespace
  19. {
  20. // spec: https://github.com/microsoft/DirectXShaderCompiler/blob/master/docs/SPIR-V.rst#subpass-inputs
  21. /// extract: "Subpasses are read through two new builtin resource types, available only in pixel shader"
  22. // because we have a unified file for all stages, we need to ensure the source remains buildable for other stages.
  23. static constexpr char StubSubpassInputTypes[] = R"(
  24. #if !defined(AZ_USE_SUBPASSINPUT)
  25. class SubpassInputStub
  26. {
  27. float4 SubpassLoad(){return (float4)0;}
  28. };
  29. class SubpassInputStubMS
  30. {
  31. float4 SubpassLoad(int sampleIndex){return (float4)0;}
  32. };
  33. #define SubpassInput SubpassInputStub
  34. #define SubpassInputMS SubpassInputStubMS
  35. #endif
  36. )";
  37. }
  38. }
  39. namespace AZ::ShaderCompiler
  40. {
  41. // to activate argument dependent lookup from template utilities in AzslcUtils, this must be in a reachable namespace
  42. Streamable& operator << (Streamable& out, const AttributeInfo::Argument& arg)
  43. {
  44. if (holds_alternative<string>(arg))
  45. {
  46. return out << get<string>(arg);
  47. }
  48. else if (holds_alternative<ConstNumericVal>(arg))
  49. {
  50. const auto& constVal = get<ConstNumericVal>(arg);
  51. if (holds_alternative<int32_t>(constVal) || holds_alternative<uint32_t>(constVal))
  52. {
  53. return out << ExtractValueAsInt64(constVal, std::numeric_limits<int64_t>::min());
  54. }
  55. else if (holds_alternative<float>(constVal))
  56. {
  57. return out << ExtractValueAsFloat(constVal, std::numeric_limits<float>::infinity());
  58. }
  59. }
  60. else if (holds_alternative<bool>(arg))
  61. {
  62. return out << get<bool>(arg);
  63. }
  64. return out;
  65. }
  66. Streamable& operator << (Streamable& out, const AttributeInfo& attr)
  67. {
  68. if (!attr.m_namespace.empty())
  69. {
  70. out << attr.m_namespace << "::";
  71. }
  72. out << attr.m_attribute;
  73. if (!attr.m_argList.empty())
  74. {
  75. out << "(" << Join(attr.m_argList, ", ") << ")";
  76. }
  77. return out;
  78. }
  79. using std::for_each;
  80. void CodeEmitter::Run(const Options& options)
  81. {
  82. const uint32_t numOf32bitConst = GetNumberOf32BitConstants(options, m_ir->m_rootConstantStructUID);
  83. const RootSigDesc rootSig = BuildSignatureDescription(options, numOf32bitConst);
  84. SetupScopeMigrations(options);
  85. SetupOptionsSpecializationId(options);
  86. // Emit global attributes
  87. for (const auto& attr : m_ir->m_symbols.GetGlobalAttributeList())
  88. {
  89. EmitAttribute(attr);
  90. }
  91. if (m_ir->m_sema.m_subpassInputSeen)
  92. {
  93. m_out << StubSubpassInputTypes << "\n";
  94. }
  95. EmitGetterFunctionDeclarationsForRootConstants(m_ir->m_rootConstantStructUID);
  96. // loop on the (mostly) user-defined order of code entities, and emit
  97. for (const IdentifierUID& iteratedSymbolUid : m_ir->m_symbols.GetOrderedSymbols())
  98. {
  99. const QualifiedNameView iteratedSymbolName = iteratedSymbolUid.GetName();
  100. const Kind iteratedSymbolKind = m_ir->GetKind(iteratedSymbolUid);
  101. switch (iteratedSymbolKind)
  102. {
  103. // top-level enums, structs and classes, as well as immediate-type-declaration enum/structs (`struct S{} s;`)
  104. case Kind::Interface:
  105. case Kind::Struct:
  106. case Kind::Class:
  107. case Kind::Enum:
  108. {
  109. if (IsTopLevelThroughTranslation(iteratedSymbolUid))
  110. {
  111. EmitPreprocessorLineDirective(iteratedSymbolName);
  112. auto* classInfo = m_ir->GetSymbolSubAs<ClassInfo>(iteratedSymbolName);
  113. iteratedSymbolKind == Kind::Enum ?
  114. EmitEnum(iteratedSymbolUid, *classInfo, options)
  115. : EmitStruct(*classInfo, iteratedSymbolName, options);
  116. }
  117. break;
  118. }
  119. // typedefs
  120. case Kind::TypeAlias:
  121. {
  122. if (IsTopLevelThroughTranslation(iteratedSymbolUid))
  123. {
  124. EmitPreprocessorLineDirective(iteratedSymbolName);
  125. auto* aliasInfo = m_ir->GetSymbolSubAs<TypeAliasInfo>(iteratedSymbolName);
  126. EmitTypeAlias(iteratedSymbolUid, *aliasInfo, options);
  127. }
  128. break;
  129. }
  130. // global variables
  131. case Kind::Variable:
  132. {
  133. if (IsTopLevelThroughTranslation(iteratedSymbolUid))
  134. {
  135. auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(iteratedSymbolName);
  136. if (varInfo->CheckHasStorageFlag(StorageFlag::Enumerator))
  137. { // Enumerators have already been emitted in the Kind::Enum case.
  138. // They act as static const, but no need to emit them here
  139. break;
  140. }
  141. // Option variables are emitted after the ShaderVariantFallback SRG
  142. if (varInfo->CheckHasStorageFlag(StorageFlag::Option))
  143. {
  144. EmitShaderVariantOptionVariableDeclaration(iteratedSymbolUid, options);
  145. break;
  146. }
  147. // a non-global extern is an SRG-variable. it should be emitted by EmitSRG
  148. bool global = IsGlobal(iteratedSymbolName);
  149. if (!global && !varInfo->StorageFlagIsLocalLinkage(global || varInfo->m_srgMember))
  150. {
  151. break;
  152. }
  153. EmitPreprocessorLineDirective(iteratedSymbolName);
  154. EmitVariableDeclaration(*varInfo, iteratedSymbolUid, options, VarDeclHasFlag(VarDeclHas::Initializer));
  155. m_out << ";\n";
  156. }
  157. break;
  158. }
  159. // SRG
  160. case Kind::ShaderResourceGroup:
  161. {
  162. EmitPreprocessorLineDirective(iteratedSymbolName);
  163. auto* srgSub = m_ir->GetSymbolSubAs<SRGInfo>(iteratedSymbolName);
  164. EmitSRG(*srgSub, iteratedSymbolUid, options, rootSig);
  165. break;
  166. }
  167. // function
  168. case Kind::Function:
  169. {
  170. EmitPreprocessorLineDirective(iteratedSymbolName);
  171. auto* funcSub = m_ir->GetSymbolSubAs<FunctionInfo>(iteratedSymbolName);
  172. const bool alreadyDeclared = AlreadyEmittedFunctionDeclaration(iteratedSymbolUid);
  173. assert(!funcSub->IsEmpty());
  174. const EmitFunctionAs form = (funcSub->HasUniqueDeclarationThroughDefinition() || alreadyDeclared) ?
  175. EmitFunctionAs::Definition : EmitFunctionAs::Declaration;
  176. EmitFunction(*funcSub, iteratedSymbolUid, form, options);
  177. break;
  178. }
  179. default: break;
  180. } // end switch on code entities kind
  181. } // end for all entities
  182. EmitRootConstants(rootSig, options);
  183. EmitShaderVariantOptionGetters(options);
  184. if (options.m_emitRootSig)
  185. {
  186. m_out << GetPlatformEmitter().GetRootSig(*this, rootSig, options, BindingPair::Set::Merged);
  187. }
  188. }
  189. //! azslSymbolName is going to be considered as a startup point of migration (e.g a scope)
  190. //! and all symbols under that scope will be migrated equivalently.
  191. void CodeEmitter::MigrateASTSubTree(const IdentifierUID& azslSymbol, QualifiedNameView landingScope)
  192. {
  193. m_translations.RegisterLandingScope(azslSymbol, landingScope);
  194. // special case for non-scoped enums because they have dependent symbols that are not children
  195. if (m_ir->GetKind(azslSymbol) == Kind::Enum)
  196. {
  197. auto* sub = m_ir->GetSymbolSubAs<ClassInfo>(azslSymbol.GetName());
  198. bool scoped = sub->Get<EnumerationInfo>()->m_isScoped;
  199. if (!scoped)
  200. {
  201. // migrate enumerators too
  202. for (auto& enumerator : sub->GetMemberFields())
  203. {
  204. m_translations.RegisterLandingScope(enumerator, landingScope);
  205. }
  206. }
  207. }
  208. }
  209. bool CodeEmitter::IsTopLevelThroughTranslation(const IdentifierUID& uid) const
  210. {
  211. return m_translations.GetLandingScope(uid.GetName()) == QualifiedNameView{"/"};
  212. }
  213. //! setup all scope migrations (srg content to global, local structs to global)
  214. void CodeEmitter::SetupScopeMigrations(const Options& options)
  215. {
  216. m_translations.SetAccessSymbolQueryFunctor([=](QualifiedNameView qnv){return m_ir->GetKindInfo(IdentifierUID{qnv});});
  217. m_translations.SetGetSeenatFunctor([=](QualifiedNameView qnv) -> vector<Seenat>&
  218. {
  219. auto* uidkind = m_ir->GetIdAndKindInfo(qnv);
  220. if (uidkind)
  221. {
  222. return uidkind->second.GetSeenats();
  223. }
  224. static vector<Seenat> s_empty;
  225. return s_empty;
  226. });
  227. auto globalScope = QualifiedNameView{ "/" };
  228. // Global root constant custom behavior
  229. for (auto&[uid, info] : m_ir->GetOrderedSymbolsOfSubType_2<VarInfo>())
  230. {
  231. bool isAlreadyGlobal = IsGlobal(uid.GetName());
  232. bool isNotContainedInType = !m_ir->IsNestedStructOrEnum(uid);
  233. if (isAlreadyGlobal && isNotContainedInType)
  234. {
  235. auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(uid.GetName());
  236. const auto& rootconstantClassInfo = *m_ir->GetSymbolSubAs<ClassInfo>(m_ir->m_rootConstantStructUID.m_name);
  237. if (options.m_rootConstantsMaxSize && rootconstantClassInfo.HasMember(uid.GetNameLeaf())
  238. && varInfo->CheckHasAllStorageFlags({ StorageFlag::Rootconstant }))
  239. {
  240. IdentifierUID constructRootconstantReference = *rootconstantClassInfo.FindMemberFromLeafName(uid.GetNameLeaf());
  241. m_translations.RegisterLandingScope(uid, m_ir->m_rootConstantStructUID.GetName());
  242. m_translations.AddCustomBehavior(uid.GetName(),
  243. BehaviorEvent::OnReference,
  244. [constructRootconstantReference](QualifiedNameView, UsageContext, string proposition, ssize_t)
  245. {
  246. // Construct the rootconstant member name which is declared as global variable with the use of get functions.
  247. // Type _g_MyRootConstVar = GetShaderRootConst_Member(); So now we just return _g_MyRootConstVar;
  248. return GetGlobalRootConstantVarName(constructRootconstantReference.GetName());
  249. });
  250. }
  251. }
  252. }
  253. // all SRGs -> erased. Migrate all their contents to global.
  254. for (auto& [srgUID, srgInfo] : m_ir->GetOrderedSymbolsOfSubType_2<SRGInfo>())
  255. {
  256. array<decltype(srgInfo->m_structs)*, 6> allSrgMembersUidArrays = {&srgInfo->m_structs,
  257. &srgInfo->m_srViews,
  258. &srgInfo->m_samplers,
  259. &srgInfo->m_CBs,
  260. &srgInfo->m_nonexternVariables,
  261. &srgInfo->m_functions};
  262. for (auto& array : allSrgMembersUidArrays)
  263. {
  264. for (auto& member : *array)
  265. {
  266. auto globalScope = QualifiedNameView{"/"};
  267. MigrateASTSubTree(member, globalScope);
  268. }
  269. }
  270. // variables get special treatment in case of non-emitConstantBufferBody, because SRG-constants go in a generated-struct: <SRGNAME>_SRGConstantStruct
  271. for (auto& member : srgInfo->m_implicitStruct.GetMemberFields())
  272. {
  273. auto globalScope = QualifiedNameView{"/"};
  274. auto constantsStruct = MakeSrgConstantsStructName(srgUID);
  275. MigrateASTSubTree(member, options.m_emitConstantBufferBody ? globalScope : QualifiedNameView{constantsStruct});
  276. }
  277. // add a special behavior for CB, because under cb-body switch, CBs are views, thus must be indexed.
  278. if (options.m_emitConstantBufferBody)
  279. {
  280. for_each(srgInfo->m_CBs.begin(), srgInfo->m_CBs.end(), [this](const IdentifierUID& viewUid)
  281. {
  282. auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(viewUid.GetName());
  283. // all buffer types references must be mutated to [0]
  284. assert(varInfo->GetTypeClass() == TypeClass::ConstantBuffer);
  285. m_translations.AddCustomBehavior(viewUid.GetName(),
  286. BehaviorEvent::OnReference,
  287. [this](QualifiedNameView, UsageContext, const string& proposition, ssize_t tokenId)
  288. {
  289. // in cb-body mode, CB are converted to views, a non-subscripted expression
  290. // needs to be abstracted away (by adding [0]) to feel like a constant buffer access.
  291. // but since we can have arrays of constant buffers in SRG, the user may have already subscripted
  292. // the access, in that case we can't double the subscript or we'll emit an ill-formed program.
  293. auto* ast = m_ir->m_tokenMap.GetNode(tokenId);
  294. return proposition + (IsNextToArrayAccessExpression(ast) ? "" : "[0]");
  295. });
  296. });
  297. }
  298. else // in the else-case, we have dedicated structures for SRG constants, and references need a custom behavior
  299. {
  300. for_each(srgInfo->m_implicitStruct.GetMemberFields().begin(), srgInfo->m_implicitStruct.GetMemberFields().end(), [this, srgUID=srgUID](IdentifierUID fieldUid)
  301. {
  302. // the natural suggestion from the translation system is going to be
  303. // to mutate MyRsc::m_f4 to MyRsc_SRGConstantStruct::MyRsc_m_f4
  304. // which is the "address" of the variable declaration, but not where the instance reside.
  305. // we need to mutate this to MyRsc_SRGConstantBuffer.MyRsc_m_f4
  306. m_translations.AddCustomBehavior(fieldUid.GetName(),
  307. BehaviorEvent::OnReference,
  308. [this, srgUID=srgUID](QualifiedNameView, UsageContext, string proposition, ssize_t)
  309. {
  310. string constantBufferId = UnMangle(MakeSrgConstantsCBName(srgUID));
  311. string translatedFieldId = UnMangle(string{ExtractLeaf(ReMangle(proposition))});
  312. return constantBufferId + "." + translatedFieldId;
  313. });
  314. });
  315. }
  316. }
  317. // structs/classes in functions/arguments/generic-parameters, may not be valid HLSL scopes to hold types -> migrate them to global
  318. for (auto& [uid, info] : m_ir->GetOrderedSymbolsOfSubType_2<ClassInfo>())
  319. {
  320. bool isAlreadyGlobal = IsGlobal(uid.GetName());
  321. bool isNotContainedInType = !m_ir->IsNestedStructOrEnum(uid);
  322. if (!isAlreadyGlobal && isNotContainedInType)
  323. {
  324. MigrateASTSubTree(uid, globalScope);
  325. }
  326. }
  327. // DXC has made into an error the definition of typedef in classes -> move them all to global
  328. for (auto& [typeAliasUid, typeAliasInfo] : m_ir->GetOrderedSymbolsOfSubType_2<TypeAliasInfo>())
  329. {
  330. if (!IsGlobal(typeAliasUid.GetName()))
  331. {
  332. MigrateASTSubTree(typeAliasUid, globalScope);
  333. }
  334. }
  335. }
  336. void CodeEmitter::EmitShaderVariantOptionVariableDeclaration(const IdentifierUID& symbolUid, const Options& options) const
  337. {
  338. assert(m_ir->GetKind(symbolUid) == Kind::Variable);
  339. assert(IsTopLevelThroughTranslation(symbolUid));
  340. auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(symbolUid.GetName());
  341. if (options.m_useSpecializationConstantsForOptions && varInfo->m_specializationId >= 0)
  342. {
  343. m_out << GetPlatformEmitter().GetSpecializationConstant(*this, symbolUid, options);
  344. }
  345. else
  346. {
  347. EmitGetShaderKeyFunctionDeclaration(symbolUid, varInfo->GetTypeRefInfo());
  348. m_out << ";\n\n";
  349. m_out << "#if defined(" + JoinAllNestedNamesWithUnderscore(symbolUid.m_name) + "_OPTION_DEF)\n";
  350. EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine));
  351. m_out << "_OPTION_DEF ;\n#else\n";
  352. EmitVariableDeclaration(*varInfo, symbolUid, options, VarDeclHasFlag(VarDeclHas::OptionDefine) | VarDeclHas::Initializer);
  353. m_out << ";\n#endif\n";
  354. }
  355. }
  356. void CodeEmitter::EmitShaderVariantOptionGetters(const Options& options) const
  357. {
  358. vector<pair<IdentifierUID, VarInfo*>> symbols;
  359. // browse all variables
  360. for (const auto& [uid, varInfo] : m_ir->m_symbols.GetOrderedSymbolsOfSubType_2<VarInfo>())
  361. {
  362. // For now only emit top level options
  363. if (!IsTopLevelThroughTranslation(uid) || !varInfo->CheckHasStorageFlag(StorageFlag::Option))
  364. {
  365. continue;
  366. }
  367. symbols.emplace_back(uid, varInfo);
  368. }
  369. if (!symbols.empty())
  370. {
  371. m_out << "// Generated code: ShaderVariantOptions fallback value getters:\n";
  372. auto shaderOptions = GetVariantList(options, true);
  373. for (uint32_t shaderOptionIndex = 0; shaderOptionIndex < symbols.size(); ++shaderOptionIndex)
  374. {
  375. const auto& [uid, varInfo] = symbols[shaderOptionIndex];
  376. if (options.m_useSpecializationConstantsForOptions && varInfo->m_specializationId >= 0)
  377. {
  378. continue;
  379. }
  380. const auto keySizeInBits = shaderOptions["ShaderOptions"][shaderOptionIndex]["keySize"].asUInt();
  381. const auto keyOffsetBits = shaderOptions["ShaderOptions"][shaderOptionIndex]["keyOffset"].asUInt();
  382. const auto defaultValue = shaderOptions["ShaderOptions"][shaderOptionIndex]["defaultValue"].asString();
  383. EmitGetShaderKeyFunction(m_shaderVariantFallbackUid, uid, keySizeInBits, keyOffsetBits, defaultValue, varInfo->GetTypeRefInfo());
  384. }
  385. }
  386. }
  387. void CodeEmitter::EmitGetFunctionsForRootConstants(const ClassInfo& rootConstInfo, string_view bufferName) const
  388. {
  389. for (const auto& memberVar : rootConstInfo.GetMemberFields())
  390. {
  391. const auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(memberVar.m_name);
  392. if (!varInfo->m_isPublic)
  393. {
  394. continue;
  395. }
  396. // Construct function return type and function name
  397. m_out << GetTranslatedName(varInfo->GetTypeRefInfo(), UsageContext::ReferenceSite) << " ";
  398. m_out << GetRootConstFunctionName(memberVar);
  399. m_out << "\n{\n";
  400. // The get function of root constants should return members of root ConstantBuffer object
  401. // Ex: return rootCB.ExampleSRG_myRootCstfloat4;
  402. m_out << " return " << bufferName << "." << memberVar.GetNameLeaf() << ";\n";
  403. m_out << "}\n\n";
  404. }
  405. }
  406. void CodeEmitter::EmitRootConstants(const RootSigDesc& rootSig, const Options& options) const
  407. {
  408. if (const auto* rootConstInfo = m_ir->GetSymbolSubAs<ClassInfo>(m_ir->m_rootConstantStructUID.GetName()))
  409. {
  410. if (!rootConstInfo->GetMemberFields().empty())
  411. {
  412. m_out << GetPlatformEmitter().GetRootConstantsView(*this, rootSig, options, BindingPair::Set::Merged);
  413. // Emit the get function definitions which returns root const member
  414. EmitGetFunctionsForRootConstants(*rootConstInfo,
  415. GetTranslatedName(RootConstantsViewName, UsageContext::ReferenceSite));
  416. }
  417. }
  418. }
  419. void CodeEmitter::EmitPreprocessorLineDirective(size_t azslLineNumber)
  420. {
  421. if (azslLineNumber == 0)
  422. return; // protect for this invalid case. seems to happen for "virtual" symbols (like OverloadSet)
  423. size_t supposedVirtualLine = m_lineFinder->GetVirtualLineNumber(azslLineNumber);
  424. size_t curHlslLine = m_out.GetLineCount() + 1; // "lines" is a space that is 1-based indexed.
  425. auto lastEmitted = Infimum(m_alreadyEmittedPreprocessorLineDirectives, curHlslLine);
  426. if (lastEmitted != m_alreadyEmittedPreprocessorLineDirectives.cend())
  427. {
  428. // verify if we can skip the line emission if current stream line feed is still in sync with expectations
  429. // image:
  430. // in sync out of sync
  431. // 1 | #line 1 1 | #line 1
  432. // 2 | code 2 | code
  433. // 3 | newSymbol 3 | code
  434. // 4 | #line 2
  435. // 5 | newSymbol
  436. size_t curHlslPhysicalDistance = curHlslLine - lastEmitted->first;
  437. size_t lastVirtualSet = lastEmitted->second;
  438. size_t nonAdjustedCurrentLandingLine = lastVirtualSet + curHlslPhysicalDistance - 1; // -1 because line directives specify the NEXT line
  439. if (nonAdjustedCurrentLandingLine == supposedVirtualLine)
  440. return; // no need to emit. we can skip
  441. }
  442. // get the original file as absolute path:
  443. const string& originalFileName = StdFs::absolute( m_lineFinder->GetVirtualFileName(azslLineNumber) ).lexically_normal().generic_string();
  444. // emit the line:
  445. m_out << "#line " << supposedVirtualLine << " \"" << originalFileName << "\"\n";
  446. // remember it:
  447. m_alreadyEmittedPreprocessorLineDirectives[curHlslLine] = supposedVirtualLine;
  448. }
  449. void CodeEmitter::EmitPreprocessorLineDirective(QualifiedNameView symbolName)
  450. {
  451. IdAndKind* idAndkindInfo = m_ir->GetIdAndKindInfo(symbolName);
  452. KindInfo& info = idAndkindInfo->second;
  453. const auto origSourceLine = info.VisitSub(GetOrigSourceLine_Visitor{});
  454. EmitPreprocessorLineDirective(origSourceLine);
  455. }
  456. string CodeEmitter::EmitInheritanceList(const ClassInfo& clInfo)
  457. {
  458. string hlsl = clInfo.HasAnyBases() ? " : " : "";
  459. vector<string> mutatedBaseNames;
  460. TransformCopy(clInfo.GetBases(), mutatedBaseNames,
  461. [&](const IdentifierUID& uid)
  462. {
  463. return GetTranslatedName(uid, UsageContext::ReferenceSite);
  464. });
  465. hlsl += Join(mutatedBaseNames, ", ");
  466. return hlsl;
  467. }
  468. void CodeEmitter::EmitStruct(const ClassInfo& classInfo, string_view structuredSymName, const Options& options)
  469. {
  470. auto HlslStructuredDelcTagFromKind = [](Kind k)
  471. {
  472. switch (k)
  473. {
  474. case Kind::Struct: return "struct";
  475. case Kind::Class: return "class";
  476. case Kind::Interface: return "interface";
  477. default: return " ";
  478. }
  479. };
  480. const bool hasName = (structuredSymName.length() > 0);
  481. const auto tabs = " ";
  482. if (hasName)
  483. {
  484. EmitAllAttachedAttributes(IdentifierUID { QualifiedNameView{structuredSymName} });
  485. m_out << HlslStructuredDelcTagFromKind(classInfo.m_kind) << " "
  486. << GetTranslatedName(QualifiedNameView{structuredSymName}, UsageContext::DeclarationSite)
  487. << EmitInheritanceList(classInfo)
  488. << "\n{\n"; // conclusion of "class X : ::Stuff {"
  489. }
  490. for (const IdentifierUID& memberUid : classInfo.GetOrderedMembers())
  491. {
  492. if (structuredSymName.empty() || m_translations.GetLandingScope(memberUid.GetName()) == structuredSymName)
  493. {
  494. auto& [uid, info] = *m_ir->GetIdAndKindInfo(memberUid.GetName());
  495. if (info.IsKindOneOf(Kind::Class, Kind::Struct, Kind::Interface))
  496. {
  497. // recurse
  498. EmitStruct(info.GetSubRefAs<ClassInfo>(), uid.GetName(), options);
  499. }
  500. else if (info.IsKindOneOf(Kind::TypeAlias))
  501. {
  502. EmitTypeAlias(uid, info.GetSubRefAs<TypeAliasInfo>(), options);
  503. }
  504. else if (info.IsKindOneOf(Kind::Enum))
  505. {
  506. EmitEnum(uid, info.GetSubRefAs<ClassInfo>(), options);
  507. }
  508. else if (info.IsKindOneOf(Kind::Variable))
  509. {
  510. const auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(uid.m_name);
  511. assert(varInfo);
  512. m_out << tabs;
  513. EmitVariableDeclaration(*varInfo, uid, options, VarDeclHasFlag(VarDeclHas::HlslSemantics));
  514. m_out << ";\n";
  515. }
  516. else if (info.IsKindOneOf(Kind::Function))
  517. {
  518. const auto* func = m_ir->GetSymbolSubAs<FunctionInfo>(uid.m_name);
  519. assert(func);
  520. m_out << tabs;
  521. EmitFunctionAs form = func->HasUniqueDeclarationThroughDefinition() ? EmitFunctionAs::Definition : EmitFunctionAs::Declaration;
  522. EmitFunction(*func, uid, form, options);
  523. }
  524. else
  525. {
  526. assert(false);// "Unhandled KindInfo case in struct emission must be handled properly!"
  527. }
  528. }
  529. }
  530. if (hasName)
  531. {
  532. m_out << "};\n\n";
  533. }
  534. }
  535. void CodeEmitter::EmitAllAttachedAttributes(const IdentifierUID& uid, Except omissionList/*= {}*/) const
  536. {
  537. if (auto attrList = m_ir->m_symbols.GetAttributeList(uid))
  538. {
  539. for_each(attrList->begin(), attrList->end(),
  540. [=](auto&& attrInfo)
  541. {
  542. if (!IsIn(attrInfo.m_attribute, omissionList))
  543. {
  544. EmitAttribute(attrInfo);
  545. }
  546. });
  547. }
  548. }
  549. void CodeEmitter::EmitGetterFunctionDeclarationsForRootConstants(const IdentifierUID& uid) const
  550. {
  551. if (const auto* rootConstInfo = m_ir->GetSymbolSubAs<ClassInfo>(uid.GetName()))
  552. {
  553. for (const auto& memberVar : rootConstInfo->GetMemberFields())
  554. {
  555. const auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(memberVar.m_name);
  556. if (!varInfo->m_isPublic)
  557. {
  558. continue;
  559. }
  560. m_out << GetTranslatedName(varInfo->GetTypeRefInfo(), UsageContext::ReferenceSite) << " ";
  561. m_out << GetRootConstFunctionName(memberVar);
  562. m_out << ";\n\n";
  563. // Emit a global variable to call the above declared function ex: static const float _gRootConstA = GetRootConst();
  564. // The global variable will be the reference variable to access the root constant
  565. m_out << "static const ";
  566. m_out << GetTranslatedName(varInfo->GetTypeRefInfo(), UsageContext::ReferenceSite) << " ";
  567. m_out << GetGlobalRootConstantVarName(memberVar.m_name) + " = ";
  568. m_out << GetRootConstFunctionName(memberVar);
  569. m_out << ";\n\n";
  570. }
  571. }
  572. }
  573. // special adapter version for AttributeInfo argument
  574. static string Undecorate(string_view decoration, const AttributeInfo::Argument& arg)
  575. {
  576. std::stringstream ss;
  577. MakeOStreamStreamable soss(ss);
  578. (Streamable&)soss << arg;
  579. return string{AZ::Undecorate(decoration, ss.str())};
  580. }
  581. void CodeEmitter::EmitAttribute(const AttributeInfo& attrInfo) const
  582. {
  583. return EmitAttribute(attrInfo, m_out);
  584. }
  585. void CodeEmitter::EmitAttribute(const AttributeInfo& attrInfo, Streamable& outstream)
  586. {
  587. if (attrInfo.m_attribute == "verbatim")
  588. {
  589. outstream << (attrInfo.m_argList.begin() == attrInfo.m_argList.end() ? "" : Unescape(Undecorate("\"", *attrInfo.m_argList.begin())));
  590. auto outVer = [&](const AttributeInfo::Argument& arg) { outstream << " " << Unescape(Undecorate("\"", arg)); };
  591. for_each(std::next(attrInfo.m_argList.begin()), attrInfo.m_argList.end(), outVer);
  592. outstream << '\n';
  593. }
  594. else if (attrInfo.m_attribute == "output_format")
  595. {
  596. if (!attrInfo.m_argList.empty())
  597. {
  598. if (holds_alternative<string>(attrInfo.m_argList[0]))
  599. {
  600. string poFormat{ Trim(get<string>(attrInfo.m_argList[0]), "\"") };
  601. outstream << "#pragma OutputFormatHint(default " << poFormat << ")\n";
  602. }
  603. else if (attrInfo.m_argList.size() > 1 &&
  604. holds_alternative<ConstNumericVal>(attrInfo.m_argList[0]) &&
  605. holds_alternative<string>(attrInfo.m_argList[1]))
  606. {
  607. const auto rtIndex = ExtractValueAsInt64(get<ConstNumericVal>(attrInfo.m_argList[0]), std::numeric_limits<int64_t>::min());
  608. if (rtIndex >= 0 && rtIndex <= 7)
  609. {
  610. string poFormat{ Trim(get<string>(attrInfo.m_argList[1]), "\"") };
  611. outstream << "#pragma OutputFormatHint(target " << rtIndex << " " << poFormat << ")\n";
  612. }
  613. }
  614. }
  615. }
  616. else if (attrInfo.m_attribute == "partial")
  617. {
  618. // Reserved for ShaderResourceGroup use. Do not re-emit
  619. outstream << "// original attribute: [[" << attrInfo << "]]\n ";
  620. }
  621. else if (attrInfo.m_attribute == "range")
  622. {
  623. // Reserved for integer type option variables. Do not re-emit
  624. outstream << "// original attribute: [[" << attrInfo << "]]\n ";
  625. }
  626. else if (attrInfo.m_attribute == "no_specialization")
  627. {
  628. // Reserved for avoiding specialization of a shader option. Do not re-emit
  629. outstream << "// original attribute: [[" << attrInfo << "]]\n ";
  630. }
  631. else
  632. {
  633. // We don't block any attributes we don't understand - pass them through
  634. outstream << ((attrInfo.m_category == AttributeCategory::Single) ? "[" : "[[")
  635. << attrInfo
  636. << ((attrInfo.m_category == AttributeCategory::Single) ? "]" : "]]")
  637. << "\n";
  638. }
  639. }
  640. void CodeEmitter::EmitTypeAlias(const IdentifierUID& uid, const TypeAliasInfo& aliasInfo, const Options& options) const
  641. {
  642. using SF = StorageFlag;
  643. m_out << "typedef " << GetTranslatedName(aliasInfo.m_canonicalType, UsageContext::ReferenceSite, options)
  644. << " " << GetTranslatedName(uid, UsageContext::DeclarationSite) << ";\n";
  645. }
  646. void CodeEmitter::EmitEnum(const IdentifierUID& uid, const ClassInfo& classInfo, const Options& options)
  647. {
  648. const auto& enumInfo = get<EnumerationInfo>(classInfo.m_subInfo);
  649. EmitAllAttachedAttributes(uid);
  650. m_out << "enum " << (enumInfo.m_isScoped ? "class " : "") << GetTranslatedName(uid.m_name, UsageContext::DeclarationSite) << "\n";
  651. m_out << "{\n";
  652. for (const auto& memberUid : classInfo.GetMemberFields())
  653. {
  654. // The variable can potentially have an initialization declaration as well
  655. const auto* varInfoPtr = m_ir->GetSymbolSubAs<VarInfo>(memberUid.m_name);
  656. EmitVariableDeclaration(*varInfoPtr,
  657. memberUid,
  658. options,
  659. VarDeclHasFlag(VarDeclHas::Initializer) | VarDeclHas::NoType | VarDeclHas::NoModifiers);
  660. m_out << ",\n";
  661. }
  662. m_out << "};\n\n";
  663. }
  664. bool CodeEmitter::AlreadyEmittedFunctionDeclaration(const IdentifierUID& uid) const
  665. {
  666. return m_alreadyEmittedFunctionDeclarations.find(uid) != m_alreadyEmittedFunctionDeclarations.end();
  667. }
  668. bool CodeEmitter::AlreadyEmittedFunctionDefinition(const IdentifierUID& uid) const
  669. {
  670. return m_alreadyEmittedFunctionDefinitions.find(uid) != m_alreadyEmittedFunctionDefinitions.end();
  671. }
  672. void CodeEmitter::EmitFunction(const FunctionInfo& funcSub, const IdentifierUID& uid, EmitFunctionAs entityConfiguration, const Options& options)
  673. {
  674. // reproduce the signature `[attr] [modifiers] rettype [classnameFQN::] Name(params) [semantics]`
  675. bool emitAsDeclaration = entityConfiguration == EmitFunctionAs::Declaration;
  676. bool emitAsDefinition = entityConfiguration == EmitFunctionAs::Definition;
  677. bool riskDoubleEmission = emitAsDeclaration && AlreadyEmittedFunctionDeclaration(uid)
  678. || emitAsDefinition && AlreadyEmittedFunctionDefinition(uid);
  679. bool undefinedFunction = funcSub.IsUndefinedFunction();
  680. if (riskDoubleEmission
  681. || undefinedFunction && emitAsDefinition)
  682. {
  683. return;
  684. }
  685. AstFuncSig* node = funcSub.m_defNode ? funcSub.m_defNode : funcSub.m_declNode;
  686. EmitAllAttachedAttributes(uid);
  687. // emit some modifiers only in case of first declaration/definition
  688. // because: class C{ static vd A(); };
  689. // static vd C::A(){} // ill-formed HLSL. we can't repeat static here.
  690. Modifiers forbidden;
  691. bool firstDecl = !AlreadyEmittedFunctionDeclaration(uid);
  692. if (!firstDecl)
  693. {
  694. forbidden = Modifiers{StorageFlag::Static} | StorageFlag::Inline | StorageFlag::Extern;
  695. }
  696. // emit return type:
  697. m_out << GetTranslatedName(funcSub.m_returnType, UsageContext::ReferenceSite, options, forbidden) << " ";
  698. // emit Name
  699. if (entityConfiguration == EmitFunctionAs::Definition && funcSub.HasDeportedDefinition())
  700. {
  701. // emit fully qualified function name (with classname prefix).
  702. // surrounded by round braces because otherwise clang's greedy parsing will wreak havoc.
  703. // indeed `RetType ::Class::Method() {}` will be parsed as `RetType::Class::Method(){}` one id-expr.
  704. // since we can't use trailing return type in HLSL, the fix is to use round braces.
  705. // solution from https://stackoverflow.com/a/3185232/893406
  706. m_out << "(" << GetTranslatedName(uid, UsageContext::ReferenceSite) << ")";
  707. }
  708. else
  709. {
  710. // emit leaf function name (in AZSL all declaration sites are Identifier and not idExpressions)
  711. m_out << GetTranslatedName(uid, UsageContext::DeclarationSite);
  712. }
  713. // emit parameters:
  714. m_out << "(";
  715. bool inhibitInitializers = AlreadyEmittedFunctionDeclaration(uid) && !emitAsDeclaration;
  716. EmitParameters(funcSub.GetParameters(emitAsDeclaration).begin(), funcSub.GetParameters(emitAsDeclaration).end(), options, !inhibitInitializers);
  717. m_out << ")";
  718. if (const auto hlslSemantic = node->hlslSemantic())
  719. {
  720. m_out << " " << hlslSemantic->getText();
  721. }
  722. if (entityConfiguration == EmitFunctionAs::Declaration)
  723. {
  724. m_out << ";\n";
  725. m_alreadyEmittedFunctionDeclarations.insert(uid);
  726. }
  727. else
  728. {
  729. auto scopeInterval = m_ir->m_scope.m_scopeIntervals[uid];
  730. auto astNode = m_ir->m_tokenMap.GetNode(scopeInterval.a);
  731. auto funcDefNode = ExtractSpecificParent<azslParser::HlslFunctionDefinitionContext>(astNode);
  732. auto blockInterval = funcDefNode->block()->getSourceInterval();
  733. m_out << "\n";
  734. EmitTranspiledTokens(blockInterval);
  735. m_out << "\n\n";
  736. m_alreadyEmittedFunctionDefinitions.insert(uid);
  737. }
  738. }
  739. void CodeEmitter::EmitVariableDeclaration(const VarInfo& varInfo, const IdentifierUID& uid, const Options& options, VarDeclHasFlag declOptions) const
  740. {
  741. // from MSDN: https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
  742. // [Storage_Class] [Type_Modifier] Type Name[Index] [: Semantic] [: Packoffset] [: Register]; [Annotations] [= Initial_Value]
  743. // example of valid HLSL statement:
  744. // static const int dim = 2;
  745. // extern uniform bool stuff[dim][1] : Color0 : register(b0) <int blabla=27; string blacksheep="Hello There";> = {{false, true}};
  746. const ICodeEmissionMutator* codeMutator = m_codeMutator;
  747. const CodeMutation* codeMutation = nullptr;
  748. if (codeMutator && varInfo.m_declNode)
  749. {
  750. const auto tokenIndex = varInfo.m_declNode->start->getTokenIndex();
  751. codeMutation = codeMutator->GetMutation(tokenIndex);
  752. }
  753. if (codeMutation && codeMutation->m_prepend)
  754. {
  755. m_out << codeMutation->m_prepend.value();
  756. }
  757. if (codeMutation && codeMutation->m_replace)
  758. {
  759. m_out << codeMutation->m_replace.value();
  760. }
  761. else
  762. {
  763. EmitAllAttachedAttributes(uid);
  764. // parameter in/out modifiers
  765. if (declOptions & VarDeclHas::InOutModifiers)
  766. {
  767. m_out << GetInputModifier(varInfo.m_typeInfoExt.m_qualifiers) << " ";
  768. }
  769. // type
  770. if (!(declOptions & VarDeclHas::NoType) || (declOptions & VarDeclHas::OptionDefine))
  771. {
  772. auto bannedModifiers = (declOptions & VarDeclHas::NoModifiers) ? ~Modifiers{(StorageFlag::EnumType)0} : Modifiers{};
  773. m_out << GetTranslatedName(varInfo.m_typeInfoExt, UsageContext::ReferenceSite, options, bannedModifiers) + " ";
  774. }
  775. // var name
  776. m_out << GetTranslatedName(uid.m_name, UsageContext::DeclarationSite);
  777. // array dimensions
  778. if (varInfo.m_declNode && !varInfo.m_declNode->ArrayRankSpecifiers.empty())
  779. {
  780. for (auto* rankCtx : varInfo.m_declNode->ArrayRankSpecifiers)
  781. {
  782. // the brackets are included by the rule arrayRankSpecifier
  783. EmitTranspiledTokens(rankCtx->getSourceInterval());
  784. }
  785. }
  786. else if (!varInfo.GetArrayDimensions().Empty())
  787. {
  788. m_out << varInfo.GetArrayDimensions().ToString();
  789. }
  790. // semantics
  791. if (varInfo.m_declNode && (declOptions & VarDeclHas::HlslSemantics))
  792. {
  793. if (auto* semanticOption = varInfo.m_declNode->SemanticOpt)
  794. {
  795. m_out << " " + semanticOption->getText();
  796. }
  797. }
  798. // init clause
  799. if (declOptions & VarDeclHas::OptionDefine)
  800. {
  801. if (declOptions & VarDeclHas::Initializer)
  802. {
  803. m_out << " = " << GetShaderKeyFunctionName(uid);
  804. }
  805. else
  806. {
  807. m_out << " = " << JoinAllNestedNamesWithUnderscore(uid.m_name);
  808. }
  809. }
  810. else if (declOptions & VarDeclHas::Initializer)
  811. {
  812. auto* initClause = varInfo.m_declNode ? varInfo.m_declNode->variableInitializer() : nullptr;
  813. if (initClause)
  814. {
  815. m_out << " ";
  816. EmitTranspiledTokens(initClause->getSourceInterval());
  817. }
  818. else // fallback on a potentially folded value, that has chances to work for constants like enumerators
  819. {
  820. int64_t constVal;
  821. if (TryGetConstExprValueAsInt64(varInfo.m_constVal, constVal))
  822. {
  823. m_out << " = " << constVal;
  824. }
  825. }
  826. }
  827. }
  828. if (codeMutation && codeMutation->m_append)
  829. {
  830. m_out << codeMutation->m_append.value();
  831. }
  832. }
  833. void CodeEmitter::EmitSRGCBUnified(const SRGInfo& srgInfo, IdentifierUID srgId, const Options& options, const RootSigDesc& rootSig)
  834. {
  835. auto bindSet = BindingPair::Set::Merged;
  836. // Use the uId of the SRG instead of a CBV, because we create a dummy placeholder CBV to hold the rest of the declarations:
  837. if (!options.m_emitConstantBufferBody)
  838. {
  839. if (srgInfo.m_implicitStruct.GetMemberFields().size() > 0)
  840. {
  841. const auto& bindInfo = rootSig.Get(srgId);
  842. const QualifiedName implicitStruct = MakeSrgConstantsStructName(srgId);
  843. const QualifiedName implicitCB = MakeSrgConstantsCBName(srgId);
  844. EmitStruct(srgInfo.m_implicitStruct, implicitStruct, options);
  845. const auto spaceX = ", space" + std::to_string(bindInfo.m_registerBinding.m_pair[bindSet].m_logicalSpace);
  846. const auto implicitStructForEmission = GetTranslatedName(implicitStruct, UsageContext::ReferenceSite);
  847. const auto implicitCBForEmission = GetTranslatedName(implicitCB, UsageContext::DeclarationSite);
  848. m_out << "ConstantBuffer<" << implicitStructForEmission << "> " << implicitCBForEmission << " : register(b" << bindInfo.m_registerBinding.m_pair[bindSet].m_registerIndex << spaceX << ");\n\n";
  849. }
  850. for (auto cId : srgInfo.m_CBs)
  851. {
  852. if (!m_ir->GetSymbolSubAs<VarInfo>(cId.GetName())->StorageFlagIsLocalLinkage(true))
  853. {
  854. EmitSRGCB(cId, options, rootSig);
  855. }
  856. }
  857. return;
  858. }
  859. if (srgInfo.m_CBs.empty() && srgInfo.m_implicitStruct.GetMemberFields().empty())
  860. {
  861. return;
  862. }
  863. const auto& bindInfo = rootSig.Get(srgId);
  864. m_out << "ConstantBuffer " << srgInfo.m_declNode->Name->getText() << "_CBContainer : register(b" << bindInfo.m_registerBinding.m_pair[bindSet].m_registerIndex << ")\n{\n";
  865. for (const auto& cId : srgInfo.m_CBs)
  866. {
  867. if (!m_ir->GetSymbolSubAs<VarInfo>(cId.GetName())->StorageFlagIsLocalLinkage(true))
  868. {
  869. const auto& uqName = cId.GetNameLeaf();
  870. m_out << " CBVArrayView " << uqName << ";\n";
  871. }
  872. }
  873. if (!srgInfo.m_implicitStruct.GetMemberFields().empty())
  874. {
  875. m_out << "\n";
  876. EmitStruct(srgInfo.m_implicitStruct, "", options);
  877. }
  878. m_out << "\n};\n";
  879. for (const auto& cId : srgInfo.m_CBs)
  880. {
  881. if (!m_ir->GetSymbolSubAs<VarInfo>(cId.GetName())->StorageFlagIsLocalLinkage(true))
  882. {
  883. const auto* memberInfo = m_ir->GetSymbolSubAs<VarInfo>(cId.m_name);
  884. const auto& cbName = ReplaceSeparators(cId.m_name, Underscore);
  885. const auto& uqName = cId.GetNameLeaf();
  886. if (memberInfo->IsConstantBuffer())
  887. {
  888. const auto& templatedCB = "RegularBuffer<" + GetTranslatedName(memberInfo->GetGenericParameterTypeId(), UsageContext::ReferenceSite) + "> ";
  889. m_out << "static const " << templatedCB << cbName << " = " << templatedCB << "(" << uqName << ");\n";
  890. }
  891. }
  892. }
  893. m_out << "\n\n";
  894. }
  895. void CodeEmitter::EmitSRGCB(const IdentifierUID& cId, const Options& options, const RootSigDesc& rootSig) const
  896. {
  897. EmitAllAttachedAttributes(cId);
  898. auto bindSet = BindingPair::Set::Merged;
  899. assert(!options.m_emitConstantBufferBody);
  900. const auto& bindInfo = rootSig.Get(cId);
  901. const auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(cId.m_name);
  902. auto cbName = ReplaceSeparators(cId.m_name, Underscore);
  903. assert(varInfo->IsConstantBuffer());
  904. // note: instead of redoing this work ad-hoc, EmitText could be used directly on the ext type.
  905. const auto genericType = "<" + GetTranslatedName(varInfo->m_typeInfoExt.m_genericParameter, UsageContext::ReferenceSite) + ">";
  906. const string spaceX = ", space" + std::to_string(bindInfo.m_registerBinding.m_pair[bindSet].m_logicalSpace);
  907. m_out << "ConstantBuffer " << genericType << " " << cbName;
  908. if (bindInfo.m_isUnboundedArray)
  909. {
  910. m_out << "[]";
  911. }
  912. else if (bindInfo.m_registerRange > 1)
  913. {
  914. m_out << "[" << bindInfo.m_registerRange << "]";
  915. }
  916. m_out << " : register(b" << bindInfo.m_registerBinding.m_pair[bindSet].m_registerIndex << spaceX << ");\n\n";
  917. }
  918. void CodeEmitter::EmitSRGSampler(const IdentifierUID& sId, const Options& options, const RootSigDesc& rootSig) const
  919. {
  920. EmitAllAttachedAttributes(sId);
  921. auto bindSet = BindingPair::Set::Merged;
  922. const auto& bindInfo = rootSig.Get(sId);
  923. const auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(sId.m_name);
  924. const string spaceX = ", space" + std::to_string(bindInfo.m_registerBinding.m_pair[bindSet].m_logicalSpace);
  925. m_out << (varInfo->m_samplerState->m_isComparison ? "SamplerComparisonState " : "SamplerState ")
  926. << ReplaceSeparators(sId.m_name, Underscore);
  927. if (bindInfo.m_isUnboundedArray)
  928. {
  929. m_out << "[]";
  930. }
  931. else if (bindInfo.m_registerRange > 1)
  932. {
  933. m_out << "[" << bindInfo.m_registerRange << "]";
  934. }
  935. m_out << " : register(s" << bindInfo.m_registerBinding.m_pair[bindSet].m_registerIndex << spaceX << ")";
  936. m_out << ";\n\n";
  937. }
  938. //! For scope-migration-aware name emission of symbol names
  939. string CodeEmitter::GetTranslatedName(QualifiedNameView mangledName, UsageContext context, ssize_t tokenId /*= NotOverToken*/) const
  940. {
  941. return m_translations.GetTranslatedName(mangledName, context, tokenId);
  942. }
  943. string CodeEmitter::GetTranslatedName(const IdentifierUID& uid, UsageContext context, ssize_t tokenId /*= NotOverToken*/) const
  944. {
  945. return GetTranslatedName(uid.m_name, context, tokenId);
  946. }
  947. string CodeEmitter::GetTranslatedName(const TypeRefInfo& typeRef, UsageContext context, ssize_t tokenId /*= NotOverToken*/) const
  948. {
  949. return GetTranslatedName(typeRef.m_typeId, context, tokenId);
  950. }
  951. string CodeEmitter::GetTranslatedName(const ExtendedTypeInfo& extTypeInfo, UsageContext context, const Options& options, Modifiers forbidden /*= {}*/, ssize_t tokenId /*= NotOverToken*/) const
  952. {
  953. return GetExtendedTypeInfo(extTypeInfo, options, forbidden,
  954. [this, context, tokenId](const TypeRefInfo& tri){ return GetTranslatedName(tri, context, tokenId); });
  955. }
  956. void CodeEmitter::EmitSRGDataView(const IdentifierUID& tId, const Options& options, const RootSigDesc& rootSig) const
  957. {
  958. EmitAllAttachedAttributes(tId, Except{ { "input_attachment_index" } });
  959. auto bindSet = BindingPair::Set::Merged;
  960. auto& bindInfo = rootSig.Get(tId);
  961. auto* varInfo = m_ir->GetSymbolSubAs<VarInfo>(tId.m_name);
  962. string varType = GetTranslatedName(varInfo->m_typeInfoExt, UsageContext::DeclarationSite, options);
  963. auto registerTypeLetter = ToLower(BindingType::ToStr(RootParamTypeToBindingType(bindInfo.m_type)));
  964. optional<string> stringifiedLogicalSpace = std::to_string(bindInfo.m_registerBinding.m_pair[bindSet].m_logicalSpace);
  965. // depending on platforms we may have supplementary attributes or/and type modifier.
  966. auto [prefix, suffix] = GetPlatformEmitter().GetDataViewHeaderFooter(*this, tId, bindInfo.m_registerBinding.m_pair[bindSet].m_registerIndex, registerTypeLetter, stringifiedLogicalSpace);
  967. m_out << prefix;
  968. // declaration of the view variable on the global scope.
  969. // type unmangled_path_to_symbol [optional_array_dimension] : optional_register_binding_as_suffix
  970. m_out << varType << " " << ReplaceSeparators(tId.m_name, Underscore);
  971. if (bindInfo.m_isUnboundedArray)
  972. {
  973. m_out << "[]";
  974. }
  975. else if (bindInfo.m_registerRange > 1)
  976. {
  977. m_out << "[" << bindInfo.m_registerRange << "]";
  978. }
  979. m_out << suffix;
  980. auto interval = m_ir->m_scope.m_scopeIntervals[tId];
  981. EmitTranspiledTokens(interval);
  982. m_out << ";\n\n";
  983. }
  984. void CodeEmitter::EmitGetShaderKeyFunctionDeclaration(const IdentifierUID& getterUid, const TypeRefInfo& returnType) const
  985. {
  986. m_out << GetTranslatedName(returnType, UsageContext::ReferenceSite) << " ";
  987. m_out << GetShaderKeyFunctionName(getterUid);
  988. }
  989. void CodeEmitter::EmitGetShaderKeyFunction(const IdentifierUID& shaderKeyUid, const IdentifierUID& getterUid, uint32_t keySizeInBits, uint32_t keyOffsetBits, string_view defaultValue, const TypeRefInfo& returnType) const
  990. {
  991. // Because we use uint on the shader source side no shader option can cross the 32-bit boundary
  992. // This is already ensured by the emission side, in Json::Value CodeEmitter::GetVariantList(...)
  993. uint32_t arraySlot = keyOffsetBits / AZ::ShaderCompiler::kShaderVariantKeyRegisterSize;
  994. keyOffsetBits -= arraySlot * AZ::ShaderCompiler::kShaderVariantKeyRegisterSize;
  995. uint32_t swizzle = keyOffsetBits / AZ::ShaderCompiler::kShaderVariantKeyElementSize;
  996. keyOffsetBits -= swizzle * AZ::ShaderCompiler::kShaderVariantKeyElementSize;
  997. // Intentional unnamed scope for error-checking
  998. {
  999. auto& varInfo = *m_ir->GetSymbolSubAs<VarInfo>(shaderKeyUid.m_name);
  1000. auto dims = varInfo.m_typeInfoExt.GetDimensions();
  1001. assert(dims.m_dimensions.size() == 1); // This is generated variable, it must have exactly 1 array dimension
  1002. if (arraySlot >= dims.m_dimensions[0])
  1003. {
  1004. const string errorMessage = ConcatString("The option {", UnmangleTrimedName(getterUid.m_name), "} exceeds the number of bits (",
  1005. AZ::ShaderCompiler::kShaderVariantKeyRegisterSize * dims.m_dimensions[0], ") allowed by the ShaderVariantFallback.\n",
  1006. "Either increase the limit or remove some options!");
  1007. throw AzslcEmitterException(EMITTER_OPTION_EXCEEDING_BITS_COUNT, errorMessage);
  1008. }
  1009. }
  1010. // The most significant bits are put in the first element, then the next, etc.
  1011. // The bit order within an element is swapped so the most significant option
  1012. // is at 0x00000001 and the least significant (32th bit) is at 0x80000000.
  1013. // This is tailored to the runtime to reduce CB compile times.
  1014. const char suffix[] = { 'x', 'y', 'z', 'w' };
  1015. m_out << "\n";
  1016. const auto mask = static_cast<uint32_t>(pow(2, keySizeInBits) - 1);
  1017. EmitGetShaderKeyFunctionDeclaration(getterUid, returnType);
  1018. m_out << "\n{\n";
  1019. if (keySizeInBits > 0) // Emit the option getter function body using CB access and bit decoding.
  1020. {
  1021. m_out << " uint shaderKey = ((" << GetTranslatedName(shaderKeyUid, UsageContext::ReferenceSite)
  1022. << "[" << arraySlot << "]." << suffix[swizzle] << " >> " << keyOffsetBits << ") & " << mask << ")";
  1023. // Also we need to reproduce the "range minimal" value as an addition post big-field extraction.
  1024. // because the CB compact storage into bits of an uint4 does not incorporate the offset of the range
  1025. // because they are "useless" (compressable) entropy bits that would take space in the key.
  1026. if (returnType.m_typeId.m_name.find("int") != string::npos) // if type int/uint/int16_t... then it's an integer range.
  1027. {
  1028. if (auto attrInfo = m_ir->m_symbols.GetAttribute(getterUid, "range"))
  1029. {
  1030. // Presence of the correct arglist has already been verified in Backend::AppendOptionRange
  1031. m_out << " + " << ExtractValueAsInt64(get<ConstNumericVal>(attrInfo->m_argList[0]));
  1032. }
  1033. }
  1034. else if (returnType.m_typeClass == TypeClass::Enum)
  1035. {
  1036. // In the case of an enumeration,
  1037. // recovering the correct value will depend if there are jumps in the enumerators.
  1038. // To discover that, we will use the presence of initializers.
  1039. // one initializer on the first enumerator is a special case because it still allows
  1040. // us to reconstruct the value from an integer cast and a +first enumerator.
  1041. // In a case of disparate enumerators, we'll need to reconstruct the values
  1042. // by emitting a switch case.
  1043. // We don't want that switch case as a generic case, that could simplify azslc's source
  1044. // but it risks slowing down the shader runtime. Sort of a pay-what-you-use principle.
  1045. auto& enumClassInfo = *m_ir->GetSymbolSubAs<ClassInfo>(returnType.m_typeId.GetName());
  1046. auto& enumerators = enumClassInfo.GetOrderedMembers();
  1047. bool otherInitializers = std::any_of(enumerators.begin() + 1, enumerators.end(),
  1048. [&](const IdentifierUID& e)
  1049. {
  1050. auto& var = *m_ir->GetSymbolSubAs<VarInfo>(e.GetName());
  1051. return !!var.m_declNodeEnum->Value;
  1052. });
  1053. if (otherInitializers)
  1054. {
  1055. m_out << ";\n switch (shaderKey)\n {\n";
  1056. for (int i = 0; i < enumerators.size(); ++i)
  1057. {
  1058. m_out << " case " << i << ": shaderKey = (" << GetTranslatedName(returnType, UsageContext::ReferenceSite) << ")"
  1059. << GetTranslatedName(enumerators[i], UsageContext::ReferenceSite) << "; break;\n";
  1060. }
  1061. m_out << " }";
  1062. }
  1063. else
  1064. {
  1065. // add the value of the first enumerator (by emitting its name)
  1066. m_out << " + (uint)" << GetTranslatedName(enumerators.front(), UsageContext::ReferenceSite); // we can access front with no defense because keySize would be 0 if there are no enumerators.
  1067. }
  1068. }
  1069. m_out << ";\n";
  1070. m_out << " return (" << GetTranslatedName(returnType, UsageContext::ReferenceSite) << ") shaderKey;\n";
  1071. }
  1072. else
  1073. {
  1074. // In the case of an empty enumeration type (no enumerators) or empty range, the keySize will be 0
  1075. // and this if-branch will be taken. if the programmer specified no initializer clause,
  1076. // this would produce unbuildable HLSL and error in generatored code. We really don't want that
  1077. // as it would be hard to diagnose for users. As a reasonable behavior, we can emit val = 0.
  1078. string typeAsStr = GetTranslatedName(returnType, UsageContext::ReferenceSite);
  1079. // "<type> val = (cast expr to <type>) 0 or default;"
  1080. m_out << " " << typeAsStr << " val = (" << typeAsStr << ")" << (defaultValue.empty() ? "0" : defaultValue) << ";\n";
  1081. m_out << " return val;\n";
  1082. }
  1083. m_out << "}\n\n";
  1084. }
  1085. void CodeEmitter::EmitSRG(const SRGInfo& srgInfo, const IdentifierUID& srgId, const Options& options, const RootSigDesc& rootSig)
  1086. {
  1087. RootSigDesc::SrgDesc srgDesc;
  1088. srgDesc.m_uid = srgId;
  1089. m_out << "/* Generated code from ";
  1090. // We don't emit the SRG attributes (only as a comment), but they can be accessed by the srgId if needed
  1091. EmitAllAttachedAttributes(srgId);
  1092. m_out << " ShaderResourceGroup " << srgInfo.m_declNode->Name->getText() << "*/\n";
  1093. for (const auto& t : srgInfo.m_srViews)
  1094. {
  1095. // (*1) non-extern symbols are not visible resources, and are emitted by canonical EmitVariableDeclaration path
  1096. if (!m_ir->GetSymbolSubAs<VarInfo>(t.GetName())->StorageFlagIsLocalLinkage(true))
  1097. {
  1098. EmitSRGDataView(t, options, rootSig);
  1099. }
  1100. }
  1101. for (const auto& s : srgInfo.m_samplers)
  1102. {
  1103. // same as (*1)
  1104. if (!m_ir->GetSymbolSubAs<VarInfo>(s.GetName())->StorageFlagIsLocalLinkage(true))
  1105. {
  1106. EmitSRGSampler(s, options, rootSig);
  1107. }
  1108. }
  1109. EmitSRGCBUnified(srgInfo, srgId, options, rootSig);
  1110. if (srgInfo.m_shaderVariantFallback)
  1111. {
  1112. m_shaderVariantFallbackUid = *srgInfo.m_shaderVariantFallback;
  1113. }
  1114. }
  1115. void CodeEmitter::IfIsSrgMemberValidateIsDefined(antlr4::Token* token, TokenToAst::AstNode* nodeFromToken) const
  1116. {
  1117. // A common user mistake is to reference an undefined SRG field like
  1118. // "MySrg::m_someVariable", where "m_someVariable" is undefined inside "MySrg",
  1119. // or because it's a typo. Either way let's catch the scenario and provide a meaningful message.
  1120. if (auto nestedNameSpecifierCtx = As<azslParser::NestedNameSpecifierContext*>(nodeFromToken))
  1121. {
  1122. // Get the parent rule, which should be a QualifiedIdContext.
  1123. auto qualifiedIdCtx = As<azslParser::QualifiedIdContext*>(nestedNameSpecifierCtx->parent);
  1124. if (!qualifiedIdCtx)
  1125. {
  1126. const string errorMessage = FormatString("Unexpected expression '%s'", token->getText().c_str());
  1127. throw AzslcEmitterException(EMITTER_UNEXPECTED_EXPRESSION, token, errorMessage);
  1128. }
  1129. // We only care to call out the error if the token has the name of an SRG.
  1130. QualifiedName qualifiedSymbolName(FormatString("/%s", token->getText().c_str()));
  1131. if (const IdAndKind* idAndKind = m_ir->GetIdAndKindInfo(qualifiedSymbolName))
  1132. {
  1133. const auto& [uid, kindInfo] = *idAndKind;
  1134. if (kindInfo.IsKindOneOf(Kind::ShaderResourceGroup))
  1135. {
  1136. const string errorMessage = FormatString("Undefined ShaderResourceGroup member '%s'", qualifiedIdCtx->getText().c_str());
  1137. throw AzslcEmitterException(EMITTER_UNDEFINED_SRG_MEMBER, qualifiedIdCtx->getStart(), errorMessage);
  1138. }
  1139. }
  1140. }
  1141. }
  1142. // override of the base method, to incorporate symbol and expression mutations
  1143. void CodeEmitter::EmitTranspiledTokens(misc::Interval interval, Streamable& output) const
  1144. {
  1145. const ICodeEmissionMutator* codeMutator = m_codeMutator;
  1146. ssize_t ii = interval.a;
  1147. while (ii <= interval.b)
  1148. {
  1149. auto* token = GetNextToken(ii /*inout*/);
  1150. const auto tokenIndex = token->getTokenIndex();
  1151. const CodeMutation* codeMutation = codeMutator ? codeMutator->GetMutation(tokenIndex) : nullptr;
  1152. if (codeMutation && codeMutation->m_prepend)
  1153. {
  1154. output << codeMutation->m_prepend.value();
  1155. }
  1156. if (codeMutation && codeMutation->m_replace)
  1157. {
  1158. output << codeMutation->m_replace.value();
  1159. }
  1160. else
  1161. {
  1162. // Access the AST from the token. Note that any processing relying on this, amounts to a hacky shortcut.
  1163. // (taken to make the economy of having to produce a specific emitters, for each AST-node/code-construct in HLSL)
  1164. auto* astNode = m_ir->m_tokenMap.GetNode(tokenIndex);
  1165. // watch for potential mutations in the middle of the original source.
  1166. auto [originalSymbol, endToken] = m_translations.OverOriginalDefinitionOf(token->getTokenIndex());
  1167. bool emitAsIs = originalSymbol.empty();
  1168. if (!emitAsIs)
  1169. {
  1170. auto* withinVarDecl = ExtractSpecificParent<azslParser::VariableDeclarationStatementContext>(astNode);
  1171. // if we are over a variable declaration, then it's a more evolved statement than just a struct/class declaration
  1172. // in such a case, the introduction of the type (declaration) is immediately used in a larger statement,
  1173. // so we need to re-emit it as a reference.
  1174. if (withinVarDecl && token == withinVarDecl->variableDeclaration()->type()->start)
  1175. {
  1176. output << GetTranslatedName(originalSymbol, UsageContext::ReferenceSite, ii) << " ";
  1177. }
  1178. // that and anything else: pure UDT declarations, or typeof/typealias may be jumped over completely
  1179. ii = endToken + 1;
  1180. }
  1181. else
  1182. {
  1183. auto idExpr = m_translations.GetIdExpression(token);
  1184. if (!idExpr.IsEmpty())
  1185. {
  1186. auto getToken = [this](ssize_t& tokenId) -> string
  1187. {
  1188. assert(tokenId >= 0);
  1189. auto* token = m_tokens->get(static_cast<size_t>(tokenId));
  1190. return token->getChannel() == Token::DEFAULT_CHANNEL ? token->getText() : string{};
  1191. };
  1192. output << m_translations.TranslateIdExpression(idExpr, ii, getToken) << " ";
  1193. ii += idExpr.m_span.length() - 1;
  1194. emitAsIs = false;
  1195. }
  1196. }
  1197. if (emitAsIs)
  1198. {
  1199. IfIsSrgMemberValidateIsDefined(token, astNode);
  1200. // do minimal reformatting to have a pseudo-readable emitted code
  1201. auto str = token->getText();
  1202. bool lineFeed = str == ";" || str == "{";
  1203. output << str << (lineFeed ? '\n' : ' ');
  1204. }
  1205. }
  1206. if (codeMutation && codeMutation->m_append)
  1207. {
  1208. output << codeMutation->m_append.value();
  1209. }
  1210. }
  1211. }
  1212. }