DxrFallbackCompiler.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #pragma once
  2. #include <map>
  3. #include <memory>
  4. #include <string>
  5. #include <vector>
  6. namespace llvm
  7. {
  8. class CallInst;
  9. class Function;
  10. class Module;
  11. class Type;
  12. }
  13. // Combines DXIL raytracing shaders together into a compute shader.
  14. //
  15. // The incoming module should contain the following functions if the corresponding
  16. // intrinsic are called by the specified shaders,
  17. // if called:
  18. // Fallback_TraceRay()
  19. // Fallback_Ignore()
  20. // Fallback_AcceptHitAndEndSearch()
  21. // Fallback_ReportHit()
  22. //
  23. // Fallback_TraceRay() will be called with the original arguments, substituting
  24. // the offset of the payload on the stack for the actual payload.
  25. // Fallback_TraceRay() will also be used to replace calls to TraceRayTest().
  26. //
  27. // ReportHit() returns a boolean. But to handle the abort of the intersection
  28. // shader when AcceptHitAndEndSearch() is called we need a third return value.
  29. // Fallback_ReportHit() should return an integer < 0 for end search, 0 for ignore,
  30. // and > 0 for accept.
  31. //
  32. // The module should also contain a single call to Fallback_Scheduler() in the
  33. // entry shader for the raytracing compute shader.
  34. //
  35. // resizeStack() needs to be called after inlining everything in the compute
  36. // shader.
  37. //
  38. // Currently the main scheduling loop and the implementation for intrinsic
  39. // functions come from an internal runtime module.
  40. class DxrFallbackCompiler
  41. {
  42. public:
  43. typedef std::map<int, std::string> IntToFuncNameMap;
  44. // If findCalledShaders is true, then the list of shaderNames is expanded to
  45. // include shader functions (functions with attribute "exp-shader") that are
  46. // called by functions in shaderNames. Shader entry state IDs are still
  47. // returned only for those originally in shaderNames. findCalledShaders used
  48. // for testing.
  49. DxrFallbackCompiler(llvm::Module* mod, const std::vector<std::string>& shaderNames, unsigned maxAttributeSize, unsigned stackSizeInBytes, bool findCalledShaders = false);
  50. // 0 - no debug output
  51. // 1 - dump initial combined module, compiled module, and final linked module
  52. // 2 - dump intermediate stages of SFT to console
  53. // 3 - dump intermediate stages of SFT to file
  54. void setDebugOutputLevel(int val);
  55. // Returns the entry state id for each of shaderNames. The transformations
  56. // are performed in place on the module.
  57. void compile(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap);
  58. void link(std::vector<int>& shaderEntryStateIds, std::vector<unsigned int> &shaderStackSizes, IntToFuncNameMap *pCachedMap);
  59. // TODO: Ideally we would run this after inlining everything at the end of compile.
  60. // Until we figure out to do this, we will call the function after the final link.
  61. static void resizeStack(llvm::Function* F, unsigned stackSizeInBytes);
  62. private:
  63. typedef std::map<int, llvm::Function*> IntToFuncMap;
  64. typedef std::map<std::string, llvm::Function*> StringToFuncMap;
  65. llvm::Module* m_module = nullptr;
  66. const std::vector<std::string>& m_entryShaderNames;
  67. unsigned m_stackSizeInBytes = 0;
  68. unsigned m_maxAttributeSize = 0;
  69. bool m_findCalledShaders = false;
  70. int m_debugOutputLevel = 0;
  71. StringToFuncMap m_shaderMap;
  72. void initShaderMap(std::vector<std::string>& shaderNames);
  73. void linkRuntime();
  74. void lowerAnyHitControlFlowFuncs();
  75. void lowerReportHit();
  76. void lowerTraceRay(llvm::Type* runtimeDataArgTy);
  77. void createStateFunctions(IntToFuncMap& stateFunctionMap, std::vector<int>& shaderEntryStateIds, std::vector<unsigned int>& shaderStackSizes, int baseStateId, const std::vector<std::string>& shaderNames, llvm::Type* runtimeDataArgTy);
  78. void createLaunchParams(llvm::Function* func);
  79. void createStack(llvm::Function* func);
  80. void createStateDispatch(llvm::Function* func, const IntToFuncMap& stateFunctionMap, llvm::Type* runtimeDataArgTy);
  81. void lowerIntrinsics();
  82. llvm::Type* getRuntimeDataArgType();
  83. llvm::Function* createDispatchFunction(const IntToFuncMap &stateFunctionMap, llvm::Type* runtimeDataArgTy);
  84. // These functions return calls only in shaders in m_shaderMap.
  85. std::vector<llvm::CallInst*> getCallsInShadersToFunction(const std::string& funcName);
  86. std::vector<llvm::CallInst*> getCallsInShadersToFunctionWithPrefix(const std::string& funcNamePrefix);
  87. };