|
@@ -761,7 +761,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
|
|
const FunctionInfo *entryInfo = workQueue[i];
|
|
const FunctionInfo *entryInfo = workQueue[i];
|
|
assert(entryInfo->isEntryFunction);
|
|
assert(entryInfo->isEntryFunction);
|
|
spvBuilder.addEntryPoint(
|
|
spvBuilder.addEntryPoint(
|
|
- getSpirvShaderStage(entryInfo->shaderModelKind),
|
|
|
|
|
|
+ getSpirvShaderStage(entryInfo->shaderModelKind, featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)),
|
|
entryInfo->entryFunction, getEntryPointName(entryInfo),
|
|
entryInfo->entryFunction, getEntryPointName(entryInfo),
|
|
getInterfacesForEntryPoint(entryInfo->entryFunction));
|
|
getInterfacesForEntryPoint(entryInfo->entryFunction));
|
|
}
|
|
}
|
|
@@ -7330,6 +7330,9 @@ void SpirvEmitter::assignToMSOutIndices(
|
|
const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
|
|
const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
|
|
assert(spvContext.isMS() && !indices.empty());
|
|
assert(spvContext.isMS() && !indices.empty());
|
|
|
|
|
|
|
|
+ bool extMesh =
|
|
|
|
+ featureManager.isExtensionEnabled(Extension::EXT_mesh_shader);
|
|
|
|
+
|
|
// Extract vertex index and vecComponent (if any).
|
|
// Extract vertex index and vecComponent (if any).
|
|
SpirvInstruction *vertIndex = indices.front();
|
|
SpirvInstruction *vertIndex = indices.front();
|
|
SpirvInstruction *vecComponent = nullptr;
|
|
SpirvInstruction *vecComponent = nullptr;
|
|
@@ -7361,45 +7364,65 @@ void SpirvEmitter::assignToMSOutIndices(
|
|
} else {
|
|
} else {
|
|
// for "line" or "triangle" output topology.
|
|
// for "line" or "triangle" output topology.
|
|
assert(numVertices == 2 || numVertices == 3);
|
|
assert(numVertices == 2 || numVertices == 3);
|
|
- // set baseOffset = vertIndex * numVertices.
|
|
|
|
- auto *baseOffset = spvBuilder.createBinaryOp(
|
|
|
|
- spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
|
|
|
|
- spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
- llvm::APInt(32, numVertices)),
|
|
|
|
- loc);
|
|
|
|
|
|
+
|
|
if (vecComponent) {
|
|
if (vecComponent) {
|
|
// write an individual vector component of uint2 or uint3.
|
|
// write an individual vector component of uint2 or uint3.
|
|
assert(numValues == 1);
|
|
assert(numValues == 1);
|
|
- // set baseOffset = baseOffset + vecComponent.
|
|
|
|
- baseOffset =
|
|
|
|
- spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
|
|
|
|
- baseOffset, vecComponent, loc);
|
|
|
|
- // create accesschain for PrimitiveIndicesNV[baseOffset].
|
|
|
|
- auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
- {baseOffset}, loc);
|
|
|
|
- // finally create store for PrimitiveIndicesNV[baseOffset] = value.
|
|
|
|
- spvBuilder.createStore(ptr, value, loc);
|
|
|
|
|
|
+ if (extMesh) {
|
|
|
|
+ // create accesschain for Primitive*IndicesEXT[vertIndex][vecComponent].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(
|
|
|
|
+ astContext.UnsignedIntTy, var, {vertIndex, vecComponent}, loc);
|
|
|
|
+ // finally create store for Primitive*IndicesEXT[vertIndex][vecComponent] = value.
|
|
|
|
+ spvBuilder.createStore(ptr, value, loc);
|
|
|
|
+ } else {
|
|
|
|
+ // set baseOffset = vertIndex * numVertices.
|
|
|
|
+ auto *baseOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
|
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
+ llvm::APInt(32, numVertices)), loc);
|
|
|
|
+ // set baseOffset = baseOffset + vecComponent.
|
|
|
|
+ baseOffset =
|
|
|
|
+ spvBuilder.createBinaryOp(spv::Op::OpIAdd, astContext.UnsignedIntTy,
|
|
|
|
+ baseOffset, vecComponent, loc);
|
|
|
|
+ // create accesschain for PrimitiveIndicesNV[baseOffset].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
+ {baseOffset}, loc);
|
|
|
|
+ // finally create store for PrimitiveIndicesNV[baseOffset] = value.
|
|
|
|
+ spvBuilder.createStore(ptr, value, loc);
|
|
|
|
+ }
|
|
} else {
|
|
} else {
|
|
- // write all vector components of uint2 or uint3.
|
|
|
|
assert(numValues == numVertices);
|
|
assert(numValues == numVertices);
|
|
- auto *curOffset = baseOffset;
|
|
|
|
- for (uint32_t i = 0; i < numValues; ++i) {
|
|
|
|
- if (i != 0) {
|
|
|
|
- // set curOffset = baseOffset + i.
|
|
|
|
- curOffset = spvBuilder.createBinaryOp(
|
|
|
|
- spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
|
|
|
|
- spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
- llvm::APInt(32, i)),
|
|
|
|
- loc);
|
|
|
|
|
|
+ if (extMesh) {
|
|
|
|
+ // create accesschain for Primitive*IndicesEXT[vertIndex].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(varType, var, vertIndex, loc);
|
|
|
|
+ // finally create store for Primitive*IndicesEXT[vertIndex] = value.
|
|
|
|
+ spvBuilder.createStore(ptr, value, loc);
|
|
|
|
+ } else {
|
|
|
|
+ // set baseOffset = vertIndex * numVertices.
|
|
|
|
+ auto *baseOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
|
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
+ llvm::APInt(32, numVertices)), loc);
|
|
|
|
+ // write all vector components of uint2 or uint3.
|
|
|
|
+ auto *curOffset = baseOffset;
|
|
|
|
+ for (uint32_t i = 0; i < numValues; ++i) {
|
|
|
|
+ if (i != 0) {
|
|
|
|
+ // set curOffset = baseOffset + i.
|
|
|
|
+ curOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
|
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
+ llvm::APInt(32, i)),
|
|
|
|
+ loc);
|
|
|
|
+ }
|
|
|
|
+ // create accesschain for PrimitiveIndicesNV[curOffset].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy,
|
|
|
|
+ var, {curOffset}, loc);
|
|
|
|
+ // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
|
|
|
|
+ spvBuilder.createStore(ptr,
|
|
|
|
+ spvBuilder.createCompositeExtract(
|
|
|
|
+ astContext.UnsignedIntTy, value, {i}, loc),
|
|
|
|
+ loc);
|
|
}
|
|
}
|
|
- // create accesschain for PrimitiveIndicesNV[curOffset].
|
|
|
|
- auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
- {curOffset}, loc);
|
|
|
|
- // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
|
|
|
|
- spvBuilder.createStore(ptr,
|
|
|
|
- spvBuilder.createCompositeExtract(
|
|
|
|
- astContext.UnsignedIntTy, value, {i}, loc),
|
|
|
|
- loc);
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -11319,26 +11342,16 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
|
|
/*isDevice*/ false,
|
|
/*isDevice*/ false,
|
|
/*groupSync*/ true,
|
|
/*groupSync*/ true,
|
|
/*isAllBarrier*/ false);
|
|
/*isAllBarrier*/ false);
|
|
-
|
|
|
|
- // 2) set TaskCountNV = threadX * threadY * threadZ.
|
|
|
|
- auto *threadX = doExpr(args[0]);
|
|
|
|
- auto *threadY = doExpr(args[1]);
|
|
|
|
- auto *threadZ = doExpr(args[2]);
|
|
|
|
- auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
|
|
|
|
- astContext.UnsignedIntTy, loc);
|
|
|
|
- auto *taskCount = spvBuilder.createBinaryOp(
|
|
|
|
- spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
|
|
|
|
- spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
|
|
|
|
- threadY, threadZ, loc, range),
|
|
|
|
- loc, range);
|
|
|
|
- spvBuilder.createStore(var, taskCount, loc, range);
|
|
|
|
-
|
|
|
|
- // 3) create PerTaskNV out attribute block and store MeshPayload info.
|
|
|
|
|
|
+
|
|
|
|
+ // 2) create PerTaskNV out attribute block and store MeshPayload info.
|
|
const auto *sigPoint =
|
|
const auto *sigPoint =
|
|
hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
|
|
hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
|
|
- spv::StorageClass sc = spv::StorageClass::Output;
|
|
|
|
|
|
+ spv::StorageClass sc = featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)
|
|
|
|
+ ? spv::StorageClass::TaskPayloadWorkgroupEXT
|
|
|
|
+ : spv::StorageClass::Output;
|
|
auto *payloadArg = doExpr(args[3]);
|
|
auto *payloadArg = doExpr(args[3]);
|
|
bool isValid = false;
|
|
bool isValid = false;
|
|
|
|
+ const VarDecl *param = nullptr;
|
|
if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
|
|
if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
|
|
if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
|
|
if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
|
|
if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
@@ -11346,6 +11359,7 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
|
|
isValid = declIdMapper.createPayloadStageVars(
|
|
isValid = declIdMapper.createPayloadStageVars(
|
|
sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
|
|
sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
|
|
"out.var", &payloadArg);
|
|
"out.var", &payloadArg);
|
|
|
|
+ param = paramDecl;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -11354,6 +11368,26 @@ void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
|
|
emitError("expected groupshared object as argument to DispatchMesh()",
|
|
emitError("expected groupshared object as argument to DispatchMesh()",
|
|
args[3]->getExprLoc());
|
|
args[3]->getExprLoc());
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ // 3) set up emit dimension.
|
|
|
|
+ auto *threadX = doExpr(args[0]);
|
|
|
|
+ auto *threadY = doExpr(args[1]);
|
|
|
|
+ auto *threadZ = doExpr(args[2]);
|
|
|
|
+
|
|
|
|
+ if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
|
|
|
|
+ // for EXT_mesh_shader, create opEmitMeshTasksEXT.
|
|
|
|
+ spvBuilder.createEmitMeshTasksEXT(threadX, threadY, threadZ, loc, nullptr, range);
|
|
|
|
+ } else {
|
|
|
|
+ // for NV_mesh_shader, set TaskCountNV = threadX * threadY * threadZ.
|
|
|
|
+ auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
|
|
|
|
+ astContext.UnsignedIntTy, loc);
|
|
|
|
+ auto *taskCount = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
|
|
|
|
+ spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
|
|
|
|
+ threadY, threadZ, loc, range),
|
|
|
|
+ loc, range);
|
|
|
|
+ spvBuilder.createStore(var, taskCount, loc, range);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
|
|
void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
|
|
@@ -11362,9 +11396,14 @@ void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
|
|
const auto args = callExpr->getArgs();
|
|
const auto args = callExpr->getArgs();
|
|
const auto loc = callExpr->getExprLoc();
|
|
const auto loc = callExpr->getExprLoc();
|
|
const auto range = callExpr->getSourceRange();
|
|
const auto range = callExpr->getSourceRange();
|
|
- auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
|
|
|
|
|
|
+
|
|
|
|
+ if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
|
|
|
|
+ spvBuilder.createSetMeshOutputsEXT(doExpr(args[0]), doExpr(args[1]), loc, range);
|
|
|
|
+ } else {
|
|
|
|
+ auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
|
|
astContext.UnsignedIntTy, loc);
|
|
astContext.UnsignedIntTy, loc);
|
|
- spvBuilder.createStore(var, doExpr(args[1]), loc, range);
|
|
|
|
|
|
+ spvBuilder.createStore(var, doExpr(args[1]), loc, range);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
|
|
SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
|
|
@@ -11687,7 +11726,7 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
|
|
}
|
|
}
|
|
|
|
|
|
spv::ExecutionModel
|
|
spv::ExecutionModel
|
|
-SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
|
|
|
|
|
|
+SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk, bool extMeshShading) {
|
|
switch (smk) {
|
|
switch (smk) {
|
|
case hlsl::ShaderModel::Kind::Vertex:
|
|
case hlsl::ShaderModel::Kind::Vertex:
|
|
return spv::ExecutionModel::Vertex;
|
|
return spv::ExecutionModel::Vertex;
|
|
@@ -11714,9 +11753,13 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
|
|
case hlsl::ShaderModel::Kind::Callable:
|
|
case hlsl::ShaderModel::Kind::Callable:
|
|
return spv::ExecutionModel::CallableNV;
|
|
return spv::ExecutionModel::CallableNV;
|
|
case hlsl::ShaderModel::Kind::Mesh:
|
|
case hlsl::ShaderModel::Kind::Mesh:
|
|
- return spv::ExecutionModel::MeshNV;
|
|
|
|
|
|
+ return extMeshShading ?
|
|
|
|
+ spv::ExecutionModel::MeshEXT:
|
|
|
|
+ spv::ExecutionModel::MeshNV;
|
|
case hlsl::ShaderModel::Kind::Amplification:
|
|
case hlsl::ShaderModel::Kind::Amplification:
|
|
- return spv::ExecutionModel::TaskNV;
|
|
|
|
|
|
+ return extMeshShading ?
|
|
|
|
+ spv::ExecutionModel::TaskEXT:
|
|
|
|
+ spv::ExecutionModel::TaskNV;
|
|
default:
|
|
default:
|
|
llvm_unreachable("invalid shader model kind");
|
|
llvm_unreachable("invalid shader model kind");
|
|
break;
|
|
break;
|