|
@@ -44,10 +44,18 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
|
|
|
/*SourceLocation*/ {});
|
|
|
fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
|
|
|
|
|
|
- // Lower the SPIR-V function type if necessary.
|
|
|
- fn->setFunctionType(const_cast<SpirvType *>(
|
|
|
- lowerType(fn->getFunctionType(), SpirvLayoutRule::Void,
|
|
|
- fn->getSourceLocation())));
|
|
|
+ // Lower the function parameter types.
|
|
|
+ auto paramQualTypes = fn->getAstParamTypes();
|
|
|
+ llvm::SmallVector<const SpirvType *, 4> spirvParamTypes;
|
|
|
+ for (auto qualtype : paramQualTypes) {
|
|
|
+ const auto *spirvParamType =
|
|
|
+ lowerType(qualtype, SpirvLayoutRule::Void,
|
|
|
+ /*isRowMajor*/ llvm::None, fn->getSourceLocation());
|
|
|
+ spirvParamTypes.push_back(spvContext.getPointerType(
|
|
|
+ spirvParamType, spv::StorageClass::Function));
|
|
|
+ }
|
|
|
+ fn->setFunctionType(
|
|
|
+ spvContext.getFunctionType(spirvReturnType, spirvParamTypes));
|
|
|
}
|
|
|
return true;
|
|
|
}
|
|
@@ -154,22 +162,6 @@ const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
|
|
|
lowerType(imageAstType, rule, /*isRowMajor*/ llvm::None, loc);
|
|
|
assert(isa<ImageType>(imageSpirvType));
|
|
|
return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
|
|
|
- } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
|
|
|
- // Lower the return type.
|
|
|
- const QualType astReturnType = hybridFn->getReturnType();
|
|
|
- const SpirvType *spirvReturnType =
|
|
|
- lowerType(astReturnType, rule, /*isRowMajor*/ llvm::None, loc);
|
|
|
-
|
|
|
- // Go over all params and lower them.
|
|
|
- std::vector<const SpirvType *> paramTypes;
|
|
|
- for (auto paramType : hybridFn->getParamTypes()) {
|
|
|
- const auto *spirvParamType =
|
|
|
- lowerType(paramType, rule, /*isRowMajor*/ llvm::None, loc);
|
|
|
- paramTypes.push_back(spvContext.getPointerType(
|
|
|
- spirvParamType, spv::StorageClass::Function));
|
|
|
- }
|
|
|
-
|
|
|
- return spvContext.getFunctionType(spirvReturnType, paramTypes);
|
|
|
} else if (const auto *hybridStruct = dyn_cast<HybridStructType>(type)) {
|
|
|
// lower all fields of the struct.
|
|
|
auto loweredFields =
|
|
@@ -553,7 +545,8 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
|
|
|
|
|
|
// We have a runtime array of structures. So:
|
|
|
// The stride of the runtime array is the size of the struct.
|
|
|
- const auto *raType = spvContext.getRuntimeArrayType(structType, arrayStride);
|
|
|
+ const auto *raType =
|
|
|
+ spvContext.getRuntimeArrayType(structType, arrayStride);
|
|
|
const bool isReadOnly = (name == "StructuredBuffer");
|
|
|
|
|
|
// Attach matrix stride decorations if this is a *StructuredBuffer<matrix>.
|