Kaynağa Gözat

[spirv] allow illegal storage/isomorphic function arg (#1991)

Before this change, decision of creating temporary variables for
function arguments were affected by three things: storage class,
param type, out or inout keyword. For example, if storage class of
a function arg is different from the param, we create a temporary
variable whose storage class is the same with the param and store
the function arg before the function call. After the function call
we store the value to the actual arg if needed.

Now based on HLSL legalization supported by spirv-opt, we do not
need to consider storage class. Similarly, we do not need to
consider type difference between arg and param if they have an
isomorphic type e.g., same structure but different names.
Jaebaek Seo 6 yıl önce
ebeveyn
işleme
c9b258a40c

+ 48 - 28
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -192,8 +192,8 @@ bool spirvToolsOptimize(spv_target_env env, std::vector<uint32_t> *module,
 }
 
 bool spirvToolsValidate(spv_target_env env, const SpirvCodeGenOptions &opts,
-                        bool relaxLogicalPointer, std::vector<uint32_t> *module,
-                        std::string *messages) {
+                        bool beforeHlslLegalization, bool relaxLogicalPointer,
+                        std::vector<uint32_t> *module, std::string *messages) {
   spvtools::SpirvTools tools(env);
 
   tools.SetMessageConsumer(
@@ -202,7 +202,16 @@ bool spirvToolsValidate(spv_target_env env, const SpirvCodeGenOptions &opts,
                  const char *message) { *messages += message; });
 
   spvtools::ValidatorOptions options;
-  options.SetRelaxLogicalPointer(relaxLogicalPointer);
+  options.SetBeforeHlslLegalization(beforeHlslLegalization);
+  // When beforeHlslLegalization is true and relaxLogicalPointer is false,
+  // options.SetBeforeHlslLegalization() enables --before-hlsl-legalization
+  // and --relax-logical-pointer. If options.SetRelaxLogicalPointer() is
+  // called, it disables --relax-logical-pointer that is not expected
+  // behavior. When beforeHlslLegalization is true, we must enable both
+  // options.
+  if (!beforeHlslLegalization)
+    options.SetRelaxLogicalPointer(relaxLogicalPointer);
+
   // GL: strict block layout rules
   // VK: relaxed block layout rules
   // DX: Skip block layout rules
@@ -486,7 +495,7 @@ SpirvEmitter::SpirvEmitter(CompilerInstance &ci)
                    spirvOptions),
       entryFunction(nullptr), curFunction(nullptr), curThis(nullptr),
       seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
-      mainSourceFile(nullptr) {
+      beforeHlslLegalization(false), mainSourceFile(nullptr) {
 
   // Get ShaderModel from command line hlsl profile option.
   const hlsl::ShaderModel *shaderModel =
@@ -670,7 +679,7 @@ void SpirvEmitter::HandleTranslationUnit(ASTContext &context) {
   // Validate the generated SPIR-V code
   if (!spirvOptions.disableValidation) {
     std::string messages;
-    if (!spirvToolsValidate(targetEnv, spirvOptions,
+    if (!spirvToolsValidate(targetEnv, spirvOptions, beforeHlslLegalization,
                             declIdMapper.requiresLegalization(), &m,
                             &messages)) {
       emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
@@ -2033,7 +2042,6 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   bool isNonStaticMemberCall = false;
   QualType objectType = {};             // Type of the object (if exists)
   SpirvInstruction *objInstr = nullptr; // EvalInfo for the object (if exists)
-  bool needsTempVar = false;            // Whether we need temporary variable.
 
   llvm::SmallVector<SpirvInstruction *, 4> vars; // Variables for function call
   llvm::SmallVector<bool, 4> isTempVar;          // Temporary variable or not
@@ -2060,15 +2068,18 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
       // getObject().objectMethod();
       // Also, any parameter passed to the member function must be of Function
       // storage class.
-      needsTempVar = objInstr->isRValue() ||
-                     objInstr->getStorageClass() != spv::StorageClass::Function;
-
-      if (needsTempVar) {
+      if (objInstr->isRValue()) {
         args.push_back(createTemporaryVar(
             objectType, getAstTypeName(objectType),
             // May need to load to use as initializer
-            loadIfGLValue(object, objInstr), object->getExprLoc()));
+            loadIfGLValue(object, objInstr), object->getLocStart()));
       } else {
+        // Based on SPIR-V spec, function parameter must always be in Function
+        // scope. If we pass a non-function scope argument, we need
+        // the legalization.
+        if (objInstr->getStorageClass() != spv::StorageClass::Function)
+          beforeHlslLegalization = true;
+
         args.push_back(objInstr);
       }
 
@@ -2089,19 +2100,32 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
 
     // Get the evaluation info if this argument is referencing some variable
     // *as a whole*, in which case we can avoid creating the temporary variable
-    // for it if it is Function scope and can act as out parameter.
+    // for it if it can act as out parameter.
     SpirvInstruction *argInfo = nullptr;
     if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(arg)) {
       argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(),
                                              arg->getLocStart());
     }
 
-    if (argInfo && argInfo->getStorageClass() == spv::StorageClass::Function &&
+    auto *argInst = doExpr(arg);
+    auto argType = arg->getType();
+
+    // If argInfo is nullptr and argInst is a rvalue, we do not have a proper
+    // pointer to pass to the function. we need a temporary variable in that
+    // case.
+    if ((argInfo || (argInst && !argInst->isRValue())) &&
         canActAsOutParmVar(param) &&
         paramTypeMatchesArgType(param->getType(), arg->getType())) {
-      vars.push_back(argInfo);
+      // Based on SPIR-V spec, function parameter must be always Function
+      // scope. In addition, we must pass memory object declaration argument
+      // to function. If we pass an argument that is not function scope
+      // or not memory object declaration, we need the legalization.
+      if (!argInfo || argInfo->getStorageClass() != spv::StorageClass::Function)
+        beforeHlslLegalization = true;
+
       isTempVar.push_back(false);
-      args.push_back(doExpr(arg));
+      args.push_back(argInst);
+      vars.push_back(argInfo ? argInfo : argInst);
     } else {
       // We need to create variables for holding the values to be used as
       // arguments. The variables themselves are of pointer types.
@@ -2119,7 +2143,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
 
       vars.push_back(tempVar);
       isTempVar.push_back(true);
-      args.push_back(doExpr(arg));
+      args.push_back(argInst);
 
       // Update counter variable associated with function parameters
       tryToAssignCounterVar(param, arg);
@@ -2150,6 +2174,9 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     }
   }
 
+  if (beforeHlslLegalization)
+    needsLegalization = true;
+
   assert(vars.size() == isTempVar.size());
   assert(vars.size() == args.size());
 
@@ -2165,23 +2192,16 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   auto *retVal = spvBuilder.createFunctionCall(
       retType, func, vars, callExpr->getCallee()->getExprLoc());
 
-  // If we created a temporary variable for the lvalue object this method is
-  // invoked upon, we need to copy the contents in the temporary variable back
-  // to the original object's variable in case there are side effects.
-  if (needsTempVar && !objInstr->isRValue()) {
-    auto *value =
-        spvBuilder.createLoad(objectType, vars.front(), callExpr->getLocEnd());
-    storeValue(objInstr, value, objectType, callExpr->getLocEnd());
-  }
-
   // Go through all parameters and write those marked as out/inout
   for (uint32_t i = 0; i < numParams; ++i) {
     const auto *param = callee->getParamDecl(i);
-    if (isTempVar[i] && canActAsOutParmVar(param)) {
+    // If it calls a non-static member function, the object itself is argument
+    // 0, and therefore all other argument positions are shifted by 1.
+    const uint32_t index = i + isNonStaticMemberCall;
+    if (isTempVar[index] && canActAsOutParmVar(param)) {
       const auto *arg = callExpr->getArg(i);
-      const uint32_t index = i + isNonStaticMemberCall;
       SpirvInstruction *value = spvBuilder.createLoad(
-          param->getType(), vars[index], callExpr->getLocEnd());
+          param->getType(), vars[index], arg->getLocStart());
 
       // Now we want to assign 'value' to arg. But first, in rare cases when
       // using 'out' or 'inout' where the parameter and argument have a type

+ 4 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -1056,6 +1056,10 @@ private:
   /// Note: legalization specific code
   bool needsLegalization;
 
+  /// Whether the translated SPIR-V binary passes --before-hlsl-legalization
+  /// option to spirv-val because of illegal function parameter scope.
+  bool beforeHlslLegalization;
+
   /// Mapping from methods to the decls to represent their implicit object
   /// parameters
   ///

+ 42 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.function-param.hlsl

@@ -0,0 +1,42 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_int Uniform
+RWStructuredBuffer<int> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_int Private
+static int B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_int Workgroup
+groupshared int C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup__arr_int_uint_10 Workgroup
+groupshared int D[10];
+
+int foo(int x, int y, int z, int w[10], int v) {
+  return x | y | z | w[0] | v;
+}
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_int Function
+  int E;
+
+// CHECK:      %param_var_x = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_y = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_z = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT: %param_var_w = OpVariable %_ptr_Function__arr_int_uint_10 Function
+// CHECK-NEXT: %param_var_v = OpVariable %_ptr_Function_int Function
+
+
+// CHECK:      [[A:%\d+]] = OpLoad %int {{%\d+}}
+// CHECK-NEXT:              OpStore %param_var_x [[A]]
+// CHECK-NEXT: [[B:%\d+]] = OpLoad %int %B
+// CHECK-NEXT:              OpStore %param_var_y [[B]]
+// CHECK-NEXT: [[C:%\d+]] = OpLoad %int %C
+// CHECK-NEXT:              OpStore %param_var_z [[C]]
+// CHECK-NEXT: [[D:%\d+]] = OpLoad %_arr_int_uint_10 %D
+// CHECK-NEXT:              OpStore %param_var_w [[D]]
+// CHECK-NEXT: [[E:%\d+]] = OpLoad %int %E
+// CHECK-NEXT:              OpStore %param_var_v [[E]]
+// CHECK-NEXT:   {{%\d+}} = OpFunctionCall %int %foo %param_var_x %param_var_y %param_var_z %param_var_w %param_var_v
+  A[0] = foo(A[0], B, C, D, E);
+}

+ 36 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.function-param.out.hlsl

@@ -0,0 +1,36 @@
+// Run: %dxc -T cs_6_0 -E main
+
+struct S {
+  int a;
+  float b;
+};
+
+void foo(out int x, out int y, out int z, out S w, out int v) {
+  x = 1;
+  y = x << 1;
+  z = x << 2;
+  w.a = x << 3;
+  v = x << 4;
+}
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_int Uniform
+RWStructuredBuffer<int> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_int Private
+static int B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_int Workgroup
+groupshared int C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S D;
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_int Function
+  int E;
+
+// CHECK:        [[A:%\d+]] = OpAccessChain %_ptr_Uniform_int %A %int_0 %uint_0
+// CHECK-NEXT:     {{%\d+}} = OpFunctionCall %void %foo [[A]] %B %C %D %E
+  foo(A[0], B, C, D, E);
+  A[0] = A[0] | B | C | D.a | E;
+}

+ 47 - 0
tools/clang/test/CodeGenSPIRV/cs.groupshared.struct-function.hlsl

@@ -0,0 +1,47 @@
+// Run: %dxc -T cs_6_0 -E main
+
+struct S {
+  int a;
+  float b;
+
+  void foo(S x, inout S y, out S z, S w) {
+    a = 1;
+    x.a = a << 1;
+    y.a = x.a << 1;
+    z.a = x.a << 2;
+    w.a = x.a << 3;
+  }
+};
+
+// CHECK: %A = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_S Uniform
+RWStructuredBuffer<S> A;
+
+// CHECK: %B = OpVariable %_ptr_Private_S Private
+static S B;
+
+// CHECK: %C = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S C;
+
+// CHECK: %D = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S D;
+
+void main() {
+// CHECK: %E = OpVariable %_ptr_Function_S Function
+  S E;
+
+// CHECK: %param_var_x = OpVariable %_ptr_Function_S Function
+// CHECK: %param_var_w = OpVariable %_ptr_Function_S Function
+
+// CHECK:        [[A:%\d+]] = OpAccessChain %_ptr_Uniform_S_0 %A %int_0 %uint_0
+// CHECK-NEXT: [[A_0:%\d+]] = OpLoad %S_0 [[A]]
+// CHECK-NEXT:   [[a:%\d+]] = OpCompositeExtract %int [[A_0]] 0
+// CHECK-NEXT:   [[b:%\d+]] = OpCompositeExtract %float [[A_0]] 1
+// CHECK-NEXT: [[A_0:%\d+]] = OpCompositeConstruct %S [[a]] [[b]]
+// CHECK-NEXT:                OpStore %param_var_x [[A_0]]
+// CHECK-NEXT:   [[E:%\d+]] = OpLoad %S %E
+// CHECK-NEXT:                OpStore %param_var_w [[E]]
+// CHECK-NEXT:     {{%\d+}} = OpFunctionCall %void %S_foo %D %param_var_x %B %C %param_var_w
+  D.foo(A[0], B, C, E);
+
+  A[0].a = A[0].a | B.a | C.a | D.a;
+}

+ 1 - 1
tools/clang/test/CodeGenSPIRV/decoration.no-contraction.hlsl

@@ -92,7 +92,7 @@ precise float func3(float e, float f, float g, float h) {
 // CHECK-NEXT:                        OpLoad %float %g_1
 // CHECK-NEXT:                        OpLoad %float %h_1
 // CHECK-NEXT:    [[func3_g_mul_h]] = OpFMul %float
-// CHECK-NEXT: [[func3_ef_plus_gh]] = OpFAdd %float %162 %165
+// CHECK-NEXT: [[func3_ef_plus_gh]] = OpFAdd %float [[func3_e_mul_f]] [[func3_g_mul_h]]
   float result = (e*f) + (g*h); // precise because it's the function return value.
   return result;
 }

+ 4 - 4
tools/clang/test/CodeGenSPIRV/fn.call.hlsl

@@ -40,10 +40,10 @@ void main() {
 // CHECK-NEXT: [[call2:%\d+]] = OpFunctionCall %int %fnTwoParm [[twoParam1]] [[twoParam2]]
     fnTwoParm(v, v);  // Pass in variable; ignore return value
 
-// CHECK-NEXT: OpStore [[nestedParam2]] %int_1
-// CHECK-NEXT: [[call3:%\d+]] = OpFunctionCall %int %fnOneParm [[nestedParam2]]
-// CHECK-NEXT: OpStore [[nestedParam1]] [[call3]]
-// CHECK-NEXT: [[call4:%\d+]] = OpFunctionCall %int %fnCallOthers [[nestedParam1]]
+// CHECK-NEXT: OpStore [[nestedParam1]] %int_1
+// CHECK-NEXT: [[call3:%\d+]] = OpFunctionCall %int %fnOneParm [[nestedParam1]]
+// CHECK-NEXT: OpStore [[nestedParam2]] [[call3]]
+// CHECK-NEXT: [[call4:%\d+]] = OpFunctionCall %int %fnCallOthers [[nestedParam2]]
 // CHECK-NEXT: OpReturn
 // CHECK-NEXT: OpFunctionEnd
     fnCallOthers(fnOneParm(1)); // Nested function calls

+ 1 - 6
tools/clang/test/CodeGenSPIRV/fn.ctbuffer.hlsl

@@ -26,13 +26,8 @@ tbuffer MyTBuffer {
 
 float4 main() : SV_Target {
 // %S vs %S_0: need destruction and construction
-// CHECK:         %temp_var_S = OpVariable %_ptr_Function_S_0 Function
 // CHECK:       [[tb_s:%\d+]] = OpAccessChain %_ptr_Uniform_S %MyTBuffer %int_1
-// CHECK-NEXT:     [[s:%\d+]] = OpLoad %S [[tb_s]]
-// CHECK-NEXT: [[s_val:%\d+]] = OpCompositeExtract %v3float [[s]] 0
-// CHECK-NEXT:   [[tmp:%\d+]] = OpCompositeConstruct %S_0 [[s_val]]
-// CHECK-NEXT:                  OpStore %temp_var_S [[tmp]]
-// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %v3float %S_get_s_val %temp_var_S
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %v3float %S_get_s_val [[tb_s]]
     return get_cb_val() + float4(tb_s.get_s_val(), 0.) * get_tb_val();
 }
 

+ 1 - 11
tools/clang/test/CodeGenSPIRV/fn.param.inout.storage-class.hlsl

@@ -9,22 +9,12 @@ void foo(in float a, inout float b, out float c) {
 
 void main(float input : INPUT) {
 // CHECK: %param_var_a = OpVariable %_ptr_Function_float Function
-// CHECK: %param_var_b = OpVariable %_ptr_Function_float Function
-// CHECK: %param_var_c = OpVariable %_ptr_Function_float Function
 
 // CHECK: [[val:%\d+]] = OpLoad %float %input
 // CHECK:                OpStore %param_var_a [[val]]
 // CHECK:  [[p0:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_0
-// CHECK: [[val:%\d+]] = OpLoad %float [[p0]]
-// CHECK:                OpStore %param_var_b [[val]]
 // CHECK:  [[p1:%\d+]] = OpAccessChain %_ptr_Uniform_float %Data %int_0 %uint_1
-// CHECK: [[val:%\d+]] = OpLoad %float [[p1]]
-// CHECK:                OpStore %param_var_c [[val]]
 
-// CHECK:                OpFunctionCall %void %foo %param_var_a %param_var_b %param_var_c
+// CHECK:                OpFunctionCall %void %foo %param_var_a [[p0]] [[p1]]
     foo(input, Data[0], Data[1]);
-// CHECK: [[val:%\d+]] = OpLoad %float %param_var_b
-// CHECK:                OpStore [[p0]] [[val]]
-// CHECK: [[val:%\d+]] = OpLoad %float %param_var_c
-// CHECK:                OpStore [[p1]] [[val]]
 }

+ 1 - 3
tools/clang/test/CodeGenSPIRV/fn.param.inout.vector.hlsl

@@ -18,7 +18,7 @@ float4 main() : C {
 
     float4 val;
 // CHECK:    [[z_ptr:%\d+]] = OpAccessChain %_ptr_Function_float %val %int_2
-// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %val %param_var_y %param_var_z %param_var_w
+// CHECK:          {{%\d+}} = OpFunctionCall %void %bar %val %param_var_y %param_var_z [[z_ptr]]
 // CHECK-NEXT:   [[y:%\d+]] = OpLoad %v3float %param_var_y
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
     // Write to val.zwx:
@@ -37,8 +37,6 @@ float4 main() : C {
 // CHECK-NEXT: [[old:%\d+]] = OpLoad %v4float %val
 // CHECK-NEXT: [[new:%\d+]] = OpVectorShuffle %v4float [[old]] [[z]] 4 5 2 3
 // CHECK-NEXT:                OpStore %val [[new]]
-// CHECK-NEXT:   [[w:%\d+]] = OpLoad %float %param_var_w
-// CHECK-NEXT:                OpStore [[z_ptr]] [[w]]
     bar(val, val.zwx, val.xy, val.z);
 
     return MyRWBuffer[0];

+ 106 - 0
tools/clang/test/CodeGenSPIRV/fn.param.isomorphism.hlsl

@@ -0,0 +1,106 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct R {
+  int a;
+  void incr() { ++a; }
+};
+
+// CHECK: %rwsb = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_R Uniform
+RWStructuredBuffer<R> rwsb;
+
+struct S {
+  int a;
+  void incr() { ++a; }
+};
+
+// CHECK: %gs = OpVariable %_ptr_Workgroup_S Workgroup
+groupshared S gs;
+
+// CHECK: %st = OpVariable %_ptr_Private_S Private
+static S st;
+
+void decr(inout R foo) {
+  foo.a--;
+};
+
+void decr2(inout S foo) {
+  foo.a--;
+};
+
+void int_decr(out int foo) {
+  ++foo;
+}
+
+// CHECK: %gsarr = OpVariable %_ptr_Workgroup__arr_S_uint_10 Workgroup
+groupshared S gsarr[10];
+
+// CHECK: %starr = OpVariable %_ptr_Private__arr_S_uint_10 Private
+static S starr[10];
+
+void main() {
+// CHECK:    %fn = OpVariable %_ptr_Function_S Function
+  S fn;
+
+// CHECK: %fnarr = OpVariable %_ptr_Function__arr_S_uint_10 Function
+  S fnarr[10];
+
+// CHECK:   %arr = OpVariable %_ptr_Function__arr_int_uint_10 Function
+  int arr[10];
+
+// CHECK:      [[rwsb:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:      {{%\d+}} = OpFunctionCall %void %R_incr [[rwsb]]
+  rwsb[0].incr();
+
+// CHECK: OpFunctionCall %void %S_incr %gs
+  gs.incr();
+
+// CHECK: OpFunctionCall %void %S_incr %st
+  st.incr();
+
+// CHECK: OpFunctionCall %void %S_incr %fn
+  fn.incr();
+
+// CHECK:      [[rwsb:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:      {{%\d+}} = OpFunctionCall %void %decr [[rwsb]]
+  decr(rwsb[0]);
+
+// CHECK: OpFunctionCall %void %decr2 %gs
+  decr2(gs);
+
+// CHECK: OpFunctionCall %void %decr2 %st
+  decr2(st);
+
+// CHECK: OpFunctionCall %void %decr2 %fn
+  decr2(fn);
+
+// CHECK:      [[gsarr:%\d+]] = OpAccessChain %_ptr_Workgroup_S %gsarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[gsarr]]
+  gsarr[0].incr();
+
+// CHECK:      [[starr:%\d+]] = OpAccessChain %_ptr_Private_S %starr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[starr]]
+  starr[0].incr();
+
+// CHECK:      [[fnarr:%\d+]] = OpAccessChain %_ptr_Function_S %fnarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %S_incr [[fnarr]]
+  fnarr[0].incr();
+
+// CHECK:      [[gsarr:%\d+]] = OpAccessChain %_ptr_Workgroup_S %gsarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[gsarr]]
+  decr2(gsarr[0]);
+
+// CHECK:      [[starr:%\d+]] = OpAccessChain %_ptr_Private_S %starr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[starr]]
+  decr2(starr[0]);
+
+// CHECK:      [[fnarr:%\d+]] = OpAccessChain %_ptr_Function_S %fnarr %int_0
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %decr2 [[fnarr]]
+  decr2(fnarr[0]);
+
+// CHECK:        [[arr:%\d+]] = OpAccessChain %_ptr_Function_int %arr %int_0
+// CHECK-NEXT: [[arr_0:%\d+]] = OpLoad %int [[arr]]
+// CHECK-NEXT: [[arr_0:%\d+]] = OpIAdd %int [[arr_0]] %int_1
+// CHECK-NEXT:                  OpStore [[arr]] [[arr_0]]
+// CHECK-NEXT:       {{%\d+}} = OpFunctionCall %void %int_decr [[arr]]
+  int_decr(++arr[0]);
+}

+ 2 - 7
tools/clang/test/CodeGenSPIRV/oo.method.on-static-var.hlsl

@@ -9,12 +9,7 @@ struct S {
 static S gSVar = {4.2};
 
 float main() : A {
-// CHECK:      %temp_var_S = OpVariable %_ptr_Function_S Function
-
-// CHECK:       [[s:%\d+]] = OpLoad %S %gSVar
-// CHECK-NEXT:               OpStore %temp_var_S [[s]]
-// CHECK-NEXT:    {{%\d+}} = OpFunctionCall %float %S_getVal %temp_var_S
-// CHECK-NEXT:  [[s:%\d+]] = OpLoad %S %temp_var_S
-// CHECK-NEXT:               OpStore %gSVar [[s]]
+// CHECK:      [[ret:%\d+]] = OpFunctionCall %float %S_getVal %gSVar
+// CHECK-NEXT:                OpReturnValue [[ret]]
     return gSVar.getVal();
 }

+ 2 - 10
tools/clang/test/CodeGenSPIRV/oo.struct.method.hlsl

@@ -104,16 +104,8 @@ float main() : A {
 // CHECK-NEXT:        {{%\d+}} = OpFunctionCall %float %S_fn_ref %temp_var_S
   float f2 = foo().get_S().fn_ref();
 
-// CHECK:         [[uniformPtr:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
-// CHECK-NEXT:   [[originalObj:%\d+]] = OpLoad %R [[uniformPtr]]
-// CHECK-NEXT:        [[member:%\d+]] = OpCompositeExtract %int [[originalObj]] 0
-// CHECK-NEXT:       [[tempVar:%\d+]] = OpCompositeConstruct %R_0 [[member]]
-// CHECK-NEXT:                          OpStore %temp_var_R [[tempVar]]
-// CHECK-NEXT:                          OpFunctionCall %void %R_incr %temp_var_R
-// CHECK-NEXT:       [[tempVar:%\d+]] = OpLoad %R_0 %temp_var_R
-// CHECK-NEXT: [[tempVarMember:%\d+]] = OpCompositeExtract %int [[tempVar]] 0
-// CHECK-NEXT:          [[newR:%\d+]] = OpCompositeConstruct %R [[tempVarMember]]
-// CHECK-NEXT:                          OpStore [[uniformPtr]] [[newR]]
+// CHECK:      [[rwsb_0:%\d+]] = OpAccessChain %_ptr_Uniform_R %rwsb %int_0 %uint_0
+// CHECK-NEXT:                   OpFunctionCall %void %R_incr [[rwsb_0]]
   rwsb[0].incr();
 
   return f1;

+ 34 - 11
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -88,7 +88,7 @@ TEST_F(FileTest, StructuredBufferType) {
   runFileTest("type.structured-buffer.hlsl");
 }
 TEST_F(FileTest, StructuredByteBufferArray) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("type.structured-buffer.array.hlsl");
 }
 TEST_F(FileTest, StructuredByteBufferArrayError) {
@@ -496,11 +496,17 @@ TEST_F(FileTest, FunctionInOutParam) {
   runFileTest("fn.param.inout.hlsl");
 }
 TEST_F(FileTest, FunctionInOutParamVector) {
+  setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.vector.hlsl");
 }
 TEST_F(FileTest, FunctionInOutParamDiffStorageClass) {
+  setBeforeHLSLLegalization();
   runFileTest("fn.param.inout.storage-class.hlsl");
 }
+TEST_F(FileTest, FunctionInOutParamIsomorphism) {
+  setBeforeHLSLLegalization();
+  runFileTest("fn.param.isomorphism.hlsl");
+}
 TEST_F(FileTest, FunctionInOutParamNoNeedToCopy) {
   // Tests that referencing function scope variables as a whole with out/inout
   // annotation does not create temporary variables
@@ -517,15 +523,18 @@ TEST_F(FileTest, FunctionInOutParamTypeMismatch) {
 TEST_F(FileTest, FunctionFowardDeclaration) {
   runFileTest("fn.foward-declaration.hlsl");
 }
-TEST_F(FileTest, FunctionInCTBuffer) { runFileTest("fn.ctbuffer.hlsl"); }
+TEST_F(FileTest, FunctionInCTBuffer) {
+  setBeforeHLSLLegalization();
+  runFileTest("fn.ctbuffer.hlsl");
+}
 
 // For OO features
 TEST_F(FileTest, StructMethodCall) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("oo.struct.method.hlsl");
 }
 TEST_F(FileTest, ClassMethodCall) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("oo.class.method.hlsl");
 }
 TEST_F(FileTest, StructStaticMember) {
@@ -538,6 +547,7 @@ TEST_F(FileTest, StaticMemberInitializer) {
   runFileTest("oo.static.member.init.hlsl");
 }
 TEST_F(FileTest, MethodCallOnStaticVar) {
+  setBeforeHLSLLegalization();
   runFileTest("oo.method.on-static-var.hlsl");
 }
 TEST_F(FileTest, Inheritance) { runFileTest("oo.inheritance.hlsl"); }
@@ -1366,34 +1376,34 @@ TEST_F(FileTest, SpirvLegalizationOpaqueStruct) {
   runFileTest("spirv.legal.opaque-struct.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferUsage) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.usage.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferMethods) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.methods.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounter) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInStruct) {
   // Tests using struct/class having associated counters
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.struct.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferCounterInMethod) {
   // Tests using methods whose enclosing struct/class having associated counters
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.counter.method.hlsl");
 }
 TEST_F(FileTest,
        SpirvLegalizationCounterVarAssignAcrossDifferentNestedStructLevel) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.counter.nested-struct.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationStructuredBufferInStruct) {
-  setRelaxLogicalPointer();
+  setBeforeHLSLLegalization();
   runFileTest("spirv.legal.sbuffer.struct.hlsl");
 }
 TEST_F(FileTest, SpirvLegalizationConstantBuffer) {
@@ -1782,6 +1792,18 @@ TEST_F(FileTest, GeometryShaderEmit) { runFileTest("gs.emit.hlsl"); }
 TEST_F(FileTest, ComputeShaderGroupShared) {
   runFileTest("cs.groupshared.hlsl");
 }
+TEST_F(FileTest, ComputeShaderGroupSharedFunctionParam) {
+  setRelaxLogicalPointer();
+  runFileTest("cs.groupshared.function-param.hlsl");
+}
+TEST_F(FileTest, ComputeShaderGroupSharedFunctionParamOut) {
+  setBeforeHLSLLegalization();
+  runFileTest("cs.groupshared.function-param.out.hlsl");
+}
+TEST_F(FileTest, ComputeShaderGroupSharedStructFunction) {
+  setBeforeHLSLLegalization();
+  runFileTest("cs.groupshared.struct-function.hlsl");
+}
 
 // === Legalization examples ===
 
@@ -1914,6 +1936,7 @@ TEST_F(FileTest, DecorationRelaxedPrecisionImage) {
 
 // For NoContraction decorations
 TEST_F(FileTest, DecorationNoContraction) {
+  setBeforeHLSLLegalization();
   runFileTest("decoration.no-contraction.hlsl");
 }
 TEST_F(FileTest, DecorationNoContractionVariableReuse) {

+ 9 - 9
tools/clang/unittests/SPIRV/FileTestFixture.cpp

@@ -62,7 +62,7 @@ bool FileTest::parseInputFile() {
 
 void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
                            bool runValidation) {
-  if (relaxLogicalPointer)
+  if (relaxLogicalPointer || beforeHLSLLegalization)
     assert(runValidation);
 
   inputFilePath = utils::getAbsPathOfInputDataFile(filename);
@@ -102,9 +102,9 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
 
     if (runValidation)
-      EXPECT_TRUE(utils::validateSpirvBinary(targetEnv, generatedBinary,
-                                             relaxLogicalPointer, glLayout,
-                                             dxLayout, scalarLayout));
+      EXPECT_TRUE(utils::validateSpirvBinary(
+          targetEnv, generatedBinary, relaxLogicalPointer,
+          beforeHLSLLegalization, glLayout, dxLayout, scalarLayout));
   } else if (expect == Expect::Warning) {
     ASSERT_TRUE(compileOk);
 
@@ -128,9 +128,9 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
     ASSERT_EQ(result.status(), effcee::Result::Status::Ok);
 
     if (runValidation)
-      EXPECT_TRUE(utils::validateSpirvBinary(targetEnv, generatedBinary,
-                                             relaxLogicalPointer, glLayout,
-                                             dxLayout, scalarLayout));
+      EXPECT_TRUE(utils::validateSpirvBinary(
+          targetEnv, generatedBinary, relaxLogicalPointer,
+          beforeHLSLLegalization, glLayout, dxLayout, scalarLayout));
   } else if (expect == Expect::Failure) {
     ASSERT_FALSE(compileOk);
 
@@ -157,8 +157,8 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
 
     std::string valMessages;
     EXPECT_FALSE(utils::validateSpirvBinary(
-        targetEnv, generatedBinary, relaxLogicalPointer, glLayout, dxLayout,
-        scalarLayout, &valMessages));
+        targetEnv, generatedBinary, relaxLogicalPointer, beforeHLSLLegalization,
+        glLayout, dxLayout, scalarLayout, &valMessages));
     auto options = effcee::Options()
                        .SetChecksName(filename.str())
                        .SetInputName("<val-message>");

+ 3 - 1
tools/clang/unittests/SPIRV/FileTestFixture.h

@@ -29,10 +29,11 @@ public:
 
   FileTest()
       : targetEnv(SPV_ENV_VULKAN_1_0), relaxLogicalPointer(false),
-        glLayout(false), dxLayout(false) {}
+        beforeHLSLLegalization(false), glLayout(false), dxLayout(false) {}
 
   void useVulkan1p1() { targetEnv = SPV_ENV_VULKAN_1_1; }
   void setRelaxLogicalPointer() { relaxLogicalPointer = true; }
+  void setBeforeHLSLLegalization() { beforeHLSLLegalization = true; }
   void setGlLayout() { glLayout = true; }
   void setDxLayout() { dxLayout = true; }
   void setScalarLayout() { scalarLayout = true; }
@@ -54,6 +55,7 @@ private:
   std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
   spv_target_env targetEnv;              ///< Environment to validate against
   bool relaxLogicalPointer;
+  bool beforeHLSLLegalization;
   bool glLayout;
   bool dxLayout;
   bool scalarLayout;

+ 4 - 2
tools/clang/unittests/SPIRV/FileTestUtils.cpp

@@ -35,10 +35,12 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 }
 
 bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
-                         bool relaxLogicalPointer, bool glLayout, bool dxLayout,
-                         bool scalarLayout, std::string *message) {
+                         bool relaxLogicalPointer, bool beforeHlslLegalization,
+                         bool glLayout, bool dxLayout, bool scalarLayout,
+                         std::string *message) {
   spvtools::ValidatorOptions options;
   options.SetRelaxLogicalPointer(relaxLogicalPointer);
+  options.SetBeforeHlslLegalization(beforeHlslLegalization);
   if (dxLayout || scalarLayout) {
     options.SetSkipBlockLayout(true);
   } else if (glLayout) {

+ 3 - 2
tools/clang/unittests/SPIRV/FileTestUtils.h

@@ -33,8 +33,9 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 /// \brief Runs the SPIR-V Tools validation on the given SPIR-V binary.
 /// Returns true if validation is successful; false otherwise.
 bool validateSpirvBinary(spv_target_env, std::vector<uint32_t> &binary,
-                         bool relaxLogicalPointer, bool glLayout, bool dxLayout,
-                         bool scalarLayout, std::string *message = nullptr);
+                         bool relaxLogicalPointer, bool beforeHlslLegalization,
+                         bool glLayout, bool dxLayout, bool scalarLayout,
+                         std::string *message = nullptr);
 
 /// \brief Parses the Target Profile and Entry Point from the Run command
 /// Returns the target profile, entry point, and the rest via arguments.

+ 2 - 2
tools/clang/unittests/SPIRV/WholeFileTestFixture.cpp

@@ -109,8 +109,8 @@ void WholeFileTest::runWholeFileTest(llvm::StringRef filename,
   if (runSpirvValidation) {
     EXPECT_TRUE(utils::validateSpirvBinary(
         targetEnv, generatedBinary,
-        /*relaxLogicalPointer=*/false, /*glLayout=*/false, /*dxLayout=*/false,
-        /*scalarLayout=*/false));
+        /*relaxLogicalPointer=*/false, /*beforeHlslLegalization=*/false,
+        /*glLayout=*/false, /*dxLayout=*/false, /*scalarLayout=*/false));
   }
 }