Browse Source

[spirv] Fix missing return value cases (#809)

If there are no return value specified in the source code for a
certain control flow path, just return the null value.

Previously always use OpReturn to terminate basic blocks without
explict return statement from the source code, which may cause
illegal SPIR-V generated.
Lei Zhang 7 years ago
parent
commit
aa7c2bc402

+ 37 - 6
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -535,13 +535,29 @@ uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType,
 }
 }
 
 
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
+  // A RAII class for maintaining the current function under traversal.
+  class FnEnvRAII {
+  public:
+    // Creates a new instance which sets fnEnv to the newFn on creation,
+    // and resets fnEnv to its original value on destruction.
+    FnEnvRAII(const FunctionDecl **fnEnv, const FunctionDecl *newFn)
+        : oldFn(*fnEnv), fnSlot(fnEnv) {
+      *fnEnv = newFn;
+    }
+    ~FnEnvRAII() { *fnSlot = oldFn; }
+
+  private:
+    const FunctionDecl *oldFn;
+    const FunctionDecl **fnSlot;
+  };
+
+  FnEnvRAII fnEnvRAII(&curFunction, decl);
+
   // We are about to start translation for a new function. Clear the break stack
   // We are about to start translation for a new function. Clear the break stack
   // and the continue stack.
   // and the continue stack.
   breakStack = std::stack<uint32_t>();
   breakStack = std::stack<uint32_t>();
   continueStack = std::stack<uint32_t>();
   continueStack = std::stack<uint32_t>();
 
 
-  curFunction = decl;
-
   std::string funcName = decl->getName();
   std::string funcName = decl->getName();
 
 
   uint32_t funcId = 0;
   uint32_t funcId = 0;
@@ -627,15 +643,23 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     doStmt(decl->getBody());
     doStmt(decl->getBody());
 
 
     // We have processed all Stmts in this function and now in the last
     // We have processed all Stmts in this function and now in the last
-    // basic block. Make sure we have OpReturn if missing.
+    // basic block. Make sure we have a termination instruction.
     if (!theBuilder.isCurrentBasicBlockTerminated()) {
     if (!theBuilder.isCurrentBasicBlockTerminated()) {
-      theBuilder.createReturn();
+      const auto retType = decl->getReturnType();
+
+      if (retType->isVoidType()) {
+        theBuilder.createReturn();
+      } else {
+        // If the source code does not provide a proper return value for some
+        // control flow path, it's undefined behavior. We just return null
+        // value here.
+        theBuilder.createReturnValue(
+            theBuilder.getConstantNull(typeTranslator.translateType(retType)));
+      }
     }
     }
   }
   }
 
 
   theBuilder.endFunction();
   theBuilder.endFunction();
-
-  curFunction = nullptr;
 }
 }
 
 
 void SPIRVEmitter::validateVKAttributes(const Decl *decl) {
 void SPIRVEmitter::validateVKAttributes(const Decl *decl) {
@@ -1200,6 +1224,13 @@ void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) {
     theBuilder.createReturn();
     theBuilder.createReturn();
   }
   }
 
 
+  // We are translating a ReturnStmt, we should be in some function's body.
+  assert(curFunction->hasBody());
+  // If this return statement is the last statement in the function, then
+  // whe have no more work to do.
+  if (cast<CompoundStmt>(curFunction->getBody())->body_back() == stmt)
+    return;
+
   // Some statements that alter the control flow (break, continue, return, and
   // Some statements that alter the control flow (break, continue, return, and
   // discard), require creation of a new basic block to hold any statement that
   // discard), require creation of a new basic block to hold any statement that
   // may follow them. In this case, the newly created basic block will contain
   // may follow them. In this case, the newly created basic block will contain

+ 8 - 8
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -176,7 +176,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %95 = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint
 // %95 = OpTypeFunction %HS_CONSTANT_DATA_OUTPUT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint
 // %_ptr_Function_HS_CONSTANT_DATA_OUTPUT = OpTypePointer Function %HS_CONSTANT_DATA_OUTPUT
 // %_ptr_Function_HS_CONSTANT_DATA_OUTPUT = OpTypePointer Function %HS_CONSTANT_DATA_OUTPUT
 // %_ptr_Function_float = OpTypePointer Function %float
 // %_ptr_Function_float = OpTypePointer Function %float
-// %121 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
+// %120 = OpTypeFunction %BEZIER_CONTROL_POINT %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3 %_ptr_Function_uint %_ptr_Function_uint
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
@@ -280,17 +280,17 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %119 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
 // %119 = OpLoad %HS_CONSTANT_DATA_OUTPUT %Output
 // OpReturnValue %119
 // OpReturnValue %119
 // OpFunctionEnd
 // OpFunctionEnd
-// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %121
+// %src_SubDToBezierHS = OpFunction %BEZIER_CONTROL_POINT None %120
 // %ip_0 = OpFunctionParameter %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3
 // %ip_0 = OpFunctionParameter %_ptr_Function__arr_VS_CONTROL_POINT_OUTPUT_uint_3
 // %cpid = OpFunctionParameter %_ptr_Function_uint
 // %cpid = OpFunctionParameter %_ptr_Function_uint
 // %PatchID_0 = OpFunctionParameter %_ptr_Function_uint
 // %PatchID_0 = OpFunctionParameter %_ptr_Function_uint
 // %bb_entry_0 = OpLabel
 // %bb_entry_0 = OpLabel
 // %vsOutput = OpVariable %_ptr_Function_VS_CONTROL_POINT_OUTPUT Function
 // %vsOutput = OpVariable %_ptr_Function_VS_CONTROL_POINT_OUTPUT Function
 // %result = OpVariable %_ptr_Function_BEZIER_CONTROL_POINT Function
 // %result = OpVariable %_ptr_Function_BEZIER_CONTROL_POINT Function
-// %131 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
-// %132 = OpLoad %v3float %131
-// %133 = OpAccessChain %_ptr_Function_v3float %result %int_0
-// OpStore %133 %132
-// %134 = OpLoad %BEZIER_CONTROL_POINT %result
-// OpReturnValue %134
+// %130 = OpAccessChain %_ptr_Function_v3float %vsOutput %int_0
+// %131 = OpLoad %v3float %130
+// %132 = OpAccessChain %_ptr_Function_v3float %result %int_0
+// OpStore %132 %131
+// %133 = OpLoad %BEZIER_CONTROL_POINT %result
+// OpReturnValue %133
 // OpFunctionEnd
 // OpFunctionEnd

+ 0 - 44
tools/clang/test/CodeGenSPIRV/constant-ps.hlsl2spv

@@ -1,44 +0,0 @@
-// Run: %dxc -T ps_6_0 -E main
-
-float4 main(): SV_Target
-{
-  return float4(1.0f, 2.0f, 3.5f, 4.7f);
-}
-
-// CHECK-WHOLE-SPIR-V:
-// ; SPIR-V
-// ; Version: 1.0
-// ; Generator: Google spiregg; 0
-// ; Bound: 19
-// ; Schema: 0
-// OpCapability Shader
-// OpMemoryModel Logical GLSL450
-// OpEntryPoint Fragment %main "main" %out_var_SV_Target
-// OpExecutionMode %main OriginUpperLeft
-// OpName %bb_entry "bb.entry"
-// OpName %src_main "src.main"
-// OpName %main "main"
-// OpName %out_var_SV_Target "out.var.SV_Target"
-// OpDecorate %out_var_SV_Target Location 0
-// %void = OpTypeVoid
-// %3 = OpTypeFunction %void
-// %float = OpTypeFloat 32
-// %v4float = OpTypeVector %float 4
-// %_ptr_Output_v4float = OpTypePointer Output %v4float
-// %11 = OpTypeFunction %v4float
-// %float_1 = OpConstant %float 1
-// %float_2 = OpConstant %float 2
-// %float_3_5 = OpConstant %float 3.5
-// %float_4_7 = OpConstant %float 4.7
-// %17 = OpConstantComposite %v4float %float_1 %float_2 %float_3_5 %float_4_7
-// %out_var_SV_Target = OpVariable %_ptr_Output_v4float Output
-// %main = OpFunction %void None %3
-// %5 = OpLabel
-// %8 = OpFunctionCall %v4float %src_main
-// OpStore %out_var_SV_Target %8
-// OpReturn
-// OpFunctionEnd
-// %src_main = OpFunction %v4float None %11
-// %bb_entry = OpLabel
-// OpReturnValue %17
-// OpFunctionEnd

+ 0 - 30
tools/clang/test/CodeGenSPIRV/empty-void-main.hlsl2spv

@@ -1,30 +0,0 @@
-// Run: %dxc -T ps_6_0 -E main
-void main()
-{
-
-}
-
-// CHECK-WHOLE-SPIR-V:
-// ; SPIR-V
-// ; Version: 1.0
-// ; Generator: Google spiregg; 0
-// ; Bound: 8
-// ; Schema: 0
-// OpCapability Shader
-// OpMemoryModel Logical GLSL450
-// OpEntryPoint Fragment %main "main"
-// OpExecutionMode %main OriginUpperLeft
-// OpName %bb_entry "bb.entry"
-// OpName %src_main "src.main"
-// OpName %main "main"
-// %void = OpTypeVoid
-// %3 = OpTypeFunction %void
-// %main = OpFunction %void None %3
-// %5 = OpLabel
-// %6 = OpFunctionCall %void %src_main
-// OpReturn
-// OpFunctionEnd
-// %src_main = OpFunction %void None %3
-// %bb_entry = OpLabel
-// OpReturn
-// OpFunctionEnd

+ 1 - 1
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -9,7 +9,7 @@ float4 main(float4 input: COLOR): SV_Target
 // ; SPIR-V
 // ; SPIR-V
 // ; Version: 1.0
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
 // ; Generator: Google spiregg; 0
-// ; Bound: 21
+// ; Bound: 20
 // ; Schema: 0
 // ; Schema: 0
 // OpCapability Shader
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
 // OpMemoryModel Logical GLSL450

+ 1 - 1
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -17,7 +17,7 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // ; SPIR-V
 // ; SPIR-V
 // ; Version: 1.0
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
 // ; Generator: Google spiregg; 0
-// ; Bound: 45
+// ; Bound: 44
 // ; Schema: 0
 // ; Schema: 0
 // OpCapability Shader
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
 // OpMemoryModel Logical GLSL450

+ 12 - 0
tools/clang/test/CodeGenSPIRV/spirv.cf.ret-missing.hlsl

@@ -0,0 +1,12 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK:[[null:%\d+]] = OpConstantNull %float
+
+float main(bool a: A) : B {
+    if (a) return 1.0;
+    // No return value for else
+
+// CHECK:      %if_merge = OpLabel
+// CHECK-NEXT: OpReturnValue [[null]]
+// CHECK-NEXT: OpFunctionEnd
+}

+ 4 - 8
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -16,10 +16,6 @@ using clang::spirv::WholeFileTest;
 
 
 // === Whole output tests ===
 // === Whole output tests ===
 
 
-TEST_F(WholeFileTest, EmptyVoidMain) {
-  runWholeFileTest("empty-void-main.hlsl2spv", /*generateHeader*/ true);
-}
-
 TEST_F(WholeFileTest, PassThruPixelShader) {
 TEST_F(WholeFileTest, PassThruPixelShader) {
   runWholeFileTest("passthru-ps.hlsl2spv", /*generateHeader*/ true);
   runWholeFileTest("passthru-ps.hlsl2spv", /*generateHeader*/ true);
 }
 }
@@ -32,10 +28,6 @@ TEST_F(WholeFileTest, PassThruComputeShader) {
   runWholeFileTest("passthru-cs.hlsl2spv", /*generateHeader*/ true);
   runWholeFileTest("passthru-cs.hlsl2spv", /*generateHeader*/ true);
 }
 }
 
 
-TEST_F(WholeFileTest, ConstantPixelShader) {
-  runWholeFileTest("constant-ps.hlsl2spv", /*generateHeader*/ true);
-}
-
 TEST_F(WholeFileTest, BezierHullShader) {
 TEST_F(WholeFileTest, BezierHullShader) {
   runWholeFileTest("bezier.hull.hlsl2spv");
   runWholeFileTest("bezier.hull.hlsl2spv");
 }
 }
@@ -865,6 +857,10 @@ TEST_F(FileTest, PrimitiveErrorGS) {
 // SPIR-V specific
 // SPIR-V specific
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }
 
 
+TEST_F(FileTest, SpirvControlFlowMissingReturn) {
+  runFileTest("spirv.cf.ret-missing.hlsl");
+}
+
 TEST_F(FileTest, SpirvEntryFunctionWrapper) {
 TEST_F(FileTest, SpirvEntryFunctionWrapper) {
   runFileTest("spirv.entry-function.wrapper.hlsl");
   runFileTest("spirv.entry-function.wrapper.hlsl");
 }
 }