Bladeren bron

[spirv] Add support for struct/class static members (#714)

These static members are translated into SPIR-V variables of
the Private storage class.
Lei Zhang 8 jaren geleden
bovenliggende
commit
9cb179024b

+ 20 - 9
docs/SPIR-V.rst

@@ -1165,6 +1165,17 @@ HLSL Intrinsic Function   GLSL Extended Instruction
 ``trunc``               ``Trunc``
 ======================= ===============================
 
+HLSL OO features
+================
+
+A HLSL struct/class member method is translated into a normal SPIR-V function,
+whose signature has an additional first parameter for the struct/class called
+upon. Every calling site of the method is generated to pass in the object as
+the first argument.
+
+HLSL struct/class static member variables are translated into SPIR-V variables
+in the ``Private`` storage class.
+
 HLSL Methods
 ============
 
@@ -1691,7 +1702,7 @@ and are translated to SPIR-V execution modes according to the table below:
 +-------------------------+---------------------+--------------------------+
 
 The ``patchconstfunc`` attribute does not have a direct equivalent in SPIR-V.
-It specifies the name of the Patch Constant Function. This function is run only 
+It specifies the name of the Patch Constant Function. This function is run only
 once per patch. This is further described below.
 
 InputPatch and OutputPatch
@@ -1707,18 +1718,18 @@ OutputPatch is an array containing ``N`` elements (where ``N`` is the number of
 output vertices). Each element of the array contains information about an
 output vertex. OutputPatch may also be passed to the patch constant function.
 
-The SPIR-V ``InvocationID`` (``SV_OutputControlPointID`` in HLSL) is used to index 
+The SPIR-V ``InvocationID`` (``SV_OutputControlPointID`` in HLSL) is used to index
 into the InputPatch and OutputPatch arrays to read/write information for the given
 vertex.
 
-The hull main entry function in HLSL returns only one value (say, of type ``T``), but 
+The hull main entry function in HLSL returns only one value (say, of type ``T``), but
 that function is in fact executed once for each control point. The Vulkan spec requires that
-"Tessellation control shader per-vertex output variables and blocks, and tessellation control, 
+"Tessellation control shader per-vertex output variables and blocks, and tessellation control,
 tessellation evaluation, and geometry shader per-vertex input variables and blocks are required
-to be declared as arrays, with each element representing input or output values for a single vertex 
+to be declared as arrays, with each element representing input or output values for a single vertex
 of a multi-vertex primitive". Therefore, we need to create a stage output variable that is an array
-with elements of type ``T``. The number of elements of the array is equal to the number of 
-output control points. Each final output control point is written into the corresponding element in 
+with elements of type ``T``. The number of elements of the array is equal to the number of
+output control points. Each final output control point is written into the corresponding element in
 the array using SV_OutputControlPointID as the index.
 
 Patch Constant Function
@@ -1729,8 +1740,8 @@ main entry function, and then use an ``OpControlBarrier`` to wait for all vertex
 processing to finish. After the barrier, *only* the first thread (with InvocationID of 0)
 will invoke the patch constant function.
 
-The information resulting from the patch constant function will also be returned 
+The information resulting from the patch constant function will also be returned
 as stage output variables. The output struct of the patch constant function must include
 ``SV_TessFactor`` and ``SV_InsideTessFactor`` fields which will translate to
-``TessLevelOuter`` and ``TessLevelInner`` builtin variables, respectively. And the rest 
+``TessLevelOuter`` and ``TessLevelInner`` builtin variables, respectively. And the rest
 will be flattened and translated into normal stage output variables, one for each field.

+ 1 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -90,7 +90,7 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const NamedDecl *decl) {
 
       return {elemId, info->storageClass, info->layoutRule};
     } else {
-      return {info->resultId, info->storageClass, info->layoutRule};
+      return *info;
     }
 
   assert(false && "found unregistered decl");

+ 5 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -195,6 +195,11 @@ public:
         : resultId(result), storageClass(sc), layoutRule(lr),
           indexInCTBuffer(indexInCTB) {}
 
+    /// Implicit conversion to SpirvEvalInfo.
+    operator SpirvEvalInfo() const {
+      return SpirvEvalInfo(resultId, storageClass, layoutRule);
+    }
+
     uint32_t resultId;
     spv::StorageClass storageClass;
     /// Layout rule for this decl.

+ 23 - 9
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -604,7 +604,7 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
   // File scope variables (static "global" and "local" variables) belongs to
   // the Private storage class, while function scope variables (normal "local"
   // variables) belongs to the Function storage class.
-  if (!decl->isExternallyVisible()) {
+  if (!decl->isExternallyVisible() || decl->isStaticDataMember()) {
     // Note: cannot move varType outside of this scope because it generates
     // SPIR-V types without decorations, while external visible variable should
     // have SPIR-V type with decorations.
@@ -1193,10 +1193,12 @@ SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
   const auto *base = collectArrayStructIndices(expr, &indices);
   auto info = doExpr(base);
 
-  const uint32_t ptrType = theBuilder.getPointerType(
-      typeTranslator.translateType(expr->getType(), info.layoutRule),
-      info.storageClass);
-  info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
+  if (!indices.empty()) {
+    const uint32_t ptrType = theBuilder.getPointerType(
+        typeTranslator.translateType(expr->getType(), info.layoutRule),
+        info.storageClass);
+    info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
+  }
 
   return info;
 }
@@ -2596,13 +2598,16 @@ SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
 
 SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   llvm::SmallVector<uint32_t, 4> indices;
+
   const Expr *base = collectArrayStructIndices(expr, &indices);
   auto info = doExpr(base);
 
-  const uint32_t ptrType = theBuilder.getPointerType(
-      typeTranslator.translateType(expr->getType(), info.layoutRule),
-      info.storageClass);
-  info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
+  if (!indices.empty()) {
+    const uint32_t ptrType = theBuilder.getPointerType(
+        typeTranslator.translateType(expr->getType(), info.layoutRule),
+        info.storageClass);
+    info.resultId = theBuilder.createAccessChain(ptrType, info, indices);
+  }
 
   return info;
 }
@@ -3539,6 +3544,15 @@ SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 const Expr *SPIRVEmitter::collectArrayStructIndices(
     const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices) {
   if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
+    // First check whether this is referring to a static member. If it is, we
+    // create a DeclRefExpr for it.
+    if (auto *varDecl = dyn_cast<VarDecl>(indexing->getMemberDecl()))
+      if (varDecl->isStaticDataMember())
+        return DeclRefExpr::Create(
+            astContext, NestedNameSpecifierLoc(), SourceLocation(), varDecl,
+            /*RefersToEnclosingVariableOrCapture=*/false, SourceLocation(),
+            varDecl->getType(), VK_LValue);
+
     const Expr *base = collectArrayStructIndices(
         indexing->getBase()->IgnoreParenNoopCasts(astContext), indices);
 

+ 0 - 0
tools/clang/test/CodeGenSPIRV/method.class.method.hlsl → tools/clang/test/CodeGenSPIRV/oo.class.method.hlsl


+ 60 - 0
tools/clang/test/CodeGenSPIRV/oo.class.static.member.hlsl

@@ -0,0 +1,60 @@
+// Run: %dxc -T ps_6_0 -E main
+
+class S {
+    float  a;
+    float4 b;
+};
+
+class T {
+    static float4 M;
+    static S      N;
+
+    static const float4 U;
+
+    int val;
+};
+
+// CHECK: [[v4fc:%\d+]] = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
+
+// CHECK: %M = OpVariable %_ptr_Private_v4float Private [[v4fc]]
+// CHECK: %N = OpVariable %_ptr_Private_S Private
+// CHECK: %U = OpVariable %_ptr_Private_v4float Private [[v4fc]]
+
+float4 T::M = float4(1., 2., 3., 4.);
+S      T::N = {5.0, 1., 2., 3., 4.};
+
+const float4 T::U = float4(1., 2., 3., 4.);
+
+// T::M is intialized using embeded initializer in the variable declaration.
+// T::N is intialized at the beginning of the main function.
+
+// CHECK-LABEL: %main = OpFunction
+// CHECK:      [[v1to4:%\d+]] = OpCompositeConstruct %v4float %float_1 %float_2 %float_3 %float_4
+// CHECK-NEXT: [[v1to5:%\d+]] = OpCompositeConstruct %S %float_5 [[v1to4]]
+// CHECK-NEXT:                  OpStore %N [[v1to5]]
+
+// CHECK-LABEL: %src_main = OpFunction
+float4 main(float4 input: A) : SV_Target {
+    T t;
+
+// CHECK: OpStore %M {{%\d+}}
+    T::M = input;
+// CHECK: OpStore %M {{%\d+}}
+    t.M = input;
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Private_float %N %int_0
+// CHECK-NEXT:           OpStore [[ptr]] %float_1
+    T::N.a = 1.0;
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT:           OpStore [[ptr]] {{%\d+}}
+    t.N.b = input;
+
+// CHECK:      {{%\d+}} = OpLoad %v4float %M
+// CHECK:      {{%\d+}} = OpLoad %v4float %M
+// CHECK:  [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v4float [[ptr]]
+// CHECK:  [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v4float [[ptr]]
+// CHECK:      {{%\d+}} = OpLoad %v4float %U
+// CHECK:      {{%\d+}} = OpLoad %v4float %U
+    return T::M + t.M + T::N.b + t.N.b + T::U + t.U;
+}

+ 0 - 0
tools/clang/test/CodeGenSPIRV/method.struct.method.hlsl → tools/clang/test/CodeGenSPIRV/oo.struct.method.hlsl


+ 60 - 0
tools/clang/test/CodeGenSPIRV/oo.struct.static.member.hlsl

@@ -0,0 +1,60 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S {
+    float  a;
+    float4 b;
+};
+
+struct T {
+    static float4 M;
+    static S      N;
+
+    static const float4 U;
+
+    int val;
+};
+
+// CHECK: [[v4fc:%\d+]] = OpConstantComposite %v4float %float_1 %float_2 %float_3 %float_4
+
+// CHECK: %M = OpVariable %_ptr_Private_v4float Private [[v4fc]]
+// CHECK: %N = OpVariable %_ptr_Private_S Private
+// CHECK: %U = OpVariable %_ptr_Private_v4float Private [[v4fc]]
+
+float4 T::M = float4(1., 2., 3., 4.);
+S      T::N = {5.0, 1., 2., 3., 4.};
+
+const float4 T::U = float4(1., 2., 3., 4.);
+
+// T::M is intialized using embeded initializer in the variable declaration.
+// T::N is intialized at the beginning of the main function.
+
+// CHECK-LABEL: %main = OpFunction
+// CHECK:      [[v1to4:%\d+]] = OpCompositeConstruct %v4float %float_1 %float_2 %float_3 %float_4
+// CHECK-NEXT: [[v1to5:%\d+]] = OpCompositeConstruct %S %float_5 [[v1to4]]
+// CHECK-NEXT:                  OpStore %N [[v1to5]]
+
+// CHECK-LABEL: %src_main = OpFunction
+float4 main(float4 input: A) : SV_Target {
+    T t;
+
+// CHECK: OpStore %M {{%\d+}}
+    T::M = input;
+// CHECK: OpStore %M {{%\d+}}
+    t.M = input;
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Private_float %N %int_0
+// CHECK-NEXT:           OpStore [[ptr]] %float_1
+    T::N.a = 1.0;
+// CHECK: [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT:           OpStore [[ptr]] {{%\d+}}
+    t.N.b = input;
+
+// CHECK:      {{%\d+}} = OpLoad %v4float %M
+// CHECK:      {{%\d+}} = OpLoad %v4float %M
+// CHECK:  [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v4float [[ptr]]
+// CHECK:  [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %N %int_1
+// CHECK-NEXT: {{%\d+}} = OpLoad %v4float [[ptr]]
+// CHECK:      {{%\d+}} = OpLoad %v4float %U
+// CHECK:      {{%\d+}} = OpLoad %v4float %U
+    return T::M + t.M + T::N.b + t.N.b + T::U + t.U;
+}

+ 8 - 8
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -324,14 +324,14 @@ TEST_F(FileTest, ControlFlowConditionalOp) { runFileTest("cf.cond-op.hlsl"); }
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 TEST_F(FileTest, FunctionInOutParam) { runFileTest("fn.param.inout.hlsl"); }
 
-// For struct methods
-TEST_F(FileTest, StructMethodCallNormal) {
-  runFileTest("method.struct.method.hlsl");
-}
-
-// For class methods
-TEST_F(FileTest, ClassMethodCallNormal) {
-  runFileTest("method.class.method.hlsl");
+// For OO features
+TEST_F(FileTest, StructMethodCall) { runFileTest("oo.struct.method.hlsl"); }
+TEST_F(FileTest, ClassMethodCall) { runFileTest("oo.class.method.hlsl"); }
+TEST_F(FileTest, StructStaticMember) {
+  runFileTest("oo.struct.static.member.hlsl");
+}
+TEST_F(FileTest, ClassStaticMember) {
+  runFileTest("oo.struct.static.member.hlsl");
 }
 
 // For semantics