HLSignatureLower.cpp 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLSignatureLower.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Lower signatures of entry function to DXIL LoadInput/StoreOutput. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "HLSignatureLower.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilSignatureElement.h"
  14. #include "dxc/DXIL/DxilSigPoint.h"
  15. #include "dxc/Support/Global.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "dxc/DXIL/DxilSemantic.h"
  18. #include "dxc/HLSL/HLModule.h"
  19. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  20. #include "dxc/HLSL/HLMatrixType.h"
  21. #include "dxc/HlslIntrinsicOp.h"
  22. #include "dxc/DXIL/DxilUtil.h"
  23. #include "dxc/HLSL/DxilPackSignatureElement.h"
  24. #include "llvm/IR/IRBuilder.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/IntrinsicInst.h"
  27. #include "llvm/IR/Module.h"
  28. #include "llvm/Transforms/Utils/Local.h"
  29. using namespace llvm;
  30. using namespace hlsl;
  31. namespace {
  32. // Decompose semantic name (eg FOO1=>FOO,1), change interp mode for SV_Position.
  33. // Return semantic index.
  34. unsigned UpdateSemanticAndInterpMode(StringRef &semName,
  35. DXIL::InterpolationMode &mode,
  36. DXIL::SigPointKind kind,
  37. LLVMContext &Context) {
  38. llvm::StringRef baseSemName; // The 'FOO' in 'FOO1'.
  39. uint32_t semIndex; // The '1' in 'FOO1'
  40. // Split semName and index.
  41. Semantic::DecomposeNameAndIndex(semName, &baseSemName, &semIndex);
  42. semName = baseSemName;
  43. const Semantic *semantic = Semantic::GetByName(semName, kind);
  44. if (semantic && semantic->GetKind() == Semantic::Kind::Position) {
  45. // Update interp mode to no_perspective version for SV_Position.
  46. switch (mode) {
  47. case InterpolationMode::Kind::LinearCentroid:
  48. mode = InterpolationMode::Kind::LinearNoperspectiveCentroid;
  49. break;
  50. case InterpolationMode::Kind::LinearSample:
  51. mode = InterpolationMode::Kind::LinearNoperspectiveSample;
  52. break;
  53. case InterpolationMode::Kind::Linear:
  54. mode = InterpolationMode::Kind::LinearNoperspective;
  55. break;
  56. case InterpolationMode::Kind::Constant:
  57. case InterpolationMode::Kind::Undefined:
  58. case InterpolationMode::Kind::Invalid: {
  59. Context.emitError("invalid interpolation mode for SV_Position");
  60. } break;
  61. case InterpolationMode::Kind::LinearNoperspective:
  62. case InterpolationMode::Kind::LinearNoperspectiveCentroid:
  63. case InterpolationMode::Kind::LinearNoperspectiveSample:
  64. // Already Noperspective modes.
  65. break;
  66. }
  67. }
  68. return semIndex;
  69. }
  70. DxilSignatureElement *FindArgInSignature(Argument &arg,
  71. llvm::StringRef semantic,
  72. DXIL::InterpolationMode interpMode,
  73. DXIL::SigPointKind kind,
  74. DxilSignature &sig) {
  75. // Match output ID.
  76. unsigned semIndex =
  77. UpdateSemanticAndInterpMode(semantic, interpMode, kind, arg.getContext());
  78. for (uint32_t i = 0; i < sig.GetElements().size(); i++) {
  79. DxilSignatureElement &SE = sig.GetElement(i);
  80. bool semNameMatch = semantic.equals_lower(SE.GetName());
  81. bool semIndexMatch = semIndex == SE.GetSemanticIndexVec()[0];
  82. if (semNameMatch && semIndexMatch) {
  83. // Find a match.
  84. return &SE;
  85. }
  86. }
  87. return nullptr;
  88. }
  89. } // namespace
  90. namespace {
  91. void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *GV,
  92. OP *hlslOP, IRBuilder<> &Builder) {
  93. Type *Ty = GV->getType();
  94. if (Ty->isPointerTy())
  95. Ty = Ty->getPointerElementType();
  96. OP::OpCode opcode;
  97. switch (semKind) {
  98. case Semantic::Kind::DomainLocation:
  99. opcode = OP::OpCode::DomainLocation;
  100. break;
  101. case Semantic::Kind::OutputControlPointID:
  102. opcode = OP::OpCode::OutputControlPointID;
  103. break;
  104. case Semantic::Kind::GSInstanceID:
  105. opcode = OP::OpCode::GSInstanceID;
  106. break;
  107. case Semantic::Kind::PrimitiveID:
  108. opcode = OP::OpCode::PrimitiveID;
  109. break;
  110. case Semantic::Kind::SampleIndex:
  111. opcode = OP::OpCode::SampleIndex;
  112. break;
  113. case Semantic::Kind::Coverage:
  114. opcode = OP::OpCode::Coverage;
  115. break;
  116. case Semantic::Kind::InnerCoverage:
  117. opcode = OP::OpCode::InnerCoverage;
  118. break;
  119. case Semantic::Kind::ViewID:
  120. opcode = OP::OpCode::ViewID;
  121. break;
  122. case Semantic::Kind::GroupThreadID:
  123. opcode = OP::OpCode::ThreadIdInGroup;
  124. break;
  125. case Semantic::Kind::GroupID:
  126. opcode = OP::OpCode::GroupId;
  127. break;
  128. case Semantic::Kind::DispatchThreadID:
  129. opcode = OP::OpCode::ThreadId;
  130. break;
  131. case Semantic::Kind::GroupIndex:
  132. opcode = OP::OpCode::FlattenedThreadIdInGroup;
  133. break;
  134. case Semantic::Kind::CullPrimitive: {
  135. GV->replaceAllUsesWith(ConstantInt::get(Ty, (uint64_t)0));
  136. return;
  137. } break;
  138. default:
  139. DXASSERT(0, "invalid semantic");
  140. return;
  141. }
  142. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty->getScalarType());
  143. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  144. Value *newArg = nullptr;
  145. if (semKind == Semantic::Kind::DomainLocation ||
  146. semKind == Semantic::Kind::GroupThreadID ||
  147. semKind == Semantic::Kind::GroupID ||
  148. semKind == Semantic::Kind::DispatchThreadID) {
  149. unsigned vecSize = 1;
  150. if (Ty->isVectorTy())
  151. vecSize = Ty->getVectorNumElements();
  152. newArg = Builder.CreateCall(dxilFunc, { OpArg,
  153. semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(0) : hlslOP->GetU32Const(0) });
  154. if (vecSize > 1) {
  155. Value *result = UndefValue::get(Ty);
  156. result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
  157. for (unsigned i = 1; i < vecSize; i++) {
  158. Value *newElt =
  159. Builder.CreateCall(dxilFunc, { OpArg,
  160. semKind == Semantic::Kind::DomainLocation ? hlslOP->GetU8Const(i)
  161. : hlslOP->GetU32Const(i) });
  162. result = Builder.CreateInsertElement(result, newElt, i);
  163. }
  164. newArg = result;
  165. }
  166. } else {
  167. newArg = Builder.CreateCall(dxilFunc, {OpArg});
  168. }
  169. if (newArg->getType() != GV->getType()) {
  170. DXASSERT_NOMSG(GV->getType()->isPointerTy());
  171. for (User *U : GV->users()) {
  172. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  173. LI->replaceAllUsesWith(newArg);
  174. }
  175. }
  176. } else {
  177. GV->replaceAllUsesWith(newArg);
  178. }
  179. }
  180. } // namespace
  181. void HLSignatureLower::ProcessArgument(Function *func,
  182. DxilFunctionAnnotation *funcAnnotation,
  183. Argument &arg, DxilFunctionProps &props,
  184. const ShaderModel *pSM,
  185. bool isPatchConstantFunction,
  186. bool forceOut, bool &hasClipPlane) {
  187. Type *Ty = arg.getType();
  188. DxilParameterAnnotation &paramAnnotation =
  189. funcAnnotation->GetParameterAnnotation(arg.getArgNo());
  190. hlsl::DxilParamInputQual qual =
  191. forceOut ? DxilParamInputQual::Out : paramAnnotation.GetParamInputQual();
  192. bool isInout = qual == DxilParamInputQual::Inout;
  193. // If this was an inout param, do the output side first
  194. if (isInout) {
  195. DXASSERT(!isPatchConstantFunction,
  196. "Patch Constant function should not have inout param");
  197. m_inoutArgSet.insert(&arg);
  198. const bool bForceOutTrue = true;
  199. ProcessArgument(func, funcAnnotation, arg, props, pSM,
  200. isPatchConstantFunction, bForceOutTrue, hasClipPlane);
  201. qual = DxilParamInputQual::In;
  202. }
  203. // Get stream index
  204. unsigned streamIdx = 0;
  205. switch (qual) {
  206. case DxilParamInputQual::OutStream1:
  207. streamIdx = 1;
  208. break;
  209. case DxilParamInputQual::OutStream2:
  210. streamIdx = 2;
  211. break;
  212. case DxilParamInputQual::OutStream3:
  213. streamIdx = 3;
  214. break;
  215. default:
  216. // Use streamIdx = 0 by default.
  217. break;
  218. }
  219. const SigPoint *sigPoint = SigPoint::GetSigPoint(
  220. SigPointFromInputQual(qual, props.shaderKind, isPatchConstantFunction));
  221. unsigned rows, cols;
  222. HLModule::GetParameterRowsAndCols(Ty, rows, cols, paramAnnotation);
  223. CompType EltTy = paramAnnotation.GetCompType();
  224. DXIL::InterpolationMode interpMode =
  225. paramAnnotation.GetInterpolationMode().GetKind();
  226. // Set undefined interpMode.
  227. if (sigPoint->GetKind() == DXIL::SigPointKind::MSPOut) {
  228. if (interpMode != InterpolationMode::Kind::Undefined &&
  229. interpMode != InterpolationMode::Kind::Constant) {
  230. dxilutil::EmitErrorOnFunction(func,
  231. "Mesh shader's primitive outputs' interpolation mode must be constant or undefined.");
  232. }
  233. interpMode = InterpolationMode::Kind::Constant;
  234. }
  235. else if (!sigPoint->NeedsInterpMode())
  236. interpMode = InterpolationMode::Kind::Undefined;
  237. else if (interpMode == InterpolationMode::Kind::Undefined) {
  238. // Type-based default: linear for floats, constant for others.
  239. if (EltTy.IsFloatTy())
  240. interpMode = InterpolationMode::Kind::Linear;
  241. else
  242. interpMode = InterpolationMode::Kind::Constant;
  243. }
  244. // back-compat mode - remap obsolete semantics
  245. if (HLM.GetHLOptions().bDX9CompatMode && paramAnnotation.HasSemanticString()) {
  246. hlsl::RemapObsoleteSemantic(paramAnnotation, sigPoint->GetKind(), HLM.GetCtx());
  247. }
  248. llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
  249. if (semanticStr.empty()) {
  250. dxilutil::EmitErrorOnFunction(func,
  251. "Semantic must be defined for all parameters of an entry function or "
  252. "patch constant function");
  253. return;
  254. }
  255. UpdateSemanticAndInterpMode(semanticStr, interpMode, sigPoint->GetKind(),
  256. arg.getContext());
  257. // Get Semantic interpretation, skipping if not in signature
  258. const Semantic *pSemantic = Semantic::GetByName(semanticStr);
  259. DXIL::SemanticInterpretationKind interpretation =
  260. SigPoint::GetInterpretation(pSemantic->GetKind(), sigPoint->GetKind(),
  261. pSM->GetMajor(), pSM->GetMinor());
  262. // Verify system value semantics do not overlap.
  263. // Note: Arbitrary are always in the signature and will be verified with a
  264. // different mechanism. For patch constant function, only validate patch
  265. // constant elements (others already validated on hull function)
  266. if (pSemantic->GetKind() != DXIL::SemanticKind::Arbitrary &&
  267. (!isPatchConstantFunction ||
  268. (!sigPoint->IsInput() && !sigPoint->IsOutput()))) {
  269. auto &SemanticUseMap =
  270. sigPoint->IsInput()
  271. ? m_InputSemanticsUsed
  272. : (sigPoint->IsOutput()
  273. ? m_OutputSemanticsUsed[streamIdx]
  274. : (sigPoint->IsPatchConstOrPrim() ? m_PatchConstantSemanticsUsed
  275. : m_OtherSemanticsUsed));
  276. if (SemanticUseMap.count((unsigned)pSemantic->GetKind()) > 0) {
  277. auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
  278. for (unsigned idx : paramAnnotation.GetSemanticIndexVec()) {
  279. if (SemanticIndexSet.count(idx) > 0) {
  280. dxilutil::EmitErrorOnFunction(func, "Parameter with semantic " + semanticStr +
  281. " has overlapping semantic index at " + std::to_string(idx) + ".");
  282. return;
  283. }
  284. }
  285. }
  286. auto &SemanticIndexSet = SemanticUseMap[(unsigned)pSemantic->GetKind()];
  287. for (unsigned idx : paramAnnotation.GetSemanticIndexVec()) {
  288. SemanticIndexSet.emplace(idx);
  289. }
  290. // Enforce Coverage and InnerCoverage input mutual exclusivity
  291. if (sigPoint->IsInput()) {
  292. if ((pSemantic->GetKind() == DXIL::SemanticKind::Coverage &&
  293. SemanticUseMap.count((unsigned)DXIL::SemanticKind::InnerCoverage) >
  294. 0) ||
  295. (pSemantic->GetKind() == DXIL::SemanticKind::InnerCoverage &&
  296. SemanticUseMap.count((unsigned)DXIL::SemanticKind::Coverage) > 0)) {
  297. dxilutil::EmitErrorOnFunction(func,
  298. "Pixel shader inputs SV_Coverage and SV_InnerCoverage are mutually "
  299. "exclusive");
  300. return;
  301. }
  302. }
  303. }
  304. // Validate interpretation and replace argument usage with load/store
  305. // intrinsics
  306. {
  307. switch (interpretation) {
  308. case DXIL::SemanticInterpretationKind::NA: {
  309. dxilutil::EmitErrorOnFunction(func, Twine("Semantic ") + semanticStr +
  310. Twine(" is invalid for shader model: ") +
  311. ShaderModel::GetKindName(props.shaderKind));
  312. return;
  313. }
  314. case DXIL::SemanticInterpretationKind::NotInSig:
  315. case DXIL::SemanticInterpretationKind::Shadow: {
  316. IRBuilder<> funcBuilder(func->getEntryBlock().getFirstInsertionPt());
  317. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&arg)) {
  318. funcBuilder.SetCurrentDebugLocation(DDI->getDebugLoc());
  319. }
  320. replaceInputOutputWithIntrinsic(pSemantic->GetKind(), &arg, HLM.GetOP(),
  321. funcBuilder);
  322. if (interpretation == DXIL::SemanticInterpretationKind::NotInSig)
  323. return; // This argument should not be included in the signature
  324. break;
  325. }
  326. case DXIL::SemanticInterpretationKind::SV:
  327. case DXIL::SemanticInterpretationKind::SGV:
  328. case DXIL::SemanticInterpretationKind::Arb:
  329. case DXIL::SemanticInterpretationKind::Target:
  330. case DXIL::SemanticInterpretationKind::TessFactor:
  331. case DXIL::SemanticInterpretationKind::NotPacked:
  332. case DXIL::SemanticInterpretationKind::ClipCull:
  333. // Will be replaced with load/store intrinsics in
  334. // GenerateDxilInputsOutputs
  335. break;
  336. default:
  337. DXASSERT(false, "Unexpected SemanticInterpretationKind");
  338. return;
  339. }
  340. }
  341. // Determine signature this argument belongs in, if any
  342. DxilSignature *pSig = nullptr;
  343. DXIL::SignatureKind sigKind = sigPoint->GetSignatureKindWithFallback();
  344. switch (sigKind) {
  345. case DXIL::SignatureKind::Input:
  346. pSig = &EntrySig.InputSignature;
  347. break;
  348. case DXIL::SignatureKind::Output:
  349. pSig = &EntrySig.OutputSignature;
  350. break;
  351. case DXIL::SignatureKind::PatchConstOrPrim:
  352. pSig = &EntrySig.PatchConstOrPrimSignature;
  353. break;
  354. default:
  355. DXASSERT(false, "Expected real signature kind at this point");
  356. return; // No corresponding signature
  357. }
  358. // Create and add element to signature
  359. DxilSignatureElement *pSE = nullptr;
  360. {
  361. // Add signature element to appropriate maps
  362. if (isPatchConstantFunction &&
  363. sigKind != DXIL::SignatureKind::PatchConstOrPrim) {
  364. pSE = FindArgInSignature(arg, paramAnnotation.GetSemanticString(),
  365. interpMode, sigPoint->GetKind(), *pSig);
  366. if (!pSE) {
  367. dxilutil::EmitErrorOnFunction(func, Twine("Signature element ") + semanticStr +
  368. Twine(", referred to by patch constant function, is not found in "
  369. "corresponding hull shader ") +
  370. (sigKind == DXIL::SignatureKind::Input ? "input." : "output."));
  371. return;
  372. }
  373. m_patchConstantInputsSigMap[arg.getArgNo()] = pSE;
  374. } else {
  375. std::unique_ptr<DxilSignatureElement> SE = pSig->CreateElement();
  376. pSE = SE.get();
  377. pSig->AppendElement(std::move(SE));
  378. pSE->SetSigPointKind(sigPoint->GetKind());
  379. pSE->Initialize(semanticStr, EltTy, interpMode, rows, cols,
  380. Semantic::kUndefinedRow, Semantic::kUndefinedCol,
  381. pSE->GetID(), paramAnnotation.GetSemanticIndexVec());
  382. m_sigValueMap[pSE] = &arg;
  383. }
  384. }
  385. if (paramAnnotation.IsPrecise())
  386. m_preciseSigSet.insert(pSE);
  387. if (sigKind == DXIL::SignatureKind::Output &&
  388. pSemantic->GetKind() == Semantic::Kind::Position && hasClipPlane) {
  389. GenerateClipPlanesForVS(&arg);
  390. hasClipPlane = false;
  391. }
  392. // Set Output Stream.
  393. if (streamIdx > 0)
  394. pSE->SetOutputStream(streamIdx);
  395. }
  396. void HLSignatureLower::CreateDxilSignatures() {
  397. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  398. const ShaderModel *pSM = HLM.GetShaderModel();
  399. DXASSERT(Entry->getReturnType()->isVoidTy(),
  400. "Should changed in SROA_Parameter_HLSL");
  401. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  402. DXASSERT(EntryAnnotation, "must have function annotation for entry function");
  403. bool bHasClipPlane =
  404. props.shaderKind == DXIL::ShaderKind::Vertex ? HasClipPlanes() : false;
  405. const bool isPatchConstantFunctionFalse = false;
  406. const bool bForOutFasle = false;
  407. for (Argument &arg : Entry->getArgumentList()) {
  408. Type *Ty = arg.getType();
  409. // Skip streamout obj.
  410. if (HLModule::IsStreamOutputPtrType(Ty))
  411. continue;
  412. // Skip OutIndices and InPayload
  413. DxilParameterAnnotation &paramAnnotation =
  414. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  415. hlsl::DxilParamInputQual qual = paramAnnotation.GetParamInputQual();
  416. if (qual == hlsl::DxilParamInputQual::OutIndices ||
  417. qual == hlsl::DxilParamInputQual::InPayload)
  418. continue;
  419. ProcessArgument(Entry, EntryAnnotation, arg, props, pSM,
  420. isPatchConstantFunctionFalse, bForOutFasle, bHasClipPlane);
  421. }
  422. if (bHasClipPlane) {
  423. dxilutil::EmitErrorOnFunction(Entry, "Cannot use clipplanes attribute without "
  424. "specifying a 4-component SV_Position "
  425. "output");
  426. }
  427. m_OtherSemanticsUsed.clear();
  428. if (props.shaderKind == DXIL::ShaderKind::Hull) {
  429. Function *patchConstantFunc = props.ShaderProps.HS.patchConstantFunc;
  430. if (patchConstantFunc == nullptr) {
  431. dxilutil::EmitErrorOnFunction(Entry,
  432. "Patch constant function is not specified.");
  433. }
  434. DxilFunctionAnnotation *patchFuncAnnotation =
  435. HLM.GetFunctionAnnotation(patchConstantFunc);
  436. DXASSERT(patchFuncAnnotation,
  437. "must have function annotation for patch constant function");
  438. const bool isPatchConstantFunctionTrue = true;
  439. for (Argument &arg : patchConstantFunc->getArgumentList()) {
  440. ProcessArgument(patchConstantFunc, patchFuncAnnotation, arg, props, pSM,
  441. isPatchConstantFunctionTrue, bForOutFasle, bHasClipPlane);
  442. }
  443. }
  444. }
  445. // Allocate input/output slots
  446. void HLSignatureLower::AllocateDxilInputOutputs() {
  447. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  448. const ShaderModel *pSM = HLM.GetShaderModel();
  449. const HLOptions &opts = HLM.GetHLOptions();
  450. DXASSERT_NOMSG(opts.PackingStrategy <
  451. (unsigned)DXIL::PackingStrategy::Invalid);
  452. DXIL::PackingStrategy packing = (DXIL::PackingStrategy)opts.PackingStrategy;
  453. if (packing == DXIL::PackingStrategy::Default)
  454. packing = pSM->GetDefaultPackingStrategy();
  455. hlsl::PackDxilSignature(EntrySig.InputSignature, packing);
  456. if (!EntrySig.InputSignature.IsFullyAllocated()) {
  457. dxilutil::EmitErrorOnFunction(Entry,
  458. "Failed to allocate all input signature elements in available space.");
  459. }
  460. if (props.shaderKind != DXIL::ShaderKind::Amplification) {
  461. hlsl::PackDxilSignature(EntrySig.OutputSignature, packing);
  462. if (!EntrySig.OutputSignature.IsFullyAllocated()) {
  463. dxilutil::EmitErrorOnFunction(Entry,
  464. "Failed to allocate all output signature elements in available space.");
  465. }
  466. }
  467. if (props.shaderKind == DXIL::ShaderKind::Hull ||
  468. props.shaderKind == DXIL::ShaderKind::Domain ||
  469. props.shaderKind == DXIL::ShaderKind::Mesh) {
  470. hlsl::PackDxilSignature(EntrySig.PatchConstOrPrimSignature, packing);
  471. if (!EntrySig.PatchConstOrPrimSignature.IsFullyAllocated()) {
  472. dxilutil::EmitErrorOnFunction(Entry,
  473. "Failed to allocate all patch constant signature "
  474. "elements in available space.");
  475. }
  476. }
  477. }
  478. namespace {
  479. // Helper functions and class for lower signature.
  480. void GenerateStOutput(Function *stOutput, MutableArrayRef<Value *> args,
  481. IRBuilder<> &Builder, bool cast) {
  482. if (cast) {
  483. Value *value = args[DXIL::OperandIndex::kStoreOutputValOpIdx];
  484. args[DXIL::OperandIndex::kStoreOutputValOpIdx] =
  485. Builder.CreateZExt(value, Builder.getInt32Ty());
  486. }
  487. Builder.CreateCall(stOutput, args);
  488. }
  489. void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
  490. Constant *OpArg, Constant *outputID, Value *idx,
  491. unsigned cols, Value *vertexOrPrimID, bool bI1Cast) {
  492. IRBuilder<> Builder(stInst);
  493. Value *val = stInst->getValueOperand();
  494. if (VectorType *VT = dyn_cast<VectorType>(val->getType())) {
  495. DXASSERT_LOCALVAR(VT, cols == VT->getNumElements(), "vec size must match");
  496. for (unsigned col = 0; col < cols; col++) {
  497. Value *subVal = Builder.CreateExtractElement(val, col);
  498. Value *colIdx = Builder.getInt8(col);
  499. SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, subVal};
  500. if (vertexOrPrimID)
  501. args.emplace_back(vertexOrPrimID);
  502. GenerateStOutput(stOutput, args, Builder, bI1Cast);
  503. }
  504. // remove stInst
  505. stInst->eraseFromParent();
  506. } else if (!val->getType()->isArrayTy()) {
  507. // TODO: support case cols not 1
  508. DXASSERT(cols == 1, "only support scalar here");
  509. Value *colIdx = Builder.getInt8(0);
  510. SmallVector<Value *, 4> args = {OpArg, outputID, idx, colIdx, val};
  511. if (vertexOrPrimID)
  512. args.emplace_back(vertexOrPrimID);
  513. GenerateStOutput(stOutput, args, Builder, bI1Cast);
  514. // remove stInst
  515. stInst->eraseFromParent();
  516. } else {
  517. DXASSERT(0, "not support array yet");
  518. // TODO: support array.
  519. Value *colIdx = Builder.getInt8(0);
  520. ArrayType *AT = cast<ArrayType>(val->getType());
  521. Value *args[] = {OpArg, outputID, idx, colIdx, /*val*/ nullptr};
  522. (void)args;
  523. (void)AT;
  524. }
  525. }
  526. Value *GenerateLdInput(Function *loadInput, ArrayRef<Value *> args,
  527. IRBuilder<> &Builder, Value *zero, bool bCast,
  528. Type *Ty) {
  529. Value *input = Builder.CreateCall(loadInput, args);
  530. if (!bCast)
  531. return input;
  532. else {
  533. Value *bVal = Builder.CreateICmpNE(input, zero);
  534. IntegerType *IT = cast<IntegerType>(Ty);
  535. if (IT->getBitWidth() == 1)
  536. return bVal;
  537. else
  538. return Builder.CreateZExt(bVal, Ty);
  539. }
  540. }
  541. Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
  542. unsigned cols, MutableArrayRef<Value *> args,
  543. bool bCast) {
  544. IRBuilder<> Builder(ldInst);
  545. IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(ldInst));
  546. Type *Ty = ldInst->getType();
  547. Type *EltTy = Ty->getScalarType();
  548. // Change i1 to i32 for load input.
  549. Value *zero = Builder.getInt32(0);
  550. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  551. Value *newVec = llvm::UndefValue::get(VT);
  552. DXASSERT(cols == VT->getNumElements(), "vec size must match");
  553. for (unsigned col = 0; col < cols; col++) {
  554. Value *colIdx = Builder.getInt8(col);
  555. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  556. Value *input =
  557. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  558. newVec = Builder.CreateInsertElement(newVec, input, col);
  559. }
  560. ldInst->replaceAllUsesWith(newVec);
  561. ldInst->eraseFromParent();
  562. return newVec;
  563. } else {
  564. Value *colIdx = args[DXIL::OperandIndex::kLoadInputColOpIdx];
  565. if (colIdx == nullptr) {
  566. DXASSERT(cols == 1, "only support scalar here");
  567. colIdx = Builder.getInt8(0);
  568. } else {
  569. if (colIdx->getType() == Builder.getInt32Ty()) {
  570. colIdx = Builder.CreateTrunc(colIdx, Builder.getInt8Ty());
  571. }
  572. }
  573. if (isa<ConstantInt>(colIdx)) {
  574. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  575. Value *input =
  576. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  577. ldInst->replaceAllUsesWith(input);
  578. ldInst->eraseFromParent();
  579. return input;
  580. } else {
  581. // Vector indexing.
  582. // Load to array.
  583. ArrayType *AT = ArrayType::get(ldInst->getType(), cols);
  584. Value *arrayVec = AllocaBuilder.CreateAlloca(AT);
  585. Value *zeroIdx = Builder.getInt32(0);
  586. for (unsigned col = 0; col < cols; col++) {
  587. Value *colIdx = Builder.getInt8(col);
  588. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  589. Value *input =
  590. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  591. Value *GEP = Builder.CreateInBoundsGEP(arrayVec, {zeroIdx, colIdx});
  592. Builder.CreateStore(input, GEP);
  593. }
  594. Value *vecIndexingPtr =
  595. Builder.CreateInBoundsGEP(arrayVec, {zeroIdx, colIdx});
  596. Value *input = Builder.CreateLoad(vecIndexingPtr);
  597. ldInst->replaceAllUsesWith(input);
  598. ldInst->eraseFromParent();
  599. return input;
  600. }
  601. }
  602. }
  603. void replaceDirectInputParameter(Value *param, Function *loadInput,
  604. unsigned cols, MutableArrayRef<Value *> args,
  605. bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
  606. Value *zero = hlslOP->GetU32Const(0);
  607. Type *Ty = param->getType();
  608. Type *EltTy = Ty->getScalarType();
  609. if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
  610. Value *newVec = llvm::UndefValue::get(VT);
  611. DXASSERT(cols == VT->getNumElements(), "vec size must match");
  612. for (unsigned col = 0; col < cols; col++) {
  613. Value *colIdx = hlslOP->GetU8Const(col);
  614. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  615. Value *input =
  616. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  617. newVec = Builder.CreateInsertElement(newVec, input, col);
  618. }
  619. param->replaceAllUsesWith(newVec);
  620. // THe individual loadInputs are the authoritative source of values for the vector.
  621. dxilutil::TryScatterDebugValueToVectorElements(newVec);
  622. } else if (!Ty->isArrayTy() && !HLMatrixType::isa(Ty)) {
  623. DXASSERT(cols == 1, "only support scalar here");
  624. Value *colIdx = hlslOP->GetU8Const(0);
  625. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  626. Value *input =
  627. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  628. param->replaceAllUsesWith(input); // Will properly relocate any DbgValueInst
  629. } else if (HLMatrixType::isa(Ty)) {
  630. if (param->use_empty()) return;
  631. DXASSERT(param->hasOneUse(),
  632. "matrix arg should only has one use as matrix to vec");
  633. CallInst *CI = cast<CallInst>(param->user_back());
  634. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  635. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLCast,
  636. "must be hlcast here");
  637. unsigned opcode = GetHLOpcode(CI);
  638. HLCastOpcode matOp = static_cast<HLCastOpcode>(opcode);
  639. switch (matOp) {
  640. case HLCastOpcode::ColMatrixToVecCast: {
  641. IRBuilder<> LocalBuilder(CI);
  642. HLMatrixType MatTy = HLMatrixType::cast(
  643. CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
  644. Type *EltTy = MatTy.getElementTypeForReg();
  645. std::vector<Value *> matElts(MatTy.getNumElements());
  646. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  647. Value *rowIdx = hlslOP->GetI32Const(c);
  648. args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
  649. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  650. Value *colIdx = hlslOP->GetU8Const(r);
  651. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  652. Value *input =
  653. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  654. matElts[MatTy.getColumnMajorIndex(r, c)] = input;
  655. }
  656. }
  657. Value *newVec =
  658. HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
  659. CI->replaceAllUsesWith(newVec);
  660. CI->eraseFromParent();
  661. } break;
  662. case HLCastOpcode::RowMatrixToVecCast: {
  663. IRBuilder<> LocalBuilder(CI);
  664. HLMatrixType MatTy = HLMatrixType::cast(
  665. CI->getArgOperand(HLOperandIndex::kUnaryOpSrc0Idx)->getType());
  666. Type *EltTy = MatTy.getElementTypeForReg();
  667. std::vector<Value *> matElts(MatTy.getNumElements());
  668. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  669. Value *rowIdx = hlslOP->GetI32Const(r);
  670. args[DXIL::OperandIndex::kLoadInputRowOpIdx] = rowIdx;
  671. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  672. Value *colIdx = hlslOP->GetU8Const(c);
  673. args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
  674. Value *input =
  675. GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
  676. matElts[MatTy.getRowMajorIndex(r, c)] = input;
  677. }
  678. }
  679. Value *newVec =
  680. HLMatrixLower::BuildVector(EltTy, matElts, LocalBuilder);
  681. CI->replaceAllUsesWith(newVec);
  682. CI->eraseFromParent();
  683. } break;
  684. default:
  685. // Only matrix to vector casts are valid.
  686. break;
  687. }
  688. } else {
  689. DXASSERT(0, "invalid type for direct input");
  690. }
  691. }
  692. struct InputOutputAccessInfo {
  693. // For input output which has only 1 row, idx is 0.
  694. Value *idx;
  695. // VertexID for HS/DS/GS input, MS vertex output. PrimitiveID for MS primitive output
  696. Value *vertexOrPrimID;
  697. // Vector index.
  698. Value *vectorIdx;
  699. // Load/Store/LoadMat/StoreMat on input/output.
  700. Instruction *user;
  701. InputOutputAccessInfo(Value *index, Instruction *I)
  702. : idx(index), vertexOrPrimID(nullptr), vectorIdx(nullptr), user(I) {}
  703. InputOutputAccessInfo(Value *index, Instruction *I, Value *ID, Value *vecIdx)
  704. : idx(index), vertexOrPrimID(ID), vectorIdx(vecIdx), user(I) {}
  705. };
  706. void collectInputOutputAccessInfo(
  707. Value *GV, Constant *constZero,
  708. std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexOrPrimID,
  709. bool bInput, bool bRowMajor, bool isMS) {
  710. // merge GEP use for input output.
  711. HLModule::MergeGepUse(GV);
  712. for (auto User = GV->user_begin(); User != GV->user_end();) {
  713. Value *I = *(User++);
  714. if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
  715. if (bInput) {
  716. InputOutputAccessInfo info = {constZero, ldInst};
  717. accessInfoList.push_back(info);
  718. }
  719. } else if (StoreInst *stInst = dyn_cast<StoreInst>(I)) {
  720. if (!bInput) {
  721. InputOutputAccessInfo info = {constZero, stInst};
  722. accessInfoList.push_back(info);
  723. }
  724. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  725. // Vector indexing may has more indices.
  726. // Vector indexing changed to array indexing in SROA_HLSL.
  727. auto idx = GEP->idx_begin();
  728. DXASSERT_LOCALVAR(idx, idx->get() == constZero,
  729. "only support 0 offset for input pointer");
  730. Value *vertexOrPrimID = nullptr;
  731. Value *vectorIdx = nullptr;
  732. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  733. // Skip first pointer idx which must be 0.
  734. GEPIt++;
  735. if (hasVertexOrPrimID) {
  736. // Save vertexOrPrimID.
  737. vertexOrPrimID = GEPIt.getOperand();
  738. GEPIt++;
  739. }
  740. // Start from first index.
  741. Value *rowIdx = GEPIt.getOperand();
  742. if (GEPIt != E) {
  743. if ((*GEPIt)->isVectorTy()) {
  744. // Vector indexing.
  745. rowIdx = constZero;
  746. vectorIdx = GEPIt.getOperand();
  747. DXASSERT_NOMSG((++GEPIt) == E);
  748. } else {
  749. // Array which may have vector indexing.
  750. // Highest dim index is saved in rowIdx,
  751. // array size for highest dim not affect index.
  752. GEPIt++;
  753. IRBuilder<> Builder(GEP);
  754. Type *idxTy = rowIdx->getType();
  755. for (; GEPIt != E; ++GEPIt) {
  756. DXASSERT(!GEPIt->isStructTy(),
  757. "Struct should be flattened SROA_Parameter_HLSL");
  758. DXASSERT(!GEPIt->isPointerTy(),
  759. "not support pointer type in middle of GEP");
  760. if (GEPIt->isArrayTy()) {
  761. Constant *arraySize =
  762. ConstantInt::get(idxTy, GEPIt->getArrayNumElements());
  763. rowIdx = Builder.CreateMul(rowIdx, arraySize);
  764. rowIdx = Builder.CreateAdd(rowIdx, GEPIt.getOperand());
  765. } else {
  766. Type *Ty = *GEPIt;
  767. DXASSERT_LOCALVAR(Ty, Ty->isVectorTy(),
  768. "must be vector type here to index");
  769. // Save vector idx.
  770. vectorIdx = GEPIt.getOperand();
  771. }
  772. }
  773. if (HLMatrixType MatTy = HLMatrixType::dyn_cast(*GEPIt)) {
  774. Constant *arraySize = ConstantInt::get(idxTy, MatTy.getNumColumns());
  775. if (bRowMajor) {
  776. arraySize = ConstantInt::get(idxTy, MatTy.getNumRows());
  777. }
  778. rowIdx = Builder.CreateMul(rowIdx, arraySize);
  779. }
  780. }
  781. } else
  782. rowIdx = constZero;
  783. auto GepUser = GEP->user_begin();
  784. auto GepUserE = GEP->user_end();
  785. Value *idxVal = rowIdx;
  786. for (; GepUser != GepUserE;) {
  787. auto GepUserIt = GepUser++;
  788. if (LoadInst *ldInst = dyn_cast<LoadInst>(*GepUserIt)) {
  789. if (bInput) {
  790. InputOutputAccessInfo info = {idxVal, ldInst, vertexOrPrimID, vectorIdx};
  791. accessInfoList.push_back(info);
  792. }
  793. } else if (StoreInst *stInst = dyn_cast<StoreInst>(*GepUserIt)) {
  794. if (!bInput) {
  795. InputOutputAccessInfo info = {idxVal, stInst, vertexOrPrimID, vectorIdx};
  796. accessInfoList.push_back(info);
  797. }
  798. } else if (CallInst *CI = dyn_cast<CallInst>(*GepUserIt)) {
  799. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  800. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLMatLoadStore,
  801. "input/output should only used by ld/st");
  802. HLMatLoadStoreOpcode opcode = (HLMatLoadStoreOpcode)GetHLOpcode(CI);
  803. if ((opcode == HLMatLoadStoreOpcode::ColMatLoad ||
  804. opcode == HLMatLoadStoreOpcode::RowMatLoad)
  805. ? bInput
  806. : !bInput) {
  807. InputOutputAccessInfo info = {idxVal, CI, vertexOrPrimID, vectorIdx};
  808. accessInfoList.push_back(info);
  809. }
  810. } else {
  811. DXASSERT(0, "input output should only used by ld/st");
  812. }
  813. }
  814. } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
  815. InputOutputAccessInfo info = {constZero, CI};
  816. accessInfoList.push_back(info);
  817. } else {
  818. DXASSERT(0, "input output should only used by ld/st");
  819. }
  820. }
  821. }
  822. void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertexIdx,
  823. Function *ldStFunc, Constant *OpArg, Constant *ID, unsigned cols, bool bI1Cast,
  824. Constant *columnConsts[],
  825. bool bNeedVertexOrPrimID, bool isArrayTy, bool bInput, bool bIsInout) {
  826. Value *idxVal = info.idx;
  827. Value *vertexOrPrimID = undefVertexIdx;
  828. if (bNeedVertexOrPrimID && isArrayTy) {
  829. vertexOrPrimID = info.vertexOrPrimID;
  830. }
  831. if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
  832. SmallVector<Value *, 4> args = {OpArg, ID, idxVal, info.vectorIdx};
  833. if (vertexOrPrimID)
  834. args.emplace_back(vertexOrPrimID);
  835. replaceLdWithLdInput(ldStFunc, ldInst, cols, args, bI1Cast);
  836. } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
  837. if (bInput) {
  838. DXASSERT_LOCALVAR(bIsInout, bIsInout, "input should not have store use.");
  839. } else {
  840. if (!info.vectorIdx) {
  841. replaceStWithStOutput(ldStFunc, stInst, OpArg, ID, idxVal, cols,
  842. vertexOrPrimID, bI1Cast);
  843. } else {
  844. Value *V = stInst->getValueOperand();
  845. Type *Ty = V->getType();
  846. DXASSERT_LOCALVAR(Ty == Ty->getScalarType() && !Ty->isAggregateType(),
  847. Ty, "only support scalar here");
  848. if (ConstantInt *ColIdx = dyn_cast<ConstantInt>(info.vectorIdx)) {
  849. IRBuilder<> Builder(stInst);
  850. if (ColIdx->getType()->getBitWidth() != 8) {
  851. ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
  852. }
  853. SmallVector<Value *, 6> args = {OpArg, ID, idxVal, ColIdx, V};
  854. if (vertexOrPrimID)
  855. args.emplace_back(vertexOrPrimID);
  856. GenerateStOutput(ldStFunc, args, Builder, bI1Cast);
  857. } else {
  858. BasicBlock *BB = stInst->getParent();
  859. BasicBlock *EndBB = BB->splitBasicBlock(stInst);
  860. TerminatorInst *TI = BB->getTerminator();
  861. IRBuilder<> SwitchBuilder(TI);
  862. LLVMContext &Ctx = stInst->getContext();
  863. SwitchInst *Switch =
  864. SwitchBuilder.CreateSwitch(info.vectorIdx, EndBB, cols);
  865. TI->eraseFromParent();
  866. Function *F = EndBB->getParent();
  867. for (unsigned i = 0; i < cols; i++) {
  868. BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
  869. Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
  870. IRBuilder<> CaseBuilder(CaseBB);
  871. ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
  872. SmallVector<Value *, 6> args = {OpArg, ID, idxVal, CaseIdx, V};
  873. if (vertexOrPrimID)
  874. args.emplace_back(vertexOrPrimID);
  875. GenerateStOutput(ldStFunc, args, CaseBuilder, bI1Cast);
  876. CaseBuilder.CreateBr(EndBB);
  877. }
  878. }
  879. // remove stInst
  880. stInst->eraseFromParent();
  881. }
  882. }
  883. } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
  884. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  885. // Intrinsic will be translated later.
  886. if (group == HLOpcodeGroup::HLIntrinsic || group == HLOpcodeGroup::NotHL)
  887. return;
  888. unsigned opcode = GetHLOpcode(CI);
  889. DXASSERT_NOMSG(group == HLOpcodeGroup::HLMatLoadStore);
  890. HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
  891. switch (matOp) {
  892. case HLMatLoadStoreOpcode::ColMatLoad:
  893. case HLMatLoadStoreOpcode::RowMatLoad: {
  894. IRBuilder<> LocalBuilder(CI);
  895. HLMatrixType MatTy = HLMatrixType::cast(
  896. CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
  897. ->getType()->getPointerElementType());
  898. std::vector<Value *> matElts(MatTy.getNumElements());
  899. if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
  900. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  901. Constant *constRowIdx = LocalBuilder.getInt32(c);
  902. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  903. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  904. SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
  905. if (vertexOrPrimID)
  906. args.emplace_back(vertexOrPrimID);
  907. Value *input = LocalBuilder.CreateCall(ldStFunc, args);
  908. unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
  909. matElts[matIdx] = input;
  910. }
  911. }
  912. } else {
  913. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  914. Constant *constRowIdx = LocalBuilder.getInt32(r);
  915. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  916. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  917. SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
  918. if (vertexOrPrimID)
  919. args.emplace_back(vertexOrPrimID);
  920. Value *input = LocalBuilder.CreateCall(ldStFunc, args);
  921. unsigned matIdx = MatTy.getRowMajorIndex(r, c);
  922. matElts[matIdx] = input;
  923. }
  924. }
  925. }
  926. Value *newVec =
  927. HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
  928. newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
  929. CI->replaceAllUsesWith(newVec);
  930. CI->eraseFromParent();
  931. } break;
  932. case HLMatLoadStoreOpcode::ColMatStore:
  933. case HLMatLoadStoreOpcode::RowMatStore: {
  934. IRBuilder<> LocalBuilder(CI);
  935. Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
  936. HLMatrixType MatTy = HLMatrixType::cast(
  937. CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
  938. ->getType()->getPointerElementType());
  939. Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
  940. if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
  941. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  942. Constant *constColIdx = LocalBuilder.getInt32(c);
  943. Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
  944. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  945. unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
  946. Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
  947. LocalBuilder.CreateCall(ldStFunc,
  948. { OpArg, ID, colIdx, columnConsts[r], Elt });
  949. }
  950. }
  951. } else {
  952. for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
  953. Constant *constRowIdx = LocalBuilder.getInt32(r);
  954. Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
  955. for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
  956. unsigned matIdx = MatTy.getRowMajorIndex(r, c);
  957. Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
  958. LocalBuilder.CreateCall(ldStFunc,
  959. { OpArg, ID, rowIdx, columnConsts[c], Elt });
  960. }
  961. }
  962. }
  963. CI->eraseFromParent();
  964. } break;
  965. }
  966. } else {
  967. DXASSERT(0, "invalid operation on input output");
  968. }
  969. }
  970. } // namespace
  971. void HLSignatureLower::GenerateDxilInputs() {
  972. GenerateDxilInputsOutputs(DXIL::SignatureKind::Input);
  973. }
  974. void HLSignatureLower::GenerateDxilOutputs() {
  975. GenerateDxilInputsOutputs(DXIL::SignatureKind::Output);
  976. }
  977. void HLSignatureLower::GenerateDxilPrimOutputs() {
  978. GenerateDxilInputsOutputs(DXIL::SignatureKind::PatchConstOrPrim);
  979. }
  980. void HLSignatureLower::GenerateDxilInputsOutputs(DXIL::SignatureKind SK) {
  981. OP *hlslOP = HLM.GetOP();
  982. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  983. Module &M = *(HLM.GetModule());
  984. OP::OpCode opcode = (OP::OpCode)-1;
  985. switch (SK) {
  986. case DXIL::SignatureKind::Input:
  987. opcode = OP::OpCode::LoadInput;
  988. break;
  989. case DXIL::SignatureKind::Output:
  990. opcode = props.IsMS() ? OP::OpCode::StoreVertexOutput : OP::OpCode::StoreOutput;
  991. break;
  992. case DXIL::SignatureKind::PatchConstOrPrim:
  993. opcode = OP::OpCode::StorePrimitiveOutput;
  994. break;
  995. default:
  996. DXASSERT_NOMSG(0);
  997. }
  998. bool bInput = SK == DXIL::SignatureKind::Input;
  999. bool bNeedVertexOrPrimID = bInput && (props.IsGS() || props.IsDS() || props.IsHS());
  1000. bNeedVertexOrPrimID |= !bInput && props.IsMS();
  1001. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1002. Constant *columnConsts[] = {
  1003. hlslOP->GetU8Const(0), hlslOP->GetU8Const(1), hlslOP->GetU8Const(2),
  1004. hlslOP->GetU8Const(3), hlslOP->GetU8Const(4), hlslOP->GetU8Const(5),
  1005. hlslOP->GetU8Const(6), hlslOP->GetU8Const(7), hlslOP->GetU8Const(8),
  1006. hlslOP->GetU8Const(9), hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
  1007. hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
  1008. hlslOP->GetU8Const(15)};
  1009. Constant *constZero = hlslOP->GetU32Const(0);
  1010. Value *undefVertexIdx = props.IsMS() || !bInput ? nullptr : UndefValue::get(Type::getInt32Ty(HLM.GetCtx()));
  1011. DxilSignature &Sig =
  1012. bInput ? EntrySig.InputSignature :
  1013. SK == DXIL::SignatureKind::Output ? EntrySig.OutputSignature :
  1014. EntrySig.PatchConstOrPrimSignature;
  1015. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  1016. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
  1017. Type *i1Ty = Type::getInt1Ty(constZero->getContext());
  1018. Type *i32Ty = constZero->getType();
  1019. llvm::SmallVector<unsigned, 8> removeIndices;
  1020. for (unsigned i = 0; i < Sig.GetElements().size(); i++) {
  1021. DxilSignatureElement *SE = &Sig.GetElement(i);
  1022. llvm::Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1023. // Cast i1 to i32 for load input.
  1024. bool bI1Cast = false;
  1025. if (Ty == i1Ty) {
  1026. bI1Cast = true;
  1027. Ty = i32Ty;
  1028. }
  1029. if (!hlslOP->IsOverloadLegal(opcode, Ty)) {
  1030. std::string O;
  1031. raw_string_ostream OSS(O);
  1032. Ty->print(OSS);
  1033. OSS << "(type for " << SE->GetName() << ")";
  1034. OSS << " cannot be used as shader inputs or outputs.";
  1035. OSS.flush();
  1036. HLM.GetCtx().emitError(O);
  1037. continue;
  1038. }
  1039. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
  1040. Constant *ID = hlslOP->GetU32Const(i);
  1041. unsigned cols = SE->GetCols();
  1042. Value *GV = m_sigValueMap[SE];
  1043. bool bIsInout = m_inoutArgSet.count(GV) > 0;
  1044. IRBuilder<> EntryBuilder(Entry->getEntryBlock().getFirstInsertionPt());
  1045. if (DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(GV)) {
  1046. EntryBuilder.SetCurrentDebugLocation(DDI->getDebugLoc());
  1047. }
  1048. DXIL::SemanticInterpretationKind SI = SE->GetInterpretation();
  1049. DXASSERT_NOMSG(SI < DXIL::SemanticInterpretationKind::Invalid);
  1050. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NA);
  1051. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NotInSig);
  1052. if (SI == DXIL::SemanticInterpretationKind::Shadow)
  1053. continue; // Handled in ProcessArgument
  1054. if (!GV->getType()->isPointerTy()) {
  1055. DXASSERT(bInput, "direct parameter must be input");
  1056. Value *vertexOrPrimID = undefVertexIdx;
  1057. Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr,
  1058. vertexOrPrimID};
  1059. replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
  1060. EntryBuilder);
  1061. continue;
  1062. }
  1063. bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
  1064. bool bIsPrecise = m_preciseSigSet.count(SE);
  1065. if (bIsPrecise)
  1066. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
  1067. bool bRowMajor = false;
  1068. if (Argument *Arg = dyn_cast<Argument>(GV)) {
  1069. if (pFuncAnnot) {
  1070. auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
  1071. if (paramAnnot.HasMatrixAnnotation())
  1072. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1073. MatrixOrientation::RowMajor;
  1074. }
  1075. }
  1076. std::vector<InputOutputAccessInfo> accessInfoList;
  1077. collectInputOutputAccessInfo(GV, constZero, accessInfoList,
  1078. bNeedVertexOrPrimID && bIsArrayTy, bInput, bRowMajor, props.IsMS());
  1079. for (InputOutputAccessInfo &info : accessInfoList) {
  1080. GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
  1081. cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
  1082. bIsArrayTy, bInput, bIsInout);
  1083. }
  1084. }
  1085. }
  1086. void HLSignatureLower::GenerateDxilCSInputs() {
  1087. OP *hlslOP = HLM.GetOP();
  1088. DxilFunctionAnnotation *funcAnnotation = HLM.GetFunctionAnnotation(Entry);
  1089. DXASSERT(funcAnnotation, "must find annotation for entry function");
  1090. IRBuilder<> Builder(Entry->getEntryBlock().getFirstInsertionPt());
  1091. for (Argument &arg : Entry->args()) {
  1092. DxilParameterAnnotation &paramAnnotation =
  1093. funcAnnotation->GetParameterAnnotation(arg.getArgNo());
  1094. llvm::StringRef semanticStr = paramAnnotation.GetSemanticString();
  1095. if (semanticStr.empty()) {
  1096. dxilutil::EmitErrorOnFunction(Entry, "Semantic must be defined for all "
  1097. "parameters of an entry function or patch "
  1098. "constant function");
  1099. return;
  1100. }
  1101. const Semantic *semantic =
  1102. Semantic::GetByName(semanticStr, DXIL::SigPointKind::CSIn);
  1103. OP::OpCode opcode;
  1104. switch (semantic->GetKind()) {
  1105. case Semantic::Kind::GroupThreadID:
  1106. opcode = OP::OpCode::ThreadIdInGroup;
  1107. break;
  1108. case Semantic::Kind::GroupID:
  1109. opcode = OP::OpCode::GroupId;
  1110. break;
  1111. case Semantic::Kind::DispatchThreadID:
  1112. opcode = OP::OpCode::ThreadId;
  1113. break;
  1114. case Semantic::Kind::GroupIndex:
  1115. opcode = OP::OpCode::FlattenedThreadIdInGroup;
  1116. break;
  1117. default:
  1118. DXASSERT(semantic->IsInvalid(),
  1119. "else compute shader semantics out-of-date");
  1120. dxilutil::EmitErrorOnFunction(Entry, "invalid semantic found in CS");
  1121. return;
  1122. }
  1123. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1124. Type *NumTy = arg.getType();
  1125. DXASSERT(!NumTy->isPointerTy(), "Unexpected byref value for CS SV_***ID semantic.");
  1126. DXASSERT(NumTy->getScalarType()->isIntegerTy(), "Unexpected non-integer value for CS SV_***ID semantic.");
  1127. // Always use the i32 overload of those intrinsics, and then cast as needed
  1128. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Builder.getInt32Ty());
  1129. Value *newArg = nullptr;
  1130. if (opcode == OP::OpCode::FlattenedThreadIdInGroup) {
  1131. newArg = Builder.CreateCall(dxilFunc, {OpArg});
  1132. } else {
  1133. unsigned vecSize = 1;
  1134. if (NumTy->isVectorTy())
  1135. vecSize = NumTy->getVectorNumElements();
  1136. newArg = Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU32Const(0)});
  1137. if (vecSize > 1) {
  1138. Value *result = UndefValue::get(VectorType::get(Builder.getInt32Ty(), vecSize));
  1139. result = Builder.CreateInsertElement(result, newArg, (uint64_t)0);
  1140. for (unsigned i = 1; i < vecSize; i++) {
  1141. Value *newElt =
  1142. Builder.CreateCall(dxilFunc, {OpArg, hlslOP->GetU32Const(i)});
  1143. result = Builder.CreateInsertElement(result, newElt, i);
  1144. }
  1145. newArg = result;
  1146. }
  1147. }
  1148. // If the argument is of non-i32 type, convert here
  1149. if (newArg->getType() != NumTy)
  1150. newArg = Builder.CreateZExtOrTrunc(newArg, NumTy);
  1151. if (newArg->getType() != arg.getType()) {
  1152. DXASSERT_NOMSG(arg.getType()->isPointerTy());
  1153. for (User *U : arg.users()) {
  1154. LoadInst *LI = cast<LoadInst>(U);
  1155. LI->replaceAllUsesWith(newArg);
  1156. }
  1157. } else {
  1158. arg.replaceAllUsesWith(newArg);
  1159. }
  1160. }
  1161. }
  1162. void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
  1163. OP *hlslOP = HLM.GetOP();
  1164. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  1165. Module &M = *(HLM.GetModule());
  1166. Constant *constZero = hlslOP->GetU32Const(0);
  1167. DxilSignature &Sig = EntrySig.PatchConstOrPrimSignature;
  1168. DxilTypeSystem &typeSys = HLM.GetTypeSystem();
  1169. DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
  1170. auto InsertPt = Entry->getEntryBlock().getFirstInsertionPt();
  1171. const bool bIsHs = props.IsHS();
  1172. const bool bIsInput = !bIsHs;
  1173. const bool bIsInout = false;
  1174. const bool bNeedVertexOrPrimID = false;
  1175. if (bIsHs) {
  1176. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1177. Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
  1178. InsertPt = patchConstantFunc->getEntryBlock().getFirstInsertionPt();
  1179. pFuncAnnot = typeSys.GetFunctionAnnotation(patchConstantFunc);
  1180. }
  1181. IRBuilder<> Builder(InsertPt);
  1182. Type *i1Ty = Builder.getInt1Ty();
  1183. Type *i32Ty = Builder.getInt32Ty();
  1184. // LoadPatchConst don't have vertexIdx operand.
  1185. Value *undefVertexIdx = nullptr;
  1186. Constant *columnConsts[] = {
  1187. hlslOP->GetU8Const(0), hlslOP->GetU8Const(1), hlslOP->GetU8Const(2),
  1188. hlslOP->GetU8Const(3), hlslOP->GetU8Const(4), hlslOP->GetU8Const(5),
  1189. hlslOP->GetU8Const(6), hlslOP->GetU8Const(7), hlslOP->GetU8Const(8),
  1190. hlslOP->GetU8Const(9), hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
  1191. hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
  1192. hlslOP->GetU8Const(15)};
  1193. OP::OpCode opcode =
  1194. bIsInput ? OP::OpCode::LoadPatchConstant : OP::OpCode::StorePatchConstant;
  1195. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1196. for (unsigned i = 0; i < Sig.GetElements().size(); i++) {
  1197. DxilSignatureElement *SE = &Sig.GetElement(i);
  1198. Value *GV = m_sigValueMap[SE];
  1199. DXIL::SemanticInterpretationKind SI = SE->GetInterpretation();
  1200. DXASSERT_NOMSG(SI < DXIL::SemanticInterpretationKind::Invalid);
  1201. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NA);
  1202. DXASSERT_NOMSG(SI != DXIL::SemanticInterpretationKind::NotInSig);
  1203. if (SI == DXIL::SemanticInterpretationKind::Shadow)
  1204. continue; // Handled in ProcessArgument
  1205. Constant *ID = hlslOP->GetU32Const(i);
  1206. // Generate LoadPatchConstant.
  1207. Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1208. // Cast i1 to i32 for load input.
  1209. bool bI1Cast = false;
  1210. if (Ty == i1Ty) {
  1211. bI1Cast = true;
  1212. Ty = i32Ty;
  1213. }
  1214. unsigned cols = SE->GetCols();
  1215. Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
  1216. if (!GV->getType()->isPointerTy()) {
  1217. DXASSERT(bIsInput, "Must be DS input.");
  1218. Constant *OpArg = hlslOP->GetU32Const(
  1219. static_cast<unsigned>(OP::OpCode::LoadPatchConstant));
  1220. Value *args[] = {OpArg, ID, /*rowIdx*/ constZero, /*colIdx*/ nullptr};
  1221. replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP,
  1222. Builder);
  1223. continue;
  1224. }
  1225. bool bRowMajor = false;
  1226. if (Argument *Arg = dyn_cast<Argument>(GV)) {
  1227. if (pFuncAnnot) {
  1228. auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
  1229. if (paramAnnot.HasMatrixAnnotation())
  1230. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1231. MatrixOrientation::RowMajor;
  1232. }
  1233. }
  1234. std::vector<InputOutputAccessInfo> accessInfoList;
  1235. collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexOrPrimID,
  1236. bIsInput, bRowMajor, false);
  1237. bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
  1238. bool isPrecise = m_preciseSigSet.count(SE);
  1239. if (isPrecise)
  1240. HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
  1241. for (InputOutputAccessInfo &info : accessInfoList) {
  1242. GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
  1243. cols, bI1Cast, columnConsts, bNeedVertexOrPrimID,
  1244. bIsArrayTy, bIsInput, bIsInout);
  1245. }
  1246. }
  1247. }
  1248. void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
  1249. // Map input patch, to input sig
  1250. // LoadOutputControlPoint for output patch .
  1251. OP *hlslOP = HLM.GetOP();
  1252. Constant *constZero = hlslOP->GetU32Const(0);
  1253. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1254. Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
  1255. DxilFunctionAnnotation *patchFuncAnnotation =
  1256. HLM.GetFunctionAnnotation(patchConstantFunc);
  1257. DXASSERT(patchFuncAnnotation,
  1258. "must find annotation for patch constant function");
  1259. Type *i1Ty = Type::getInt1Ty(constZero->getContext());
  1260. Type *i32Ty = constZero->getType();
  1261. for (Argument &arg : patchConstantFunc->args()) {
  1262. DxilParameterAnnotation &paramAnnotation =
  1263. patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
  1264. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1265. if (inputQual == DxilParamInputQual::InputPatch ||
  1266. inputQual == DxilParamInputQual::OutputPatch) {
  1267. DxilSignatureElement *SE = m_patchConstantInputsSigMap[arg.getArgNo()];
  1268. if (!SE) // Error should have been reported at an earlier stage.
  1269. continue;
  1270. Constant *inputID = hlslOP->GetU32Const(SE->GetID());
  1271. unsigned cols = SE->GetCols();
  1272. Type *Ty = SE->GetCompType().GetLLVMType(HLM.GetCtx());
  1273. // Cast i1 to i32 for load input.
  1274. bool bI1Cast = false;
  1275. if (Ty == i1Ty) {
  1276. bI1Cast = true;
  1277. Ty = i32Ty;
  1278. }
  1279. OP::OpCode opcode = inputQual == DxilParamInputQual::InputPatch
  1280. ? OP::OpCode::LoadInput
  1281. : OP::OpCode::LoadOutputControlPoint;
  1282. Function *dxilLdFunc = hlslOP->GetOpFunc(opcode, Ty);
  1283. bool bRowMajor = false;
  1284. if (Argument *Arg = dyn_cast<Argument>(&arg)) {
  1285. if (patchFuncAnnotation) {
  1286. auto &paramAnnot = patchFuncAnnotation->GetParameterAnnotation(Arg->getArgNo());
  1287. if (paramAnnot.HasMatrixAnnotation())
  1288. bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
  1289. MatrixOrientation::RowMajor;
  1290. }
  1291. }
  1292. std::vector<InputOutputAccessInfo> accessInfoList;
  1293. collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
  1294. /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
  1295. for (InputOutputAccessInfo &info : accessInfoList) {
  1296. if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
  1297. Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
  1298. Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
  1299. info.vertexOrPrimID};
  1300. replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
  1301. } else {
  1302. DXASSERT(0, "input should only be ld");
  1303. }
  1304. }
  1305. }
  1306. }
  1307. }
  1308. bool HLSignatureLower::HasClipPlanes() {
  1309. if (!HLM.HasDxilFunctionProps(Entry))
  1310. return false;
  1311. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1312. auto &VS = EntryQual.ShaderProps.VS;
  1313. unsigned numClipPlanes = 0;
  1314. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  1315. if (!VS.clipPlanes[i])
  1316. break;
  1317. numClipPlanes++;
  1318. }
  1319. return numClipPlanes != 0;
  1320. }
  1321. void HLSignatureLower::GenerateClipPlanesForVS(Value *outPosition) {
  1322. DxilFunctionProps &EntryQual = HLM.GetDxilFunctionProps(Entry);
  1323. auto &VS = EntryQual.ShaderProps.VS;
  1324. unsigned numClipPlanes = 0;
  1325. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  1326. if (!VS.clipPlanes[i])
  1327. break;
  1328. numClipPlanes++;
  1329. }
  1330. if (!numClipPlanes)
  1331. return;
  1332. LLVMContext &Ctx = HLM.GetCtx();
  1333. Function *dp4 =
  1334. HLM.GetOP()->GetOpFunc(DXIL::OpCode::Dot4, Type::getFloatTy(Ctx));
  1335. Value *dp4Args[] = {
  1336. ConstantInt::get(Type::getInt32Ty(Ctx),
  1337. static_cast<unsigned>(DXIL::OpCode::Dot4)),
  1338. nullptr,
  1339. nullptr,
  1340. nullptr,
  1341. nullptr,
  1342. nullptr,
  1343. nullptr,
  1344. nullptr,
  1345. nullptr,
  1346. };
  1347. // out SV_Position should only have StoreInst use.
  1348. // Done by LegalizeDxilInputOutputs in ScalarReplAggregatesHLSL.cpp
  1349. for (User *U : outPosition->users()) {
  1350. StoreInst *ST = cast<StoreInst>(U);
  1351. Value *posVal = ST->getValueOperand();
  1352. DXASSERT(posVal->getType()->isVectorTy(), "SV_Position must be a vector");
  1353. IRBuilder<> Builder(ST);
  1354. // Put position to args.
  1355. for (unsigned i = 0; i < 4; i++)
  1356. dp4Args[i + 1] = Builder.CreateExtractElement(posVal, i);
  1357. // For each clip plane.
  1358. // clipDistance = dp4 position, clipPlane.
  1359. auto argIt = Entry->getArgumentList().rbegin();
  1360. for (int clipIdx = numClipPlanes - 1; clipIdx >= 0; clipIdx--) {
  1361. Constant *GV = VS.clipPlanes[clipIdx];
  1362. DXASSERT_NOMSG(GV->hasOneUse());
  1363. StoreInst *ST = cast<StoreInst>(GV->user_back());
  1364. Value *clipPlane = ST->getValueOperand();
  1365. ST->eraseFromParent();
  1366. Argument &arg = *(argIt++);
  1367. // Put clipPlane to args.
  1368. for (unsigned i = 0; i < 4; i++)
  1369. dp4Args[i + 5] = Builder.CreateExtractElement(clipPlane, i);
  1370. Value *clipDistance = Builder.CreateCall(dp4, dp4Args);
  1371. Builder.CreateStore(clipDistance, &arg);
  1372. }
  1373. }
  1374. }
  1375. namespace {
  1376. Value *TranslateStreamAppend(CallInst *CI, unsigned ID, hlsl::OP *OP) {
  1377. Function *DxilFunc = OP->GetOpFunc(OP::OpCode::EmitStream, CI->getType());
  1378. // TODO: generate a emit which has the data being emited as its argment.
  1379. // Value *data = CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  1380. Constant *opArg = OP->GetU32Const((unsigned)OP::OpCode::EmitStream);
  1381. IRBuilder<> Builder(CI);
  1382. Constant *streamID = OP->GetU8Const(ID);
  1383. Value *args[] = {opArg, streamID};
  1384. return Builder.CreateCall(DxilFunc, args);
  1385. }
  1386. Value *TranslateStreamCut(CallInst *CI, unsigned ID, hlsl::OP *OP) {
  1387. Function *DxilFunc = OP->GetOpFunc(OP::OpCode::CutStream, CI->getType());
  1388. // TODO: generate a emit which has the data being emited as its argment.
  1389. // Value *data = CI->getArgOperand(HLOperandIndex::kStreamAppendDataOpIndex);
  1390. Constant *opArg = OP->GetU32Const((unsigned)OP::OpCode::CutStream);
  1391. IRBuilder<> Builder(CI);
  1392. Constant *streamID = OP->GetU8Const(ID);
  1393. Value *args[] = {opArg, streamID};
  1394. return Builder.CreateCall(DxilFunc, args);
  1395. }
  1396. } // namespace
  1397. // Generate DXIL stream output operation.
  1398. void HLSignatureLower::GenerateStreamOutputOperation(Value *streamVal, unsigned ID) {
  1399. OP * hlslOP = HLM.GetOP();
  1400. for (auto U = streamVal->user_begin(); U != streamVal->user_end();) {
  1401. Value *user = *(U++);
  1402. // Should only used by append, restartStrip .
  1403. CallInst *CI = cast<CallInst>(user);
  1404. HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
  1405. // Ignore user functions.
  1406. if (group == HLOpcodeGroup::NotHL)
  1407. continue;
  1408. unsigned opcode = GetHLOpcode(CI);
  1409. DXASSERT_LOCALVAR(group, group == HLOpcodeGroup::HLIntrinsic, "Must be HLIntrinsic here");
  1410. IntrinsicOp IOP = static_cast<IntrinsicOp>(opcode);
  1411. switch (IOP) {
  1412. case IntrinsicOp::MOP_Append:
  1413. TranslateStreamAppend(CI, ID, hlslOP);
  1414. break;
  1415. case IntrinsicOp::MOP_RestartStrip:
  1416. TranslateStreamCut(CI, ID, hlslOP);
  1417. break;
  1418. default:
  1419. DXASSERT(0, "invalid operation on stream");
  1420. }
  1421. CI->eraseFromParent();
  1422. }
  1423. }
  1424. // Generate DXIL stream output operations.
  1425. void HLSignatureLower::GenerateStreamOutputOperations() {
  1426. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1427. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1428. for (Argument &arg : Entry->getArgumentList()) {
  1429. if (HLModule::IsStreamOutputPtrType(arg.getType())) {
  1430. unsigned streamID = 0;
  1431. DxilParameterAnnotation &paramAnnotation =
  1432. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1433. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1434. switch (inputQual) {
  1435. case DxilParamInputQual::OutStream0:
  1436. streamID = 0;
  1437. break;
  1438. case DxilParamInputQual::OutStream1:
  1439. streamID = 1;
  1440. break;
  1441. case DxilParamInputQual::OutStream2:
  1442. streamID = 2;
  1443. break;
  1444. case DxilParamInputQual::OutStream3:
  1445. default:
  1446. DXASSERT(inputQual == DxilParamInputQual::OutStream3,
  1447. "invalid input qual.");
  1448. streamID = 3;
  1449. break;
  1450. }
  1451. GenerateStreamOutputOperation(&arg, streamID);
  1452. }
  1453. }
  1454. }
  1455. // Generate DXIL EmitIndices operation.
  1456. void HLSignatureLower::GenerateEmitIndicesOperation(Value *indicesOutput) {
  1457. OP * hlslOP = HLM.GetOP();
  1458. Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::EmitIndices, Type::getVoidTy(indicesOutput->getContext()));
  1459. Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::EmitIndices);
  1460. for (auto U = indicesOutput->user_begin(); U != indicesOutput->user_end();) {
  1461. Value *user = *(U++);
  1462. GetElementPtrInst *GEP = cast<GetElementPtrInst>(user);
  1463. auto idx = GEP->idx_begin();
  1464. DXASSERT_LOCALVAR(idx, idx->get() == hlslOP->GetU32Const(0),
  1465. "only support 0 offset for input pointer");
  1466. gep_type_iterator GEPIt = gep_type_begin(GEP), E = gep_type_end(GEP);
  1467. // Skip first pointer idx which must be 0.
  1468. GEPIt++;
  1469. Value *primIdx = GEPIt.getOperand();
  1470. DXASSERT(++GEPIt == E, "invalid GEP here"); (void)E;
  1471. auto GepUser = GEP->user_begin();
  1472. auto GepUserE = GEP->user_end();
  1473. for (; GepUser != GepUserE;) {
  1474. auto GepUserIt = GepUser++;
  1475. StoreInst *stInst = cast<StoreInst>(*GepUserIt);
  1476. Value *stVal = stInst->getValueOperand();
  1477. VectorType *VT = cast<VectorType>(stVal->getType());
  1478. unsigned eleCount = VT->getNumElements();
  1479. IRBuilder<> Builder(stInst);
  1480. Value *subVal0 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(0));
  1481. Value *subVal1 = Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(1));
  1482. Value *subVal2 = eleCount == 3 ?
  1483. Builder.CreateExtractElement(stVal, hlslOP->GetU32Const(2)) : hlslOP->GetU32Const(0);
  1484. Value *args[] = { opArg, primIdx, subVal0, subVal1, subVal2 };
  1485. Builder.CreateCall(DxilFunc, args);
  1486. stInst->eraseFromParent();
  1487. }
  1488. GEP->eraseFromParent();
  1489. }
  1490. }
  1491. // Generate DXIL EmitIndices operations.
  1492. void HLSignatureLower::GenerateEmitIndicesOperations() {
  1493. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1494. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1495. for (Argument &arg : Entry->getArgumentList()) {
  1496. DxilParameterAnnotation &paramAnnotation =
  1497. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1498. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1499. if (inputQual == DxilParamInputQual::OutIndices) {
  1500. GenerateEmitIndicesOperation(&arg);
  1501. }
  1502. }
  1503. }
  1504. // Generate DXIL GetMeshPayload operation.
  1505. void HLSignatureLower::GenerateGetMeshPayloadOperation() {
  1506. DxilFunctionAnnotation *EntryAnnotation = HLM.GetFunctionAnnotation(Entry);
  1507. DXASSERT(EntryAnnotation, "must find annotation for entry function");
  1508. for (Argument &arg : Entry->getArgumentList()) {
  1509. DxilParameterAnnotation &paramAnnotation =
  1510. EntryAnnotation->GetParameterAnnotation(arg.getArgNo());
  1511. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
  1512. if (inputQual == DxilParamInputQual::InPayload) {
  1513. OP * hlslOP = HLM.GetOP();
  1514. Function *DxilFunc = hlslOP->GetOpFunc(OP::OpCode::GetMeshPayload, arg.getType());
  1515. Constant *opArg = hlslOP->GetU32Const((unsigned)OP::OpCode::GetMeshPayload);
  1516. IRBuilder<> Builder(arg.getParent()->getEntryBlock().getFirstInsertionPt());
  1517. Value *args[] = { opArg };
  1518. Value *payload = Builder.CreateCall(DxilFunc, args);
  1519. arg.replaceAllUsesWith(payload);
  1520. }
  1521. }
  1522. }
  1523. // Lower signatures.
  1524. void HLSignatureLower::Run() {
  1525. DxilFunctionProps &props = HLM.GetDxilFunctionProps(Entry);
  1526. if (props.IsGraphics()) {
  1527. if (props.IsMS()) {
  1528. GenerateEmitIndicesOperations();
  1529. GenerateGetMeshPayloadOperation();
  1530. }
  1531. CreateDxilSignatures();
  1532. // Allocate input output.
  1533. AllocateDxilInputOutputs();
  1534. GenerateDxilInputs();
  1535. GenerateDxilOutputs();
  1536. if (props.IsMS()) {
  1537. GenerateDxilPrimOutputs();
  1538. }
  1539. } else if (props.IsCS()) {
  1540. GenerateDxilCSInputs();
  1541. }
  1542. if (props.IsDS() || props.IsHS())
  1543. GenerateDxilPatchConstantLdSt();
  1544. if (props.IsHS())
  1545. GenerateDxilPatchConstantFunctionInputs();
  1546. if (props.IsGS())
  1547. GenerateStreamOutputOperations();
  1548. }