Browse Source

[spirv] Support input modifiers on function parameters (#559)

We need to write out the values for parameters having out/inout
modifiers.
Lei Zhang 8 years ago
parent
commit
5423d05daf

+ 65 - 23
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1095,22 +1095,35 @@ uint32_t SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
   }
 
   if (callee) {
-    const uint32_t returnType =
-        typeTranslator.translateType(callExpr->getType());
+    const auto numParams = callee->getNumParams();
 
-    // Get or forward declare the function <result-id>
-    const uint32_t funcId = declIdMapper.getOrRegisterDeclResultId(callee);
+    llvm::SmallVector<uint32_t, 4> params;
+    llvm::SmallVector<uint32_t, 4> args;
 
     // Evaluate parameters
-    llvm::SmallVector<uint32_t, 4> params;
-    for (const auto *arg : callExpr->arguments()) {
+    for (uint32_t i = 0; i < numParams; ++i) {
+      const auto *arg = callExpr->getArg(i);
+      const auto *param = callee->getParamDecl(i);
+
       // We need to create variables for holding the values to be used as
       // arguments. The variables themselves are of pointer types.
       const uint32_t varType = typeTranslator.translateType(arg->getType());
-      const uint32_t tempVarId = theBuilder.addFnVar(varType);
-      theBuilder.createStore(tempVarId, doExpr(arg));
+      const std::string varName = "param.var." + param->getNameAsString();
+      const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
 
       params.push_back(tempVarId);
+      args.push_back(doExpr(arg));
+
+      if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
+        // The current parameter is marked as out/inout. The argument then is
+        // essentially passed in by reference. We need to load the value
+        // explicitly here since the AST won't inject LValueToRValue implicit
+        // cast for this case.
+        const uint32_t value = theBuilder.createLoad(varType, args.back());
+        theBuilder.createStore(tempVarId, value);
+      } else {
+        theBuilder.createStore(tempVarId, args.back());
+      }
     }
 
     // Push the callee into the work queue if it is not there.
@@ -1118,7 +1131,24 @@ uint32_t SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
       workQueue.insert(callee);
     }
 
-    return theBuilder.createFunctionCall(returnType, funcId, params);
+    const uint32_t retType = typeTranslator.translateType(callExpr->getType());
+    // Get or forward declare the function <result-id>
+    const uint32_t funcId = declIdMapper.getOrRegisterDeclResultId(callee);
+
+    const uint32_t retVal =
+        theBuilder.createFunctionCall(retType, funcId, params);
+
+    // Go through all parameters and write those marked as out/inout
+    for (uint32_t i = 0; i < numParams; ++i) {
+      const auto *param = callee->getParamDecl(i);
+      if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
+        const uint32_t typeId = typeTranslator.translateType(param->getType());
+        const uint32_t value = theBuilder.createLoad(typeId, params[i]);
+        theBuilder.createStore(args[i], value);
+      }
+    }
+
+    return retVal;
   }
 
   emitError("calling non-function unimplemented");
@@ -3117,28 +3147,29 @@ uint32_t SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   const uint32_t entryLabel = theBuilder.createBasicBlock();
   theBuilder.setInsertPoint(entryLabel);
 
-  // Initialize all global variables at the beginning of the entry function
+  // Initialize all global variables at the beginning of the wrapper
   for (const VarDecl *varDecl : toInitGloalVars)
     theBuilder.createStore(declIdMapper.getDeclResultId(varDecl),
                            doExpr(varDecl->getInit()));
 
-  // Create and read stage input variables
-  // TODO: handle out/inout modifier
+  // Create temporary variables for holding function call arguments
   llvm::SmallVector<uint32_t, 4> params;
   for (const auto *param : decl->params()) {
-    // Create stage variable(s) from this parameter and load the composite value
-    uint32_t loadedValue = 0;
-    if (!declIdMapper.createStageInputVar(param, &loadedValue))
-      return 0;
-
-    // Create a temporary variable to hold the composite value of the stage
-    // variable
     const uint32_t typeId = typeTranslator.translateType(param->getType());
     std::string tempVarName = "param.var." + param->getNameAsString();
     const uint32_t tempVar = theBuilder.addFnVar(typeId, tempVarName);
-    theBuilder.createStore(tempVar, loadedValue);
 
     params.push_back(tempVar);
+
+    // Create the stage input variable for parameter not marked as pure out and
+    // initialize the corresponding temporary variable
+    if (!param->getAttr<HLSLOutAttr>()) {
+      uint32_t loadedValue = 0;
+      if (!declIdMapper.createStageInputVar(param, &loadedValue))
+        return 0;
+
+      theBuilder.createStore(tempVar, loadedValue);
+    }
   }
 
   // Call the original entry function
@@ -3146,12 +3177,23 @@ uint32_t SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   const uint32_t retVal =
       theBuilder.createFunctionCall(retType, entryFuncId, params);
 
-  // Create and write stage output variables
-  // TODO: handle out/inout modifier
-
+  // Create and write stage output variables for return value
   if (!declIdMapper.createStageOutputVar(decl, retVal))
     return 0;
 
+  // Create and write stage output variables for parameters marked as out/inout
+  for (uint32_t i = 0; i < decl->getNumParams(); ++i) {
+    const auto *param = decl->getParamDecl(i);
+    if (param->getAttr<HLSLOutAttr>() || param->getAttr<HLSLInOutAttr>()) {
+      // Load the value from the parameter after function call
+      const uint32_t typeId = typeTranslator.translateType(param->getType());
+      const uint32_t loadedParam = theBuilder.createLoad(typeId, params[i]);
+
+      if (!declIdMapper.createStageOutputVar(param, loadedParam))
+        return 0;
+    }
+  }
+
   theBuilder.createReturn();
   theBuilder.endFunction();
 

+ 13 - 2
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -52,6 +52,17 @@ uint32_t TypeTranslator::translateType(QualType type) {
     return translateType(typedefType->desugar());
   }
 
+  // Reference types
+  if (const auto *refType = dyn_cast<ReferenceType>(typePtr)) {
+    // Note: Pointer/reference types are disallowed in HLSL source code.
+    // Although developers cannot use them directly, they are generated into
+    // the AST by out/inout parameter modifiers in function signatures.
+    // We already pass function arguments via pointers to tempoary local
+    // variables. So it should be fine to drop the pointer type and treat it
+    // as the underlying pointee type here.
+    return translateType(refType->getPointeeType());
+  }
+
   // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
   // We handle them via HLSL type inspection functions.
 
@@ -108,7 +119,7 @@ uint32_t TypeTranslator::translateType(QualType type) {
       fieldNames.push_back(field->getName());
     }
 
-    return theBuilder.getStructType(fieldTypes, type.getAsString(), fieldNames);
+    return theBuilder.getStructType(fieldTypes, decl->getName(), fieldNames);
   }
 
   emitError("Type '%0' is not supported yet.") << type->getTypeClassName();
@@ -272,4 +283,4 @@ uint32_t TypeTranslator::getComponentVectorType(QualType matrixType) {
 }
 
 } // end namespace spirv
-} // end namespace clang
+} // end namespace clang

+ 5 - 5
tools/clang/test/CodeGenSPIRV/fn.call.hlsl

@@ -22,11 +22,11 @@ void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 // CHECK-NEXT: %v = OpVariable %_ptr_Function_int Function
     int v;
-// CHECK-NEXT: [[oneParam:%\d+]] = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: [[twoParam1:%\d+]] = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: [[twoParam2:%\d+]] = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: [[nestedParam1:%\d+]] = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: [[nestedParam2:%\d+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[oneParam:%\w+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[twoParam1:%\w+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[twoParam2:%\w+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[nestedParam1:%\w+]] = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: [[nestedParam2:%\w+]] = OpVariable %_ptr_Function_int Function
 
 // CHECK-NEXT: OpStore [[oneParam]] %int_1
 // CHECK-NEXT: [[call0:%\d+]] = OpFunctionCall %int %fnOneParm [[oneParam]]

+ 35 - 0
tools/clang/test/CodeGenSPIRV/fn.param.inout.hlsl

@@ -0,0 +1,35 @@
+// Run: %dxc -T vs_6_0 -E main
+
+float fnInOut(uniform float a, in float b, out float c, inout float d) {
+    float v = a + b + c + d;
+    d = c = a;
+    return v;
+}
+
+float main(float val: A) : B {
+// CHECK-LABEL: %src_main = OpFunction
+    float m, n;
+// CHECK:      %param_var_a = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %param_var_b = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %param_var_c = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %param_var_d = OpVariable %_ptr_Function_float Function
+
+// CHECK-NEXT:                OpStore %param_var_a %float_5
+// CHECK-NEXT: [[val:%\d+]] = OpLoad %float %val
+// CHECK-NEXT:                OpStore %param_var_b [[val]]
+// CHECK-NEXT:   [[m:%\d+]] = OpLoad %float %m
+// CHECK-NEXT:                OpStore %param_var_c [[m]]
+// CHECK-NEXT:   [[n:%\d+]] = OpLoad %float %n
+// CHECK-NEXT:                OpStore %param_var_d [[n]]
+
+// CHECK-NEXT: [[ret:%\d+]] = OpFunctionCall %float %fnInOut %param_var_a %param_var_b %param_var_c %param_var_d
+
+// CHECK-NEXT:   [[c:%\d+]] = OpLoad %float %param_var_c
+// CHECK-NEXT:                OpStore %m [[c]]
+// CHECK-NEXT:   [[d:%\d+]] = OpLoad %float %param_var_d
+// CHECK-NEXT:                OpStore %n [[d]]
+
+// CHECK-NEXT:                OpReturnValue [[ret]]
+    return fnInOut(5., val, m, n);
+// CHECK-NEXT: OpFunctionEnd
+}

+ 4 - 4
tools/clang/test/CodeGenSPIRV/passthru-ps.hlsl2spv

@@ -16,8 +16,8 @@ float4 main(float4 input: COLOR): SV_TARGET
 // OpEntryPoint Fragment %main "main" %in_var_COLOR %out_var_SV_Target
 // OpExecutionMode %main OriginUpperLeft
 // OpName %main "main"
-// OpName %in_var_COLOR "in.var.COLOR"
 // OpName %param_var_input "param.var.input"
+// OpName %in_var_COLOR "in.var.COLOR"
 // OpName %out_var_SV_Target "out.var.SV_Target"
 // OpName %src_main "src.main"
 // OpName %input "input"
@@ -28,8 +28,8 @@ float4 main(float4 input: COLOR): SV_TARGET
 // %3 = OpTypeFunction %void
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
-// %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %_ptr_Function_v4float = OpTypePointer Function %v4float
+// %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
 // %16 = OpTypeFunction %v4float %_ptr_Function_v4float
 // %in_var_COLOR = OpVariable %_ptr_Input_v4float Input
@@ -37,8 +37,8 @@ float4 main(float4 input: COLOR): SV_TARGET
 // %main = OpFunction %void None %3
 // %5 = OpLabel
 // %param_var_input = OpVariable %_ptr_Function_v4float Function
-// %10 = OpLoad %v4float %in_var_COLOR
-// OpStore %param_var_input %10
+// %12 = OpLoad %v4float %in_var_COLOR
+// OpStore %param_var_input %12
 // %13 = OpFunctionCall %v4float %src_main %param_var_input
 // OpStore %out_var_SV_Target %13
 // OpReturn

+ 7 - 7
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -23,10 +23,10 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // OpMemoryModel Logical GLSL450
 // OpEntryPoint Vertex %VSmain "VSmain" %in_var_POSITION %in_var_COLOR %gl_Position %out_var_COLOR
 // OpName %VSmain "VSmain"
-// OpName %in_var_POSITION "in.var.POSITION"
 // OpName %param_var_position "param.var.position"
-// OpName %in_var_COLOR "in.var.COLOR"
+// OpName %in_var_POSITION "in.var.POSITION"
 // OpName %param_var_color "param.var.color"
+// OpName %in_var_COLOR "in.var.COLOR"
 // OpName %PSInput "PSInput"
 // OpMemberName %PSInput 0 "position"
 // OpMemberName %PSInput 1 "color"
@@ -45,8 +45,8 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %3 = OpTypeFunction %void
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
-// %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %_ptr_Function_v4float = OpTypePointer Function %v4float
+// %_ptr_Input_v4float = OpTypePointer Input %v4float
 // %PSInput = OpTypeStruct %v4float %v4float
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
 // %23 = OpTypeFunction %PSInput %_ptr_Function_v4float %_ptr_Function_v4float
@@ -61,10 +61,10 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %5 = OpLabel
 // %param_var_position = OpVariable %_ptr_Function_v4float Function
 // %param_var_color = OpVariable %_ptr_Function_v4float Function
-// %10 = OpLoad %v4float %in_var_POSITION
-// OpStore %param_var_position %10
-// %14 = OpLoad %v4float %in_var_COLOR
-// OpStore %param_var_color %14
+// %12 = OpLoad %v4float %in_var_POSITION
+// OpStore %param_var_position %12
+// %15 = OpLoad %v4float %in_var_COLOR
+// OpStore %param_var_color %15
 // %17 = OpFunctionCall %PSInput %src_VSmain %param_var_position %param_var_color
 // %18 = OpCompositeExtract %v4float %17 0
 // OpStore %gl_Position %18

+ 100 - 0
tools/clang/test/CodeGenSPIRV/spirv.entry-function.inout.hlsl

@@ -0,0 +1,100 @@
+// Run: %dxc -T vs_6_0 -E main
+
+struct X {
+    int    a: A;
+    float4 b: B;
+};
+
+struct Y {
+    uint   c: C;
+    float4 d: D;
+};
+
+struct Z {
+    float e: E;
+};
+
+// CHECK:       %in_var_O = OpVariable %_ptr_Input_v4int Input
+// CHECK-NEXT:  %in_var_Q = OpVariable %_ptr_Input_v4int Input
+// CHECK-NEXT:  %in_var_A = OpVariable %_ptr_Input_int Input
+// CHECK-NEXT:  %in_var_B = OpVariable %_ptr_Input_v4float Input
+// CHECK-NEXT:  %in_var_C = OpVariable %_ptr_Input_uint Input
+// CHECK-NEXT:  %in_var_D = OpVariable %_ptr_Input_v4float Input
+// CHECK-NEXT:  %in_var_R = OpVariable %_ptr_Input_float Input
+// CHECK-NEXT:  %in_var_E = OpVariable %_ptr_Input_float Input
+
+// CHECK-NEXT: %out_var_P = OpVariable %_ptr_Output_v4int Output
+// CHECK-NEXT: %out_var_Q = OpVariable %_ptr_Output_v4int Output
+// CHECK-NEXT: %out_var_A = OpVariable %_ptr_Output_int Output
+// CHECK-NEXT: %out_var_B = OpVariable %_ptr_Output_v4float Output
+// CHECK-NEXT: %out_var_C = OpVariable %_ptr_Output_uint Output
+// CHECK-NEXT: %out_var_D = OpVariable %_ptr_Output_v4float Output
+
+// CHECK:      %main = OpFunction %void None
+// CHECK-NEXT: OpLabel
+
+// CHECK-NEXT: %param_var_param1 = OpVariable %_ptr_Function_v4int Function
+// CHECK-NEXT: %param_var_param2 = OpVariable %_ptr_Function_v4int Function
+// CHECK-NEXT: %param_var_param3 = OpVariable %_ptr_Function_v4int Function
+// CHECK-NEXT: %param_var_param4 = OpVariable %_ptr_Function_X Function
+// CHECK-NEXT: %param_var_param5 = OpVariable %_ptr_Function_X Function
+// CHECK-NEXT: %param_var_param6 = OpVariable %_ptr_Function_Y Function
+// CHECK-NEXT: %param_var_param7 = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %param_var_param8 = OpVariable %_ptr_Function_Z Function
+
+// CHECK-NEXT:  [[inO:%\d+]] = OpLoad %v4int %in_var_O
+// CHECK-NEXT:                 OpStore %param_var_param1 [[inO]]
+// CHECK-NEXT:  [[inQ:%\d+]] = OpLoad %v4int %in_var_Q
+// CHECK-NEXT:                 OpStore %param_var_param3 [[inQ]]
+// CHECK-NEXT:  [[inA:%\d+]] = OpLoad %int %in_var_A
+// CHECK-NEXT:  [[inB:%\d+]] = OpLoad %v4float %in_var_B
+// CHECK-NEXT:  [[inX:%\d+]] = OpCompositeConstruct %X [[inA]] [[inB]]
+// CHECK-NEXT:                 OpStore %param_var_param4 [[inX]]
+// CHECK-NEXT:  [[inC:%\d+]] = OpLoad %uint %in_var_C
+// CHECK-NEXT:  [[inD:%\d+]] = OpLoad %v4float %in_var_D
+// CHECK-NEXT:  [[inY:%\d+]] = OpCompositeConstruct %Y [[inC]] [[inD]]
+// CHECK-NEXT:                 OpStore %param_var_param6 [[inY]]
+// CHECK-NEXT:  [[inR:%\d+]] = OpLoad %float %in_var_R
+// CHECK-NEXT:                 OpStore %param_var_param7 [[inR]]
+// CHECK-NEXT:  [[inE:%\d+]] = OpLoad %float %in_var_E
+// CHECK-NEXT:  [[inZ:%\d+]] = OpCompositeConstruct %Z [[inE]]
+// CHECK-NEXT:                 OpStore %param_var_param8 [[inZ]]
+
+// CHECK-NEXT:                 OpFunctionCall %void %src_main %param_var_param1 %param_var_param2 %param_var_param3 %param_var_param4 %param_var_param5 %param_var_param6 %param_var_param7 %param_var_param8
+// CHECK-NEXT: [[outP:%\d+]] = OpLoad %v4int %param_var_param2
+// CHECK-NEXT:                 OpStore %out_var_P [[outP]]
+// CHECK-NEXT: [[outQ:%\d+]] = OpLoad %v4int %param_var_param3
+// CHECK-NEXT:                 OpStore %out_var_Q [[outQ]]
+// CHECK-NEXT: [[outX:%\d+]] = OpLoad %X %param_var_param5
+// CHECK-NEXT: [[outA:%\d+]] = OpCompositeExtract %int [[outX]] 0
+// CHECK-NEXT:                 OpStore %out_var_A [[outA]]
+// CHECK-NEXT: [[outB:%\d+]] = OpCompositeExtract %v4float [[outX]] 1
+// CHECK-NEXT:                 OpStore %out_var_B [[outB]]
+// CHECK-NEXT: [[outY:%\d+]] = OpLoad %Y %param_var_param6
+// CHECK-NEXT: [[outC:%\d+]] = OpCompositeExtract %uint [[outY]] 0
+// CHECK-NEXT:                 OpStore %out_var_C [[outC]]
+// CHECK-NEXT: [[outD:%\d+]] = OpCompositeExtract %v4float [[outY]] 1
+// CHECK-NEXT:                 OpStore %out_var_D [[outD]]
+
+// CHECK-NEXT: OpReturn
+// CHECK-NEXT: OpFunctionEnd
+
+// Input  semantics: O, Q, A, B, C, D, R, E
+// Output semantics: P, Q, A, B, C, D
+void main(in      int4  param1: O,
+          out     int4  param2: P,
+          inout   int4  param3: Q,
+          in      X     param4,
+          out     X     param5,
+          inout   Y     param6,
+          uniform float param7: R,
+          uniform Z     param8)
+{
+// CHECK-LABEL: %src_main = OpFunction
+    param2 = param1;
+    param3 = param1;
+
+    param5 = param4;
+    param6.c = param4.a;
+    param6.d = param4.b;
+}

+ 9 - 1
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -239,9 +239,14 @@ TEST_F(FileTest, ControlFlowConditionalOp) { runFileTest("cf.cond-op.hlsl"); }
 // For function calls
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 
+// For function parameters
+TEST_F(FileTest, FunctionInOutParam) { runFileTest("fn.param.inout.hlsl"); }
+
 // For early returns
 TEST_F(FileTest, EarlyReturn) { runFileTest("cf.return.early.hlsl"); }
-TEST_F(FileTest, EarlyReturnFloat4) { runFileTest("cf.return.early.float4.hlsl"); }
+TEST_F(FileTest, EarlyReturnFloat4) {
+  runFileTest("cf.return.early.float4.hlsl");
+}
 
 // For discard
 TEST_F(FileTest, Discard) { runFileTest("cf.discard.hlsl"); }
@@ -315,5 +320,8 @@ TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }
 TEST_F(FileTest, SpirvEntryFunctionWrapper) {
   runFileTest("spirv.entry-function.wrapper.hlsl");
 }
+TEST_F(FileTest, SpirvEntryFunctionInOut) {
+  runFileTest("spirv.entry-function.inout.hlsl");
+}
 
 } // namespace