StateFunctionTransform.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. #pragma once
  2. #include "llvm/ADT/DenseMap.h"
  3. #include "llvm/ADT/SetVector.h"
  4. #include <map>
  5. #include <string>
  6. #include <vector>
  7. namespace llvm
  8. {
  9. class AllocaInst;
  10. class BasicBlock;
  11. class CallInst;
  12. class Function;
  13. class FunctionType;
  14. class Instruction;
  15. class Module;
  16. class raw_ostream;
  17. class ReturnInst;
  18. class StructType;
  19. class Type;
  20. class Value;
  21. }
  22. class LiveValues;
  23. typedef std::vector<llvm::BasicBlock*> BasicBlockVector;
  24. typedef llvm::SetVector<llvm::Instruction*> InstructionSetVector;
  25. //==============================================================================
  26. // Transforms the given function into a number of state functions to be
  27. // used in a state machine.
  28. //
  29. // State functions have the following signature:
  30. // int (<RuntimeDataTy> runtimeData).
  31. // They take an runtime data argument with a given type used by the runtime and
  32. // return the state ID of the next state. If the function contains calls to other
  33. // candidate functions that are to be transformed into state functions, the
  34. // function is split into multiple substate functions at call sites and the calls
  35. // are replaced with continuations. For example candidate funcA() calling candidate
  36. // funcB():
  37. // void funcA(int param0)
  38. // {
  39. // // code moved to funcA_ss0()
  40. // int foo = 10;
  41. // ...
  42. //
  43. // funcB(arg0, arg1);
  44. //
  45. // // code moved to funcA_ss1()
  46. // int bar = someFunc(foo);
  47. //
  48. // }
  49. // will be split into two substate functions, funcA_ss0() and funcA_ss1().
  50. // funcA_ss0() pushes the stateID for funcA_ss1() onto the stack, and
  51. // returns the state ID for the entry substate of funcB, funcB_ss0().
  52. // A substate of funcB will eventually pop the stack and return the state ID
  53. // for funcA_ss1(). funcA_ss1() in turn pops the stack to get the state ID
  54. // placed there by its caller.
  55. //
  56. // If candidate functions, like funcB(), have arguments they are moved to the stack.
  57. // Any values that are live across continuations, like foo in this example,
  58. // must also be saved to the stack before the continuation and restored before use.
  59. // Some values, like DXIL buffer handles should not be saved and must be
  60. // rematerialized after a continuation. The stack frame in a state function has
  61. // the following layout:
  62. //
  63. // | |
  64. // +---------------+
  65. // | argN |
  66. // | ... |
  67. // | arg0 |
  68. // | returnStateID | caller arg frame
  69. // +---------------+ <-- entry stack pointer
  70. // | |
  71. // | saved values |
  72. // | |
  73. // +---------------+
  74. // | argN |
  75. // | ... |
  76. // | arg0 |
  77. // | returnStateID | callee arg frame
  78. // +---------------+ <-- stack frame pointer
  79. // |
  80. // V stack grows downward towards smaller addresses
  81. //
  82. // The return state ID is stored at the base of the argument frame, followed by
  83. // function arguments, if any. The saved values follow the argument frame. Instead
  84. // of adjusting the size of the stack frame for the saved values and argument
  85. // frames of each continuation a single allocation is made with enough space to
  86. // accommodate all continuations in the function.
  87. //
  88. // Several placeholder functions are used during the process of the state function
  89. // transform to break dependency cycles. A placeholder for the runtime data pointer
  90. // is used to allocate the stack frame before the function signature is changed
  91. // and the pointer parameter is created. The stack frame is also allocated before
  92. // its size has been determined, so a placeholder is used. The state IDs corresponding
  93. // to function entry substates may also not be known before the transform has been
  94. // run on all the candidate functions. Therefore a placeholder is used for state
  95. // IDs as well. These are replaced by calling StateFunctionTransform::finalizeStateIds()
  96. // after all the candidate functions have been transformed.
  97. //
  98. // If the intrinsic Internal_CallIndirect(int stateId) appears in the body of
  99. // the function then it is treated as a continuation with a transition to the
  100. // specified stateId.
  101. //
  102. // When an attribute size is specified, space is allocated on the stack frame for
  103. // committed/pending attributes, as well as the previous offsets for the committed/
  104. // pending attributes. The attribute size should be set if the
  105. // function is TraceRay(). The payload offset needs to be set by the caller. The
  106. // stack frame for TraceRay() has the following layout:
  107. //
  108. // | |
  109. // +-------------------------+
  110. // | |
  111. // | TraceRay() args |
  112. // | |
  113. // +-------------------------+
  114. // | returnStateID | caller arg frame
  115. // +-------------------------+ <-- entry stack offset
  116. // | old committed attr offs |
  117. // | old pending attr offset |
  118. // +-------------------------+
  119. // | |
  120. // | committed attributes |
  121. // | |
  122. // +-------------------------+ <-- new committed attribute offset
  123. // | |
  124. // | pending attributes |
  125. // | |
  126. // +-------------------------+ <-- new pending attribute offset
  127. // | |
  128. // | saved values |
  129. // | |
  130. // +-------------------------+
  131. // | argN |
  132. // | ... |
  133. // | arg0 |
  134. // | returnStateID | callee arg frame
  135. // +-------------------------+ <-- stack frame offset
  136. //
  137. // The arguments to some functions (e.g. closesthit, anyhit, and miss shaders)
  138. // come from the payload or attributes. The positions of these arguments can be
  139. // specified to SFT, which will redirect the defs from the args to corresponding
  140. // values on the stack.
  141. //
  142. // The following runtime (LLVM) functions are used by SFT (all sizes and offsets
  143. // are in terms of ints):
  144. // void stackFramePush(<RuntimeDataTy> runtimeData, i32 size)
  145. // void stackFramePop(<RuntimeDataTy> runtimeData, i32 size)
  146. //
  147. // i32 stackFrameOffset(<RuntimeDataTy> runtimeData)
  148. // i32 payloadOffset(<RuntimeDataTy> runtimeData)
  149. // i32 committedAttrOffset(<RuntimeDataTy> runtimeData)
  150. // i32 pendingAttrOffset(<RuntimeDataTy> runtimeData)
  151. //
  152. // i32* stackIntPtr(<RuntimeDataTy> runtimeData, i32 baseOffset, i32 offset)
  153. //
  154. // Called before/after stackFramePush()/stackFramePop():
  155. // void traceFramePush(<RuntimeDataTy> runtimeData, i32 attrSize)
  156. // void traceFramePop(<RuntimeDataTy> runtimeData)
  157. class StateFunctionTransform
  158. {
  159. public:
  160. enum ParameterSemanticType
  161. {
  162. PST_NONE = 0,
  163. PST_PAYLOAD,
  164. PST_ATTRIBUTE,
  165. PST_COUNT
  166. };
  167. // func is the function to be transformed. candidateFuncNames is a list of all
  168. // functions that which have been or will be transformed to state functions,
  169. // including func. The runtimeDataArgTy is the type to use for the first argument
  170. // in state functions.
  171. StateFunctionTransform(llvm::Function* func, const std::vector<std::string>& candidateFuncNames, llvm::Type* runtimeDataArgTy);
  172. // Optional parameters to be specified before run()
  173. void setAttributeSize(int sizeInBytes); // needed for TraceRay()
  174. void setParameterInfo(const std::vector<ParameterSemanticType>& paramTypes, bool useCommittedAttr = true);
  175. void setResourceGlobals(const std::set<llvm::Value*>& resources);
  176. static llvm::Function* createDummyRuntimeDataArgFunc(llvm::Module* M, llvm::Type* runtimeDataArgTy);
  177. // Generates state functions from func into the same module. The original function
  178. // is left only as a declaration.
  179. void run(std::vector<llvm::Function*>& stateFunctions, _Out_ unsigned int &shaderStackSize);
  180. // candidateFuncEntryStateIds corresponding to the candidateFuncNames passed to
  181. // the constructor. stateIDs are computed as candidateFuncEntryStateIds[functionIdx]
  182. // + substateIdx, where functionIdx and substateIdx come from the arguments to
  183. // the placeholder stateID function.
  184. static void finalizeStateIds(llvm::Module* module, const std::vector<int>& candidateFuncEntryStateIds);
  185. // Outputs detailed diagnostic information if set to true.
  186. void setVerbose(bool val);
  187. void setDumpFilename(const std::string& dumpFilename);
  188. private:
  189. // Function to transform
  190. llvm::Function* m_function = nullptr;
  191. // Name of the function to transform
  192. std::string m_functionName;
  193. // Index of the function to transform in m_candidateFuncNames
  194. int m_functionIdx = 0;
  195. // cadidateFuncNames is a list of all functions that which have been or will
  196. // be transformed to state functions. Used to create function index used
  197. // by the stateID placeholder function.
  198. const std::vector<std::string>& m_candidateFuncNames;
  199. llvm::Type* m_runtimeDataArgTy = nullptr;
  200. llvm::Value* m_runtimeDataArg = nullptr; // set in init() and changeFunctionSignature()
  201. llvm::Value* m_stackFrameSizeVal = nullptr; // set in init() and preserveLiveValuesAcrossCallsites()
  202. int m_attributeSizeInBytes = -1;
  203. std::vector<ParameterSemanticType> m_paramTypes;
  204. bool m_useCommittedAttr = false;
  205. const std::set<llvm::Value*>* m_resources;
  206. std::vector<llvm::CallInst*> m_callSites;
  207. std::vector<int> m_callSiteFunctionIdx;
  208. std::vector<llvm::CallInst*> m_movePayloadToStackCalls;
  209. std::vector<llvm::CallInst*> m_setPendingAttrCalls;
  210. std::vector<llvm::ReturnInst*> m_returns;
  211. bool m_verbose = false;
  212. std::string m_dumpFilename;
  213. unsigned int m_dumpId = 0;
  214. llvm::Function* m_stackIntPtrFunc = nullptr;
  215. llvm::CallInst* m_stackFramePush = nullptr;
  216. llvm::CallInst* m_stackFrameOffset = nullptr;
  217. llvm::CallInst* m_payloadOffset = nullptr; // Offset at beginning of function
  218. llvm::CallInst* m_committedAttrOffset = nullptr; // Offset at beginning of function
  219. llvm::CallInst* m_pendingAttrOffset = nullptr; // Offset at beginning of function
  220. // Placeholder function taking constant values functionIdx and substate.
  221. // These are later translated to a stateId by finalizeStateIds().
  222. llvm::Function* m_dummyStateIdFunc = nullptr;
  223. int m_maxCallerArgFrameSizeInBytes = 0;
  224. int m_traceFrameSizeInBytes = 0;
  225. // Functions used to abstract stack operations. These make intermediate stages
  226. // in the transform a little bit cleaner.
  227. std::map<llvm::FunctionType*, llvm::Function*> m_stackStoreFuncs;
  228. std::map<llvm::FunctionType*, llvm::Function*> m_stackLoadFuncs;
  229. std::map<llvm::FunctionType*, llvm::Function*> m_stackPtrFuncs;
  230. // Main stages of the transformation
  231. void init();
  232. void findCallSitesIntrinsicsAndReturns();
  233. void changeCallingConvention();
  234. void preserveLiveValuesAcrossCallsites(_Out_ unsigned int &shaderStackSize);
  235. void createSubstateFunctions(std::vector<llvm::Function*>& stateFunctions);
  236. void lowerStackFuncs();
  237. llvm::Value* getDummyStateId(int functionIdx, int substate, llvm::Instruction* insertBefore);
  238. void allocateStackFrame();
  239. void allocateTraceFrame();
  240. void createArgFrames();
  241. void changeFunctionSignature();
  242. void createStackStore(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore);
  243. llvm::Instruction* createStackLoad(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore);
  244. llvm::Instruction* createStackPtr(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore);
  245. llvm::Instruction* createStackPtr(llvm::Value* baseOffset, llvm::Type* valTy, llvm::Value* intIndex, llvm::Instruction* insertBefore);
  246. void rewriteDummyStackSize(uint64_t frameSizeInBytes);
  247. BasicBlockVector replaceCallSites();
  248. llvm::Function* split(llvm::Function* baseFunc, llvm::BasicBlock* subStateEntryBlock, int substateIndex);
  249. void flattenGepsOnValue(llvm::Value* val, llvm::Value* baseOffset, llvm::Value* offset);
  250. void scalarizeVectorStackAccess(llvm::Instruction* vecPtr, llvm::Value* baseOffset, llvm::Value* offsetVal);
  251. // Diagnostic printing functions
  252. llvm::raw_ostream& getOutputStream(const std::string functionName, const std::string& suffix, unsigned int dumpId);
  253. void printFunction(const llvm::Function* function, const std::string& suffix, unsigned int dumpId);
  254. void printFunction(const std::string& suffix);
  255. void printFunctions(const std::vector<llvm::Function*>& funcs, const char* suffix);
  256. void printModule(const llvm::Module* module, const std::string& suffix);
  257. void printSet(const InstructionSetVector& vals, const char* msg = nullptr);
  258. };