Browse Source

[spirv] allow illegal storage/isomorphic function arg (#1991)

Before this change, decision of creating temporary variables for
function arguments were affected by three things: storage class,
param type, out or inout keyword. For example, if storage class of
a function arg is different from the param, we create a temporary
variable whose storage class is the same with the param and store
the function arg before the function call. After the function call
we store the value to the actual arg if needed.

Now based on HLSL legalization supported by spirv-opt, we do not
need to consider storage class. Similarly, we do not need to
consider type difference between arg and param if they have an
isomorphic type e.g., same structure but different names.
Jaebaek Seo 6 years ago
parent
commit
c9b258a40c

+ 48 - 28
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -192,8 +192,8 @@ bool spirvToolsOptimize(spv_target_env env, std::vector<uint32_t> *module,
 }
 }
 
 
 bool spirvToolsValidate(spv_target_env env, const SpirvCodeGenOptions &opts,
 bool spirvToolsValidate(spv_target_env env, const SpirvCodeGenOptions &opts,
-                        bool relaxLogicalPointer, std::vector<uint32_t> *module,
-                        std::string *messages) {
+                        bool beforeHlslLegalization, bool relaxLogicalPointer,
+                        std::vector<uint32_t> *module, std::string *messages) {
   spvtools::SpirvTools tools(env);
   spvtools::SpirvTools tools(env);
 
 
   tools.SetMessageConsumer(
   tools.SetMessageConsumer(
@@ -202,7 +202,16 @@ bool spirvToolsValidate(spv_target_env env, const SpirvCodeGenOptions &opts,
                  const char *message) { *messages += message; });
                  const char *message) { *messages += message; });
 
 
   spvtools::ValidatorOptions options;
   spvtools::ValidatorOptions options;
-  options.SetRelaxLogicalPointer(relaxLogicalPointer);
+  options.SetBeforeHlslLegalization(beforeHlslLegalization);
+  // When beforeHlslLegalization is true and relaxLogicalPointer is false,
+  // options.SetBeforeHlslLegalization() enables --before-hlsl-legalization
+  // and --relax-logical-pointer. If options.SetRelaxLogicalPointer() is
+  // called, it disables --relax-logical-pointer that is not expected
+  // behavior. When beforeHlslLegalization is true, we must enable both
+  // options.
+  if (!beforeHlslLegalization)
+    options.SetRelaxLogicalPointer(relaxLogicalPointer);
+
   // GL: strict block layout rules
   // GL: strict block layout rules
   // VK: relaxed block layout rules
   // VK: relaxed block layout rules
   // DX: Skip block layout rules
   // DX: Skip block layout rules
@@ -486,7 +495,7 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
                    spirvOptions),
                    spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
-      mainSourceFile(nullptr) {
+      beforeHlslLegalization(false), mainSourceFile(nullptr) {
 
 
   // Get ShaderModel from command line hlsl profile option.
   // Get ShaderModel from command line hlsl profile option.
   const hlsl::ShaderModel *shaderModel =
   const hlsl::ShaderModel *shaderModel =
@@ -670,7 +679,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   // Validate the generated SPIR-V code
   // Validate the generated SPIR-V code
   if (!spirvOptions.disableValidation) {
   if (!spirvOptions.disableValidation) {
     std::string messages;
     std::string messages;
-    if (!spirvToolsValidate(targetEnv, spirvOptions,
+    if (!spirvToolsValidate(targetEnv, spirvOptions, beforeHlslLegalization,
                             declIdMapper.requiresLegalization(), &m,
                             declIdMapper.requiresLegalization(), &m,
                             &messages)) {
                             &messages)) {
       emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
       emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
@@ -2033,7 +2042,6 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   bool isNonStaticMemberCall = false;
   bool isNonStaticMemberCall = false;
   QualType objectType = {};             // Type of the object (if exists)
   QualType objectType = {};             // Type of the object (if exists)
   SpirvInstruction *objInstr = nullptr; // EvalInfo for the object (if exists)
   SpirvInstruction *objInstr = nullptr; // EvalInfo for the object (if exists)
-  bool needsTempVar = false;            // Whether we need temporary variable.
 
 
   llvm::SmallVector<SpirvInstruction *, 4> vars; // Variables for function call
   llvm::SmallVector<SpirvInstruction *, 4> vars; // Variables for function call
   llvm::SmallVector<bool, 4> isTempVar;          // Temporary variable or not
   llvm::SmallVector<bool, 4> isTempVar;          // Temporary variable or not
@@ -2060,15 +2068,18 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
       // getObject().objectMethod();
       // getObject().objectMethod();
       // Also, any parameter passed to the member function must be of Function
       // Also, any parameter passed to the member function must be of Function
       // storage class.
       // storage class.
-      needsTempVar = objInstr->isRValue() ||
-                     objInstr->getStorageClass() != spv::StorageClass::Function;
-
-      if (needsTempVar) {
+      if (objInstr->isRValue()) {
         args.push_back(createTemporaryVar(
         args.push_back(createTemporaryVar(
             objectType, getAstTypeName(objectType),
             objectType, getAstTypeName(objectType),
             // May need to load to use as initializer
             // May need to load to use as initializer
-            loadIfGLValue(object, objInstr), object->getExprLoc()));
+            loadIfGLValue(object, objInstr), object->getLocStart()));
       } else {
       } else {
+        // Based on SPIR-V spec, function parameter must always be in Function
+        // scope. If we pass a non-function scope argument, we need
+        // the legalization.
+        if (objInstr->getStorageClass() != spv::StorageClass::Function)
+          beforeHlslLegalization = true;
+
         args.push_back(objInstr);
         args.push_back(objInstr);
       }
       }
 
 
@@ -2089,19 +2100,32 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
 
 
     // Get the evaluation info if this argument is referencing some variable
     // Get the evaluation info if this argument is referencing some variable
     // *as a whole*, in which case we can avoid creating the temporary variable
     // *as a whole*, in which case we can avoid creating the temporary variable
-    // for it if it is Function scope and can act as out parameter.
+    // for it if it can act as out parameter.
     SpirvInstruction *argInfo = nullptr;
     SpirvInstruction *argInfo = nullptr;
     if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
     if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
       argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(),
       argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(),
                                              arg->getLocStart());
                                              arg->getLocStart());
     }
     }
 
 
-    if (argInfo && argInfo->getStorageClass() == spv::StorageClass::Function &&
+    auto *argInst = doExpr(arg);
+    auto argType = arg->getType();
+
+    // If argInfo is nullptr and argInst is a rvalue, we do not have a proper
+    // pointer to pass to the function. we need a temporary variable in that
+    // case.
+    if ((argInfo || (argInst && !argInst->isRValue())) &&
         canActAsOutParmVar(param) &&
         canActAsOutParmVar(param) &&
         paramTypeMatchesArgType(param->getType(), arg->getType())) {
         paramTypeMatchesArgType(param->getType(), arg->getType())) {
-      vars.push_back(argInfo);
+      // Based on SPIR-V spec, function parameter must be always Function
+      // scope. In addition, we must pass memory object declaration argument
+      // to function. If we pass an argument that is not function scope
+      // or not memory object declaration, we need the legalization.
+      if (!argInfo || argInfo->getStorageClass() != spv::StorageClass::Function)
+        beforeHlslLegalization = true;
+
       isTempVar.push_back(false);
       isTempVar.push_back(false);
-      args.push_back(doExpr(arg));
+      args.push_back(argInst);
+      vars.push_back(argInfo ? argInfo : argInst);
     } else {
     } else {
       // We need to create variables for holding the values to be used as
       // We need to create variables for holding the values to be used as
       // arguments. The variables themselves are of pointer types.
       // arguments. The variables themselves are of pointer types.
@@ -2119,7 +2143,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
 
 
       vars.push_back(tempVar);
       vars.push_back(tempVar);
       isTempVar.push_back(true);
       isTempVar.push_back(true);
-      args.push_back(doExpr(arg));
+      args.push_back(argInst);
 
 
       // Update counter variable associated with function parameters
       // Update counter variable associated with function parameters
       tryToAssignCounterVar(param, arg);
       tryToAssignCounterVar(param, arg);
@@ -2150,6 +2174,9 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     }
     }
   }
   }
 
 
+  if (beforeHlslLegalization)
+    needsLegalization = true;
+
   assert(vars.size() == isTempVar.size());
   assert(vars.size() == isTempVar.size());
   assert(vars.size() == args.size());
   assert(vars.size() == args.size());
 
 
@@ -2165,23 +2192,16 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   auto *retVal = spvBuilder.createFunctionCall(
   auto *retVal = spvBuilder.createFunctionCall(
       retType, func, vars, callExpr->getCallee()->getExprLoc());
       retType, func, vars, callExpr->getCallee()->getExprLoc());
 
 
-  // If we created a temporary variable for the lvalue object this method is
-  // invoked upon, we need to copy the contents in the temporary variable back
-  // to the original object's variable in case there are side effects.
-  if (needsTempVar && !objInstr->isRValue()) {
-    auto *value =
-        spvBuilder.createLoad(objectType, vars.front(), callExpr->getLocEnd());
-    storeValue(objInstr, value, objectType, callExpr->getLocEnd());
-  }
-
   // Go through all parameters and write those marked as out/inout
   // Go through all parameters and write those marked as out/inout
   for (uint32_t i = 0; i < numParams; ++i) {
   for (uint32_t i = 0; i < numParams; ++i) {
     const auto *param = callee->getParamDecl(i);
     const auto *param = callee->getParamDecl(i);
-    if (isTempVar[i] && canActAsOutParmVar(param)) {
+    // If it calls a non-static member function, the object itself is argument
+    // 0, and therefore all other argument positions are shifted by 1.
+    const uint32_t index = i + isNonStaticMemberCall;
+    if (isTempVar[index] && canActAsOutParmVar(param)) {
       const auto *arg = callExpr->getArg(i);
       const auto *arg = callExpr->getArg(i);
-      const uint32_t index = i + isNonStaticMemberCall;
       SpirvInstruction *value = spvBuilder.createLoad(
       SpirvInstruction *value = spvBuilder.createLoad(
-          param->getType(), vars[index], callExpr->getLocEnd());
+          param->getType(), vars[index], arg->getLocStart());
 
 
       // Now we want to assign 'value' to arg. But first, in rare cases when
       // Now we want to assign 'value' to arg. But first, in rare cases when
       // using 'out' or 'inout' where the parameter and argument have a type
       // using 'out' or 'inout' where the parameter and argument have a type

+ 4 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1056,6 +1056,10 @@ private:
   /// Note: legalization specific code
   /// Note: legalization specific code
   bool needsLegalization;
   bool needsLegalization;
 
 
+  /// Whether the translated SPIR-V binary passes --before-hlsl-legalization
+  /// option to spirv-val because of illegal function parameter scope.
+  bool beforeHlslLegalization;
+
   /// Mapping from methods to the decls to represent their implicit object
   /// Mapping from methods to the decls to represent their implicit object
   /// parameters
   /// parameters
   ///
   ///

+ 42 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.function-param.hlsl

@@ -0,0 +1,42 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_int Uniform
+RWStructuredBuffer<int> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_int Private
+static int B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_int Workgroup
+groupshared int C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup__arr_int_uint_10 Workgroup
+groupshared int D[10];
+
+int foo(int x, int y, int z, int w[10], int v) {
+  return x | y | z | w[0] | v;
+}
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_int Function
+  int E;
+
+// CHECK:      %param_var_x = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_y = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_z = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_w = OpVariable %_ptr_Function__arr_int_uint_10 Function
+// CHECK-NEXT: %param_var_v = OpVariable %_ptr_Function_int Function
+
+
+// CHECK:      [[A:%\d+]] = OpLoad %int {{%\d+}}
+// CHECK-NEXT:              OpStore %param_var_x [[A]]
+// CHECK-NEXT: [[B:%\d+]] = OpLoad %int %B
+// CHECK-NEXT:              OpStore %param_var_y [[B]]
+// CHECK-NEXT: [[C:%\d+]] = OpLoad %int %C
+// CHECK-NEXT:              OpStore %param_var_z [[C]]
+// CHECK-NEXT: [[D:%\d+]] = OpLoad %_arr_int_uint_10 %D
+// CHECK-NEXT:              OpStore %param_var_w [[D]]
+// CHECK-NEXT: [[E:%\d+]] = OpLoad %int %E
+// CHECK-NEXT:              OpStore %param_var_v [[E]]
+// CHECK-NEXT:   {{%\d+}} = OpFunctionCall %int %foo %param_var_x %param_var_y %param_var_z %param_var_w %param_var_v
+  A[0] = foo(A[0], B, C, D, E);
+}

+ 36 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.function-param.out.hlsl

@@ -0,0 +1,36 @@
+// Run: %dxc -T cs_6_0 -E main
+
+struct S {
+  int a;
+  float b;
+};
+
+void foo(out int x, out int y, out int z, out S w, out int v) {
+  x = 1;
+  y = x << 1;
+  z = x << 2;
+  w.a = x << 3;
+  v = x << 4;
+}
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_int Uniform
+RWStructuredBuffer<int> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_int Private
+static int B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_int Workgroup
+groupshared int C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S D;
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_int Function
+  int E;
+
+// CHECK:        [[A:%\d+]] = OpAccessChain %_ptr_Uniform_int %A %int_0 %uint_0
+// CHECK-NEXT:     {{%\d+}} = OpFunctionCall %void %foo [[A]] %B %C %D %E
+  foo(A[0], B, C, D, E);
+  A[0] = A[0] | B | C | D.a | E;
+}

+ 47 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.struct-function.hlsl

@@ -0,0 +1,47 @@
+// Run: %dxc -T cs_6_0 -E main
+
+struct S {
+  int a;
+  float b;
+
+  void foo(S x, inout S y, out S z, S w) {
+    a = 1;
+    x.a = a << 1;
+    y.a = x.a << 1;
+    z.a = x.a << 2;
+    w.a = x.a << 3;
+  }
+};
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform
+RWStructuredBuffer<S> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_S Private
+static S B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S D;
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_S Function
+  S E;
+
+// CHECK: %param_var_x = OpVariable %_ptr_Function_S Function
+// CHECK: %param_var_w = OpVariable %_ptr_Function_S Function
+
+// CHECK:        [[A:%\d+]] = OpAccessChain %_ptr_Uniform_S_0 %A %int_0 %uint_0
+// CHECK-NEXT: [[A_0:%\d+]] = OpLoad %S_0 [[A]]
+// CHECK-NEXT:   [[a:%\d+]] = OpCompositeExtract %int [[A_0]] 0
+// CHECK-NEXT:   [[b:%\d+]] = OpCompositeExtract %float [[A_0]] 1
+// CHECK-NEXT: [[A_0:%\d+]] = OpCompositeConstruct %S [[a]] [[b]]
+// CHECK-NEXT:                OpStore %param_var_x [[A_0]]
+// CHECK-NEXT:   [[E:%\d+]] = OpLoad %S %E
+// CHECK-NEXT:                OpStore %param_var_w [[E]]
+// CHECK-NEXT:     {{%\d+}} = OpFunctionCall %void %S_foo %D %param_var_x %B %C %param_var_w
+  D.foo(A[0], B, C, E);
+
+  A[0].a = A[0].a | B.a | C.a | D.a;
+}

+ 1 - 1
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.hlsl

@@ -92,7 +92,7 @@ precise float func3(float e, float f, float g, float h) {
 // CHECK-NEXT:                        OpLoad %float %g_1
 // CHECK-NEXT:                        OpLoad %float %g_1
 // CHECK-NEXT:                        OpLoad %float %h_1
 // CHECK-NEXT:                        OpLoad %float %h_1
 // CHECK-NEXT:    [[func3_g_mul_h]] = OpFMul %float
 // CHECK-NEXT:    [[func3_g_mul_h]] = OpFMul %float
-// CHECK-NEXT: [[func3_ef_plus_gh]] = OpFAdd %float %162 %165
+// CHECK-NEXT: [[func3_ef_plus_gh]] = OpFAdd %float [[func3_e_mul_f]] [[func3_g_mul_h]]
   float result = (e*f) + (g*h); // precise because it's the function return value.
   float result = (e*f) + (g*h); // precise because it's the function return value.
   return result;
   return result;
 }
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/fn.call.hlsl

@@ -40,10 +40,10 @@ void main() {
 // CHECK-NEXT: [[call2:%\d+]] = OpFunctionCall %int %fnTwoParm [[twoParam1]] [[twoParam2]]
 // CHECK-NEXT: [[call2:%\d+]] = OpFunctionCall %int %fnTwoParm [[twoParam1]] [[twoParam2]]
     fnTwoParm(v, v);  // Pass in variable; ignore return value
     fnTwoParm(v, v);  // Pass in variable; ignore return value
 
 
-// CHECK-NEXT: OpStore [[nestedParam2]] %int_1
-// CHECK-NEXT: [[call3:%\d+]] = OpFunctionCall %int %fnOneParm [[nestedParam2]]
-// CHECK-NEXT: OpStore [[nestedParam1]] [[call3]]
-// CHECK-NEXT: [[call4:%\d+]] = OpFunctionCall %int %fnCallOthers [[nestedParam1]]
+// CHECK-NEXT: OpStore [[nestedParam1]] %int_1
+// CHECK-NEXT: [[call3:%\d+]] = OpFunctionCall %int %fnOneParm [[nestedParam1]]
+// CHECK-NEXT: OpStore [[nestedParam2]] [[call3]]
+// CHECK-NEXT: [[call4:%\d+]] = OpFunctionCall %int %fnCallOthers [[nestedParam2]]
 // CHECK-NEXT: OpReturn
 // CHECK-NEXT: OpReturn
 // CHECK-NEXT: OpFunctionEnd
 // CHECK-NEXT: OpFunctionEnd
     fnCallOthers(fnOneParm(1)); // Nested function calls
     fnCallOthers(fnOneParm(1)); // Nested function calls

+ 1 - 6
tools/clang/test/CodeGenSPIRV/fn.ctbuffer.hlsl

@@ -26,13 +26,8 @@ tbuffer MyTBuffer {
 
 
 float4 main() : SV_Target {
 float4 main() : SV_Target {
 // %S vs %S_0: need destruction and construction
 // %S vs %S_0: need destruction and construction
-// CHECK:         %temp_var_S = OpVariable %_ptr_Function_S_0 Function
 // CHECK:       [[tb_s:%\d+]] = OpAccessChain %_ptr_Uniform_S %MyTBuffer %int_1
 // CHECK:       [[tb_s:%\d+]] = OpAccessChain %_ptr_Uniform_S %MyTBuffer %int_1
-// CHECK-NEXT:     [[s:%\d+]] = OpLoad %S [[tb_s]]
-// CHECK-NEXT: [[s_val:%\d+]] = OpCompositeExtract %v3float [[s]] 0
-// CHECK-NEXT:   [[tmp:%\d+]] = OpCompositeConstruct %S_0 [[s_val]]
-// CHECK-NEXT:                  OpStore %temp_var_S [[tmp]]
-// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %v3float %S_get_s_val %temp_var_S
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %v3float %S_get_s_val [[tb_s]]
     return get_cb_val() + float4(tb_s.get_s_val(), 0.) * get_tb_val();
     return get_cb_val() + float4(tb_s.get_s_val(), 0.) * get_tb_val();
 }
 }
 
 

+ 1 - 11
tools/clang/test/CodeGenSPIRV/fn.param.inout.storage-class.hlsl

@@ -9,22 +9,12 @@ void foo(in float a, inout float b, out float c) {
 
 
 void main(float input : INPUT) {
 void main(float input : INPUT) {
 // CHECK: %param_var_a = OpVariable %_ptr_Function_float Function
 // CHECK: %param_var_a = OpVariable %_ptr_Function_float Function
-// CHECK: %param_var_b = OpVariable %_ptr_Function_float Function
-// CHECK: %param_var_c = OpVariable %_ptr_Function_float Function
 
 
 // CHECK: [[val:%\d+]] = OpLoad %float %input
 // CHECK: [[val:%\d+]] = OpLoad %float %input
 // CHECK:                OpStore %param_var_a [[val]]
 // CHECK:                OpStore %param_var_a [[val]]
 // CHECK:  [[p0:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_0
 // CHECK:  [[p0:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_0
-// CHECK: [[val:%\d+]] = OpLoad %float [[p0]]
-// CHECK:                OpStore %param_var_b [[val]]
 // CHECK:  [[p1:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_1
 // CHECK:  [[p1:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_1
-// CHECK: [[val:%\d+]] = OpLoad %float [[p1]]
-// CHECK:                OpStore %param_var_c [[val]]
 
 
-// CHECK:                OpFunctionCall %void %foo %param_var_a %param_var_b %param_var_c
+// CHECK:                OpFunctionCall %void %foo %param_var_a [[p0]] [[p1]]
     foo(input, Data[0], Data[1]);
     foo(input, Data[0], Data[1]);
-// CHECK: [[val:%\d+]] = OpLoad %float %param_var_b
-// CHECK:                OpStore [[p0]] [[val]]
-// CHECK: [[val:%\d+]] = OpLoad %float %param_var_c
-// CHECK:                OpStore [[p1]] [[val]]
 }
 }

+ 1 - 3
tools/clang/test/CodeGenSPIRV/fn.param.inout.vector.hlsl

@@ -18,7 +18,7 @@ float4 main() : C {
 
 
     float4 val;
     float4 val;
 // CHECK:    [[z_ptr:%\d+]] = OpAccessChain %_ptr_Function_float %val %int_2
 // CHECK:    [[z_ptr:%\d+]] = OpAccessChain %_ptr_Function_float %val %int_2
-// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %val %param_var_y %param_var_z %param_var_w
+// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %val %param_var_y %param_var_z [[z_ptr]]
 // CHECK-NEXT:   [[y:%\d+]] = OpLoad %v3float %param_var_y
 // CHECK-NEXT:   [[y:%\d+]] = OpLoad %v3float %param_var_y
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
     // Write to val.zwx:
     // Write to val.zwx:
@@ -37,8 +37,6 @@ float4 main() : C {
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
 // CHECK-NEXT: [[new:%\d+]] = OpVectorShuffle %v4float [[old]] [[z]] 4 5 2 3
 // CHECK-NEXT: [[new:%\d+]] = OpVectorShuffle %v4float [[old]] [[z]] 4 5 2 3
 // CHECK-NEXT:                OpStore %val [[new]]
 // CHECK-NEXT:                OpStore %val [[new]]
-// CHECK-NEXT:   [[w:%\d+]] = OpLoad %float %param_var_w
-// CHECK-NEXT:                OpStore [[z_ptr]] [[w]]
     bar(val, val.zwx, val.xy, val.z);
     bar(val, val.zwx, val.xy, val.z);
 
 
     return MyRWBuffer[0];
     return MyRWBuffer[0];

+ 106 - 0
tools/clang/test/CodeGenSPIRV/fn.param.isomorphism.hlsl

@@ -0,0 +1,106 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct R {
+  int a;
+  void incr() { ++a; }
+};
+
+// CHECK: %rwsb = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_R Uniform
+RWStructuredBuffer<R> rwsb;
+
+struct S {
+  int a;
+  void incr() { ++a; }
+};
+
+// CHECK: %gs = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S gs;
+
+// CHECK: %st = OpVariable %_ptr_Private_S Private
+static S st;
+
+void decr(inout R foo) {
+  foo.a--;
+};
+
+void decr2(inout S foo) {
+  foo.a--;
+};
+
+void int_decr(out int foo) {
+  ++foo;
+}
+
+// CHECK: %gsarr = OpVariable %_ptr_Workgroup__arr_S_uint_10 Workgroup
+groupshared S gsarr[10];
+
+// CHECK: %starr = OpVariable %_ptr_Private__arr_S_uint_10 Private
+static S starr[10];
+
+void main() {
+// CHECK:    %fn = OpVariable %_ptr_Function_S Function
+  S fn;
+
+// CHECK: %fnarr = OpVariable %_ptr_Function__arr_S_uint_10 Function
+  S fnarr[10];
+
+// CHECK:   %arr = OpVariable %_ptr_Function__arr_int_uint_10 Function
+  int arr[10];
+
+// CHECK:      [[rwsb:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:      {{%\d+}} = OpFunctionCall %void %R_incr [[rwsb]]
+  rwsb[0].incr();
+
+// CHECK: OpFunctionCall %void %S_incr %gs
+  gs.incr();
+
+// CHECK: OpFunctionCall %void %S_incr %st
+  st.incr();
+
+// CHECK: OpFunctionCall %void %S_incr %fn
+  fn.incr();
+
+// CHECK:      [[rwsb:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:      {{%\d+}} = OpFunctionCall %void %decr [[rwsb]]
+  decr(rwsb[0]);
+
+// CHECK: OpFunctionCall %void %decr2 %gs
+  decr2(gs);
+
+// CHECK: OpFunctionCall %void %decr2 %st
+  decr2(st);
+
+// CHECK: OpFunctionCall %void %decr2 %fn
+  decr2(fn);
+
+// CHECK:      [[gsarr:%\d+]] = OpAccessChain %_ptr_Workgroup_S %gsarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[gsarr]]
+  gsarr[0].incr();
+
+// CHECK:      [[starr:%\d+]] = OpAccessChain %_ptr_Private_S %starr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[starr]]
+  starr[0].incr();
+
+// CHECK:      [[fnarr:%\d+]] = OpAccessChain %_ptr_Function_S %fnarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[fnarr]]
+  fnarr[0].incr();
+
+// CHECK:      [[gsarr:%\d+]] = OpAccessChain %_ptr_Workgroup_S %gsarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[gsarr]]
+  decr2(gsarr[0]);
+
+// CHECK:      [[starr:%\d+]] = OpAccessChain %_ptr_Private_S %starr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[starr]]
+  decr2(starr[0]);
+
+// CHECK:      [[fnarr:%\d+]] = OpAccessChain %_ptr_Function_S %fnarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[fnarr]]
+  decr2(fnarr[0]);
+
+// CHECK:        [[arr:%\d+]] = OpAccessChain %_ptr_Function_int %arr %int_0
+// CHECK-NEXT: [[arr_0:%\d+]] = OpLoad %int [[arr]]
+// CHECK-NEXT: [[arr_0:%\d+]] = OpIAdd %int [[arr_0]] %int_1
+// CHECK-NEXT:                  OpStore [[arr]] [[arr_0]]
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %int_decr [[arr]]
+  int_decr(++arr[0]);
+}

+ 2 - 7
tools/clang/test/CodeGenSPIRV/oo.method.on-static-var.hlsl

@@ -9,12 +9,7 @@ struct S {
 static S gSVar = {4.2};
 static S gSVar = {4.2};
 
 
 float main() : A {
 float main() : A {
-// CHECK:      %temp_var_S = OpVariable %_ptr_Function_S Function
-
-// CHECK:       [[s:%\d+]] = OpLoad %S %gSVar
-// CHECK-NEXT:               OpStore %temp_var_S [[s]]
-// CHECK-NEXT:    {{%\d+}} = OpFunctionCall %float %S_getVal %temp_var_S
-// CHECK-NEXT:  [[s:%\d+]] = OpLoad %S %temp_var_S
-// CHECK-NEXT:               OpStore %gSVar [[s]]
+// CHECK:      [[ret:%\d+]] = OpFunctionCall %float %S_getVal %gSVar
+// CHECK-NEXT:                OpReturnValue [[ret]]
     return gSVar.getVal();
     return gSVar.getVal();
 }
 }

+ 2 - 10
tools/clang/test/CodeGenSPIRV/oo.struct.method.hlsl

@@ -104,16 +104,8 @@ float main() : A {
 // CHECK-NEXT:        {{%\d+}} = OpFunctionCall %float %S_fn_ref %temp_var_S
 // CHECK-NEXT:        {{%\d+}} = OpFunctionCall %float %S_fn_ref %temp_var_S
   float f2 = foo().get_S().fn_ref();
   float f2 = foo().get_S().fn_ref();
 
 
-// CHECK:         [[uniformPtr:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
-// CHECK-NEXT:   [[originalObj:%\d+]] = OpLoad %R [[uniformPtr]]
-// CHECK-NEXT:        [[member:%\d+]] = OpCompositeExtract %int [[originalObj]] 0
-// CHECK-NEXT:       [[tempVar:%\d+]] = OpCompositeConstruct %R_0 [[member]]
-// CHECK-NEXT:                          OpStore %temp_var_R [[tempVar]]
-// CHECK-NEXT:                          OpFunctionCall %void %R_incr %temp_var_R
-// CHECK-NEXT:       [[tempVar:%\d+]] = OpLoad %R_0 %temp_var_R
-// CHECK-NEXT: [[tempVarMember:%\d+]] = OpCompositeExtract %int [[tempVar]] 0
-// CHECK-NEXT:          [[newR:%\d+]] = OpCompositeConstruct %R [[tempVarMember]]
-// CHECK-NEXT:                          OpStore [[uniformPtr]] [[newR]]
+// CHECK:      [[rwsb_0:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:                   OpFunctionCall %void %R_incr [[rwsb_0]]
   rwsb[0].incr();
   rwsb[0].incr();
 
 
   return f1;
   return f1;

+ 34 - 11
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -88,7 +88,7 @@ TEST_F(FileTest, StructuredBufferType) {
   runFileTest("type.structured-buffer.hlsl");
   runFileTest("type.structured-buffer.hlsl");
 }
 }
 TEST_F(FileTest, StructuredByteBufferArray) {
 TEST_F(FileTest, StructuredByteBufferArray) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("type.structured-buffer.array.hlsl");
   runFileTest("type.structured-buffer.array.hlsl");
 }
 }
 TEST_F(FileTest, StructuredByteBufferArrayError) {
 TEST_F(FileTest, StructuredByteBufferArrayError) {
@@ -496,11 +496,17 @@ TEST_F(FileTest, FunctionInOutParam) {
   runFileTest("fn.param.inout.hlsl");
   runFileTest("fn.param.inout.hlsl");
 }
 }
 TEST_F(FileTest, FunctionInOutParamVector) {
 TEST_F(FileTest, FunctionInOutParamVector) {
+  setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.vector.hlsl");
   runFileTest("fn.param.inout.vector.hlsl");
 }
 }
 TEST_F(FileTest, FunctionInOutParamDiffStorageClass) {
 TEST_F(FileTest, FunctionInOutParamDiffStorageClass) {
+  setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.storage-class.hlsl");
   runFileTest("fn.param.inout.storage-class.hlsl");
 }
 }
+TEST_F(FileTest, FunctionInOutParamIsomorphism) {
+  setBeforeHLSLLegalization();
+  runFileTest("fn.param.isomorphism.hlsl");
+}
 TEST_F(FileTest, FunctionInOutParamNoNeedToCopy) {
 TEST_F(FileTest, FunctionInOutParamNoNeedToCopy) {
   // Tests that referencing function scope variables as a whole with out/inout
   // Tests that referencing function scope variables as a whole with out/inout
   // annotation does not create temporary variables
   // annotation does not create temporary variables
@@ -517,15 +523,18 @@ TEST_F(FileTest, FunctionInOutParamTypeMismatch) {
 TEST_F(FileTest, FunctionFowardDeclaration) {
 TEST_F(FileTest, FunctionFowardDeclaration) {
   runFileTest("fn.foward-declaration.hlsl");
   runFileTest("fn.foward-declaration.hlsl");
 }
 }
-TEST_F(FileTest, FunctionInCTBuffer) { runFileTest("fn.ctbuffer.hlsl"); }
+TEST_F(FileTest, FunctionInCTBuffer) {
+  setBeforeHLSLLegalization();
+  runFileTest("fn.ctbuffer.hlsl");
+}
 
 
 // For OO features
 // For OO features
 TEST_F(FileTest, StructMethodCall) {
 TEST_F(FileTest, StructMethodCall) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("oo.struct.method.hlsl");
   runFileTest("oo.struct.method.hlsl");
 }
 }
 TEST_F(FileTest, ClassMethodCall) {
 TEST_F(FileTest, ClassMethodCall) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("oo.class.method.hlsl");
   runFileTest("oo.class.method.hlsl");
 }
 }
 TEST_F(FileTest, StructStaticMember) {
 TEST_F(FileTest, StructStaticMember) {
@@ -538,6 +547,7 @@ TEST_F(FileTest, StaticMemberInitializer) {
   runFileTest("oo.static.member.init.hlsl");
   runFileTest("oo.static.member.init.hlsl");
 }
 }
 TEST_F(FileTest, MethodCallOnStaticVar) {
 TEST_F(FileTest, MethodCallOnStaticVar) {
+  setBeforeHLSLLegalization();
   runFileTest("oo.method.on-static-var.hlsl");
   runFileTest("oo.method.on-static-var.hlsl");
 }
 }
 TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
 TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
@@ -1366,34 +1376,34 @@ TEST_F(FileTest, SpirvLegalizationOpaqueStruct) {
   runFileTest("spirv.legal.opaque-struct.hlsl");
   runFileTest("spirv.legal.opaque-struct.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferUsage) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferUsage) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.usage.hlsl");
   runFileTest("spirv.legal.sbuffer.usage.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferMethods) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferMethods) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.methods.hlsl");
   runFileTest("spirv.legal.sbuffer.methods.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.hlsl");
   runFileTest("spirv.legal.sbuffer.counter.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
   // Tests using struct/class having associated counters
   // Tests using struct/class having associated counters
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl");
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInMethod) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInMethod) {
   // Tests using methods whose enclosing struct/class having associated counters
   // Tests using methods whose enclosing struct/class having associated counters
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.method.hlsl");
   runFileTest("spirv.legal.sbuffer.counter.method.hlsl");
 }
 }
 TEST_F(FileTest,
 TEST_F(FileTest,
        SpirvLegalizationCounterVarAssignAcrossDifferentNestedStructLevel) {
        SpirvLegalizationCounterVarAssignAcrossDifferentNestedStructLevel) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.counter.nested-struct.hlsl");
   runFileTest("spirv.legal.counter.nested-struct.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.struct.hlsl");
   runFileTest("spirv.legal.sbuffer.struct.hlsl");
 }
 }
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
@@ -1782,6 +1792,18 @@ TEST_F(FileTest, GeometryShaderEmit) { runFileTest("gs.emit.hlsl"); }
 TEST_F(FileTest, ComputeShaderGroupShared) {
 TEST_F(FileTest, ComputeShaderGroupShared) {
   runFileTest("cs.groupshared.hlsl");
   runFileTest("cs.groupshared.hlsl");
 }
 }
+TEST_F(FileTest, ComputeShaderGroupSharedFunctionParam) {
+  setRelaxLogicalPointer();
+  runFileTest("cs.groupshared.function-param.hlsl");
+}
+TEST_F(FileTest, ComputeShaderGroupSharedFunctionParamOut) {
+  setBeforeHLSLLegalization();
+  runFileTest("cs.groupshared.function-param.out.hlsl");
+}
+TEST_F(FileTest, ComputeShaderGroupSharedStructFunction) {
+  setBeforeHLSLLegalization();
+  runFileTest("cs.groupshared.struct-function.hlsl");
+}
 
 
 // === Legalization examples ===
 // === Legalization examples ===
 
 
@@ -1914,6 +1936,7 @@ TEST_F(FileTest, DecorationRelaxedPrecisionImage) {
 
 
 // For NoContraction decorations
 // For NoContraction decorations
 TEST_F(FileTest, DecorationNoContraction) {
 TEST_F(FileTest, DecorationNoContraction) {
+  setBeforeHLSLLegalization();
   runFileTest("decoration.no-contraction.hlsl");
   runFileTest("decoration.no-contraction.hlsl");
 }
 }
 TEST_F(FileTest, DecorationNoContractionVariableReuse) {
 TEST_F(FileTest, DecorationNoContractionVariableReuse) {

+ 9 - 9
tools/clang/unittests/SPIRV/FileTestFixture.cpp

@@ -62,7 +62,7 @@ bool FileTest::parseInputFile() {
 
 
 void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
 void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
                            bool runValidation) {
                            bool runValidation) {
-  if (relaxLogicalPointer)
+  if (relaxLogicalPointer || beforeHLSLLegalization)
     assert(runValidation);
     assert(runValidation);
 
 
   inputFilePath = utils::getAbsPathOfInputDataFile(filename);
   inputFilePath = utils::getAbsPathOfInputDataFile(filename);
@@ -102,9 +102,9 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
 
 
     if (runValidation)
     if (runValidation)
-      EXPECT_TRUE(utils::validateSpirvBinary(targetEnv, generatedBinary,
-                                             relaxLogicalPointer, glLayout,
-                                             dxLayout, scalarLayout));
+      EXPECT_TRUE(utils::validateSpirvBinary(
+          targetEnv, generatedBinary, relaxLogicalPointer,
+          beforeHLSLLegalization, glLayout, dxLayout, scalarLayout));
   } else if (expect == Expect::Warning) {
   } else if (expect == Expect::Warning) {
     ASSERT_TRUE(compileOk);
     ASSERT_TRUE(compileOk);
 
 
@@ -128,9 +128,9 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
 
 
     if (runValidation)
     if (runValidation)
-      EXPECT_TRUE(utils::validateSpirvBinary(targetEnv, generatedBinary,
-                                             relaxLogicalPointer, glLayout,
-                                             dxLayout, scalarLayout));
+      EXPECT_TRUE(utils::validateSpirvBinary(
+          targetEnv, generatedBinary, relaxLogicalPointer,
+          beforeHLSLLegalization, glLayout, dxLayout, scalarLayout));
   } else if (expect == Expect::Failure) {
   } else if (expect == Expect::Failure) {
     ASSERT_FALSE(compileOk);
     ASSERT_FALSE(compileOk);
 
 
@@ -157,8 +157,8 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
 
 
     std::string valMessages;
     std::string valMessages;
     EXPECT_FALSE(utils::validateSpirvBinary(
     EXPECT_FALSE(utils::validateSpirvBinary(
-        targetEnv, generatedBinary, relaxLogicalPointer, glLayout, dxLayout,
-        scalarLayout, &valMessages));
+        targetEnv, generatedBinary, relaxLogicalPointer, beforeHLSLLegalization,
+        glLayout, dxLayout, scalarLayout, &valMessages));
     auto options = effcee::Options()
     auto options = effcee::Options()
                        .SetChecksName(filename.str())
                        .SetChecksName(filename.str())
                        .SetInputName("<val-message>");
                        .SetInputName("<val-message>");

+ 3 - 1
tools/clang/unittests/SPIRV/FileTestFixture.h

@@ -29,10 +29,11 @@ public:
 
 
   FileTest()
   FileTest()
       : targetEnv(SPV_ENV_VULKAN_1_0), relaxLogicalPointer(false),
       : targetEnv(SPV_ENV_VULKAN_1_0), relaxLogicalPointer(false),
-        glLayout(false), dxLayout(false) {}
+        beforeHLSLLegalization(false), glLayout(false), dxLayout(false) {}
 
 
   void useVulkan1p1() { targetEnv = SPV_ENV_VULKAN_1_1; }
   void useVulkan1p1() { targetEnv = SPV_ENV_VULKAN_1_1; }
   void setRelaxLogicalPointer() { relaxLogicalPointer = true; }
   void setRelaxLogicalPointer() { relaxLogicalPointer = true; }
+  void setBeforeHLSLLegalization() { beforeHLSLLegalization = true; }
   void setGlLayout() { glLayout = true; }
   void setGlLayout() { glLayout = true; }
   void setDxLayout() { dxLayout = true; }
   void setDxLayout() { dxLayout = true; }
   void setScalarLayout() { scalarLayout = true; }
   void setScalarLayout() { scalarLayout = true; }
@@ -54,6 +55,7 @@ private:
   std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
   std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
   spv_target_env targetEnv;              ///< Environment to validate against
   spv_target_env targetEnv;              ///< Environment to validate against
   bool relaxLogicalPointer;
   bool relaxLogicalPointer;
+  bool beforeHLSLLegalization;
   bool glLayout;
   bool glLayout;
   bool dxLayout;
   bool dxLayout;
   bool scalarLayout;
   bool scalarLayout;

+ 4 - 2
tools/clang/unittests/SPIRV/FileTestUtils.cpp

@@ -35,10 +35,12 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 }
 }
 
 
 bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
 bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
-                         bool relaxLogicalPointer, bool glLayout, bool dxLayout,
-                         bool scalarLayout, std::string *message) {
+                         bool relaxLogicalPointer, bool beforeHlslLegalization,
+                         bool glLayout, bool dxLayout, bool scalarLayout,
+                         std::string *message) {
   spvtools::ValidatorOptions options;
   spvtools::ValidatorOptions options;
   options.SetRelaxLogicalPointer(relaxLogicalPointer);
   options.SetRelaxLogicalPointer(relaxLogicalPointer);
+  options.SetBeforeHlslLegalization(beforeHlslLegalization);
   if (dxLayout || scalarLayout) {
   if (dxLayout || scalarLayout) {
     options.SetSkipBlockLayout(true);
     options.SetSkipBlockLayout(true);
   } else if (glLayout) {
   } else if (glLayout) {

+ 3 - 2
tools/clang/unittests/SPIRV/FileTestUtils.h

@@ -33,8 +33,9 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 /// \brief Runs the SPIR-V Tools validation on the given SPIR-V binary.
 /// \brief Runs the SPIR-V Tools validation on the given SPIR-V binary.
 /// Returns true if validation is successful; false otherwise.
 /// Returns true if validation is successful; false otherwise.
 bool validateSpirvBinary(spv_target_env, std::vector<uint32_t> &binary,
 bool validateSpirvBinary(spv_target_env, std::vector<uint32_t> &binary,
-                         bool relaxLogicalPointer, bool glLayout, bool dxLayout,
-                         bool scalarLayout, std::string *message = nullptr);
+                         bool relaxLogicalPointer, bool beforeHlslLegalization,
+                         bool glLayout, bool dxLayout, bool scalarLayout,
+                         std::string *message = nullptr);
 
 
 /// \brief Parses the Target Profile and Entry Point from the Run command
 /// \brief Parses the Target Profile and Entry Point from the Run command
 /// Returns the target profile, entry point, and the rest via arguments.
 /// Returns the target profile, entry point, and the rest via arguments.

+ 2 - 2
tools/clang/unittests/SPIRV/WholeFileTestFixture.cpp

@@ -109,8 +109,8 @@ void WholeFileTest::runWholeFileTest(llvm::StringRef filename,
   if (runSpirvValidation) {
   if (runSpirvValidation) {
     EXPECT_TRUE(utils::validateSpirvBinary(
     EXPECT_TRUE(utils::validateSpirvBinary(
         targetEnv, generatedBinary,
         targetEnv, generatedBinary,
-        /*relaxLogicalPointer=*/false, /*glLayout=*/false, /*dxLayout=*/false,
-        /*scalarLayout=*/false));
+        /*relaxLogicalPointer=*/false, /*beforeHlslLegalization=*/false,
+        /*glLayout=*/false, /*dxLayout=*/false, /*scalarLayout=*/false));
   }
   }
 }
 }