Texture2DMSto2DCodeMutator.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. /*
  2. * All or portions of this file Copyright (c) Amazon.com, Inc. or its affiliates or
  3. * its licensors.
  4. *
  5. * For complete copyright and license terms please see the LICENSE at the root of this
  6. * distribution (the "License"). All use of this software is governed by the License,
  7. * or, if provided, by the license below or the license accompanying this file. Do not
  8. * remove or modify any license notices. This file is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. *
  11. */
  12. #include "Texture2DMSto2DCodeMutator.h"
  13. namespace AZ::ShaderCompiler
  14. {
  15. static constexpr char FunctionNameLoad[] = "Load";
  16. static constexpr char FunctionNameGetSamplePosition[] = "GetSamplePosition";
  17. static constexpr char FunctionNameGetDimensions[] = "GetDimensions";
  18. ///////////////////////////////////////////////////////////////////////
  19. // azslParserBaseListener Overrides ...
  20. void Texture2DMSto2DCodeMutator::enterFunctionCallExpression(azslParser::FunctionCallExpressionContext* ctx)
  21. {
  22. const auto expressionCtx = ctx->expression();
  23. const std::string functionName = expressionCtx->stop->getText();
  24. if (functionName == FunctionNameLoad)
  25. {
  26. OnEnterLoad(ctx);
  27. }
  28. else if (functionName == FunctionNameGetSamplePosition)
  29. {
  30. OnEnterGetSamplePosition(ctx);
  31. }
  32. else if (functionName == FunctionNameGetDimensions)
  33. {
  34. OnEnterGetDimensions(ctx);
  35. }
  36. }
  37. ///////////////////////////////////////////////////////////////////////
  38. ///////////////////////////////////////////////////////////////////////
  39. // ICodeEmissionMutator Overrides ...
  40. const CodeMutation* Texture2DMSto2DCodeMutator::GetMutation(ssize_t tokenId) const
  41. {
  42. auto itor = m_mutations.find(tokenId);
  43. if (itor == m_mutations.end())
  44. {
  45. return nullptr;
  46. }
  47. return &itor->second;
  48. }
  49. ///////////////////////////////////////////////////////////////////////
  50. void Texture2DMSto2DCodeMutator::RunMiddleEndMutations()
  51. {
  52. if (MutateTypeOfMultiSampleVariables())
  53. {
  54. MutateMultiSampleSystemSemantics();
  55. }
  56. }
  57. //! A helper function that returns the symbol name contained in @expressionCtx.
  58. static UnqualifiedName GetSymbolName(azslParser::ExpressionContext* expressionCtx)
  59. {
  60. const auto& children = expressionCtx->children;
  61. // We only care for cases with three children:
  62. // "<Symbol>", ".", "<funcName>"
  63. if (children.size() == 3)
  64. {
  65. string symbolName = Replace(children[0]->getText(), "::", "/");
  66. return UnqualifiedName{ symbolName };
  67. }
  68. return UnqualifiedName();
  69. }
  70. Texture2DMSto2DCodeMutator::TextureMSType Texture2DMSto2DCodeMutator::GetMultiSampledTextureClass(const UnqualifiedName& uqSymbolName)
  71. {
  72. if (uqSymbolName.empty())
  73. {
  74. return TextureMSType::None;
  75. }
  76. // We only care if the symbol that is calling Load(...) is of type Texture2DMS or Texture2DMSArray
  77. IdAndKind* idkind = m_ir->m_sema.LookupSymbol(uqSymbolName);
  78. if (!idkind)
  79. {
  80. return TextureMSType::None;
  81. }
  82. auto& [uid, kind] = *idkind;
  83. if (kind.GetKind() != Kind::Variable)
  84. {
  85. return TextureMSType::None;
  86. }
  87. auto varInfo = kind.GetSubAs<VarInfo>();
  88. if (varInfo->GetTypeClass() != TypeClass::MultisampledTexture)
  89. {
  90. return TextureMSType::None;
  91. }
  92. if (EndsWith(varInfo->m_typeInfoExt.m_coreType.m_typeId.GetName(), "Array"))
  93. {
  94. return TextureMSType::Texture2DMSArray;
  95. }
  96. return TextureMSType::Texture2DMS;
  97. }
  98. void Texture2DMSto2DCodeMutator::OnEnterLoad(azslParser::FunctionCallExpressionContext* ctx)
  99. {
  100. // First we must capture the complete name of the symbol that called <Symbol>.Load(...)
  101. const auto expressionCtx = ctx->expression();
  102. const UnqualifiedName uqSymbolName = GetSymbolName(expressionCtx);
  103. const TextureMSType textureMSType = GetMultiSampledTextureClass(uqSymbolName);
  104. if (textureMSType == TextureMSType::None)
  105. {
  106. return;
  107. }
  108. // Define the mutations.
  109. const auto argumentListCtx = ctx->argumentList();
  110. const auto argumentsCtx = argumentListCtx->arguments();
  111. auto vectorOfArguments = argumentsCtx->expression();
  112. // For Texture2DMS Load has two variants:
  113. // 1- Two arguments: int2 location, int sampleIndex
  114. // When mutating this variant to Texture2D the first argument will be prepended
  115. // with "int3("<location> and the second argument "sampleIndex" will be replaced with "0)".
  116. // 2- Three arguments: int2 location, int sampleIndex, int2 offset
  117. // When mutating this variant to Texture2D the first argument will be prepended
  118. // with "int3("<location> and the second argument "sampleIndex" will be replaced with "0)",
  119. // the third argument will remain as is.
  120. // For Texture2DMSArray it's the same as above, except that the first argument is of type int3.
  121. // And it will be wrapped with an int4.
  122. const string wrapperType = textureMSType == TextureMSType::Texture2DMSArray ? "int4(" : "int3(";
  123. if (vectorOfArguments.size() >= 2)
  124. {
  125. {
  126. auto firstArgContext = vectorOfArguments[0];
  127. ssize_t tokenIndex = firstArgContext->start->getTokenIndex();
  128. CodeMutation firstArgMutation;
  129. firstArgMutation.m_prepend.emplace(wrapperType);
  130. m_mutations.emplace(tokenIndex, firstArgMutation);
  131. }
  132. {
  133. // There's already a "," token that will be emitted after the first argument
  134. // So all we have to do is simply replace the second argument with "0)"
  135. // and will get in the end: int3( @firstArgContext, 0) or int4( @firstArgContext, 0) for MSArray.
  136. // Also keep in mind that we are working with ParseRuleContexts, and they are a range of
  137. // tokens, for the second argument all tokens will be dropped with "" empty strings,
  138. // and the last token will be dropped with "0)".
  139. auto secondArgContext = vectorOfArguments[1];
  140. const ssize_t startingTokenIndex = secondArgContext->start->getTokenIndex();
  141. const ssize_t stoppingTokenIndex = secondArgContext->stop->getTokenIndex();
  142. for (ssize_t tokenIndex = startingTokenIndex; tokenIndex < stoppingTokenIndex; ++tokenIndex)
  143. {
  144. CodeMutation codeMutation;
  145. codeMutation.m_replace.emplace("");
  146. m_mutations.emplace(tokenIndex, codeMutation);
  147. }
  148. CodeMutation codeMutation;
  149. codeMutation.m_replace.emplace("0)");
  150. m_mutations.emplace(stoppingTokenIndex, codeMutation);
  151. }
  152. }
  153. }
  154. void Texture2DMSto2DCodeMutator::OnEnterGetSamplePosition(azslParser::FunctionCallExpressionContext* ctx)
  155. {
  156. // First we must capture the complete name of the symbol that called <Symbol>.GetSamplePosition(...)
  157. const auto expressionCtx = ctx->expression();
  158. const UnqualifiedName uqSymbolName = GetSymbolName(expressionCtx);
  159. const TextureMSType textureMSType = GetMultiSampledTextureClass(uqSymbolName);
  160. if (textureMSType == TextureMSType::None)
  161. {
  162. return;
  163. }
  164. // Because GetSamplePosition() doesn't exist for Texture2D/Texture2DArray, we will replace
  165. // the whole expresion with float2(0, 0).
  166. const ssize_t startingTokenIndex = ctx->start->getTokenIndex();
  167. const ssize_t stoppingTokenIndex = ctx->stop->getTokenIndex();
  168. for (ssize_t tokenIndex = startingTokenIndex; tokenIndex < stoppingTokenIndex; ++tokenIndex)
  169. {
  170. CodeMutation codeMutation;
  171. codeMutation.m_replace.emplace("");
  172. m_mutations.emplace(tokenIndex, codeMutation);
  173. }
  174. CodeMutation codeMutation;
  175. codeMutation.m_replace.emplace("float2(0, 0)");
  176. m_mutations.emplace(stoppingTokenIndex, codeMutation);
  177. }
  178. void Texture2DMSto2DCodeMutator::OnEnterGetDimensions(azslParser::FunctionCallExpressionContext* ctx)
  179. {
  180. const auto expressionCtx = ctx->expression();
  181. const UnqualifiedName uqSymbolName = GetSymbolName(expressionCtx);
  182. const TextureMSType textureMSType = GetMultiSampledTextureClass(uqSymbolName);
  183. if (textureMSType == TextureMSType::None)
  184. {
  185. return;
  186. }
  187. // For Texture2DMS GetDimensions(...) only has one variant:
  188. // GetDimensions (width, height, samples)
  189. // All we have to do for Texture2D is to drop ", samples" and add "; samples = 1" after the closing parenthesis. We'll get:
  190. // GetDimensions (width, height); samples = 1
  191. // For Texture2DMSArray GetDimensions(...) only has one variant:
  192. // GetDimensions (width, height, elems, samples)
  193. // All we have to do for Texture2DArray is to drop ", samples" and add "; samples = 1" after the closing parenthesis. We'll get:
  194. // GetDimensions (width, height, elems); samples = 1
  195. // Remark: The last ";" is already present in the original code, this is why we append "; samples = 1" instead
  196. // of "; samples = 1;"
  197. const auto argumentListCtx = ctx->argumentList();
  198. const auto argumentsCtx = argumentListCtx->arguments();
  199. // From argumentsCtx we can detect the last "," token and we'll
  200. // add it to the mutation as a replacement with "".
  201. auto vectorOfCommas = argumentsCtx->Comma();
  202. auto lastCommaIndex = vectorOfCommas.size() - 1;
  203. auto lastCommaToken = vectorOfCommas[lastCommaIndex];
  204. {
  205. ssize_t tokenIndex = lastCommaToken->getSymbol()->getTokenIndex();
  206. CodeMutation codeMutation;
  207. codeMutation.m_replace.emplace("");
  208. m_mutations.emplace(tokenIndex, codeMutation);
  209. }
  210. // Drop the last argument.
  211. auto vectorOfArguments = argumentsCtx->expression();
  212. auto lastArgumentIndex = vectorOfArguments.size() - 1;
  213. auto lastArgumentCtx = vectorOfArguments[lastArgumentIndex];
  214. // Capture the name of the variable that gets the number of samples because
  215. // it will be assigned the value 1.
  216. string lastArgumentName = lastArgumentCtx->getText();
  217. {
  218. const ssize_t startingTokenIndex = lastArgumentCtx->start->getTokenIndex();
  219. const ssize_t stoppingTokenIndex = lastArgumentCtx->stop->getTokenIndex();
  220. for (ssize_t tokenIndex = startingTokenIndex; tokenIndex <= stoppingTokenIndex; ++tokenIndex)
  221. {
  222. CodeMutation codeMutation;
  223. codeMutation.m_replace.emplace("");
  224. m_mutations.emplace(tokenIndex, codeMutation);
  225. }
  226. }
  227. // Get the rule context for the closing right parenthesis ")".
  228. // "; samples = 1" will be added after it.
  229. const auto rightParenthesisNode = ctx->argumentList()->RightParen();
  230. {
  231. const ssize_t parenthesisTokenIndex = rightParenthesisNode->getSymbol()->getTokenIndex();
  232. CodeMutation codeMutation;
  233. string samplesExpression = AZ::FormatString("; %s = 1 ", lastArgumentName.c_str());
  234. codeMutation.m_append.emplace(samplesExpression);
  235. m_mutations.emplace(parenthesisTokenIndex, codeMutation);
  236. }
  237. }
  238. size_t Texture2DMSto2DCodeMutator::MutateTypeOfMultiSampleVariables()
  239. {
  240. size_t mutationCount = 0;
  241. // Get all variables that are members of something of type Texture2DMS
  242. // We use this function pointer to find SRGs that have no references.
  243. auto texture2DMSFilterFunc = +[](KindInfo* kindInfo) {
  244. const auto* varInfo = kindInfo->GetSubAs<VarInfo>();
  245. return varInfo->m_typeInfoExt.m_coreType.m_typeClass == TypeClass::MultisampledTexture;
  246. };
  247. vector<IdentifierUID> texture2DMSVariables = m_ir->GetFilteredSymbolsOfSubType<VarInfo>(texture2DMSFilterFunc);
  248. for (const auto& uid : texture2DMSVariables)
  249. {
  250. auto varInfo = m_ir->GetSymbolSubAs<VarInfo>(uid.GetName());
  251. auto& typeId = varInfo->m_typeInfoExt.m_coreType.m_typeId;
  252. auto typeName = typeId.GetName();
  253. if (typeName == "?Texture2DMS")
  254. {
  255. typeId.m_name = QualifiedName{ "?Texture2D" };
  256. }
  257. else
  258. {
  259. typeId.m_name = QualifiedName{ "?Texture2DArray" };
  260. }
  261. ++mutationCount;
  262. }
  263. return mutationCount;
  264. }
  265. void Texture2DMSto2DCodeMutator::MutateMultiSampleSystemSemantics()
  266. {
  267. //Let's find all variables that have a system semantic.
  268. auto variablesWithSystemSemanticFilterFunc = +[](KindInfo* kindInfo) {
  269. const auto* varInfo = kindInfo->GetSubAs<VarInfo>();
  270. if (!varInfo->m_declNode)
  271. {
  272. return false;
  273. }
  274. auto* semanticOption = varInfo->m_declNode->SemanticOpt;
  275. if (!semanticOption)
  276. {
  277. return false;
  278. }
  279. return semanticOption->hlslSemanticName()->HLSLSemanticSystem() != nullptr;
  280. };
  281. vector<IdentifierUID> systemSemanticVariables = m_ir->GetFilteredSymbolsOfSubType<VarInfo>(variablesWithSystemSemanticFilterFunc);
  282. for (const auto& uid : systemSemanticVariables)
  283. {
  284. auto varInfo = m_ir->GetSymbolSubAs<VarInfo>(uid.GetName());
  285. // Get the semantic name.
  286. auto systemSemanticName = varInfo->m_declNode->SemanticOpt->hlslSemanticName()->HLSLSemanticSystem()->getText();
  287. static const std::array<string_view, 2> SystemSemanticsNames =
  288. {
  289. "SV_SampleIndex",
  290. "SV_Coverage",
  291. };
  292. if (!IsIn(systemSemanticName, SystemSemanticsNames))
  293. {
  294. continue;
  295. }
  296. //Semantics can be part of a struct, or function parameters.
  297. if (ParamContextOverVariableDeclarator(varInfo->m_declNode))
  298. {
  299. // This is a function parameter.
  300. IdentifierUID functionUid = IdentifierUID{ GetParentName(uid.GetName()) };
  301. DropMultiSamplingSystemSemanticFromFunction(uid, varInfo, systemSemanticName, functionUid);
  302. }
  303. else
  304. {
  305. // This is a variable within a struct
  306. IdentifierUID structUid = IdentifierUID{ GetParentName(uid.GetName()) };
  307. MutateMultiSamplingSystemSemanticInStruct(uid, varInfo, systemSemanticName, structUid);
  308. }
  309. }
  310. }
  311. //! A helper method that figures out how a function argument should look like
  312. //! when mutated into a local variable.
  313. static string GetLocalVariableStringFromFunctionArgument(const UnqualifiedName& uqName, AstUnnamedVarDecl* ctx, const char * initializationValue)
  314. {
  315. azslParser::FunctionParamContext* paramCtx = nullptr;
  316. auto typeCtx = ExtractTypeFromVariableDeclarator(ctx, &paramCtx);
  317. auto variableTypeStr = typeCtx->getText();
  318. if (initializationValue)
  319. {
  320. return FormatString("%s %s = (%s)%s;\n", variableTypeStr.c_str(), uqName.c_str(), variableTypeStr.c_str(), initializationValue);
  321. }
  322. return FormatString("%s %s;\n", variableTypeStr.c_str(), uqName.c_str());
  323. }
  324. void Texture2DMSto2DCodeMutator::DropMultiSamplingSystemSemanticFromFunction(const IdentifierUID& varUid, const VarInfo* varInfo, const string& systemSemanticName, const IdentifierUID& functionUid)
  325. {
  326. // Let's get the FunctionInfo object and report this variable, which will be dropped from the
  327. // input arguments and will be re-emitted as a local variable to avoid compiler errors from other
  328. // pieces of code that may reference the semantic.
  329. // Example:
  330. // PSOutput mainPS(VSOutput IN, in uint sampleIndex : SV_SampleIndex)
  331. // {
  332. // ...
  333. // int2 sampleIndexVector = int2(sampleIndex, sampleIndex);
  334. // ...
  335. // }
  336. // Will look like this when emitted (Which will avoid compilation errors)
  337. // PSOutput mainPS(VSOutput IN)
  338. // {
  339. // uint sampleIndex = 0;
  340. // ...
  341. // int2 sampleIndexVector = int2(sampleIndex, sampleIndex);
  342. // ...
  343. // }
  344. FunctionInfo* functionInfo = m_ir->GetSymbolSubAs<FunctionInfo>(functionUid.GetName());
  345. functionInfo->DeleteParameter(varUid);
  346. string initializationValue = "0";
  347. if (systemSemanticName == "SV_Coverage")
  348. {
  349. // SV_Coverage is a mask where each bit represents active sample indices.
  350. // In this case we initialize to -1, because bitwise it will look like as if
  351. // all samples are active.
  352. // Usually code that that uses SV_Coverage loops over this mask (limited by the number of samples,
  353. // which will be 1 for no-MS) for each sampleIndex.
  354. // By settings to -1 it will mimic full coverage and the rendering logic will
  355. // work seamlessly. It could be set to "1", but "-1" would cover all cases.
  356. initializationValue = "-1";
  357. }
  358. auto newCode = GetLocalVariableStringFromFunctionArgument(varUid.GetNameLeaf(), varInfo->m_declNode, initializationValue.c_str());
  359. // The idea is to find the TokenIndex of the opening bracket "{",
  360. // Once we know that TokenIndex we can add code mutation as an appended
  361. // string.
  362. auto hlslFunctionDefinitionContext = ExtractSpecificParent<azslParser::HlslFunctionDefinitionContext>(functionInfo->m_defNode);
  363. auto blockContext = hlslFunctionDefinitionContext->block();
  364. auto leftBraceTokenIndex = blockContext->LeftBrace()->getSymbol()->getTokenIndex();
  365. auto itor = m_mutations.find(leftBraceTokenIndex);
  366. if (itor == m_mutations.end())
  367. {
  368. CodeMutation mutation;
  369. mutation.m_append.emplace(newCode);
  370. m_mutations.emplace(leftBraceTokenIndex, mutation);
  371. }
  372. else
  373. {
  374. CodeMutation& mutation = itor->second;
  375. string prevCode = mutation.m_append.value();
  376. mutation.m_append.emplace(prevCode + newCode);
  377. }
  378. }
  379. void Texture2DMSto2DCodeMutator::MutateMultiSamplingSystemSemanticInStruct(const IdentifierUID& varUid, const VarInfo* varInfo, const string& systemSemanticName, const IdentifierUID& structUid)
  380. {
  381. // This is the case of member variable of a struct, but it is a system semantic.
  382. // Example:
  383. // struct VSOutput
  384. // {
  385. // float4 m_position : SV_Position;
  386. // float2 m_texCoord : TEXCOORD0;
  387. // uint m_sampleIndex : SV_SampleIndex; <--- This is the variable in question.
  388. // };
  389. // Will look like this when emitted (Which will avoid compilation errors)
  390. // PSOutput mainPS(VSOutput IN)
  391. // {
  392. // float4 m_position : SV_Position;
  393. // float2 m_texCoord : TEXCOORD0;
  394. // static const uint m_sampleIndex = 0; <-- Became a static const, and of course, the semantic is removed.
  395. // ...
  396. // }
  397. string initializationValue = "0";
  398. if (systemSemanticName == "SV_Coverage")
  399. {
  400. initializationValue = "-1";
  401. }
  402. auto newCode = GetLocalVariableStringFromFunctionArgument(varUid.GetNameLeaf(), varInfo->m_declNode, initializationValue.c_str());
  403. auto tokenIndex = varInfo->m_declNode->start->getTokenIndex();
  404. CodeMutation mutation;
  405. mutation.m_prepend.emplace("static const ");
  406. mutation.m_replace.emplace(newCode);
  407. m_mutations.emplace(tokenIndex, mutation);
  408. }
  409. } //namespace AZ::ShaderCompiler