Przeglądaj źródła

[spirv] Fix static variable init from different storage class (#969)

If the initializer is in a different storage class, we sometimes
need to decompose it and assign basic components recursively.

Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/968
Lei Zhang 7 lat temu
rodzic
commit
273e2a45af

+ 7 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -4018,7 +4018,10 @@ void SPIRVEmitter::initOnce(QualType varType, std::string varName,
   theBuilder.setInsertPoint(todoBB);
   // Do initialization and mark done
   if (varInit) {
-    theBuilder.createStore(varPtr, doExpr(varInit));
+    storeValue(
+        // Static function variable are of private storage class
+        SpirvEvalInfo(varPtr).setStorageClass(spv::StorageClass::Private),
+        doExpr(varInit), varInit->getType());
   } else {
     const auto typeId = typeTranslator.translateType(varType);
     theBuilder.createStore(varPtr, theBuilder.getConstantNull(typeId));
@@ -7109,15 +7112,15 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
 
   // Initialize all global variables at the beginning of the wrapper
   for (const VarDecl *varDecl : toInitGloalVars) {
-    const auto id = declIdMapper.getDeclResultId(varDecl);
+    const auto varInfo = declIdMapper.getDeclResultId(varDecl);
     if (const auto *init = varDecl->getInit()) {
-      theBuilder.createStore(id, doExpr(init));
+      storeValue(varInfo, doExpr(init), varDecl->getType());
 
       // Update counter variable associatd with global variables
       tryToAssignCounterVar(varDecl, init);
     } else {
       const auto typeId = typeTranslator.translateType(varDecl->getType());
-      theBuilder.createStore(id, theBuilder.getConstantNull(typeId));
+      theBuilder.createStore(varInfo, theBuilder.getConstantNull(typeId));
     }
   }
 

+ 49 - 0
tools/clang/test/CodeGenSPIRV/var.init.cross-storage-class.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// Tests struct variable initialization from different storage class
+
+struct S {
+    float4 pos : POSITION;
+};
+
+cbuffer Constants {
+  S uniform_struct;
+}
+
+// uniform_struct is of type %S, while (fn_)private_struct and s is of type %S_0.
+// So we need to decompose and assign for uniform_struct -> (fn_)private_struct,
+// but not private_struct -> s.
+
+// CHECK-LABEL:           %main = OpFunction
+
+// CHECK:          [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_S %var_Constants %int_0
+// CHECK-NEXT: [[uniform:%\d+]] = OpLoad %S [[ptr]]
+// CHECK-NEXT:     [[vec:%\d+]] = OpCompositeExtract %v4float [[uniform]] 0
+// CHECK-NEXT:     [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %private_struct %uint_0
+// CHECK-NEXT:                    OpStore [[ptr]] [[vec]]
+static const S private_struct = uniform_struct; // Unifrom -> Private
+
+float4 foo();
+
+S main()
+// CHECK-LABEL:       %src_main = OpFunction
+{
+    S Out;
+// CHECK:      [[private:%\d+]] = OpLoad %S_0 %private_struct
+// CHECK-NEXT:                    OpStore %s [[private]]
+    S s = private_struct; // Private -> Function
+    Out.pos = s.pos + foo();
+    return Out;
+}
+
+float4 foo()
+// CHECK-LABEL:            %foo = OpFunction
+{
+// CHECK:          [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_S %var_Constants %int_0
+// CHECK-NEXT: [[uniform:%\d+]] = OpLoad %S [[ptr]]
+// CHECK-NEXT:     [[vec:%\d+]] = OpCompositeExtract %v4float [[uniform]] 0
+// CHECK-NEXT:     [[ptr:%\d+]] = OpAccessChain %_ptr_Private_v4float %fn_private_struct %uint_0
+// CHECK-NEXT:                    OpStore [[ptr]] [[vec]]
+    static S fn_private_struct = uniform_struct; // Uniform -> Private
+    return fn_private_struct.pos;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -102,6 +102,9 @@ TEST_F(FileTest, VarInitTbuffer) {
   runFileTest("var.init.tbuffer.hlsl", FileTest::Expect::Warning);
 }
 TEST_F(FileTest, VarInitOpaque) { runFileTest("var.init.opaque.hlsl"); }
+TEST_F(FileTest, VarInitCrossStorageClass) {
+  runFileTest("var.init.cross-storage-class.hlsl");
+}
 TEST_F(FileTest, StaticVar) { runFileTest("var.static.hlsl"); }
 
 // For prefix/postfix increment/decrement