|
@@ -2193,6 +2193,36 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) {
|
|
|
return processCall(callExpr);
|
|
|
}
|
|
|
|
|
|
+SpirvInstruction *SpirvEmitter::getBaseOfMemberFunction(QualType objectType,
|
|
|
+ SpirvInstruction * objInstr,
|
|
|
+ const CXXMethodDecl* memberFn,
|
|
|
+ SourceLocation loc) {
|
|
|
+ // If objectType is different from the parent of memberFn, memberFn should be
|
|
|
+ // defined in a base struct/class of objectType. We create OpAccessChain with
|
|
|
+ // index 0 while iterating bases of objectType until we find the base with
|
|
|
+ // the definition of memberFn.
|
|
|
+ if (const auto *ptrType = objectType->getAs<PointerType>()) {
|
|
|
+ if (const auto *recordType = ptrType->getPointeeType()->getAs<RecordType>()) {
|
|
|
+ const auto *parentDeclOfMemberFn = memberFn->getParent();
|
|
|
+ if (recordType->getDecl() != parentDeclOfMemberFn) {
|
|
|
+ const auto *cxxRecordDecl = dyn_cast<CXXRecordDecl>(recordType->getDecl());
|
|
|
+ auto *zero =
|
|
|
+ spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0));
|
|
|
+ for (auto baseItr = cxxRecordDecl->bases_begin(), itrEnd = cxxRecordDecl->bases_end();
|
|
|
+ baseItr != itrEnd; baseItr++) {
|
|
|
+ const auto *baseType = baseItr->getType()->getAs<RecordType>();
|
|
|
+ objectType = astContext.getPointerType(baseType->desugar());
|
|
|
+ objInstr = spvBuilder.createAccessChain(objectType,
|
|
|
+ objInstr, {zero},
|
|
|
+ loc);
|
|
|
+ if (baseType->getDecl() == parentDeclOfMemberFn) return objInstr;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return nullptr;
|
|
|
+}
|
|
|
+
|
|
|
SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
|
|
|
const FunctionDecl *callee = getCalleeDefinition(callExpr);
|
|
|
|
|
@@ -2243,6 +2273,10 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
|
|
|
|
|
|
objectType = object->getType();
|
|
|
objInstr = doExpr(object);
|
|
|
+ if (auto *accessToBaseInstr = getBaseOfMemberFunction(objectType, objInstr, memberFn, memberCall->getExprLoc())) {
|
|
|
+ objInstr = accessToBaseInstr;
|
|
|
+ objectType = accessToBaseInstr->getAstResultType();
|
|
|
+ }
|
|
|
|
|
|
// If not already a variable, we need to create a temporary variable and
|
|
|
// pass the object pointer to the function. Example:
|