|
@@ -25,12 +25,10 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
|
|
/*SourceLocation*/ {});
|
|
/*SourceLocation*/ {});
|
|
fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
|
|
fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
|
|
|
|
|
|
- // In case the function type is a hybrid type, we should also lower the
|
|
|
|
- // return type of the SPIR-V function type.
|
|
|
|
- if (auto *fnRetType = dyn_cast<HybridType>(fn->getFunctionType())) {
|
|
|
|
- fn->setFunctionType(const_cast<SpirvType *>(lowerType(
|
|
|
|
- fnRetType, SpirvLayoutRule::Void, fn->getSourceLocation())));
|
|
|
|
- }
|
|
|
|
|
|
+ // Lower the SPIR-V function type if necessary.
|
|
|
|
+ fn->setFunctionType(const_cast<SpirvType *>(
|
|
|
|
+ lowerType(fn->getFunctionType(), SpirvLayoutRule::Void,
|
|
|
|
+ fn->getSourceLocation())));
|
|
}
|
|
}
|
|
return true;
|
|
return true;
|
|
}
|
|
}
|
|
@@ -48,11 +46,9 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
|
|
}
|
|
}
|
|
// Lower Hybrid type to SpirvType
|
|
// Lower Hybrid type to SpirvType
|
|
else if (hybridType) {
|
|
else if (hybridType) {
|
|
- if (const auto *hybridType = dyn_cast<HybridType>(instr->getResultType())) {
|
|
|
|
- const SpirvType *spirvType = lowerType(hybridType, instr->getLayoutRule(),
|
|
|
|
- instr->getSourceLocation());
|
|
|
|
- instr->setResultType(spirvType);
|
|
|
|
- }
|
|
|
|
|
|
+ const SpirvType *spirvType = lowerType(hybridType, instr->getLayoutRule(),
|
|
|
|
+ instr->getSourceLocation());
|
|
|
|
+ instr->setResultType(spirvType);
|
|
}
|
|
}
|
|
|
|
|
|
// The instruction does not have a result-type, so nothing to do.
|
|
// The instruction does not have a result-type, so nothing to do.
|
|
@@ -81,21 +77,35 @@ bool LowerTypeVisitor::visit(SpirvFunctionParameter *param) {
|
|
return true;
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
-const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
|
|
|
|
|
|
+bool LowerTypeVisitor::visit(SpirvSampledImage *instr) {
|
|
|
|
+ if (!visitInstruction(instr))
|
|
|
|
+ return false;
|
|
|
|
+
|
|
|
|
+ // Wrap the image type in sampled image type if necessary.
|
|
|
|
+ const auto *resultType = instr->getResultType();
|
|
|
|
+ if (!isa<SampledImageType>(resultType)) {
|
|
|
|
+ assert(isa<ImageType>(resultType));
|
|
|
|
+ instr->setResultType(
|
|
|
|
+ spvContext.getSampledImageType(cast<ImageType>(resultType)));
|
|
|
|
+ }
|
|
|
|
+ return true;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
|
|
SpirvLayoutRule rule,
|
|
SpirvLayoutRule rule,
|
|
SourceLocation loc) {
|
|
SourceLocation loc) {
|
|
- if (const auto *hybridPointer = dyn_cast<HybridPointerType>(hybrid)) {
|
|
|
|
|
|
+ if (const auto *hybridPointer = dyn_cast<HybridPointerType>(type)) {
|
|
const QualType pointeeType = hybridPointer->getPointeeType();
|
|
const QualType pointeeType = hybridPointer->getPointeeType();
|
|
const SpirvType *pointeeSpirvType = lowerType(pointeeType, rule, loc);
|
|
const SpirvType *pointeeSpirvType = lowerType(pointeeType, rule, loc);
|
|
return spvContext.getPointerType(pointeeSpirvType,
|
|
return spvContext.getPointerType(pointeeSpirvType,
|
|
hybridPointer->getStorageClass());
|
|
hybridPointer->getStorageClass());
|
|
} else if (const auto *hybridSampledImage =
|
|
} else if (const auto *hybridSampledImage =
|
|
- dyn_cast<HybridSampledImageType>(hybrid)) {
|
|
|
|
|
|
+ dyn_cast<HybridSampledImageType>(type)) {
|
|
const QualType imageAstType = hybridSampledImage->getImageType();
|
|
const QualType imageAstType = hybridSampledImage->getImageType();
|
|
const SpirvType *imageSpirvType = lowerType(imageAstType, rule, loc);
|
|
const SpirvType *imageSpirvType = lowerType(imageAstType, rule, loc);
|
|
assert(isa<ImageType>(imageSpirvType));
|
|
assert(isa<ImageType>(imageSpirvType));
|
|
return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
|
|
return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
|
|
- } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(hybrid)) {
|
|
|
|
|
|
+ } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
|
|
// Lower the return type.
|
|
// Lower the return type.
|
|
const QualType astReturnType = hybridFn->getAstReturnType();
|
|
const QualType astReturnType = hybridFn->getAstReturnType();
|
|
const SpirvType *spirvReturnType = lowerType(astReturnType, rule, loc);
|
|
const SpirvType *spirvReturnType = lowerType(astReturnType, rule, loc);
|
|
@@ -111,7 +121,7 @@ const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
|
|
}
|
|
}
|
|
|
|
|
|
return spvContext.getFunctionType(spirvReturnType, paramTypes);
|
|
return spvContext.getFunctionType(spirvReturnType, paramTypes);
|
|
- } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(hybrid)) {
|
|
|
|
|
|
+ } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(type)) {
|
|
// lower all fields of the struct.
|
|
// lower all fields of the struct.
|
|
std::vector<StructType::FieldInfo> structFields;
|
|
std::vector<StructType::FieldInfo> structFields;
|
|
for (auto field : hybridStruct->getFields()) {
|
|
for (auto field : hybridStruct->getFields()) {
|
|
@@ -124,6 +134,100 @@ const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
|
|
hybridStruct->isReadOnly(),
|
|
hybridStruct->isReadOnly(),
|
|
hybridStruct->getInterfaceType());
|
|
hybridStruct->getInterfaceType());
|
|
}
|
|
}
|
|
|
|
+ // Void, bool, int, float cannot be further lowered.
|
|
|
|
+ // Matrices cannot contain hybrid types. Only matrices of scalars are valid.
|
|
|
|
+ // sampledType in image types can only be numberical type.
|
|
|
|
+ // Sampler types cannot be further lowered.
|
|
|
|
+ // SampledImage types cannot be further lowered.
|
|
|
|
+ else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
|
|
|
|
+ isa<MatrixType>(type) || isa<ImageType>(type) ||
|
|
|
|
+ isa<SamplerType>(type) || isa<SampledImageType>(type)) {
|
|
|
|
+ return type;
|
|
|
|
+ }
|
|
|
|
+ // Vectors could contain a hybrid type
|
|
|
|
+ else if (const auto *vecType = dyn_cast<VectorType>(type)) {
|
|
|
|
+ const auto *loweredElemType =
|
|
|
|
+ lowerType(vecType->getElementType(), rule, loc);
|
|
|
|
+ // If vector didn't contain any hybrid types, return itself.
|
|
|
|
+ if (vecType->getElementType() == loweredElemType)
|
|
|
|
+ return vecType;
|
|
|
|
+ return spvContext.getVectorType(loweredElemType,
|
|
|
|
+ vecType->getElementCount());
|
|
|
|
+ }
|
|
|
|
+ // Arrays could contain a hybrid type
|
|
|
|
+ else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
|
|
|
|
+ const auto *loweredElemType =
|
|
|
|
+ lowerType(arrType->getElementType(), rule, loc);
|
|
|
|
+ // If array didn't contain any hybrid types, return itself.
|
|
|
|
+ if (arrType->getElementType() == loweredElemType)
|
|
|
|
+ return arrType;
|
|
|
|
+ return spvContext.getArrayType(loweredElemType, arrType->getElementCount());
|
|
|
|
+ }
|
|
|
|
+ // Runtime arrays could contain a hybrid type
|
|
|
|
+ else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
|
|
|
|
+ const auto *loweredElemType =
|
|
|
|
+ lowerType(raType->getElementType(), rule, loc);
|
|
|
|
+ // If runtime array didn't contain any hybrid types, return itself.
|
|
|
|
+ if (raType->getElementType() == loweredElemType)
|
|
|
|
+ return arrType;
|
|
|
|
+ return spvContext.getRuntimeArrayType(loweredElemType);
|
|
|
|
+ }
|
|
|
|
+ // Struct types could contain a hybrid type
|
|
|
|
+ else if (const auto *structType = dyn_cast<StructType>(type)) {
|
|
|
|
+ const auto &fields = structType->getFields();
|
|
|
|
+ llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
|
|
|
|
+ bool wasLowered = false;
|
|
|
|
+ for (auto &field : fields) {
|
|
|
|
+ const auto *loweredFieldType = lowerType(field.type, rule, loc);
|
|
|
|
+ if (loweredFieldType != field.type) {
|
|
|
|
+ wasLowered = true;
|
|
|
|
+ loweredFields.push_back(
|
|
|
|
+ StructType::FieldInfo(loweredFieldType, field.name,
|
|
|
|
+ field.vkOffsetAttr, field.packOffsetAttr));
|
|
|
|
+ } else {
|
|
|
|
+ loweredFields.push_back(field);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // If the struct didn't contain any hybrid types, return itself.
|
|
|
|
+ if (!wasLowered)
|
|
|
|
+ return structType;
|
|
|
|
+
|
|
|
|
+ return spvContext.getStructType(loweredFields, structType->getStructName(),
|
|
|
|
+ structType->isReadOnly(),
|
|
|
|
+ structType->getInterfaceType());
|
|
|
|
+ }
|
|
|
|
+ // Pointer types could point to a hybrid type.
|
|
|
|
+ else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
|
|
|
|
+ const auto *loweredPointee =
|
|
|
|
+ lowerType(ptrType->getPointeeType(), rule, loc);
|
|
|
|
+ // If the pointer type didn't point to any hybrid type, return itself.
|
|
|
|
+ if (ptrType->getPointeeType() == loweredPointee)
|
|
|
|
+ return ptrType;
|
|
|
|
+
|
|
|
|
+ return spvContext.getPointerType(loweredPointee,
|
|
|
|
+ ptrType->getStorageClass());
|
|
|
|
+ }
|
|
|
|
+ // Function types may have a parameter or return type that is hybrid.
|
|
|
|
+ else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
|
|
|
|
+ const auto *loweredRetType = lowerType(fnType->getReturnType(), rule, loc);
|
|
|
|
+ bool wasLowered = fnType->getReturnType() != loweredRetType;
|
|
|
|
+ llvm::SmallVector<const SpirvType *, 4> loweredParams;
|
|
|
|
+ const auto ¶mTypes = fnType->getParamTypes();
|
|
|
|
+ for (auto *paramType : paramTypes) {
|
|
|
|
+ const auto *loweredParamType = lowerType(paramType, rule, loc);
|
|
|
|
+ loweredParams.push_back(loweredParamType);
|
|
|
|
+ if (loweredParamType != paramType) {
|
|
|
|
+ wasLowered = true;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // If the function type didn't include any hybrid types, return itself.
|
|
|
|
+ if (!wasLowered) {
|
|
|
|
+ return fnType;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return spvContext.getFunctionType(loweredRetType, loweredParams);
|
|
|
|
+ }
|
|
|
|
+
|
|
llvm_unreachable("lowering of hybrid type not implemented");
|
|
llvm_unreachable("lowering of hybrid type not implemented");
|
|
}
|
|
}
|
|
|
|
|