2
0
Эх сурвалжийг харах

[spirv] Support static struct member functions (#640)

These functions are just translated into standalone functions
through the normal translation path.
Lei Zhang 8 жил өмнө
parent
commit
5f2b598050

+ 32 - 21
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -430,16 +430,23 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Construct the function signature.
   llvm::SmallVector<uint32_t, 4> paramTypes;
 
-  bool isMemberFn = false;
-  // For member function, the first parameter should be the object on which we
-  // are invoking this method.
+  bool isNonStaticMemberFn = false;
   if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
-    isMemberFn = true;
-    const uint32_t valueType = typeTranslator.translateType(
-        memberFn->getThisType(astContext)->getPointeeType());
-    const uint32_t ptrType =
-        theBuilder.getPointerType(valueType, spv::StorageClass::Function);
-    paramTypes.push_back(ptrType);
+    isNonStaticMemberFn = !memberFn->isStatic();
+
+    if (isNonStaticMemberFn) {
+      // For non-static member function, the first parameter should be the
+      // object on which we are invoking this method.
+      const uint32_t valueType = typeTranslator.translateType(
+          memberFn->getThisType(astContext)->getPointeeType());
+      const uint32_t ptrType =
+          theBuilder.getPointerType(valueType, spv::StorageClass::Function);
+      paramTypes.push_back(ptrType);
+    }
+
+    // Prefix the function name with the struct name
+    if (const auto *st = dyn_cast<CXXRecordDecl>(memberFn->getDeclContext()))
+      funcName = st->getName().str() + "." + funcName;
   }
 
   for (const auto *param : decl->params()) {
@@ -452,7 +459,7 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes);
   theBuilder.beginFunction(funcType, retType, funcName, funcId);
 
-  if (isMemberFn) {
+  if (isNonStaticMemberFn) {
     // Remember the parameter for the this object so later we can handle
     // CXXThisExpr correctly.
     curThis = theBuilder.addFnParam(paramTypes[0], "param.this");
@@ -461,7 +468,8 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   // Create all parameters.
   for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
     const ParmVarDecl *paramDecl = decl->getParamDecl(i);
-    (void)declIdMapper.createFnParam(paramTypes[i + isMemberFn], paramDecl);
+    (void)declIdMapper.createFnParam(paramTypes[i + isNonStaticMemberFn],
+                                     paramDecl);
   }
 
   if (decl->hasBody()) {
@@ -1118,20 +1126,23 @@ uint32_t SPIRVEmitter::processCall(const CallExpr *callExpr) {
 
   if (callee) {
     const auto numParams = callee->getNumParams();
-    bool isMemberCall = false;
+    bool isNonStaticMemberCall = false;
 
     llvm::SmallVector<uint32_t, 4> params; // Temporary variables
     llvm::SmallVector<uint32_t, 4> args;   // Evaluated arguments
 
-    // For normal member calls, evaluate the object and pass it as the first
-    // argument.
     if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr)) {
-      isMemberCall = true;
-      const auto *object = memberCall->getImplicitObjectArgument();
-      args.push_back(doExpr(object));
-      // We do not need to create a new temporary variable for the this object.
-      // Use the evaluated argument.
-      params.push_back(args.back());
+      isNonStaticMemberCall =
+          !cast<CXXMethodDecl>(memberCall->getCalleeDecl())->isStatic();
+      if (isNonStaticMemberCall) {
+        // For non-static member calls, evaluate the object and pass it as the
+        // first argument.
+        const auto *object = memberCall->getImplicitObjectArgument();
+        args.push_back(doExpr(object));
+        // We do not need to create a new temporary variable for the this
+        // object. Use the evaluated argument.
+        params.push_back(args.back());
+      }
     }
 
     // Evaluate parameters
@@ -1176,7 +1187,7 @@ uint32_t SPIRVEmitter::processCall(const CallExpr *callExpr) {
     for (uint32_t i = 0; i < numParams; ++i) {
       const auto *param = callee->getParamDecl(i);
       if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
-        const uint32_t index = i + isMemberCall;
+        const uint32_t index = i + isNonStaticMemberCall;
         const uint32_t typeId = typeTranslator.translateType(param->getType());
         const uint32_t value = theBuilder.createLoad(typeId, params[index]);
         theBuilder.createStore(args[index], value);

+ 44 - 19
tools/clang/test/CodeGenSPIRV/method.struct.method.hlsl

@@ -28,17 +28,28 @@ struct S {
     float fn_unused() {
         return 2.4;
     }
+
+    // Static method
+    static float fn_static() {
+        return 3.5;
+    }
 };
 
 struct T {
-  S s;
+    S s;
 
-  // Calling method in nested struct
-  float fn_nested() {
-    return s.fn_ref();
-  }
+    // Calling method in nested struct
+    float fn_nested() {
+        return s.fn_ref();
+    }
+
+    // Static method with the same name as S
+    static float fn_static() {
+        return 6.7;
+    }
 };
 
+// CHECK:   [[ft_f32:%\d+]] = OpTypeFunction %float
 // CHECK:     [[ft_S:%\d+]] = OpTypeFunction %float %_ptr_Function_S
 // CHECK: [[ft_S_f32:%\d+]] = OpTypeFunction %float %_ptr_Function_S %_ptr_Function_float
 // CHECK:     [[ft_T:%\d+]] = OpTypeFunction %float %_ptr_Function_T
@@ -48,24 +59,29 @@ struct T {
 // CHECK-NEXT:           %s = OpVariable %_ptr_Function_S Function
 // CHECK-NEXT:           %t = OpVariable %_ptr_Function_T Function
 // CHECK-NEXT: %param_var_c = OpVariable %_ptr_Function_float Function
-// CHECK:          {{%\d+}} = OpFunctionCall %float %fn_no_ref %s
-// CHECK:          {{%\d+}} = OpFunctionCall %float %fn_ref %s
-// CHECK:          {{%\d+}} = OpFunctionCall %float %fn_call %s %param_var_c
-// CHECK:          {{%\d+}} = OpFunctionCall %float %fn_nested %t
+// CHECK:          {{%\d+}} = OpFunctionCall %float %S_fn_no_ref %s
+// CHECK:          {{%\d+}} = OpFunctionCall %float %S_fn_ref %s
+// CHECK:          {{%\d+}} = OpFunctionCall %float %S_fn_call %s %param_var_c
+// CHECK:          {{%\d+}} = OpFunctionCall %float %T_fn_nested %t
+// CHECK:          {{%\d+}} = OpFunctionCall %float %S_fn_static
+// CHECK:          {{%\d+}} = OpFunctionCall %float %T_fn_static
+// CHECK:          {{%\d+}} = OpFunctionCall %float %S_fn_static
+// CHECK:          {{%\d+}} = OpFunctionCall %float %T_fn_static
 // CHECK:                     OpFunctionEnd
 float main() : A {
     S s;
     T t;
-    return s.fn_no_ref() + s.fn_ref() + s.fn_call(5.0) + t.fn_nested();
+    return s.fn_no_ref() + s.fn_ref() + s.fn_call(5.0) + t.fn_nested() +
+           s.fn_static() + t.fn_static() + S::fn_static() + T::fn_static();
 }
 
-// CHECK:         %fn_no_ref = OpFunction %float None [[ft_S]]
+// CHECK:       %S_fn_no_ref = OpFunction %float None [[ft_S]]
 // CHECK-NEXT:   %param_this = OpFunctionParameter %_ptr_Function_S
 // CHECK-NEXT:   %bb_entry_0 = OpLabel
 // CHECK:                      OpFunctionEnd
 
 
-// CHECK:            %fn_ref = OpFunction %float None [[ft_S]]
+// CHECK:          %S_fn_ref = OpFunction %float None [[ft_S]]
 // CHECK-NEXT: %param_this_0 = OpFunctionParameter %_ptr_Function_S
 // CHECK-NEXT:   %bb_entry_1 = OpLabel
 // CHECK:           {{%\d+}} = OpAccessChain %_ptr_Function_float %param_this_0 %int_0
@@ -73,26 +89,35 @@ float main() : A {
 // CHECK:                      OpFunctionEnd
 
 
-// CHECK:            %fn_call = OpFunction %float None [[ft_S_f32]]
+// CHECK:          %S_fn_call = OpFunction %float None [[ft_S_f32]]
 // CHECK-NEXT:  %param_this_1 = OpFunctionParameter %_ptr_Function_S
 // CHECK-NEXT:             %c = OpFunctionParameter %_ptr_Function_float
 // CHECK-NEXT:    %bb_entry_2 = OpLabel
 // CHECK-NEXT: %param_var_c_0 = OpVariable %_ptr_Function_float Function
-// CHECK:            {{%\d+}} = OpFunctionCall %float %fn_param %param_this_1 %param_var_c_0
+// CHECK:            {{%\d+}} = OpFunctionCall %float %S_fn_param %param_this_1 %param_var_c_0
 // CHECK:                       OpFunctionEnd
 
-
-// CHECK:         %fn_nested = OpFunction %float None [[ft_T]]
+// CHECK:       %T_fn_nested = OpFunction %float None [[ft_T]]
 // CHECK-NEXT: %param_this_2 = OpFunctionParameter %_ptr_Function_T
 // CHECK-NEXT:   %bb_entry_3 = OpLabel
 // CHECK:       [[t_s:%\d+]] = OpAccessChain %_ptr_Function_S %param_this_2 %int_0
-// CHECK:           {{%\d+}} = OpFunctionCall %float %fn_ref [[t_s]]
+// CHECK:           {{%\d+}} = OpFunctionCall %float %S_fn_ref [[t_s]]
 // CHECK:                      OpFunctionEnd
 
 
-// CHECK:          %fn_param = OpFunction %float None [[ft_S_f32]]
+// CHECK:        %S_fn_static = OpFunction %float None [[ft_f32]]
+// CHECK-NEXT:    %bb_entry_4 = OpLabel
+// CHECK:                       OpFunctionEnd
+
+
+// CHECK:        %T_fn_static = OpFunction %float None [[ft_f32]]
+// CHECK-NEXT:    %bb_entry_5 = OpLabel
+// CHECK:                       OpFunctionEnd
+
+
+// CHECK:        %S_fn_param = OpFunction %float None [[ft_S_f32]]
 // CHECK-NEXT: %param_this_3 = OpFunctionParameter %_ptr_Function_S
 // CHECK-NEXT:          %c_0 = OpFunctionParameter %_ptr_Function_float
-// CHECK-NEXT:   %bb_entry_4 = OpLabel
+// CHECK-NEXT:   %bb_entry_6 = OpLabel
 // CHECK:           {{%\d+}} = OpAccessChain %_ptr_Function_float %param_this_3 %int_0
 // CHECK:                      OpFunctionEnd