EmitVisitor.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. //===-- EmitVisitor.h - Emit Visitor ----------------------------*- C++ -*-===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //===----------------------------------------------------------------------===//
  8. #ifndef LLVM_CLANG_SPIRV_EMITVISITOR_H
  9. #define LLVM_CLANG_SPIRV_EMITVISITOR_H
  10. #include "clang/SPIRV/SpirvContext.h"
  11. #include "clang/SPIRV/SpirvVisitor.h"
  12. #include "llvm/ADT/DenseMap.h"
  13. #include "llvm/ADT/StringMap.h"
  14. #include <functional>
  15. namespace clang {
  16. namespace spirv {
  17. class SpirvFunction;
  18. class SpirvBasicBlock;
  19. class SpirvType;
  20. class EmitTypeHandler {
  21. public:
  22. struct DecorationInfo {
  23. DecorationInfo(spv::Decoration decor, llvm::ArrayRef<uint32_t> params = {},
  24. llvm::Optional<uint32_t> index = llvm::None)
  25. : decoration(decor), decorationParams(params.begin(), params.end()),
  26. memberIndex(index) {}
  27. bool operator==(const DecorationInfo &other) const {
  28. return decoration == other.decoration &&
  29. decorationParams == other.decorationParams &&
  30. memberIndex.hasValue() == other.memberIndex.hasValue() &&
  31. (!memberIndex.hasValue() ||
  32. memberIndex.getValue() == other.memberIndex.getValue());
  33. }
  34. spv::Decoration decoration;
  35. llvm::SmallVector<uint32_t, 4> decorationParams;
  36. llvm::Optional<uint32_t> memberIndex;
  37. };
  38. public:
  39. EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvContext,
  40. const SpirvCodeGenOptions &opts,
  41. std::vector<uint32_t> *debugVec,
  42. std::vector<uint32_t> *decVec,
  43. std::vector<uint32_t> *typesVec,
  44. const std::function<uint32_t()> &takeNextIdFn)
  45. : astContext(astCtx), context(spvContext), spvOptions(opts),
  46. debugVariableBinary(debugVec), annotationsBinary(decVec),
  47. typeConstantBinary(typesVec), takeNextIdFunction(takeNextIdFn),
  48. emittedConstantInts({}), emittedConstantFloats({}),
  49. emittedConstantComposites({}), emittedConstantNulls({}),
  50. emittedConstantBools() {
  51. assert(decVec);
  52. assert(typesVec);
  53. }
  54. // Disable copy constructor/assignment.
  55. EmitTypeHandler(const EmitTypeHandler &) = delete;
  56. EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
  57. // Emits the instruction for the given type into the typeConstantBinary and
  58. // returns the result-id for the type. If the type has already been emitted,
  59. // it only returns its result-id.
  60. //
  61. // If any names are associated with the type (or its members in case of
  62. // structs), the OpName/OpMemberNames will also be emitted.
  63. //
  64. // If any decorations apply to the type, it also emits the decoration
  65. // instructions into the annotationsBinary.
  66. uint32_t emitType(const SpirvType *);
  67. // Emits OpDecorate (or OpMemberDecorate if memberIndex is non-zero)
  68. // targetting the given type. Uses the given decoration kind and its
  69. // parameters.
  70. void emitDecoration(uint32_t typeResultId, spv::Decoration,
  71. llvm::ArrayRef<uint32_t> decorationParams,
  72. llvm::Optional<uint32_t> memberIndex = llvm::None);
  73. uint32_t getOrCreateConstant(SpirvConstant *);
  74. // Emits an OpConstant instruction and returns its result-id.
  75. // For non-specialization constants, if an identical constant has already been
  76. // emitted, returns the existing constant's result-id.
  77. //
  78. // Note1: This method modifies the curTypeInst. Do not call in the middle of
  79. // construction of another instruction.
  80. //
  81. // Note 2: Integer constants may need to be generated for cases where there is
  82. // no SpirvConstantInteger instruction in the module. For example, we need to
  83. // emit an integer in order to create an array type. Therefore,
  84. // 'getOrCreateConstantInt' has a different signature than others. If a
  85. // constant instruction is provided, and it already has a result-id assigned,
  86. // it will be used. Otherwise a new result-id will be allocated for the
  87. // instruction.
  88. uint32_t
  89. getOrCreateConstantInt(llvm::APInt value, const SpirvType *type,
  90. bool isSpecConst,
  91. SpirvInstruction *constantInstruction = nullptr);
  92. uint32_t getOrCreateConstantFloat(SpirvConstantFloat *);
  93. uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
  94. uint32_t getOrCreateConstantNull(SpirvConstantNull *);
  95. uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
  96. private:
  97. void initTypeInstruction(spv::Op op);
  98. void finalizeTypeInstruction();
  99. // Returns the result-id for the given type and decorations. If a type with
  100. // the same decorations have already been used, it returns the existing
  101. // result-id. If not, creates a new result-id for such type and returns it.
  102. uint32_t getResultIdForType(const SpirvType *, bool *alreadyExists);
  103. // Emits an OpName (if memberIndex is not provided) or OpMemberName (if
  104. // memberIndex is provided) for the given target result-id.
  105. void emitNameForType(llvm::StringRef name, uint32_t targetTypeId,
  106. llvm::Optional<uint32_t> memberIndex = llvm::None);
  107. // There is no guarantee that an instruction or a function or a basic block
  108. // has been assigned result-id. This method returns the result-id for the
  109. // given object. If a result-id has not been assigned yet, it'll assign
  110. // one and return it.
  111. template <class T> uint32_t getOrAssignResultId(T *obj) {
  112. if (!obj->getResultId()) {
  113. obj->setResultId(takeNextIdFunction());
  114. }
  115. return obj->getResultId();
  116. }
  117. private:
  118. /// Emits error to the diagnostic engine associated with this visitor.
  119. template <unsigned N>
  120. DiagnosticBuilder emitError(const char (&message)[N],
  121. SourceLocation loc = {}) {
  122. const auto diagId = astContext.getDiagnostics().getCustomDiagID(
  123. clang::DiagnosticsEngine::Error, message);
  124. return astContext.getDiagnostics().Report(loc, diagId);
  125. }
  126. private:
  127. ASTContext &astContext;
  128. SpirvContext &context;
  129. const SpirvCodeGenOptions &spvOptions;
  130. std::vector<uint32_t> curTypeInst;
  131. std::vector<uint32_t> curDecorationInst;
  132. std::vector<uint32_t> *debugVariableBinary;
  133. std::vector<uint32_t> *annotationsBinary;
  134. std::vector<uint32_t> *typeConstantBinary;
  135. std::function<uint32_t()> takeNextIdFunction;
  136. // The array type requires the result-id of an OpConstant for its length. In
  137. // order to avoid duplicate OpConstant instructions, we keep a map of constant
  138. // uint value to the result-id of the OpConstant for that value.
  139. llvm::DenseMap<std::pair<uint64_t, const SpirvType *>, uint32_t>
  140. emittedConstantInts;
  141. llvm::DenseMap<std::pair<uint64_t, const SpirvType *>, uint32_t>
  142. emittedConstantFloats;
  143. llvm::SmallVector<SpirvConstantComposite *, 8> emittedConstantComposites;
  144. llvm::SmallVector<SpirvConstantNull *, 8> emittedConstantNulls;
  145. SpirvConstantBoolean *emittedConstantBools[2];
  146. // emittedTypes is a map that caches the result-id of types in order to avoid
  147. // emitting an identical type multiple times.
  148. llvm::DenseMap<const SpirvType *, uint32_t> emittedTypes;
  149. };
  150. /// \breif The visitor class that emits the SPIR-V words from the in-memory
  151. /// representation.
  152. class EmitVisitor : public Visitor {
  153. public:
  154. /// \brief The struct representing a SPIR-V module header.
  155. struct Header {
  156. /// \brief Default constructs a SPIR-V module header with id bound 0.
  157. Header(uint32_t bound, uint32_t version);
  158. /// \brief Feeds the consumer with all the SPIR-V words for this header.
  159. std::vector<uint32_t> takeBinary();
  160. const uint32_t magicNumber;
  161. uint32_t version;
  162. const uint32_t generator;
  163. uint32_t bound;
  164. const uint32_t reserved;
  165. };
  166. public:
  167. EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
  168. const SpirvCodeGenOptions &opts)
  169. : Visitor(opts, spvCtx), astContext(astCtx), id(0),
  170. typeHandler(astCtx, spvCtx, opts, &debugVariableBinary,
  171. &annotationsBinary, &typeConstantBinary,
  172. [this]() -> uint32_t { return takeNextId(); }),
  173. debugMainFileId(0), debugLine(0), debugColumn(0),
  174. lastOpWasMergeInst(false) {}
  175. // Visit different SPIR-V constructs for emitting.
  176. bool visit(SpirvModule *, Phase phase) override;
  177. bool visit(SpirvFunction *, Phase phase) override;
  178. bool visit(SpirvBasicBlock *, Phase phase) override;
  179. bool visit(SpirvCapability *) override;
  180. bool visit(SpirvExtension *) override;
  181. bool visit(SpirvExtInstImport *) override;
  182. bool visit(SpirvMemoryModel *) override;
  183. bool visit(SpirvEmitVertex *) override;
  184. bool visit(SpirvEndPrimitive *) override;
  185. bool visit(SpirvEntryPoint *) override;
  186. bool visit(SpirvExecutionMode *) override;
  187. bool visit(SpirvString *) override;
  188. bool visit(SpirvSource *) override;
  189. bool visit(SpirvModuleProcessed *) override;
  190. bool visit(SpirvDecoration *) override;
  191. bool visit(SpirvVariable *) override;
  192. bool visit(SpirvFunctionParameter *) override;
  193. bool visit(SpirvLoopMerge *) override;
  194. bool visit(SpirvSelectionMerge *) override;
  195. bool visit(SpirvBranch *) override;
  196. bool visit(SpirvBranchConditional *) override;
  197. bool visit(SpirvKill *) override;
  198. bool visit(SpirvReturn *) override;
  199. bool visit(SpirvSwitch *) override;
  200. bool visit(SpirvUnreachable *) override;
  201. bool visit(SpirvAccessChain *) override;
  202. bool visit(SpirvAtomic *) override;
  203. bool visit(SpirvBarrier *) override;
  204. bool visit(SpirvBinaryOp *) override;
  205. bool visit(SpirvBitFieldExtract *) override;
  206. bool visit(SpirvBitFieldInsert *) override;
  207. bool visit(SpirvConstantBoolean *) override;
  208. bool visit(SpirvConstantInteger *) override;
  209. bool visit(SpirvConstantFloat *) override;
  210. bool visit(SpirvConstantComposite *) override;
  211. bool visit(SpirvConstantNull *) override;
  212. bool visit(SpirvCompositeConstruct *) override;
  213. bool visit(SpirvCompositeExtract *) override;
  214. bool visit(SpirvCompositeInsert *) override;
  215. bool visit(SpirvExtInst *) override;
  216. bool visit(SpirvFunctionCall *) override;
  217. bool visit(SpirvNonUniformBinaryOp *) override;
  218. bool visit(SpirvNonUniformElect *) override;
  219. bool visit(SpirvNonUniformUnaryOp *) override;
  220. bool visit(SpirvImageOp *) override;
  221. bool visit(SpirvImageQuery *) override;
  222. bool visit(SpirvImageSparseTexelsResident *) override;
  223. bool visit(SpirvImageTexelPointer *) override;
  224. bool visit(SpirvLoad *) override;
  225. bool visit(SpirvCopyObject *) override;
  226. bool visit(SpirvSampledImage *) override;
  227. bool visit(SpirvSelect *) override;
  228. bool visit(SpirvSpecConstantBinaryOp *) override;
  229. bool visit(SpirvSpecConstantUnaryOp *) override;
  230. bool visit(SpirvStore *) override;
  231. bool visit(SpirvUnaryOp *) override;
  232. bool visit(SpirvVectorShuffle *) override;
  233. bool visit(SpirvArrayLength *) override;
  234. bool visit(SpirvRayTracingOpNV *) override;
  235. bool visit(SpirvDemoteToHelperInvocationEXT *) override;
  236. bool visit(SpirvRayQueryOpKHR *) override;
  237. using Visitor::visit;
  238. // Returns the assembled binary built up in this visitor.
  239. std::vector<uint32_t> takeBinary();
  240. private:
  241. // Returns the next available result-id.
  242. uint32_t takeNextId() { return ++id; }
  243. // There is no guarantee that an instruction or a function or a basic block
  244. // has been assigned result-id. This method returns the result-id for the
  245. // given object. If a result-id has not been assigned yet, it'll assign
  246. // one and return it.
  247. template <class T> uint32_t getOrAssignResultId(T *obj) {
  248. if (!obj->getResultId()) {
  249. obj->setResultId(takeNextId());
  250. }
  251. return obj->getResultId();
  252. }
  253. void emitDebugLine(spv::Op op, const SourceLocation &loc);
  254. // Initiates the creation of a new instruction with the given Opcode.
  255. void initInstruction(spv::Op, const SourceLocation &);
  256. // Initiates the creation of the given SPIR-V instruction.
  257. // If the given instruction has a return type, it will also trigger emitting
  258. // the necessary type (and its associated decorations) and uses its result-id
  259. // in the instruction.
  260. void initInstruction(SpirvInstruction *);
  261. // Finalizes the current instruction by encoding the instruction size into the
  262. // first word, and then appends the current instruction to the SPIR-V binary.
  263. void finalizeInstruction();
  264. // Encodes the given string into the current instruction that is being built.
  265. void encodeString(llvm::StringRef value);
  266. // Emits an OpName instruction into the debugBinary for the given target.
  267. void emitDebugNameForInstruction(uint32_t resultId, llvm::StringRef name);
  268. // TODO: Add a method for adding OpMemberName instructions for struct members
  269. // using the type information.
  270. private:
  271. /// Emits error to the diagnostic engine associated with this visitor.
  272. template <unsigned N>
  273. DiagnosticBuilder emitError(const char (&message)[N],
  274. SourceLocation loc = {}) {
  275. const auto diagId = astContext.getDiagnostics().getCustomDiagID(
  276. clang::DiagnosticsEngine::Error, message);
  277. return astContext.getDiagnostics().Report(loc, diagId);
  278. }
  279. private:
  280. // Object that holds Clang AST nodes.
  281. ASTContext &astContext;
  282. // The last result-id that's been used so far.
  283. uint32_t id;
  284. // Handler for emitting types and their related instructions.
  285. EmitTypeHandler typeHandler;
  286. // Current instruction being built
  287. SmallVector<uint32_t, 16> curInst;
  288. // All preamble instructions in the following order:
  289. // OpCapability, OpExtension, OpExtInstImport, OpMemoryModel, OpEntryPoint,
  290. // OpExecutionMode(Id)
  291. std::vector<uint32_t> preambleBinary;
  292. // Debug instructions related to file. Includes:
  293. // OpString, OpSourceExtension, OpSource, OpSourceContinued
  294. std::vector<uint32_t> debugFileBinary;
  295. // All debug instructions related to variable name. Includes:
  296. // OpName, OpMemberName, OpModuleProcessed
  297. std::vector<uint32_t> debugVariableBinary;
  298. // All annotation instructions: OpDecorate, OpMemberDecorate, OpGroupDecorate,
  299. // OpGroupMemberDecorate, and OpDecorationGroup.
  300. std::vector<uint32_t> annotationsBinary;
  301. // All type and constant instructions
  302. std::vector<uint32_t> typeConstantBinary;
  303. // All other instructions
  304. std::vector<uint32_t> mainBinary;
  305. // File information for debugging that will be used by OpLine.
  306. llvm::StringMap<uint32_t> debugFileIdMap;
  307. // Main file information for debugging that will be used by OpLine.
  308. uint32_t debugMainFileId;
  309. // One HLSL source line may result in several SPIR-V instructions. In order to
  310. // avoid emitting many OpLine instructions with identical line and column
  311. // numbers, we record the last line and column number that was used by OpLine,
  312. // and only emit a new OpLine when a new line/column in the source is
  313. // discovered. The last debug line number information emitted by OpLine.
  314. uint32_t debugLine;
  315. // The last debug column number information emitted by OpLine.
  316. uint32_t debugColumn;
  317. // True if the last emitted instruction was OpSelectionMerge or OpLoopMerge.
  318. bool lastOpWasMergeInst;
  319. };
  320. } // namespace spirv
  321. } // namespace clang
  322. #endif // LLVM_CLANG_SPIRV_EMITVISITOR_H