HLLowerUDT.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLLowerUDT.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Lower user defined type used directly by certain intrinsic operations. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLLowerUDT.h"
  12. #include "dxc/Support/Global.h"
  13. #include "dxc/DXIL/DxilConstants.h"
  14. #include "dxc/HLSL/HLModule.h"
  15. #include "dxc/HLSL/HLOperations.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  18. #include "dxc/HLSL/HLMatrixType.h"
  19. #include "dxc/HlslIntrinsicOp.h"
  20. #include "dxc/DXIL/DxilUtil.h"
  21. #include "HLMatrixSubscriptUseReplacer.h"
  22. #include "llvm/ADT/SmallVector.h"
  23. #include "llvm/IR/CallSite.h"
  24. #include "llvm/IR/Constants.h"
  25. #include "llvm/IR/DebugInfo.h"
  26. #include "llvm/IR/Function.h"
  27. #include "llvm/IR/GlobalVariable.h"
  28. #include "llvm/IR/IRBuilder.h"
  29. #include "llvm/IR/Instructions.h"
  30. #include "llvm/IR/IntrinsicInst.h"
  31. #include "llvm/IR/LLVMContext.h"
  32. #include "llvm/IR/Module.h"
  33. using namespace llvm;
  34. using namespace hlsl;
  35. // Lowered UDT is the same layout, but with vectors and matrices translated to
  36. // arrays.
  37. // Returns nullptr for failure due to embedded HLSL object type.
  38. StructType *hlsl::GetLoweredUDT(StructType *structTy, DxilTypeSystem *pTypeSys) {
  39. bool changed = false;
  40. SmallVector<Type*, 8> NewElTys(structTy->getNumContainedTypes());
  41. for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
  42. Type *FieldTy = structTy->getContainedType(iField);
  43. // Default to original type
  44. NewElTys[iField] = FieldTy;
  45. // Unwrap arrays:
  46. SmallVector<unsigned, 4> OuterToInnerLengths;
  47. Type *EltTy = dxilutil::StripArrayTypes(FieldTy, &OuterToInnerLengths);
  48. Type *NewTy = EltTy;
  49. // Lower element if necessary
  50. if (EltTy->isVectorTy()) {
  51. NewTy = ArrayType::get(EltTy->getVectorElementType(),
  52. EltTy->getVectorNumElements());
  53. } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(EltTy)) {
  54. NewTy = ArrayType::get(Mat.getElementType(/*MemRepr*/true),
  55. Mat.getNumElements());
  56. } else if (dxilutil::IsHLSLObjectType(EltTy) ||
  57. dxilutil::IsHLSLRayQueryType(EltTy)) {
  58. // We cannot lower a structure with an embedded object type
  59. return nullptr;
  60. } else if (StructType *ST = dyn_cast<StructType>(EltTy)) {
  61. NewTy = GetLoweredUDT(ST);
  62. if (nullptr == NewTy)
  63. return nullptr; // Propagate failure back to root
  64. } else if (EltTy->isIntegerTy(1)) {
  65. // Must translate bool to mem type
  66. EltTy = IntegerType::get(EltTy->getContext(), 32);
  67. }
  68. // if unchanged, skip field
  69. if (NewTy == EltTy)
  70. continue;
  71. // Rewrap Arrays:
  72. for (auto itLen = OuterToInnerLengths.rbegin(),
  73. E = OuterToInnerLengths.rend();
  74. itLen != E; ++itLen) {
  75. NewTy = ArrayType::get(NewTy, *itLen);
  76. }
  77. // Update field, and set changed
  78. NewElTys[iField] = NewTy;
  79. changed = true;
  80. }
  81. if (changed) {
  82. StructType *newStructTy = StructType::create(
  83. structTy->getContext(), NewElTys, structTy->getStructName());
  84. if (DxilStructAnnotation *pSA = pTypeSys ?
  85. pTypeSys->GetStructAnnotation(structTy) : nullptr) {
  86. if (!pTypeSys->GetStructAnnotation(newStructTy)) {
  87. DxilStructAnnotation &NewSA = *pTypeSys->AddStructAnnotation(newStructTy);
  88. for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
  89. NewSA.GetFieldAnnotation(iField) = pSA->GetFieldAnnotation(iField);
  90. }
  91. }
  92. }
  93. return newStructTy;
  94. }
  95. return structTy;
  96. }
  97. Constant *hlsl::TranslateInitForLoweredUDT(
  98. Constant *Init, Type *NewTy,
  99. // We need orientation for matrix fields
  100. DxilTypeSystem *pTypeSys,
  101. MatrixOrientation matOrientation) {
  102. // handle undef and zero init
  103. if (isa<UndefValue>(Init))
  104. return UndefValue::get(NewTy);
  105. else if (Init->getType()->isAggregateType() && Init->isZeroValue())
  106. return ConstantAggregateZero::get(NewTy);
  107. // unchanged
  108. Type *Ty = Init->getType();
  109. if (Ty == NewTy)
  110. return Init;
  111. SmallVector<Constant*, 16> values;
  112. if (Ty->isArrayTy()) {
  113. values.reserve(Ty->getArrayNumElements());
  114. ConstantArray *CA = cast<ConstantArray>(Init);
  115. for (unsigned i = 0; i < Ty->getArrayNumElements(); ++i)
  116. values.emplace_back(
  117. TranslateInitForLoweredUDT(
  118. CA->getAggregateElement(i),
  119. NewTy->getArrayElementType(),
  120. pTypeSys, matOrientation));
  121. return ConstantArray::get(cast<ArrayType>(NewTy), values);
  122. } else if (Ty->isVectorTy()) {
  123. values.reserve(Ty->getVectorNumElements());
  124. ConstantVector *CV = cast<ConstantVector>(Init);
  125. for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i)
  126. values.emplace_back(CV->getAggregateElement(i));
  127. return ConstantArray::get(cast<ArrayType>(NewTy), values);
  128. } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(Ty)) {
  129. values.reserve(Mat.getNumElements());
  130. ConstantArray *MatArray = cast<ConstantArray>(
  131. cast<ConstantStruct>(Init)->getOperand(0));
  132. for (unsigned row = 0; row < Mat.getNumRows(); ++row) {
  133. ConstantVector *RowVector = cast<ConstantVector>(
  134. MatArray->getOperand(row));
  135. for (unsigned col = 0; col < Mat.getNumColumns(); ++col) {
  136. unsigned index = matOrientation == MatrixOrientation::ColumnMajor ?
  137. Mat.getColumnMajorIndex(row, col) : Mat.getRowMajorIndex(row, col);
  138. values[index] = RowVector->getOperand(col);
  139. }
  140. }
  141. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  142. DxilStructAnnotation *pStructAnnotation =
  143. pTypeSys ? pTypeSys->GetStructAnnotation(ST) : nullptr;
  144. values.reserve(ST->getNumContainedTypes());
  145. ConstantStruct *CS = cast<ConstantStruct>(Init);
  146. for (unsigned i = 0; i < ST->getStructNumElements(); ++i) {
  147. MatrixOrientation matFieldOrientation = matOrientation;
  148. if (pStructAnnotation) {
  149. DxilFieldAnnotation &FA = pStructAnnotation->GetFieldAnnotation(i);
  150. if (FA.HasMatrixAnnotation()) {
  151. matFieldOrientation = FA.GetMatrixAnnotation().Orientation;
  152. }
  153. }
  154. values.emplace_back(
  155. TranslateInitForLoweredUDT(
  156. cast<Constant>(CS->getAggregateElement(i)),
  157. NewTy->getStructElementType(i),
  158. pTypeSys, matFieldOrientation));
  159. }
  160. return ConstantStruct::get(cast<StructType>(NewTy), values);
  161. }
  162. return Init;
  163. }
  164. void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
  165. Type *Ty = V->getType();
  166. Type *NewTy = NewV->getType();
  167. if (Ty == NewTy) {
  168. V->replaceAllUsesWith(NewV);
  169. if (Instruction *I = dyn_cast<Instruction>(V))
  170. I->dropAllReferences();
  171. if (Constant *CV = dyn_cast<Constant>(V))
  172. CV->removeDeadConstantUsers();
  173. return;
  174. }
  175. if (Ty->isPointerTy())
  176. Ty = Ty->getPointerElementType();
  177. if (NewTy->isPointerTy())
  178. NewTy = NewTy->getPointerElementType();
  179. while (!V->use_empty()) {
  180. Use &use = *V->use_begin();
  181. User *user = use.getUser();
  182. // Clear use to prevent infinite loop on unhandled case.
  183. use.set(UndefValue::get(V->getType()));
  184. if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
  185. // Load for non-matching type should only be vector
  186. DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
  187. Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
  188. "unexpected load of non-matching type");
  189. IRBuilder<> Builder(LI);
  190. Value *result = UndefValue::get(Ty);
  191. for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
  192. Value *GEP = Builder.CreateInBoundsGEP(NewV,
  193. {Builder.getInt32(0), Builder.getInt32(i)});
  194. Value *El = Builder.CreateLoad(GEP);
  195. result = Builder.CreateInsertElement(result, El, i);
  196. }
  197. LI->replaceAllUsesWith(result);
  198. LI->eraseFromParent();
  199. } else if (StoreInst *SI = dyn_cast<StoreInst>(user)) {
  200. // Store for non-matching type should only be vector
  201. DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
  202. Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
  203. "unexpected load of non-matching type");
  204. IRBuilder<> Builder(SI);
  205. for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
  206. Value *EE = Builder.CreateExtractElement(SI->getValueOperand(), i);
  207. Value *GEP = Builder.CreateInBoundsGEP(
  208. NewV, {Builder.getInt32(0), Builder.getInt32(i)});
  209. Builder.CreateStore(EE, GEP);
  210. }
  211. SI->eraseFromParent();
  212. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
  213. // Non-constant GEP
  214. IRBuilder<> Builder(GEP);
  215. SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  216. Value *NewGEP = Builder.CreateGEP(NewV, idxList);
  217. ReplaceUsesForLoweredUDT(GEP, NewGEP);
  218. GEP->eraseFromParent();
  219. } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(user)) {
  220. // Has to be constant GEP, NewV better be constant
  221. SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
  222. Constant *NewGEP = ConstantExpr::getGetElementPtr(
  223. nullptr, cast<Constant>(NewV), idxList, true);
  224. ReplaceUsesForLoweredUDT(GEP, NewGEP);
  225. GEP->dropAllReferences();
  226. } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
  227. // Address space cast
  228. IRBuilder<> Builder(AC);
  229. unsigned AddrSpace = AC->getType()->getPointerAddressSpace();
  230. Value *NewAC = Builder.CreateAddrSpaceCast(
  231. NewV, PointerType::get(NewTy, AddrSpace));
  232. ReplaceUsesForLoweredUDT(user, NewAC);
  233. AC->eraseFromParent();
  234. } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
  235. IRBuilder<> Builder(BC);
  236. if (BC->getType()->getPointerElementType() == NewTy) {
  237. // if alreday bitcast to new type, just replace the bitcast
  238. // with the new value (already translated user function)
  239. BC->replaceAllUsesWith(NewV);
  240. } else {
  241. // Could be i8 for memcpy?
  242. // Replace bitcast argument with new value
  243. use.set(NewV);
  244. }
  245. } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
  246. // Constant AddrSpaceCast, or BitCast
  247. if (CE->getOpcode() == Instruction::AddrSpaceCast) {
  248. unsigned AddrSpace = CE->getType()->getPointerAddressSpace();
  249. ReplaceUsesForLoweredUDT(user,
  250. ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
  251. PointerType::get(NewTy, AddrSpace)));
  252. } else if (CE->getOpcode() == Instruction::BitCast) {
  253. if (CE->getType()->getPointerElementType() == NewTy) {
  254. // if alreday bitcast to new type, just replace the bitcast
  255. // with the new value
  256. CE->replaceAllUsesWith(NewV);
  257. } else {
  258. // Could be i8 for memcpy?
  259. // Replace bitcast argument with new value
  260. use.set(NewV);
  261. }
  262. } else {
  263. DXASSERT(0, "unhandled constant expr for lowered UTD");
  264. CE->dropAllReferences(); // better than infinite loop on release
  265. }
  266. } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
  267. // Lower some matrix intrinsics that access pointers early, and
  268. // cast arguments for user functions or special UDT intrinsics
  269. // for later translation.
  270. Function *F = CI->getCalledFunction();
  271. HLOpcodeGroup group = GetHLOpcodeGroupByName(F);
  272. HLMatrixType Mat = HLMatrixType::dyn_cast(Ty);
  273. bool bColMajor = false;
  274. switch (group) {
  275. case HLOpcodeGroup::HLMatLoadStore: {
  276. DXASSERT(Mat, "otherwise, matrix operation on non-matrix value");
  277. IRBuilder<> Builder(CI);
  278. HLMatLoadStoreOpcode opcode =
  279. static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
  280. switch (opcode) {
  281. case HLMatLoadStoreOpcode::ColMatLoad:
  282. bColMajor = true;
  283. __fallthrough;
  284. case HLMatLoadStoreOpcode::RowMatLoad: {
  285. Value *val = UndefValue::get(
  286. VectorType::get(NewTy->getArrayElementType(),
  287. NewTy->getArrayNumElements()));
  288. for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
  289. Value *GEP = Builder.CreateGEP(NewV,
  290. {Builder.getInt32(0), Builder.getInt32(i)});
  291. Value *elt = Builder.CreateLoad(GEP);
  292. val = Builder.CreateInsertElement(val, elt, i);
  293. }
  294. if (bColMajor) {
  295. // transpose matrix to match expected value orientation for
  296. // default cast to matrix type
  297. SmallVector<int, 16> ShuffleIndices;
  298. for (unsigned RowIdx = 0; RowIdx < Mat.getNumRows(); ++RowIdx)
  299. for (unsigned ColIdx = 0; ColIdx < Mat.getNumColumns(); ++ColIdx)
  300. ShuffleIndices.emplace_back(
  301. static_cast<int>(Mat.getColumnMajorIndex(RowIdx, ColIdx)));
  302. val = Builder.CreateShuffleVector(val, val, ShuffleIndices);
  303. }
  304. // lower mem to reg type
  305. val = Mat.emitLoweredMemToReg(val, Builder);
  306. // cast vector back to matrix value (DefaultCast expects row major)
  307. unsigned newOpcode = (unsigned)HLCastOpcode::DefaultCast;
  308. val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
  309. Ty, { Builder.getInt32(newOpcode), val }, Builder);
  310. if (bColMajor) {
  311. // emit cast row to col to match original result
  312. newOpcode = (unsigned)HLCastOpcode::RowMatrixToColMatrix;
  313. val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
  314. Ty, { Builder.getInt32(newOpcode), val }, Builder);
  315. }
  316. // replace use of HLMatLoadStore with loaded vector
  317. CI->replaceAllUsesWith(val);
  318. } break;
  319. case HLMatLoadStoreOpcode::ColMatStore:
  320. bColMajor = true;
  321. __fallthrough;
  322. case HLMatLoadStoreOpcode::RowMatStore: {
  323. // HLCast matrix value to vector
  324. unsigned newOpcode = (unsigned)(bColMajor ?
  325. HLCastOpcode::ColMatrixToVecCast :
  326. HLCastOpcode::RowMatrixToVecCast);
  327. Value *val = callHLFunction(*F->getParent(),
  328. HLOpcodeGroup::HLCast, newOpcode,
  329. Mat.getLoweredVectorType(false),
  330. { Builder.getInt32(newOpcode),
  331. CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx) },
  332. Builder);
  333. // lower reg to mem type
  334. val = Mat.emitLoweredRegToMem(val, Builder);
  335. for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
  336. Value *elt = Builder.CreateExtractElement(val, i);
  337. Value *GEP = Builder.CreateGEP(NewV,
  338. {Builder.getInt32(0), Builder.getInt32(i)});
  339. Builder.CreateStore(elt, GEP);
  340. }
  341. } break;
  342. default:
  343. DXASSERT(0, "invalid opcode");
  344. }
  345. CI->eraseFromParent();
  346. } break;
  347. case HLOpcodeGroup::HLSubscript: {
  348. SmallVector<Value*, 4> ElemIndices;
  349. HLSubscriptOpcode opcode =
  350. static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
  351. switch (opcode) {
  352. case HLSubscriptOpcode::VectorSubscript:
  353. DXASSERT(0, "not handled yet");
  354. break;
  355. case HLSubscriptOpcode::ColMatElement:
  356. bColMajor = true;
  357. __fallthrough;
  358. case HLSubscriptOpcode::RowMatElement: {
  359. ConstantDataSequential *cIdx = cast<ConstantDataSequential>(
  360. CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
  361. for (unsigned i = 0; i < cIdx->getNumElements(); ++i) {
  362. ElemIndices.push_back(cIdx->getElementAsConstant(i));
  363. }
  364. } break;
  365. case HLSubscriptOpcode::ColMatSubscript:
  366. bColMajor = true;
  367. __fallthrough;
  368. case HLSubscriptOpcode::RowMatSubscript: {
  369. for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < CI->getNumArgOperands(); ++Idx) {
  370. ElemIndices.emplace_back(CI->getArgOperand(Idx));
  371. }
  372. } break;
  373. default:
  374. DXASSERT(0, "invalid opcode");
  375. }
  376. std::vector<Instruction*> DeadInsts;
  377. HLMatrixSubscriptUseReplacer UseReplacer(
  378. CI, NewV, /*TempLoweredMatrix*/nullptr, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
  379. DXASSERT(CI->use_empty(),
  380. "Expected all matrix subscript uses to have been replaced.");
  381. CI->eraseFromParent();
  382. while (!DeadInsts.empty()) {
  383. DeadInsts.back()->eraseFromParent();
  384. DeadInsts.pop_back();
  385. }
  386. } break;
  387. //case HLOpcodeGroup::NotHL: // TODO: Support lib functions
  388. case HLOpcodeGroup::HLIntrinsic: {
  389. // Just bitcast for now
  390. IRBuilder<> Builder(CI);
  391. use.set(Builder.CreateBitCast(NewV, V->getType()));
  392. continue;
  393. } break;
  394. default:
  395. DXASSERT(0, "invalid opcode");
  396. }
  397. } else {
  398. // What else?
  399. DXASSERT(false, "case not handled.");
  400. }
  401. }
  402. }