EmitVisitor.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. //===-- EmitVisitor.h - Emit Visitor ----------------------------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. #ifndef LLVM_CLANG_SPIRV_EMITVISITOR_H
  9. #define LLVM_CLANG_SPIRV_EMITVISITOR_H
  10. #include "clang/SPIRV/SpirvContext.h"
  11. #include "clang/SPIRV/SpirvVisitor.h"
  12. #include "llvm/ADT/DenseMap.h"
  13. #include "llvm/ADT/StringMap.h"
  14. #include <functional>
  15. namespace clang {
  16. namespace spirv {
  17. class SpirvFunction;
  18. class SpirvBasicBlock;
  19. class SpirvType;
  20. class EmitTypeHandler {
  21. public:
  22. struct DecorationInfo {
  23. DecorationInfo(spv::Decoration decor, llvm::ArrayRef<uint32_t> params = {},
  24. llvm::Optional<uint32_t> index = llvm::None)
  25. : decoration(decor), decorationParams(params.begin(), params.end()),
  26. memberIndex(index) {}
  27. bool operator==(const DecorationInfo &other) const {
  28. return decoration == other.decoration &&
  29. decorationParams == other.decorationParams &&
  30. memberIndex.hasValue() == other.memberIndex.hasValue() &&
  31. (!memberIndex.hasValue() ||
  32. memberIndex.getValue() == other.memberIndex.getValue());
  33. }
  34. spv::Decoration decoration;
  35. llvm::SmallVector<uint32_t, 4> decorationParams;
  36. llvm::Optional<uint32_t> memberIndex;
  37. };
  38. public:
  39. EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvContext,
  40. const SpirvCodeGenOptions &opts,
  41. std::vector<uint32_t> *debugVec,
  42. std::vector<uint32_t> *decVec,
  43. std::vector<uint32_t> *typesVec,
  44. const std::function<uint32_t()> &takeNextIdFn)
  45. : astContext(astCtx), context(spvContext), spvOptions(opts),
  46. debugVariableBinary(debugVec), annotationsBinary(decVec),
  47. typeConstantBinary(typesVec), takeNextIdFunction(takeNextIdFn),
  48. emittedConstantInts({}), emittedConstantFloats({}),
  49. emittedConstantComposites({}), emittedConstantNulls({}),
  50. emittedConstantBools() {
  51. assert(decVec);
  52. assert(typesVec);
  53. }
  54. // Disable copy constructor/assignment.
  55. EmitTypeHandler(const EmitTypeHandler &) = delete;
  56. EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
  57. // Emits the instruction for the given type into the typeConstantBinary and
  58. // returns the result-id for the type. If the type has already been emitted,
  59. // it only returns its result-id.
  60. //
  61. // If any names are associated with the type (or its members in case of
  62. // structs), the OpName/OpMemberNames will also be emitted.
  63. //
  64. // If any decorations apply to the type, it also emits the decoration
  65. // instructions into the annotationsBinary.
  66. uint32_t emitType(const SpirvType *);
  67. // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
  68. // targetting the given type. Uses the given decoration kind and its
  69. // parameters.
  70. void emitDecoration(uint32_t typeResultId, spv::Decoration,
  71. llvm::ArrayRef<uint32_t> decorationParams,
  72. llvm::Optional<uint32_t> memberIndex = llvm::None);
  73. uint32_t getOrCreateConstant(SpirvConstant *);
  74. // Emits an OpConstant instruction and returns its result-id.
  75. // For non-specialization constants, if an identical constant has already been
  76. // emitted, returns the existing constant's result-id.
  77. //
  78. // Note1: This method modifies the curTypeInst. Do not call in the middle of
  79. // construction of another instruction.
  80. //
  81. // Note 2: Integer constants may need to be generated for cases where there is
  82. // no SpirvConstantInteger instruction in the module. For example, we need to
  83. // emit an integer in order to create an array type. Therefore,
  84. // 'getOrCreateConstantInt' has a different signature than others. If a
  85. // constant instruction is provided, and it already has a result-id assigned,
  86. // it will be used. Otherwise a new result-id will be allocated for the
  87. // instruction.
  88. uint32_t
  89. getOrCreateConstantInt(llvm::APInt value, const SpirvType *type,
  90. bool isSpecConst,
  91. SpirvInstruction *constantInstruction = nullptr);
  92. uint32_t getOrCreateConstantFloat(SpirvConstantFloat *);
  93. uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
  94. uint32_t getOrCreateConstantNull(SpirvConstantNull *);
  95. uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
  96. private:
  97. void initTypeInstruction(spv::Op op);
  98. void finalizeTypeInstruction();
  99. // Returns the result-id for the given type and decorations. If a type with
  100. // the same decorations have already been used, it returns the existing
  101. // result-id. If not, creates a new result-id for such type and returns it.
  102. uint32_t getResultIdForType(const SpirvType *, bool *alreadyExists);
  103. // Emits an OpName (if memberIndex is not provided) or OpMemberName (if
  104. // memberIndex is provided) for the given target result-id.
  105. void emitNameForType(llvm::StringRef name, uint32_t targetTypeId,
  106. llvm::Optional<uint32_t> memberIndex = llvm::None);
  107. // There is no guarantee that an instruction or a function or a basic block
  108. // has been assigned result-id. This method returns the result-id for the
  109. // given object. If a result-id has not been assigned yet, it'll assign
  110. // one and return it.
  111. template <class T> uint32_t getOrAssignResultId(T *obj) {
  112. if (!obj->getResultId()) {
  113. obj->setResultId(takeNextIdFunction());
  114. }
  115. return obj->getResultId();
  116. }
  117. private:
  118. /// Emits error to the diagnostic engine associated with this visitor.
  119. template <unsigned N>
  120. DiagnosticBuilder emitError(const char (&message)[N],
  121. SourceLocation loc = {}) {
  122. const auto diagId = astContext.getDiagnostics().getCustomDiagID(
  123. clang::DiagnosticsEngine::Error, message);
  124. return astContext.getDiagnostics().Report(loc, diagId);
  125. }
  126. private:
  127. ASTContext &astContext;
  128. SpirvContext &context;
  129. const SpirvCodeGenOptions &spvOptions;
  130. std::vector<uint32_t> curTypeInst;
  131. std::vector<uint32_t> curDecorationInst;
  132. std::vector<uint32_t> *debugVariableBinary;
  133. std::vector<uint32_t> *annotationsBinary;
  134. std::vector<uint32_t> *typeConstantBinary;
  135. std::function<uint32_t()> takeNextIdFunction;
  136. // The array type requires the result-id of an OpConstant for its length. In
  137. // order to avoid duplicate OpConstant instructions, we keep a map of constant
  138. // uint value to the result-id of the OpConstant for that value.
  139. llvm::DenseMap<std::pair<uint64_t, const SpirvType *>, uint32_t>
  140. emittedConstantInts;
  141. llvm::DenseMap<std::pair<uint64_t, const SpirvType *>, uint32_t>
  142. emittedConstantFloats;
  143. llvm::SmallVector<SpirvConstantComposite *, 8> emittedConstantComposites;
  144. llvm::SmallVector<SpirvConstantNull *, 8> emittedConstantNulls;
  145. SpirvConstantBoolean *emittedConstantBools[2];
  146. // emittedTypes is a map that caches the result-id of types in order to avoid
  147. // emitting an identical type multiple times.
  148. llvm::DenseMap<const SpirvType *, uint32_t> emittedTypes;
  149. };
  150. /// \breif The visitor class that emits the SPIR-V words from the in-memory
  151. /// representation.
  152. class EmitVisitor : public Visitor {
  153. public:
  154. /// \brief The struct representing a SPIR-V module header.
  155. struct Header {
  156. /// \brief Default constructs a SPIR-V module header with id bound 0.
  157. Header(uint32_t bound, uint32_t version);
  158. /// \brief Feeds the consumer with all the SPIR-V words for this header.
  159. std::vector<uint32_t> takeBinary();
  160. const uint32_t magicNumber;
  161. uint32_t version;
  162. const uint32_t generator;
  163. uint32_t bound;
  164. const uint32_t reserved;
  165. };
  166. public:
  167. EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
  168. const SpirvCodeGenOptions &opts)
  169. : Visitor(opts, spvCtx), astContext(astCtx), id(0),
  170. typeHandler(astCtx, spvCtx, opts, &debugVariableBinary,
  171. &annotationsBinary, &typeConstantBinary,
  172. [this]() -> uint32_t { return takeNextId(); }),
  173. debugMainFileId(0), debugLine(0), debugColumn(0),
  174. lastOpWasMergeInst(false) {}
  175. // Visit different SPIR-V constructs for emitting.
  176. bool visit(SpirvModule *, Phase phase);
  177. bool visit(SpirvFunction *, Phase phase);
  178. bool visit(SpirvBasicBlock *, Phase phase);
  179. bool visit(SpirvCapability *);
  180. bool visit(SpirvExtension *);
  181. bool visit(SpirvExtInstImport *);
  182. bool visit(SpirvMemoryModel *);
  183. bool visit(SpirvEmitVertex *);
  184. bool visit(SpirvEndPrimitive *);
  185. bool visit(SpirvEntryPoint *);
  186. bool visit(SpirvExecutionMode *);
  187. bool visit(SpirvString *);
  188. bool visit(SpirvSource *);
  189. bool visit(SpirvModuleProcessed *);
  190. bool visit(SpirvDecoration *);
  191. bool visit(SpirvVariable *);
  192. bool visit(SpirvFunctionParameter *);
  193. bool visit(SpirvLoopMerge *);
  194. bool visit(SpirvSelectionMerge *);
  195. bool visit(SpirvBranch *);
  196. bool visit(SpirvBranchConditional *);
  197. bool visit(SpirvKill *);
  198. bool visit(SpirvReturn *);
  199. bool visit(SpirvSwitch *);
  200. bool visit(SpirvUnreachable *);
  201. bool visit(SpirvAccessChain *);
  202. bool visit(SpirvAtomic *);
  203. bool visit(SpirvBarrier *);
  204. bool visit(SpirvBinaryOp *);
  205. bool visit(SpirvBitFieldExtract *);
  206. bool visit(SpirvBitFieldInsert *);
  207. bool visit(SpirvConstantBoolean *);
  208. bool visit(SpirvConstantInteger *);
  209. bool visit(SpirvConstantFloat *);
  210. bool visit(SpirvConstantComposite *);
  211. bool visit(SpirvConstantNull *);
  212. bool visit(SpirvCompositeConstruct *);
  213. bool visit(SpirvCompositeExtract *);
  214. bool visit(SpirvCompositeInsert *);
  215. bool visit(SpirvExtInst *);
  216. bool visit(SpirvFunctionCall *);
  217. bool visit(SpirvNonUniformBinaryOp *);
  218. bool visit(SpirvNonUniformElect *);
  219. bool visit(SpirvNonUniformUnaryOp *);
  220. bool visit(SpirvImageOp *);
  221. bool visit(SpirvImageQuery *);
  222. bool visit(SpirvImageSparseTexelsResident *);
  223. bool visit(SpirvImageTexelPointer *);
  224. bool visit(SpirvLoad *);
  225. bool visit(SpirvCopyObject *);
  226. bool visit(SpirvSampledImage *);
  227. bool visit(SpirvSelect *);
  228. bool visit(SpirvSpecConstantBinaryOp *);
  229. bool visit(SpirvSpecConstantUnaryOp *);
  230. bool visit(SpirvStore *);
  231. bool visit(SpirvUnaryOp *);
  232. bool visit(SpirvVectorShuffle *);
  233. bool visit(SpirvArrayLength *);
  234. bool visit(SpirvRayTracingOpNV *);
  235. bool visit(SpirvDemoteToHelperInvocationEXT *);
  236. bool visit(SpirvRayQueryOpKHR *);
  237. // Returns the assembled binary built up in this visitor.
  238. std::vector<uint32_t> takeBinary();
  239. private:
  240. // Returns the next available result-id.
  241. uint32_t takeNextId() { return ++id; }
  242. // There is no guarantee that an instruction or a function or a basic block
  243. // has been assigned result-id. This method returns the result-id for the
  244. // given object. If a result-id has not been assigned yet, it'll assign
  245. // one and return it.
  246. template <class T> uint32_t getOrAssignResultId(T *obj) {
  247. if (!obj->getResultId()) {
  248. obj->setResultId(takeNextId());
  249. }
  250. return obj->getResultId();
  251. }
  252. void emitDebugLine(spv::Op op, const SourceLocation &loc);
  253. // Initiates the creation of a new instruction with the given Opcode.
  254. void initInstruction(spv::Op, const SourceLocation &);
  255. // Initiates the creation of the given SPIR-V instruction.
  256. // If the given instruction has a return type, it will also trigger emitting
  257. // the necessary type (and its associated decorations) and uses its result-id
  258. // in the instruction.
  259. void initInstruction(SpirvInstruction *);
  260. // Finalizes the current instruction by encoding the instruction size into the
  261. // first word, and then appends the current instruction to the SPIR-V binary.
  262. void finalizeInstruction();
  263. // Encodes the given string into the current instruction that is being built.
  264. void encodeString(llvm::StringRef value);
  265. // Emits an OpName instruction into the debugBinary for the given target.
  266. void emitDebugNameForInstruction(uint32_t resultId, llvm::StringRef name);
  267. // TODO: Add a method for adding OpMemberName instructions for struct members
  268. // using the type information.
  269. private:
  270. /// Emits error to the diagnostic engine associated with this visitor.
  271. template <unsigned N>
  272. DiagnosticBuilder emitError(const char (&message)[N],
  273. SourceLocation loc = {}) {
  274. const auto diagId = astContext.getDiagnostics().getCustomDiagID(
  275. clang::DiagnosticsEngine::Error, message);
  276. return astContext.getDiagnostics().Report(loc, diagId);
  277. }
  278. private:
  279. // Object that holds Clang AST nodes.
  280. ASTContext &astContext;
  281. // The last result-id that's been used so far.
  282. uint32_t id;
  283. // Handler for emitting types and their related instructions.
  284. EmitTypeHandler typeHandler;
  285. // Current instruction being built
  286. SmallVector<uint32_t, 16> curInst;
  287. // All preamble instructions in the following order:
  288. // OpCapability, OpExtension, OpExtInstImport, OpMemoryModel, OpEntryPoint,
  289. // OpExecutionMode(Id)
  290. std::vector<uint32_t> preambleBinary;
  291. // Debug instructions related to file. Includes:
  292. // OpString, OpSourceExtension, OpSource, OpSourceContinued
  293. std::vector<uint32_t> debugFileBinary;
  294. // All debug instructions related to variable name. Includes:
  295. // OpName, OpMemberName, OpModuleProcessed
  296. std::vector<uint32_t> debugVariableBinary;
  297. // All annotation instructions: OpDecorate, OpMemberDecorate, OpGroupDecorate,
  298. // OpGroupMemberDecorate, and OpDecorationGroup.
  299. std::vector<uint32_t> annotationsBinary;
  300. // All type and constant instructions
  301. std::vector<uint32_t> typeConstantBinary;
  302. // All other instructions
  303. std::vector<uint32_t> mainBinary;
  304. // File information for debugging that will be used by OpLine.
  305. llvm::StringMap<uint32_t> debugFileIdMap;
  306. // Main file information for debugging that will be used by OpLine.
  307. uint32_t debugMainFileId;
  308. // One HLSL source line may result in several SPIR-V instructions. In order to
  309. // avoid emitting many OpLine instructions with identical line and column
  310. // numbers, we record the last line and column number that was used by OpLine,
  311. // and only emit a new OpLine when a new line/column in the source is
  312. // discovered. The last debug line number information emitted by OpLine.
  313. uint32_t debugLine;
  314. // The last debug column number information emitted by OpLine.
  315. uint32_t debugColumn;
  316. // True if the last emitted instruction was OpSelectionMerge or OpLoopMerge.
  317. bool lastOpWasMergeInst;
  318. };
  319. } // namespace spirv
  320. } // namespace clang
  321. #endif // LLVM_CLANG_SPIRV_EMITVISITOR_H