HLLegalizeParameter.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLLegalizeParameter.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Legalize in parameter has write and out parameter has read. //
  9. // Must be call before inline pass. //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/HLSL/HLModule.h"
  12. #include "dxc/DXIL/DxilOperations.h"
  13. #include "dxc/DXIL/DxilUtil.h"
  14. #include "dxc/HLSL/DxilGenerationPass.h"
  15. #include "dxc/HLSL/HLUtil.h"
  16. #include "dxc/DXIL/DxilTypeSystem.h"
  17. #include "llvm/IR/IntrinsicInst.h"
  18. #include "dxc/Support/Global.h"
  19. #include "llvm/Pass.h"
  20. #include "llvm/ADT/ArrayRef.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include "llvm/IR/Constant.h"
  23. #include "llvm/IR/Constants.h"
  24. #include "llvm/IR/Function.h"
  25. #include "llvm/IR/Instruction.h"
  26. #include "llvm/IR/Instructions.h"
  27. #include "llvm/IR/IRBuilder.h"
  28. #include "llvm/IR/Module.h"
  29. #include "llvm/Support/Casting.h"
  30. #include <vector>
  31. using namespace llvm;
  32. using namespace hlsl;
  33. // For parameter need to legalize, create alloca to replace all uses of it, and copy between the alloca and the parameter.
  34. namespace {
  35. class HLLegalizeParameter : public ModulePass {
  36. public:
  37. static char ID;
  38. explicit HLLegalizeParameter() : ModulePass(ID) {}
  39. bool runOnModule(Module &M) override;
  40. private:
  41. void patchWriteOnInParam(Function &F, Argument &Arg, const DataLayout &DL);
  42. void patchReadOnOutParam(Function &F, Argument &Arg, const DataLayout &DL);
  43. };
  44. AllocaInst *createAllocaForPatch(Function &F, Type *Ty) {
  45. IRBuilder<> Builder(F.getEntryBlock().getFirstInsertionPt());
  46. return Builder.CreateAlloca(Ty);
  47. }
  48. void copyIn(AllocaInst *temp, Value *arg, CallInst *CI, unsigned size) {
  49. if (size == 0)
  50. return;
  51. // copy arg to temp befor CI.
  52. IRBuilder<> Builder(CI);
  53. Builder.CreateMemCpy(temp, arg, size, 1);
  54. }
  55. void copyOut(AllocaInst *temp, Value *arg, CallInst *CI, unsigned size) {
  56. if (size == 0)
  57. return;
  58. // copy temp to arg after CI.
  59. IRBuilder<> Builder(CI->getNextNode());
  60. Builder.CreateMemCpy(arg, temp, size, 1);
  61. }
  62. bool isPointerNeedToLower(Value *V, Type *HandleTy) {
  63. // CBuffer, Buffer, Texture....
  64. // Anything related to dxil op.
  65. // hl.subscript.
  66. // Got to root of GEP.
  67. while (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
  68. V = GEP->getPointerOperand();
  69. }
  70. CallInst *CI = dyn_cast<CallInst>(V);
  71. if (!CI)
  72. return false;
  73. HLOpcodeGroup group = GetHLOpcodeGroup(CI->getCalledFunction());
  74. if (group != HLOpcodeGroup::HLSubscript)
  75. return false;
  76. Value *Ptr = CI->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx);
  77. // Ptr from resource handle.
  78. if (Ptr->getType() == HandleTy)
  79. return true;
  80. unsigned Opcode = GetHLOpcode(CI);
  81. // Ptr from cbuffer.
  82. if (Opcode == (unsigned)HLSubscriptOpcode::CBufferSubscript)
  83. return true;
  84. return isPointerNeedToLower(Ptr, HandleTy);
  85. }
  86. bool mayAliasWithGlobal(Value *V, CallInst *CallSite, std::vector<GlobalVariable *> &staticGVs) {
  87. // The unsafe case need copy-in copy-out will be global variable alias with
  88. // parameter. Then global variable is updated in the function, the parameter
  89. // will be updated silently.
  90. // Currently add copy for all non-const static global in
  91. // CGMSHLSLRuntime::EmitHLSLOutParamConversionInit.
  92. //So here just return false and do nothing.
  93. // For case like
  94. // struct T {
  95. // float4 a[10];
  96. //};
  97. // static T g;
  98. // void foo(inout T t) {
  99. // // modify g
  100. //}
  101. // void bar() {
  102. // T t = g;
  103. // // Not copy because t is local.
  104. // // But optimizations will change t to g later.
  105. // foo(t);
  106. //}
  107. // Optimizations which remove the copy should not replace foo(t) into foo(g)
  108. // when g could be modified.
  109. // TODO: remove copy for global in
  110. // CGMSHLSLRuntime::EmitHLSLOutParamConversionInit, do analysis to check alias
  111. // only generate copy when there's alias.
  112. return false;
  113. }
  114. struct CopyData {
  115. CallInst *CallSite;
  116. Value *Arg;
  117. bool bCopyIn;
  118. bool bCopyOut;
  119. };
  120. void ParameterCopyInCopyOut(hlsl::HLModule &HLM) {
  121. Module &M = *HLM.GetModule();
  122. Type *HandleTy = HLM.GetOP()->GetHandleType();
  123. const DataLayout &DL = M.getDataLayout();
  124. std::vector<GlobalVariable *> staticGVs;
  125. for (GlobalVariable &GV : M.globals()) {
  126. if (dxilutil::IsStaticGlobal(&GV) && !GV.isConstant()) {
  127. staticGVs.emplace_back(&GV);
  128. }
  129. }
  130. SmallVector<CopyData, 4> WorkList;
  131. for (Function &F : M) {
  132. if (F.user_empty())
  133. continue;
  134. DxilFunctionAnnotation *Annot = HLM.GetFunctionAnnotation(&F);
  135. // Skip functions don't have annotation, include llvm intrinsic and HLOp
  136. // functions.
  137. if (!Annot)
  138. continue;
  139. bool bNoInline = F.hasFnAttribute(llvm::Attribute::NoInline) || F.isDeclaration();
  140. for (User *U : F.users()) {
  141. CallInst *CI = dyn_cast<CallInst>(U);
  142. if (!CI)
  143. continue;
  144. for (unsigned i = 0; i < CI->getNumArgOperands(); i++) {
  145. Value *arg = CI->getArgOperand(i);
  146. if (!arg->getType()->isPointerTy())
  147. continue;
  148. DxilParameterAnnotation &ParamAnnot = Annot->GetParameterAnnotation(i);
  149. bool bCopyIn = false;
  150. bool bCopyOut = false;
  151. switch (ParamAnnot.GetParamInputQual()) {
  152. default:
  153. break;
  154. case DxilParamInputQual::In: {
  155. bCopyIn = true;
  156. } break;
  157. case DxilParamInputQual::Out: {
  158. bCopyOut = true;
  159. } break;
  160. case DxilParamInputQual::Inout: {
  161. bCopyIn = true;
  162. bCopyOut = true;
  163. } break;
  164. }
  165. if (!bCopyIn && !bCopyOut)
  166. continue;
  167. // When use ptr from cbuffer/buffer, need copy to avoid lower on user
  168. // function.
  169. bool bNeedCopy = mayAliasWithGlobal(arg, CI, staticGVs);
  170. if (bNoInline)
  171. bNeedCopy |= isPointerNeedToLower(arg, HandleTy);
  172. if (!bNeedCopy)
  173. continue;
  174. CopyData data = {CI, arg, bCopyIn, bCopyOut};
  175. WorkList.emplace_back(data);
  176. }
  177. }
  178. }
  179. for (CopyData &data : WorkList) {
  180. CallInst *CI = data.CallSite;
  181. Value *arg = data.Arg;
  182. Type *Ty = arg->getType()->getPointerElementType();
  183. Type *EltTy = dxilutil::GetArrayEltTy(Ty);
  184. // Skip on object type and resource type.
  185. if (dxilutil::IsHLSLObjectType(EltTy) ||
  186. dxilutil::IsHLSLResourceType(EltTy))
  187. continue;
  188. unsigned size = DL.getTypeAllocSize(Ty);
  189. AllocaInst *temp = createAllocaForPatch(*CI->getParent()->getParent(), Ty);
  190. if (data.bCopyIn)
  191. copyIn(temp, arg, CI, size);
  192. if (data.bCopyOut)
  193. copyOut(temp, arg, CI, size);
  194. CI->replaceUsesOfWith(arg, temp);
  195. }
  196. }
  197. } // namespace
  198. bool HLLegalizeParameter::runOnModule(Module &M) {
  199. HLModule &HLM = M.GetOrCreateHLModule();
  200. // TODO: enable avoid copy for lib profile.
  201. if (HLM.GetShaderModel()->IsLib())
  202. return false;
  203. auto &typeSys = HLM.GetTypeSystem();
  204. const DataLayout &DL = M.getDataLayout();
  205. for (Function &F : M) {
  206. if (F.isDeclaration())
  207. continue;
  208. DxilFunctionAnnotation *Annot = HLM.GetFunctionAnnotation(&F);
  209. if (!Annot)
  210. continue;
  211. for (Argument &Arg : F.args()) {
  212. if (!Arg.getType()->isPointerTy())
  213. continue;
  214. Type *EltTy = dxilutil::GetArrayEltTy(Arg.getType());
  215. if (dxilutil::IsHLSLObjectType(EltTy) ||
  216. dxilutil::IsHLSLResourceType(EltTy))
  217. continue;
  218. DxilParameterAnnotation &ParamAnnot =
  219. Annot->GetParameterAnnotation(Arg.getArgNo());
  220. switch (ParamAnnot.GetParamInputQual()) {
  221. default:
  222. break;
  223. case DxilParamInputQual::In: {
  224. hlutil::PointerStatus PS(&Arg, 0, /*bLdStOnly*/ true);
  225. PS.analyze(typeSys, /*bStructElt*/ false);
  226. if (PS.HasStored()) {
  227. patchWriteOnInParam(F, Arg, DL);
  228. }
  229. } break;
  230. case DxilParamInputQual::Out: {
  231. hlutil::PointerStatus PS(&Arg, 0, /*bLdStOnly*/ true);
  232. PS.analyze(typeSys, /*bStructElt*/false);
  233. if (PS.HasLoaded()) {
  234. patchReadOnOutParam(F, Arg, DL);
  235. }
  236. }
  237. }
  238. }
  239. }
  240. // Copy-in copy-out for ptr arg when need.
  241. ParameterCopyInCopyOut(HLM);
  242. return true;
  243. }
  244. void HLLegalizeParameter::patchWriteOnInParam(Function &F, Argument &Arg,
  245. const DataLayout &DL) {
  246. Type *Ty = Arg.getType()->getPointerElementType();
  247. AllocaInst *temp = createAllocaForPatch(F, Ty);
  248. Arg.replaceAllUsesWith(temp);
  249. IRBuilder<> Builder(temp->getNextNode());
  250. unsigned size = DL.getTypeAllocSize(Ty);
  251. // copy arg to temp at beginning of function.
  252. Builder.CreateMemCpy(temp, &Arg, size, 1);
  253. }
  254. void HLLegalizeParameter::patchReadOnOutParam(Function &F, Argument &Arg,
  255. const DataLayout &DL) {
  256. Type *Ty = Arg.getType()->getPointerElementType();
  257. AllocaInst *temp = createAllocaForPatch(F, Ty);
  258. Arg.replaceAllUsesWith(temp);
  259. unsigned size = DL.getTypeAllocSize(Ty);
  260. for (auto &BB : F.getBasicBlockList()) {
  261. // copy temp to arg before every return.
  262. if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
  263. IRBuilder<> RetBuilder(RI);
  264. RetBuilder.CreateMemCpy(&Arg, temp, size, 1);
  265. }
  266. }
  267. }
  268. char HLLegalizeParameter::ID = 0;
  269. ModulePass *llvm::createHLLegalizeParameter() {
  270. return new HLLegalizeParameter();
  271. }
  272. INITIALIZE_PASS(HLLegalizeParameter, "hl-legalize-parameter",
  273. "Legalize parameter", false, false)