Ver código fonte

[spirv] Fix returning struct containing arrays (#1392)

We should always create a temporary variable for rvalues if we
are trying to get an access chain from it. Now do it in the
turnIntoElementPtr() function, which is used everywhere.

Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1387
Lei Zhang 7 anos atrás
pai
commit
15885f01ac

+ 42 - 24
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1939,7 +1939,7 @@ SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
   }
 
   if (!indices.empty()) {
-    (void)turnIntoElementPtr(info, expr->getType(), indices);
+    (void)turnIntoElementPtr(base->getType(), info, expr->getType(), indices);
   }
 
   return info;
@@ -2432,7 +2432,8 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
       baseIndices[i] = theBuilder.getConstantUint32(baseIndices[i]);
 
     auto derivedInfo = doExpr(subExpr);
-    return turnIntoElementPtr(derivedInfo, expr->getType(), baseIndices);
+    return turnIntoElementPtr(subExpr->getType(), derivedInfo, expr->getType(),
+                              baseIndices);
   }
   default:
     emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc())
@@ -3279,7 +3280,7 @@ SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) {
   const uint32_t zero = theBuilder.getConstantInt32(0);
   const uint32_t index = doExpr(expr->getArg(0));
 
-  return turnIntoElementPtr(info, structType, {zero, index});
+  return turnIntoElementPtr(buffer->getType(), info, structType, {zero, index});
 }
 
 uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
@@ -3473,7 +3474,8 @@ SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
 
   const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType());
 
-  (void)turnIntoElementPtr(bufferInfo, bufferElemTy, {zero, index});
+  (void)turnIntoElementPtr(object->getType(), bufferInfo, bufferElemTy,
+                           {zero, index});
 
   if (isAppend) {
     // Write out the value
@@ -4396,7 +4398,8 @@ SPIRVEmitter::doCXXOperatorCallExpr(const CXXOperatorCallExpr *expr) {
     base = createTemporaryVar(baseExpr->getType(), "vector", base);
   }
 
-  return turnIntoElementPtr(base, expr->getType(), indices);
+  return turnIntoElementPtr(baseExpr->getType(), base, expr->getType(),
+                            indices);
 }
 
 SpirvEvalInfo
@@ -4581,19 +4584,7 @@ SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {
   auto info = loadIfAliasVarRef(base);
 
   if (!indices.empty()) {
-    // Sometime we are accessing the member of a rvalue, e.g.,
-    // <some-function-returing-a-struct>().<some-field>
-    // Create a temporary variable to hold the rvalue so that we can use access
-    // chain to index into it.
-    if (info.isRValue()) {
-      SpirvEvalInfo tempVar = createTemporaryVar(
-          base->getType(), TypeTranslator::getName(base->getType()), info);
-      (void)turnIntoElementPtr(tempVar, expr->getType(), indices);
-      info.setResultId(theBuilder.createLoad(
-          typeTranslator.translateType(expr->getType()), tempVar));
-    } else {
-      (void)turnIntoElementPtr(info, expr->getType(), indices);
-    }
+    (void)turnIntoElementPtr(base->getType(), info, expr->getType(), indices);
   }
 
   return info;
@@ -6031,13 +6022,40 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
 }
 
 SpirvEvalInfo &SPIRVEmitter::turnIntoElementPtr(
-    SpirvEvalInfo &info, QualType elemType,
+    QualType baseType, SpirvEvalInfo &base, QualType elemType,
     const llvm::SmallVector<uint32_t, 4> &indices) {
-  assert(!info.isRValue());
-  const uint32_t ptrType = theBuilder.getPointerType(
-      typeTranslator.translateType(elemType, info.getLayoutRule()),
-      info.getStorageClass());
-  return info.setResultId(theBuilder.createAccessChain(ptrType, info, indices));
+  // If this is a rvalue, we need a temporary object to hold it
+  // so that we can get access chain from it.
+  const bool needTempVar = base.isRValue();
+
+  if (needTempVar) {
+    auto varName = TypeTranslator::getName(baseType);
+    const auto var = createTemporaryVar(baseType, varName, base);
+    base.setResultId(var)
+        .setLayoutRule(LayoutRule::Void)
+        .setStorageClass(spv::StorageClass::Function);
+  }
+
+  const uint32_t elemTypeId =
+      typeTranslator.translateType(elemType, base.getLayoutRule());
+  const uint32_t ptrType =
+      theBuilder.getPointerType(elemTypeId, base.getStorageClass());
+  base.setResultId(theBuilder.createAccessChain(ptrType, base, indices));
+
+  // Okay, this part seems weird, but it is intended:
+  // If the base is originally a rvalue, the whole AST involving the base
+  // is consistently set up to handle rvalues. By copying the base into
+  // a temporary variable and grab an access chain from it, we are breaking
+  // the consistency by turning the base from rvalue into lvalue. Keep in
+  // mind that there will be no LValueToRValue casts in the AST for us
+  // to rely on to load the access chain if a rvalue is expected. Therefore,
+  // we must do the load here. Otherwise, it's up to the consumer of this
+  // access chain to do the load, and that can be everywhere.
+  if (needTempVar) {
+    base.setResultId(theBuilder.createLoad(elemTypeId, base));
+  }
+
+  return base;
 }
 
 uint32_t SPIRVEmitter::castToBool(const uint32_t fromVal, QualType fromType,

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -288,7 +288,7 @@ private:
   /// Creates an access chain to index into the given SPIR-V evaluation result
   /// and overwrites and returns the new SPIR-V evaluation result.
   SpirvEvalInfo &
-  turnIntoElementPtr(SpirvEvalInfo &info, QualType elemType,
+  turnIntoElementPtr(QualType baseType, SpirvEvalInfo &base, QualType elemType,
                      const llvm::SmallVector<uint32_t, 4> &indices);
 
 private:

+ 29 - 3
tools/clang/test/CodeGenSPIRV/op.struct.access.hlsl

@@ -1,10 +1,13 @@
 // Run: %dxc -T ps_6_0 -E main
 
 struct S {
-    bool a;
-    uint2 b;
+    bool     a;
+    uint2    b;
     float2x3 c;
-    float4 d;
+    float4   d;
+    float4   e[1];
+    float    f[4];
+    int      g;
 };
 
 struct T {
@@ -17,6 +20,17 @@ T foo() {
     return ret;
 }
 
+S bar() {
+    S ret = (S)0;
+    return ret;
+}
+
+ConstantBuffer<S> MyBuffer;
+
+S baz() {
+    return MyBuffer;
+}
+
 float4 main() : SV_Target {
     T t;
 
@@ -75,6 +89,18 @@ float4 main() : SV_Target {
 // CHECK-NEXT: OpStore [[c0]] {{%\d+}}
     t.i.c[0] = v6;
 
+// CHECK:       [[baz:%\d+]] = OpFunctionCall %S %baz
+// CHECK-NEXT:                 OpStore %temp_var_S [[baz]]
+// CHECK-NEXT:                 OpAccessChain %_ptr_Function_v4float %temp_var_S %int_4 %int_0
+// CHECK:       [[bar:%\d+]] = OpFunctionCall %S %bar
+// CHECK-NEXT:                 OpStore %temp_var_S_0 [[bar]]
+// CHECK-NEXT:                 OpAccessChain %_ptr_Function_float %temp_var_S_0 %int_5 %int_1
+    float4 val1 = bar().f[1] * baz().e[0];
+
+// CHECK:        [[ac:%\d+]] = OpAccessChain %_ptr_Function_int %temp_var_S_1 %int_6
+// CHECK-NEXT:                 OpLoad %int [[ac]]
+    bool val2 = bar().g; // Need cast on rvalue function return
+
 // CHECK:      [[ret:%\d+]] = OpFunctionCall %T %foo
 // CHECK-NEXT: OpStore %temp_var_T [[ret]]
 // CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Function_v4float %temp_var_T %int_1 %int_3