HLSignatureLower.cpp 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLSignatureLower.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Lower signatures of entry function to DXIL LoadInput/StoreOutput. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "HLSignatureLower.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilSignatureElement.h"
  14. #include "dxc/DXIL/DxilSigPoint.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "dxc/DXIL/DxilSemantic.h"
  18. #include "dxc/HLSL/HLModule.h"
  19. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  20. #include "dxc/HLSL/HLMatrixType.h"
  21. #include "dxc/HlslIntrinsicOp.h"
  22. #include "dxc/DXIL/DxilUtil.h"
  23. #include "dxc/HLSL/DxilPackSignatureElement.h"
  24. #include "llvm/IR/IRBuilder.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/IR/Module.h"
  28. #include "llvm/Transforms/Utils/Local.h"
  29. using namespace llvm;
  30. using namespace hlsl;
  31. namespace {
  32. // Decompose semantic name (eg FOO1=>FOO,1), change interp mode for SV_Position.
  33. // Return semantic index.
  34. unsigned UpdateSemanticAndInterpMode(StringRef &semName,
  35. DXIL::InterpolationMode &mode,
  36. DXIL::SigPointKind kind,
  37. LLVMContext &Context) {
  38. llvm::StringRef baseSemName; // The 'FOO' in 'FOO1'.
  39. uint32_t semIndex; // The '1' in 'FOO1'
  40. // Split semName and index.
  41. Semantic::DecomposeNameAndIndex(semName, &baseSemName, &semIndex);
  42. semName = baseSemName;
  43. const Semantic *semantic = Semantic::GetByName(semName, kind);
  44. if (semantic && semantic->GetKind() == Semantic::Kind::Position) {
  45. // Update interp mode to no_perspective version for SV_Position.
  46. switch (mode) {
  47. case InterpolationMode::Kind::LinearCentroid:
  48. mode = InterpolationMode::Kind::LinearNoperspectiveCentroid;
  49. break;
  50. case InterpolationMode::Kind::LinearSample:
  51. mode = InterpolationMode::Kind::LinearNoperspectiveSample;
  52. break;
  53. case InterpolationMode::Kind::Linear:
  54. mode = InterpolationMode::Kind::LinearNoperspective;
  55. break;
  56. case InterpolationMode::Kind::Constant:
  57. case InterpolationMode::Kind::Undefined:
  58. case InterpolationMode::Kind::Invalid: {
  59. Context.emitError("invalid interpolation mode for SV_Position");
  60. } break;
  61. case InterpolationMode::Kind::LinearNoperspective:
  62. case InterpolationMode::Kind::LinearNoperspectiveCentroid:
  63. case InterpolationMode::Kind::LinearNoperspectiveSample:
  64. // Already Noperspective modes.
  65. break;
  66. }
  67. }
  68. return semIndex;
  69. }
  70. DxilSignatureElement *FindArgInSignature(Argument &arg,
  71. llvm::StringRef semantic,
  72. DXIL::InterpolationMode interpMode,
  73. DXIL::SigPointKind kind,
  74. DxilSignature &sig) {
  75. // Match output ID.
  76. unsigned semIndex =
  77. UpdateSemanticAndInterpMode(semantic, interpMode, kind, arg.getContext());
  78. for (uint32_t i = 0; i < sig.GetElements().size(); i++) {
  79. DxilSignatureElement &SE = sig.GetElement(i);
  80. bool semNameMatch = semantic.equals_lower(SE.GetName());
  81. bool semIndexMatch = semIndex == SE.GetSemanticIndexVec()[0];
  82. if (semNameMatch && semIndexMatch) {
  83. // Find a match.
  84. return &SE;
  85. }
  86. }
  87. return nullptr;
  88. }
  89. } // namespace
  90. namespace {
  91. void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
  92. OP *hlslOP, IRBuilder<> &Builder) {
  93. Type *Ty = GV->getType();
  94. if (Ty->isPointerTy())
  95. Ty = Ty->getPointerElementType();
  96. OP::OpCode opcode;
  97. switch (semKind) {
  98. case Semantic::Kind::DomainLocation:
  99. opcode = OP::OpCode::DomainLocation;
  100. break;
  101. case Semantic::Kind::OutputControlPointID:
  102. opcode = OP::OpCode::OutputControlPointID;
  103. break;
  104. case Semantic::Kind::GSInstanceID:
  105. opcode = OP::OpCode::GSInstanceID;
  106. break;
  107. case Semantic::Kind::PrimitiveID:
  108. opcode = OP::OpCode::PrimitiveID;
  109. break;
  110. case Semantic::Kind::SampleIndex:
  111. opcode = OP::OpCode::SampleIndex;
  112. break;
  113. case Semantic::Kind::Coverage:
  114. opcode = OP::OpCode::Coverage;
  115. break;
  116. case Semantic::Kind::InnerCoverage:
  117. opcode = OP::OpCode::InnerCoverage;
  118. break;
  119. case Semantic::Kind::ViewID:
  120. opcode = OP::OpCode::ViewID;
  121. break;
  122. case Semantic::Kind::GroupThreadID:
  123. opcode = OP::OpCode::ThreadIdInGroup;
  124. break;
  125. case Semantic::Kind::GroupID:
  126. opcode = OP::OpCode::GroupId;
  127. break;
  128. case Semantic::Kind::DispatchThreadID:
  129. opcode = OP::OpCode::ThreadId;
  130. break;
  131. case Semantic::Kind::GroupIndex:
  132. opcode = OP::OpCode::FlattenedThreadIdInGroup;
  133. break;
  134. default:
  135. DXASSERT(0, "invalid semantic");
  136. return;
  137. }
  138. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty->getScalarType());
  139. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  140. Value *newArg = nullptr;
  141. if (semKind == Semantic::Kind::DomainLocation ||
  142. semKind == Semantic::Kind::GroupThreadID ||
  143. semKind == Semantic::Kind::GroupID ||
  144. semKind == Semantic::Kind::DispatchThreadID) {
  145. unsigned vecSize = 1;
  146. if (Ty->isVectorTy())
  147. vecSize = Ty->getVectorNumElements();
  148. newArg = Builder.CreateCall(dxilFunc, { OpArg,
  149. semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(0) : hlslOP->GetU32Const(0) });
  150. if (vecSize > 1) {
  151. Value *result = UndefValue::get(Ty);
  152. result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
  153. for (unsigned i = 1; i < vecSize; i++) {
  154. Value *newElt =
  155. Builder.CreateCall(dxilFunc, { OpArg,
  156. semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(i)
  157. : hlslOP->GetU32Const(i) });
  158. result = Builder.CreateInsertElement(result, newElt, i);
  159. }
  160. newArg = result;
  161. }
  162. } else {
  163. newArg = Builder.CreateCall(dxilFunc, {OpArg});
  164. }
  165. if (newArg->getType() != GV->getType()) {
  166. DXASSERT_NOMSG(GV->getType()->isPointerTy());
  167. for (User *U : GV->users()) {
  168. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  169. LI->replaceAllUsesWith(newArg);
  170. }
  171. }
  172. } else {
  173. GV->replaceAllUsesWith(newArg);
  174. }
  175. }
  176. } // namespace
  177. void HLSignatureLower::ProcessArgument(Function *func,
  178. DxilFunctionAnnotation *funcAnnotation,
  179. Argument &arg, DxilFunctionProps &props,
  180. const ShaderModel *pSM,
  181. bool isPatchConstantFunction,
  182. bool forceOut, bool &hasClipPlane) {
  183. Type *Ty = arg.getType();
  184. DxilParameterAnnotation &paramAnnotation =
  185. funcAnnotation->GetParameterAnnotation(arg.getArgNo());
  186. hlsl::DxilParamInputQual qual =
  187. forceOut ? DxilParamInputQual::Out : paramAnnotation.GetParamInputQual();
  188. bool isInout = qual == DxilParamInputQual::Inout;
  189. // If this was an inout param, do the output side first
  190. if (isInout) {
  191. DXASSERT(!isPatchConstantFunction,
  192. "Patch Constant function should not have inout param");
  193. m_inoutArgSet.insert(&arg);
  194. const bool bForceOutTrue = true;
  195. ProcessArgument(func, funcAnnotation, arg, props, pSM,
  196. isPatchConstantFunction, bForceOutTrue, hasClipPlane);
  197. qual = DxilParamInputQual::In;
  198. }
  199. // Get stream index
  200. unsigned streamIdx = 0;
  201. switch (qual) {
  202. case DxilParamInputQual::OutStream1:
  203. streamIdx = 1;
  204. break;
  205. case DxilParamInputQual::OutStream2:
  206. streamIdx = 2;
  207. break;
  208. case DxilParamInputQual::OutStream3:
  209. streamIdx = 3;
  210. break;
  211. default:
  212. // Use streamIdx = 0 by default.
  213. break;
  214. }
  215. const SigPoint *sigPoint = SigPoint::GetSigPoint(
  216. SigPointFromInputQual(qual, props.shaderKind, isPatchConstantFunction));
  217. unsigned rows, cols;
  218. HLModule::GetParameterRowsAndCols(Ty, rows, cols, paramAnnotation);
  219. CompType EltTy = paramAnnotation.GetCompType();
  220. DXIL::InterpolationMode interpMode =
  221. paramAnnotation.GetInterpolationMode().GetKind();
  222. // Set undefined interpMode.
  223. if (sigPoint->GetKind() == DXIL::SigPointKind::MSPOut) {
  224. if (interpMode != InterpolationMode::Kind::Undefined &&
  225. interpMode != InterpolationMode::Kind::Constant) {
  226. dxilutil::EmitErrorOnFunction(func,
  227. "Mesh shader's primitive outputs' interpolation mode must be constant or undefined.");
  228. }
  229. interpMode = InterpolationMode::Kind::Constant;
  230. }
  231. else if (!sigPoint->NeedsInterpMode())
  232. interpMode = InterpolationMode::Kind::Undefined;
  233. else if (interpMode == InterpolationMode::Kind::Undefined) {
  234. // Type-based default: linear for floats, constant for others.
  235. if (EltTy.IsFloatTy())
  236. interpMode = InterpolationMode::Kind::Linear;
  237. else
  238. interpMode = InterpolationMode::Kind::Constant;
  239. }
  240. // back-compat mode - remap obsolete semantics
  241. if (HLM.GetHLOptions().bDX9CompatMode && paramAnnotation.HasSemanticString()) {
  242. hlsl::RemapObsoleteSemantic(paramAnnotation, sigPoint->GetKind(), HLM.GetCtx());
  243. }
  244. llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
  245. if (semanticStr.empty()) {
  246. dxilutil::EmitErrorOnFunction(func,
  247. "Semantic must be defined for all parameters of an entry function or "
  248. "patch constant function");
  249. return;
  250. }
  251. UpdateSemanticAndInterpMode(semanticStr, interpMode, sigPoint->GetKind(),
  252. arg.getContext());
  253. // Get Semantic interpretation, skipping if not in signature
  254. const Semantic *pSemantic = Semantic::GetByName(semanticStr);
  255. DXIL::SemanticInterpretationKind interpretation =
  256. SigPoint::GetInterpretation(pSemantic->GetKind(), sigPoint->GetKind(),
  257. pSM->GetMajor(), pSM->GetMinor());
  258. // Verify system value semantics do not overlap.
  259. // Note: Arbitrary are always in the signature and will be verified with a
  260. // different mechanism. For patch constant function, only validate patch
  261. // constant elements (others already validated on hull function)
  262. if (pSemantic->GetKind() != DXIL::SemanticKind::Arbitrary &&
  263. (!isPatchConstantFunction ||
  264. (!sigPoint->IsInput() && !sigPoint->IsOutput()))) {
  265. auto &SemanticUseMap =
  266. sigPoint->IsInput()
  267. ? m_InputSemanticsUsed
  268. : (sigPoint->IsOutput()
  269. ? m_OutputSemanticsUsed[streamIdx]
  270. : (sigPoint->IsPatchConstOrPrim() ? m_PatchConstantSemanticsUsed
  271. : m_OtherSemanticsUsed));
  272. if (SemanticUseMap.count((unsigned)pSemantic->GetKind()) > 0) {
  273. auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
  274. for (unsigned idx : paramAnnotation.GetSemanticIndexVec()) {
  275. if (SemanticIndexSet.count(idx) > 0) {
  276. dxilutil::EmitErrorOnFunction(func, "Parameter with semantic " + semanticStr +
  277. " has overlapping semantic index at " + std::to_string(idx) + ".");
  278. return;
  279. }
  280. }
  281. }
  282. auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
  283. for (unsigned idx : paramAnnotation.GetSemanticIndexVec()) {
  284. SemanticIndexSet.emplace(idx);
  285. }
  286. // Enforce Coverage and InnerCoverage input mutual exclusivity
  287. if (sigPoint->IsInput()) {
  288. if ((pSemantic->GetKind() == DXIL::SemanticKind::Coverage &&
  289. SemanticUseMap.count((unsigned)DXIL::SemanticKind::InnerCoverage) >
  290. 0) ||
  291. (pSemantic->GetKind() == DXIL::SemanticKind::InnerCoverage &&
  292. SemanticUseMap.count((unsigned)DXIL::SemanticKind::Coverage) > 0)) {
  293. dxilutil::EmitErrorOnFunction(func,
  294. "Pixel shader inputs SV_Coverage and SV_InnerCoverage are mutually "
  295. "exclusive");
  296. return;
  297. }
  298. }
  299. }
  300. // Validate interpretation and replace argument usage with load/store
  301. // intrinsics
  302. {
  303. switch (interpretation) {
  304. case DXIL::SemanticInterpretationKind::NA: {
  305. dxilutil::EmitErrorOnFunction(func, Twine("Semantic ") + semanticStr +
  306. Twine(" is invalid for shader model: ") +
  307. ShaderModel::GetKindName(props.shaderKind));
  308. return;
  309. }
  310. case DXIL::SemanticInterpretationKind::NotInSig:
  311. case DXIL::SemanticInterpretationKind::Shadow: {
  312. IRBuilder<> funcBuilder(func->getEntryBlock().getFirstInsertionPt());
  313. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&arg)) {
  314. funcBuilder.SetCurrentDebugLocation(DDI->getDebugLoc());
  315. }
  316. replaceInputOutputWithIntrinsic(pSemantic->GetKind(), &arg, HLM.GetOP(),
  317. funcBuilder);
  318. if (interpretation == DXIL::SemanticInterpretationKind::NotInSig)
  319. return; // This argument should not be included in the signature
  320. break;
  321. }
  322. case DXIL::SemanticInterpretationKind::SV:
  323. case DXIL::SemanticInterpretationKind::SGV:
  324. case DXIL::SemanticInterpretationKind::Arb:
  325. case DXIL::SemanticInterpretationKind::Target:
  326. case DXIL::SemanticInterpretationKind::TessFactor:
  327. case DXIL::SemanticInterpretationKind::NotPacked:
  328. case DXIL::SemanticInterpretationKind::ClipCull:
  329. // Will be replaced with load/store intrinsics in
  330. // GenerateDxilInputsOutputs
  331. break;
  332. default:
  333. DXASSERT(false, "Unexpected SemanticInterpretationKind");
  334. return;
  335. }
  336. }
  337. // Determine signature this argument belongs in, if any
  338. DxilSignature *pSig = nullptr;
  339. DXIL::SignatureKind sigKind = sigPoint->GetSignatureKindWithFallback();
  340. switch (sigKind) {
  341. case DXIL::SignatureKind::Input:
  342. pSig = &EntrySig.InputSignature;
  343. break;
  344. case DXIL::SignatureKind::Output:
  345. pSig = &EntrySig.OutputSignature;
  346. break;
  347. case DXIL::SignatureKind::PatchConstOrPrim:
  348. pSig = &EntrySig.PatchConstOrPrimSignature;
  349. break;
  350. default:
  351. DXASSERT(false, "Expected real signature kind at this point");
  352. return; // No corresponding signature
  353. }
  354. // Create and add element to signature
  355. DxilSignatureElement *pSE = nullptr;
  356. {
  357. // Add signature element to appropriate maps
  358. if (isPatchConstantFunction &&
  359. sigKind != DXIL::SignatureKind::PatchConstOrPrim) {
  360. pSE = FindArgInSignature(arg, paramAnnotation.GetSemanticString(),
  361. interpMode, sigPoint->GetKind(), *pSig);
  362. if (!pSE) {
  363. dxilutil::EmitErrorOnFunction(func, Twine("Signature element ") + semanticStr +
  364. Twine(", referred to by patch constant function, is not found in "
  365. "corresponding hull shader ") +
  366. (sigKind == DXIL::SignatureKind::Input ? "input." : "output."));
  367. return;
  368. }
  369. m_patchConstantInputsSigMap[arg.getArgNo()] = pSE;
  370. } else {
  371. std::unique_ptr<DxilSignatureElement> SE = pSig->CreateElement();
  372. pSE = SE.get();
  373. pSig->AppendElement(std::move(SE));
  374. pSE->SetSigPointKind(sigPoint->GetKind());
  375. pSE->Initialize(semanticStr, EltTy, interpMode, rows, cols,
  376. Semantic::kUndefinedRow, Semantic::kUndefinedCol,
  377. pSE->GetID(), paramAnnotation.GetSemanticIndexVec());
  378. m_sigValueMap[pSE] = &arg;
  379. }
  380. }
  381. if (paramAnnotation.IsPrecise())
  382. m_preciseSigSet.insert(pSE);
  383. if (sigKind == DXIL::SignatureKind::Output &&
  384. pSemantic->GetKind() == Semantic::Kind::Position && hasClipPlane) {
  385. GenerateClipPlanesForVS(&arg);
  386. hasClipPlane = false;
  387. }
  388. // Set Output Stream.
  389. if (streamIdx > 0)
  390. pSE->SetOutputStream(streamIdx);
  391. }
  392. void HLSignatureLower::CreateDxilSignatures() {
  393. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  394. const ShaderModel *pSM = HLM.GetShaderModel();
  395. DXASSERT(Entry->getReturnType()->isVoidTy(),
  396. "Should changed in SROA_Parameter_HLSL");
  397. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  398. DXASSERT(EntryAnnotation, "must have function annotation for entry function");
  399. bool bHasClipPlane =
  400. props.shaderKind == DXIL::ShaderKind::Vertex ? HasClipPlanes() : false;
  401. const bool isPatchConstantFunctionFalse = false;
  402. const bool bForOutFasle = false;
  403. for (Argument &arg : Entry->getArgumentList()) {
  404. Type *Ty = arg.getType();
  405. // Skip streamout obj.
  406. if (HLModule::IsStreamOutputPtrType(Ty))
  407. continue;
  408. // Skip OutIndices and InPayload
  409. DxilParameterAnnotation &paramAnnotation =
  410. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  411. hlsl::DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
  412. if (qual == hlsl::DxilParamInputQual::OutIndices ||
  413. qual == hlsl::DxilParamInputQual::InPayload)
  414. continue;
  415. ProcessArgument(Entry, EntryAnnotation, arg, props, pSM,
  416. isPatchConstantFunctionFalse, bForOutFasle, bHasClipPlane);
  417. }
  418. if (bHasClipPlane) {
  419. dxilutil::EmitErrorOnFunction(Entry, "Cannot use clipplanes attribute without "
  420. "specifying a 4-component SV_Position "
  421. "output");
  422. }
  423. m_OtherSemanticsUsed.clear();
  424. if (props.shaderKind == DXIL::ShaderKind::Hull) {
  425. Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc;
  426. if (patchConstantFunc == nullptr) {
  427. dxilutil::EmitErrorOnFunction(Entry,
  428. "Patch constant function is not specified.");
  429. }
  430. DxilFunctionAnnotation *patchFuncAnnotation =
  431. HLM.GetFunctionAnnotation(patchConstantFunc);
  432. DXASSERT(patchFuncAnnotation,
  433. "must have function annotation for patch constant function");
  434. const bool isPatchConstantFunctionTrue = true;
  435. for (Argument &arg : patchConstantFunc->getArgumentList()) {
  436. ProcessArgument(patchConstantFunc, patchFuncAnnotation, arg, props, pSM,
  437. isPatchConstantFunctionTrue, bForOutFasle, bHasClipPlane);
  438. }
  439. }
  440. }
  441. // Allocate input/output slots
  442. void HLSignatureLower::AllocateDxilInputOutputs() {
  443. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  444. const ShaderModel *pSM = HLM.GetShaderModel();
  445. const HLOptions &opts = HLM.GetHLOptions();
  446. DXASSERT_NOMSG(opts.PackingStrategy <
  447. (unsigned)DXIL::PackingStrategy::Invalid);
  448. DXIL::PackingStrategy packing = (DXIL::PackingStrategy)opts.PackingStrategy;
  449. if (packing == DXIL::PackingStrategy::Default)
  450. packing = pSM->GetDefaultPackingStrategy();
  451. hlsl::PackDxilSignature(EntrySig.InputSignature, packing);
  452. if (!EntrySig.InputSignature.IsFullyAllocated()) {
  453. dxilutil::EmitErrorOnFunction(Entry,
  454. "Failed to allocate all input signature elements in available space.");
  455. }
  456. if (props.shaderKind != DXIL::ShaderKind::Amplification) {
  457. hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
  458. if (!EntrySig.OutputSignature.IsFullyAllocated()) {
  459. dxilutil::EmitErrorOnFunction(Entry,
  460. "Failed to allocate all output signature elements in available space.");
  461. }
  462. }
  463. if (props.shaderKind == DXIL::ShaderKind::Hull ||
  464. props.shaderKind == DXIL::ShaderKind::Domain ||
  465. props.shaderKind == DXIL::ShaderKind::Mesh) {
  466. hlsl::PackDxilSignature(EntrySig.PatchConstOrPrimSignature, packing);
  467. if (!EntrySig.PatchConstOrPrimSignature.IsFullyAllocated()) {
  468. dxilutil::EmitErrorOnFunction(Entry,
  469. "Failed to allocate all patch constant signature "
  470. "elements in available space.");
  471. }
  472. }
  473. }
  474. namespace {
  475. // Helper functions and class for lower signature.
  476. void GenerateStOutput(Function *stOutput, MutableArrayRef<Value *> args,
  477. IRBuilder<> &Builder, bool cast) {
  478. if (cast) {
  479. Value *value = args[DXIL::OperandIndex::kStoreOutputValOpIdx];
  480. args[DXIL::OperandIndex::kStoreOutputValOpIdx] =
  481. Builder.CreateZExt(value, Builder.getInt32Ty());
  482. }
  483. Builder.CreateCall(stOutput, args);
  484. }
  485. void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
  486. Constant *OpArg, Constant *outputID, Value *idx,
  487. unsigned cols, Value *vertexOrPrimID, bool bI1Cast) {
  488. IRBuilder<> Builder(stInst);
  489. Value *val = stInst->getValueOperand();
  490. if (VectorType *VT = dyn_cast<VectorType>(val->getType())) {
  491. DXASSERT_LOCALVAR(VT, cols == VT->getNumElements(), "vec size must match");
  492. for (unsigned col = 0; col < cols; col++) {
  493. Value *subVal = Builder.CreateExtractElement(val, col);
  494. Value *colIdx = Builder.getInt8(col);
  495. SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, subVal};
  496. if (vertexOrPrimID)
  497. args.emplace_back(vertexOrPrimID);
  498. GenerateStOutput(stOutput, args, Builder, bI1Cast);
  499. }
  500. // remove stInst
  501. stInst->eraseFromParent();
  502. } else if (!val->getType()->isArrayTy()) {
  503. // TODO: support case cols not 1
  504. DXASSERT(cols == 1, "only support scalar here");
  505. Value *colIdx = Builder.getInt8(0);
  506. SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, val};
  507. if (vertexOrPrimID)
  508. args.emplace_back(vertexOrPrimID);
  509. GenerateStOutput(stOutput, args, Builder, bI1Cast);
  510. // remove stInst
  511. stInst->eraseFromParent();
  512. } else {
  513. DXASSERT(0, "not support array yet");
  514. // TODO: support array.
  515. Value *colIdx = Builder.getInt8(0);
  516. ArrayType *AT = cast<ArrayType>(val->getType());
  517. Value *args[] = {OpArg, outputID, idx, colIdx, /*val*/ nullptr};
  518. (void)args;
  519. (void)AT;
  520. }
  521. }
  522. Value *GenerateLdInput(Function *loadInput, ArrayRef<Value *> args,
  523. IRBuilder<> &Builder, Value *zero, bool bCast,
  524. Type *Ty) {
  525. Value *input = Builder.CreateCall(loadInput, args);
  526. if (!bCast)
  527. return input;
  528. else {
  529. Value *bVal = Builder.CreateICmpNE(input, zero);
  530. IntegerType *IT = cast<IntegerType>(Ty);
  531. if (IT->getBitWidth() == 1)
  532. return bVal;
  533. else
  534. return Builder.CreateZExt(bVal, Ty);
  535. }
  536. }
  537. Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
  538. unsigned cols, MutableArrayRef<Value *> args,
  539. bool bCast) {
  540. IRBuilder<> Builder(ldInst);
  541. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(ldInst));
  542. Type *Ty = ldInst->getType();
  543. Type *EltTy = Ty->getScalarType();
  544. // Change i1 to i32 for load input.
  545. Value *zero = Builder.getInt32(0);
  546. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  547. Value *newVec = llvm::UndefValue::get(VT);
  548. DXASSERT(cols == VT->getNumElements(), "vec size must match");
  549. for (unsigned col = 0; col < cols; col++) {
  550. Value *colIdx = Builder.getInt8(col);
  551. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  552. Value *input =
  553. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  554. newVec = Builder.CreateInsertElement(newVec, input, col);
  555. }
  556. ldInst->replaceAllUsesWith(newVec);
  557. ldInst->eraseFromParent();
  558. return newVec;
  559. } else {
  560. Value *colIdx = args[DXIL::OperandIndex::kLoadInputColOpIdx];
  561. if (colIdx == nullptr) {
  562. DXASSERT(cols == 1, "only support scalar here");
  563. colIdx = Builder.getInt8(0);
  564. } else {
  565. if (colIdx->getType() == Builder.getInt32Ty()) {
  566. colIdx = Builder.CreateTrunc(colIdx, Builder.getInt8Ty());
  567. }
  568. }
  569. if (isa<ConstantInt>(colIdx)) {
  570. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  571. Value *input =
  572. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  573. ldInst->replaceAllUsesWith(input);
  574. ldInst->eraseFromParent();
  575. return input;
  576. } else {
  577. // Vector indexing.
  578. // Load to array.
  579. ArrayType *AT = ArrayType::get(ldInst->getType(), cols);
  580. Value *arrayVec = AllocaBuilder.CreateAlloca(AT);
  581. Value *zeroIdx = Builder.getInt32(0);
  582. for (unsigned col = 0; col < cols; col++) {
  583. Value *colIdx = Builder.getInt8(col);
  584. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  585. Value *input =
  586. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  587. Value *GEP = Builder.CreateInBoundsGEP(arrayVec, {zeroIdx, colIdx});
  588. Builder.CreateStore(input, GEP);
  589. }
  590. Value *vecIndexingPtr =
  591. Builder.CreateInBoundsGEP(arrayVec, {zeroIdx, colIdx});
  592. Value *input = Builder.CreateLoad(vecIndexingPtr);
  593. ldInst->replaceAllUsesWith(input);
  594. ldInst->eraseFromParent();
  595. return input;
  596. }
  597. }
  598. }
  599. void replaceDirectInputParameter(Value *param, Function *loadInput,
  600. unsigned cols, MutableArrayRef<Value *> args,
  601. bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
  602. Value *zero = hlslOP->GetU32Const(0);
  603. Type *Ty = param->getType();
  604. Type *EltTy = Ty->getScalarType();
  605. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  606. Value *newVec = llvm::UndefValue::get(VT);
  607. DXASSERT(cols == VT->getNumElements(), "vec size must match");
  608. for (unsigned col = 0; col < cols; col++) {
  609. Value *colIdx = hlslOP->GetU8Const(col);
  610. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  611. Value *input =
  612. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  613. newVec = Builder.CreateInsertElement(newVec, input, col);
  614. }
  615. param->replaceAllUsesWith(newVec);
  616. // THe individual loadInputs are the authoritative source of values for the vector.
  617. dxilutil::TryScatterDebugValueToVectorElements(newVec);
  618. } else if (!Ty->isArrayTy() && !HLMatrixType::isa(Ty)) {
  619. DXASSERT(cols == 1, "only support scalar here");
  620. Value *colIdx = hlslOP->GetU8Const(0);
  621. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  622. Value *input =
  623. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  624. param->replaceAllUsesWith(input); // Will properly relocate any DbgValueInst
  625. } else if (HLMatrixType::isa(Ty)) {
  626. if (param->use_empty()) return;
  627. DXASSERT(param->hasOneUse(),
  628. "matrix arg should only has one use as matrix to vec");
  629. CallInst *CI = cast<CallInst>(param->user_back());
  630. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  631. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLCast,
  632. "must be hlcast here");
  633. unsigned opcode = GetHLOpcode(CI);
  634. HLCastOpcode matOp = static_cast<HLCastOpcode>(opcode);
  635. switch (matOp) {
  636. case HLCastOpcode::ColMatrixToVecCast: {
  637. IRBuilder<> LocalBuilder(CI);
  638. HLMatrixType MatTy = HLMatrixType::cast(
  639. CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
  640. Type *EltTy = MatTy.getElementTypeForReg();
  641. std::vector<Value *> matElts(MatTy.getNumElements());
  642. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  643. Value *rowIdx = hlslOP->GetI32Const(c);
  644. args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
  645. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  646. Value *colIdx = hlslOP->GetU8Const(r);
  647. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  648. Value *input =
  649. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  650. matElts[MatTy.getColumnMajorIndex(r, c)] = input;
  651. }
  652. }
  653. Value *newVec =
  654. HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
  655. CI->replaceAllUsesWith(newVec);
  656. CI->eraseFromParent();
  657. } break;
  658. case HLCastOpcode::RowMatrixToVecCast: {
  659. IRBuilder<> LocalBuilder(CI);
  660. HLMatrixType MatTy = HLMatrixType::cast(
  661. CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
  662. Type *EltTy = MatTy.getElementTypeForReg();
  663. std::vector<Value *> matElts(MatTy.getNumElements());
  664. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  665. Value *rowIdx = hlslOP->GetI32Const(r);
  666. args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
  667. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  668. Value *colIdx = hlslOP->GetU8Const(c);
  669. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  670. Value *input =
  671. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  672. matElts[MatTy.getRowMajorIndex(r, c)] = input;
  673. }
  674. }
  675. Value *newVec =
  676. HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
  677. CI->replaceAllUsesWith(newVec);
  678. CI->eraseFromParent();
  679. } break;
  680. default:
  681. // Only matrix to vector casts are valid.
  682. break;
  683. }
  684. } else {
  685. DXASSERT(0, "invalid type for direct input");
  686. }
  687. }
  688. struct InputOutputAccessInfo {
  689. // For input output which has only 1 row, idx is 0.
  690. Value *idx;
  691. // VertexID for HS/DS/GS input, MS vertex output. PrimitiveID for MS primitive output
  692. Value *vertexOrPrimID;
  693. // Vector index.
  694. Value *vectorIdx;
  695. // Load/Store/LoadMat/StoreMat on input/output.
  696. Instruction *user;
  697. InputOutputAccessInfo(Value *index, Instruction *I)
  698. : idx(index), vertexOrPrimID(nullptr), vectorIdx(nullptr), user(I) {}
  699. InputOutputAccessInfo(Value *index, Instruction *I, Value *ID, Value *vecIdx)
  700. : idx(index), vertexOrPrimID(ID), vectorIdx(vecIdx), user(I) {}
  701. };
  702. void collectInputOutputAccessInfo(
  703. Value *GV, Constant *constZero,
  704. std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexOrPrimID,
  705. bool bInput, bool bRowMajor, bool isMS) {
  706. // merge GEP use for input output.
  707. HLModule::MergeGepUse(GV);
  708. for (auto User = GV->user_begin(); User != GV->user_end();) {
  709. Value *I = *(User++);
  710. if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
  711. if (bInput) {
  712. InputOutputAccessInfo info = {constZero, ldInst};
  713. accessInfoList.push_back(info);
  714. }
  715. } else if (StoreInst *stInst = dyn_cast<StoreInst>(I)) {
  716. if (!bInput) {
  717. InputOutputAccessInfo info = {constZero, stInst};
  718. accessInfoList.push_back(info);
  719. }
  720. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  721. // Vector indexing may has more indices.
  722. // Vector indexing changed to array indexing in SROA_HLSL.
  723. auto idx = GEP->idx_begin();
  724. DXASSERT_LOCALVAR(idx, idx->get() == constZero,
  725. "only support 0 offset for input pointer");
  726. Value *vertexOrPrimID = nullptr;
  727. Value *vectorIdx = nullptr;
  728. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  729. // Skip first pointer idx which must be 0.
  730. GEPIt++;
  731. if (hasVertexOrPrimID) {
  732. // Save vertexOrPrimID.
  733. vertexOrPrimID = GEPIt.getOperand();
  734. GEPIt++;
  735. }
  736. // Start from first index.
  737. Value *rowIdx = GEPIt.getOperand();
  738. if (GEPIt != E) {
  739. if ((*GEPIt)->isVectorTy()) {
  740. // Vector indexing.
  741. rowIdx = constZero;
  742. vectorIdx = GEPIt.getOperand();
  743. DXASSERT_NOMSG((++GEPIt) == E);
  744. } else {
  745. // Array which may have vector indexing.
  746. // Highest dim index is saved in rowIdx,
  747. // array size for highest dim not affect index.
  748. GEPIt++;
  749. IRBuilder<> Builder(GEP);
  750. Type *idxTy = rowIdx->getType();
  751. for (; GEPIt != E; ++GEPIt) {
  752. DXASSERT(!GEPIt->isStructTy(),
  753. "Struct should be flattened SROA_Parameter_HLSL");
  754. DXASSERT(!GEPIt->isPointerTy(),
  755. "not support pointer type in middle of GEP");
  756. if (GEPIt->isArrayTy()) {
  757. Constant *arraySize =
  758. ConstantInt::get(idxTy, GEPIt->getArrayNumElements());
  759. rowIdx = Builder.CreateMul(rowIdx, arraySize);
  760. rowIdx = Builder.CreateAdd(rowIdx, GEPIt.getOperand());
  761. } else {
  762. Type *Ty = *GEPIt;
  763. DXASSERT_LOCALVAR(Ty, Ty->isVectorTy(),
  764. "must be vector type here to index");
  765. // Save vector idx.
  766. vectorIdx = GEPIt.getOperand();
  767. }
  768. }
  769. if (HLMatrixType MatTy = HLMatrixType::dyn_cast(*GEPIt)) {
  770. Constant *arraySize = ConstantInt::get(idxTy, MatTy.getNumColumns());
  771. if (bRowMajor) {
  772. arraySize = ConstantInt::get(idxTy, MatTy.getNumRows());
  773. }
  774. rowIdx = Builder.CreateMul(rowIdx, arraySize);
  775. }
  776. }
  777. } else
  778. rowIdx = constZero;
  779. auto GepUser = GEP->user_begin();
  780. auto GepUserE = GEP->user_end();
  781. Value *idxVal = rowIdx;
  782. for (; GepUser != GepUserE;) {
  783. auto GepUserIt = GepUser++;
  784. if (LoadInst *ldInst = dyn_cast<LoadInst>(*GepUserIt)) {
  785. if (bInput) {
  786. InputOutputAccessInfo info = {idxVal, ldInst, vertexOrPrimID, vectorIdx};
  787. accessInfoList.push_back(info);
  788. }
  789. } else if (StoreInst *stInst = dyn_cast<StoreInst>(*GepUserIt)) {
  790. if (!bInput) {
  791. InputOutputAccessInfo info = {idxVal, stInst, vertexOrPrimID, vectorIdx};
  792. accessInfoList.push_back(info);
  793. }
  794. } else if (CallInst *CI = dyn_cast<CallInst>(*GepUserIt)) {
  795. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  796. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLMatLoadStore,
  797. "input/output should only used by ld/st");
  798. HLMatLoadStoreOpcode opcode = (HLMatLoadStoreOpcode)GetHLOpcode(CI);
  799. if ((opcode == HLMatLoadStoreOpcode::ColMatLoad ||
  800. opcode == HLMatLoadStoreOpcode::RowMatLoad)
  801. ? bInput
  802. : !bInput) {
  803. InputOutputAccessInfo info = {idxVal, CI, vertexOrPrimID, vectorIdx};
  804. accessInfoList.push_back(info);
  805. }
  806. } else {
  807. DXASSERT(0, "input output should only used by ld/st");
  808. }
  809. }
  810. } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
  811. InputOutputAccessInfo info = {constZero, CI};
  812. accessInfoList.push_back(info);
  813. } else {
  814. DXASSERT(0, "input output should only used by ld/st");
  815. }
  816. }
  817. }
  818. void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertexIdx,
  819. Function *ldStFunc, Constant *OpArg, Constant *ID, unsigned cols, bool bI1Cast,
  820. Constant *columnConsts[],
  821. bool bNeedVertexOrPrimID, bool isArrayTy, bool bInput, bool bIsInout) {
  822. Value *idxVal = info.idx;
  823. Value *vertexOrPrimID = undefVertexIdx;
  824. if (bNeedVertexOrPrimID && isArrayTy) {
  825. vertexOrPrimID = info.vertexOrPrimID;
  826. }
  827. if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
  828. SmallVector<Value *, 4> args = {OpArg, ID, idxVal, info.vectorIdx};
  829. if (vertexOrPrimID)
  830. args.emplace_back(vertexOrPrimID);
  831. replaceLdWithLdInput(ldStFunc, ldInst, cols, args, bI1Cast);
  832. } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
  833. if (bInput) {
  834. DXASSERT_LOCALVAR(bIsInout, bIsInout, "input should not have store use.");
  835. } else {
  836. if (!info.vectorIdx) {
  837. replaceStWithStOutput(ldStFunc, stInst, OpArg, ID, idxVal, cols,
  838. vertexOrPrimID, bI1Cast);
  839. } else {
  840. Value *V = stInst->getValueOperand();
  841. Type *Ty = V->getType();
  842. DXASSERT_LOCALVAR(Ty == Ty->getScalarType() && !Ty->isAggregateType(),
  843. Ty, "only support scalar here");
  844. if (ConstantInt *ColIdx = dyn_cast<ConstantInt>(info.vectorIdx)) {
  845. IRBuilder<> Builder(stInst);
  846. if (ColIdx->getType()->getBitWidth() != 8) {
  847. ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
  848. }
  849. SmallVector<Value *, 6> args = {OpArg, ID, idxVal, ColIdx, V};
  850. if (vertexOrPrimID)
  851. args.emplace_back(vertexOrPrimID);
  852. GenerateStOutput(ldStFunc, args, Builder, bI1Cast);
  853. } else {
  854. BasicBlock *BB = stInst->getParent();
  855. BasicBlock *EndBB = BB->splitBasicBlock(stInst);
  856. TerminatorInst *TI = BB->getTerminator();
  857. IRBuilder<> SwitchBuilder(TI);
  858. LLVMContext &Ctx = stInst->getContext();
  859. SwitchInst *Switch =
  860. SwitchBuilder.CreateSwitch(info.vectorIdx, EndBB, cols);
  861. TI->eraseFromParent();
  862. Function *F = EndBB->getParent();
  863. for (unsigned i = 0; i < cols; i++) {
  864. BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
  865. Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
  866. IRBuilder<> CaseBuilder(CaseBB);
  867. ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
  868. SmallVector<Value *, 6> args = {OpArg, ID, idxVal, CaseIdx, V};
  869. if (vertexOrPrimID)
  870. args.emplace_back(vertexOrPrimID);
  871. GenerateStOutput(ldStFunc, args, CaseBuilder, bI1Cast);
  872. CaseBuilder.CreateBr(EndBB);
  873. }
  874. }
  875. // remove stInst
  876. stInst->eraseFromParent();
  877. }
  878. }
  879. } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
  880. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  881. // Intrinsic will be translated later.
  882. if (group == HLOpcodeGroup::HLIntrinsic || group == HLOpcodeGroup::NotHL)
  883. return;
  884. unsigned opcode = GetHLOpcode(CI);
  885. DXASSERT_NOMSG(group == HLOpcodeGroup::HLMatLoadStore);
  886. HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  887. switch (matOp) {
  888. case HLMatLoadStoreOpcode::ColMatLoad:
  889. case HLMatLoadStoreOpcode::RowMatLoad: {
  890. IRBuilder<> LocalBuilder(CI);
  891. HLMatrixType MatTy = HLMatrixType::cast(
  892. CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
  893. ->getType()->getPointerElementType());
  894. std::vector<Value *> matElts(MatTy.getNumElements());
  895. if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
  896. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  897. Constant *constRowIdx = LocalBuilder.getInt32(c);
  898. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  899. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  900. SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
  901. if (vertexOrPrimID)
  902. args.emplace_back(vertexOrPrimID);
  903. Value *input = LocalBuilder.CreateCall(ldStFunc, args);
  904. unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
  905. matElts[matIdx] = input;
  906. }
  907. }
  908. } else {
  909. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  910. Constant *constRowIdx = LocalBuilder.getInt32(r);
  911. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  912. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  913. SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
  914. if (vertexOrPrimID)
  915. args.emplace_back(vertexOrPrimID);
  916. Value *input = LocalBuilder.CreateCall(ldStFunc, args);
  917. unsigned matIdx = MatTy.getRowMajorIndex(r, c);
  918. matElts[matIdx] = input;
  919. }
  920. }
  921. }
  922. Value *newVec =
  923. HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
  924. newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
  925. CI->replaceAllUsesWith(newVec);
  926. CI->eraseFromParent();
  927. } break;
  928. case HLMatLoadStoreOpcode::ColMatStore:
  929. case HLMatLoadStoreOpcode::RowMatStore: {
  930. IRBuilder<> LocalBuilder(CI);
  931. Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  932. HLMatrixType MatTy = HLMatrixType::cast(
  933. CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
  934. ->getType()->getPointerElementType());
  935. Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
  936. if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
  937. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  938. Constant *constColIdx = LocalBuilder.getInt32(c);
  939. Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
  940. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  941. unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
  942. Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
  943. LocalBuilder.CreateCall(ldStFunc,
  944. { OpArg, ID, colIdx, columnConsts[r], Elt });
  945. }
  946. }
  947. } else {
  948. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  949. Constant *constRowIdx = LocalBuilder.getInt32(r);
  950. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  951. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  952. unsigned matIdx = MatTy.getRowMajorIndex(r, c);
  953. Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
  954. LocalBuilder.CreateCall(ldStFunc,
  955. { OpArg, ID, rowIdx, columnConsts[c], Elt });
  956. }
  957. }
  958. }
  959. CI->eraseFromParent();
  960. } break;
  961. }
  962. } else {
  963. DXASSERT(0, "invalid operation on input output");
  964. }
  965. }
  966. } // namespace
  967. void HLSignatureLower::GenerateDxilInputs() {
  968. GenerateDxilInputsOutputs(DXIL::SignatureKind::Input);
  969. }
  970. void HLSignatureLower::GenerateDxilOutputs() {
  971. GenerateDxilInputsOutputs(DXIL::SignatureKind::Output);
  972. }
  973. void HLSignatureLower::GenerateDxilPrimOutputs() {
  974. GenerateDxilInputsOutputs(DXIL::SignatureKind::PatchConstOrPrim);
  975. }
  976. void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
  977. OP *hlslOP = HLM.GetOP();
  978. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  979. Module &M = *(HLM.GetModule());
  980. OP::OpCode opcode = (OP::OpCode)-1;
  981. switch (SK) {
  982. case DXIL::SignatureKind::Input:
  983. opcode = OP::OpCode::LoadInput;
  984. break;
  985. case DXIL::SignatureKind::Output:
  986. opcode = props.IsMS() ? OP::OpCode::StoreVertexOutput : OP::OpCode::StoreOutput;
  987. break;
  988. case DXIL::SignatureKind::PatchConstOrPrim:
  989. opcode = OP::OpCode::StorePrimitiveOutput;
  990. break;
  991. default:
  992. DXASSERT_NOMSG(0);
  993. }
  994. bool bInput = SK == DXIL::SignatureKind::Input;
  995. bool bNeedVertexOrPrimID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
  996. bNeedVertexOrPrimID |= !bInput && props.IsMS();
  997. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  998. Constant *columnConsts[] = {
  999. hlslOP->GetU8Const(0), hlslOP->GetU8Const(1), hlslOP->GetU8Const(2),
  1000. hlslOP->GetU8Const(3), hlslOP->GetU8Const(4), hlslOP->GetU8Const(5),
  1001. hlslOP->GetU8Const(6), hlslOP->GetU8Const(7), hlslOP->GetU8Const(8),
  1002. hlslOP->GetU8Const(9), hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
  1003. hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
  1004. hlslOP->GetU8Const(15)};
  1005. Constant *constZero = hlslOP->GetU32Const(0);
  1006. Value *undefVertexIdx = props.IsMS() || !bInput ? nullptr : UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
  1007. DxilSignature &Sig =
  1008. bInput ? EntrySig.InputSignature :
  1009. SK == DXIL::SignatureKind::Output ? EntrySig.OutputSignature :
  1010. EntrySig.PatchConstOrPrimSignature;
  1011. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  1012. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
  1013. Type *i1Ty = Type::getInt1Ty(constZero->getContext());
  1014. Type *i32Ty = constZero->getType();
  1015. llvm::SmallVector<unsigned, 8> removeIndices;
  1016. for (unsigned i = 0; i < Sig.GetElements().size(); i++) {
  1017. DxilSignatureElement *SE = &Sig.GetElement(i);
  1018. llvm::Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1019. // Cast i1 to i32 for load input.
  1020. bool bI1Cast = false;
  1021. if (Ty == i1Ty) {
  1022. bI1Cast = true;
  1023. Ty = i32Ty;
  1024. }
  1025. if (!hlslOP->IsOverloadLegal(opcode, Ty)) {
  1026. std::string O;
  1027. raw_string_ostream OSS(O);
  1028. Ty->print(OSS);
  1029. OSS << "(type for " << SE->GetName() << ")";
  1030. OSS << " cannot be used as shader inputs or outputs.";
  1031. OSS.flush();
  1032. HLM.GetCtx().emitError(O);
  1033. continue;
  1034. }
  1035. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
  1036. Constant *ID = hlslOP->GetU32Const(i);
  1037. unsigned cols = SE->GetCols();
  1038. Value *GV = m_sigValueMap[SE];
  1039. bool bIsInout = m_inoutArgSet.count(GV) > 0;
  1040. IRBuilder<> EntryBuilder(Entry->getEntryBlock().getFirstInsertionPt());
  1041. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(GV)) {
  1042. EntryBuilder.SetCurrentDebugLocation(DDI->getDebugLoc());
  1043. }
  1044. DXIL::SemanticInterpretationKind SI = SE->GetInterpretation();
  1045. DXASSERT_NOMSG(SI < DXIL::SemanticInterpretationKind::Invalid);
  1046. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NA);
  1047. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NotInSig);
  1048. if (SI == DXIL::SemanticInterpretationKind::Shadow)
  1049. continue; // Handled in ProcessArgument
  1050. if (!GV->getType()->isPointerTy()) {
  1051. DXASSERT(bInput, "direct parameter must be input");
  1052. Value *vertexOrPrimID = undefVertexIdx;
  1053. Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr,
  1054. vertexOrPrimID};
  1055. replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
  1056. EntryBuilder);
  1057. continue;
  1058. }
  1059. bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
  1060. bool bIsPrecise = m_preciseSigSet.count(SE);
  1061. if (bIsPrecise)
  1062. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
  1063. bool bRowMajor = false;
  1064. if (Argument *Arg = dyn_cast<Argument>(GV)) {
  1065. if (pFuncAnnot) {
  1066. auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
  1067. if (paramAnnot.HasMatrixAnnotation())
  1068. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1069. MatrixOrientation::RowMajor;
  1070. }
  1071. }
  1072. std::vector<InputOutputAccessInfo> accessInfoList;
  1073. collectInputOutputAccessInfo(GV, constZero, accessInfoList,
  1074. bNeedVertexOrPrimID && bIsArrayTy, bInput, bRowMajor, props.IsMS());
  1075. for (InputOutputAccessInfo &info : accessInfoList) {
  1076. GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
  1077. cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
  1078. bIsArrayTy, bInput, bIsInout);
  1079. }
  1080. }
  1081. }
  1082. void HLSignatureLower::GenerateDxilCSInputs() {
  1083. OP *hlslOP = HLM.GetOP();
  1084. DxilFunctionAnnotation *funcAnnotation = HLM.GetFunctionAnnotation(Entry);
  1085. DXASSERT(funcAnnotation, "must find annotation for entry function");
  1086. IRBuilder<> Builder(Entry->getEntryBlock().getFirstInsertionPt());
  1087. for (Argument &arg : Entry->args()) {
  1088. DxilParameterAnnotation &paramAnnotation =
  1089. funcAnnotation->GetParameterAnnotation(arg.getArgNo());
  1090. llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
  1091. if (semanticStr.empty()) {
  1092. dxilutil::EmitErrorOnFunction(Entry, "Semantic must be defined for all "
  1093. "parameters of an entry function or patch "
  1094. "constant function");
  1095. return;
  1096. }
  1097. const Semantic *semantic =
  1098. Semantic::GetByName(semanticStr, DXIL::SigPointKind::CSIn);
  1099. OP::OpCode opcode;
  1100. switch (semantic->GetKind()) {
  1101. case Semantic::Kind::GroupThreadID:
  1102. opcode = OP::OpCode::ThreadIdInGroup;
  1103. break;
  1104. case Semantic::Kind::GroupID:
  1105. opcode = OP::OpCode::GroupId;
  1106. break;
  1107. case Semantic::Kind::DispatchThreadID:
  1108. opcode = OP::OpCode::ThreadId;
  1109. break;
  1110. case Semantic::Kind::GroupIndex:
  1111. opcode = OP::OpCode::FlattenedThreadIdInGroup;
  1112. break;
  1113. default:
  1114. DXASSERT(semantic->IsInvalid(),
  1115. "else compute shader semantics out-of-date");
  1116. dxilutil::EmitErrorOnFunction(Entry, "invalid semantic found in CS");
  1117. return;
  1118. }
  1119. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1120. Type *NumTy = arg.getType();
  1121. DXASSERT(!NumTy->isPointerTy(), "Unexpected byref value for CS SV_***ID semantic.");
  1122. DXASSERT(NumTy->getScalarType()->isIntegerTy(), "Unexpected non-integer value for CS SV_***ID semantic.");
  1123. // Always use the i32 overload of those intrinsics, and then cast as needed
  1124. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Builder.getInt32Ty());
  1125. Value *newArg = nullptr;
  1126. if (opcode == OP::OpCode::FlattenedThreadIdInGroup) {
  1127. newArg = Builder.CreateCall(dxilFunc, {OpArg});
  1128. } else {
  1129. unsigned vecSize = 1;
  1130. if (NumTy->isVectorTy())
  1131. vecSize = NumTy->getVectorNumElements();
  1132. newArg = Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU32Const(0)});
  1133. if (vecSize > 1) {
  1134. Value *result = UndefValue::get(VectorType::get(Builder.getInt32Ty(), vecSize));
  1135. result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
  1136. for (unsigned i = 1; i < vecSize; i++) {
  1137. Value *newElt =
  1138. Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU32Const(i)});
  1139. result = Builder.CreateInsertElement(result, newElt, i);
  1140. }
  1141. newArg = result;
  1142. }
  1143. }
  1144. // If the argument is of non-i32 type, convert here
  1145. if (newArg->getType() != NumTy)
  1146. newArg = Builder.CreateZExtOrTrunc(newArg, NumTy);
  1147. if (newArg->getType() != arg.getType()) {
  1148. DXASSERT_NOMSG(arg.getType()->isPointerTy());
  1149. for (User *U : arg.users()) {
  1150. LoadInst *LI = cast<LoadInst>(U);
  1151. LI->replaceAllUsesWith(newArg);
  1152. }
  1153. } else {
  1154. arg.replaceAllUsesWith(newArg);
  1155. }
  1156. }
  1157. }
  1158. void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
  1159. OP *hlslOP = HLM.GetOP();
  1160. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  1161. Module &M = *(HLM.GetModule());
  1162. Constant *constZero = hlslOP->GetU32Const(0);
  1163. DxilSignature &Sig = EntrySig.PatchConstOrPrimSignature;
  1164. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  1165. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
  1166. auto InsertPt = Entry->getEntryBlock().getFirstInsertionPt();
  1167. const bool bIsHs = props.IsHS();
  1168. const bool bIsInput = !bIsHs;
  1169. const bool bIsInout = false;
  1170. const bool bNeedVertexOrPrimID = false;
  1171. if (bIsHs) {
  1172. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1173. Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
  1174. InsertPt = patchConstantFunc->getEntryBlock().getFirstInsertionPt();
  1175. pFuncAnnot = typeSys.GetFunctionAnnotation(patchConstantFunc);
  1176. }
  1177. IRBuilder<> Builder(InsertPt);
  1178. Type *i1Ty = Builder.getInt1Ty();
  1179. Type *i32Ty = Builder.getInt32Ty();
  1180. // LoadPatchConst don't have vertexIdx operand.
  1181. Value *undefVertexIdx = nullptr;
  1182. Constant *columnConsts[] = {
  1183. hlslOP->GetU8Const(0), hlslOP->GetU8Const(1), hlslOP->GetU8Const(2),
  1184. hlslOP->GetU8Const(3), hlslOP->GetU8Const(4), hlslOP->GetU8Const(5),
  1185. hlslOP->GetU8Const(6), hlslOP->GetU8Const(7), hlslOP->GetU8Const(8),
  1186. hlslOP->GetU8Const(9), hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
  1187. hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
  1188. hlslOP->GetU8Const(15)};
  1189. OP::OpCode opcode =
  1190. bIsInput ? OP::OpCode::LoadPatchConstant : OP::OpCode::StorePatchConstant;
  1191. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1192. for (unsigned i = 0; i < Sig.GetElements().size(); i++) {
  1193. DxilSignatureElement *SE = &Sig.GetElement(i);
  1194. Value *GV = m_sigValueMap[SE];
  1195. DXIL::SemanticInterpretationKind SI = SE->GetInterpretation();
  1196. DXASSERT_NOMSG(SI < DXIL::SemanticInterpretationKind::Invalid);
  1197. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NA);
  1198. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NotInSig);
  1199. if (SI == DXIL::SemanticInterpretationKind::Shadow)
  1200. continue; // Handled in ProcessArgument
  1201. Constant *ID = hlslOP->GetU32Const(i);
  1202. // Generate LoadPatchConstant.
  1203. Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1204. // Cast i1 to i32 for load input.
  1205. bool bI1Cast = false;
  1206. if (Ty == i1Ty) {
  1207. bI1Cast = true;
  1208. Ty = i32Ty;
  1209. }
  1210. unsigned cols = SE->GetCols();
  1211. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
  1212. if (!GV->getType()->isPointerTy()) {
  1213. DXASSERT(bIsInput, "Must be DS input.");
  1214. Constant *OpArg = hlslOP->GetU32Const(
  1215. static_cast<unsigned>(OP::OpCode::LoadPatchConstant));
  1216. Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr};
  1217. replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
  1218. Builder);
  1219. continue;
  1220. }
  1221. bool bRowMajor = false;
  1222. if (Argument *Arg = dyn_cast<Argument>(GV)) {
  1223. if (pFuncAnnot) {
  1224. auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
  1225. if (paramAnnot.HasMatrixAnnotation())
  1226. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1227. MatrixOrientation::RowMajor;
  1228. }
  1229. }
  1230. std::vector<InputOutputAccessInfo> accessInfoList;
  1231. collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexOrPrimID,
  1232. bIsInput, bRowMajor, false);
  1233. bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
  1234. bool isPrecise = m_preciseSigSet.count(SE);
  1235. if (isPrecise)
  1236. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
  1237. for (InputOutputAccessInfo &info : accessInfoList) {
  1238. GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
  1239. cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
  1240. bIsArrayTy, bIsInput, bIsInout);
  1241. }
  1242. }
  1243. }
  1244. void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
  1245. // Map input patch, to input sig
  1246. // LoadOutputControlPoint for output patch .
  1247. OP *hlslOP = HLM.GetOP();
  1248. Constant *constZero = hlslOP->GetU32Const(0);
  1249. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1250. Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
  1251. DxilFunctionAnnotation *patchFuncAnnotation =
  1252. HLM.GetFunctionAnnotation(patchConstantFunc);
  1253. DXASSERT(patchFuncAnnotation,
  1254. "must find annotation for patch constant function");
  1255. Type *i1Ty = Type::getInt1Ty(constZero->getContext());
  1256. Type *i32Ty = constZero->getType();
  1257. for (Argument &arg : patchConstantFunc->args()) {
  1258. DxilParameterAnnotation &paramAnnotation =
  1259. patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
  1260. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1261. if (inputQual == DxilParamInputQual::InputPatch ||
  1262. inputQual == DxilParamInputQual::OutputPatch) {
  1263. DxilSignatureElement *SE = m_patchConstantInputsSigMap[arg.getArgNo()];
  1264. if (!SE) // Error should have been reported at an earlier stage.
  1265. continue;
  1266. Constant *inputID = hlslOP->GetU32Const(SE->GetID());
  1267. unsigned cols = SE->GetCols();
  1268. Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1269. // Cast i1 to i32 for load input.
  1270. bool bI1Cast = false;
  1271. if (Ty == i1Ty) {
  1272. bI1Cast = true;
  1273. Ty = i32Ty;
  1274. }
  1275. OP::OpCode opcode = inputQual == DxilParamInputQual::InputPatch
  1276. ? OP::OpCode::LoadInput
  1277. : OP::OpCode::LoadOutputControlPoint;
  1278. Function *dxilLdFunc = hlslOP->GetOpFunc(opcode, Ty);
  1279. bool bRowMajor = false;
  1280. if (Argument *Arg = dyn_cast<Argument>(&arg)) {
  1281. if (patchFuncAnnotation) {
  1282. auto &paramAnnot = patchFuncAnnotation->GetParameterAnnotation(Arg->getArgNo());
  1283. if (paramAnnot.HasMatrixAnnotation())
  1284. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1285. MatrixOrientation::RowMajor;
  1286. }
  1287. }
  1288. std::vector<InputOutputAccessInfo> accessInfoList;
  1289. collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
  1290. /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
  1291. for (InputOutputAccessInfo &info : accessInfoList) {
  1292. if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
  1293. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1294. Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
  1295. info.vertexOrPrimID};
  1296. replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
  1297. } else {
  1298. DXASSERT(0, "input should only be ld");
  1299. }
  1300. }
  1301. }
  1302. }
  1303. }
  1304. bool HLSignatureLower::HasClipPlanes() {
  1305. if (!HLM.HasDxilFunctionProps(Entry))
  1306. return false;
  1307. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1308. auto &VS = EntryQual.ShaderProps.VS;
  1309. unsigned numClipPlanes = 0;
  1310. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  1311. if (!VS.clipPlanes[i])
  1312. break;
  1313. numClipPlanes++;
  1314. }
  1315. return numClipPlanes != 0;
  1316. }
  1317. void HLSignatureLower::GenerateClipPlanesForVS(Value *outPosition) {
  1318. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1319. auto &VS = EntryQual.ShaderProps.VS;
  1320. unsigned numClipPlanes = 0;
  1321. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  1322. if (!VS.clipPlanes[i])
  1323. break;
  1324. numClipPlanes++;
  1325. }
  1326. if (!numClipPlanes)
  1327. return;
  1328. LLVMContext &Ctx = HLM.GetCtx();
  1329. Function *dp4 =
  1330. HLM.GetOP()->GetOpFunc(DXIL::OpCode::Dot4, Type::getFloatTy(Ctx));
  1331. Value *dp4Args[] = {
  1332. ConstantInt::get(Type::getInt32Ty(Ctx),
  1333. static_cast<unsigned>(DXIL::OpCode::Dot4)),
  1334. nullptr,
  1335. nullptr,
  1336. nullptr,
  1337. nullptr,
  1338. nullptr,
  1339. nullptr,
  1340. nullptr,
  1341. nullptr,
  1342. };
  1343. // out SV_Position should only have StoreInst use.
  1344. // Done by LegalizeDxilInputOutputs in ScalarReplAggregatesHLSL.cpp
  1345. for (User *U : outPosition->users()) {
  1346. StoreInst *ST = cast<StoreInst>(U);
  1347. Value *posVal = ST->getValueOperand();
  1348. DXASSERT(posVal->getType()->isVectorTy(), "SV_Position must be a vector");
  1349. IRBuilder<> Builder(ST);
  1350. // Put position to args.
  1351. for (unsigned i = 0; i < 4; i++)
  1352. dp4Args[i + 1] = Builder.CreateExtractElement(posVal, i);
  1353. // For each clip plane.
  1354. // clipDistance = dp4 position, clipPlane.
  1355. auto argIt = Entry->getArgumentList().rbegin();
  1356. for (int clipIdx = numClipPlanes - 1; clipIdx >= 0; clipIdx--) {
  1357. Constant *GV = VS.clipPlanes[clipIdx];
  1358. DXASSERT_NOMSG(GV->hasOneUse());
  1359. StoreInst *ST = cast<StoreInst>(GV->user_back());
  1360. Value *clipPlane = ST->getValueOperand();
  1361. ST->eraseFromParent();
  1362. Argument &arg = *(argIt++);
  1363. // Put clipPlane to args.
  1364. for (unsigned i = 0; i < 4; i++)
  1365. dp4Args[i + 5] = Builder.CreateExtractElement(clipPlane, i);
  1366. Value *clipDistance = Builder.CreateCall(dp4, dp4Args);
  1367. Builder.CreateStore(clipDistance, &arg);
  1368. }
  1369. }
  1370. }
  1371. namespace {
  1372. Value *TranslateStreamAppend(CallInst *CI, unsigned ID, hlsl::OP *OP) {
  1373. Function *DxilFunc = OP->GetOpFunc(OP::OpCode::EmitStream, CI->getType());
  1374. // TODO: generate a emit which has the data being emited as its argment.
  1375. // Value *data = CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  1376. Constant *opArg = OP->GetU32Const((unsigned)OP::OpCode::EmitStream);
  1377. IRBuilder<> Builder(CI);
  1378. Constant *streamID = OP->GetU8Const(ID);
  1379. Value *args[] = {opArg, streamID};
  1380. return Builder.CreateCall(DxilFunc, args);
  1381. }
  1382. Value *TranslateStreamCut(CallInst *CI, unsigned ID, hlsl::OP *OP) {
  1383. Function *DxilFunc = OP->GetOpFunc(OP::OpCode::CutStream, CI->getType());
  1384. // TODO: generate a emit which has the data being emited as its argment.
  1385. // Value *data = CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  1386. Constant *opArg = OP->GetU32Const((unsigned)OP::OpCode::CutStream);
  1387. IRBuilder<> Builder(CI);
  1388. Constant *streamID = OP->GetU8Const(ID);
  1389. Value *args[] = {opArg, streamID};
  1390. return Builder.CreateCall(DxilFunc, args);
  1391. }
  1392. } // namespace
  1393. // Generate DXIL stream output operation.
  1394. void HLSignatureLower::GenerateStreamOutputOperation(Value *streamVal, unsigned ID) {
  1395. OP * hlslOP = HLM.GetOP();
  1396. for (auto U = streamVal->user_begin(); U != streamVal->user_end();) {
  1397. Value *user = *(U++);
  1398. // Should only used by append, restartStrip .
  1399. CallInst *CI = cast<CallInst>(user);
  1400. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1401. // Ignore user functions.
  1402. if (group == HLOpcodeGroup::NotHL)
  1403. continue;
  1404. unsigned opcode = GetHLOpcode(CI);
  1405. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLIntrinsic, "Must be HLIntrinsic here");
  1406. IntrinsicOp IOP = static_cast<IntrinsicOp>(opcode);
  1407. switch (IOP) {
  1408. case IntrinsicOp::MOP_Append:
  1409. TranslateStreamAppend(CI, ID, hlslOP);
  1410. break;
  1411. case IntrinsicOp::MOP_RestartStrip:
  1412. TranslateStreamCut(CI, ID, hlslOP);
  1413. break;
  1414. default:
  1415. DXASSERT(0, "invalid operation on stream");
  1416. }
  1417. CI->eraseFromParent();
  1418. }
  1419. }
  1420. // Generate DXIL stream output operations.
  1421. void HLSignatureLower::GenerateStreamOutputOperations() {
  1422. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1423. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1424. for (Argument &arg : Entry->getArgumentList()) {
  1425. if (HLModule::IsStreamOutputPtrType(arg.getType())) {
  1426. unsigned streamID = 0;
  1427. DxilParameterAnnotation &paramAnnotation =
  1428. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1429. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1430. switch (inputQual) {
  1431. case DxilParamInputQual::OutStream0:
  1432. streamID = 0;
  1433. break;
  1434. case DxilParamInputQual::OutStream1:
  1435. streamID = 1;
  1436. break;
  1437. case DxilParamInputQual::OutStream2:
  1438. streamID = 2;
  1439. break;
  1440. case DxilParamInputQual::OutStream3:
  1441. default:
  1442. DXASSERT(inputQual == DxilParamInputQual::OutStream3,
  1443. "invalid input qual.");
  1444. streamID = 3;
  1445. break;
  1446. }
  1447. GenerateStreamOutputOperation(&arg, streamID);
  1448. }
  1449. }
  1450. }
  1451. // Generate DXIL EmitIndices operation.
  1452. void HLSignatureLower::GenerateEmitIndicesOperation(Value *indicesOutput) {
  1453. OP * hlslOP = HLM.GetOP();
  1454. Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::EmitIndices, Type::getVoidTy(indicesOutput->getContext()));
  1455. Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::EmitIndices);
  1456. for (auto U = indicesOutput->user_begin(); U != indicesOutput->user_end();) {
  1457. Value *user = *(U++);
  1458. GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
  1459. auto idx = GEP->idx_begin();
  1460. DXASSERT_LOCALVAR(idx, idx->get() == hlslOP->GetU32Const(0),
  1461. "only support 0 offset for input pointer");
  1462. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  1463. // Skip first pointer idx which must be 0.
  1464. GEPIt++;
  1465. Value *primIdx = GEPIt.getOperand();
  1466. DXASSERT(++GEPIt == E, "invalid GEP here"); (void)E;
  1467. auto GepUser = GEP->user_begin();
  1468. auto GepUserE = GEP->user_end();
  1469. for (; GepUser != GepUserE;) {
  1470. auto GepUserIt = GepUser++;
  1471. StoreInst *stInst = cast<StoreInst>(*GepUserIt);
  1472. Value *stVal = stInst->getValueOperand();
  1473. VectorType *VT = cast<VectorType>(stVal->getType());
  1474. unsigned eleCount = VT->getNumElements();
  1475. IRBuilder<> Builder(stInst);
  1476. Value *subVal0 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(0));
  1477. Value *subVal1 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(1));
  1478. Value *subVal2 = eleCount == 3 ?
  1479. Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(2)) : hlslOP->GetU32Const(0);
  1480. Value *args[] = { opArg, primIdx, subVal0, subVal1, subVal2 };
  1481. Builder.CreateCall(DxilFunc, args);
  1482. stInst->eraseFromParent();
  1483. }
  1484. GEP->eraseFromParent();
  1485. }
  1486. }
  1487. // Generate DXIL EmitIndices operations.
  1488. void HLSignatureLower::GenerateEmitIndicesOperations() {
  1489. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1490. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1491. for (Argument &arg : Entry->getArgumentList()) {
  1492. DxilParameterAnnotation &paramAnnotation =
  1493. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1494. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1495. if (inputQual == DxilParamInputQual::OutIndices) {
  1496. GenerateEmitIndicesOperation(&arg);
  1497. }
  1498. }
  1499. }
  1500. // Generate DXIL GetMeshPayload operation.
  1501. void HLSignatureLower::GenerateGetMeshPayloadOperation() {
  1502. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1503. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1504. for (Argument &arg : Entry->getArgumentList()) {
  1505. DxilParameterAnnotation &paramAnnotation =
  1506. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1507. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1508. if (inputQual == DxilParamInputQual::InPayload) {
  1509. OP * hlslOP = HLM.GetOP();
  1510. Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::GetMeshPayload, arg.getType());
  1511. Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::GetMeshPayload);
  1512. IRBuilder<> Builder(arg.getParent()->getEntryBlock().getFirstInsertionPt());
  1513. Value *args[] = { opArg };
  1514. Value *payload = Builder.CreateCall(DxilFunc, args);
  1515. arg.replaceAllUsesWith(payload);
  1516. }
  1517. }
  1518. }
  1519. // Lower signatures.
  1520. void HLSignatureLower::Run() {
  1521. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  1522. if (props.IsGraphics()) {
  1523. if (props.IsMS()) {
  1524. GenerateEmitIndicesOperations();
  1525. GenerateGetMeshPayloadOperation();
  1526. }
  1527. CreateDxilSignatures();
  1528. // Allocate input output.
  1529. AllocateDxilInputOutputs();
  1530. GenerateDxilInputs();
  1531. GenerateDxilOutputs();
  1532. if (props.IsMS()) {
  1533. GenerateDxilPrimOutputs();
  1534. }
  1535. } else if (props.IsCS()) {
  1536. GenerateDxilCSInputs();
  1537. }
  1538. if (props.IsDS() || props.IsHS())
  1539. GenerateDxilPatchConstantLdSt();
  1540. if (props.IsHS())
  1541. GenerateDxilPatchConstantFunctionInputs();
  1542. if (props.IsGS())
  1543. GenerateStreamOutputOperations();
  1544. }