|
@@ -33,6 +33,23 @@ bool RemoveBufferBlockVisitor::visit(SpirvModule *mod, Phase phase) {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+bool RemoveBufferBlockVisitor::hasStorageBufferInterfaceType(
|
|
|
+ const SpirvType *type) {
|
|
|
+ while (type != nullptr) {
|
|
|
+ if (const auto *structType = dyn_cast<StructType>(type)) {
|
|
|
+ return structType->getInterfaceType() ==
|
|
|
+ StructInterfaceType::StorageBuffer;
|
|
|
+ } else if (const auto *elemType = dyn_cast<ArrayType>(type)) {
|
|
|
+ type = elemType->getElementType();
|
|
|
+ } else if (const auto *elemType = dyn_cast<RuntimeArrayType>(type)) {
|
|
|
+ type = elemType->getElementType();
|
|
|
+ } else {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+}
|
|
|
+
|
|
|
bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
|
|
|
if (!inst->getResultType())
|
|
|
return true;
|
|
@@ -60,17 +77,11 @@ bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
|
|
|
// For all instructions, if the result type is a pointer pointing to a struct
|
|
|
// with StorageBuffer interface, the storage class must be updated.
|
|
|
if (auto *ptrResultType = dyn_cast<SpirvPointerType>(inst->getResultType())) {
|
|
|
- if (auto *structPointeeType =
|
|
|
- dyn_cast<StructType>(ptrResultType->getPointeeType())) {
|
|
|
- // Update the instruction's storage class if necessary
|
|
|
- if (structPointeeType->getInterfaceType() ==
|
|
|
- StructInterfaceType::StorageBuffer &&
|
|
|
- ptrResultType->getStorageClass() !=
|
|
|
- spv::StorageClass::StorageBuffer) {
|
|
|
- inst->setStorageClass(spv::StorageClass::StorageBuffer);
|
|
|
- inst->setResultType(context.getPointerType(
|
|
|
- ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
|
|
|
- }
|
|
|
+ if (hasStorageBufferInterfaceType(ptrResultType->getPointeeType()) &&
|
|
|
+ ptrResultType->getStorageClass() != spv::StorageClass::StorageBuffer) {
|
|
|
+ inst->setStorageClass(spv::StorageClass::StorageBuffer);
|
|
|
+ inst->setResultType(context.getPointerType(
|
|
|
+ ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
|
|
|
}
|
|
|
}
|
|
|
|