ModuleBuilder.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. //===--- ModuleBuilder.cpp - SPIR-V builder implementation ----*- C++ -*---===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/ModuleBuilder.h"
  10. #include "spirv/1.0//spirv.hpp11"
  11. #include "clang/SPIRV/InstBuilder.h"
  12. #include "llvm/llvm_assert/assert.h"
  13. namespace clang {
  14. namespace spirv {
  15. ModuleBuilder::ModuleBuilder(SPIRVContext *C)
  16. : theContext(*C), theModule(), theFunction(nullptr), insertPoint(nullptr),
  17. instBuilder(nullptr) {
  18. instBuilder.setConsumer([this](std::vector<uint32_t> &&words) {
  19. this->constructSite = std::move(words);
  20. });
  21. }
  22. std::vector<uint32_t> ModuleBuilder::takeModule() {
  23. theModule.setBound(theContext.getNextId());
  24. std::vector<uint32_t> binary;
  25. auto ib = InstBuilder([&binary](std::vector<uint32_t> &&words) {
  26. binary.insert(binary.end(), words.begin(), words.end());
  27. });
  28. theModule.take(&ib);
  29. return std::move(binary);
  30. }
  31. uint32_t ModuleBuilder::beginFunction(uint32_t funcType, uint32_t returnType,
  32. llvm::StringRef funcName) {
  33. if (theFunction) {
  34. assert(false && "found nested function");
  35. return 0;
  36. }
  37. const uint32_t fId = theContext.takeNextId();
  38. theFunction = llvm::make_unique<Function>(
  39. returnType, fId, spv::FunctionControlMask::MaskNone, funcType);
  40. theModule.addDebugName(fId, funcName);
  41. return fId;
  42. }
  43. uint32_t ModuleBuilder::addFnParameter(uint32_t type, llvm::StringRef name) {
  44. assert(theFunction && "found detached parameter");
  45. const uint32_t pointerType =
  46. getPointerType(type, spv::StorageClass::Function);
  47. const uint32_t paramId = theContext.takeNextId();
  48. theFunction->addParameter(pointerType, paramId);
  49. theModule.addDebugName(paramId, name);
  50. return paramId;
  51. }
  52. uint32_t ModuleBuilder::addFnVariable(uint32_t type, llvm::StringRef name,
  53. llvm::Optional<uint32_t> init) {
  54. assert(theFunction && "found detached local variable");
  55. const uint32_t varId = theContext.takeNextId();
  56. theFunction->addVariable(type, varId, init);
  57. theModule.addDebugName(varId, name);
  58. return varId;
  59. }
  60. bool ModuleBuilder::endFunction() {
  61. if (theFunction == nullptr) {
  62. assert(false && "no active function");
  63. return false;
  64. }
  65. // Move all basic blocks into the current function.
  66. // TODO: we should adjust the order the basic blocks according to
  67. // SPIR-V validation rules.
  68. for (auto &bb : basicBlocks) {
  69. theFunction->addBasicBlock(std::move(bb.second));
  70. }
  71. basicBlocks.clear();
  72. theModule.addFunction(std::move(theFunction));
  73. theFunction.reset(nullptr);
  74. insertPoint = nullptr;
  75. return true;
  76. }
  77. uint32_t ModuleBuilder::createBasicBlock(llvm::StringRef name) {
  78. if (theFunction == nullptr) {
  79. assert(false && "found detached basic block");
  80. return 0;
  81. }
  82. const uint32_t labelId = theContext.takeNextId();
  83. basicBlocks[labelId] = llvm::make_unique<BasicBlock>(labelId);
  84. theModule.addDebugName(labelId, name);
  85. return labelId;
  86. }
  87. bool ModuleBuilder::setInsertPoint(uint32_t labelId) {
  88. auto it = basicBlocks.find(labelId);
  89. if (it == basicBlocks.end()) {
  90. assert(false && "invalid <label-id>");
  91. return false;
  92. }
  93. insertPoint = it->second.get();
  94. return true;
  95. }
  96. uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
  97. assert(insertPoint && "null insert point");
  98. const uint32_t resultId = theContext.takeNextId();
  99. instBuilder.opLoad(resultType, resultId, pointer, llvm::None).x();
  100. insertPoint->appendInstruction(std::move(constructSite));
  101. return resultId;
  102. }
  103. void ModuleBuilder::createStore(uint32_t address, uint32_t value) {
  104. assert(insertPoint && "null insert point");
  105. instBuilder.opStore(address, value, llvm::None).x();
  106. insertPoint->appendInstruction(std::move(constructSite));
  107. }
  108. uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
  109. llvm::ArrayRef<uint32_t> indexes) {
  110. assert(insertPoint && "null insert point");
  111. const uint32_t id = theContext.takeNextId();
  112. instBuilder.opAccessChain(resultType, id, base, indexes).x();
  113. insertPoint->appendInstruction(std::move(constructSite));
  114. return id;
  115. }
  116. void ModuleBuilder::createReturn() {
  117. assert(insertPoint && "null insert point");
  118. instBuilder.opReturn().x();
  119. insertPoint->appendInstruction(std::move(constructSite));
  120. }
  121. void ModuleBuilder::createReturnValue(uint32_t value) {
  122. assert(insertPoint && "null insert point");
  123. instBuilder.opReturnValue(value).x();
  124. insertPoint->appendInstruction(std::move(constructSite));
  125. }
  126. void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
  127. spv::ExecutionMode em,
  128. const std::vector<uint32_t> &params) {
  129. instBuilder.opExecutionMode(entryPointId, em);
  130. for (const auto &param : params) {
  131. instBuilder.literalInteger(param);
  132. }
  133. instBuilder.x();
  134. theModule.addExecutionMode(std::move(constructSite));
  135. }
  136. uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
  137. spv::StorageClass storageClass) {
  138. const uint32_t pointerType = getPointerType(type, storageClass);
  139. const uint32_t varId = theContext.takeNextId();
  140. instBuilder.opVariable(pointerType, varId, storageClass, llvm::None).x();
  141. theModule.addVariable(std::move(constructSite));
  142. return varId;
  143. }
  144. uint32_t ModuleBuilder::addStageBuiltinVariable(uint32_t type,
  145. spv::BuiltIn builtin) {
  146. spv::StorageClass sc = spv::StorageClass::Input;
  147. switch (builtin) {
  148. case spv::BuiltIn::Position:
  149. case spv::BuiltIn::PointSize:
  150. // TODO: add the rest output builtins
  151. sc = spv::StorageClass::Output;
  152. break;
  153. default:
  154. break;
  155. }
  156. const uint32_t pointerType = getPointerType(type, sc);
  157. const uint32_t varId = theContext.takeNextId();
  158. instBuilder.opVariable(pointerType, varId, sc, llvm::None).x();
  159. theModule.addVariable(std::move(constructSite));
  160. // Decorate with the specified Builtin
  161. const Decoration *d = Decoration::getBuiltIn(theContext, builtin);
  162. theModule.addDecoration(*d, varId);
  163. return varId;
  164. }
  165. void ModuleBuilder::decorateLocation(uint32_t targetId, uint32_t location) {
  166. const Decoration *d =
  167. Decoration::getLocation(theContext, location, llvm::None);
  168. theModule.addDecoration(*d, targetId);
  169. }
  170. #define IMPL_GET_PRIMITIVE_TYPE(ty) \
  171. \
  172. uint32_t ModuleBuilder::get##ty##Type() { \
  173. const Type *type = Type::get##ty(theContext); \
  174. const uint32_t typeId = theContext.getResultIdForType(type); \
  175. theModule.addType(type, typeId); \
  176. return typeId; \
  177. \
  178. }
  179. IMPL_GET_PRIMITIVE_TYPE(Void)
  180. IMPL_GET_PRIMITIVE_TYPE(Bool)
  181. IMPL_GET_PRIMITIVE_TYPE(Int32)
  182. IMPL_GET_PRIMITIVE_TYPE(Uint32)
  183. IMPL_GET_PRIMITIVE_TYPE(Float32)
  184. #undef IMPL_GET_PRIMITIVE_TYPE
  185. uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
  186. const Type *type = nullptr;
  187. switch (elemCount) {
  188. case 2:
  189. type = Type::getVec2(theContext, elemType);
  190. break;
  191. case 3:
  192. type = Type::getVec3(theContext, elemType);
  193. break;
  194. case 4:
  195. type = Type::getVec4(theContext, elemType);
  196. break;
  197. default:
  198. assert(false && "unhandled vector size");
  199. // Error found. Return 0 as the <result-id> directly.
  200. return 0;
  201. }
  202. const uint32_t typeId = theContext.getResultIdForType(type);
  203. theModule.addType(type, typeId);
  204. return typeId;
  205. }
  206. uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
  207. spv::StorageClass storageClass) {
  208. const Type *type = Type::getPointer(theContext, storageClass, pointeeType);
  209. const uint32_t typeId = theContext.getResultIdForType(type);
  210. theModule.addType(type, typeId);
  211. return typeId;
  212. }
  213. uint32_t ModuleBuilder::getStructType(llvm::ArrayRef<uint32_t> fieldTypes) {
  214. const Type *type = Type::getStruct(theContext, fieldTypes);
  215. const uint32_t typeId = theContext.getResultIdForType(type);
  216. theModule.addType(type, typeId);
  217. return typeId;
  218. }
  219. uint32_t
  220. ModuleBuilder::getFunctionType(uint32_t returnType,
  221. const std::vector<uint32_t> &paramTypes) {
  222. const Type *type = Type::getFunction(theContext, returnType, paramTypes);
  223. const uint32_t typeId = theContext.getResultIdForType(type);
  224. theModule.addType(type, typeId);
  225. return typeId;
  226. }
  227. uint32_t ModuleBuilder::getConstantBool(bool value) {
  228. const uint32_t typeId = getBoolType();
  229. const Constant *constant = value ? Constant::getTrue(theContext, typeId)
  230. : Constant::getFalse(theContext, typeId);
  231. const uint32_t constId = theContext.getResultIdForConstant(constant);
  232. theModule.addConstant(constant, constId);
  233. return constId;
  234. }
  235. #define IMPL_GET_PRIMITIVE_CONST(builderTy, cppTy) \
  236. \
  237. uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) { \
  238. const uint32_t typeId = get##builderTy##Type(); \
  239. const Constant *constant = \
  240. Constant::get##builderTy(theContext, typeId, value); \
  241. const uint32_t constId = theContext.getResultIdForConstant(constant); \
  242. theModule.addConstant(constant, constId); \
  243. return constId; \
  244. \
  245. }
  246. IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
  247. IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
  248. IMPL_GET_PRIMITIVE_CONST(Float32, float)
  249. #undef IMPL_GET_PRIMITIVE_VALUE
  250. uint32_t
  251. ModuleBuilder::getConstantComposite(uint32_t typeId,
  252. llvm::ArrayRef<uint32_t> constituents) {
  253. const Constant *constant =
  254. Constant::getComposite(theContext, typeId, constituents);
  255. const uint32_t constId = theContext.getResultIdForConstant(constant);
  256. theModule.addConstant(constant, constId);
  257. return constId;
  258. }
  259. } // end namespace spirv
  260. } // end namespace clang