|
@@ -511,18 +511,22 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer;
|
|
|
|
+ spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::FxcSBuffer;
|
|
} else if (spirvOptions.useGlLayout) {
|
|
} else if (spirvOptions.useGlLayout) {
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430;
|
|
|
|
+ spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::GLSLStd430;
|
|
} else if (spirvOptions.useScalarLayout) {
|
|
} else if (spirvOptions.useScalarLayout) {
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar;
|
|
|
|
+ spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::Scalar;
|
|
} else {
|
|
} else {
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
|
|
spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
|
|
spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
|
|
spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
|
|
|
|
+ spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430;
|
|
}
|
|
}
|
|
|
|
|
|
// Set shader module version, source file name, and source file content (if
|
|
// Set shader module version, source file name, and source file content (if
|
|
@@ -4989,6 +4993,11 @@ SpirvEmitter::processAssignment(const Expr *lhs, SpirvInstruction *rhs,
|
|
if (SpirvInstruction *result = tryToAssignToRWBufferRWTexture(lhs, rhs))
|
|
if (SpirvInstruction *result = tryToAssignToRWBufferRWTexture(lhs, rhs))
|
|
return result;
|
|
return result;
|
|
|
|
|
|
|
|
+ // Assigning to a out attribute or indices object in mesh shader should be
|
|
|
|
+ // handled differently.
|
|
|
|
+ if (SpirvInstruction *result = tryToAssignToMSOutAttrsOrIndices(lhs, rhs))
|
|
|
|
+ return result;
|
|
|
|
+
|
|
// Normal assignment procedure
|
|
// Normal assignment procedure
|
|
|
|
|
|
if (!lhsPtr)
|
|
if (!lhsPtr)
|
|
@@ -5833,6 +5842,11 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
|
|
(void)result;
|
|
(void)result;
|
|
return rhs; // TODO: incorrect for compound assignments
|
|
return rhs; // TODO: incorrect for compound assignments
|
|
} else {
|
|
} else {
|
|
|
|
+ // Assigning to one component of mesh out attribute/indices vector object.
|
|
|
|
+ SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
|
|
|
|
+ astContext.UnsignedIntTy, llvm::APInt(32, accessor.Swz0));
|
|
|
|
+ if (tryToAssignToMSOutAttrsOrIndices(base, rhs, vecComponent))
|
|
|
|
+ return rhs;
|
|
// Assigning to one normal vector component. Nothing special, just fall
|
|
// Assigning to one normal vector component. Nothing special, just fall
|
|
// back to the normal CodeGen path.
|
|
// back to the normal CodeGen path.
|
|
return nullptr;
|
|
return nullptr;
|
|
@@ -5854,6 +5868,26 @@ SpirvEmitter::tryToAssignToVectorElements(const Expr *lhs,
|
|
return processAssignment(base, rhs, false);
|
|
return processAssignment(base, rhs, false);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if (tryToAssignToMSOutAttrsOrIndices(base, rhs, /*vecComponent=*/nullptr,
|
|
|
|
+ /*noWriteBack=*/true)) {
|
|
|
|
+ // Assigning to 'n' components of mesh out attribute/indices vector object.
|
|
|
|
+ const QualType elemType =
|
|
|
|
+ hlsl::GetHLSLVecElementType(rhs->getAstResultType());
|
|
|
|
+ uint32_t i = 0;
|
|
|
|
+ for (; i < accessor.Count; ++i) {
|
|
|
|
+ auto *rhsElem = spvBuilder.createCompositeExtract(elemType, rhs, {i},
|
|
|
|
+ lhs->getLocStart());
|
|
|
|
+ uint32_t position;
|
|
|
|
+ accessor.GetPosition(i, &position);
|
|
|
|
+ SpirvInstruction *vecComponent = spvBuilder.getConstantInt(
|
|
|
|
+ astContext.UnsignedIntTy, llvm::APInt(32, position));
|
|
|
|
+ if (!tryToAssignToMSOutAttrsOrIndices(base, rhsElem, vecComponent))
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ assert(i == accessor.Count);
|
|
|
|
+ return rhs;
|
|
|
|
+ }
|
|
|
|
+
|
|
llvm::SmallVector<uint32_t, 4> selectors;
|
|
llvm::SmallVector<uint32_t, 4> selectors;
|
|
selectors.resize(baseSize);
|
|
selectors.resize(baseSize);
|
|
// Assume we are selecting all original elements first.
|
|
// Assume we are selecting all original elements first.
|
|
@@ -5969,6 +6003,191 @@ SpirvEmitter::tryToAssignToMatrixElements(const Expr *lhs,
|
|
return rhs;
|
|
return rhs;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+SpirvInstruction *SpirvEmitter::tryToAssignToMSOutAttrsOrIndices(
|
|
|
|
+ const Expr *lhs, SpirvInstruction *rhs, SpirvInstruction *vecComponent,
|
|
|
|
+ bool noWriteBack) {
|
|
|
|
+ // Early exit for non-mesh shaders.
|
|
|
|
+ if (!spvContext.isMS())
|
|
|
|
+ return nullptr;
|
|
|
|
+
|
|
|
|
+ llvm::SmallVector<SpirvInstruction *, 4> indices;
|
|
|
|
+ bool isMSOutAttribute = false;
|
|
|
|
+ bool isMSOutAttributeBlock = false;
|
|
|
|
+ bool isMSOutIndices = false;
|
|
|
|
+
|
|
|
|
+ const Expr *base = collectArrayStructIndices(lhs, /*rawIndex*/ false,
|
|
|
|
+ /*rawIndices*/ nullptr, &indices,
|
|
|
|
+ &isMSOutAttribute);
|
|
|
|
+ // Expecting at least one array index - early exit.
|
|
|
|
+ if (!base || indices.empty())
|
|
|
|
+ return nullptr;
|
|
|
|
+
|
|
|
|
+ const DeclaratorDecl *varDecl = nullptr;
|
|
|
|
+ if (isMSOutAttribute) {
|
|
|
|
+ const MemberExpr *memberExpr = dyn_cast<MemberExpr>(base);
|
|
|
|
+ assert(memberExpr);
|
|
|
|
+ varDecl = cast<DeclaratorDecl>(memberExpr->getMemberDecl());
|
|
|
|
+ } else {
|
|
|
|
+ if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
|
|
|
|
+ if (varDecl = dyn_cast<DeclaratorDecl>(arg->getDecl())) {
|
|
|
|
+ if (varDecl->hasAttr<HLSLIndicesAttr>()) {
|
|
|
|
+ isMSOutIndices = true;
|
|
|
|
+ } else if (varDecl->hasAttr<HLSLVerticesAttr>() ||
|
|
|
|
+ varDecl->hasAttr<HLSLPrimitivesAttr>()) {
|
|
|
|
+ isMSOutAttributeBlock = true;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Return if no out attribute or indices object found.
|
|
|
|
+ if (!(isMSOutAttribute || isMSOutAttributeBlock || isMSOutIndices)) {
|
|
|
|
+ return nullptr;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // For noWriteBack, return without generating write instructions.
|
|
|
|
+ if (noWriteBack) {
|
|
|
|
+ return rhs;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Add vecComponent to indices.
|
|
|
|
+ if (vecComponent) {
|
|
|
|
+ indices.push_back(vecComponent);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (isMSOutAttribute) {
|
|
|
|
+ assignToMSOutAttribute(varDecl, rhs, indices);
|
|
|
|
+ } else if (isMSOutIndices) {
|
|
|
|
+ assignToMSOutIndices(varDecl, rhs, indices);
|
|
|
|
+ } else {
|
|
|
|
+ assert(isMSOutAttributeBlock);
|
|
|
|
+ QualType type = varDecl->getType();
|
|
|
|
+ assert(isa<ConstantArrayType>(type));
|
|
|
|
+ type = astContext.getAsConstantArrayType(type)->getElementType();
|
|
|
|
+ assert(type->isStructureType());
|
|
|
|
+
|
|
|
|
+ // Extract subvalue and assign to its corresponding member attribute.
|
|
|
|
+ const auto *structDecl = type->getAs<RecordType>()->getDecl();
|
|
|
|
+ for (const auto *field : structDecl->fields()) {
|
|
|
|
+ const auto fieldType = field->getType();
|
|
|
|
+ SpirvInstruction *subValue = spvBuilder.createCompositeExtract(
|
|
|
|
+ fieldType, rhs, {getNumBaseClasses(type) + field->getFieldIndex()},
|
|
|
|
+ lhs->getLocStart());
|
|
|
|
+ assignToMSOutAttribute(field, subValue, indices);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // TODO: OK, this return value is incorrect for compound assignments, for
|
|
|
|
+ // which cases we should return lvalues. Should at least emit errors if
|
|
|
|
+ // this return value is used (can be checked via ASTContext.getParents).
|
|
|
|
+ return rhs;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void SpirvEmitter::assignToMSOutAttribute(
|
|
|
|
+ const DeclaratorDecl *decl, SpirvInstruction *value,
|
|
|
|
+ const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
|
|
|
|
+ assert(spvContext.isMS() && !indices.empty());
|
|
|
|
+
|
|
|
|
+ // Extract attribute index and vecComponent (if any).
|
|
|
|
+ SpirvInstruction *attrIndex = indices.front();
|
|
|
|
+ SpirvInstruction *vecComponent = nullptr;
|
|
|
|
+ if (indices.size() > 1) {
|
|
|
|
+ vecComponent = indices.back();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ auto semanticInfo = declIdMapper.getStageVarSemantic(decl);
|
|
|
|
+ assert(semanticInfo.isValid());
|
|
|
|
+ const auto loc = decl->getLocation();
|
|
|
|
+ // Special handle writes to clip/cull distance attributes.
|
|
|
|
+ if (!declIdMapper.glPerVertex.tryToAccess(
|
|
|
|
+ hlsl::DXIL::SigPointKind::MSOut, semanticInfo.semantic->GetKind(),
|
|
|
|
+ semanticInfo.index, attrIndex, &value, /*noWriteBack=*/false,
|
|
|
|
+ vecComponent, loc)) {
|
|
|
|
+ // All other attribute writes are handled below.
|
|
|
|
+ auto *varInstr = declIdMapper.getStageVarInstruction(decl);
|
|
|
|
+ QualType valueType = value->getAstResultType();
|
|
|
|
+ varInstr = spvBuilder.createAccessChain(valueType, varInstr, indices, loc);
|
|
|
|
+ spvBuilder.createStore(varInstr, value, loc);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void SpirvEmitter::assignToMSOutIndices(
|
|
|
|
+ const DeclaratorDecl *decl, SpirvInstruction *value,
|
|
|
|
+ const llvm::SmallVector<SpirvInstruction *, 4> &indices) {
|
|
|
|
+ assert(spvContext.isMS() && !indices.empty());
|
|
|
|
+
|
|
|
|
+ // Extract vertex index and vecComponent (if any).
|
|
|
|
+ SpirvInstruction *vertIndex = indices.front();
|
|
|
|
+ SpirvInstruction *vecComponent = nullptr;
|
|
|
|
+ if (indices.size() > 1) {
|
|
|
|
+ vecComponent = indices.back();
|
|
|
|
+ }
|
|
|
|
+ auto *var = declIdMapper.getStageVarInstruction(decl);
|
|
|
|
+ const auto *varTypeDecl = astContext.getAsConstantArrayType(decl->getType());
|
|
|
|
+ QualType varType = varTypeDecl->getElementType();
|
|
|
|
+ uint32_t numVertices = 1;
|
|
|
|
+ if (!isVectorType(varType, nullptr, &numVertices)) {
|
|
|
|
+ assert(isScalarType(varType));
|
|
|
|
+ }
|
|
|
|
+ QualType valueType = value->getAstResultType();
|
|
|
|
+ uint32_t numValues = 1;
|
|
|
|
+ if (!isVectorType(valueType, nullptr, &numValues)) {
|
|
|
|
+ assert(isScalarType(valueType));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ const auto loc = decl->getLocation();
|
|
|
|
+ if (numVertices == 1) {
|
|
|
|
+ // for "point" output topology.
|
|
|
|
+ assert(numValues == 1);
|
|
|
|
+ // create accesschain for PrimitiveIndicesNV[vertIndex].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
+ vertIndex, loc);
|
|
|
|
+ // finally create store for PrimitiveIndicesNV[vertIndex] = value.
|
|
|
|
+ spvBuilder.createStore(ptr, value, loc);
|
|
|
|
+ } else {
|
|
|
|
+ // for "line" or "triangle" output topology.
|
|
|
|
+ assert(numVertices == 2 || numVertices == 3);
|
|
|
|
+ // set baseOffset = vertIndex * numVertices.
|
|
|
|
+ auto *baseOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIMul, astContext.UnsignedIntTy, vertIndex,
|
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
+ llvm::APInt(32, numVertices)));
|
|
|
|
+ if (vecComponent) {
|
|
|
|
+ // write an individual vector component of uint2 or uint3.
|
|
|
|
+ assert(numValues == 1);
|
|
|
|
+ // set baseOffset = baseOffset + vecComponent.
|
|
|
|
+ baseOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, vecComponent);
|
|
|
|
+ // create accesschain for PrimitiveIndicesNV[baseOffset].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
+ baseOffset, loc);
|
|
|
|
+ // finally create store for PrimitiveIndicesNV[baseOffset] = value.
|
|
|
|
+ spvBuilder.createStore(ptr, value, loc);
|
|
|
|
+ } else {
|
|
|
|
+ // write all vector components of uint2 or uint3.
|
|
|
|
+ assert(numValues == numVertices);
|
|
|
|
+ auto *curOffset = baseOffset;
|
|
|
|
+ for (uint32_t i = 0; i < numValues; ++i) {
|
|
|
|
+ if (i != 0) {
|
|
|
|
+ // set curOffset = baseOffset + i.
|
|
|
|
+ curOffset = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset,
|
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy,
|
|
|
|
+ llvm::APInt(32, i)));
|
|
|
|
+ }
|
|
|
|
+ // create accesschain for PrimitiveIndicesNV[curOffset].
|
|
|
|
+ auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, var,
|
|
|
|
+ curOffset, loc);
|
|
|
|
+ // finally create store for PrimitiveIndicesNV[curOffset] = value[i].
|
|
|
|
+ spvBuilder.createStore(ptr,
|
|
|
|
+ spvBuilder.createCompositeExtract(
|
|
|
|
+ astContext.UnsignedIntTy, value, {i}, loc),
|
|
|
|
+ loc);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
|
|
SpirvInstruction *SpirvEmitter::processEachVectorInMatrix(
|
|
const Expr *matrix, SpirvInstruction *matrixVal,
|
|
const Expr *matrix, SpirvInstruction *matrixVal,
|
|
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
|
|
llvm::function_ref<SpirvInstruction *(uint32_t, QualType,
|
|
@@ -6125,7 +6344,8 @@ SpirvEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
|
|
const Expr *SpirvEmitter::collectArrayStructIndices(
|
|
const Expr *SpirvEmitter::collectArrayStructIndices(
|
|
const Expr *expr, bool rawIndex,
|
|
const Expr *expr, bool rawIndex,
|
|
llvm::SmallVectorImpl<uint32_t> *rawIndices,
|
|
llvm::SmallVectorImpl<uint32_t> *rawIndices,
|
|
- llvm::SmallVectorImpl<SpirvInstruction *> *indices) {
|
|
|
|
|
|
+ llvm::SmallVectorImpl<SpirvInstruction *> *indices,
|
|
|
|
+ bool *isMSOutAttribute) {
|
|
assert((rawIndex && rawIndices) || (!rawIndex && indices));
|
|
assert((rawIndex && rawIndices) || (!rawIndex && indices));
|
|
|
|
|
|
if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
|
|
if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
|
|
@@ -6140,7 +6360,20 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
|
|
|
|
|
|
const Expr *base = collectArrayStructIndices(
|
|
const Expr *base = collectArrayStructIndices(
|
|
indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
|
|
indexing->getBase()->IgnoreParenNoopCasts(astContext), rawIndex,
|
|
- rawIndices, indices);
|
|
|
|
|
|
+ rawIndices, indices, isMSOutAttribute);
|
|
|
|
+
|
|
|
|
+ if (isMSOutAttribute && base) {
|
|
|
|
+ if (const auto *arg = dyn_cast<DeclRefExpr>(base)) {
|
|
|
|
+ if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
|
|
+ if (varDecl->hasAttr<HLSLVerticesAttr>() ||
|
|
|
|
+ varDecl->hasAttr<HLSLPrimitivesAttr>()) {
|
|
|
|
+ assert(spvContext.isMS());
|
|
|
|
+ *isMSOutAttribute = true;
|
|
|
|
+ return expr;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
|
|
// Append the index of the current level
|
|
// Append the index of the current level
|
|
const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
|
|
const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
|
|
@@ -6167,8 +6400,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
|
|
// The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
|
|
// The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
|
|
// cast. We need to ingore it to avoid creating OpLoad.
|
|
// cast. We need to ingore it to avoid creating OpLoad.
|
|
const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
|
|
const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
|
|
- const Expr *base =
|
|
|
|
- collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
|
|
|
|
|
|
+ const Expr *base = collectArrayStructIndices(thisBase, rawIndex, rawIndices,
|
|
|
|
+ indices, isMSOutAttribute);
|
|
indices->push_back(doExpr(indexing->getIdx()));
|
|
indices->push_back(doExpr(indexing->getIdx()));
|
|
return base;
|
|
return base;
|
|
}
|
|
}
|
|
@@ -6188,8 +6421,8 @@ const Expr *SpirvEmitter::collectArrayStructIndices(
|
|
indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
|
|
indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
|
|
|
|
|
|
const auto thisBaseType = thisBase->getType();
|
|
const auto thisBaseType = thisBase->getType();
|
|
- const Expr *base =
|
|
|
|
- collectArrayStructIndices(thisBase, rawIndex, rawIndices, indices);
|
|
|
|
|
|
+ const Expr *base = collectArrayStructIndices(
|
|
|
|
+ thisBase, rawIndex, rawIndices, indices, isMSOutAttribute);
|
|
|
|
|
|
if (thisBaseType != base->getType() &&
|
|
if (thisBaseType != base->getType() &&
|
|
isAKindOfStructuredOrByteBuffer(thisBaseType)) {
|
|
isAKindOfStructuredOrByteBuffer(thisBaseType)) {
|
|
@@ -6841,6 +7074,14 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
|
|
case hlsl::IntrinsicOp::IOP_CallShader: {
|
|
case hlsl::IntrinsicOp::IOP_CallShader: {
|
|
processCallShader(callExpr);
|
|
processCallShader(callExpr);
|
|
break;
|
|
break;
|
|
|
|
+ }
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_DispatchMesh: {
|
|
|
|
+ processDispatchMesh(callExpr);
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_SetMeshOutputCounts: {
|
|
|
|
+ processMeshOutputCounts(callExpr);
|
|
|
|
+ break;
|
|
}
|
|
}
|
|
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
|
|
INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
|
|
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
|
|
INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
|
|
@@ -9172,6 +9413,7 @@ SpirvInstruction *SpirvEmitter::processReportHit(const CallExpr *callExpr) {
|
|
astContext.BoolTy, reportHitArgs,
|
|
astContext.BoolTy, reportHitArgs,
|
|
callExpr->getExprLoc());
|
|
callExpr->getExprLoc());
|
|
}
|
|
}
|
|
|
|
+
|
|
void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
|
|
void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
|
|
SpirvInstruction *callDataLocInst = nullptr;
|
|
SpirvInstruction *callDataLocInst = nullptr;
|
|
SpirvInstruction *callDataStageVar = nullptr;
|
|
SpirvInstruction *callDataStageVar = nullptr;
|
|
@@ -9238,11 +9480,12 @@ void SpirvEmitter::processCallShader(const CallExpr *callExpr) {
|
|
spvBuilder.createStore(callDataArgInst, tempLoad, callExpr->getExprLoc());
|
|
spvBuilder.createStore(callDataArgInst, tempLoad, callExpr->getExprLoc());
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
+
|
|
void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
- SpirvInstruction *payloadLocInst = nullptr;
|
|
|
|
- SpirvInstruction *payloadStageVar = nullptr;
|
|
|
|
- const VarDecl *payloadArg = nullptr;
|
|
|
|
- QualType payloadType;
|
|
|
|
|
|
+ SpirvInstruction *rayPayloadLocInst = nullptr;
|
|
|
|
+ SpirvInstruction *rayPayloadStageVar = nullptr;
|
|
|
|
+ const VarDecl *rayPayloadArg = nullptr;
|
|
|
|
+ QualType rayPayloadType;
|
|
|
|
|
|
const auto args = callExpr->getArgs();
|
|
const auto args = callExpr->getArgs();
|
|
|
|
|
|
@@ -9252,7 +9495,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
}
|
|
}
|
|
|
|
|
|
// HLSL Func
|
|
// HLSL Func
|
|
- // template<typename Payload>
|
|
|
|
|
|
+ // template<typename RayPayload>
|
|
// void TraceRay(RaytracingAccelerationStructure rs,
|
|
// void TraceRay(RaytracingAccelerationStructure rs,
|
|
// uint rayflags,
|
|
// uint rayflags,
|
|
// uint InstanceInclusionMask
|
|
// uint InstanceInclusionMask
|
|
@@ -9260,36 +9503,36 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
// uint MultiplierForGeometryContributionToHitGroupIndex,
|
|
// uint MultiplierForGeometryContributionToHitGroupIndex,
|
|
// uint MissShaderIndex,
|
|
// uint MissShaderIndex,
|
|
// RayDesc ray,
|
|
// RayDesc ray,
|
|
- // inout Payload p)
|
|
|
|
|
|
+ // inout RayPayload p)
|
|
// where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
|
|
// where RayDesc = {float3 origin, float tMin, float3 direction, float tMax}
|
|
|
|
|
|
if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
|
|
if (const auto *implCastExpr = dyn_cast<CastExpr>(args[7])) {
|
|
if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
|
|
if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
|
|
if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
if (const auto *varDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
- payloadType = varDecl->getType();
|
|
|
|
- payloadArg = varDecl;
|
|
|
|
- const auto payloadPair = payloadMap.find(payloadType);
|
|
|
|
- // Check if same type of payload stage variable was already
|
|
|
|
|
|
+ rayPayloadType = varDecl->getType();
|
|
|
|
+ rayPayloadArg = varDecl;
|
|
|
|
+ const auto rayPayloadPair = rayPayloadMap.find(rayPayloadType);
|
|
|
|
+ // Check if same type of rayPayload stage variable was already
|
|
// created, if so re-use
|
|
// created, if so re-use
|
|
- if (payloadPair == payloadMap.end()) {
|
|
|
|
- int numPayloadVars = payloadMap.size();
|
|
|
|
- payloadStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
|
|
|
|
+ if (rayPayloadPair == rayPayloadMap.end()) {
|
|
|
|
+ int numPayloadVars = rayPayloadMap.size();
|
|
|
|
+ rayPayloadStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
spv::StorageClass::RayPayloadNV, varDecl);
|
|
spv::StorageClass::RayPayloadNV, varDecl);
|
|
// Decorate unique location id for each created stage var
|
|
// Decorate unique location id for each created stage var
|
|
- spvBuilder.decorateLocation(payloadStageVar, numPayloadVars);
|
|
|
|
- payloadLocInst = spvBuilder.getConstantInt(
|
|
|
|
|
|
+ spvBuilder.decorateLocation(rayPayloadStageVar, numPayloadVars);
|
|
|
|
+ rayPayloadLocInst = spvBuilder.getConstantInt(
|
|
astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
|
|
astContext.UnsignedIntTy, llvm::APInt(32, numPayloadVars));
|
|
- payloadMap[payloadType] =
|
|
|
|
- std::make_pair(payloadStageVar, payloadLocInst);
|
|
|
|
|
|
+ rayPayloadMap[rayPayloadType] =
|
|
|
|
+ std::make_pair(rayPayloadStageVar, rayPayloadLocInst);
|
|
} else {
|
|
} else {
|
|
- payloadStageVar = payloadPair->second.first;
|
|
|
|
- payloadLocInst = payloadPair->second.second;
|
|
|
|
|
|
+ rayPayloadStageVar = rayPayloadPair->second.first;
|
|
|
|
+ rayPayloadLocInst = rayPayloadPair->second.second;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- assert(payloadStageVar && payloadArg);
|
|
|
|
|
|
+ assert(rayPayloadStageVar && rayPayloadArg);
|
|
|
|
|
|
const auto floatType = astContext.FloatTy;
|
|
const auto floatType = astContext.FloatTy;
|
|
const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
|
|
const auto vecType = astContext.getExtVectorType(astContext.FloatTy, 3);
|
|
@@ -9307,11 +9550,12 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
|
|
spvBuilder.createCompositeExtract(floatType, rayDescArg, {3}, loc);
|
|
|
|
|
|
// Copy argument to stage variable
|
|
// Copy argument to stage variable
|
|
- const auto payloadArgInst =
|
|
|
|
- declIdMapper.getDeclEvalInfo(payloadArg, payloadArg->getLocStart());
|
|
|
|
- auto tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadArgInst,
|
|
|
|
- payloadArg->getLocStart());
|
|
|
|
- spvBuilder.createStore(payloadStageVar, tempLoad, callExpr->getExprLoc());
|
|
|
|
|
|
+ const auto rayPayloadArgInst =
|
|
|
|
+ declIdMapper.getDeclEvalInfo(rayPayloadArg, rayPayloadArg->getLocStart());
|
|
|
|
+ auto tempLoad =
|
|
|
|
+ spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadArgInst,
|
|
|
|
+ rayPayloadArg->getLocStart());
|
|
|
|
+ spvBuilder.createStore(rayPayloadStageVar, tempLoad, callExpr->getExprLoc());
|
|
|
|
|
|
// SPIR-V Instruction
|
|
// SPIR-V Instruction
|
|
// void OpTraceNV ( <id> AccelerationStructureNV acStruct,
|
|
// void OpTraceNV ( <id> AccelerationStructureNV acStruct,
|
|
@@ -9324,7 +9568,7 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
// <id> float Ray Tmin,
|
|
// <id> float Ray Tmin,
|
|
// <id> vec3 Ray Direction,
|
|
// <id> vec3 Ray Direction,
|
|
// <id> float Ray Tmax,
|
|
// <id> float Ray Tmax,
|
|
- // <id> uint Payload number)
|
|
|
|
|
|
+ // <id> uint RayPayload number)
|
|
|
|
|
|
llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
|
|
llvm::SmallVector<SpirvInstruction *, 8> traceArgs;
|
|
for (int ii = 0; ii < 6; ii++) {
|
|
for (int ii = 0; ii < 6; ii++) {
|
|
@@ -9335,18 +9579,78 @@ void SpirvEmitter::processTraceRay(const CallExpr *callExpr) {
|
|
traceArgs.push_back(tMin);
|
|
traceArgs.push_back(tMin);
|
|
traceArgs.push_back(direction);
|
|
traceArgs.push_back(direction);
|
|
traceArgs.push_back(tMax);
|
|
traceArgs.push_back(tMax);
|
|
- traceArgs.push_back(payloadLocInst);
|
|
|
|
|
|
+ traceArgs.push_back(rayPayloadLocInst);
|
|
|
|
|
|
spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
|
|
spvBuilder.createRayTracingOpsNV(spv::Op::OpTraceNV, QualType(), traceArgs,
|
|
callExpr->getExprLoc());
|
|
callExpr->getExprLoc());
|
|
|
|
|
|
// Copy arguments back to stage variable
|
|
// Copy arguments back to stage variable
|
|
- tempLoad = spvBuilder.createLoad(payloadArg->getType(), payloadStageVar,
|
|
|
|
- payloadArg->getLocStart());
|
|
|
|
- spvBuilder.createStore(payloadArgInst, tempLoad, callExpr->getExprLoc());
|
|
|
|
|
|
+ tempLoad = spvBuilder.createLoad(rayPayloadArg->getType(), rayPayloadStageVar,
|
|
|
|
+ rayPayloadArg->getLocStart());
|
|
|
|
+ spvBuilder.createStore(rayPayloadArgInst, tempLoad, callExpr->getExprLoc());
|
|
return;
|
|
return;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+void SpirvEmitter::processDispatchMesh(const CallExpr *callExpr) {
|
|
|
|
+ // HLSL Func - void DispatchMesh(uint ThreadGroupCountX,
|
|
|
|
+ // uint ThreadGroupCountY,
|
|
|
|
+ // uint ThreadGroupCountZ,
|
|
|
|
+ // groupshared <structType> MeshPayload);
|
|
|
|
+ assert(callExpr->getNumArgs() == 4);
|
|
|
|
+ const auto args = callExpr->getArgs();
|
|
|
|
+ const auto loc = callExpr->getExprLoc();
|
|
|
|
+
|
|
|
|
+ // 1) create a barrier GroupMemoryBarrierWithGroupSync().
|
|
|
|
+ processIntrinsicMemoryBarrier(callExpr,
|
|
|
|
+ /*isDevice*/ false,
|
|
|
|
+ /*groupSync*/ true,
|
|
|
|
+ /*isAllBarrier*/ false);
|
|
|
|
+
|
|
|
|
+ // 2) set TaskCountNV = threadX * threadY * threadZ.
|
|
|
|
+ auto *threadX = doExpr(args[0]);
|
|
|
|
+ auto *threadY = doExpr(args[1]);
|
|
|
|
+ auto *threadZ = doExpr(args[2]);
|
|
|
|
+ auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::TaskCountNV,
|
|
|
|
+ astContext.UnsignedIntTy, loc);
|
|
|
|
+ auto *taskCount = spvBuilder.createBinaryOp(
|
|
|
|
+ spv::Op::OpIMul, astContext.UnsignedIntTy, threadX,
|
|
|
|
+ spvBuilder.createBinaryOp(spv::Op::OpIMul, astContext.UnsignedIntTy,
|
|
|
|
+ threadY, threadZ));
|
|
|
|
+ spvBuilder.createStore(var, taskCount, loc);
|
|
|
|
+
|
|
|
|
+ // 3) create PerTaskNV out attribute block and store MeshPayload info.
|
|
|
|
+ const auto *sigPoint =
|
|
|
|
+ hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::MSOut);
|
|
|
|
+ spv::StorageClass sc = spv::StorageClass::Output;
|
|
|
|
+ auto *payloadArg = doExpr(args[3]);
|
|
|
|
+ bool isValid = false;
|
|
|
|
+ if (const auto *implCastExpr = dyn_cast<CastExpr>(args[3])) {
|
|
|
|
+ if (const auto *arg = dyn_cast<DeclRefExpr>(implCastExpr->getSubExpr())) {
|
|
|
|
+ if (const auto *paramDecl = dyn_cast<VarDecl>(arg->getDecl())) {
|
|
|
|
+ if (paramDecl->hasAttr<HLSLGroupSharedAttr>()) {
|
|
|
|
+ isValid = declIdMapper.createPayloadStageVars(
|
|
|
|
+ sigPoint, sc, paramDecl, /*asInput=*/false, paramDecl->getType(),
|
|
|
|
+ "out.var", &payloadArg);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ if (!isValid) {
|
|
|
|
+ emitError("expected groupshared object as argument to DispatchMesh()",
|
|
|
|
+ args[3]->getExprLoc());
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void SpirvEmitter::processMeshOutputCounts(const CallExpr *callExpr) {
|
|
|
|
+ // HLSL Func - void SetMeshOutputCounts(uint numVertices, uint numPrimitives);
|
|
|
|
+ assert(callExpr->getNumArgs() == 2);
|
|
|
|
+ const auto args = callExpr->getArgs();
|
|
|
|
+ const auto loc = callExpr->getExprLoc();
|
|
|
|
+ auto *var = declIdMapper.getBuiltinVar(spv::BuiltIn::PrimitiveCountNV,
|
|
|
|
+ astContext.UnsignedIntTy, loc);
|
|
|
|
+ spvBuilder.createStore(var, doExpr(args[1]), loc);
|
|
|
|
+}
|
|
|
|
+
|
|
SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
|
|
SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
|
|
{
|
|
{
|
|
QualType scalarType = {};
|
|
QualType scalarType = {};
|
|
@@ -9637,10 +9941,24 @@ hlsl::ShaderModel::Kind SpirvEmitter::getShaderModelKind(StringRef stageName) {
|
|
smk = hlsl::ShaderModel::Kind::Intersection;
|
|
smk = hlsl::ShaderModel::Kind::Intersection;
|
|
break;
|
|
break;
|
|
case 'a':
|
|
case 'a':
|
|
- smk = hlsl::ShaderModel::Kind::AnyHit;
|
|
|
|
|
|
+ switch (stageName[1]) {
|
|
|
|
+ case 'm':
|
|
|
|
+ smk = hlsl::ShaderModel::Kind::Amplification;
|
|
|
|
+ break;
|
|
|
|
+ case 'n':
|
|
|
|
+ smk = hlsl::ShaderModel::Kind::AnyHit;
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
break;
|
|
break;
|
|
case 'm':
|
|
case 'm':
|
|
- smk = hlsl::ShaderModel::Kind::Miss;
|
|
|
|
|
|
+ switch (stageName[1]) {
|
|
|
|
+ case 'e':
|
|
|
|
+ smk = hlsl::ShaderModel::Kind::Mesh;
|
|
|
|
+ break;
|
|
|
|
+ case 'i':
|
|
|
|
+ smk = hlsl::ShaderModel::Kind::Miss;
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
break;
|
|
break;
|
|
default:
|
|
default:
|
|
smk = hlsl::ShaderModel::Kind::Invalid;
|
|
smk = hlsl::ShaderModel::Kind::Invalid;
|
|
@@ -9679,6 +9997,10 @@ SpirvEmitter::getSpirvShaderStage(hlsl::ShaderModel::Kind smk) {
|
|
return spv::ExecutionModel::MissNV;
|
|
return spv::ExecutionModel::MissNV;
|
|
case hlsl::ShaderModel::Kind::Callable:
|
|
case hlsl::ShaderModel::Kind::Callable:
|
|
return spv::ExecutionModel::CallableNV;
|
|
return spv::ExecutionModel::CallableNV;
|
|
|
|
+ case hlsl::ShaderModel::Kind::Mesh:
|
|
|
|
+ return spv::ExecutionModel::MeshNV;
|
|
|
|
+ case hlsl::ShaderModel::Kind::Amplification:
|
|
|
|
+ return spv::ExecutionModel::TaskNV;
|
|
default:
|
|
default:
|
|
llvm_unreachable("invalid shader model kind");
|
|
llvm_unreachable("invalid shader model kind");
|
|
break;
|
|
break;
|
|
@@ -9942,8 +10264,8 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
|
|
paramTypes.push_back(paramType);
|
|
paramTypes.push_back(paramType);
|
|
|
|
|
|
// Order of arguments is fixed
|
|
// Order of arguments is fixed
|
|
- // Any-Hit/Closest-Hit : Arg 0 = payload(inout), Arg1 = attribute(in)
|
|
|
|
- // Miss : Arg 0 = payload(inout)
|
|
|
|
|
|
+ // Any-Hit/Closest-Hit : Arg 0 = rayPayload(inout), Arg1 = attribute(in)
|
|
|
|
+ // Miss : Arg 0 = rayPayload(inout)
|
|
// Callable : Arg 0 = callable data(inout)
|
|
// Callable : Arg 0 = callable data(inout)
|
|
// Raygeneration/Intersection : No Args allowed
|
|
// Raygeneration/Intersection : No Args allowed
|
|
if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
|
|
if (sKind == hlsl::ShaderModel::Kind::RayGeneration) {
|
|
@@ -9954,7 +10276,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
|
|
sKind == hlsl::ShaderModel::Kind::AnyHit) {
|
|
sKind == hlsl::ShaderModel::Kind::AnyHit) {
|
|
// Generate rayPayloadInNV and hitAttributeNV stage variables
|
|
// Generate rayPayloadInNV and hitAttributeNV stage variables
|
|
if (i == 0) {
|
|
if (i == 0) {
|
|
- // First argument is always payload
|
|
|
|
|
|
+ // First argument is always rayPayload
|
|
curStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
curStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
spv::StorageClass::IncomingRayPayloadNV, param);
|
|
spv::StorageClass::IncomingRayPayloadNV, param);
|
|
} else {
|
|
} else {
|
|
@@ -9964,7 +10286,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
|
|
}
|
|
}
|
|
} else if (sKind == hlsl::ShaderModel::Kind::Miss) {
|
|
} else if (sKind == hlsl::ShaderModel::Kind::Miss) {
|
|
// Generate rayPayloadInNV stage variable
|
|
// Generate rayPayloadInNV stage variable
|
|
- // First and only argument is payload
|
|
|
|
|
|
+ // First and only argument is rayPayload
|
|
curStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
curStageVar = declIdMapper.createRayTracingNVStageVar(
|
|
spv::StorageClass::IncomingRayPayloadNV, param);
|
|
spv::StorageClass::IncomingRayPayloadNV, param);
|
|
} else if (sKind == hlsl::ShaderModel::Kind::Callable) {
|
|
} else if (sKind == hlsl::ShaderModel::Kind::Callable) {
|
|
@@ -10004,6 +10326,166 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
|
|
return true;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
|
|
|
|
+ const FunctionDecl *decl, uint32_t *outVerticesArraySize) {
|
|
|
|
+ if (auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>()) {
|
|
|
|
+ uint32_t x, y, z;
|
|
|
|
+ x = static_cast<uint32_t>(numThreadsAttr->getX());
|
|
|
|
+ y = static_cast<uint32_t>(numThreadsAttr->getY());
|
|
|
|
+ z = static_cast<uint32_t>(numThreadsAttr->getZ());
|
|
|
|
+ spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
|
|
|
|
+ {x, y, z}, decl->getLocation());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Early return for amplification shaders as they only take the 'numthreads'
|
|
|
|
+ // attribute.
|
|
|
|
+ if (spvContext.isAS())
|
|
|
|
+ return true;
|
|
|
|
+
|
|
|
|
+ spv::ExecutionMode outputPrimitive = spv::ExecutionMode::Max;
|
|
|
|
+ if (auto *outputTopology = decl->getAttr<HLSLOutputTopologyAttr>()) {
|
|
|
|
+ const auto topology = outputTopology->getTopology().lower();
|
|
|
|
+ outputPrimitive =
|
|
|
|
+ llvm::StringSwitch<spv::ExecutionMode>(topology)
|
|
|
|
+ .Case("point", spv::ExecutionMode::OutputPoints)
|
|
|
|
+ .Case("line", spv::ExecutionMode::OutputLinesNV)
|
|
|
|
+ .Case("triangle", spv::ExecutionMode::OutputTrianglesNV);
|
|
|
|
+ if (outputPrimitive != spv::ExecutionMode::Max) {
|
|
|
|
+ spvBuilder.addExecutionMode(entryFunction, outputPrimitive, {},
|
|
|
|
+ decl->getLocation());
|
|
|
|
+ } else {
|
|
|
|
+ emitError("unknown output topology in mesh shader",
|
|
|
|
+ outputTopology->getLocation());
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ uint32_t numVertices = 0;
|
|
|
|
+ uint32_t numIndices = 0;
|
|
|
|
+ uint32_t numPrimitives = 0;
|
|
|
|
+ bool payloadDeclSeen = false;
|
|
|
|
+
|
|
|
|
+ for (uint32_t i = 0; i < decl->getNumParams(); i++) {
|
|
|
|
+ const auto param = decl->getParamDecl(i);
|
|
|
|
+ const auto paramType = param->getType();
|
|
|
|
+ const auto paramLoc = param->getLocation();
|
|
|
|
+ if (param->hasAttr<HLSLVerticesAttr>() ||
|
|
|
|
+ param->hasAttr<HLSLIndicesAttr>() ||
|
|
|
|
+ param->hasAttr<HLSLPrimitivesAttr>()) {
|
|
|
|
+ uint32_t arraySize = 0;
|
|
|
|
+ if (const auto *arrayType =
|
|
|
|
+ astContext.getAsConstantArrayType(paramType)) {
|
|
|
|
+ const auto eleType =
|
|
|
|
+ arrayType->getElementType()->getCanonicalTypeUnqualified();
|
|
|
|
+ if (param->hasAttr<HLSLIndicesAttr>()) {
|
|
|
|
+ switch (outputPrimitive) {
|
|
|
|
+ case spv::ExecutionMode::OutputPoints:
|
|
|
|
+ if (eleType != astContext.UnsignedIntTy) {
|
|
|
|
+ emitError("expected 1D array of uint type", paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ break;
|
|
|
|
+ case spv::ExecutionMode::OutputLinesNV: {
|
|
|
|
+ QualType baseType;
|
|
|
|
+ uint32_t length;
|
|
|
|
+ if (!isVectorType(eleType, &baseType, &length) ||
|
|
|
|
+ baseType != astContext.UnsignedIntTy || length != 2) {
|
|
|
|
+ emitError("expected 1D array of uint2 type", paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ case spv::ExecutionMode::OutputTrianglesNV: {
|
|
|
|
+ QualType baseType;
|
|
|
|
+ uint32_t length;
|
|
|
|
+ if (!isVectorType(eleType, &baseType, &length) ||
|
|
|
|
+ baseType != astContext.UnsignedIntTy || length != 3) {
|
|
|
|
+ emitError("expected 1D array of uint3 type", paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ break;
|
|
|
|
+ }
|
|
|
|
+ default:
|
|
|
|
+ assert(false && "unexpected spirv execution mode");
|
|
|
|
+ }
|
|
|
|
+ } else if (!eleType->isStructureType()) {
|
|
|
|
+ // vertices/primitives objects
|
|
|
|
+ emitError("expected 1D array of struct type", paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ arraySize = static_cast<uint32_t>(arrayType->getSize().getZExtValue());
|
|
|
|
+ } else {
|
|
|
|
+ emitError("expected 1D array of indices/vertices/primitives object",
|
|
|
|
+ paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ if (param->hasAttr<HLSLVerticesAttr>()) {
|
|
|
|
+ if (numVertices != 0) {
|
|
|
|
+ emitError("only one object with 'vertices' modifier is allowed",
|
|
|
|
+ paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ numVertices = arraySize;
|
|
|
|
+ } else if (param->hasAttr<HLSLIndicesAttr>()) {
|
|
|
|
+ if (numIndices != 0) {
|
|
|
|
+ emitError("only one object with 'indices' modifier is allowed",
|
|
|
|
+ paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ numIndices = arraySize;
|
|
|
|
+ } else if (param->hasAttr<HLSLPrimitivesAttr>()) {
|
|
|
|
+ if (numPrimitives != 0) {
|
|
|
|
+ emitError("only one object with 'primitives' modifier is allowed",
|
|
|
|
+ paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ numPrimitives = arraySize;
|
|
|
|
+ }
|
|
|
|
+ } else if (param->hasAttr<HLSLPayloadAttr>()) {
|
|
|
|
+ if (payloadDeclSeen) {
|
|
|
|
+ emitError("only one object with 'payload' modifier is allowed",
|
|
|
|
+ paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ payloadDeclSeen = true;
|
|
|
|
+ if (!paramType->isStructureType()) {
|
|
|
|
+ emitError("expected payload of struct type", paramLoc);
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Vertex attribute array is a mandatory param to mesh entry function.
|
|
|
|
+ if (numVertices != 0) {
|
|
|
|
+ *outVerticesArraySize = numVertices;
|
|
|
|
+ spvBuilder.addExecutionMode(
|
|
|
|
+ entryFunction, spv::ExecutionMode::OutputVertices,
|
|
|
|
+ {static_cast<uint32_t>(numVertices)}, decl->getLocation());
|
|
|
|
+ } else {
|
|
|
|
+ emitError("expected vertices object declaration", decl->getLocation());
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Vertex indices array is a mandatory param to mesh entry function.
|
|
|
|
+ if (numIndices != 0) {
|
|
|
|
+ spvBuilder.addExecutionMode(
|
|
|
|
+ entryFunction, spv::ExecutionMode::OutputPrimitivesNV,
|
|
|
|
+ {static_cast<uint32_t>(numIndices)}, decl->getLocation());
|
|
|
|
+ // Primitive attribute array is an optional param to mesh entry function,
|
|
|
|
+ // but the array size should match the indices array.
|
|
|
|
+ if (numPrimitives != 0 && numPrimitives != numIndices) {
|
|
|
|
+ emitError("array size of primitives object should match 'indices' object",
|
|
|
|
+ decl->getLocation());
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+ } else {
|
|
|
|
+ emitError("expected indices object declaration", decl->getLocation());
|
|
|
|
+ return false;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return true;
|
|
|
|
+}
|
|
|
|
+
|
|
bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
|
|
bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
|
|
SpirvFunction *entryFuncInstr) {
|
|
SpirvFunction *entryFuncInstr) {
|
|
// HS specific attributes
|
|
// HS specific attributes
|
|
@@ -10073,6 +10555,9 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
|
|
if (!processGeometryShaderAttributes(decl, &inputArraySize))
|
|
if (!processGeometryShaderAttributes(decl, &inputArraySize))
|
|
return false;
|
|
return false;
|
|
// The per-vertex output of GS is not an array.
|
|
// The per-vertex output of GS is not an array.
|
|
|
|
+ } else if (spvContext.isMS() || spvContext.isAS()) {
|
|
|
|
+ if (!processMeshOrAmplificationShaderAttributes(decl, &outputArraySize))
|
|
|
|
+ return false;
|
|
}
|
|
}
|
|
|
|
|
|
// Go through all parameters and record the declaration of SV_ClipDistance
|
|
// Go through all parameters and record the declaration of SV_ClipDistance
|
|
@@ -10101,7 +10586,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
|
|
// offset of SV_ClipDistance/SV_CullDistance variables within the array.
|
|
// offset of SV_ClipDistance/SV_CullDistance variables within the array.
|
|
declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
|
|
declIdMapper.glPerVertex.calculateClipCullDistanceArraySize();
|
|
|
|
|
|
- if (!spvContext.isCS()) {
|
|
|
|
|
|
+ if (!spvContext.isCS() && !spvContext.isAS()) {
|
|
// Generate stand-alone builtins of Position, ClipDistance, and
|
|
// Generate stand-alone builtins of Position, ClipDistance, and
|
|
// CullDistance, which belongs to gl_PerVertex.
|
|
// CullDistance, which belongs to gl_PerVertex.
|
|
declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
|
|
declIdMapper.glPerVertex.generateVars(inputArraySize, outputArraySize);
|