Explorar el Código

[SPIR-V] Fix default arguments for function templates (#4665)

Fixes #4169
Cassandra Beckley hace 3 años
padre
commit
eec7261fa5

+ 9 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1056,7 +1056,13 @@ SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr,
   } else if (const auto *condExpr = dyn_cast<ConditionalOperator>(expr)) {
     result = doConditionalOperator(condExpr);
   } else if (const auto *defaultArgExpr = dyn_cast<CXXDefaultArgExpr>(expr)) {
-    result = doExpr(defaultArgExpr->getParam()->getDefaultArg());
+    if (defaultArgExpr->getParam()->hasUninstantiatedDefaultArg()) {
+      auto defaultArg = defaultArgExpr->getParam()->getUninstantiatedDefaultArg();
+      result = castToType(doExpr(defaultArg), defaultArg->getType(), defaultArgExpr->getType(), defaultArg->getLocStart(), defaultArg->getSourceRange());
+      result->setRValue();
+    } else {
+      result = doExpr(defaultArgExpr->getParam()->getDefaultArg());
+    }
   } else if (isa<CXXThisExpr>(expr)) {
     assert(curThis);
     result = curThis;
@@ -13190,7 +13196,7 @@ SpirvEmitter::storeDataToRawAddress(SpirvInstruction *addressInUInt64,
 
   SpirvStore *storeInst = spvBuilder.createStore(address, value, loc);
   storeInst->setAlignment(alignment);
-  return nullptr; 
+  return nullptr;
 }
 
 SpirvInstruction *SpirvEmitter::processRawBufferStore(const CallExpr *callExpr) {
@@ -13200,7 +13206,7 @@ SpirvInstruction *SpirvEmitter::processRawBufferStore(const CallExpr *callExpr)
 
   SpirvInstruction *address = doExpr(callExpr->getArg(0));
   SpirvInstruction *value = doExpr(callExpr->getArg(1));
-  QualType bufferType = value->getAstResultType(); 
+  QualType bufferType = value->getAstResultType();
   clang::SourceLocation loc = callExpr->getExprLoc();
   if (!isBoolOrVecMatOfBoolType(bufferType)) {
     return storeDataToRawAddress(address, value, bufferType, alignment, loc);

+ 20 - 0
tools/clang/test/CodeGenSPIRV/fn.param.default.hlsl

@@ -0,0 +1,20 @@
+// RUN: %dxc -T vs_6_0 -E main -HV 2021
+
+template<typename T>
+T test(const T a, const T b = 0)
+{
+  return a + b;
+}
+
+float4 main(uint vertex_id : SV_VertexID) : SV_Position
+{
+  // CHECK: OpStore %param_var_a %float_1
+  // CHECK: OpStore %param_var_b %float_2
+  // CHECK: [[first:%\d+]] = OpFunctionCall %float %test %param_var_a %param_var_b
+  // CHECK: OpStore %param_var_a_0 %float_4
+  // CHECK: [[default:%\d+]] = OpConvertSToF %float %int_0
+  // CHECK: OpStore %param_var_b_0 [[default]]
+  // CHECK: [[second:%\d+]] = OpFunctionCall %float %test %param_var_a_0 %param_var_b_0
+  // CHECK: {{%\d+}} = OpCompositeConstruct %v4float [[first]] [[second]] %float_0 %float_0
+  return float4(test<float>(1,2), test<float>(4), 0, 0);
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -653,6 +653,7 @@ TEST_F(FileTest, FunctionInCTBuffer) {
 }
 
 TEST_F(FileTest, FunctionNoInline) { runFileTest("fn.noinline.hlsl"); }
+TEST_F(FileTest, FunctionDefaultParam) { runFileTest("fn.param.default.hlsl"); }
 TEST_F(FileTest, FunctionExport) { runFileTest("fn.export.hlsl"); }
 
 TEST_F(FileTest, FixFunctionCall) {