#pragma once #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include #include #include namespace llvm { class AllocaInst; class BasicBlock; class CallInst; class Function; class FunctionType; class Instruction; class Module; class raw_ostream; class ReturnInst; class StructType; class Type; class Value; } class LiveValues; typedef std::vector BasicBlockVector; typedef llvm::SetVector InstructionSetVector; //============================================================================== // Transforms the given function into a number of state functions to be // used in a state machine. // // State functions have the following signature: // int ( runtimeData). // They take an runtime data argument with a given type used by the runtime and // return the state ID of the next state. If the function contains calls to other // candidate functions that are to be transformed into state functions, the // function is split into multiple substate functions at call sites and the calls // are replaced with continuations. For example candidate funcA() calling candidate // funcB(): // void funcA(int param0) // { // // code moved to funcA_ss0() // int foo = 10; // ... // // funcB(arg0, arg1); // // // code moved to funcA_ss1() // int bar = someFunc(foo); // // } // will be split into two substate functions, funcA_ss0() and funcA_ss1(). // funcA_ss0() pushes the stateID for funcA_ss1() onto the stack, and // returns the state ID for the entry substate of funcB, funcB_ss0(). // A substate of funcB will eventually pop the stack and return the state ID // for funcA_ss1(). funcA_ss1() in turn pops the stack to get the state ID // placed there by its caller. // // If candidate functions, like funcB(), have arguments they are moved to the stack. // Any values that are live across continuations, like foo in this example, // must also be saved to the stack before the continuation and restored before use. // Some values, like DXIL buffer handles should not be saved and must be // rematerialized after a continuation. The stack frame in a state function has // the following layout: // // | | // +---------------+ // | argN | // | ... | // | arg0 | // | returnStateID | caller arg frame // +---------------+ <-- entry stack pointer // | | // | saved values | // | | // +---------------+ // | argN | // | ... | // | arg0 | // | returnStateID | callee arg frame // +---------------+ <-- stack frame pointer // | // V stack grows downward towards smaller addresses // // The return state ID is stored at the base of the argument frame, followed by // function arguments, if any. The saved values follow the argument frame. Instead // of adjusting the size of the stack frame for the saved values and argument // frames of each continuation a single allocation is made with enough space to // accommodate all continuations in the function. // // Several placeholder functions are used during the process of the state function // transform to break dependency cycles. A placeholder for the runtime data pointer // is used to allocate the stack frame before the function signature is changed // and the pointer parameter is created. The stack frame is also allocated before // its size has been determined, so a placeholder is used. The state IDs corresponding // to function entry substates may also not be known before the transform has been // run on all the candidate functions. Therefore a placeholder is used for state // IDs as well. These are replaced by calling StateFunctionTransform::finalizeStateIds() // after all the candidate functions have been transformed. // // If the intrinsic Internal_CallIndirect(int stateId) appears in the body of // the function then it is treated as a continuation with a transition to the // specified stateId. // // When an attribute size is specified, space is allocated on the stack frame for // committed/pending attributes, as well as the previous offsets for the committed/ // pending attributes. The attribute size should be set if the // function is TraceRay(). The payload offset needs to be set by the caller. The // stack frame for TraceRay() has the following layout: // // | | // +-------------------------+ // | | // | TraceRay() args | // | | // +-------------------------+ // | returnStateID | caller arg frame // +-------------------------+ <-- entry stack offset // | old committed attr offs | // | old pending attr offset | // +-------------------------+ // | | // | committed attributes | // | | // +-------------------------+ <-- new committed attribute offset // | | // | pending attributes | // | | // +-------------------------+ <-- new pending attribute offset // | | // | saved values | // | | // +-------------------------+ // | argN | // | ... | // | arg0 | // | returnStateID | callee arg frame // +-------------------------+ <-- stack frame offset // // The arguments to some functions (e.g. closesthit, anyhit, and miss shaders) // come from the payload or attributes. The positions of these arguments can be // specified to SFT, which will redirect the defs from the args to corresponding // values on the stack. // // The following runtime (LLVM) functions are used by SFT (all sizes and offsets // are in terms of ints): // void stackFramePush( runtimeData, i32 size) // void stackFramePop( runtimeData, i32 size) // // i32 stackFrameOffset( runtimeData) // i32 payloadOffset( runtimeData) // i32 committedAttrOffset( runtimeData) // i32 pendingAttrOffset( runtimeData) // // i32* stackIntPtr( runtimeData, i32 baseOffset, i32 offset) // // Called before/after stackFramePush()/stackFramePop(): // void traceFramePush( runtimeData, i32 attrSize) // void traceFramePop( runtimeData) class StateFunctionTransform { public: enum ParameterSemanticType { PST_NONE = 0, PST_PAYLOAD, PST_ATTRIBUTE, PST_COUNT }; // func is the function to be transformed. candidateFuncNames is a list of all // functions that which have been or will be transformed to state functions, // including func. The runtimeDataArgTy is the type to use for the first argument // in state functions. StateFunctionTransform(llvm::Function* func, const std::vector& candidateFuncNames, llvm::Type* runtimeDataArgTy); // Optional parameters to be specified before run() void setAttributeSize(int sizeInBytes); // needed for TraceRay() void setParameterInfo(const std::vector& paramTypes, bool useCommittedAttr = true); void setResourceGlobals(const std::set& resources); static llvm::Function* createDummyRuntimeDataArgFunc(llvm::Module* M, llvm::Type* runtimeDataArgTy); // Generates state functions from func into the same module. The original function // is left only as a declaration. void run(std::vector& stateFunctions, _Out_ unsigned int &shaderStackSize); // candidateFuncEntryStateIds corresponding to the candidateFuncNames passed to // the constructor. stateIDs are computed as candidateFuncEntryStateIds[functionIdx] // + substateIdx, where functionIdx and substateIdx come from the arguments to // the placeholder stateID function. static void finalizeStateIds(llvm::Module* module, const std::vector& candidateFuncEntryStateIds); // Outputs detailed diagnostic information if set to true. void setVerbose(bool val); void setDumpFilename(const std::string& dumpFilename); private: // Function to transform llvm::Function* m_function = nullptr; // Name of the function to transform std::string m_functionName; // Index of the function to transform in m_candidateFuncNames int m_functionIdx = 0; // cadidateFuncNames is a list of all functions that which have been or will // be transformed to state functions. Used to create function index used // by the stateID placeholder function. const std::vector& m_candidateFuncNames; llvm::Type* m_runtimeDataArgTy = nullptr; llvm::Value* m_runtimeDataArg = nullptr; // set in init() and changeFunctionSignature() llvm::Value* m_stackFrameSizeVal = nullptr; // set in init() and preserveLiveValuesAcrossCallsites() int m_attributeSizeInBytes = -1; std::vector m_paramTypes; bool m_useCommittedAttr = false; const std::set* m_resources; std::vector m_callSites; std::vector m_callSiteFunctionIdx; std::vector m_movePayloadToStackCalls; std::vector m_setPendingAttrCalls; std::vector m_returns; bool m_verbose = false; std::string m_dumpFilename; unsigned int m_dumpId = 0; llvm::Function* m_stackIntPtrFunc = nullptr; llvm::CallInst* m_stackFramePush = nullptr; llvm::CallInst* m_stackFrameOffset = nullptr; llvm::CallInst* m_payloadOffset = nullptr; // Offset at beginning of function llvm::CallInst* m_committedAttrOffset = nullptr; // Offset at beginning of function llvm::CallInst* m_pendingAttrOffset = nullptr; // Offset at beginning of function // Placeholder function taking constant values functionIdx and substate. // These are later translated to a stateId by finalizeStateIds(). llvm::Function* m_dummyStateIdFunc = nullptr; int m_maxCallerArgFrameSizeInBytes = 0; int m_traceFrameSizeInBytes = 0; // Functions used to abstract stack operations. These make intermediate stages // in the transform a little bit cleaner. std::map m_stackStoreFuncs; std::map m_stackLoadFuncs; std::map m_stackPtrFuncs; // Main stages of the transformation void init(); void findCallSitesIntrinsicsAndReturns(); void changeCallingConvention(); void preserveLiveValuesAcrossCallsites(_Out_ unsigned int &shaderStackSize); void createSubstateFunctions(std::vector& stateFunctions); void lowerStackFuncs(); llvm::Value* getDummyStateId(int functionIdx, int substate, llvm::Instruction* insertBefore); void allocateStackFrame(); void allocateTraceFrame(); void createArgFrames(); void changeFunctionSignature(); void createStackStore(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore); llvm::Instruction* createStackLoad(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore); llvm::Instruction* createStackPtr(llvm::Value* baseOffset, llvm::Value* val, int offsetInBytes, llvm::Instruction* insertBefore); llvm::Instruction* createStackPtr(llvm::Value* baseOffset, llvm::Type* valTy, llvm::Value* intIndex, llvm::Instruction* insertBefore); void rewriteDummyStackSize(uint64_t frameSizeInBytes); BasicBlockVector replaceCallSites(); llvm::Function* split(llvm::Function* baseFunc, llvm::BasicBlock* subStateEntryBlock, int substateIndex); void flattenGepsOnValue(llvm::Value* val, llvm::Value* baseOffset, llvm::Value* offset); void scalarizeVectorStackAccess(llvm::Instruction* vecPtr, llvm::Value* baseOffset, llvm::Value* offsetVal); // Diagnostic printing functions llvm::raw_ostream& getOutputStream(const std::string functionName, const std::string& suffix, unsigned int dumpId); void printFunction(const llvm::Function* function, const std::string& suffix, unsigned int dumpId); void printFunction(const std::string& suffix); void printFunctions(const std::vector& funcs, const char* suffix); void printModule(const llvm::Module* module, const std::string& suffix); void printSet(const InstructionSetVector& vals, const char* msg = nullptr); };