DeclResultIdMapper.cpp 79 KB


  1. //===--- DeclResultIdMapper.cpp - DeclResultIdMapper impl --------*- C++ -*-==//
  2. //
  3. // The LLVM Compiler Infrastructure
  4. //
  5. // This file is distributed under the University of Illinois Open Source
  6. // License. See LICENSE.TXT for details.
  7. //
  8. //===----------------------------------------------------------------------===//
  9. #include "DeclResultIdMapper.h"
  10. #include <algorithm>
  11. #include <cstring>
  12. #include <sstream>
  13. #include <unordered_map>
  14. #include "dxc/HLSL/DxilConstants.h"
  15. #include "dxc/HLSL/DxilTypeSystem.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/HlslTypes.h"
  18. #include "clang/AST/RecursiveASTVisitor.h"
  19. #include "llvm/ADT/SmallBitVector.h"
  20. #include "llvm/ADT/StringSet.h"
  21. namespace clang {
  22. namespace spirv {
  23. namespace {
  24. /// \brief Returns the stage variable's register assignment for the given Decl.
  25. const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
  26. for (auto *annotation : decl->getUnusualAnnotations()) {
  27. if (auto *reg = dyn_cast<hlsl::RegisterAssignment>(annotation)) {
  28. return reg;
  29. }
  30. }
  31. return nullptr;
  32. }
  33. /// \brief Returns the resource category for the given type.
  34. ResourceVar::Category getResourceCategory(QualType type) {
  35. if (TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type))
  36. return ResourceVar::Category::Image;
  37. if (TypeTranslator::isSampler(type))
  38. return ResourceVar::Category::Sampler;
  39. return ResourceVar::Category::Other;
  40. }
  41. /// \brief Returns true if the given declaration has a primitive type qualifier.
  42. /// Returns false otherwise.
  43. inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
  44. return decl->hasAttr<HLSLTriangleAttr>() ||
  45. decl->hasAttr<HLSLTriangleAdjAttr>() ||
  46. decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
  47. decl->hasAttr<HLSLLineAdjAttr>();
  48. }
  49. /// \brief Deduces the parameter qualifier for the given decl.
  50. hlsl::DxilParamInputQual deduceParamQual(const DeclaratorDecl *decl,
  51. bool asInput) {
  52. const auto type = decl->getType();
  53. if (hlsl::IsHLSLInputPatchType(type))
  54. return hlsl::DxilParamInputQual::InputPatch;
  55. if (hlsl::IsHLSLOutputPatchType(type))
  56. return hlsl::DxilParamInputQual::OutputPatch;
  57. // TODO: Add support for multiple output streams.
  58. if (hlsl::IsHLSLStreamOutputType(type))
  59. return hlsl::DxilParamInputQual::OutStream0;
  60. // The inputs to the geometry shader that have a primitive type qualifier
  61. // must use 'InputPrimitive'.
  62. if (hasGSPrimitiveTypeQualifier(decl))
  63. return hlsl::DxilParamInputQual::InputPrimitive;
  64. return asInput ? hlsl::DxilParamInputQual::In : hlsl::DxilParamInputQual::Out;
  65. }
  66. /// \brief Deduces the HLSL SigPoint for the given decl appearing in the given
  67. /// shader model.
  68. const hlsl::SigPoint *deduceSigPoint(const DeclaratorDecl *decl, bool asInput,
  69. const hlsl::ShaderModel::Kind kind,
  70. bool forPCF) {
  71. return hlsl::SigPoint::GetSigPoint(hlsl::SigPointFromInputQual(
  72. deduceParamQual(decl, asInput), kind, forPCF));
  73. }
  74. /// Returns the type of the given decl. If the given decl is a FunctionDecl,
  75. /// returns its result type.
  76. inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
  77. if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
  78. return funcDecl->getReturnType();
  79. }
  80. return decl->getType();
  81. }
  82. } // anonymous namespace
  83. std::string StageVar::getSemanticStr() const {
  84. // A special case for zero index, which is equivalent to no index.
  85. // Use what is in the source code.
  86. // TODO: this looks like a hack to make the current tests happy.
  87. // Should consider remove it and fix all tests.
  88. if (semanticIndex == 0)
  89. return semanticStr;
  90. std::ostringstream ss;
  91. ss << semanticName.str() << semanticIndex;
  92. return ss.str();
  93. }
  94. uint32_t CounterIdAliasPair::get(ModuleBuilder &builder,
  95. TypeTranslator &translator) const {
  96. if (isAlias) {
  97. const uint32_t counterVarType = builder.getPointerType(
  98. translator.getACSBufferCounter(), spv::StorageClass::Uniform);
  99. return builder.createLoad(counterVarType, resultId);
  100. }
  101. return resultId;
  102. }
  103. const CounterIdAliasPair *
  104. CounterVarFields::get(const llvm::SmallVectorImpl<uint32_t> &indices) const {
  105. for (const auto &field : fields)
  106. if (field.indices == indices)
  107. return &field.counterVar;
  108. return nullptr;
  109. }
  110. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  111. ModuleBuilder &builder,
  112. TypeTranslator &translator) const {
  113. for (const auto &field : fields) {
  114. const auto *srcField = srcFields.get(field.indices);
  115. if (!srcField)
  116. return false;
  117. field.counterVar.assign(*srcField, builder, translator);
  118. }
  119. return true;
  120. }
  121. bool CounterVarFields::assign(const CounterVarFields &srcFields,
  122. const llvm::SmallVector<uint32_t, 4> &dstPrefix,
  123. const llvm::SmallVector<uint32_t, 4> &srcPrefix,
  124. ModuleBuilder &builder,
  125. TypeTranslator &translator) const {
  126. if (dstPrefix.empty() && srcPrefix.empty())
  127. return assign(srcFields, builder, translator);
  128. llvm::SmallVector<uint32_t, 4> srcIndices = srcPrefix;
  129. // If whole has the given prefix, appends all elements after the prefix in
  130. // whole to srcIndices.
  131. const auto applyDiff =
  132. [&srcIndices](const llvm::SmallVector<uint32_t, 4> &whole,
  133. const llvm::SmallVector<uint32_t, 4> &prefix) -> bool {
  134. uint32_t i = 0;
  135. for (; i < prefix.size(); ++i)
  136. if (whole[i] != prefix[i]) {
  137. break;
  138. }
  139. if (i == prefix.size()) {
  140. for (; i < whole.size(); ++i)
  141. srcIndices.push_back(whole[i]);
  142. return true;
  143. }
  144. return false;
  145. };
  146. for (const auto &field : fields)
  147. if (applyDiff(field.indices, dstPrefix)) {
  148. const auto *srcField = srcFields.get(srcIndices);
  149. if (!srcField)
  150. return false;
  151. field.counterVar.assign(*srcField, builder, translator);
  152. for (uint32_t i = srcPrefix.size(); i < srcIndices.size(); ++i)
  153. srcIndices.pop_back();
  154. }
  155. return true;
  156. }
  157. DeclResultIdMapper::SemanticInfo
  158. DeclResultIdMapper::getStageVarSemantic(const ValueDecl *decl) {
  159. for (auto *annotation : decl->getUnusualAnnotations()) {
  160. if (auto *sema = dyn_cast<hlsl::SemanticDecl>(annotation)) {
  161. llvm::StringRef semanticStr = sema->SemanticName;
  162. llvm::StringRef semanticName;
  163. uint32_t index = 0;
  164. hlsl::Semantic::DecomposeNameAndIndex(semanticStr, &semanticName, &index);
  165. const auto *semantic = hlsl::Semantic::GetByName(semanticName);
  166. return {semanticStr, semantic, semanticName, index, sema->Loc};
  167. }
  168. }
  169. return {};
  170. }
  171. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  172. uint32_t storedValue,
  173. bool forPCF) {
  174. QualType type = getTypeOrFnRetType(decl);
  175. // Output stream types (PointStream, LineStream, TriangleStream) are
  176. // translated as their underlying struct types.
  177. if (hlsl::IsHLSLStreamOutputType(type))
  178. type = hlsl::GetHLSLResourceResultType(type);
  179. const auto *sigPoint =
  180. deduceSigPoint(decl, /*asInput=*/false, shaderModel.GetKind(), forPCF);
  181. // HS output variables are created using the other overload. For the rest,
  182. // none of them should be created as arrays.
  183. assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
  184. SemanticInfo inheritSemantic = {};
  185. return createStageVars(sigPoint, decl, /*asInput=*/false, type,
  186. /*arraySize=*/0, "out.var", llvm::None, &storedValue,
  187. // Write back of stage output variables in GS is
  188. // manually controlled by .Append() intrinsic method,
  189. // implemented in writeBackOutputStream(). So
  190. // noWriteBack should be set to true for GS.
  191. shaderModel.IsGS(), &inheritSemantic);
  192. }
  193. bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
  194. uint32_t arraySize,
  195. uint32_t invocationId,
  196. uint32_t storedValue) {
  197. assert(shaderModel.IsHS());
  198. QualType type = getTypeOrFnRetType(decl);
  199. const auto *sigPoint =
  200. hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
  201. SemanticInfo inheritSemantic = {};
  202. return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
  203. "out.var", invocationId, &storedValue,
  204. /*noWriteBack=*/false, &inheritSemantic);
  205. }
  206. bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
  207. uint32_t *loadedValue,
  208. bool forPCF) {
  209. uint32_t arraySize = 0;
  210. QualType type = paramDecl->getType();
  211. // Deprive the outermost arrayness for HS/DS/GS and use arraySize
  212. // to convey that information
  213. if (hlsl::IsHLSLInputPatchType(type)) {
  214. arraySize = hlsl::GetHLSLInputPatchCount(type);
  215. type = hlsl::GetHLSLInputPatchElementType(type);
  216. } else if (hlsl::IsHLSLOutputPatchType(type)) {
  217. arraySize = hlsl::GetHLSLOutputPatchCount(type);
  218. type = hlsl::GetHLSLOutputPatchElementType(type);
  219. }
  220. if (hasGSPrimitiveTypeQualifier(paramDecl)) {
  221. const auto *typeDecl = astContext.getAsConstantArrayType(type);
  222. arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
  223. type = typeDecl->getElementType();
  224. }
  225. const auto *sigPoint = deduceSigPoint(paramDecl, /*asInput=*/true,
  226. shaderModel.GetKind(), forPCF);
  227. SemanticInfo inheritSemantic = {};
  228. return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, arraySize,
  229. "in.var", llvm::None, loadedValue,
  230. /*noWriteBack=*/false, &inheritSemantic);
  231. }
  232. const DeclResultIdMapper::DeclSpirvInfo *
  233. DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
  234. auto it = astDecls.find(decl);
  235. if (it != astDecls.end())
  236. return &it->second;
  237. return nullptr;
  238. }
  239. SpirvEvalInfo DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
  240. bool checkRegistered) {
  241. if (const auto *info = getDeclSpirvInfo(decl))
  242. if (info->indexInCTBuffer >= 0) {
  243. // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
  244. // OpAccessChain to get the pointer to the variable since we created
  245. // a single variable for the whole buffer object.
  246. const uint32_t varType = typeTranslator.translateType(
  247. // Should only have VarDecls in a HLSLBufferDecl.
  248. cast<VarDecl>(decl)->getType(),
  249. // We need to set decorateLayout here to avoid creating SPIR-V
  250. // instructions for the current type without decorations.
  251. info->info.getLayoutRule(), info->isRowMajor);
  252. const uint32_t elemId = theBuilder.createAccessChain(
  253. theBuilder.getPointerType(varType, info->info.getStorageClass()),
  254. info->info, {theBuilder.getConstantInt32(info->indexInCTBuffer)});
  255. return SpirvEvalInfo(elemId)
  256. .setStorageClass(info->info.getStorageClass())
  257. .setLayoutRule(info->info.getLayoutRule());
  258. } else {
  259. return *info;
  260. }
  261. if (checkRegistered) {
  262. emitFatalError("found unregistered decl", decl->getLocation())
  263. << decl->getName();
  264. emitNote("please file a bug report on "
  265. "https://github.com/Microsoft/DirectXShaderCompiler/issues with "
  266. "source code if possible",
  267. {});
  268. }
  269. return 0;
  270. }
  271. uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
  272. bool isAlias = false;
  273. auto &info = astDecls[param].info;
  274. const uint32_t type =
  275. getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias, &info);
  276. const uint32_t ptrType =
  277. theBuilder.getPointerType(type, spv::StorageClass::Function);
  278. const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
  279. info.setResultId(id);
  280. return id;
  281. }
  282. void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
  283. const QualType declType = getTypeOrFnRetType(decl);
  284. if (!counterVars.count(decl) &&
  285. TypeTranslator::isRWAppendConsumeSBuffer(declType)) {
  286. createCounterVar(decl, /*isAlias=*/true);
  287. } else if (!fieldCounterVars.count(decl) && declType->isStructureType() &&
  288. // Exclude other resource types which are represented as structs
  289. !hlsl::IsHLSLResourceType(declType)) {
  290. createFieldCounterVars(decl);
  291. }
  292. }
  293. uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
  294. llvm::Optional<uint32_t> init) {
  295. bool isAlias = false;
  296. auto &info = astDecls[var].info;
  297. const uint32_t type =
  298. getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
  299. const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
  300. info.setResultId(id);
  301. return id;
  302. }
  303. uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
  304. llvm::Optional<uint32_t> init) {
  305. bool isAlias = false;
  306. auto &info = astDecls[var].info;
  307. const uint32_t type =
  308. getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
  309. const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
  310. var->getName(), init);
  311. info.setResultId(id).setStorageClass(spv::StorageClass::Private);
  312. return id;
  313. }
  314. uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
  315. auto storageClass = spv::StorageClass::UniformConstant;
  316. auto rule = LayoutRule::Void;
  317. bool isACRWSBuffer = false; // Whether its {Append|Consume|RW}StructuredBuffer
  318. if (var->getAttr<HLSLGroupSharedAttr>()) {
  319. // For CS groupshared variables
  320. storageClass = spv::StorageClass::Workgroup;
  321. } else if (auto *t = var->getType()->getAs<RecordType>()) {
  322. const llvm::StringRef typeName = t->getDecl()->getName();
  323. // These types are all translated into OpTypeStruct with BufferBlock
  324. // decoration. They should follow standard storage buffer layout,
  325. // which GLSL std430 rules statisfies.
  326. if (typeName == "StructuredBuffer" || typeName == "ByteAddressBuffer" ||
  327. typeName == "RWByteAddressBuffer") {
  328. storageClass = spv::StorageClass::Uniform;
  329. rule = LayoutRule::GLSLStd430;
  330. } else if (typeName == "RWStructuredBuffer" ||
  331. typeName == "AppendStructuredBuffer" ||
  332. typeName == "ConsumeStructuredBuffer") {
  333. storageClass = spv::StorageClass::Uniform;
  334. rule = LayoutRule::GLSLStd430;
  335. isACRWSBuffer = true;
  336. }
  337. }
  338. const auto varType = typeTranslator.translateType(var->getType(), rule);
  339. const uint32_t id = theBuilder.addModuleVar(varType, storageClass,
  340. var->getName(), llvm::None);
  341. astDecls[var] =
  342. SpirvEvalInfo(id).setStorageClass(storageClass).setLayoutRule(rule);
  343. const auto *regAttr = getResourceBinding(var);
  344. const auto *bindingAttr = var->getAttr<VKBindingAttr>();
  345. const auto *counterBindingAttr = var->getAttr<VKCounterBindingAttr>();
  346. resourceVars.emplace_back(id, getResourceCategory(var->getType()), regAttr,
  347. bindingAttr, counterBindingAttr);
  348. if (const auto *inputAttachment = var->getAttr<VKInputAttachmentIndexAttr>())
  349. theBuilder.decorateInputAttachmentIndex(id, inputAttachment->getIndex());
  350. if (isACRWSBuffer) {
  351. // For {Append|Consume|RW}StructuredBuffer, we need to always create another
  352. // variable for its associated counter.
  353. createCounterVar(var, /*isAlias=*/false);
  354. }
  355. return id;
  356. }
  357. uint32_t DeclResultIdMapper::createVarOfExplicitLayoutStruct(
  358. const DeclContext *decl, const ContextUsageKind usageKind,
  359. llvm::StringRef typeName, llvm::StringRef varName) {
  360. // cbuffers are translated into OpTypeStruct with Block decoration.
  361. // tbuffers are translated into OpTypeStruct with BufferBlock decoration.
  362. // PushConstants are translated into OpTypeStruct with Block decoration.
  363. //
  364. // Both cbuffers and tbuffers have the SPIR-V Uniform storage class. cbuffers
  365. // follow GLSL std140 layout rules, and tbuffers follow GLSL std430 layout
  366. // rules. PushConstants follow GLSL std430 layout rules.
  367. auto &context = *theBuilder.getSPIRVContext();
  368. const LayoutRule layoutRule = usageKind == ContextUsageKind::CBuffer
  369. ? LayoutRule::GLSLStd140
  370. : LayoutRule::GLSLStd430;
  371. const auto *blockDec = usageKind == ContextUsageKind::TBuffer
  372. ? Decoration::getBufferBlock(context)
  373. : Decoration::getBlock(context);
  374. auto decorations = typeTranslator.getLayoutDecorations(decl, layoutRule);
  375. decorations.push_back(blockDec);
  376. // Collect the type and name for each field
  377. llvm::SmallVector<uint32_t, 4> fieldTypes;
  378. llvm::SmallVector<llvm::StringRef, 4> fieldNames;
  379. uint32_t fieldIndex = 0;
  380. for (const auto *subDecl : decl->decls()) {
  381. // Ignore implicit generated struct declarations/constructors/destructors.
  382. // Ignore embedded struct/union/class/enum/function decls.
  383. if (subDecl->isImplicit() || isa<TagDecl>(subDecl) ||
  384. isa<FunctionDecl>(subDecl))
  385. continue;
  386. // The field can only be FieldDecl (for normal structs) or VarDecl (for
  387. // HLSLBufferDecls).
  388. assert(isa<VarDecl>(subDecl) || isa<FieldDecl>(subDecl));
  389. const auto *declDecl = cast<DeclaratorDecl>(subDecl);
  390. // All fields are qualified with const. It will affect the debug name.
  391. // We don't need it here.
  392. auto varType = declDecl->getType();
  393. varType.removeLocalConst();
  394. const bool isRowMajor = typeTranslator.isRowMajorMatrix(varType, declDecl);
  395. fieldTypes.push_back(
  396. typeTranslator.translateType(varType, layoutRule, isRowMajor));
  397. fieldNames.push_back(declDecl->getName());
  398. // tbuffer/TextureBuffers are non-writable SSBOs. OpMemberDecorate
  399. // NonWritable must be applied to all fields.
  400. if (usageKind == ContextUsageKind::TBuffer) {
  401. decorations.push_back(Decoration::getNonWritable(
  402. *theBuilder.getSPIRVContext(), fieldIndex));
  403. }
  404. ++fieldIndex;
  405. }
  406. // Get the type for the whole struct
  407. const uint32_t structType =
  408. theBuilder.getStructType(fieldTypes, typeName, fieldNames, decorations);
  409. // Register the <type-id> for this decl
  410. ctBufferPCTypeIds[decl] = structType;
  411. const auto sc = usageKind == ContextUsageKind::PushConstant
  412. ? spv::StorageClass::PushConstant
  413. : spv::StorageClass::Uniform;
  414. // Create the variable for the whole struct
  415. return theBuilder.addModuleVar(structType, sc, varName);
  416. }
  417. uint32_t DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
  418. const auto usageKind =
  419. decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;
  420. const std::string structName = "type." + decl->getName().str();
  421. const std::string varName = "var." + decl->getName().str();
  422. const uint32_t bufferVar =
  423. createVarOfExplicitLayoutStruct(decl, usageKind, structName, varName);
  424. // We still register all VarDecls seperately here. All the VarDecls are
  425. // mapped to the <result-id> of the buffer object, which means when querying
  426. // querying the <result-id> for a certain VarDecl, we need to do an extra
  427. // OpAccessChain.
  428. int index = 0;
  429. for (const auto *subDecl : decl->decls()) {
  430. // Ignore implicit generated struct declarations/constructors/destructors.
  431. // Ignore embedded struct/union/class/enum/function decls.
  432. if (subDecl->isImplicit() || isa<TagDecl>(subDecl) ||
  433. isa<FunctionDecl>(subDecl))
  434. continue;
  435. const auto *varDecl = cast<VarDecl>(subDecl);
  436. const bool isRowMajor =
  437. typeTranslator.isRowMajorMatrix(varDecl->getType(), varDecl);
  438. astDecls[varDecl] = {SpirvEvalInfo(bufferVar)
  439. .setStorageClass(spv::StorageClass::Uniform)
  440. .setLayoutRule(decl->isCBuffer()
  441. ? LayoutRule::GLSLStd140
  442. : LayoutRule::GLSLStd430),
  443. index++, isRowMajor};
  444. }
  445. resourceVars.emplace_back(
  446. bufferVar, ResourceVar::Category::Other, getResourceBinding(decl),
  447. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  448. return bufferVar;
  449. }
  450. uint32_t DeclResultIdMapper::createCTBuffer(const VarDecl *decl) {
  451. const auto *recordType = decl->getType()->getAs<RecordType>();
  452. assert(recordType);
  453. const auto *context = cast<HLSLBufferDecl>(decl->getDeclContext());
  454. const auto usageKind = context->isCBuffer() ? ContextUsageKind::CBuffer
  455. : ContextUsageKind::TBuffer;
  456. const char *ctBufferName =
  457. context->isCBuffer() ? "ConstantBuffer." : "TextureBuffer.";
  458. const std::string structName = "type." + std::string(ctBufferName) +
  459. recordType->getDecl()->getName().str();
  460. const uint32_t bufferVar = createVarOfExplicitLayoutStruct(
  461. recordType->getDecl(), usageKind, structName, decl->getName());
  462. // We register the VarDecl here.
  463. astDecls[decl] =
  464. SpirvEvalInfo(bufferVar)
  465. .setStorageClass(spv::StorageClass::Uniform)
  466. .setLayoutRule(context->isCBuffer() ? LayoutRule::GLSLStd140
  467. : LayoutRule::GLSLStd430);
  468. resourceVars.emplace_back(
  469. bufferVar, ResourceVar::Category::Other, getResourceBinding(context),
  470. decl->getAttr<VKBindingAttr>(), decl->getAttr<VKCounterBindingAttr>());
  471. return bufferVar;
  472. }
  473. uint32_t DeclResultIdMapper::createPushConstant(const VarDecl *decl) {
  474. const auto *recordType = decl->getType()->getAs<RecordType>();
  475. assert(recordType);
  476. const std::string structName =
  477. "type.PushConstant." + recordType->getDecl()->getName().str();
  478. const uint32_t var = createVarOfExplicitLayoutStruct(
  479. recordType->getDecl(), ContextUsageKind::PushConstant, structName,
  480. decl->getName());
  481. // Register the VarDecl
  482. astDecls[decl] = SpirvEvalInfo(var)
  483. .setStorageClass(spv::StorageClass::PushConstant)
  484. .setLayoutRule(LayoutRule::GLSLStd430);
  485. // Do not push this variable into resourceVars since it does not need
  486. // descriptor set.
  487. return var;
  488. }
  489. uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
  490. if (const auto *info = getDeclSpirvInfo(fn))
  491. return info->info;
  492. auto &info = astDecls[fn].info;
  493. bool isAlias = false;
  494. const uint32_t type =
  495. getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias, &info);
  496. const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
  497. info.setResultId(id);
  498. // No need to dereference to get the pointer. Function returns that are
  499. // stand-alone aliases are already pointers to values. All other cases should
  500. // be normal rvalues.
  501. if (!isAlias ||
  502. !TypeTranslator::isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
  503. info.setRValue();
  504. return id;
  505. }
  506. const CounterIdAliasPair *DeclResultIdMapper::getCounterIdAliasPair(
  507. const DeclaratorDecl *decl, const llvm::SmallVector<uint32_t, 4> *indices) {
  508. if (!decl)
  509. return nullptr;
  510. if (indices) {
  511. // Indices are provided. Walk through the fields of the decl.
  512. const auto counter = fieldCounterVars.find(decl);
  513. if (counter != fieldCounterVars.end())
  514. return counter->second.get(*indices);
  515. } else {
  516. // No indices. Check the stand-alone entities.
  517. const auto counter = counterVars.find(decl);
  518. if (counter != counterVars.end())
  519. return &counter->second;
  520. }
  521. return nullptr;
  522. }
  523. const CounterVarFields *
  524. DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
  525. if (!decl)
  526. return nullptr;
  527. const auto found = fieldCounterVars.find(decl);
  528. if (found != fieldCounterVars.end())
  529. return &found->second;
  530. return nullptr;
  531. }
  532. void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
  533. uint32_t specConstant) {
  534. astDecls[decl].info.setResultId(specConstant).setRValue().setSpecConstant();
  535. }
  536. void DeclResultIdMapper::createCounterVar(
  537. const DeclaratorDecl *decl, bool isAlias,
  538. const llvm::SmallVector<uint32_t, 4> *indices) {
  539. std::string counterName = "counter.var." + decl->getName().str();
  540. if (indices) {
  541. // Append field indices to the name
  542. for (const auto index : *indices)
  543. counterName += "." + std::to_string(index);
  544. }
  545. uint32_t counterType = typeTranslator.getACSBufferCounter();
  546. // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
  547. // Alias counter variables should be created into the Private storage class.
  548. const spv::StorageClass sc =
  549. isAlias ? spv::StorageClass::Private : spv::StorageClass::Uniform;
  550. if (isAlias) {
  551. // Apply an extra level of pointer for alias counter variable
  552. counterType =
  553. theBuilder.getPointerType(counterType, spv::StorageClass::Uniform);
  554. }
  555. const uint32_t counterId =
  556. theBuilder.addModuleVar(counterType, sc, counterName);
  557. if (!isAlias) {
  558. // Non-alias counter variables should be put in to resourceVars so that
  559. // descriptors can be allocated for them.
  560. resourceVars.emplace_back(counterId, ResourceVar::Category::Other,
  561. getResourceBinding(decl),
  562. decl->getAttr<VKBindingAttr>(),
  563. decl->getAttr<VKCounterBindingAttr>(), true);
  564. }
  565. if (indices)
  566. fieldCounterVars[decl].append(*indices, counterId);
  567. else
  568. counterVars[decl] = {counterId, isAlias};
  569. }
  570. void DeclResultIdMapper::createFieldCounterVars(
  571. const DeclaratorDecl *rootDecl, const DeclaratorDecl *decl,
  572. llvm::SmallVector<uint32_t, 4> *indices) {
  573. const QualType type = getTypeOrFnRetType(decl);
  574. const auto *recordType = type->getAs<RecordType>();
  575. assert(recordType);
  576. const auto *recordDecl = recordType->getDecl();
  577. for (const auto *field : recordDecl->fields()) {
  578. indices->push_back(field->getFieldIndex()); // Build up the index chain
  579. const QualType fieldType = field->getType();
  580. if (TypeTranslator::isRWAppendConsumeSBuffer(fieldType))
  581. createCounterVar(rootDecl, /*isAlias=*/true, indices);
  582. else if (fieldType->isStructureType() &&
  583. !hlsl::IsHLSLResourceType(fieldType))
  584. // Go recursively into all nested structs
  585. createFieldCounterVars(rootDecl, field, indices);
  586. indices->pop_back();
  587. }
  588. }
  589. uint32_t
  590. DeclResultIdMapper::getCTBufferPushConstantTypeId(const DeclContext *decl) {
  591. const auto found = ctBufferPCTypeIds.find(decl);
  592. assert(found != ctBufferPCTypeIds.end());
  593. return found->second;
  594. }
  595. std::vector<uint32_t> DeclResultIdMapper::collectStageVars() const {
  596. std::vector<uint32_t> vars;
  597. for (auto var : glPerVertex.getStageInVars())
  598. vars.push_back(var);
  599. for (auto var : glPerVertex.getStageOutVars())
  600. vars.push_back(var);
  601. for (const auto &var : stageVars)
  602. vars.push_back(var.getSpirvId());
  603. return vars;
  604. }
  605. namespace {
  606. /// A class for managing stage input/output locations to avoid duplicate uses of
  607. /// the same location.
  608. class LocationSet {
  609. public:
  610. /// Maximum number of locations supported
  611. // Typically we won't have that many stage input or output variables.
  612. // Using 64 should be fine here.
  613. const static uint32_t kMaxLoc = 64;
  614. LocationSet() : usedLocs(kMaxLoc, false), nextLoc(0) {}
  615. /// Uses the given location.
  616. void useLoc(uint32_t loc) { usedLocs.set(loc); }
  617. /// Uses the next available location.
  618. uint32_t useNextLoc() {
  619. while (usedLocs[nextLoc])
  620. nextLoc++;
  621. usedLocs.set(nextLoc);
  622. return nextLoc++;
  623. }
  624. /// Returns true if the given location number is already used.
  625. bool isLocUsed(uint32_t loc) { return usedLocs[loc]; }
  626. private:
  627. llvm::SmallBitVector usedLocs; ///< All previously used locations
  628. uint32_t nextLoc; ///< Next available location
  629. };
  630. /// A class for managing resource bindings to avoid duplicate uses of the same
  631. /// set and binding number.
  632. class BindingSet {
  633. public:
  634. /// Tries to use the given set and binding number. Returns true if possible,
  635. /// false otherwise and writes the source location of where the binding number
  636. /// is used to *usedLoc.
  637. bool tryToUseBinding(uint32_t binding, uint32_t set,
  638. ResourceVar::Category category, SourceLocation tryLoc,
  639. SourceLocation *usedLoc) {
  640. const auto cat = static_cast<uint32_t>(category);
  641. // Note that we will create the entry for binding in bindings[set] here.
  642. // But that should not have bad effects since it defaults to zero.
  643. if ((usedBindings[set][binding] & cat) == 0) {
  644. usedBindings[set][binding] |= cat;
  645. whereUsed[set][binding] = tryLoc;
  646. return true;
  647. }
  648. *usedLoc = whereUsed[set][binding];
  649. return false;
  650. }
  651. /// Uses the next avaiable binding number in set 0.
  652. uint32_t useNextBinding(uint32_t set, ResourceVar::Category category) {
  653. auto &binding = usedBindings[set];
  654. auto &next = nextBindings[set];
  655. while (binding.count(next))
  656. ++next;
  657. binding[next] = static_cast<uint32_t>(category);
  658. return next++;
  659. }
  660. private:
  661. ///< set number -> (binding number -> resource category)
  662. llvm::DenseMap<uint32_t, llvm::DenseMap<uint32_t, uint32_t>> usedBindings;
  663. ///< set number -> (binding number -> source location)
  664. llvm::DenseMap<uint32_t, llvm::DenseMap<uint32_t, SourceLocation>> whereUsed;
  665. ///< set number -> next available binding number
  666. llvm::DenseMap<uint32_t, uint32_t> nextBindings;
  667. };
  668. } // namespace
  669. bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
  670. llvm::StringSet<> seenSemantics;
  671. bool success = true;
  672. for (const auto &var : stageVars) {
  673. auto s = var.getSemanticStr();
  674. if (forInput && var.getSigPoint()->IsInput()) {
  675. if (seenSemantics.count(s)) {
  676. emitError("input semantic '%0' used more than once", {}) << s;
  677. success = false;
  678. }
  679. seenSemantics.insert(s);
  680. } else if (!forInput && var.getSigPoint()->IsOutput()) {
  681. if (seenSemantics.count(s)) {
  682. emitError("output semantic '%0' used more than once", {}) << s;
  683. success = false;
  684. }
  685. seenSemantics.insert(s);
  686. }
  687. }
  688. return success;
  689. }
  690. bool DeclResultIdMapper::finalizeStageIOLocations(bool forInput) {
  691. if (!checkSemanticDuplication(forInput))
  692. return false;
  693. // Returns false if the given StageVar is an input/output variable without
  694. // explicit location assignment. Otherwise, returns true.
  695. const auto locAssigned = [forInput, this](const StageVar &v) {
  696. if (forInput == isInputStorageClass(v))
  697. // No need to assign location for builtins. Treat as assigned.
  698. return v.isSpirvBuitin() || v.getLocationAttr() != nullptr;
  699. // For the ones we don't care, treat as assigned.
  700. return true;
  701. };
  702. // If we have explicit location specified for all input/output variables,
  703. // use them instead assign by ourselves.
  704. if (std::all_of(stageVars.begin(), stageVars.end(), locAssigned)) {
  705. LocationSet locSet;
  706. bool noError = true;
  707. for (const auto &var : stageVars) {
  708. // Skip those stage variables we are not handling for this call
  709. if (forInput != isInputStorageClass(var))
  710. continue;
  711. // Skip builtins
  712. if (var.isSpirvBuitin())
  713. continue;
  714. const auto *attr = var.getLocationAttr();
  715. const auto loc = attr->getNumber();
  716. const auto attrLoc = attr->getLocation(); // Attr source code location
  717. if (loc >= LocationSet::kMaxLoc) {
  718. emitError("stage %select{output|input}0 location #%1 too large",
  719. attrLoc)
  720. << forInput << loc;
  721. return false;
  722. }
  723. // Make sure the same location is not assigned more than once
  724. if (locSet.isLocUsed(loc)) {
  725. emitError("stage %select{output|input}0 location #%1 already assigned",
  726. attrLoc)
  727. << forInput << loc;
  728. noError = false;
  729. }
  730. locSet.useLoc(loc);
  731. theBuilder.decorateLocation(var.getSpirvId(), loc);
  732. }
  733. return noError;
  734. }
  735. std::vector<const StageVar *> vars;
  736. LocationSet locSet;
  737. for (const auto &var : stageVars) {
  738. if (forInput != isInputStorageClass(var))
  739. continue;
  740. if (!var.isSpirvBuitin()) {
  741. if (var.getLocationAttr() != nullptr) {
  742. // We have checked that not all of the stage variables have explicit
  743. // location assignment.
  744. emitError("partial explicit stage %select{output|input}0 location "
  745. "assignment via vk::location(X) unsupported",
  746. {})
  747. << forInput;
  748. return false;
  749. }
  750. // Only SV_Target, SV_Depth, SV_DepthLessEqual, SV_DepthGreaterEqual,
  751. // SV_StencilRef, SV_Coverage are allowed in the pixel shader.
  752. // Arbitrary semantics are disallowed in pixel shader.
  753. if (var.getSemantic() &&
  754. var.getSemantic()->GetKind() == hlsl::Semantic::Kind::Target) {
  755. theBuilder.decorateLocation(var.getSpirvId(), var.getSemanticIndex());
  756. locSet.useLoc(var.getSemanticIndex());
  757. } else {
  758. vars.push_back(&var);
  759. }
  760. }
  761. }
  762. // If alphabetical ordering was requested, sort by semantic string.
  763. // Since HS includes 2 sets of outputs (patch-constant output and
  764. // OutputPatch), running into location mismatches between HS and DS is very
  765. // likely. In order to avoid location mismatches between HS and DS, use
  766. // alphabetical ordering.
  767. if (spirvOptions.stageIoOrder == "alpha" ||
  768. (!forInput && shaderModel.IsHS()) || (forInput && shaderModel.IsDS())) {
  769. // Sort stage input/output variables alphabetically
  770. std::sort(vars.begin(), vars.end(),
  771. [](const StageVar *a, const StageVar *b) {
  772. return a->getSemanticStr() < b->getSemanticStr();
  773. });
  774. }
  775. for (const auto *var : vars)
  776. theBuilder.decorateLocation(var->getSpirvId(), locSet.useNextLoc());
  777. return true;
  778. }
  779. namespace {
  780. /// A class for maintaining the binding number shift requested for descriptor
  781. /// sets.
  782. class BindingShiftMapper {
  783. public:
  784. explicit BindingShiftMapper(const llvm::SmallVectorImpl<uint32_t> &shifts)
  785. : masterShift(0) {
  786. assert(shifts.size() % 2 == 0);
  787. for (uint32_t i = 0; i < shifts.size(); i += 2)
  788. perSetShift[shifts[i + 1]] = shifts[i];
  789. }
  790. /// Returns the shift amount for the given set.
  791. uint32_t getShiftForSet(uint32_t set) const {
  792. const auto found = perSetShift.find(set);
  793. if (found != perSetShift.end())
  794. return found->second;
  795. return masterShift;
  796. }
  797. private:
  798. uint32_t masterShift; /// Shift amount applies to all sets.
  799. llvm::DenseMap<uint32_t, uint32_t> perSetShift;
  800. };
  801. } // namespace
  802. bool DeclResultIdMapper::decorateResourceBindings() {
  803. // For normal resource, we support 3 approaches of setting binding numbers:
  804. // - m1: [[vk::binding(...)]]
  805. // - m2: :register(...)
  806. // - m3: None
  807. //
  808. // For associated counters, we support 2 approaches:
  809. // - c1: [[vk::counter_binding(...)]
  810. // - c2: None
  811. //
  812. // In combination, we need to handle 9 cases:
  813. // - 3 cases for nomral resoures (m1, m2, m3)
  814. // - 6 cases for associated counters (mX * cY)
  815. //
  816. // In the following order:
  817. // - m1, mX * c1
  818. // - m2
  819. // - m3, mX * c2
  820. BindingSet bindingSet;
  821. // Decorates the given varId of the given category with set number
  822. // setNo, binding number bindingNo. Emits warning if overlap.
  823. const auto tryToDecorate = [this, &bindingSet](
  824. const uint32_t varId, const uint32_t setNo,
  825. const uint32_t bindingNo,
  826. const ResourceVar::Category cat,
  827. SourceLocation loc) {
  828. SourceLocation prevUseLoc;
  829. if (!bindingSet.tryToUseBinding(bindingNo, setNo, cat, loc, &prevUseLoc)) {
  830. emitWarning("resource binding #%0 in descriptor set #%1 already assigned",
  831. loc)
  832. << bindingNo << setNo;
  833. emitNote("binding number previously assigned here", prevUseLoc);
  834. }
  835. theBuilder.decorateDSetBinding(varId, setNo, bindingNo);
  836. };
  837. for (const auto &var : resourceVars) {
  838. if (var.isCounter()) {
  839. if (const auto *vkCBinding = var.getCounterBinding()) {
  840. // Process mX * c1
  841. uint32_t set = 0;
  842. if (const auto *vkBinding = var.getBinding())
  843. set = vkBinding->getSet();
  844. if (const auto *reg = var.getRegister())
  845. set = reg->RegisterSpace;
  846. tryToDecorate(var.getSpirvId(), set, vkCBinding->getBinding(),
  847. var.getCategory(), vkCBinding->getLocation());
  848. }
  849. } else {
  850. if (const auto *vkBinding = var.getBinding()) {
  851. // Process m1
  852. tryToDecorate(var.getSpirvId(), vkBinding->getSet(),
  853. vkBinding->getBinding(), var.getCategory(),
  854. vkBinding->getLocation());
  855. }
  856. }
  857. }
  858. BindingShiftMapper bShiftMapper(spirvOptions.bShift);
  859. BindingShiftMapper tShiftMapper(spirvOptions.tShift);
  860. BindingShiftMapper sShiftMapper(spirvOptions.sShift);
  861. BindingShiftMapper uShiftMapper(spirvOptions.uShift);
  862. // Process m2
  863. for (const auto &var : resourceVars)
  864. if (!var.isCounter() && !var.getBinding())
  865. if (const auto *reg = var.getRegister()) {
  866. const uint32_t set = reg->RegisterSpace;
  867. uint32_t binding = reg->RegisterNumber;
  868. switch (reg->RegisterType) {
  869. case 'b':
  870. binding += bShiftMapper.getShiftForSet(set);
  871. break;
  872. case 't':
  873. binding += tShiftMapper.getShiftForSet(set);
  874. break;
  875. case 's':
  876. binding += sShiftMapper.getShiftForSet(set);
  877. break;
  878. case 'u':
  879. binding += uShiftMapper.getShiftForSet(set);
  880. break;
  881. case 'c':
  882. // For setting packing offset. Does not affect binding.
  883. break;
  884. default:
  885. llvm_unreachable("unknown register type found");
  886. }
  887. tryToDecorate(var.getSpirvId(), set, binding, var.getCategory(),
  888. reg->Loc);
  889. }
  890. for (const auto &var : resourceVars) {
  891. const auto cat = var.getCategory();
  892. if (var.isCounter()) {
  893. if (!var.getCounterBinding()) {
  894. // Process mX * c2
  895. uint32_t set = 0;
  896. if (const auto *vkBinding = var.getBinding())
  897. set = vkBinding->getSet();
  898. else if (const auto *reg = var.getRegister())
  899. set = reg->RegisterSpace;
  900. theBuilder.decorateDSetBinding(var.getSpirvId(), set,
  901. bindingSet.useNextBinding(set, cat));
  902. }
  903. } else if (!var.getBinding() && !var.getRegister()) {
  904. // Process m3
  905. theBuilder.decorateDSetBinding(var.getSpirvId(), 0,
  906. bindingSet.useNextBinding(0, cat));
  907. }
  908. }
  909. return true;
  910. }
  911. bool DeclResultIdMapper::createStageVars(
  912. const hlsl::SigPoint *sigPoint, const DeclaratorDecl *decl, bool asInput,
  913. QualType type, uint32_t arraySize, const llvm::StringRef namePrefix,
  914. llvm::Optional<uint32_t> invocationId, uint32_t *value, bool noWriteBack,
  915. SemanticInfo *inheritSemantic) {
  916. // invocationId should only be used for handling HS per-vertex output.
  917. if (invocationId.hasValue()) {
  918. assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
  919. }
  920. assert(inheritSemantic);
  921. if (type->isVoidType()) {
  922. // No stage variables will be created for void type.
  923. return true;
  924. }
  925. uint32_t typeId = typeTranslator.translateType(type);
  926. // We have several cases regarding HLSL semantics to handle here:
  927. // * If the currrent decl inherits a semantic from some enclosing entity,
  928. // use the inherited semantic no matter whether there is a semantic
  929. // attached to the current decl.
  930. // * If there is no semantic to inherit,
  931. // * If the current decl is a struct,
  932. // * If the current decl has a semantic, all its members inhert this
  933. // decl's semantic, with the index sequentially increasing;
  934. // * If the current decl does not have a semantic, all its members
  935. // should have semantics attached;
  936. // * If the current decl is not a struct, it should have semantic attached.
  937. auto thisSemantic = getStageVarSemantic(decl);
  938. // Which semantic we should use for this decl
  939. auto *semanticToUse = &thisSemantic;
  940. // Enclosing semantics override internal ones
  941. if (inheritSemantic->isValid()) {
  942. if (thisSemantic.isValid()) {
  943. emitWarning(
  944. "internal semantic '%0' overridden by enclosing semantic '%1'",
  945. thisSemantic.loc)
  946. << thisSemantic.str << inheritSemantic->str;
  947. }
  948. semanticToUse = inheritSemantic;
  949. }
  950. if (semanticToUse->isValid() &&
  951. // Structs with attached semantics will be handled later.
  952. !type->isStructureType()) {
  953. // Found semantic attached directly to this Decl. This means we need to
  954. // map this decl to a single stage variable.
  955. const auto semanticKind = semanticToUse->semantic->GetKind();
  956. // Error out when the given semantic is invalid in this shader model
  957. if (hlsl::SigPoint::GetInterpretation(semanticKind, sigPoint->GetKind(),
  958. shaderModel.GetMajor(),
  959. shaderModel.GetMinor()) ==
  960. hlsl::DXIL::SemanticInterpretationKind::NA) {
  961. emitError("invalid usage of semantic '%0' in shader profile %1",
  962. decl->getLocation())
  963. << semanticToUse->str << shaderModel.GetName();
  964. return false;
  965. }
  966. if (!validateVKBuiltins(decl, sigPoint))
  967. return false;
  968. const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>();
  969. // For VS/HS/DS, the PointSize builtin is handled in gl_PerVertex.
  970. // For GSVIn also in gl_PerVertex; for GSOut, it's a stand-alone
  971. // variable handled below.
  972. if (builtinAttr && builtinAttr->getBuiltIn() == "PointSize" &&
  973. glPerVertex.tryToAccessPointSize(sigPoint->GetKind(), invocationId,
  974. value, noWriteBack))
  975. return true;
  976. // Special handling of certain mappings between HLSL semantics and
  977. // SPIR-V builtins:
  978. // * SV_Position/SV_CullDistance/SV_ClipDistance should be grouped into the
  979. // gl_PerVertex struct in vertex processing stages.
  980. // * SV_DomainLocation can refer to a float2, whereas TessCoord is a float3.
  981. // To ensure SPIR-V validity, we must create a float3 and extract a
  982. // float2 from it before passing it to the main function.
  983. // * SV_TessFactor is an array of size 2 for isoline patch, array of size 3
  984. // for tri patch, and array of size 4 for quad patch, but it must always
  985. // be an array of size 4 in SPIR-V for Vulkan.
  986. // * SV_InsideTessFactor is a single float for tri patch, and an array of
  987. // size 2 for a quad patch, but it must always be an array of size 2 in
  988. // SPIR-V for Vulkan.
  989. // * SV_Coverage is an uint value, but the builtin it corresponds to,
  990. // SampleMask, must be an array of integers.
  991. if (glPerVertex.tryToAccess(sigPoint->GetKind(), semanticKind,
  992. semanticToUse->index, invocationId, value,
  993. noWriteBack))
  994. return true;
  995. const uint32_t srcTypeId = typeId; // Variable type in source code
  996. switch (semanticKind) {
  997. case hlsl::Semantic::Kind::DomainLocation:
  998. typeId = theBuilder.getVecType(theBuilder.getFloat32Type(), 3);
  999. break;
  1000. case hlsl::Semantic::Kind::TessFactor:
  1001. typeId = theBuilder.getArrayType(theBuilder.getFloat32Type(),
  1002. theBuilder.getConstantUint32(4));
  1003. break;
  1004. case hlsl::Semantic::Kind::InsideTessFactor:
  1005. typeId = theBuilder.getArrayType(theBuilder.getFloat32Type(),
  1006. theBuilder.getConstantUint32(2));
  1007. break;
  1008. case hlsl::Semantic::Kind::Coverage:
  1009. typeId = theBuilder.getArrayType(typeId, theBuilder.getConstantUint32(1));
  1010. break;
  1011. case hlsl::Semantic::Kind::Barycentrics:
  1012. typeId = theBuilder.getVecType(theBuilder.getFloat32Type(), 2);
  1013. break;
  1014. }
  1015. // Handle the extra arrayness
  1016. const uint32_t elementTypeId = typeId; // Array element's type
  1017. if (arraySize != 0)
  1018. typeId = theBuilder.getArrayType(typeId,
  1019. theBuilder.getConstantUint32(arraySize));
  1020. StageVar stageVar(sigPoint, semanticToUse->str, semanticToUse->semantic,
  1021. semanticToUse->name, semanticToUse->index, builtinAttr,
  1022. typeId);
  1023. const auto name = namePrefix.str() + "." + stageVar.getSemanticStr();
  1024. const uint32_t varId =
  1025. createSpirvStageVar(&stageVar, decl, name, semanticToUse->loc);
  1026. if (varId == 0)
  1027. return false;
  1028. stageVar.setSpirvId(varId);
  1029. stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
  1030. stageVars.push_back(stageVar);
  1031. stageVarIds[decl] = varId;
  1032. // Mark that we have used one index for this semantic
  1033. ++semanticToUse->index;
  1034. // TODO: the following may not be correct?
  1035. if (sigPoint->GetSignatureKind() ==
  1036. hlsl::DXIL::SignatureKind::PatchConstant)
  1037. theBuilder.decorate(varId, spv::Decoration::Patch);
  1038. // Decorate with interpolation modes for pixel shader input variables
  1039. if (shaderModel.IsPS() && sigPoint->IsInput() &&
  1040. // BaryCoord*AMD buitins already encode the interpolation mode.
  1041. semanticKind != hlsl::Semantic::Kind::Barycentrics)
  1042. decoratePSInterpolationMode(decl, type, varId);
  1043. if (asInput) {
  1044. *value = theBuilder.createLoad(typeId, varId);
  1045. // Fix ups for corner cases
  1046. // Special handling of SV_TessFactor DS patch constant input.
  1047. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1048. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1049. // relevant indexes must be loaded.
  1050. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1051. hlsl::GetArraySize(type) != 4) {
  1052. llvm::SmallVector<uint32_t, 4> components;
  1053. const auto f32TypeId = theBuilder.getFloat32Type();
  1054. const auto tessFactorSize = hlsl::GetArraySize(type);
  1055. const auto arrType = theBuilder.getArrayType(
  1056. f32TypeId, theBuilder.getConstantUint32(tessFactorSize));
  1057. for (uint32_t i = 0; i < tessFactorSize; ++i)
  1058. components.push_back(
  1059. theBuilder.createCompositeExtract(f32TypeId, *value, {i}));
  1060. *value = theBuilder.createCompositeConstruct(arrType, components);
  1061. }
  1062. // Special handling of SV_InsideTessFactor DS patch constant input.
  1063. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1064. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1065. // HLSL. If SV_InsideTessFactor is a scalar, only extract index 0 of
  1066. // TessLevelInner.
  1067. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1068. // Some developers use float[1] instead of a scalar float.
  1069. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1070. const auto f32Type = theBuilder.getFloat32Type();
  1071. *value = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1072. if (type->isArrayType()) // float[1]
  1073. *value = theBuilder.createCompositeConstruct(
  1074. theBuilder.getArrayType(f32Type, theBuilder.getConstantUint32(1)),
  1075. {*value});
  1076. }
  1077. // SV_DomainLocation can refer to a float2 or a float3, whereas TessCoord
  1078. // is always a float3. To ensure SPIR-V validity, a float3 stage variable
  1079. // is created, and we must extract a float2 from it before passing it to
  1080. // the main function.
  1081. else if (semanticKind == hlsl::Semantic::Kind::DomainLocation &&
  1082. hlsl::GetHLSLVecSize(type) != 3) {
  1083. const auto domainLocSize = hlsl::GetHLSLVecSize(type);
  1084. *value = theBuilder.createVectorShuffle(
  1085. theBuilder.getVecType(theBuilder.getFloat32Type(), domainLocSize),
  1086. *value, *value, {0, 1});
  1087. }
  1088. // Special handling of SV_Coverage, which is an uint value. We need to
  1089. // read SampleMask and extract its first element.
  1090. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1091. *value = theBuilder.createCompositeExtract(srcTypeId, *value, {0});
  1092. }
  1093. // Special handling of SV_Barycentrics, which is a float3, but the
  1094. // underlying stage input variable is a float2 (only provides the first
  1095. // two components). Calculate the third element.
  1096. else if (semanticKind == hlsl::Semantic::Kind::Barycentrics) {
  1097. const auto f32Type = theBuilder.getFloat32Type();
  1098. const auto x = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1099. const auto y = theBuilder.createCompositeExtract(f32Type, *value, {1});
  1100. const auto xy =
  1101. theBuilder.createBinaryOp(spv::Op::OpFAdd, f32Type, x, y);
  1102. const auto z = theBuilder.createBinaryOp(
  1103. spv::Op::OpFSub, f32Type, theBuilder.getConstantFloat32(1), xy);
  1104. const auto v3f32Type = theBuilder.getVecType(f32Type, 3);
  1105. *value = theBuilder.createCompositeConstruct(v3f32Type, {x, y, z});
  1106. }
  1107. } else {
  1108. if (noWriteBack)
  1109. return true;
  1110. uint32_t ptr = varId;
  1111. // Special handling of SV_TessFactor HS patch constant output.
  1112. // TessLevelOuter is always an array of size 4 in SPIR-V, but
  1113. // SV_TessFactor could be an array of size 2, 3, or 4 in HLSL. Only the
  1114. // relevant indexes must be written to.
  1115. if (semanticKind == hlsl::Semantic::Kind::TessFactor &&
  1116. hlsl::GetArraySize(type) != 4) {
  1117. const auto f32TypeId = theBuilder.getFloat32Type();
  1118. const auto tessFactorSize = hlsl::GetArraySize(type);
  1119. for (uint32_t i = 0; i < tessFactorSize; ++i) {
  1120. const uint32_t ptrType =
  1121. theBuilder.getPointerType(f32TypeId, spv::StorageClass::Output);
  1122. ptr = theBuilder.createAccessChain(ptrType, varId,
  1123. theBuilder.getConstantUint32(i));
  1124. theBuilder.createStore(
  1125. ptr, theBuilder.createCompositeExtract(f32TypeId, *value, i));
  1126. }
  1127. }
  1128. // Special handling of SV_InsideTessFactor HS patch constant output.
  1129. // TessLevelInner is always an array of size 2 in SPIR-V, but
  1130. // SV_InsideTessFactor could be an array of size 1 (scalar) or size 2 in
  1131. // HLSL. If SV_InsideTessFactor is a scalar, only write to index 0 of
  1132. // TessLevelInner.
  1133. else if (semanticKind == hlsl::Semantic::Kind::InsideTessFactor &&
  1134. // Some developers use float[1] instead of a scalar float.
  1135. (!type->isArrayType() || hlsl::GetArraySize(type) == 1)) {
  1136. const auto f32Type = theBuilder.getFloat32Type();
  1137. ptr = theBuilder.createAccessChain(
  1138. theBuilder.getPointerType(f32Type, spv::StorageClass::Output),
  1139. varId, theBuilder.getConstantUint32(0));
  1140. if (type->isArrayType()) // float[1]
  1141. *value = theBuilder.createCompositeExtract(f32Type, *value, {0});
  1142. theBuilder.createStore(ptr, *value);
  1143. }
  1144. // Special handling of SV_Coverage, which is an unit value. We need to
  1145. // write it to the first element in the SampleMask builtin.
  1146. else if (semanticKind == hlsl::Semantic::Kind::Coverage) {
  1147. ptr = theBuilder.createAccessChain(
  1148. theBuilder.getPointerType(srcTypeId, spv::StorageClass::Output),
  1149. varId, theBuilder.getConstantUint32(0));
  1150. theBuilder.createStore(ptr, *value);
  1151. }
  1152. // Special handling of HS ouput, for which we write to only one
  1153. // element in the per-vertex data array: the one indexed by
  1154. // SV_ControlPointID.
  1155. else if (invocationId.hasValue()) {
  1156. const uint32_t ptrType =
  1157. theBuilder.getPointerType(elementTypeId, spv::StorageClass::Output);
  1158. const uint32_t index = invocationId.getValue();
  1159. ptr = theBuilder.createAccessChain(ptrType, varId, index);
  1160. theBuilder.createStore(ptr, *value);
  1161. }
  1162. // For all normal cases
  1163. else {
  1164. theBuilder.createStore(ptr, *value);
  1165. }
  1166. }
  1167. return true;
  1168. }
  1169. // If the decl itself doesn't have semantic string attached and there is no
  1170. // one to inherit, it should be a struct having all its fields with semantic
  1171. // strings.
  1172. if (!semanticToUse->isValid() && !type->isStructureType()) {
  1173. emitError("semantic string missing for shader %select{output|input}0 "
  1174. "variable '%1'",
  1175. decl->getLocation())
  1176. << asInput << decl->getName();
  1177. return false;
  1178. }
  1179. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  1180. if (asInput) {
  1181. // If this decl translates into multiple stage input variables, we need to
  1182. // load their values into a composite.
  1183. llvm::SmallVector<uint32_t, 4> subValues;
  1184. for (const auto *field : structDecl->fields()) {
  1185. uint32_t subValue = 0;
  1186. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  1187. arraySize, namePrefix, invocationId, &subValue,
  1188. noWriteBack, semanticToUse))
  1189. return false;
  1190. subValues.push_back(subValue);
  1191. }
  1192. if (arraySize == 0) {
  1193. *value = theBuilder.createCompositeConstruct(typeId, subValues);
  1194. return true;
  1195. }
  1196. // Handle the extra level of arrayness.
  1197. // We need to return an array of structs. But we get arrays of fields
  1198. // from visiting all fields. So now we need to extract all the elements
  1199. // at the same index of each field arrays and compose a new struct out
  1200. // of them.
  1201. const uint32_t structType = typeTranslator.translateType(type);
  1202. const uint32_t arrayType = theBuilder.getArrayType(
  1203. structType, theBuilder.getConstantUint32(arraySize));
  1204. llvm::SmallVector<uint32_t, 16> arrayElements;
  1205. for (uint32_t arrayIndex = 0; arrayIndex < arraySize; ++arrayIndex) {
  1206. llvm::SmallVector<uint32_t, 8> fields;
  1207. // Extract the element at index arrayIndex from each field
  1208. for (const auto *field : structDecl->fields()) {
  1209. const uint32_t fieldType =
  1210. typeTranslator.translateType(field->getType());
  1211. fields.push_back(theBuilder.createCompositeExtract(
  1212. fieldType, subValues[field->getFieldIndex()], {arrayIndex}));
  1213. }
  1214. // Compose a new struct out of them
  1215. arrayElements.push_back(
  1216. theBuilder.createCompositeConstruct(structType, fields));
  1217. }
  1218. *value = theBuilder.createCompositeConstruct(arrayType, arrayElements);
  1219. } else {
  1220. // Unlike reading, which may require us to read stand-alone builtins and
  1221. // stage input variables and compose an array of structs out of them,
  1222. // it happens that we don't need to write an array of structs in a bunch
  1223. // for all shader stages:
  1224. //
  1225. // * VS: output is a single struct, without extra arrayness
  1226. // * HS: output is an array of structs, with extra arrayness,
  1227. // but we only write to the struct at the InvocationID index
  1228. // * DS: output is a single struct, without extra arrayness
  1229. // * GS: output is controlled by OpEmitVertex, one vertex per time
  1230. //
  1231. // The interesting shader stage is HS. We need the InvocationID to write
  1232. // out the value to the correct array element.
  1233. for (const auto *field : structDecl->fields()) {
  1234. const uint32_t fieldType = typeTranslator.translateType(field->getType());
  1235. uint32_t subValue = 0;
  1236. if (!noWriteBack)
  1237. subValue = theBuilder.createCompositeExtract(fieldType, *value,
  1238. {field->getFieldIndex()});
  1239. if (!createStageVars(sigPoint, field, asInput, field->getType(),
  1240. arraySize, namePrefix, invocationId, &subValue,
  1241. noWriteBack, semanticToUse))
  1242. return false;
  1243. }
  1244. }
  1245. return true;
  1246. }
  1247. bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
  1248. uint32_t value) {
  1249. assert(shaderModel.IsGS()); // Only for GS use
  1250. QualType type = decl->getType();
  1251. if (hlsl::IsHLSLStreamOutputType(type))
  1252. type = hlsl::GetHLSLResourceResultType(type);
  1253. if (hasGSPrimitiveTypeQualifier(decl))
  1254. type = astContext.getAsConstantArrayType(type)->getElementType();
  1255. auto semanticInfo = getStageVarSemantic(decl);
  1256. if (semanticInfo.isValid()) {
  1257. // Found semantic attached directly to this Decl. Write the value for this
  1258. // Decl to the corresponding stage output variable.
  1259. const uint32_t srcTypeId = typeTranslator.translateType(type);
  1260. // Handle SV_Position, SV_ClipDistance, and SV_CullDistance
  1261. if (glPerVertex.tryToAccess(
  1262. hlsl::DXIL::SigPointKind::GSOut, semanticInfo.semantic->GetKind(),
  1263. semanticInfo.index, llvm::None, &value, /*noWriteBack=*/false))
  1264. return true;
  1265. // Query the <result-id> for the stage output variable generated out
  1266. // of this decl.
  1267. const auto found = stageVarIds.find(decl);
  1268. // We should have recorded its stage output variable previously.
  1269. assert(found != stageVarIds.end());
  1270. // Negate SV_Position.y if requested
  1271. if (spirvOptions.invertY &&
  1272. semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position) {
  1273. const auto f32Type = theBuilder.getFloat32Type();
  1274. const auto v4f32Type = theBuilder.getVecType(f32Type, 4);
  1275. const auto oldY = theBuilder.createCompositeExtract(f32Type, value, {1});
  1276. const auto newY =
  1277. theBuilder.createUnaryOp(spv::Op::OpFNegate, f32Type, oldY);
  1278. value = theBuilder.createCompositeInsert(v4f32Type, value, {1}, newY);
  1279. }
  1280. theBuilder.createStore(found->second, value);
  1281. return true;
  1282. }
  1283. // If the decl itself doesn't have semantic string attached, it should be
  1284. // a struct having all its fields with semantic strings.
  1285. if (!type->isStructureType()) {
  1286. emitError("semantic string missing for shader output variable '%0'",
  1287. decl->getLocation())
  1288. << decl->getName();
  1289. return false;
  1290. }
  1291. const auto *structDecl = type->getAs<RecordType>()->getDecl();
  1292. // Write out each field
  1293. for (const auto *field : structDecl->fields()) {
  1294. const uint32_t fieldType = typeTranslator.translateType(field->getType());
  1295. const uint32_t subValue = theBuilder.createCompositeExtract(
  1296. fieldType, value, {field->getFieldIndex()});
  1297. if (!writeBackOutputStream(field, subValue))
  1298. return false;
  1299. }
  1300. return true;
  1301. }
  1302. void DeclResultIdMapper::decoratePSInterpolationMode(const DeclaratorDecl *decl,
  1303. QualType type,
  1304. uint32_t varId) {
  1305. const QualType elemType = typeTranslator.getElementType(type);
  1306. if (elemType->isBooleanType() || elemType->isIntegerType()) {
  1307. // TODO: Probably we can call hlsl::ValidateSignatureElement() for the
  1308. // following check.
  1309. if (decl->getAttr<HLSLLinearAttr>() || decl->getAttr<HLSLCentroidAttr>() ||
  1310. decl->getAttr<HLSLNoPerspectiveAttr>() ||
  1311. decl->getAttr<HLSLSampleAttr>()) {
  1312. emitError("only nointerpolation mode allowed for integer input "
  1313. "parameters in pixel shader",
  1314. decl->getLocation());
  1315. } else {
  1316. theBuilder.decorate(varId, spv::Decoration::Flat);
  1317. }
  1318. } else {
  1319. // Do nothing for HLSLLinearAttr since its the default
  1320. // Attributes can be used together. So cannot use else if.
  1321. if (decl->getAttr<HLSLCentroidAttr>())
  1322. theBuilder.decorate(varId, spv::Decoration::Centroid);
  1323. if (decl->getAttr<HLSLNoInterpolationAttr>())
  1324. theBuilder.decorate(varId, spv::Decoration::Flat);
  1325. if (decl->getAttr<HLSLNoPerspectiveAttr>())
  1326. theBuilder.decorate(varId, spv::Decoration::NoPerspective);
  1327. if (decl->getAttr<HLSLSampleAttr>()) {
  1328. theBuilder.requireCapability(spv::Capability::SampleRateShading);
  1329. theBuilder.decorate(varId, spv::Decoration::Sample);
  1330. }
  1331. }
  1332. }
  1333. uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar,
  1334. const DeclaratorDecl *decl,
  1335. const llvm::StringRef name,
  1336. SourceLocation srcLoc) {
  1337. using spv::BuiltIn;
  1338. const auto sigPoint = stageVar->getSigPoint();
  1339. const auto semanticKind = stageVar->getSemantic()->GetKind();
  1340. const auto sigPointKind = sigPoint->GetKind();
  1341. const uint32_t type = stageVar->getSpirvTypeId();
  1342. spv::StorageClass sc = getStorageClassForSigPoint(sigPoint);
  1343. if (sc == spv::StorageClass::Max)
  1344. return 0;
  1345. stageVar->setStorageClass(sc);
  1346. // [[vk::builtin(...)]] takes precedence.
  1347. if (const auto *builtinAttr = stageVar->getBuiltInAttr()) {
  1348. const auto spvBuiltIn =
  1349. llvm::StringSwitch<BuiltIn>(builtinAttr->getBuiltIn())
  1350. .Case("PointSize", BuiltIn::PointSize)
  1351. .Case("HelperInvocation", BuiltIn::HelperInvocation)
  1352. .Default(BuiltIn::Max);
  1353. assert(spvBuiltIn != BuiltIn::Max); // The frontend should guarantee this.
  1354. return theBuilder.addStageBuiltinVar(type, sc, spvBuiltIn);
  1355. }
  1356. // The following translation assumes that semantic validity in the current
  1357. // shader model is already checked, so it only covers valid SigPoints for
  1358. // each semantic.
  1359. switch (semanticKind) {
  1360. // According to DXIL spec, the Position SV can be used by all SigPoints
  1361. // other than PCIn, HSIn, GSIn, PSOut, CSIn.
  1362. // According to Vulkan spec, the Position BuiltIn can only be used
  1363. // by VSOut, HS/DS/GS In/Out.
  1364. case hlsl::Semantic::Kind::Position: {
  1365. switch (sigPointKind) {
  1366. case hlsl::SigPoint::Kind::VSIn:
  1367. case hlsl::SigPoint::Kind::PCOut:
  1368. case hlsl::SigPoint::Kind::DSIn:
  1369. return theBuilder.addStageIOVar(type, sc, name.str());
  1370. case hlsl::SigPoint::Kind::VSOut:
  1371. case hlsl::SigPoint::Kind::HSCPIn:
  1372. case hlsl::SigPoint::Kind::HSCPOut:
  1373. case hlsl::SigPoint::Kind::DSCPIn:
  1374. case hlsl::SigPoint::Kind::DSOut:
  1375. case hlsl::SigPoint::Kind::GSVIn:
  1376. llvm_unreachable("should be handled in gl_PerVertex struct");
  1377. case hlsl::SigPoint::Kind::GSOut:
  1378. stageVar->setIsSpirvBuiltin();
  1379. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::Position);
  1380. case hlsl::SigPoint::Kind::PSIn:
  1381. stageVar->setIsSpirvBuiltin();
  1382. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragCoord);
  1383. default:
  1384. llvm_unreachable("invalid usage of SV_Position sneaked in");
  1385. }
  1386. }
  1387. // According to DXIL spec, the VertexID SV can only be used by VSIn.
  1388. // According to Vulkan spec, the VertexIndex BuiltIn can only be used by
  1389. // VSIn.
  1390. case hlsl::Semantic::Kind::VertexID: {
  1391. stageVar->setIsSpirvBuiltin();
  1392. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::VertexIndex);
  1393. }
  1394. // According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
  1395. // HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
  1396. // According to Vulkan spec, the InstanceIndex BuitIn can only be used by
  1397. // VSIn.
  1398. case hlsl::Semantic::Kind::InstanceID: {
  1399. switch (sigPointKind) {
  1400. case hlsl::SigPoint::Kind::VSIn:
  1401. stageVar->setIsSpirvBuiltin();
  1402. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InstanceIndex);
  1403. case hlsl::SigPoint::Kind::VSOut:
  1404. case hlsl::SigPoint::Kind::HSCPIn:
  1405. case hlsl::SigPoint::Kind::HSCPOut:
  1406. case hlsl::SigPoint::Kind::DSCPIn:
  1407. case hlsl::SigPoint::Kind::DSOut:
  1408. case hlsl::SigPoint::Kind::GSVIn:
  1409. case hlsl::SigPoint::Kind::GSOut:
  1410. case hlsl::SigPoint::Kind::PSIn:
  1411. return theBuilder.addStageIOVar(type, sc, name.str());
  1412. default:
  1413. llvm_unreachable("invalid usage of SV_InstanceID sneaked in");
  1414. }
  1415. }
  1416. // According to DXIL spec, the Depth{|GreaterEqual|LessEqual} SV can only be
  1417. // used by PSOut.
  1418. // According to Vulkan spec, the FragDepth BuiltIn can only be used by PSOut.
  1419. case hlsl::Semantic::Kind::Depth:
  1420. case hlsl::Semantic::Kind::DepthGreaterEqual:
  1421. case hlsl::Semantic::Kind::DepthLessEqual: {
  1422. stageVar->setIsSpirvBuiltin();
  1423. if (semanticKind == hlsl::Semantic::Kind::DepthGreaterEqual)
  1424. theBuilder.addExecutionMode(entryFunctionId,
  1425. spv::ExecutionMode::DepthGreater, {});
  1426. else if (semanticKind == hlsl::Semantic::Kind::DepthLessEqual)
  1427. theBuilder.addExecutionMode(entryFunctionId,
  1428. spv::ExecutionMode::DepthLess, {});
  1429. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragDepth);
  1430. }
  1431. // According to DXIL spec, the ClipDistance/CullDistance SV can be used by all
  1432. // SigPoints other than PCIn, HSIn, GSIn, PSOut, CSIn.
  1433. // According to Vulkan spec, the ClipDistance/CullDistance BuiltIn can only
  1434. // be
  1435. // used by VSOut, HS/DS/GS In/Out.
  1436. case hlsl::Semantic::Kind::ClipDistance:
  1437. case hlsl::Semantic::Kind::CullDistance: {
  1438. switch (sigPointKind) {
  1439. case hlsl::SigPoint::Kind::VSIn:
  1440. case hlsl::SigPoint::Kind::PCOut:
  1441. case hlsl::SigPoint::Kind::DSIn:
  1442. return theBuilder.addStageIOVar(type, sc, name.str());
  1443. case hlsl::SigPoint::Kind::VSOut:
  1444. case hlsl::SigPoint::Kind::HSCPIn:
  1445. case hlsl::SigPoint::Kind::HSCPOut:
  1446. case hlsl::SigPoint::Kind::DSCPIn:
  1447. case hlsl::SigPoint::Kind::DSOut:
  1448. case hlsl::SigPoint::Kind::GSVIn:
  1449. case hlsl::SigPoint::Kind::GSOut:
  1450. case hlsl::SigPoint::Kind::PSIn:
  1451. llvm_unreachable("should be handled in gl_PerVertex struct");
  1452. default:
  1453. llvm_unreachable(
  1454. "invalid usage of SV_ClipDistance/SV_CullDistance sneaked in");
  1455. }
  1456. }
  1457. // According to DXIL spec, the IsFrontFace SV can only be used by GSOut and
  1458. // PSIn.
  1459. // According to Vulkan spec, the FrontFacing BuitIn can only be used in PSIn.
  1460. case hlsl::Semantic::Kind::IsFrontFace: {
  1461. switch (sigPointKind) {
  1462. case hlsl::SigPoint::Kind::GSOut:
  1463. return theBuilder.addStageIOVar(type, sc, name.str());
  1464. case hlsl::SigPoint::Kind::PSIn:
  1465. stageVar->setIsSpirvBuiltin();
  1466. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FrontFacing);
  1467. default:
  1468. llvm_unreachable("invalid usage of SV_IsFrontFace sneaked in");
  1469. }
  1470. }
  1471. // According to DXIL spec, the Target SV can only be used by PSOut.
  1472. // There is no corresponding builtin decoration in SPIR-V. So generate normal
  1473. // Vulkan stage input/output variables.
  1474. case hlsl::Semantic::Kind::Target:
  1475. // An arbitrary semantic is defined by users. Generate normal Vulkan stage
  1476. // input/output variables.
  1477. case hlsl::Semantic::Kind::Arbitrary: {
  1478. return theBuilder.addStageIOVar(type, sc, name.str());
  1479. // TODO: patch constant function in hull shader
  1480. }
  1481. // According to DXIL spec, the DispatchThreadID SV can only be used by CSIn.
  1482. // According to Vulkan spec, the GlobalInvocationId can only be used in CSIn.
  1483. case hlsl::Semantic::Kind::DispatchThreadID: {
  1484. stageVar->setIsSpirvBuiltin();
  1485. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::GlobalInvocationId);
  1486. }
  1487. // According to DXIL spec, the GroupID SV can only be used by CSIn.
  1488. // According to Vulkan spec, the WorkgroupId can only be used in CSIn.
  1489. case hlsl::Semantic::Kind::GroupID: {
  1490. stageVar->setIsSpirvBuiltin();
  1491. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::WorkgroupId);
  1492. }
  1493. // According to DXIL spec, the GroupThreadID SV can only be used by CSIn.
  1494. // According to Vulkan spec, the LocalInvocationId can only be used in CSIn.
  1495. case hlsl::Semantic::Kind::GroupThreadID: {
  1496. stageVar->setIsSpirvBuiltin();
  1497. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::LocalInvocationId);
  1498. }
  1499. // According to DXIL spec, the GroupIndex SV can only be used by CSIn.
  1500. // According to Vulkan spec, the LocalInvocationIndex can only be used in
  1501. // CSIn.
  1502. case hlsl::Semantic::Kind::GroupIndex: {
  1503. stageVar->setIsSpirvBuiltin();
  1504. return theBuilder.addStageBuiltinVar(type, sc,
  1505. BuiltIn::LocalInvocationIndex);
  1506. }
  1507. // According to DXIL spec, the OutputControlID SV can only be used by HSIn.
  1508. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  1509. // HS/GS In.
  1510. case hlsl::Semantic::Kind::OutputControlPointID: {
  1511. stageVar->setIsSpirvBuiltin();
  1512. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId);
  1513. }
  1514. // According to DXIL spec, the PrimitiveID SV can only be used by PCIn, HSIn,
  1515. // DSIn, GSIn, GSOut, and PSIn.
  1516. // According to Vulkan spec, the PrimitiveId BuiltIn can only be used in
  1517. // HS/DS/PS In, GS In/Out.
  1518. case hlsl::Semantic::Kind::PrimitiveID: {
  1519. // PrimitiveId requires either Tessellation or Geometry capability.
  1520. // Need to require one for PSIn.
  1521. if (sigPointKind == hlsl::SigPoint::Kind::PSIn)
  1522. theBuilder.requireCapability(spv::Capability::Geometry);
  1523. // Translate to PrimitiveId BuiltIn for all valid SigPoints.
  1524. stageVar->setIsSpirvBuiltin();
  1525. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::PrimitiveId);
  1526. }
  1527. // According to DXIL spec, the TessFactor SV can only be used by PCOut and
  1528. // DSIn.
  1529. // According to Vulkan spec, the TessLevelOuter BuiltIn can only be used in
  1530. // PCOut and DSIn.
  1531. case hlsl::Semantic::Kind::TessFactor: {
  1532. stageVar->setIsSpirvBuiltin();
  1533. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelOuter);
  1534. }
  1535. // According to DXIL spec, the InsideTessFactor SV can only be used by PCOut
  1536. // and DSIn.
  1537. // According to Vulkan spec, the TessLevelInner BuiltIn can only be used in
  1538. // PCOut and DSIn.
  1539. case hlsl::Semantic::Kind::InsideTessFactor: {
  1540. stageVar->setIsSpirvBuiltin();
  1541. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessLevelInner);
  1542. }
  1543. // According to DXIL spec, the DomainLocation SV can only be used by DSIn.
  1544. // According to Vulkan spec, the TessCoord BuiltIn can only be used in DSIn.
  1545. case hlsl::Semantic::Kind::DomainLocation: {
  1546. stageVar->setIsSpirvBuiltin();
  1547. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::TessCoord);
  1548. }
  1549. // According to DXIL spec, the GSInstanceID SV can only be used by GSIn.
  1550. // According to Vulkan spec, the InvocationId BuiltIn can only be used in
  1551. // HS/GS In.
  1552. case hlsl::Semantic::Kind::GSInstanceID: {
  1553. stageVar->setIsSpirvBuiltin();
  1554. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::InvocationId);
  1555. }
  1556. // According to DXIL spec, the SampleIndex SV can only be used by PSIn.
  1557. // According to Vulkan spec, the SampleId BuiltIn can only be used in PSIn.
  1558. case hlsl::Semantic::Kind::SampleIndex: {
  1559. theBuilder.requireCapability(spv::Capability::SampleRateShading);
  1560. stageVar->setIsSpirvBuiltin();
  1561. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleId);
  1562. }
  1563. // According to DXIL spec, the StencilRef SV can only be used by PSOut.
  1564. case hlsl::Semantic::Kind::StencilRef: {
  1565. theBuilder.addExtension("SPV_EXT_shader_stencil_export");
  1566. theBuilder.requireCapability(spv::Capability::StencilExportEXT);
  1567. stageVar->setIsSpirvBuiltin();
  1568. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::FragStencilRefEXT);
  1569. }
  1570. // According to DXIL spec, the ViewID SV can only be used by PSIn.
  1571. case hlsl::Semantic::Kind::Barycentrics: {
  1572. theBuilder.addExtension("SPV_AMD_shader_explicit_vertex_parameter");
  1573. stageVar->setIsSpirvBuiltin();
  1574. // Selecting the correct builtin according to interpolation mode
  1575. auto bi = BuiltIn::Max;
  1576. if (decl->hasAttr<HLSLNoPerspectiveAttr>()) {
  1577. if (decl->hasAttr<HLSLCentroidAttr>()) {
  1578. bi = BuiltIn::BaryCoordNoPerspCentroidAMD;
  1579. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  1580. bi = BuiltIn::BaryCoordNoPerspSampleAMD;
  1581. } else {
  1582. bi = BuiltIn::BaryCoordNoPerspAMD;
  1583. }
  1584. } else {
  1585. if (decl->hasAttr<HLSLCentroidAttr>()) {
  1586. bi = BuiltIn::BaryCoordSmoothCentroidAMD;
  1587. } else if (decl->hasAttr<HLSLSampleAttr>()) {
  1588. bi = BuiltIn::BaryCoordSmoothSampleAMD;
  1589. } else {
  1590. bi = BuiltIn::BaryCoordSmoothAMD;
  1591. }
  1592. }
  1593. return theBuilder.addStageBuiltinVar(type, sc, bi);
  1594. }
  1595. // According to DXIL spec, the RenderTargetArrayIndex SV can only be used by
  1596. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
  1597. // According to Vulkan spec, the Layer BuiltIn can only be used in GSOut and
  1598. // PSIn.
  1599. case hlsl::Semantic::Kind::RenderTargetArrayIndex: {
  1600. switch (sigPointKind) {
  1601. case hlsl::SigPoint::Kind::VSIn:
  1602. case hlsl::SigPoint::Kind::VSOut:
  1603. case hlsl::SigPoint::Kind::HSCPIn:
  1604. case hlsl::SigPoint::Kind::HSCPOut:
  1605. case hlsl::SigPoint::Kind::PCOut:
  1606. case hlsl::SigPoint::Kind::DSIn:
  1607. case hlsl::SigPoint::Kind::DSCPIn:
  1608. case hlsl::SigPoint::Kind::DSOut:
  1609. case hlsl::SigPoint::Kind::GSVIn:
  1610. return theBuilder.addStageIOVar(type, sc, name.str());
  1611. case hlsl::SigPoint::Kind::GSOut:
  1612. case hlsl::SigPoint::Kind::PSIn:
  1613. theBuilder.requireCapability(spv::Capability::Geometry);
  1614. stageVar->setIsSpirvBuiltin();
  1615. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::Layer);
  1616. default:
  1617. llvm_unreachable("invalid usage of SV_RenderTargetArrayIndex sneaked in");
  1618. }
  1619. }
  1620. // According to DXIL spec, the ViewportArrayIndex SV can only be used by
  1621. // VSIn, VSOut, HSCPIn, HSCPOut, DSIn, DSOut, GSVIn, GSOut, PSIn.
  1622. // According to Vulkan spec, the ViewportIndex BuiltIn can only be used in
  1623. // GSOut and PSIn.
  1624. case hlsl::Semantic::Kind::ViewPortArrayIndex: {
  1625. switch (sigPointKind) {
  1626. case hlsl::SigPoint::Kind::VSIn:
  1627. case hlsl::SigPoint::Kind::VSOut:
  1628. case hlsl::SigPoint::Kind::HSCPIn:
  1629. case hlsl::SigPoint::Kind::HSCPOut:
  1630. case hlsl::SigPoint::Kind::PCOut:
  1631. case hlsl::SigPoint::Kind::DSIn:
  1632. case hlsl::SigPoint::Kind::DSCPIn:
  1633. case hlsl::SigPoint::Kind::DSOut:
  1634. case hlsl::SigPoint::Kind::GSVIn:
  1635. return theBuilder.addStageIOVar(type, sc, name.str());
  1636. case hlsl::SigPoint::Kind::GSOut:
  1637. case hlsl::SigPoint::Kind::PSIn:
  1638. theBuilder.requireCapability(spv::Capability::MultiViewport);
  1639. stageVar->setIsSpirvBuiltin();
  1640. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewportIndex);
  1641. default:
  1642. llvm_unreachable("invalid usage of SV_ViewportArrayIndex sneaked in");
  1643. }
  1644. }
  1645. // According to DXIL spec, the Coverage SV can only be used by PSIn and PSOut.
  1646. // According to Vulkan spec, the SampleMask BuiltIn can only be used in
  1647. // PSIn and PSOut.
  1648. case hlsl::Semantic::Kind::Coverage: {
  1649. stageVar->setIsSpirvBuiltin();
  1650. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::SampleMask);
  1651. }
  1652. // According to DXIL spec, the ViewID SV can only be used by VSIn, PCIn,
  1653. // HSIn, DSIn, GSIn, PSIn.
  1654. // According to Vulkan spec, the ViewIndex BuiltIn can only be used in
  1655. // VS/HS/DS/GS/PS input.
  1656. case hlsl::Semantic::Kind::ViewID: {
  1657. theBuilder.addExtension("SPV_KHR_multiview");
  1658. theBuilder.requireCapability(spv::Capability::MultiView);
  1659. stageVar->setIsSpirvBuiltin();
  1660. return theBuilder.addStageBuiltinVar(type, sc, BuiltIn::ViewIndex);
  1661. }
  1662. case hlsl::Semantic::Kind::InnerCoverage: {
  1663. emitError("no equivalent for semantic SV_InnerCoverage in Vulkan", srcLoc);
  1664. return 0;
  1665. }
  1666. default:
  1667. emitError("semantic %0 unimplemented", srcLoc)
  1668. << stageVar->getSemantic()->GetName();
  1669. break;
  1670. }
  1671. return 0;
  1672. }
  1673. bool DeclResultIdMapper::validateVKBuiltins(const DeclaratorDecl *decl,
  1674. const hlsl::SigPoint *sigPoint) {
  1675. bool success = true;
  1676. if (const auto *builtinAttr = decl->getAttr<VKBuiltInAttr>()) {
  1677. const auto declType = getTypeOrFnRetType(decl);
  1678. const auto loc = builtinAttr->getLocation();
  1679. if (decl->hasAttr<VKLocationAttr>()) {
  1680. emitError("cannot use vk::builtin and vk::location together", loc);
  1681. success = false;
  1682. }
  1683. const llvm::StringRef builtin = builtinAttr->getBuiltIn();
  1684. if (builtin == "HelperInvocation") {
  1685. if (!declType->isBooleanType()) {
  1686. emitError("HelperInvocation builtin must be of boolean type", loc);
  1687. success = false;
  1688. }
  1689. if (sigPoint->GetKind() != hlsl::SigPoint::Kind::PSIn) {
  1690. emitError(
  1691. "HelperInvocation builtin can only be used as pixel shader input",
  1692. loc);
  1693. success = false;
  1694. }
  1695. } else if (builtin == "PointSize") {
  1696. if (!declType->isFloatingType()) {
  1697. emitError("PointSize builtin must be of float type", loc);
  1698. success = false;
  1699. }
  1700. switch (sigPoint->GetKind()) {
  1701. case hlsl::SigPoint::Kind::VSOut:
  1702. case hlsl::SigPoint::Kind::HSCPIn:
  1703. case hlsl::SigPoint::Kind::HSCPOut:
  1704. case hlsl::SigPoint::Kind::DSCPIn:
  1705. case hlsl::SigPoint::Kind::DSOut:
  1706. case hlsl::SigPoint::Kind::GSVIn:
  1707. case hlsl::SigPoint::Kind::GSOut:
  1708. case hlsl::SigPoint::Kind::PSIn:
  1709. break;
  1710. default:
  1711. emitError("PointSize builtin cannot be used as %0", loc)
  1712. << sigPoint->GetName();
  1713. success = false;
  1714. }
  1715. }
  1716. }
  1717. return success;
  1718. }
  1719. spv::StorageClass
  1720. DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
  1721. // This translation is done based on the HLSL reference (see docs/dxil.rst).
  1722. const auto sigPointKind = sigPoint->GetKind();
  1723. const auto signatureKind = sigPoint->GetSignatureKind();
  1724. spv::StorageClass sc = spv::StorageClass::Max;
  1725. switch (signatureKind) {
  1726. case hlsl::DXIL::SignatureKind::Input:
  1727. sc = spv::StorageClass::Input;
  1728. break;
  1729. case hlsl::DXIL::SignatureKind::Output:
  1730. sc = spv::StorageClass::Output;
  1731. break;
  1732. case hlsl::DXIL::SignatureKind::Invalid: {
  1733. // There are some special cases in HLSL (See docs/dxil.rst):
  1734. // SignatureKind is "invalid" for PCIn, HSIn, GSIn, and CSIn.
  1735. switch (sigPointKind) {
  1736. case hlsl::DXIL::SigPointKind::PCIn:
  1737. case hlsl::DXIL::SigPointKind::HSIn:
  1738. case hlsl::DXIL::SigPointKind::GSIn:
  1739. case hlsl::DXIL::SigPointKind::CSIn:
  1740. sc = spv::StorageClass::Input;
  1741. break;
  1742. default:
  1743. llvm_unreachable("Found invalid SigPoint kind for semantic");
  1744. }
  1745. break;
  1746. }
  1747. case hlsl::DXIL::SignatureKind::PatchConstant: {
  1748. // There are some special cases in HLSL (See docs/dxil.rst):
  1749. // SignatureKind is "PatchConstant" for PCOut and DSIn.
  1750. switch (sigPointKind) {
  1751. case hlsl::DXIL::SigPointKind::PCOut:
  1752. // Patch Constant Output (Output of Hull which is passed to Domain).
  1753. sc = spv::StorageClass::Output;
  1754. break;
  1755. case hlsl::DXIL::SigPointKind::DSIn:
  1756. // Domain Shader regular input - Patch Constant data plus system values.
  1757. sc = spv::StorageClass::Input;
  1758. break;
  1759. default:
  1760. llvm_unreachable("Found invalid SigPoint kind for semantic");
  1761. }
  1762. break;
  1763. }
  1764. default:
  1765. llvm_unreachable("Found invalid SigPoint kind for semantic");
  1766. }
  1767. return sc;
  1768. }
  1769. uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
  1770. const DeclaratorDecl *decl, bool *shouldBeAlias, SpirvEvalInfo *info) {
  1771. if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
  1772. // This method is only intended to be used to create SPIR-V variables in the
  1773. // Function or Private storage class.
  1774. assert(!varDecl->isExternallyVisible() || varDecl->isStaticDataMember());
  1775. }
  1776. const QualType type = getTypeOrFnRetType(decl);
  1777. // Whether we should generate this decl as an alias variable.
  1778. bool genAlias = false;
  1779. if (const auto *buffer = dyn_cast<HLSLBufferDecl>(decl->getDeclContext())) {
  1780. // For ConstantBuffer and TextureBuffer
  1781. if (buffer->isConstantBufferView())
  1782. genAlias = true;
  1783. } else if (TypeTranslator::isOrContainsAKindOfStructuredOrByteBuffer(type)) {
  1784. genAlias = true;
  1785. }
  1786. if (shouldBeAlias)
  1787. *shouldBeAlias = genAlias;
  1788. if (genAlias) {
  1789. needsLegalization = true;
  1790. createCounterVarForDecl(decl);
  1791. if (info)
  1792. info->setContainsAliasComponent(true);
  1793. }
  1794. return typeTranslator.translateType(type);
  1795. }
  1796. } // end namespace spirv
  1797. } // end namespace clang