2
0
Эх сурвалжийг харах

[spirv] Fix returning from non-Function storage class (#712)

If we are returning some struct value not in the Function storage
class, we need to decompose it and write each component to the
corresponding component of a temporary variable and then return
that temporary variable.

Also fixed a crash regarding implicitly generated constructors
and destructors for structs.
Lei Zhang 8 жил өмнө
parent
commit
2b4a131e6d

+ 5 - 5
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -174,8 +174,8 @@ DeclResultIdMapper::createVarOfExplicitLayoutStruct(const DeclContext *decl,
   llvm::SmallVector<uint32_t, 4> fieldTypes;
   llvm::SmallVector<llvm::StringRef, 4> fieldNames;
   for (const auto *subDecl : decl->decls()) {
-    // Implicit generated struct declarations should be ignored.
-    if (isa<CXXRecordDecl>(subDecl) && subDecl->isImplicit())
+    // Ignore implicit generated struct declarations/constructors/destructors.
+    if (subDecl->isImplicit())
       continue;
 
     // The field can only be FieldDecl (for normal structs) or VarDecl (for
@@ -567,9 +567,9 @@ bool DeclResultIdMapper::createStageVars(const DeclaratorDecl *decl,
 
     // Error out when the given semantic is invalid in this shader model
     if (hlsl::SigPoint::GetInterpretation(
-                             semantic->GetKind(), sigPoint->GetKind(),
-                             shaderModel.GetMajor(), shaderModel.GetMinor()) ==
-                             hlsl::DXIL::SemanticInterpretationKind::NA) {
+            semantic->GetKind(), sigPoint->GetKind(), shaderModel.GetMajor(),
+            shaderModel.GetMinor()) ==
+        hlsl::DXIL::SemanticInterpretationKind::NA) {
       emitError("invalid semantic %0 for shader module %1")
           << semanticStr << shaderModel.GetName();
       return false;

+ 20 - 5
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1088,10 +1088,25 @@ void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt) {
 }
 
 void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
-  if (const auto *retVal = stmt->getRetValue())
-    theBuilder.createReturnValue(doExpr(retVal));
-  else
+  if (const auto *retVal = stmt->getRetValue()) {
+    const auto retInfo = doExpr(retVal);
+    const auto retType = retVal->getType();
+    if (retInfo.storageClass != spv::StorageClass::Function &&
+        retType->isStructureType()) {
+      // We are returning some value from a non-Function storage class. Need to
+      // create a temporary variable to "convert" the value to Function storage
+      // class and then return.
+      const uint32_t valType = typeTranslator.translateType(retType);
+      const uint32_t tempVar = theBuilder.addFnVar(valType, "temp.var.ret");
+      storeValue(tempVar, retInfo, retType);
+
+      theBuilder.createReturnValue(theBuilder.createLoad(valType, tempVar));
+    } else {
+      theBuilder.createReturnValue(retInfo);
+    }
+  } else {
     theBuilder.createReturn();
+  }
 
   // Some statements that alter the control flow (break, continue, return, and
   // discard), require creation of a new basic block to hold any statement that
@@ -2817,8 +2832,8 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
   } else if (const auto *recordType = valType->getAs<RecordType>()) {
     uint32_t index = 0;
     for (const auto *decl : recordType->getDecl()->decls()) {
-      // Implicit generated struct declarations should be ignored.
-      if (isa<CXXRecordDecl>(decl) && decl->isImplicit())
+      // Ignore implicit generated struct declarations/constructors/destructors.
+      if (decl->isImplicit())
         continue;
 
       const auto *field = cast<FieldDecl>(decl);

+ 2 - 2
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -545,8 +545,8 @@ TypeTranslator::getLayoutDecorations(const DeclContext *decl, LayoutRule rule) {
   uint32_t offset = 0, index = 0;
 
   for (const auto *field : decl->decls()) {
-    // Implicit generated struct declarations should be ignored.
-    if (isa<CXXRecordDecl>(field) && field->isImplicit())
+    // Ignore implicit generated struct declarations/constructors/destructors.
+    if (field->isImplicit())
       continue;
 
     // The field can only be FieldDecl (for normal structs) or VarDecl (for

+ 34 - 0
tools/clang/test/CodeGenSPIRV/cf.return.storage-class.hlsl

@@ -0,0 +1,34 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct BufferType {
+    float     a;
+    float3    b;
+    float3x2  c;
+};
+
+RWStructuredBuffer<BufferType> sbuf;  // %BufferType
+
+// CHECK: %retSBuffer5 = OpFunction %BufferType_0 None {{%\d+}}
+BufferType retSBuffer5() {            // BufferType_0
+// CHECK:    %temp_var_ret = OpVariable %_ptr_Function_BufferType_0 Function
+
+// CHECK-NEXT: [[sbuf:%\d+]] = OpAccessChain %_ptr_Uniform_BufferType %sbuf %int_0 %uint_5
+// CHECK-NEXT:  [[val:%\d+]] = OpLoad %BufferType [[sbuf]]
+// CHECK-NEXT:    [[a:%\d+]] = OpCompositeExtract %float [[val]] 0
+// CHECK-NEXT: [[tmp0:%\d+]] = OpAccessChain %_ptr_Function_float %temp_var_ret %uint_0
+// CHECK-NEXT:                 OpStore [[tmp0]] [[a]]
+// CHECK-NEXT:    [[b:%\d+]] = OpCompositeExtract %v3float [[val]] 1
+// CHECK-NEXT: [[tmp1:%\d+]] = OpAccessChain %_ptr_Function_v3float %temp_var_ret %uint_1
+// CHECK-NEXT:                 OpStore [[tmp1]] [[b]]
+// CHECK-NEXT:    [[c:%\d+]] = OpCompositeExtract %mat3v2float [[val]] 2
+// CHECK-NEXT: [[tmp2:%\d+]] = OpAccessChain %_ptr_Function_mat3v2float %temp_var_ret %uint_2
+// CHECK-NEXT:                 OpStore [[tmp2]] [[c]]
+// CHECK-NEXT:  [[tmp:%\d+]] = OpLoad %BufferType_0 %temp_var_ret
+// CHECK-NEXT:       OpReturnValue [[tmp]]
+// CHECK-NEXT:       OpFunctionEnd
+    return sbuf[5];
+}
+
+void main() {
+    sbuf[6] = retSBuffer5();
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -310,6 +310,9 @@ TEST_F(FileTest, EarlyReturnFloat4) {
   runFileTest("cf.return.early.float4.hlsl");
 }
 TEST_F(FileTest, ReturnStruct) { runFileTest("cf.return.struct.hlsl"); }
+TEST_F(FileTest, ReturnFromDifferentStorageClass) {
+  runFileTest("cf.return.storage-class.hlsl");
+}
 
 // For control flows
 TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }