Jelajahi Sumber

[spirv] Use Workgroup storage class for groupshared variables (#788)

Lei Zhang 7 tahun lalu
induk
melakukan
909351314c

+ 4 - 2
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -250,8 +250,10 @@ uint32_t DeclResultIdMapper::createExternVar(const VarDecl *var) {
   auto rule = LayoutRule::Void;
   auto rule = LayoutRule::Void;
   bool isACSBuffer = false; // Whether its {Append|Consume}StructuredBuffer
   bool isACSBuffer = false; // Whether its {Append|Consume}StructuredBuffer
 
 
-  // TODO: Figure out other cases where the storage class should be Uniform.
-  if (auto *t = var->getType()->getAs<RecordType>()) {
+  if (var->getAttr<HLSLGroupSharedAttr>()) {
+    // For CS groupshared variables
+    storageClass = spv::StorageClass::Workgroup;
+  } else if (auto *t = var->getType()->getAs<RecordType>()) {
     const llvm::StringRef typeName = t->getDecl()->getName();
     const llvm::StringRef typeName = t->getDecl()->getName();
     if (typeName == "StructuredBuffer" || typeName == "RWStructuredBuffer" ||
     if (typeName == "StructuredBuffer" || typeName == "RWStructuredBuffer" ||
         typeName == "ByteAddressBuffer" || typeName == "RWByteAddressBuffer" ||
         typeName == "ByteAddressBuffer" || typeName == "RWByteAddressBuffer" ||

+ 26 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.hlsl

@@ -0,0 +1,26 @@
+// Run: %dxc -T cs_6_0 -E main
+
+struct S {
+    float  f1;
+    float3 f2;
+};
+
+// CHECK: %a = OpVariable %_ptr_Workgroup_float Workgroup
+groupshared              float    a;
+// CHECK: %b = OpVariable %_ptr_Workgroup_v3float Workgroup
+groupshared              float3   b;
+// CHECK: %c = OpVariable %_ptr_Workgroup_mat2v3float Workgroup
+groupshared column_major float2x3 c;
+// CHECK: %d = OpVariable %_ptr_Workgroup__arr_v2float_uint_5 Workgroup
+groupshared              float2   d[5];
+// CHECK: %s = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared              S        s;
+
+[numthreads(8, 8, 8)]
+void main(uint2 tid : SV_DispatchThreadID, uint2 gid : SV_GroupID) {
+// Make sure pointers have the correct storage class
+// CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float %s %int_0
+// CHECK: [[d0:%\d+]] = OpAccessChain %_ptr_Workgroup_v2float %d %int_0
+// CHECK:    {{%\d+}} = OpAccessChain %_ptr_Workgroup_float [[d0]] %int_1
+    d[0].y = s.f1;
+}

+ 5 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -883,4 +883,9 @@ TEST_F(FileTest, HullShaderStructure) { runFileTest("hs.structure.hlsl"); }
 // GS: emit vertex and emit primitive
 // GS: emit vertex and emit primitive
 TEST_F(FileTest, GeometryShaderEmit) { runFileTest("gs.emit.hlsl"); }
 TEST_F(FileTest, GeometryShaderEmit) { runFileTest("gs.emit.hlsl"); }
 
 
+// CS: groupshared
+TEST_F(FileTest, ComputeShaderGroupShared) {
+  runFileTest("cs.groupshared.hlsl");
+}
+
 } // namespace
 } // namespace