HLLegalizeParameter.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLLegalizeParameter.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Legalize in parameter has write and out parameter has read. //
  9. // Must be call before inline pass. //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLModule.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilUtil.h"
  14. #include "dxc/HLSL/DxilGenerationPass.h"
  15. #include "dxc/HLSL/HLUtil.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "llvm/IR/IntrinsicInst.h"
  18. #include "dxc/Support/Global.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/ADT/ArrayRef.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/IR/Constant.h"
  23. #include "llvm/IR/Constants.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/Instruction.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IRBuilder.h"
  28. #include "llvm/IR/Module.h"
  29. #include "llvm/Support/Casting.h"
  30. #include <vector>
  31. using namespace llvm;
  32. using namespace hlsl;
  33. // For parameter need to legalize, create alloca to replace all uses of it, and copy between the alloca and the parameter.
  34. namespace {
  35. class HLLegalizeParameter : public ModulePass {
  36. public:
  37. static char ID;
  38. explicit HLLegalizeParameter() : ModulePass(ID) {}
  39. bool runOnModule(Module &M) override;
  40. private:
  41. void patchWriteOnInParam(Function &F, Argument &Arg, const DataLayout &DL);
  42. void patchReadOnOutParam(Function &F, Argument &Arg, const DataLayout &DL);
  43. };
  44. AllocaInst *createAllocaForPatch(Function &F, Type *Ty) {
  45. IRBuilder<> Builder(F.getEntryBlock().getFirstInsertionPt());
  46. return Builder.CreateAlloca(Ty);
  47. }
  48. void copyIn(AllocaInst *temp, Value *arg, CallInst *CI, unsigned size) {
  49. if (size == 0)
  50. return;
  51. // Copy arg to temp before CI.
  52. IRBuilder<> Builder(CI);
  53. Builder.CreateMemCpy(temp, arg, size, 1);
  54. }
  55. void copyOut(AllocaInst *temp, Value *arg, CallInst *CI, unsigned size) {
  56. if (size == 0)
  57. return;
  58. // Copy temp to arg after CI.
  59. IRBuilder<> Builder(CI->getNextNode());
  60. Builder.CreateMemCpy(arg, temp, size, 1);
  61. }
  62. bool isPointerNeedToLower(Value *V, Type *HandleTy) {
  63. // CBuffer, Buffer, Texture....
  64. // Anything related to dxil op.
  65. // hl.subscript.
  66. // Got to root of GEP.
  67. while (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
  68. V = GEP->getPointerOperand();
  69. }
  70. CallInst *CI = dyn_cast<CallInst>(V);
  71. if (!CI) {
  72. // If array of vector, we need a copy to handle vector to array in LowerTypePasses.
  73. Type *Ty = V->getType();
  74. if (Ty->isPointerTy())
  75. Ty = Ty->getPointerElementType();
  76. if (!Ty->isArrayTy())
  77. return false;
  78. while (Ty->isArrayTy()) {
  79. Ty = Ty->getArrayElementType();
  80. }
  81. return Ty->isVectorTy();
  82. }
  83. HLOpcodeGroup group = GetHLOpcodeGroup(CI->getCalledFunction());
  84. if (group != HLOpcodeGroup::HLSubscript)
  85. return false;
  86. Value *Ptr = CI->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx);
  87. // Ptr from resource handle.
  88. if (Ptr->getType() == HandleTy)
  89. return true;
  90. unsigned Opcode = GetHLOpcode(CI);
  91. // Ptr from cbuffer.
  92. if (Opcode == (unsigned)HLSubscriptOpcode::CBufferSubscript)
  93. return true;
  94. return isPointerNeedToLower(Ptr, HandleTy);
  95. }
  96. bool mayAliasWithGlobal(Value *V, CallInst *CallSite, std::vector<GlobalVariable *> &staticGVs) {
  97. // The unsafe case need copy-in copy-out will be global variable alias with
  98. // parameter. Then global variable is updated in the function, the parameter
  99. // will be updated silently.
  100. // Currently add copy for all non-const static global in
  101. // CGMSHLSLRuntime::EmitHLSLOutParamConversionInit.
  102. //So here just return false and do nothing.
  103. // For case like
  104. // struct T {
  105. // float4 a[10];
  106. //};
  107. // static T g;
  108. // void foo(inout T t) {
  109. // // modify g
  110. //}
  111. // void bar() {
  112. // T t = g;
  113. // // Not copy because t is local.
  114. // // But optimizations will change t to g later.
  115. // foo(t);
  116. //}
  117. // Optimizations which remove the copy should not replace foo(t) into foo(g)
  118. // when g could be modified.
  119. // TODO: remove copy for global in
  120. // CGMSHLSLRuntime::EmitHLSLOutParamConversionInit, do analysis to check alias
  121. // only generate copy when there's alias.
  122. return false;
  123. }
  124. struct CopyData {
  125. CallInst *CallSite;
  126. Value *Arg;
  127. bool bCopyIn;
  128. bool bCopyOut;
  129. };
  130. void ParameterCopyInCopyOut(hlsl::HLModule &HLM) {
  131. Module &M = *HLM.GetModule();
  132. Type *HandleTy = HLM.GetOP()->GetHandleType();
  133. const DataLayout &DL = M.getDataLayout();
  134. std::vector<GlobalVariable *> staticGVs;
  135. for (GlobalVariable &GV : M.globals()) {
  136. if (dxilutil::IsStaticGlobal(&GV) && !GV.isConstant()) {
  137. staticGVs.emplace_back(&GV);
  138. }
  139. }
  140. SmallVector<CopyData, 4> WorkList;
  141. for (Function &F : M) {
  142. if (F.user_empty())
  143. continue;
  144. DxilFunctionAnnotation *Annot = HLM.GetFunctionAnnotation(&F);
  145. // Skip functions don't have annotation, include llvm intrinsic and HLOp
  146. // functions.
  147. if (!Annot)
  148. continue;
  149. bool bNoInline = F.hasFnAttribute(llvm::Attribute::NoInline) || F.isDeclaration();
  150. for (User *U : F.users()) {
  151. CallInst *CI = dyn_cast<CallInst>(U);
  152. if (!CI)
  153. continue;
  154. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  155. Value *arg = CI->getArgOperand(i);
  156. if (!arg->getType()->isPointerTy())
  157. continue;
  158. DxilParameterAnnotation &ParamAnnot = Annot->GetParameterAnnotation(i);
  159. bool bCopyIn = false;
  160. bool bCopyOut = false;
  161. switch (ParamAnnot.GetParamInputQual()) {
  162. default:
  163. break;
  164. case DxilParamInputQual::In: {
  165. bCopyIn = true;
  166. } break;
  167. case DxilParamInputQual::Out: {
  168. bCopyOut = true;
  169. } break;
  170. case DxilParamInputQual::Inout: {
  171. bCopyIn = true;
  172. bCopyOut = true;
  173. } break;
  174. }
  175. if (!bCopyIn && !bCopyOut)
  176. continue;
  177. // When use ptr from cbuffer/buffer, need copy to avoid lower on user
  178. // function.
  179. bool bNeedCopy = mayAliasWithGlobal(arg, CI, staticGVs);
  180. if (bNoInline)
  181. bNeedCopy |= isPointerNeedToLower(arg, HandleTy);
  182. if (!bNeedCopy)
  183. continue;
  184. CopyData data = {CI, arg, bCopyIn, bCopyOut};
  185. WorkList.emplace_back(data);
  186. }
  187. }
  188. }
  189. for (CopyData &data : WorkList) {
  190. CallInst *CI = data.CallSite;
  191. Value *arg = data.Arg;
  192. Type *Ty = arg->getType()->getPointerElementType();
  193. Type *EltTy = dxilutil::GetArrayEltTy(Ty);
  194. // Skip on object type and resource type.
  195. if (dxilutil::IsHLSLObjectType(EltTy) ||
  196. dxilutil::IsHLSLResourceType(EltTy))
  197. continue;
  198. unsigned size = DL.getTypeAllocSize(Ty);
  199. AllocaInst *temp = createAllocaForPatch(*CI->getParent()->getParent(), Ty);
  200. // TODO: Adding lifetime intrinsics isn't easy here, have to analyze uses.
  201. if (data.bCopyIn)
  202. copyIn(temp, arg, CI, size);
  203. if (data.bCopyOut)
  204. copyOut(temp, arg, CI, size);
  205. CI->replaceUsesOfWith(arg, temp);
  206. }
  207. }
  208. } // namespace
  209. bool HLLegalizeParameter::runOnModule(Module &M) {
  210. HLModule &HLM = M.GetOrCreateHLModule();
  211. auto &typeSys = HLM.GetTypeSystem();
  212. const DataLayout &DL = M.getDataLayout();
  213. for (Function &F : M) {
  214. if (F.isDeclaration())
  215. continue;
  216. DxilFunctionAnnotation *Annot = HLM.GetFunctionAnnotation(&F);
  217. if (!Annot)
  218. continue;
  219. for (Argument &Arg : F.args()) {
  220. if (!Arg.getType()->isPointerTy())
  221. continue;
  222. Type *EltTy = dxilutil::GetArrayEltTy(Arg.getType());
  223. if (dxilutil::IsHLSLObjectType(EltTy) ||
  224. dxilutil::IsHLSLResourceType(EltTy))
  225. continue;
  226. DxilParameterAnnotation &ParamAnnot =
  227. Annot->GetParameterAnnotation(Arg.getArgNo());
  228. switch (ParamAnnot.GetParamInputQual()) {
  229. default:
  230. break;
  231. case DxilParamInputQual::In: {
  232. hlutil::PointerStatus PS(&Arg, 0, /*bLdStOnly*/ true);
  233. PS.analyze(typeSys, /*bStructElt*/ false);
  234. if (PS.HasStored()) {
  235. patchWriteOnInParam(F, Arg, DL);
  236. }
  237. } break;
  238. case DxilParamInputQual::Out: {
  239. hlutil::PointerStatus PS(&Arg, 0, /*bLdStOnly*/ true);
  240. PS.analyze(typeSys, /*bStructElt*/false);
  241. if (PS.HasLoaded()) {
  242. patchReadOnOutParam(F, Arg, DL);
  243. }
  244. }
  245. }
  246. }
  247. }
  248. // Copy-in copy-out for ptr arg when need.
  249. ParameterCopyInCopyOut(HLM);
  250. return true;
  251. }
  252. void HLLegalizeParameter::patchWriteOnInParam(Function &F, Argument &Arg,
  253. const DataLayout &DL) {
  254. // TODO: Adding lifetime intrinsics isn't easy here, have to analyze uses.
  255. Type *Ty = Arg.getType()->getPointerElementType();
  256. AllocaInst *temp = createAllocaForPatch(F, Ty);
  257. Arg.replaceAllUsesWith(temp);
  258. IRBuilder<> Builder(temp->getNextNode());
  259. unsigned size = DL.getTypeAllocSize(Ty);
  260. // copy arg to temp at beginning of function.
  261. Builder.CreateMemCpy(temp, &Arg, size, 1);
  262. }
  263. void HLLegalizeParameter::patchReadOnOutParam(Function &F, Argument &Arg,
  264. const DataLayout &DL) {
  265. // TODO: Adding lifetime intrinsics isn't easy here, have to analyze uses.
  266. Type *Ty = Arg.getType()->getPointerElementType();
  267. AllocaInst *temp = createAllocaForPatch(F, Ty);
  268. Arg.replaceAllUsesWith(temp);
  269. unsigned size = DL.getTypeAllocSize(Ty);
  270. for (auto &BB : F.getBasicBlockList()) {
  271. // copy temp to arg before every return.
  272. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  273. IRBuilder<> RetBuilder(RI);
  274. RetBuilder.CreateMemCpy(&Arg, temp, size, 1);
  275. }
  276. }
  277. }
  278. char HLLegalizeParameter::ID = 0;
  279. ModulePass *llvm::createHLLegalizeParameter() {
  280. return new HLLegalizeParameter();
  281. }
  282. INITIALIZE_PASS(HLLegalizeParameter, "hl-legalize-parameter",
  283. "Legalize parameter", false, false)