|
@@ -11,7 +11,9 @@
|
|
|
|
|
|
#include "dxc/HLSL/DxilConstants.h"
|
|
|
#include "dxc/HLSL/DxilTypeSystem.h"
|
|
|
+#include "clang/AST/Expr.h"
|
|
|
#include "clang/AST/HlslTypes.h"
|
|
|
+#include "clang/AST/RecursiveASTVisitor.h"
|
|
|
|
|
|
namespace clang {
|
|
|
namespace spirv {
|
|
@@ -29,31 +31,42 @@ bool DeclResultIdMapper::createStageVarFromFnParam(
|
|
|
|
|
|
void DeclResultIdMapper::registerDeclResultId(const NamedDecl *symbol,
|
|
|
uint32_t resultId) {
|
|
|
- normalDecls[symbol] = resultId;
|
|
|
+ auto sc = spv::StorageClass::Function;
|
|
|
+ // TODO: need to fix the storage class for other cases
|
|
|
+ if (const auto *varDecl = dyn_cast<VarDecl>(symbol)) {
|
|
|
+ if (!varDecl->isLocalVarDecl()) {
|
|
|
+ // Global variables are by default constant. But the default behavior
|
|
|
+ // can be changed via command line option.
|
|
|
+ sc = spv::StorageClass::Uniform;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ normalDecls[symbol] = {resultId, sc};
|
|
|
}
|
|
|
|
|
|
-bool DeclResultIdMapper::isStageVariable(uint32_t varId) const {
|
|
|
- for (const auto &var : stageVars)
|
|
|
- if (var.getSpirvId() == varId)
|
|
|
- return true;
|
|
|
- return false;
|
|
|
+const DeclResultIdMapper::DeclSpirvInfo *
|
|
|
+DeclResultIdMapper::getDeclSpirvInfo(const NamedDecl *decl) const {
|
|
|
+ auto it = remappedDecls.find(decl);
|
|
|
+ if (it != remappedDecls.end())
|
|
|
+ return &it->second;
|
|
|
+
|
|
|
+ it = normalDecls.find(decl);
|
|
|
+ if (it != normalDecls.end())
|
|
|
+ return &it->second;
|
|
|
+
|
|
|
+ return nullptr;
|
|
|
}
|
|
|
|
|
|
uint32_t DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) const {
|
|
|
- if (const uint32_t id = getNormalDeclResultId(decl))
|
|
|
- return id;
|
|
|
- if (const uint32_t id = getRemappedDeclResultId(decl))
|
|
|
- return id;
|
|
|
+ if (const auto *info = getDeclSpirvInfo(decl))
|
|
|
+ return info->resultId;
|
|
|
|
|
|
assert(false && "found unregistered decl");
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
uint32_t DeclResultIdMapper::getOrRegisterDeclResultId(const NamedDecl *decl) {
|
|
|
- if (const uint32_t id = getNormalDeclResultId(decl))
|
|
|
- return id;
|
|
|
- if (const uint32_t id = getRemappedDeclResultId(decl))
|
|
|
- return id;
|
|
|
+ if (const auto *info = getDeclSpirvInfo(decl))
|
|
|
+ return info->resultId;
|
|
|
|
|
|
const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
|
|
|
registerDeclResultId(decl, id);
|
|
@@ -65,16 +78,67 @@ uint32_t
|
|
|
DeclResultIdMapper::getRemappedDeclResultId(const NamedDecl *decl) const {
|
|
|
auto it = remappedDecls.find(decl);
|
|
|
if (it != remappedDecls.end())
|
|
|
- return it->second;
|
|
|
+ return it->second.resultId;
|
|
|
return 0;
|
|
|
}
|
|
|
|
|
|
-uint32_t
|
|
|
-DeclResultIdMapper::getNormalDeclResultId(const NamedDecl *decl) const {
|
|
|
- auto it = normalDecls.find(decl);
|
|
|
- if (it != normalDecls.end())
|
|
|
- return it->second;
|
|
|
- return 0;
|
|
|
+namespace {
|
|
|
+/// A class for resolving the storage class of a given Decl or Expr.
|
|
|
+class StorageClassResolver : public RecursiveASTVisitor<StorageClassResolver> {
|
|
|
+public:
|
|
|
+ explicit StorageClassResolver(const DeclResultIdMapper &mapper)
|
|
|
+ : declIdMapper(mapper), storageClass(spv::StorageClass::Max) {}
|
|
|
+
|
|
|
+ // For querying the storage class of a remapped decl
|
|
|
+
|
|
|
+ // Semantics may be attached to FunctionDecl, ParmVarDecl, and FieldDecl.
|
|
|
+ // We create stage variables for them and we may need to query the storage
|
|
|
+ // classes of these stage variables.
|
|
|
+ bool VisitFunctionDecl(FunctionDecl *decl) { return processDecl(decl); }
|
|
|
+ bool VisitFieldDecl(FieldDecl *decl) { return processDecl(decl); }
|
|
|
+ bool VisitParmVarDecl(ParmVarDecl *decl) { return processDecl(decl); }
|
|
|
+
|
|
|
+ // For querying the storage class of a normal decl
|
|
|
+
|
|
|
+ // Normal decls should be referred in expressions.
|
|
|
+ bool VisitDeclRefExpr(DeclRefExpr *expr) {
|
|
|
+ return processDecl(expr->getDecl());
|
|
|
+ }
|
|
|
+
|
|
|
+ bool processDecl(NamedDecl *decl) {
|
|
|
+ const auto *info = declIdMapper.getDeclSpirvInfo(decl);
|
|
|
+ assert(info);
|
|
|
+ if (storageClass == spv::StorageClass::Max) {
|
|
|
+ storageClass = info->storageClass;
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Two decls with different storage classes are referenced in this
|
|
|
+ // expression. We should not visit such expression using this class.
|
|
|
+ assert(storageClass == info->storageClass);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ spv::StorageClass get() const { return storageClass; }
|
|
|
+
|
|
|
+private:
|
|
|
+ const DeclResultIdMapper &declIdMapper;
|
|
|
+ spv::StorageClass storageClass;
|
|
|
+};
|
|
|
+} // namespace
|
|
|
+
|
|
|
+spv::StorageClass
|
|
|
+DeclResultIdMapper::resolveStorageClass(const Expr *expr) const {
|
|
|
+ auto resolver = StorageClassResolver(*this);
|
|
|
+ resolver.TraverseStmt(const_cast<Expr *>(expr));
|
|
|
+ return resolver.get();
|
|
|
+}
|
|
|
+
|
|
|
+spv::StorageClass
|
|
|
+DeclResultIdMapper::resolveStorageClass(const Decl *decl) const {
|
|
|
+ auto resolver = StorageClassResolver(*this);
|
|
|
+ resolver.TraverseDecl(const_cast<Decl *>(decl));
|
|
|
+ return resolver.get();
|
|
|
}
|
|
|
|
|
|
std::vector<uint32_t> DeclResultIdMapper::collectStageVariables() const {
|
|
@@ -153,7 +217,7 @@ bool DeclResultIdMapper::createStageVariables(const DeclaratorDecl *decl,
|
|
|
stageVar.setSpirvId(varId);
|
|
|
|
|
|
stageVars.push_back(stageVar);
|
|
|
- remappedDecls[decl] = varId;
|
|
|
+ remappedDecls[decl] = {varId, stageVar.getStorageClass()};
|
|
|
} else {
|
|
|
// If the decl itself doesn't have semantic, it should be a struct having
|
|
|
// all its fields with semantics.
|
|
@@ -173,6 +237,8 @@ bool DeclResultIdMapper::createStageVariables(const DeclaratorDecl *decl,
|
|
|
}
|
|
|
|
|
|
uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
+ using spv::BuiltIn;
|
|
|
+
|
|
|
const auto semanticKind = stageVar->getSemantic()->GetKind();
|
|
|
const auto sigPointKind = stageVar->getSigPoint()->GetKind();
|
|
|
const uint32_t type = stageVar->getSpirvTypeId();
|
|
@@ -181,6 +247,12 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
// shader model is already checked, so it only covers valid SigPoints for
|
|
|
// each semantic.
|
|
|
|
|
|
+ // TODO: case for patch constant
|
|
|
+ const auto sc = stageVar->getSigPoint()->IsInput()
|
|
|
+ ? spv::StorageClass::Input
|
|
|
+ : spv::StorageClass::Output;
|
|
|
+ stageVar->setStorageClass(sc);
|
|
|
+
|
|
|
switch (semanticKind) {
|
|
|
// According to DXIL spec, the Position SV can be used by all SigPoints
|
|
|
// other than PCIn, HSIn, GSIn, PSOut, CSIn.
|
|
@@ -189,13 +261,13 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
case hlsl::Semantic::Kind::Position: {
|
|
|
switch (sigPointKind) {
|
|
|
case hlsl::SigPoint::Kind::VSIn:
|
|
|
- return theBuilder.addStageIOVariable(type, spv::StorageClass::Input);
|
|
|
+ return theBuilder.addStageIOVariable(type, sc);
|
|
|
case hlsl::SigPoint::Kind::VSOut:
|
|
|
stageVar->setIsSpirvBuiltin();
|
|
|
- return theBuilder.addStageBuiltinVariable(type, spv::BuiltIn::Position);
|
|
|
+ return theBuilder.addStageBuiltinVariable(type, sc, BuiltIn::Position);
|
|
|
case hlsl::SigPoint::Kind::PSIn:
|
|
|
stageVar->setIsSpirvBuiltin();
|
|
|
- return theBuilder.addStageBuiltinVariable(type, spv::BuiltIn::FragCoord);
|
|
|
+ return theBuilder.addStageBuiltinVariable(type, sc, BuiltIn::FragCoord);
|
|
|
default:
|
|
|
emitError("semantic Position for SigPoint %0 unimplemented yet")
|
|
|
<< stageVar->getSigPoint()->GetName();
|
|
@@ -205,7 +277,7 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
// According to DXIL spec, the VertexID SV can only be used by VSIn.
|
|
|
case hlsl::Semantic::Kind::VertexID:
|
|
|
stageVar->setIsSpirvBuiltin();
|
|
|
- return theBuilder.addStageBuiltinVariable(type, spv::BuiltIn::VertexIndex);
|
|
|
+ return theBuilder.addStageBuiltinVariable(type, sc, BuiltIn::VertexIndex);
|
|
|
// According to DXIL spec, the InstanceID SV can be used by VSIn, VSOut,
|
|
|
// HSCPIn, HSCPOut, DSCPIn, DSOut, GSVIn, GSOut, PSIn.
|
|
|
// According to Vulkan spec, the InstanceIndex can only be used by VSIn.
|
|
@@ -213,12 +285,12 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
switch (sigPointKind) {
|
|
|
case hlsl::SigPoint::Kind::VSIn:
|
|
|
stageVar->setIsSpirvBuiltin();
|
|
|
- return theBuilder.addStageBuiltinVariable(type,
|
|
|
- spv::BuiltIn::InstanceIndex);
|
|
|
+ return theBuilder.addStageBuiltinVariable(type, sc,
|
|
|
+ BuiltIn::InstanceIndex);
|
|
|
case hlsl::SigPoint::Kind::VSOut:
|
|
|
- return theBuilder.addStageIOVariable(type, spv::StorageClass::Output);
|
|
|
+ return theBuilder.addStageIOVariable(type, sc);
|
|
|
case hlsl::SigPoint::Kind::PSIn:
|
|
|
- return theBuilder.addStageIOVariable(type, spv::StorageClass::Input);
|
|
|
+ return theBuilder.addStageIOVariable(type, sc);
|
|
|
default:
|
|
|
emitError("semantic InstanceID for SigPoint %0 unimplemented yet")
|
|
|
<< stageVar->getSigPoint()->GetName();
|
|
@@ -228,7 +300,7 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
// According to DXIL spec, the Depth SV can only be used by PSOut.
|
|
|
case hlsl::Semantic::Kind::Depth:
|
|
|
stageVar->setIsSpirvBuiltin();
|
|
|
- return theBuilder.addStageBuiltinVariable(type, spv::BuiltIn::FragDepth);
|
|
|
+ return theBuilder.addStageBuiltinVariable(type, sc, BuiltIn::FragDepth);
|
|
|
// According to DXIL spec, the Target SV can only be used by PSOut.
|
|
|
// There is no corresponding builtin decoration in SPIR-V. So generate normal
|
|
|
// Vulkan stage input/output variables.
|
|
@@ -236,10 +308,7 @@ uint32_t DeclResultIdMapper::createSpirvStageVar(StageVar *stageVar) {
|
|
|
// An arbitrary semantic is defined by users. Generate normal Vulkan stage
|
|
|
// input/output variables.
|
|
|
case hlsl::Semantic::Kind::Arbitrary: {
|
|
|
- if (stageVar->getSigPoint()->IsInput())
|
|
|
- return theBuilder.addStageIOVariable(type, spv::StorageClass::Input);
|
|
|
- if (stageVar->getSigPoint()->IsOutput())
|
|
|
- return theBuilder.addStageIOVariable(type, spv::StorageClass::Output);
|
|
|
+ return theBuilder.addStageIOVariable(type, sc);
|
|
|
// TODO: patch constant function in hull shader
|
|
|
}
|
|
|
default:
|