EmitSPIRVAction.cpp 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571
  1. //===--- EmitSPIRVAction.cpp - EmitSPIRVAction implementation -------------===//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "clang/SPIRV/EmitSPIRVAction.h"
  10. #include "dxc/HlslIntrinsicOp.h"
  11. #include "clang/AST/AST.h"
  12. #include "clang/AST/ASTConsumer.h"
  13. #include "clang/AST/ASTContext.h"
  14. #include "clang/Basic/Diagnostic.h"
  15. #include "clang/Frontend/CompilerInstance.h"
  16. #include "clang/SPIRV/DeclResultIdMapper.h"
  17. #include "clang/SPIRV/ModuleBuilder.h"
  18. #include "clang/SPIRV/TypeTranslator.h"
  19. #include "llvm/ADT/STLExtras.h"
  20. #include "llvm/ADT/SetVector.h"
  21. namespace clang {
  22. namespace spirv {
  23. namespace {
  24. /// Returns true if the given type is a bool or vector of bool type.
  25. bool isBoolOrVecOfBoolType(QualType type) {
  26. return type->isBooleanType() ||
  27. (hlsl::IsHLSLVecType(type) &&
  28. hlsl::GetHLSLVecElementType(type)->isBooleanType());
  29. }
  30. /// Returns true if the given type is a signed integer or vector of signed
  31. /// integer type.
  32. bool isSintOrVecOfSintType(QualType type) {
  33. return type->isSignedIntegerType() ||
  34. (hlsl::IsHLSLVecType(type) &&
  35. hlsl::GetHLSLVecElementType(type)->isSignedIntegerType());
  36. }
  37. /// Returns true if the given type is an unsigned integer or vector of unsigned
  38. /// integer type.
  39. bool isUintOrVecOfUintType(QualType type) {
  40. return type->isUnsignedIntegerType() ||
  41. (hlsl::IsHLSLVecType(type) &&
  42. hlsl::GetHLSLVecElementType(type)->isUnsignedIntegerType());
  43. }
  44. /// Returns true if the given type is a float or vector of float type.
  45. bool isFloatOrVecOfFloatType(QualType type) {
  46. return type->isFloatingType() ||
  47. (hlsl::IsHLSLVecType(type) &&
  48. hlsl::GetHLSLVecElementType(type)->isFloatingType());
  49. }
  50. } // namespace
  51. /// SPIR-V emitter class. It consumes the HLSL AST and emits SPIR-V words.
  52. ///
  53. /// This class only overrides the HandleTranslationUnit() method; Traversing
  54. /// through the AST is done manually instead of using ASTConsumer's harness.
  55. class SPIRVEmitter : public ASTConsumer {
  56. public:
  57. explicit SPIRVEmitter(CompilerInstance &ci)
  58. : theCompilerInstance(ci), astContext(ci.getASTContext()),
  59. diags(ci.getDiagnostics()),
  60. entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction),
  61. shaderStage(getSpirvShaderStageFromHlslProfile(
  62. ci.getCodeGenOpts().HLSLProfile.c_str())),
  63. theContext(), theBuilder(&theContext),
  64. declIdMapper(shaderStage, theBuilder, diags),
  65. typeTranslator(theBuilder, diags), entryFunctionId(0),
  66. curFunction(nullptr) {}
  67. spv::ExecutionModel getSpirvShaderStageFromHlslProfile(const char *profile) {
  68. assert(profile && "nullptr passed as HLSL profile.");
  69. // DXIL Models are:
  70. // Profile (DXIL Model) : HLSL Shader Kind : SPIR-V Shader Stage
  71. // vs_<version> : Vertex Shader : Vertex Shader
  72. // hs_<version> : Hull Shader : Tassellation Control Shader
  73. // ds_<version> : Domain Shader : Tessellation Evaluation Shader
  74. // gs_<version> : Geometry Shader : Geometry Shader
  75. // ps_<version> : Pixel Shader : Fragment Shader
  76. // cs_<version> : Compute Shader : Compute Shader
  77. switch (profile[0]) {
  78. case 'v':
  79. return spv::ExecutionModel::Vertex;
  80. case 'h':
  81. return spv::ExecutionModel::TessellationControl;
  82. case 'd':
  83. return spv::ExecutionModel::TessellationEvaluation;
  84. case 'g':
  85. return spv::ExecutionModel::Geometry;
  86. case 'p':
  87. return spv::ExecutionModel::Fragment;
  88. case 'c':
  89. return spv::ExecutionModel::GLCompute;
  90. default:
  91. emitError("Unknown HLSL Profile: %0") << profile;
  92. return spv::ExecutionModel::Fragment;
  93. }
  94. }
  95. void AddRequiredCapabilitiesForExecutionModel(spv::ExecutionModel em) {
  96. if (em == spv::ExecutionModel::TessellationControl ||
  97. em == spv::ExecutionModel::TessellationEvaluation) {
  98. theBuilder.requireCapability(spv::Capability::Tessellation);
  99. emitError("Tasselation shaders are currently not supported.");
  100. } else if (em == spv::ExecutionModel::Geometry) {
  101. theBuilder.requireCapability(spv::Capability::Geometry);
  102. emitError("Geometry shaders are currently not supported.");
  103. } else {
  104. theBuilder.requireCapability(spv::Capability::Shader);
  105. }
  106. }
  107. /// \brief Adds the execution mode for the given entry point based on the
  108. /// execution model.
  109. void AddExecutionModeForEntryPoint(spv::ExecutionModel execModel,
  110. uint32_t entryPointId) {
  111. if (execModel == spv::ExecutionModel::Fragment) {
  112. // TODO: Implement the logic to determine the proper Execution Mode for
  113. // fragment shaders. Currently using OriginUpperLeft as default.
  114. theBuilder.addExecutionMode(entryPointId,
  115. spv::ExecutionMode::OriginUpperLeft, {});
  116. emitWarning("Execution mode for fragment shaders is "
  117. "currently set to OriginUpperLeft by default.");
  118. } else {
  119. emitWarning(
  120. "Execution mode is currently only defined for fragment shaders.");
  121. // TODO: Implement logic for adding proper execution mode for other
  122. // shader stages. Silently skipping for now.
  123. }
  124. }
  125. void HandleTranslationUnit(ASTContext &context) override {
  126. const spv::ExecutionModel em = getSpirvShaderStageFromHlslProfile(
  127. theCompilerInstance.getCodeGenOpts().HLSLProfile.c_str());
  128. AddRequiredCapabilitiesForExecutionModel(em);
  129. // Addressing and memory model are required in a valid SPIR-V module.
  130. theBuilder.setAddressingModel(spv::AddressingModel::Logical);
  131. theBuilder.setMemoryModel(spv::MemoryModel::GLSL450);
  132. TranslationUnitDecl *tu = context.getTranslationUnitDecl();
  133. // The entry function is the seed of the queue.
  134. for (auto *decl : tu->decls()) {
  135. if (auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  136. if (funcDecl->getName() == entryFunctionName) {
  137. workQueue.insert(funcDecl);
  138. }
  139. }
  140. }
  141. // TODO: enlarge the queue upon seeing a function call.
  142. // Translate all functions reachable from the entry function.
  143. // The queue can grow in the meanwhile; so need to keep evaluating
  144. // workQueue.size().
  145. for (uint32_t i = 0; i < workQueue.size(); ++i) {
  146. doDecl(workQueue[i]);
  147. }
  148. theBuilder.addEntryPoint(shaderStage, entryFunctionId, entryFunctionName,
  149. declIdMapper.collectStageVariables());
  150. AddExecutionModeForEntryPoint(shaderStage, entryFunctionId);
  151. // Add Location decorations to stage input/output variables.
  152. declIdMapper.finalizeStageIOLocations();
  153. // Output the constructed module.
  154. std::vector<uint32_t> m = theBuilder.takeModule();
  155. theCompilerInstance.getOutStream()->write(
  156. reinterpret_cast<const char *>(m.data()), m.size() * 4);
  157. }
  158. void doDecl(const Decl *decl) {
  159. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  160. doVarDecl(varDecl);
  161. } else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  162. doFunctionDecl(funcDecl);
  163. } else {
  164. // TODO: Implement handling of other Decl types.
  165. emitWarning("Decl type '%0' is not supported yet.")
  166. << std::string(decl->getDeclKindName());
  167. }
  168. }
  169. void doFunctionDecl(const FunctionDecl *decl) {
  170. curFunction = decl;
  171. const llvm::StringRef funcName = decl->getName();
  172. uint32_t funcId;
  173. if (funcName == entryFunctionName) {
  174. // First create stage variables for the entry point.
  175. declIdMapper.createStageVarFromFnReturn(decl);
  176. for (const auto *param : decl->params())
  177. declIdMapper.createStageVarFromFnParam(param);
  178. // Construct the function signature.
  179. const uint32_t voidType = theBuilder.getVoidType();
  180. const uint32_t funcType = theBuilder.getFunctionType(voidType, {});
  181. // The entry function surely does not have pre-assigned <result-id> for
  182. // it like other functions that got added to the work queue following
  183. // function calls.
  184. funcId = theBuilder.beginFunction(funcType, voidType, funcName);
  185. // Record the entry function's <result-id>.
  186. entryFunctionId = funcId;
  187. } else {
  188. const uint32_t retType =
  189. typeTranslator.translateType(decl->getReturnType());
  190. // Construct the function signature.
  191. llvm::SmallVector<uint32_t, 4> paramTypes;
  192. for (const auto *param : decl->params()) {
  193. const uint32_t valueType =
  194. typeTranslator.translateType(param->getType());
  195. const uint32_t ptrType =
  196. theBuilder.getPointerType(valueType, spv::StorageClass::Function);
  197. paramTypes.push_back(ptrType);
  198. }
  199. const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
  200. // Non-entry functions are added to the work queue following function
  201. // calls. We have already assigned <result-id>s for it when translating
  202. // its call site. Query it here.
  203. funcId = declIdMapper.getDeclResultId(decl);
  204. theBuilder.beginFunction(funcType, retType, funcName, funcId);
  205. // Create all parameters.
  206. for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
  207. const ParmVarDecl *paramDecl = decl->getParamDecl(i);
  208. const uint32_t paramId =
  209. theBuilder.addFnParameter(paramTypes[i], paramDecl->getName());
  210. declIdMapper.registerDeclResultId(paramDecl, paramId);
  211. }
  212. }
  213. if (decl->hasBody()) {
  214. // The entry basic block.
  215. const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry");
  216. theBuilder.setInsertPoint(entryLabel);
  217. // Process all statments in the body.
  218. doStmt(decl->getBody());
  219. // We have processed all Stmts in this function and now in the last
  220. // basic block. Make sure we have OpReturn if missing.
  221. if (!theBuilder.isCurrentBasicBlockTerminated()) {
  222. theBuilder.createReturn();
  223. }
  224. }
  225. theBuilder.endFunction();
  226. curFunction = nullptr;
  227. }
  228. void doVarDecl(const VarDecl *decl) {
  229. if (decl->isLocalVarDecl()) {
  230. const uint32_t ptrType = theBuilder.getPointerType(
  231. typeTranslator.translateType(decl->getType()),
  232. spv::StorageClass::Function);
  233. // Handle initializer. SPIR-V requires that "initializer must be an <id>
  234. // from a constant instruction or a global (module scope) OpVariable
  235. // instruction."
  236. llvm::Optional<uint32_t> constInitializer = llvm::None;
  237. uint32_t varInitializer = 0;
  238. if (decl->hasInit()) {
  239. const Expr *declInit = decl->getInit();
  240. // First try to evaluate the initializer as a constant expression
  241. Expr::EvalResult evalResult;
  242. if (declInit->EvaluateAsRValue(evalResult, astContext) &&
  243. !evalResult.HasSideEffects) {
  244. constInitializer = llvm::Optional<uint32_t>(
  245. translateAPValue(evalResult.Val, decl->getType()));
  246. }
  247. // If we cannot evaluate the initializer as a constant expression, we'll
  248. // need use OpStore to write the initializer to the variable.
  249. if (!constInitializer.hasValue()) {
  250. varInitializer = doExpr(declInit);
  251. }
  252. }
  253. const uint32_t varId =
  254. theBuilder.addFnVariable(ptrType, decl->getName(), constInitializer);
  255. declIdMapper.registerDeclResultId(decl, varId);
  256. if (varInitializer) {
  257. theBuilder.createStore(varId, varInitializer);
  258. }
  259. } else {
  260. // TODO: handle global variables
  261. emitError("Global variables are not supported yet.");
  262. }
  263. }
  264. void doStmt(const Stmt *stmt) {
  265. if (const auto *compoundStmt = dyn_cast<CompoundStmt>(stmt)) {
  266. for (auto *st : compoundStmt->body())
  267. doStmt(st);
  268. } else if (const auto *retStmt = dyn_cast<ReturnStmt>(stmt)) {
  269. doReturnStmt(retStmt);
  270. } else if (const auto *declStmt = dyn_cast<DeclStmt>(stmt)) {
  271. for (auto *decl : declStmt->decls()) {
  272. doDecl(decl);
  273. }
  274. } else if (const auto *ifStmt = dyn_cast<IfStmt>(stmt)) {
  275. doIfStmt(ifStmt);
  276. } else if (const auto *forStmt = dyn_cast<ForStmt>(stmt)) {
  277. doForStmt(forStmt);
  278. } else if (const auto *nullStmt = dyn_cast<NullStmt>(stmt)) {
  279. // For the null statement ";". We don't need to do anything.
  280. } else if (const auto *expr = dyn_cast<Expr>(stmt)) {
  281. // All cases for expressions used as statements
  282. doExpr(expr);
  283. } else {
  284. emitError("Stmt '%0' is not supported yet.") << stmt->getStmtClassName();
  285. }
  286. }
  287. void doReturnStmt(const ReturnStmt *stmt) {
  288. // For normal functions, just return in the normal way.
  289. if (curFunction->getName() != entryFunctionName) {
  290. theBuilder.createReturnValue(doExpr(stmt->getRetValue()));
  291. return;
  292. }
  293. // SPIR-V requires the signature of entry functions to be void(), while
  294. // in HLSL we can have non-void parameter and return types for entry points.
  295. // So we should treat the ReturnStmt in entry functions specially.
  296. //
  297. // We need to walk through the return type, and for each subtype attached
  298. // with semantics, write out the value to the corresponding stage variable
  299. // mapped to the semantic.
  300. const uint32_t stageVarId =
  301. declIdMapper.getRemappedDeclResultId(curFunction);
  302. if (stageVarId) {
  303. // The return value is mapped to a single stage variable. We just need
  304. // to store the value into the stage variable instead.
  305. theBuilder.createStore(stageVarId, doExpr(stmt->getRetValue()));
  306. theBuilder.createReturn();
  307. return;
  308. }
  309. QualType retType = stmt->getRetValue()->getType();
  310. if (const auto *structType = retType->getAsStructureType()) {
  311. // We are trying to return a value of struct type.
  312. // First get the return value. Clang AST will use an LValueToRValue cast
  313. // for returning a struct variable. We need to ignore the cast to avoid
  314. // creating OpLoad instruction since we need the pointer to the variable
  315. // for creating access chain later.
  316. const uint32_t retValue =
  317. doExpr(stmt->getRetValue()->IgnoreParenLValueCasts());
  318. // Then go through all fields.
  319. uint32_t fieldIndex = 0;
  320. for (const auto *field : structType->getDecl()->fields()) {
  321. // Load the value from the current field.
  322. const uint32_t valueType =
  323. typeTranslator.translateType(field->getType());
  324. // TODO: We may need to change the storage class accordingly.
  325. const uint32_t ptrType = theBuilder.getPointerType(
  326. typeTranslator.translateType(field->getType()),
  327. spv::StorageClass::Function);
  328. const uint32_t indexId = theBuilder.getConstantInt32(fieldIndex++);
  329. const uint32_t valuePtr =
  330. theBuilder.createAccessChain(ptrType, retValue, {indexId});
  331. const uint32_t value = theBuilder.createLoad(valueType, valuePtr);
  332. // Store it to the corresponding stage variable.
  333. const uint32_t targetVar = declIdMapper.getDeclResultId(field);
  334. theBuilder.createStore(targetVar, value);
  335. }
  336. } else {
  337. emitError("Return type '%0' for entry function is not supported yet.")
  338. << retType->getTypeClassName();
  339. }
  340. }
  341. void doIfStmt(const IfStmt *ifStmt) {
  342. // if statements are composed of:
  343. // if (<check>) { <then> } else { <else> }
  344. //
  345. // To translate if statements, we'll need to emit the <check> expressions
  346. // in the current basic block, and then create separate basic blocks for
  347. // <then> and <else>. Additionally, we'll need a <merge> block as per
  348. // SPIR-V's structured control flow requirements. Depending whether there
  349. // exists the else branch, the final CFG should normally be like the
  350. // following. Exceptions will occur with non-local exits like loop breaks
  351. // or early returns.
  352. // +-------+ +-------+
  353. // | check | | check |
  354. // +-------+ +-------+
  355. // | |
  356. // +-------+-------+ +-----+-----+
  357. // | true | false | true | false
  358. // v v or v |
  359. // +------+ +------+ +------+ |
  360. // | then | | else | | then | |
  361. // +------+ +------+ +------+ |
  362. // | | | v
  363. // | +-------+ | | +-------+
  364. // +-> | merge | <-+ +---> | merge |
  365. // +-------+ +-------+
  366. // First emit the instruction for evaluating the condition.
  367. const uint32_t condition = doExpr(ifStmt->getCond());
  368. // Then we need to emit the instruction for the conditional branch.
  369. // We'll need the <label-id> for the then/else/merge block to do so.
  370. const bool hasElse = ifStmt->getElse() != nullptr;
  371. const uint32_t thenBB = theBuilder.createBasicBlock("if.true");
  372. const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge");
  373. const uint32_t elseBB =
  374. hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB;
  375. // Create the branch instruction. This will end the current basic block.
  376. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB);
  377. theBuilder.addSuccessor(thenBB);
  378. theBuilder.addSuccessor(elseBB);
  379. // The current basic block has the OpSelectionMerge instruction. We need
  380. // to record its merge target.
  381. theBuilder.setMergeTarget(mergeBB);
  382. // Handle the then branch
  383. theBuilder.setInsertPoint(thenBB);
  384. doStmt(ifStmt->getThen());
  385. if (!theBuilder.isCurrentBasicBlockTerminated())
  386. theBuilder.createBranch(mergeBB);
  387. theBuilder.addSuccessor(mergeBB);
  388. // Handle the else branch (if exists)
  389. if (hasElse) {
  390. theBuilder.setInsertPoint(elseBB);
  391. doStmt(ifStmt->getElse());
  392. if (!theBuilder.isCurrentBasicBlockTerminated())
  393. theBuilder.createBranch(mergeBB);
  394. theBuilder.addSuccessor(mergeBB);
  395. }
  396. // From now on, we'll emit instructions into the merge block.
  397. theBuilder.setInsertPoint(mergeBB);
  398. }
  399. void doForStmt(const ForStmt *forStmt) {
  400. // for loops are composed of:
  401. // for (<init>; <check>; <continue>) <body>
  402. //
  403. // To translate a for loop, we'll need to emit all <init> statements
  404. // in the current basic block, and then have separate basic blocks for
  405. // <check>, <continue>, and <body>. Besides, since SPIR-V requires
  406. // structured control flow, we need two more basic blocks, <header>
  407. // and <merge>. <header> is the block before control flow diverges,
  408. // while <merge> is the block where control flow subsequently converges.
  409. // The <check> block can take the responsibility of the <header> block.
  410. // The final CFG should normally be like the following. Exceptions will
  411. // occur with non-local exits like loop breaks or early returns.
  412. // +--------+
  413. // | init |
  414. // +--------+
  415. // |
  416. // v
  417. // +----------+
  418. // | header | <---------------+
  419. // | (check) | |
  420. // +----------+ |
  421. // | |
  422. // +-------+-------+ |
  423. // | false | true |
  424. // | v |
  425. // | +------+ +----------+
  426. // | | body | --> | continue |
  427. // v +------+ +----------+
  428. // +-------+
  429. // | merge |
  430. // +-------+
  431. //
  432. // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec.
  433. // Create basic blocks
  434. const uint32_t checkBB = theBuilder.createBasicBlock("for.check");
  435. const uint32_t bodyBB = theBuilder.createBasicBlock("for.body");
  436. const uint32_t continueBB = theBuilder.createBasicBlock("for.continue");
  437. const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge");
  438. // Process the <init> block
  439. if (const Stmt *initStmt = forStmt->getInit()) {
  440. doStmt(initStmt);
  441. }
  442. theBuilder.createBranch(checkBB);
  443. theBuilder.addSuccessor(checkBB);
  444. // Process the <check> block
  445. theBuilder.setInsertPoint(checkBB);
  446. uint32_t condition;
  447. if (const Expr *check = forStmt->getCond()) {
  448. condition = doExpr(check);
  449. } else {
  450. condition = theBuilder.getConstantBool(true);
  451. }
  452. theBuilder.createConditionalBranch(condition, bodyBB,
  453. /*false branch*/ mergeBB,
  454. /*merge*/ mergeBB, continueBB);
  455. theBuilder.addSuccessor(bodyBB);
  456. theBuilder.addSuccessor(mergeBB);
  457. // The current basic block has OpLoopMerge instruction. We need to set its
  458. // continue and merge target.
  459. theBuilder.setContinueTarget(continueBB);
  460. theBuilder.setMergeTarget(mergeBB);
  461. // Process the <body> block
  462. theBuilder.setInsertPoint(bodyBB);
  463. if (const Stmt *body = forStmt->getBody()) {
  464. doStmt(body);
  465. }
  466. theBuilder.createBranch(continueBB);
  467. theBuilder.addSuccessor(continueBB);
  468. // Process the <continue> block
  469. theBuilder.setInsertPoint(continueBB);
  470. if (const Expr *cont = forStmt->getInc()) {
  471. doExpr(cont);
  472. }
  473. theBuilder.createBranch(checkBB); // <continue> should jump back to header
  474. theBuilder.addSuccessor(checkBB);
  475. // Set insertion point to the <merge> block for subsequent statements
  476. theBuilder.setInsertPoint(mergeBB);
  477. }
  478. uint32_t doExpr(const Expr *expr) {
  479. if (const auto *delRefExpr = dyn_cast<DeclRefExpr>(expr)) {
  480. // Returns the <result-id> of the referenced Decl.
  481. const NamedDecl *referredDecl = delRefExpr->getFoundDecl();
  482. assert(referredDecl && "found non-NamedDecl referenced");
  483. return declIdMapper.getDeclResultId(referredDecl);
  484. }
  485. if (const auto *parenExpr = dyn_cast<ParenExpr>(expr)) {
  486. // Just need to return what's inside the parentheses.
  487. return doExpr(parenExpr->getSubExpr());
  488. }
  489. if (const auto *memberExpr = dyn_cast<MemberExpr>(expr)) {
  490. const uint32_t base = doExpr(memberExpr->getBase());
  491. const auto *memberDecl = memberExpr->getMemberDecl();
  492. if (const auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
  493. const auto index =
  494. theBuilder.getConstantInt32(fieldDecl->getFieldIndex());
  495. const uint32_t fieldType =
  496. typeTranslator.translateType(fieldDecl->getType());
  497. const uint32_t ptrType =
  498. theBuilder.getPointerType(fieldType, spv::StorageClass::Function);
  499. return theBuilder.createAccessChain(ptrType, base, {index});
  500. } else {
  501. emitError("Decl '%0' in MemberExpr is not supported yet.")
  502. << memberDecl->getDeclKindName();
  503. return 0;
  504. }
  505. }
  506. if (const auto *castExpr = dyn_cast<CastExpr>(expr)) {
  507. return doCastExpr(castExpr);
  508. }
  509. if (const auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
  510. return doInitListExpr(initListExpr);
  511. }
  512. if (const auto *boolLiteral = dyn_cast<CXXBoolLiteralExpr>(expr)) {
  513. const bool value = boolLiteral->getValue();
  514. return theBuilder.getConstantBool(value);
  515. }
  516. if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
  517. return translateAPInt(intLiteral->getValue(), expr->getType());
  518. }
  519. if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
  520. return translateAPFloat(floatLiteral->getValue(), expr->getType());
  521. }
  522. // CompoundAssignOperator is a subclass of BinaryOperator. It should be
  523. // checked before BinaryOperator.
  524. if (const auto *compoundAssignOp = dyn_cast<CompoundAssignOperator>(expr)) {
  525. return doCompoundAssignOperator(compoundAssignOp);
  526. }
  527. if (const auto *binOp = dyn_cast<BinaryOperator>(expr)) {
  528. return doBinaryOperator(binOp);
  529. }
  530. if (const auto *unaryOp = dyn_cast<UnaryOperator>(expr)) {
  531. return doUnaryOperator(unaryOp);
  532. }
  533. if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
  534. return doCallExpr(funcCall);
  535. }
  536. if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
  537. return doConditionalOperator(condExpr);
  538. }
  539. emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
  540. // TODO: handle other expressions
  541. return 0;
  542. }
  543. uint32_t doInitListExpr(const InitListExpr *expr) {
  544. // First try to evaluate the expression as constant expression
  545. Expr::EvalResult evalResult;
  546. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  547. !evalResult.HasSideEffects) {
  548. return translateAPValue(evalResult.Val, expr->getType());
  549. }
  550. const QualType type = expr->getType();
  551. // InitListExpr is tricky to handle. It can have initializers of different
  552. // types, and each initializer can itself be of a composite type.
  553. // The front end parsing only gurantees the total number of elements in
  554. // the initializers are the same as the one of the InitListExpr's type.
  555. // For builtin types, we can assume the front end parsing has injected
  556. // the necessary ImplicitCastExpr for type casting. So we just need to
  557. // return the result of processing the only initializer.
  558. if (type->isBuiltinType()) {
  559. assert(expr->getNumInits() == 1);
  560. return doExpr(expr->getInit(0));
  561. }
  562. // For composite types, we need to type cast the initializers if necessary.
  563. const auto initCount = expr->getNumInits();
  564. const uint32_t resultType = typeTranslator.translateType(type);
  565. // For InitListExpr of vector type and having one initializer, we can avoid
  566. // composite extraction and construction.
  567. if (initCount == 1 && hlsl::IsHLSLVecType(type)) {
  568. const Expr *init = expr->getInit(0);
  569. // If the initializer already have the correct type, we don't need to
  570. // type cast.
  571. if (init->getType() == type) {
  572. return doExpr(init);
  573. }
  574. // For the rest, we can do type cast as a whole.
  575. const auto targetElemType = hlsl::GetHLSLVecElementType(type);
  576. if (targetElemType->isBooleanType()) {
  577. return castToBool(init, type);
  578. } else if (targetElemType->isIntegerType()) {
  579. return castToInt(init, type);
  580. } else if (targetElemType->isFloatingType()) {
  581. return castToFloat(init, type);
  582. } else {
  583. emitError("unimplemented vector InitList cases");
  584. expr->dump();
  585. return 0;
  586. }
  587. }
  588. // Cases needing composite extraction and construction
  589. std::vector<uint32_t> constituents;
  590. for (size_t i = 0; i < initCount; ++i) {
  591. const Expr *init = expr->getInit(i);
  592. if (!init->getType()->isBuiltinType()) {
  593. emitError("unimplemented InitList initializer type");
  594. init->dump();
  595. return 0;
  596. }
  597. constituents.push_back(doExpr(init));
  598. }
  599. return theBuilder.createCompositeConstruct(resultType, constituents);
  600. }
  601. uint32_t doBinaryOperator(const BinaryOperator *expr) {
  602. const auto opcode = expr->getOpcode();
  603. // Handle assignment first since we need to evaluate rhs before lhs.
  604. // For other binary operations, we need to evaluate lhs before rhs.
  605. if (opcode == BO_Assign) {
  606. const uint32_t rhs = doExpr(expr->getRHS());
  607. const uint32_t lhs = doExpr(expr->getLHS());
  608. theBuilder.createStore(lhs, rhs);
  609. // Assignment returns a rvalue.
  610. return rhs;
  611. }
  612. // Try to optimize floatN * float case
  613. if (opcode == BO_Mul) {
  614. if (const uint32_t result = tryToGenFloatVectorScale(expr))
  615. return result;
  616. }
  617. const uint32_t lhs = doExpr(expr->getLHS());
  618. const uint32_t rhs = doExpr(expr->getRHS());
  619. const uint32_t typeId = typeTranslator.translateType(expr->getType());
  620. const QualType elemType = expr->getLHS()->getType();
  621. switch (opcode) {
  622. case BO_Add:
  623. case BO_Sub:
  624. case BO_Mul:
  625. case BO_Div:
  626. case BO_Rem:
  627. case BO_LT:
  628. case BO_LE:
  629. case BO_GT:
  630. case BO_GE:
  631. case BO_EQ:
  632. case BO_NE:
  633. case BO_And:
  634. case BO_Or:
  635. case BO_Xor:
  636. case BO_Shl:
  637. case BO_Shr:
  638. case BO_LAnd:
  639. case BO_LOr: {
  640. const spv::Op spvOp = translateOp(opcode, elemType);
  641. return theBuilder.createBinaryOp(spvOp, typeId, lhs, rhs);
  642. }
  643. case BO_Assign: {
  644. llvm_unreachable("assignment already handled before");
  645. } break;
  646. default:
  647. break;
  648. }
  649. emitError("BinaryOperator '%0' is not supported yet.")
  650. << expr->getOpcodeStr(opcode);
  651. expr->dump();
  652. return 0;
  653. }
  654. uint32_t doCompoundAssignOperator(const CompoundAssignOperator *expr) {
  655. const auto opcode = expr->getOpcode();
  656. // Try to optimize floatN *= float case
  657. if (opcode == BO_MulAssign) {
  658. if (const uint32_t result = tryToGenFloatVectorScale(expr))
  659. return result;
  660. }
  661. const auto *rhs = expr->getRHS();
  662. const auto *lhs = expr->getLHS();
  663. switch (opcode) {
  664. case BO_AddAssign:
  665. case BO_SubAssign:
  666. case BO_MulAssign:
  667. case BO_DivAssign:
  668. case BO_RemAssign:
  669. case BO_AndAssign:
  670. case BO_OrAssign:
  671. case BO_XorAssign:
  672. case BO_ShlAssign:
  673. case BO_ShrAssign: {
  674. const uint32_t resultType = typeTranslator.translateType(expr->getType());
  675. // Evalute rhs before lhs
  676. const uint32_t rhsVal = doExpr(rhs);
  677. const uint32_t lhsPtr = doExpr(lhs);
  678. const uint32_t lhsVal = theBuilder.createLoad(resultType, lhsPtr);
  679. const spv::Op spvOp = translateOp(opcode, expr->getType());
  680. const uint32_t result =
  681. theBuilder.createBinaryOp(spvOp, resultType, lhsVal, rhsVal);
  682. theBuilder.createStore(lhsPtr, result);
  683. // Compound assign operators return lvalues.
  684. return lhsPtr;
  685. }
  686. default:
  687. emitError("CompoundAssignOperator '%0' unimplemented")
  688. << expr->getOpcodeStr(opcode);
  689. return 0;
  690. }
  691. }
  692. uint32_t doUnaryOperator(const UnaryOperator *expr) {
  693. const auto opcode = expr->getOpcode();
  694. const auto *subExpr = expr->getSubExpr();
  695. const auto subType = subExpr->getType();
  696. const auto subValue = doExpr(subExpr);
  697. const auto subTypeId = typeTranslator.translateType(subType);
  698. switch (opcode) {
  699. case UO_PreInc:
  700. case UO_PreDec:
  701. case UO_PostInc:
  702. case UO_PostDec: {
  703. const bool isPre = opcode == UO_PreInc || opcode == UO_PreDec;
  704. const bool isInc = opcode == UO_PreInc || opcode == UO_PostInc;
  705. const spv::Op spvOp = translateOp(isInc ? BO_Add : BO_Sub, subType);
  706. const uint32_t one = getValueOne(subType);
  707. const uint32_t originValue = theBuilder.createLoad(subTypeId, subValue);
  708. const uint32_t incValue =
  709. theBuilder.createBinaryOp(spvOp, subTypeId, originValue, one);
  710. theBuilder.createStore(subValue, incValue);
  711. // Prefix increment/decrement operator returns a lvalue, while postfix
  712. // increment/decrement returns a rvalue.
  713. return isPre ? subValue : originValue;
  714. }
  715. case UO_Not:
  716. return theBuilder.createUnaryOp(spv::Op::OpNot, subTypeId, subValue);
  717. case UO_LNot:
  718. // Parsing will do the necessary casting to make sure we are applying the
  719. // ! operator on boolean values.
  720. return theBuilder.createUnaryOp(spv::Op::OpLogicalNot, subTypeId,
  721. subValue);
  722. case UO_Plus:
  723. // No need to do anything for the prefix + operator.
  724. return subValue;
  725. case UO_Minus: {
  726. // SPIR-V have two opcodes for negating values: OpSNegate and OpFNegate.
  727. const spv::Op spvOp = isFloatOrVecOfFloatType(subType)
  728. ? spv::Op::OpFNegate
  729. : spv::Op::OpSNegate;
  730. return theBuilder.createUnaryOp(spvOp, subTypeId, subValue);
  731. }
  732. default:
  733. break;
  734. }
  735. emitError("unary operator '%0' unimplemented yet")
  736. << expr->getOpcodeStr(opcode);
  737. expr->dump();
  738. return 0;
  739. }
  740. uint32_t doCastExpr(const CastExpr *expr) {
  741. const Expr *subExpr = expr->getSubExpr();
  742. const QualType toType = expr->getType();
  743. switch (expr->getCastKind()) {
  744. case CastKind::CK_LValueToRValue: {
  745. const uint32_t fromValue = doExpr(subExpr);
  746. // Using lvalue as rvalue means we need to OpLoad the contents from
  747. // the parameter/variable first.
  748. const uint32_t resultType = typeTranslator.translateType(toType);
  749. return theBuilder.createLoad(resultType, fromValue);
  750. }
  751. case CastKind::CK_NoOp:
  752. return doExpr(subExpr);
  753. case CastKind::CK_IntegralCast:
  754. case CastKind::CK_FloatingToIntegral:
  755. case CastKind::CK_HLSLCC_IntegralCast:
  756. case CastKind::CK_HLSLCC_FloatingToIntegral: {
  757. // Integer literals in the AST are represented using 64bit APInt
  758. // themselves and then implicitly casted into the expected bitwidth.
  759. // We need special treatment of integer literals here because generating
  760. // a 64bit constant and then explicit casting in SPIR-V requires Int64
  761. // capability. We should avoid introducing unnecessary capabilities to
  762. // our best.
  763. llvm::APSInt intValue;
  764. if (expr->EvaluateAsInt(intValue, astContext, Expr::SE_NoSideEffects)) {
  765. return translateAPInt(intValue, toType);
  766. }
  767. return castToInt(subExpr, toType);
  768. }
  769. case CastKind::CK_FloatingCast:
  770. case CastKind::CK_IntegralToFloating:
  771. case CastKind::CK_HLSLCC_FloatingCast:
  772. case CastKind::CK_HLSLCC_IntegralToFloating: {
  773. // First try to see if we can do constant folding for floating point
  774. // numbers like what we are doing for integers in the above.
  775. Expr::EvalResult evalResult;
  776. if (expr->EvaluateAsRValue(evalResult, astContext) &&
  777. !evalResult.HasSideEffects) {
  778. return translateAPFloat(evalResult.Val.getFloat(), toType);
  779. }
  780. return castToFloat(subExpr, toType);
  781. }
  782. case CastKind::CK_IntegralToBoolean:
  783. case CastKind::CK_FloatingToBoolean:
  784. case CastKind::CK_HLSLCC_IntegralToBoolean:
  785. case CastKind::CK_HLSLCC_FloatingToBoolean: {
  786. // First try to see if we can do constant folding.
  787. bool boolVal;
  788. if (!expr->HasSideEffects(astContext) &&
  789. expr->EvaluateAsBooleanCondition(boolVal, astContext)) {
  790. return theBuilder.getConstantBool(boolVal);
  791. }
  792. return castToBool(subExpr, toType);
  793. }
  794. case CastKind::CK_HLSLVectorSplat: {
  795. const size_t size = hlsl::GetHLSLVecSize(expr->getType());
  796. const uint32_t scalarValue = doExpr(subExpr);
  797. // Just return the scalar value for vector splat with size 1
  798. if (size == 1) {
  799. return scalarValue;
  800. }
  801. const uint32_t vecTypeId = typeTranslator.translateType(toType);
  802. llvm::SmallVector<uint32_t, 4> elements(size, scalarValue);
  803. return theBuilder.createCompositeConstruct(vecTypeId, elements);
  804. }
  805. case CastKind::CK_HLSLVectorTruncationCast: {
  806. const uint32_t toVecTypeId = typeTranslator.translateType(toType);
  807. const uint32_t elemTypeId =
  808. typeTranslator.translateType(hlsl::GetHLSLVecElementType(toType));
  809. const auto toSize = hlsl::GetHLSLVecSize(toType);
  810. const uint32_t composite = doExpr(subExpr);
  811. llvm::SmallVector<uint32_t, 4> elements;
  812. for (uint32_t i = 0; i < toSize; ++i) {
  813. elements.push_back(
  814. theBuilder.createCompositeExtract(elemTypeId, composite, {i}));
  815. }
  816. if (toSize == 1) {
  817. return elements.front();
  818. }
  819. return theBuilder.createCompositeConstruct(toVecTypeId, elements);
  820. }
  821. case CastKind::CK_HLSLVectorToScalarCast: {
  822. // The underlying should already be a vector of size 1.
  823. assert(hlsl::GetHLSLVecSize(subExpr->getType()) == 1);
  824. return doExpr(subExpr);
  825. }
  826. case CastKind::CK_FunctionToPointerDecay:
  827. // Just need to return the function id
  828. return doExpr(subExpr);
  829. default:
  830. emitError("ImplictCast Kind '%0' is not supported yet.")
  831. << expr->getCastKindName();
  832. expr->dump();
  833. return 0;
  834. }
  835. }
  836. uint32_t processIntrinsicDot(const CallExpr *callExpr) {
  837. const uint32_t returnType =
  838. typeTranslator.translateType(callExpr->getType());
  839. // Get the function parameters. Expect 2 vectors as parameters.
  840. assert(callExpr->getNumArgs() == 2u);
  841. const Expr *arg0 = callExpr->getArg(0);
  842. const Expr *arg1 = callExpr->getArg(1);
  843. const uint32_t arg0Id = doExpr(arg0);
  844. const uint32_t arg1Id = doExpr(arg1);
  845. QualType arg0Type = arg0->getType();
  846. QualType arg1Type = arg1->getType();
  847. const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
  848. const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
  849. const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
  850. const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
  851. assert(callExpr->getType() == vec1ComponentType);
  852. assert(vec0ComponentType == vec1ComponentType);
  853. assert(vec0Size == vec1Size);
  854. assert(vec0Size >= 1 && vec0Size <= 4);
  855. // According to HLSL reference, the dot function only works on integers
  856. // and floats.
  857. const auto returnTypeBuiltinKind =
  858. cast<BuiltinType>(callExpr->getType().getTypePtr())->getKind();
  859. assert(returnTypeBuiltinKind == BuiltinType::Float ||
  860. returnTypeBuiltinKind == BuiltinType::Int ||
  861. returnTypeBuiltinKind == BuiltinType::UInt);
  862. // Special case: dot product of two vectors, each of size 1. That is
  863. // basically the same as regular multiplication of 2 scalars.
  864. if (vec0Size == 1) {
  865. const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
  866. return theBuilder.createBinaryOp(spvOp, returnType, arg0Id, arg1Id);
  867. }
  868. // If the vectors are of type Float, we can use OpDot.
  869. if (returnTypeBuiltinKind == BuiltinType::Float) {
  870. return theBuilder.createBinaryOp(spv::Op::OpDot, returnType, arg0Id,
  871. arg1Id);
  872. }
  873. // Vector component type is Integer (signed or unsigned).
  874. // Create all instructions necessary to perform a dot product on
  875. // two integer vectors. SPIR-V OpDot does not support integer vectors.
  876. // Therefore, we use other SPIR-V instructions (addition and
  877. // multiplication).
  878. else {
  879. uint32_t result = 0;
  880. llvm::SmallVector<uint32_t, 4> multIds;
  881. const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
  882. const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
  883. // Extract members from the two vectors and multiply them.
  884. for (unsigned int i = 0; i < vec0Size; ++i) {
  885. const uint32_t vec0member =
  886. theBuilder.createCompositeExtract(returnType, arg0Id, {i});
  887. const uint32_t vec1member =
  888. theBuilder.createCompositeExtract(returnType, arg1Id, {i});
  889. const uint32_t multId = theBuilder.createBinaryOp(
  890. multSpvOp, returnType, vec0member, vec1member);
  891. multIds.push_back(multId);
  892. }
  893. // Add all the multiplications.
  894. result = multIds[0];
  895. for (unsigned int i = 1; i < vec0Size; ++i) {
  896. const uint32_t additionId =
  897. theBuilder.createBinaryOp(addSpvOp, returnType, result, multIds[i]);
  898. result = additionId;
  899. }
  900. return result;
  901. }
  902. }
  903. /// Processes the given expr, casts the result into the given bool (vector)
  904. /// type and returns the <result-id> of the casted value.
  905. uint32_t castToBool(const Expr *expr, QualType toBoolType) {
  906. // Converting to bool means comparing with value zero.
  907. const uint32_t fromVal = doExpr(expr);
  908. if (isBoolOrVecOfBoolType(expr->getType()))
  909. return fromVal;
  910. const spv::Op spvOp = translateOp(BO_NE, expr->getType());
  911. const uint32_t boolType = typeTranslator.translateType(toBoolType);
  912. const uint32_t zeroVal = getValueZero(expr->getType());
  913. return theBuilder.createBinaryOp(spvOp, boolType, fromVal, zeroVal);
  914. }
  915. /// Processes the given expr, casts the result into the given integer (vector)
  916. /// type and returns the <result-id> of the casted value.
  917. uint32_t castToInt(const Expr *expr, QualType toIntType) {
  918. const QualType fromType = expr->getType();
  919. const uint32_t intType = typeTranslator.translateType(toIntType);
  920. const uint32_t fromVal = doExpr(expr);
  921. if (isBoolOrVecOfBoolType(fromType)) {
  922. const uint32_t one = getValueOne(toIntType);
  923. const uint32_t zero = getValueZero(toIntType);
  924. return theBuilder.createSelect(intType, fromVal, one, zero);
  925. } else if (isSintOrVecOfSintType(fromType) ||
  926. isUintOrVecOfUintType(fromType)) {
  927. if (fromType == toIntType)
  928. return fromVal;
  929. // TODO: handle different bitwidths
  930. return theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, fromVal);
  931. } else if (isFloatOrVecOfFloatType(fromType)) {
  932. if (isSintOrVecOfSintType(toIntType)) {
  933. return theBuilder.createUnaryOp(spv::Op::OpConvertFToS, intType,
  934. fromVal);
  935. } else if (isUintOrVecOfUintType(toIntType)) {
  936. return theBuilder.createUnaryOp(spv::Op::OpConvertFToU, intType,
  937. fromVal);
  938. } else {
  939. emitError("unimplemented casting to integer from floating point");
  940. }
  941. } else {
  942. emitError("unimplemented casting to integer");
  943. }
  944. expr->dump();
  945. return 0;
  946. }
  947. uint32_t processIntrinsicCallExpr(const CallExpr *callExpr) {
  948. const FunctionDecl *callee = callExpr->getDirectCallee();
  949. assert(hlsl::IsIntrinsicOp(callee) &&
  950. "doIntrinsicCallExpr was called for a non-intrinsic function.");
  951. // Figure out which intrinsic function to translate.
  952. llvm::StringRef group;
  953. uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
  954. hlsl::GetIntrinsicOp(callee, opcode, group);
  955. switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
  956. case hlsl::IntrinsicOp::IOP_dot: {
  957. return processIntrinsicDot(callExpr);
  958. break;
  959. }
  960. default:
  961. break;
  962. }
  963. emitError("Intrinsic function '%0' not yet implemented.")
  964. << callee->getName();
  965. return 0;
  966. }
  967. uint32_t castToFloat(const Expr *expr, QualType toFloatType) {
  968. const QualType fromType = expr->getType();
  969. const uint32_t floatType = typeTranslator.translateType(toFloatType);
  970. const uint32_t fromVal = doExpr(expr);
  971. if (isBoolOrVecOfBoolType(fromType)) {
  972. const uint32_t one = getValueOne(toFloatType);
  973. const uint32_t zero = getValueZero(toFloatType);
  974. return theBuilder.createSelect(floatType, fromVal, one, zero);
  975. }
  976. if (isSintOrVecOfSintType(fromType)) {
  977. return theBuilder.createUnaryOp(spv::Op::OpConvertSToF, floatType,
  978. fromVal);
  979. }
  980. if (isUintOrVecOfUintType(fromType)) {
  981. return theBuilder.createUnaryOp(spv::Op::OpConvertUToF, floatType,
  982. fromVal);
  983. }
  984. if (isFloatOrVecOfFloatType(fromType)) {
  985. return fromVal;
  986. }
  987. emitError("unimplemented casting to floating point");
  988. expr->dump();
  989. return 0;
  990. }
  991. uint32_t doCallExpr(const CallExpr *callExpr) {
  992. const FunctionDecl *callee = callExpr->getDirectCallee();
  993. // Intrinsic functions such as 'dot' or 'mul'
  994. if (hlsl::IsIntrinsicOp(callee)) {
  995. return processIntrinsicCallExpr(callExpr);
  996. }
  997. if (callee) {
  998. const uint32_t returnType =
  999. typeTranslator.translateType(callExpr->getType());
  1000. // Get or forward declare the function <result-id>
  1001. const uint32_t funcId = declIdMapper.getOrRegisterDeclResultId(callee);
  1002. // Evaluate parameters
  1003. llvm::SmallVector<uint32_t, 4> params;
  1004. for (const auto *arg : callExpr->arguments()) {
  1005. // We need to create variables for holding the values to be used as
  1006. // arguments. The variables themselves are of pointer types.
  1007. const uint32_t ptrType = theBuilder.getPointerType(
  1008. typeTranslator.translateType(arg->getType()),
  1009. spv::StorageClass::Function);
  1010. const uint32_t tempVarId = theBuilder.addFnVariable(ptrType);
  1011. theBuilder.createStore(tempVarId, doExpr(arg));
  1012. params.push_back(tempVarId);
  1013. }
  1014. // Push the callee into the work queue if it is not there.
  1015. if (!workQueue.count(callee)) {
  1016. workQueue.insert(callee);
  1017. }
  1018. return theBuilder.createFunctionCall(returnType, funcId, params);
  1019. }
  1020. emitError("calling non-function unimplemented");
  1021. return 0;
  1022. }
  1023. uint32_t doConditionalOperator(const ConditionalOperator *expr) {
  1024. // According to HLSL doc, all sides of the ?: expression are always
  1025. // evaluated.
  1026. const uint32_t type = typeTranslator.translateType(expr->getType());
  1027. const uint32_t condition = doExpr(expr->getCond());
  1028. const uint32_t trueBranch = doExpr(expr->getTrueExpr());
  1029. const uint32_t falseBranch = doExpr(expr->getFalseExpr());
  1030. return theBuilder.createSelect(type, condition, trueBranch, falseBranch);
  1031. }
  1032. /// Translates a floatN * float multiplication into SPIR-V instructions and
  1033. /// returns the <result-id>. Returns 0 if the given binary operation is not
  1034. /// floatN * float.
  1035. uint32_t tryToGenFloatVectorScale(const BinaryOperator *expr) {
  1036. const QualType type = expr->getType();
  1037. // We can only translate floatN * float into OpVectorTimesScalar.
  1038. // So the result type must be floatN.
  1039. if (!hlsl::IsHLSLVecType(type) ||
  1040. !hlsl::GetHLSLVecElementType(type)->isFloatingType())
  1041. return 0;
  1042. const Expr *lhs = expr->getLHS();
  1043. const Expr *rhs = expr->getRHS();
  1044. // Multiplying a float vector with a float scalar will be represented in
  1045. // AST via a binary operation with two float vectors as operands; one of
  1046. // the operand is from an implicit cast with kind CK_HLSLVectorSplat.
  1047. // vector * scalar
  1048. if (hlsl::IsHLSLVecType(lhs->getType())) {
  1049. if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
  1050. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  1051. const uint32_t vecType =
  1052. typeTranslator.translateType(expr->getType());
  1053. if (isa<CompoundAssignOperator>(expr)) {
  1054. // For floatN * float cases. We'll need to do the load/store and
  1055. // return the lhs.
  1056. const uint32_t rhsVal = doExpr(cast->getSubExpr());
  1057. const uint32_t lhsPtr = doExpr(lhs);
  1058. const uint32_t lhsVal = theBuilder.createLoad(vecType, lhsPtr);
  1059. const uint32_t result = theBuilder.createBinaryOp(
  1060. spv::Op::OpVectorTimesScalar, vecType, lhsVal, rhsVal);
  1061. theBuilder.createStore(lhsPtr, result);
  1062. return lhsPtr;
  1063. } else {
  1064. const uint32_t lhsId = doExpr(lhs);
  1065. const uint32_t rhsId = doExpr(cast->getSubExpr());
  1066. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  1067. vecType, lhsId, rhsId);
  1068. }
  1069. }
  1070. }
  1071. }
  1072. // scalar * vector
  1073. if (hlsl::IsHLSLVecType(rhs->getType())) {
  1074. if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
  1075. if (cast->getCastKind() == CK_HLSLVectorSplat) {
  1076. const uint32_t vecType =
  1077. typeTranslator.translateType(expr->getType());
  1078. const uint32_t lhsId = doExpr(cast->getSubExpr());
  1079. const uint32_t rhsId = doExpr(rhs);
  1080. return theBuilder.createBinaryOp(spv::Op::OpVectorTimesScalar,
  1081. vecType, rhsId, lhsId);
  1082. }
  1083. }
  1084. }
  1085. return 0;
  1086. }
  1087. /// Translates the given frontend binary operator into its SPIR-V equivalent
  1088. /// taking consideration of the operand type.
  1089. spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
  1090. // TODO: the following is not considering vector types yet.
  1091. const bool isSintType = isSintOrVecOfSintType(type);
  1092. const bool isUintType = isUintOrVecOfUintType(type);
  1093. const bool isFloatType = isFloatOrVecOfFloatType(type);
  1094. #define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp) \
  1095. \
  1096. case BO_##kind : { \
  1097. if (isSintType || isUintType) { \
  1098. return spv::Op::Op##intBinOp; \
  1099. } \
  1100. if (isFloatType) { \
  1101. return spv::Op::Op##floatBinOp; \
  1102. } \
  1103. } \
  1104. break
  1105. #define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp) \
  1106. \
  1107. case BO_##kind : { \
  1108. if (isSintType) { \
  1109. return spv::Op::Op##sintBinOp; \
  1110. } \
  1111. if (isUintType) { \
  1112. return spv::Op::Op##uintBinOp; \
  1113. } \
  1114. if (isFloatType) { \
  1115. return spv::Op::Op##floatBinOp; \
  1116. } \
  1117. } \
  1118. break
  1119. #define BIN_OP_CASE_SINT_UINT(kind, sintBinOp, uintBinOp) \
  1120. \
  1121. case BO_##kind : { \
  1122. if (isSintType) { \
  1123. return spv::Op::Op##sintBinOp; \
  1124. } \
  1125. if (isUintType) { \
  1126. return spv::Op::Op##uintBinOp; \
  1127. } \
  1128. } \
  1129. break
  1130. switch (op) {
  1131. BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
  1132. BIN_OP_CASE_INT_FLOAT(AddAssign, IAdd, FAdd);
  1133. BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
  1134. BIN_OP_CASE_INT_FLOAT(SubAssign, ISub, FSub);
  1135. BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
  1136. BIN_OP_CASE_INT_FLOAT(MulAssign, IMul, FMul);
  1137. BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
  1138. BIN_OP_CASE_SINT_UINT_FLOAT(DivAssign, SDiv, UDiv, FDiv);
  1139. // According to HLSL spec, "the modulus operator returns the remainder of
  1140. // a division." "The % operator is defined only in cases where either both
  1141. // sides are positive or both sides are negative."
  1142. //
  1143. // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
  1144. // the former, the sign of a non-0 result comes from Operand 1, while
  1145. // with the latter, from Operand 2.
  1146. //
  1147. // For operands with different signs, technically we can map % to either
  1148. // Op*Rem or Op*Mod since it's undefined behavior. But it is more
  1149. // consistent with C (HLSL starts as a C derivative) and Clang frontend
  1150. // const expression evaluation if we map % to Op*Rem.
  1151. //
  1152. // Note there is no OpURem in SPIR-V.
  1153. BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
  1154. BIN_OP_CASE_SINT_UINT_FLOAT(RemAssign, SRem, UMod, FRem);
  1155. BIN_OP_CASE_SINT_UINT_FLOAT(LT, SLessThan, ULessThan, FOrdLessThan);
  1156. BIN_OP_CASE_SINT_UINT_FLOAT(LE, SLessThanEqual, ULessThanEqual,
  1157. FOrdLessThanEqual);
  1158. BIN_OP_CASE_SINT_UINT_FLOAT(GT, SGreaterThan, UGreaterThan,
  1159. FOrdGreaterThan);
  1160. BIN_OP_CASE_SINT_UINT_FLOAT(GE, SGreaterThanEqual, UGreaterThanEqual,
  1161. FOrdGreaterThanEqual);
  1162. BIN_OP_CASE_INT_FLOAT(EQ, IEqual, FOrdEqual);
  1163. BIN_OP_CASE_INT_FLOAT(NE, INotEqual, FOrdNotEqual);
  1164. BIN_OP_CASE_SINT_UINT(And, BitwiseAnd, BitwiseAnd);
  1165. BIN_OP_CASE_SINT_UINT(AndAssign, BitwiseAnd, BitwiseAnd);
  1166. BIN_OP_CASE_SINT_UINT(Or, BitwiseOr, BitwiseOr);
  1167. BIN_OP_CASE_SINT_UINT(OrAssign, BitwiseOr, BitwiseOr);
  1168. BIN_OP_CASE_SINT_UINT(Xor, BitwiseXor, BitwiseXor);
  1169. BIN_OP_CASE_SINT_UINT(XorAssign, BitwiseXor, BitwiseXor);
  1170. BIN_OP_CASE_SINT_UINT(Shl, ShiftLeftLogical, ShiftLeftLogical);
  1171. BIN_OP_CASE_SINT_UINT(ShlAssign, ShiftLeftLogical, ShiftLeftLogical);
  1172. BIN_OP_CASE_SINT_UINT(Shr, ShiftRightArithmetic, ShiftRightLogical);
  1173. BIN_OP_CASE_SINT_UINT(ShrAssign, ShiftRightArithmetic, ShiftRightLogical);
  1174. // According to HLSL doc, all sides of the && and || expression are always
  1175. // evaluated.
  1176. case BO_LAnd:
  1177. return spv::Op::OpLogicalAnd;
  1178. case BO_LOr:
  1179. return spv::Op::OpLogicalOr;
  1180. default:
  1181. break;
  1182. }
  1183. #undef BIN_OP_CASE_INT_FLOAT
  1184. #undef BIN_OP_CASE_SINT_UINT_FLOAT
  1185. #undef BIN_OP_CASE_SINT_UINT
  1186. emitError("translating binary operator '%0' unimplemented")
  1187. << BinaryOperator::getOpcodeStr(op);
  1188. return spv::Op::OpNop;
  1189. }
  1190. /// Returns the <result-id> for constant value 1 of the given type.
  1191. uint32_t getValueOne(QualType type) {
  1192. if (type->isSignedIntegerType()) {
  1193. return theBuilder.getConstantInt32(1);
  1194. }
  1195. if (type->isUnsignedIntegerType()) {
  1196. return theBuilder.getConstantUint32(1);
  1197. }
  1198. if (type->isFloatingType()) {
  1199. return theBuilder.getConstantFloat32(1.0);
  1200. }
  1201. if (hlsl::IsHLSLVecType(type)) {
  1202. const QualType elemType = hlsl::GetHLSLVecElementType(type);
  1203. const uint32_t elemOneId = getValueOne(elemType);
  1204. const size_t size = hlsl::GetHLSLVecSize(type);
  1205. if (size == 1)
  1206. return elemOneId;
  1207. llvm::SmallVector<uint32_t, 4> elements(size, elemOneId);
  1208. const uint32_t vecTypeId = typeTranslator.translateType(type);
  1209. return theBuilder.getConstantComposite(vecTypeId, elements);
  1210. }
  1211. emitError("getting value 1 for type '%0' unimplemented") << type;
  1212. return 0;
  1213. }
  1214. /// Returns the <result-id> for constant value 0 of the given type.
  1215. uint32_t getValueZero(QualType type) {
  1216. if (type->isSignedIntegerType()) {
  1217. return theBuilder.getConstantInt32(0);
  1218. }
  1219. if (type->isUnsignedIntegerType()) {
  1220. return theBuilder.getConstantUint32(0);
  1221. }
  1222. if (type->isFloatingType()) {
  1223. return theBuilder.getConstantFloat32(0.0);
  1224. }
  1225. if (hlsl::IsHLSLVecType(type)) {
  1226. const QualType elemType = hlsl::GetHLSLVecElementType(type);
  1227. const uint32_t elemZeroId = getValueZero(elemType);
  1228. const size_t size = hlsl::GetHLSLVecSize(type);
  1229. if (size == 1)
  1230. return elemZeroId;
  1231. llvm::SmallVector<uint32_t, 4> elements(size, elemZeroId);
  1232. const uint32_t vecTypeId = typeTranslator.translateType(type);
  1233. return theBuilder.getConstantComposite(vecTypeId, elements);
  1234. }
  1235. emitError("getting value 0 for type '%0' unimplemented")
  1236. << type.getAsString();
  1237. return 0;
  1238. }
  1239. /// Translates the given frontend APValue into its SPIR-V equivalent for the
  1240. /// given targetType.
  1241. uint32_t translateAPValue(const APValue &value, const QualType targetType) {
  1242. if (targetType->isBooleanType()) {
  1243. const bool boolValue = value.getInt().getBoolValue();
  1244. return theBuilder.getConstantBool(boolValue);
  1245. }
  1246. if (targetType->isIntegerType()) {
  1247. const llvm::APInt &intValue = value.getInt();
  1248. return translateAPInt(intValue, targetType);
  1249. }
  1250. if (targetType->isFloatingType()) {
  1251. const llvm::APFloat &floatValue = value.getFloat();
  1252. return translateAPFloat(floatValue, targetType);
  1253. }
  1254. if (hlsl::IsHLSLVecType(targetType)) {
  1255. const uint32_t vecType = typeTranslator.translateType(targetType);
  1256. const QualType elemType = hlsl::GetHLSLVecElementType(targetType);
  1257. const auto numElements = value.getVectorLength();
  1258. // Special case for vectors of size 1. SPIR-V doesn't support this vector
  1259. // size so we need to translate it to scalar values.
  1260. if (numElements == 1) {
  1261. return translateAPValue(value.getVectorElt(0), elemType);
  1262. }
  1263. llvm::SmallVector<uint32_t, 4> elements;
  1264. for (uint32_t i = 0; i < numElements; ++i) {
  1265. elements.push_back(translateAPValue(value.getVectorElt(i), elemType));
  1266. }
  1267. return theBuilder.getConstantComposite(vecType, elements);
  1268. }
  1269. emitError("APValue of type '%0' is not supported yet.") << value.getKind();
  1270. value.dump();
  1271. return 0;
  1272. }
  1273. /// Translates the given frontend APInt into its SPIR-V equivalent for the
  1274. /// given targetType.
  1275. uint32_t translateAPInt(const llvm::APInt &intValue, QualType targetType) {
  1276. const auto bitwidth = astContext.getIntWidth(targetType);
  1277. if (targetType->isSignedIntegerType()) {
  1278. const int64_t value = intValue.getSExtValue();
  1279. switch (bitwidth) {
  1280. case 32:
  1281. return theBuilder.getConstantInt32(static_cast<int32_t>(value));
  1282. default:
  1283. break;
  1284. }
  1285. } else {
  1286. const uint64_t value = intValue.getZExtValue();
  1287. switch (bitwidth) {
  1288. case 32:
  1289. return theBuilder.getConstantUint32(static_cast<uint32_t>(value));
  1290. default:
  1291. break;
  1292. }
  1293. }
  1294. emitError("APInt for target bitwidth '%0' is not supported yet.")
  1295. << bitwidth;
  1296. return 0;
  1297. }
  1298. /// Translates the given frontend APFloat into its SPIR-V equivalent for the
  1299. /// given targetType.
  1300. uint32_t translateAPFloat(const llvm::APFloat &floatValue,
  1301. QualType targetType) {
  1302. const auto &semantics = astContext.getFloatTypeSemantics(targetType);
  1303. const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
  1304. switch (bitwidth) {
  1305. case 32:
  1306. return theBuilder.getConstantFloat32(floatValue.convertToFloat());
  1307. default:
  1308. break;
  1309. }
  1310. emitError("APFloat for target bitwidth '%0' is not supported yet.")
  1311. << bitwidth;
  1312. return 0;
  1313. }
  1314. private:
  1315. /// \brief Wrapper method to create an error message and report it
  1316. /// in the diagnostic engine associated with this consumer.
  1317. template <unsigned N> DiagnosticBuilder emitError(const char (&message)[N]) {
  1318. const auto diagId =
  1319. diags.getCustomDiagID(clang::DiagnosticsEngine::Error, message);
  1320. return diags.Report(diagId);
  1321. }
  1322. /// \brief Wrapper method to create a warning message and report it
  1323. /// in the diagnostic engine associated with this consumer
  1324. template <unsigned N>
  1325. DiagnosticBuilder emitWarning(const char (&message)[N]) {
  1326. const auto diagId =
  1327. diags.getCustomDiagID(clang::DiagnosticsEngine::Warning, message);
  1328. return diags.Report(diagId);
  1329. }
  1330. private:
  1331. CompilerInstance &theCompilerInstance;
  1332. ASTContext &astContext;
  1333. DiagnosticsEngine &diags;
  1334. /// Entry function name and shader stage. Both of them are derived from the
  1335. /// command line and should be const.
  1336. const llvm::StringRef entryFunctionName;
  1337. const spv::ExecutionModel shaderStage;
  1338. SPIRVContext theContext;
  1339. ModuleBuilder theBuilder;
  1340. DeclResultIdMapper declIdMapper;
  1341. TypeTranslator typeTranslator;
  1342. /// A queue of decls reachable from the entry function. Decls inserted into
  1343. /// this queue will persist to avoid duplicated translations. And we'd like
  1344. /// a deterministic order of iterating the queue for finding the next decl
  1345. /// to translate. So we need SetVector here.
  1346. llvm::SetVector<const DeclaratorDecl *> workQueue;
  1347. /// <result-id> for the entry function. Initially it is zero and will be reset
  1348. /// when starting to translate the entry function.
  1349. uint32_t entryFunctionId;
  1350. /// The current function under traversal.
  1351. const FunctionDecl *curFunction;
  1352. };
  1353. } // end namespace spirv
  1354. std::unique_ptr<ASTConsumer>
  1355. EmitSPIRVAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
  1356. return llvm::make_unique<spirv::SPIRVEmitter>(CI);
  1357. }
  1358. } // end namespace clang