|
@@ -584,6 +584,54 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
|
|
return output;
|
|
return output;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+// Calls `operation` on for each field in the base and derives class defined by
|
|
|
|
+// `recordType`. The `operation` will receive the AST type linked to the field,
|
|
|
|
+// the SPIRV type linked to the field, and the index of the field in the final
|
|
|
|
+// SPIR-V representation. This index of the field can vary from the AST
|
|
|
|
+// field-index because bitfields are merged into a single field in the SPIR-V
|
|
|
|
+// representation.
|
|
|
|
+//
|
|
|
|
+// If the operation returns false, we stop processing fields.
|
|
|
|
+void forEachSpirvField(
|
|
|
|
+ const RecordType *recordType, const StructType *spirvType,
|
|
|
|
+ std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
|
|
|
|
+ const StructType::FieldInfo &field)>
|
|
|
|
+ operation) {
|
|
|
|
+ const auto *cxxDecl = recordType->getAsCXXRecordDecl();
|
|
|
|
+ const auto *recordDecl = recordType->getDecl();
|
|
|
|
+
|
|
|
|
+ // Iterate through the base class (one field per base class).
|
|
|
|
+ // Bases cannot be melded into 1 field like bitfields, simple iteration.
|
|
|
|
+ uint32_t lastConvertedIndex = 0;
|
|
|
|
+ size_t astFieldIndex = 0;
|
|
|
|
+ for (const auto &base : cxxDecl->bases()) {
|
|
|
|
+ const auto &type = base.getType();
|
|
|
|
+ const auto &spirvField = spirvType->getFields()[astFieldIndex];
|
|
|
|
+ if (!operation(spirvField.fieldIndex, type, spirvField)) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+ lastConvertedIndex = spirvField.fieldIndex;
|
|
|
|
+ ++astFieldIndex;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // Iterate through the derived class fields. Field could be merged.
|
|
|
|
+ for (const auto *field : recordDecl->fields()) {
|
|
|
|
+ const auto &spirvField = spirvType->getFields()[astFieldIndex];
|
|
|
|
+ const uint32_t currentFieldIndex = spirvField.fieldIndex;
|
|
|
|
+ if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
|
|
|
|
+ ++astFieldIndex;
|
|
|
|
+ continue;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ const auto &type = field->getType();
|
|
|
|
+ if (!operation(currentFieldIndex, type, spirvField)) {
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+ lastConvertedIndex = currentFieldIndex;
|
|
|
|
+ ++astFieldIndex;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
} // namespace
|
|
} // namespace
|
|
|
|
|
|
SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
|
|
SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
|
|
@@ -6410,29 +6458,26 @@ SpirvInstruction *SpirvEmitter::reconstructValue(SpirvInstruction *srcVal,
|
|
|
|
|
|
// Structs
|
|
// Structs
|
|
if (const auto *recordType = valType->getAs<RecordType>()) {
|
|
if (const auto *recordType = valType->getAs<RecordType>()) {
|
|
- uint32_t index = 0;
|
|
|
|
- llvm::SmallVector<SpirvInstruction *, 4> elements;
|
|
|
|
|
|
+ assert(recordType->isStructureType());
|
|
|
|
|
|
- // If the struct inherits from other structs, visit the bases.
|
|
|
|
- const auto *decl = valType->getAsCXXRecordDecl();
|
|
|
|
- for (auto baseIt = decl->bases_begin(), baseIe = decl->bases_end();
|
|
|
|
- baseIt != baseIe; ++baseIt, ++index) {
|
|
|
|
- SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
|
|
|
|
- baseIt->getType(), srcVal, {index}, loc, range);
|
|
|
|
- subSrcVal->setLayoutRule(srcVal->getLayoutRule());
|
|
|
|
- elements.push_back(
|
|
|
|
- reconstructValue(subSrcVal, baseIt->getType(), dstLR, loc, range));
|
|
|
|
- }
|
|
|
|
|
|
+ LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
|
|
|
|
+ const StructType *spirvStructType =
|
|
|
|
+ lowerStructType(spirvOptions, lowerTypeVisitor, recordType->desugar());
|
|
|
|
+
|
|
|
|
+ llvm::SmallVector<SpirvInstruction *, 4> elements;
|
|
|
|
+ forEachSpirvField(
|
|
|
|
+ recordType, spirvStructType,
|
|
|
|
+ [&](size_t spirvFieldIndex, const QualType &fieldType,
|
|
|
|
+ const auto &field) {
|
|
|
|
+ SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
|
|
|
|
+ fieldType, srcVal, {static_cast<uint32_t>(spirvFieldIndex)}, loc, range);
|
|
|
|
+ subSrcVal->setLayoutRule(srcVal->getLayoutRule());
|
|
|
|
+ elements.push_back(
|
|
|
|
+ reconstructValue(subSrcVal, fieldType, dstLR, loc, range));
|
|
|
|
+
|
|
|
|
+ return true;
|
|
|
|
+ });
|
|
|
|
|
|
- // Go over struct fields.
|
|
|
|
- for (const auto *field : recordType->getDecl()->fields()) {
|
|
|
|
- SpirvInstruction *subSrcVal = spvBuilder.createCompositeExtract(
|
|
|
|
- field->getType(), srcVal, {index}, loc, range);
|
|
|
|
- subSrcVal->setLayoutRule(srcVal->getLayoutRule());
|
|
|
|
- elements.push_back(
|
|
|
|
- reconstructValue(subSrcVal, field->getType(), dstLR, loc, range));
|
|
|
|
- ++index;
|
|
|
|
- }
|
|
|
|
auto *result = spvBuilder.createCompositeConstruct(
|
|
auto *result = spvBuilder.createCompositeConstruct(
|
|
valType, elements, srcVal->getSourceLocation(), range);
|
|
valType, elements, srcVal->getSourceLocation(), range);
|
|
result->setLayoutRule(dstLR);
|
|
result->setLayoutRule(dstLR);
|
|
@@ -6947,47 +6992,35 @@ SpirvInstruction *SpirvEmitter::convertVectorToStruct(QualType astStructType,
|
|
SourceRange range) {
|
|
SourceRange range) {
|
|
assert(astStructType->isStructureType());
|
|
assert(astStructType->isStructureType());
|
|
|
|
|
|
- const auto *structDecl = astStructType->getAsStructureType()->getDecl();
|
|
|
|
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
|
|
LowerTypeVisitor lowerTypeVisitor(astContext, spvContext, spirvOptions);
|
|
const StructType *spirvStructType =
|
|
const StructType *spirvStructType =
|
|
lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
|
|
lowerStructType(spirvOptions, lowerTypeVisitor, astStructType);
|
|
-
|
|
|
|
uint32_t vectorIndex = 0;
|
|
uint32_t vectorIndex = 0;
|
|
uint32_t elemCount = 1;
|
|
uint32_t elemCount = 1;
|
|
- uint32_t lastConvertedIndex = 0;
|
|
|
|
llvm::SmallVector<SpirvInstruction *, 4> members;
|
|
llvm::SmallVector<SpirvInstruction *, 4> members;
|
|
- for (auto field = structDecl->field_begin(); field != structDecl->field_end();
|
|
|
|
- field++) {
|
|
|
|
- // Multiple bitfields can share the same storing type. In such case, we only
|
|
|
|
- // want to append the whole storage once.
|
|
|
|
- const size_t astFieldIndex =
|
|
|
|
- std::distance(structDecl->field_begin(), field);
|
|
|
|
- const uint32_t currentFieldIndex =
|
|
|
|
- spirvStructType->getFields()[astFieldIndex].fieldIndex;
|
|
|
|
- if (astFieldIndex > 0 && currentFieldIndex == lastConvertedIndex) {
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
- lastConvertedIndex = currentFieldIndex;
|
|
|
|
-
|
|
|
|
- if (isScalarType(field->getType())) {
|
|
|
|
- members.push_back(spvBuilder.createCompositeExtract(
|
|
|
|
- elemType, vector, {vectorIndex++}, loc, range));
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if (isVectorType(field->getType(), nullptr, &elemCount)) {
|
|
|
|
- llvm::SmallVector<uint32_t, 4> indices;
|
|
|
|
- for (uint32_t i = 0; i < elemCount; ++i)
|
|
|
|
- indices.push_back(vectorIndex++);
|
|
|
|
-
|
|
|
|
- members.push_back(spvBuilder.createVectorShuffle(
|
|
|
|
- astContext.getExtVectorType(elemType, elemCount), vector, vector,
|
|
|
|
- indices, loc, range));
|
|
|
|
- continue;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- assert(false && "unhandled type");
|
|
|
|
- }
|
|
|
|
|
|
+ forEachSpirvField(astStructType->getAs<RecordType>(), spirvStructType,
|
|
|
|
+ [&](size_t spirvFieldIndex, const QualType &fieldType,
|
|
|
|
+ const auto &field) {
|
|
|
|
+ if (isScalarType(fieldType)) {
|
|
|
|
+ members.push_back(spvBuilder.createCompositeExtract(
|
|
|
|
+ elemType, vector, {vectorIndex++}, loc, range));
|
|
|
|
+ return true;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (isVectorType(fieldType, nullptr, &elemCount)) {
|
|
|
|
+ llvm::SmallVector<uint32_t, 4> indices;
|
|
|
|
+ for (uint32_t i = 0; i < elemCount; ++i)
|
|
|
|
+ indices.push_back(vectorIndex++);
|
|
|
|
+
|
|
|
|
+ members.push_back(spvBuilder.createVectorShuffle(
|
|
|
|
+ astContext.getExtVectorType(elemType, elemCount),
|
|
|
|
+ vector, vector, indices, loc, range));
|
|
|
|
+ return true;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ assert(false && "unhandled type");
|
|
|
|
+ return false;
|
|
|
|
+ });
|
|
|
|
|
|
return spvBuilder.createCompositeConstruct(
|
|
return spvBuilder.createCompositeConstruct(
|
|
astStructType, members, vector->getSourceLocation(), range);
|
|
astStructType, members, vector->getSourceLocation(), range);
|