Procházet zdrojové kódy

[spirv] Add support for OO inheritance (#1039)

An field will be created at the very beginning of the struct
for the base class.
Lei Zhang před 7 roky
rodič
revize
e9120a1594

+ 17 - 5
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -93,6 +93,14 @@ inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
   }
   return decl->getType();
 }
+
+/// Returns the number of base classes if this type is a derived class/struct.
+/// Returns zero otherwise.
+inline uint32_t getNumBaseClasses(QualType type) {
+  if (const auto *cxxDecl = type->getAsCXXRecordDecl())
+    return cxxDecl->getNumBases();
+  return 0;
+}
 } // anonymous namespace
 
 std::string StageVar::getSemanticStr() const {
@@ -687,7 +695,8 @@ void DeclResultIdMapper::createFieldCounterVars(
   const auto *recordDecl = recordType->getDecl();
 
   for (const auto *field : recordDecl->fields()) {
-    indices->push_back(field->getFieldIndex()); // Build up the index chain
+    // Build up the index chain
+    indices->push_back(getNumBaseClasses(type) + field->getFieldIndex());
 
     const QualType fieldType = field->getType();
     if (TypeTranslator::isRWAppendConsumeSBuffer(fieldType))
@@ -1412,7 +1421,9 @@ bool DeclResultIdMapper::createStageVars(
         const uint32_t fieldType =
             typeTranslator.translateType(field->getType());
         fields.push_back(theBuilder.createCompositeExtract(
-            fieldType, subValues[field->getFieldIndex()], {arrayIndex}));
+            fieldType,
+            subValues[getNumBaseClasses(type) + field->getFieldIndex()],
+            {arrayIndex}));
       }
       // Compose a new struct out of them
       arrayElements.push_back(
@@ -1438,8 +1449,9 @@ bool DeclResultIdMapper::createStageVars(
       const uint32_t fieldType = typeTranslator.translateType(field->getType());
       uint32_t subValue = 0;
       if (!noWriteBack)
-        subValue = theBuilder.createCompositeExtract(fieldType, *value,
-                                                     {field->getFieldIndex()});
+        subValue = theBuilder.createCompositeExtract(
+            fieldType, *value,
+            {getNumBaseClasses(type) + field->getFieldIndex()});
 
       if (!createStageVars(sigPoint, field, asInput, field->getType(),
                            arraySize, namePrefix, invocationId, &subValue,
@@ -1514,7 +1526,7 @@ bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
   for (const auto *field : structDecl->fields()) {
     const uint32_t fieldType = typeTranslator.translateType(field->getType());
     const uint32_t subValue = theBuilder.createCompositeExtract(
-        fieldType, value, {field->getFieldIndex()});
+        fieldType, value, {getNumBaseClasses(type) + field->getFieldIndex()});
 
     if (!writeBackOutputStream(field, subValue))
       return false;

+ 60 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -426,6 +426,49 @@ const DeclaratorDecl *getReferencedDef(const Expr *expr) {
   return nullptr;
 }
 
+/// Returns the number of base classes if this type is a derived class/struct.
+/// Returns zero otherwise.
+inline uint32_t getNumBaseClasses(QualType type) {
+  if (const auto *cxxDecl = type->getAsCXXRecordDecl())
+    return cxxDecl->getNumBases();
+  return 0;
+}
+
+/// Gets the index sequence of casting a derived object to a base object by
+/// following the cast chain.
+void getBaseClassIndices(const CastExpr *expr,
+                         llvm::SmallVectorImpl<uint32_t> *indices) {
+  assert(expr->getCastKind() == CK_UncheckedDerivedToBase);
+
+  indices->clear();
+
+  QualType derivedType = expr->getSubExpr()->getType();
+  const auto *derivedDecl = derivedType->getAsCXXRecordDecl();
+
+  // Go through the base cast chain: for each of the derived to base cast, find
+  // the index of the base in question in the derived's bases.
+  for (auto pathIt = expr->path_begin(), pathIe = expr->path_end();
+       pathIt != pathIe; ++pathIt) {
+    // The type of the base in question
+    const auto baseType = (*pathIt)->getType();
+
+    uint32_t index = 0;
+    for (auto baseIt = derivedDecl->bases_begin(),
+              baseIe = derivedDecl->bases_end();
+         baseIt != baseIe; ++baseIt, ++index)
+      if (baseIt->getType() == baseType) {
+        indices->push_back(index);
+        break;
+      }
+
+    assert(index < derivedDecl->getNumBases());
+
+    // Continue to proceed the next base in the chain
+    derivedType = baseType;
+    derivedDecl = derivedType->getAsCXXRecordDecl();
+  }
+}
+
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -2123,6 +2166,18 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
         processFlatConversion(toType, evalType, subExprId, expr->getExprLoc());
     return SpirvEvalInfo(valId).setRValue();
   }
+  case CastKind::CK_UncheckedDerivedToBase: {
+    // Find the index sequence of the base to which we are casting
+    llvm::SmallVector<uint32_t, 4> baseIndices;
+    getBaseClassIndices(expr, &baseIndices);
+
+    // Turn them in to SPIR-V constants
+    for (uint32_t i = 0; i < baseIndices.size(); ++i)
+      baseIndices[i] = theBuilder.getConstantUint32(baseIndices[i]);
+
+    auto derivedInfo = doExpr(subExpr);
+    return turnIntoElementPtr(derivedInfo, expr->getType(), baseIndices);
+  }
   default:
     emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
         << expr->getCastKindName() << expr->getSourceRange();
@@ -5345,7 +5400,11 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
     // Append the index of the current level
     const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
     assert(fieldDecl);
-    const uint32_t index = fieldDecl->getFieldIndex();
+    // If we are accessing a derived struct, we need to account for the number
+    // of base structs, since they are placed as fields at the beginning of the
+    // derived struct.
+    const uint32_t index = getNumBaseClasses(indexing->getBase()->getType()) +
+                           fieldDecl->getFieldIndex();
     indices->push_back(rawIndex ? index : theBuilder.getConstantInt32(index));
 
     return base;

+ 10 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -378,6 +378,16 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
     // Collect all fields' types and names.
     llvm::SmallVector<uint32_t, 4> fieldTypes;
     llvm::SmallVector<llvm::StringRef, 4> fieldNames;
+
+    // If this struct is derived from some other struct, place an implicit field
+    // at the very beginning for the base struct.
+    if (const auto *cxxDecl = dyn_cast<CXXRecordDecl>(decl))
+      for (const auto base : cxxDecl->bases()) {
+        fieldTypes.push_back(translateType(base.getType(), rule));
+        fieldNames.push_back("");
+      }
+
+    // Create fields for all members of this struct
     for (const auto *field : decl->fields()) {
       fieldTypes.push_back(translateType(
           field->getType(), rule, isRowMajorMatrix(field->getType(), field)));

+ 90 - 0
tools/clang/test/CodeGenSPIRV/oo.inheritance.hlsl

@@ -0,0 +1,90 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct Base {
+    float4 a;
+    float4 b;
+};
+
+// Make sure we have the correct indices for fields
+// CHECK: OpMemberName %Derived 1 "b"
+// CHECK: OpMemberName %Derived 2 "c"
+// CHECK: OpMemberName %Derived 3 "x"
+
+// Placing the implicit base object at the beginning
+// CHECK: %Derived = OpTypeStruct %Base %v4float %v4float %Base
+struct Derived : Base {
+    float4 b;
+    float4 c;
+    Base   x;
+};
+
+// CHECK: %DerivedAgain = OpTypeStruct %Derived %v4float %v4float
+struct DerivedAgain : Derived {
+    float4 c;
+    float4 d;
+};
+
+float4 main() : SV_Target {
+    Derived d;
+
+    // Accessing a field from the implicit base object
+// CHECK:       [[base:%\d+]] = OpAccessChain %_ptr_Function_Base %d %uint_0
+// CHECK-NEXT: [[base_a:%\d+]] = OpAccessChain %_ptr_Function_v4float [[base]] %int_0
+// CHECK-NEXT:                   OpStore [[base_a]] {{%\d+}}
+    d.a   = 1.;
+
+    // Accessing fields from the derived object (shadowing)
+ // CHECK-NEXT:      [[b:%\d+]] = OpAccessChain %_ptr_Function_v4float %d %int_1
+// CHECK-NEXT:                   OpStore [[b]] {{%\d+}}
+    d.b   = 2.;
+
+    // Accessing fields from the derived object
+// CHECK-NEXT:      [[c:%\d+]] = OpAccessChain %_ptr_Function_v4float %d %int_2
+// CHECK-NEXT:                   OpStore [[c]] {{%\d+}}
+    d.c   = 3.;
+
+    // Embedding another object of the implict base object's type
+// CHECK-NEXT:    [[x_a:%\d+]] = OpAccessChain %_ptr_Function_v4float %d %int_3 %int_0
+// CHECK-NEXT:                   OpStore [[x_a]] {{%\d+}}
+// CHECK-NEXT:    [[x_b:%\d+]] = OpAccessChain %_ptr_Function_v4float %d %int_3 %int_1
+// CHECK-NEXT:                   OpStore [[x_b]] {{%\d+}}
+    d.x.a = 4.;
+    d.x.b = 5.;
+
+    DerivedAgain dd;
+
+    // Accessing a field from the deep implicit base object
+// CHECK-NEXT:   [[base:%\d+]] = OpAccessChain %_ptr_Function_Base %dd %uint_0 %uint_0
+// CHECK-NEXT: [[base_a:%\d+]] = OpAccessChain %_ptr_Function_v4float [[base]] %int_0
+// CHECK-NEXT:                   OpStore [[base_a]] {{%\d+}}
+    dd.a  = 6.;
+    // Accessing a field from the immediate implicit base object
+// CHECK-NEXT:    [[drv:%\d+]] = OpAccessChain %_ptr_Function_Derived %dd %uint_0
+// CHECK-NEXT:  [[drv_b:%\d+]] = OpAccessChain %_ptr_Function_v4float [[drv]] %int_1
+// CHECK-NEXT:                   OpStore [[drv_b]] {{%\d+}}
+    dd.b  = 7.;
+    // Accessing fields from the derived object (shadowing)
+// CHECK-NEXT:      [[c:%\d+]] = OpAccessChain %_ptr_Function_v4float %dd %int_1
+// CHECK-NEXT:                   OpStore [[c]] {{%\d+}}
+    // Accessing fields from the derived object
+    dd.c  = 8.;
+// CHECK-NEXT:      [[d:%\d+]] = OpAccessChain %_ptr_Function_v4float %dd %int_2
+// CHECK-NEXT:                   OpStore [[d]] {{%\d+}}
+    dd.d  = 9.;
+
+    // Make sure reads are good
+// CHECK:        [[base:%\d+]] = OpAccessChain %_ptr_Function_Base %d %uint_0
+// CHECK-NEXT:        {{%\d+}} = OpAccessChain %_ptr_Function_v4float [[base]] %int_0
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %d %int_1
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %d %int_2
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %d %int_3 %int_0
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %d %int_3 %int_1
+    return d.a + d.b + d.c + d.x.a + d.x.b +
+// CHECK:        [[base:%\d+]] = OpAccessChain %_ptr_Function_Base %dd %uint_0 %uint_0
+// CHECK-NEXT:        {{%\d+}} = OpAccessChain %_ptr_Function_v4float [[base]] %int_0
+// CHECK:         [[drv:%\d+]] = OpAccessChain %_ptr_Function_Derived %dd %uint_0
+// CHECK-NEXT:        {{%\d+}} = OpAccessChain %_ptr_Function_v4float [[drv]] %int_1
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %dd %int_1
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_Function_v4float %dd %int_2
+           dd.a + dd.b + dd.c + dd.d;
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -416,6 +416,7 @@ TEST_F(FileTest, StaticMemberInitializer) {
 TEST_F(FileTest, MethodCallOnStaticVar) {
   runFileTest("oo.method.on-static-var.hlsl");
 }
+TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
 
 // For semantics
 // SV_Position, SV_ClipDistance, and SV_CullDistance are covered in