Prechádzať zdrojové kódy

Lower vector/matrix early for UDT ptrs used directly such as Payload

Tex Riddell 5 rokov pred
rodič
commit
6eb541244a

+ 5 - 2
include/dxc/HLSL/HLLowerUDT.h

@@ -23,9 +23,12 @@ class Value;
 } // namespace llvm
 
 namespace hlsl {
+class DxilTypeSystem;
 
-llvm::StructType *GetLoweredUDT(llvm::StructType *structTy);
-llvm::Constant *TranslateInitForLoweredUDT(llvm::Constant *Init, llvm::Type *NewTy,
+llvm::StructType *GetLoweredUDT(
+  llvm::StructType *structTy, hlsl::DxilTypeSystem *pTypeSys = nullptr);
+llvm::Constant *TranslateInitForLoweredUDT(
+    llvm::Constant *Init, llvm::Type *NewTy,
     // We need orientation for matrix fields
     hlsl::DxilTypeSystem *pTypeSys,
     hlsl::MatrixOrientation matOrientation = hlsl::MatrixOrientation::Undefined);

+ 26 - 5
lib/HLSL/HLLowerUDT.cpp

@@ -54,7 +54,7 @@ static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, un
 // Lowered UDT is the same layout, but with vectors and matrices translated to
 // arrays.
 // Returns nullptr for failure due to embedded HLSL object type.
-StructType *hlsl::GetLoweredUDT(StructType *structTy) {
+StructType *hlsl::GetLoweredUDT(StructType *structTy, DxilTypeSystem *pTypeSys) {
   bool changed = false;
   SmallVector<Type*, 8> NewElTys(structTy->getNumContainedTypes());
 
@@ -106,17 +106,29 @@ StructType *hlsl::GetLoweredUDT(StructType *structTy) {
   }
 
   if (changed) {
-    return StructType::create(
+    StructType *newStructTy = StructType::create(
       structTy->getContext(), NewElTys, structTy->getStructName());
+    if (DxilStructAnnotation *pSA = pTypeSys ?
+          pTypeSys->GetStructAnnotation(structTy) : nullptr) {
+      if (!pTypeSys->GetStructAnnotation(newStructTy)) {
+        DxilStructAnnotation &NewSA = *pTypeSys->AddStructAnnotation(newStructTy);
+        for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
+          NewSA.GetFieldAnnotation(iField) = pSA->GetFieldAnnotation(iField);
+        }
+      }
+    }
+    return newStructTy;
   }
 
   return structTy;
 }
 
-Constant *hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy,
+Constant *hlsl::TranslateInitForLoweredUDT(
+    Constant *Init, Type *NewTy,
     // We need orientation for matrix fields
     DxilTypeSystem *pTypeSys,
     MatrixOrientation matOrientation) {
+
   // handle undef and zero init
   if (isa<UndefValue>(Init))
     return UndefValue::get(NewTy);
@@ -159,14 +171,23 @@ Constant *hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy,
       }
     }
   } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    DxilStructAnnotation *pStructAnnotation =
+      pTypeSys ? pTypeSys->GetStructAnnotation(ST) : nullptr;
     values.reserve(ST->getNumContainedTypes());
     ConstantStruct *CS = cast<ConstantStruct>(Init);
     for (unsigned i = 0; i < ST->getStructNumElements(); ++i) {
+      MatrixOrientation matFieldOrientation = matOrientation;
+      if (pStructAnnotation) {
+        DxilFieldAnnotation &FA = pStructAnnotation->GetFieldAnnotation(i);
+        if (FA.HasMatrixAnnotation()) {
+          matFieldOrientation = FA.GetMatrixAnnotation().Orientation;
+        }
+      }
       values.emplace_back(
         TranslateInitForLoweredUDT(
           cast<Constant>(CS->getAggregateElement(i)),
           NewTy->getStructElementType(i),
-          pTypeSys, matOrientation));
+          pTypeSys, matFieldOrientation));
     }
     return ConstantStruct::get(cast<StructType>(NewTy), values);
   }
@@ -411,7 +432,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
         }
       } break;
 
-      case HLOpcodeGroup::NotHL:
+      //case HLOpcodeGroup::NotHL:  // TODO: Support lib functions
       case HLOpcodeGroup::HLIntrinsic: {
         // Just bitcast for now
         IRBuilder<> Builder(CI);

+ 233 - 13
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -58,6 +58,7 @@
 #include "dxc/HLSL/HLMatrixLowerHelper.h"
 #include "dxc/HLSL/HLMatrixType.h"
 #include "dxc/DXIL/DxilOperations.h"
+#include "dxc/HLSL/HLLowerUDT.h"
 #include <deque>
 #include <unordered_map>
 #include <unordered_set>
@@ -777,6 +778,200 @@ static unsigned getNestedLevelInStruct(const Type *ty) {
   return lvl;
 }
 
+/// Returns first GEP index that indexes a struct member, or 0 otherwise.
+/// Ignores initial ptr index.
+static unsigned FindFirstStructMemberIdxInGEP(GEPOperator *GEP) {
+  StructType *ST = dyn_cast<StructType>(
+    GEP->getPointerOperandType()->getPointerElementType());
+  int index = 1;
+  for (auto it = gep_type_begin(GEP), E = gep_type_end(GEP); it != E;
+       ++it, ++index) {
+    if (ST) {
+      DXASSERT(!HLMatrixType::isa(ST) && !dxilutil::IsHLSLObjectType(ST),
+               "otherwise, indexing into hlsl object");
+      return index;
+    }
+    ST = dyn_cast<StructType>(it->getPointerElementType());
+  }
+  return 0;
+}
+
+/// Return true when ptr should not be SROA'd or copied, but used directly
+/// by a function in its lowered form.  Also collect uses for translation.
+/// What is meant by directly here:
+///   Possibly accessed through GEP array index or address space cast, but
+///   not under another struct member (always allow SROA of outer struct).
+typedef SmallMapVector<CallInst*, unsigned, 4> FunctionUseMap;
+static unsigned IsPtrUsedByLoweredFn(
+    Value *V, FunctionUseMap &CollectedUses) {
+  bool bFound = false;
+  for (Use &U : V->uses()) {
+    User *user = U.getUser();
+
+    if (CallInst *CI = dyn_cast<CallInst>(user)) {
+      unsigned foundIdx = (unsigned)-1;
+      Function *F = CI->getCalledFunction();
+      Type *Ty = V->getType();
+      if (F->isDeclaration() && !F->isIntrinsic() &&
+          Ty->isPointerTy()) {
+        HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByName(F);
+        if (group == HLOpcodeGroup::HLIntrinsic) {
+          unsigned opIdx = U.getOperandNo();
+          switch ((IntrinsicOp)hlsl::GetHLOpcode(CI)) {
+            // TODO: Lower these as well, along with function parameter types
+            //case IntrinsicOp::IOP_TraceRay:
+            //  if (opIdx != HLOperandIndex::kTraceRayPayLoadOpIdx)
+            //    continue;
+            //  break;
+            //case IntrinsicOp::IOP_ReportHit:
+            //  if (opIdx != HLOperandIndex::kReportIntersectionAttributeOpIdx)
+            //    continue;
+            //  break;
+            //case IntrinsicOp::IOP_CallShader:
+            //  if (opIdx != HLOperandIndex::kCallShaderPayloadOpIdx)
+            //    continue;
+            //  break;
+            case IntrinsicOp::IOP_DispatchMesh:
+              if (opIdx != HLOperandIndex::kDispatchMeshOpPayload)
+                continue;
+              break;
+            default:
+              continue;
+          }
+          foundIdx = opIdx;
+
+        // TODO: Lower these as well, along with function parameter types
+        //} else if (group == HLOpcodeGroup::NotHL) {
+        //  foundIdx = U.getOperandNo();
+        }
+      }
+      if (foundIdx != (unsigned)-1) {
+        bFound = true;
+        auto insRes = CollectedUses.insert(std::make_pair(CI, foundIdx));
+        DXASSERT_LOCALVAR(insRes, insRes.second,
+            "otherwise, multiple uses in single call");
+      }
+
+    } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
+      // Not what we are looking for if GEP result is not [array of] struct.
+      // If use is under struct member, we can still SROA the outer struct.
+      if (!dxilutil::StripArrayTypes(GEP->getType()->getPointerElementType())
+            ->isStructTy() ||
+          FindFirstStructMemberIdxInGEP(cast<GEPOperator>(GEP)))
+        continue;
+      if (IsPtrUsedByLoweredFn(user, CollectedUses))
+        bFound = true;
+
+    } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
+      if (IsPtrUsedByLoweredFn(user, CollectedUses))
+        bFound = true;
+
+    } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
+      unsigned opcode = CE->getOpcode();
+      if (opcode == Instruction::AddrSpaceCast || Instruction::GetElementPtr)
+        if (IsPtrUsedByLoweredFn(user, CollectedUses))
+          bFound = true;
+    }
+  }
+  return bFound;
+}
+
+/// Rewrite call to natively use an argument with addrspace cast/bitcast
+static CallInst *RewriteIntrinsicCallForCastedArg(CallInst *CI, unsigned argIdx) {
+  Function *F = CI->getCalledFunction();
+  HLOpcodeGroup group = GetHLOpcodeGroupByName(F);
+  DXASSERT_NOMSG(group == HLOpcodeGroup::HLIntrinsic);
+  unsigned opcode = GetHLOpcode(CI);
+  SmallVector<Type *, 8> newArgTypes(CI->getFunctionType()->param_begin(),
+                                     CI->getFunctionType()->param_end());
+  SmallVector<Value *, 8> newArgs(CI->arg_operands());
+
+  Value *newArg = CI->getOperand(argIdx)->stripPointerCasts();
+  newArgTypes[argIdx] = newArg->getType();
+  newArgs[argIdx] = newArg;
+
+  FunctionType *newFuncTy = FunctionType::get(CI->getType(), newArgTypes, false);
+  Function *newF = GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode);
+
+  IRBuilder<> Builder(CI);
+  return Builder.CreateCall(newF, newArgs);
+}
+
+/// Translate pointer for cases where intrinsics use UDT pointers directly
+/// Return existing or new ptr if needs preserving,
+/// otherwise nullptr to proceed with existing checks and SROA.
+static Value *TranslatePtrIfUsedByLoweredFn(
+    Value *Ptr, DxilTypeSystem &TypeSys) {
+  if (!Ptr->getType()->isPointerTy())
+    return nullptr;
+  Type *Ty = Ptr->getType()->getPointerElementType();
+  SmallVector<unsigned, 4> outerToInnerLengths;
+  Ty = dxilutil::StripArrayTypes(Ty, &outerToInnerLengths);
+  if (!Ty->isStructTy())
+    return nullptr;
+  if (HLMatrixType::isa(Ty) || dxilutil::IsHLSLObjectType(Ty))
+    return nullptr;
+  unsigned AddrSpace = Ptr->getType()->getPointerAddressSpace();
+  FunctionUseMap FunctionUses;
+  if (!IsPtrUsedByLoweredFn(Ptr, FunctionUses))
+    return nullptr;
+  // Translate vectors to arrays in type, but don't SROA
+  Type *NewTy = GetLoweredUDT(cast<StructType>(Ty));
+
+  // No work to do here, but prevent SROA.
+  if (Ty == NewTy && AddrSpace != DXIL::kTGSMAddrSpace)
+    return Ptr;
+
+  // If type changed, replace value, otherwise casting may still
+  // require a rewrite of the calls.
+  Value *NewPtr = Ptr;
+  if (Ty != NewTy) {
+    // TODO: Transfer type annotation
+    DxilStructAnnotation *pOldAnnotation = TypeSys.GetStructAnnotation(cast<StructType>(Ty));
+    if (pOldAnnotation) {
+
+    }
+    NewTy = dxilutil::WrapInArrayTypes(NewTy, outerToInnerLengths);
+    if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
+      Module &M = *GV->getParent();
+      // Rewrite init expression for arrays instead of vectors
+      Constant *Init = GV->hasInitializer() ?
+        GV->getInitializer() : UndefValue::get(Ptr->getType());
+      Constant *NewInit = TranslateInitForLoweredUDT(
+        Init, NewTy, &TypeSys);
+      // Replace with new GV, and rewrite vector load/store users
+      GlobalVariable *NewGV = new GlobalVariable(
+          M, NewTy, GV->isConstant(), GV->getLinkage(),
+          NewInit, GV->getName(), /*InsertBefore*/ GV,
+          GV->getThreadLocalMode(), AddrSpace);
+      NewPtr = NewGV;
+    } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
+      IRBuilder<> Builder(AI);
+      AllocaInst * NewAI = Builder.CreateAlloca(NewTy, nullptr, AI->getName());
+      NewPtr = NewAI;
+    } else {
+      DXASSERT(false, "Ptr must be global or alloca");
+    }
+    // This will rewrite vector load/store users
+    // and insert bitcasts for CallInst users
+    ReplaceUsesForLoweredUDT(Ptr, NewPtr);
+  }
+
+  // Rewrite the HLIntrinsic calls
+  for (auto it : FunctionUses) {
+    CallInst *CI = it.first;
+    HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+    if (group == HLOpcodeGroup::NotHL)
+      continue;
+    CallInst *newCI = RewriteIntrinsicCallForCastedArg(CI, it.second);
+    CI->replaceAllUsesWith(newCI);
+    CI->eraseFromParent();
+  }
+
+  return NewPtr;
+}
+
+
 // performScalarRepl - This algorithm is a simple worklist driven algorithm,
 // which runs on all of the alloca instructions in the entry block, removing
 // them if they are only used by getelementptr instructions.
@@ -866,6 +1061,15 @@ bool SROA_HLSL::performScalarRepl(Function &F, DxilTypeSystem &typeSys) {
       continue;
     }
 
+    if (Value *NewV = TranslatePtrIfUsedByLoweredFn(AI, typeSys)) {
+      if (NewV != AI) {
+        DXASSERT(AI->getNumUses() == 0, "must have zero users.");
+        AI->eraseFromParent();
+        Changed = true;
+      }
+      continue;
+    }
+
     // If the alloca looks like a good candidate for scalar replacement, and
     // if
     // all its users can be transformed, then split up the aggregate into its
@@ -1053,8 +1257,7 @@ void SROA_HLSL::isSafeForScalarRepl(Instruction *I, uint64_t Offset,
         IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(CI));
         if (IntrinsicOp::IOP_TraceRay == opcode ||
             IntrinsicOp::IOP_ReportHit == opcode ||
-            IntrinsicOp::IOP_CallShader == opcode ||
-            IntrinsicOp::IOP_DispatchMesh == opcode) {
+            IntrinsicOp::IOP_CallShader == opcode) {
           return MarkUnsafe(Info, User);
         }
       }
@@ -2588,13 +2791,6 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
         RewriteCallArg(CI, HLOperandIndex::kCallShaderPayloadOpIdx,
                        /*bIn*/ true, /*bOut*/ true);
       } break;
-      case IntrinsicOp::IOP_DispatchMesh: {
-        if (OldVal ==
-            CI->getArgOperand(HLOperandIndex::kDispatchMeshOpPayload)) {
-          RewriteCallArg(CI, HLOperandIndex::kDispatchMeshOpPayload,
-                         /*bIn*/ true, /*bOut*/ false);
-        }
-      } break;
       case IntrinsicOp::MOP_TraceRayInline: {
         if (OldVal ==
             CI->getArgOperand(HLOperandIndex::kTraceRayInlineRayDescOpIdx)) {
@@ -4068,10 +4264,20 @@ void SROA_Parameter_HLSL::flattenGlobal(GlobalVariable *GV) {
       bFlatVector = false;
 
     std::vector<Value *> Elts;
-    bool SROAed = SROA_Helper::DoScalarReplacement(
-        EltGV, Elts, Builder, bFlatVector,
-        // TODO: set precise.
-        /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts);
+    bool SROAed = false;
+    if (GlobalVariable *NewEltGV = dyn_cast_or_null<GlobalVariable>(
+        TranslatePtrIfUsedByLoweredFn(EltGV, dxilTypeSys))) {
+      if (GV != EltGV) {
+        EltGV->removeDeadConstantUsers();
+        EltGV->eraseFromParent();
+      }
+      EltGV = NewEltGV;
+    } else {
+      SROAed = SROA_Helper::DoScalarReplacement(
+          EltGV, Elts, Builder, bFlatVector,
+          // TODO: set precise.
+          /*hasPrecise*/ false, dxilTypeSys, DL, DeadInsts);
+    }
 
     if (SROAed) {
       // Push Elts into workList.
@@ -4722,6 +4928,19 @@ Value *SROA_Parameter_HLSL::castArgumentIfRequired(
   Module &M = *m_pHLModule->GetModule();
   IRBuilder<> AllocaBuilder(dxilutil::FindAllocaInsertionPt(Builder.GetInsertPoint()));
 
+  if (inputQual == DxilParamInputQual::InPayload) {
+    DXASSERT_NOMSG(isa<StructType>(Ty));
+    // Lower payload type here
+    StructType *LoweredTy = GetLoweredUDT(cast<StructType>(Ty));
+    if (LoweredTy != Ty) {
+      Value *Ptr = AllocaBuilder.CreateAlloca(LoweredTy);
+      ReplaceUsesForLoweredUDT(V, Ptr);
+      castParamMap[V] = std::make_pair(Ptr, inputQual);
+      V = Ptr;
+    }
+    return V;
+  }
+
   // Remove pointer for vector/scalar which is not out.
   if (V->getType()->isPointerTy() && !Ty->isAggregateType() && !bOut) {
     Value *Ptr = AllocaBuilder.CreateAlloca(Ty);
@@ -5419,6 +5638,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       bLoadOutputFromTemp = true;
     } else if (bLoad && bStore) {
       switch (qual) {
+      case DxilParamInputQual::InPayload:
       case DxilParamInputQual::InputPrimitive:
       case DxilParamInputQual::InputPatch:
       case DxilParamInputQual::OutputPatch: {

+ 62 - 0
tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload-matrix.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc -E main -T as_6_5 %s | FileCheck %s
+
+// CHECK: define void @main
+
+struct MeshPayload
+{
+  int4 data;
+  bool2x2 mat;
+};
+
+struct GSStruct
+{
+  row_major bool2x2 mat;
+  int4 vecmat;
+  MeshPayload pld[2];
+};
+
+groupshared GSStruct gs[2];
+
+row_major bool2x2 row_mat_array[2];
+
+int i, j;
+
+[numthreads(4,1,1)]
+void main(uint gtid : SV_GroupIndex)
+{
+  // write to dynamic row/col
+  gs[j].pld[i].mat[gtid >> 1][gtid & 1] = (int)gtid;
+  gs[j].vecmat[gtid] = (int)gtid;
+
+  int2x2 mat = gs[j].pld[i].mat;
+  gs[j].pld[i].mat = (bool2x2)gs[j].vecmat;
+
+  // subscript + constant GEP for component
+  gs[j].pld[i].mat[1].x = mat[1].y;
+  mat[0].y = gs[j].pld[i].mat[0].x;
+
+  // dynamic subscript + constant component index
+  gs[j].pld[i].mat[gtid & 1].x = mat[gtid & 1].y;
+  mat[gtid & 1].y = gs[j].pld[i].mat[gtid & 1].x;
+
+  // dynamic subscript + GEP for component
+  gs[j].pld[i].mat[gtid & 1] = mat[gtid & 1].y;
+  mat[gtid & 1].y = gs[j].pld[i].mat[gtid & 1].x;
+
+  // subscript element
+  gs[j].pld[i].mat._m01_m10 = mat[1];
+  mat[0] = gs[j].pld[i].mat._m00_m11;
+
+  // dynamic index of subscript element vector
+  mat[0].x = gs[j].pld[i].mat._m00_m11_m10[gtid & 1];
+  gs[j].pld[i].mat._m11_m10[gtid & 1] = gtid;
+
+  // Dynamic index into vector
+  int idx = gs[j].vecmat.x;
+  gs[j].pld[i].mat[1][idx] = mat[1].y;
+  mat[0].y = gs[j].pld[i].mat[0][idx];
+  int2 vec = gs[j].mat[0];
+  int2 multiplied = mul(mat, vec);
+  gs[j].pld[i].data = multiplied.xyxy;
+  DispatchMesh(1,1,1,gs[j].pld[i]);
+}

+ 21 - 4
tools/clang/test/HLSLFileCheck/shader_targets/mesh/as-groupshared-payload.hlsl

@@ -1,17 +1,34 @@
 // RUN: %dxc -E amplification -T as_6_5 %s | FileCheck %s
 
+// Make sure we pass constant gep of groupshared mesh payload directly
+// in to DispatchMesh, with no alloca involved.
+
 // CHECK: define void @amplification
+// CHECK-NOT: alloca
+// CHECK-NOT: addrspacecast
+// CHECK-NOT: bitcast
+// CHECK: call void @dx.op.dispatchMesh.struct.MeshPayload{{[^ ]*}}(i32 173, i32 1, i32 1, i32 1, %struct.MeshPayload{{[^ ]*}} addrspace(3)* getelementptr inbounds (%struct.GSStruct{{[^ ]*}}, %struct.GSStruct{{[^ ]*}} addrspace(3)* @"\01?gs@@3UGSStruct@@A{{[^ ]*}}", i32 0, i32 1))
+// CHECK: ret void
 
 struct MeshPayload
 {
-  uint data[4];
+  uint4 data;
+};
+
+struct GSStruct
+{
+  uint i;
+  MeshPayload pld;
 };
 
-groupshared MeshPayload pld;
+groupshared GSStruct gs;
+GSStruct cb_gs;
 
 [numthreads(4,1,1)]
 void amplification(uint gtid : SV_GroupIndex)
 {
-  pld.data[gtid] = gtid;
-  DispatchMesh(1,1,1,pld);
+  gs = cb_gs;
+  gs.i = 1;
+  gs.pld.data[gtid] = gtid;
+  DispatchMesh(1,1,1,gs.pld);
 }

+ 81 - 0
tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh-payload-matrix.hlsl

@@ -0,0 +1,81 @@
+// RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
+
+// CHECK: %[[pld:[^ ]+]] = call %struct.MeshPayload{{[^ ]*}} @dx.op.getMeshPayload.struct.MeshPayload{{.*}}(i32 170)
+// CHECK: call void @dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
+// CHECK: call void @dx.op.emitIndices
+
+// Verify bool translated from mem type
+// CHECK: %[[ppld0:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 0
+// CHECK: %[[pld0:[^ ]+]] = load i32, i32* %[[ppld0]], align 4
+// CHECK: %[[ppld1:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 1
+// CHECK: %[[pld1:[^ ]+]] = load i32, i32* %[[ppld1]], align 4
+// CHECK: %[[ppld2:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 2
+// CHECK: %[[pld2:[^ ]+]] = load i32, i32* %[[ppld2]], align 4
+// CHECK: %[[ppld3:[^ ]+]] = getelementptr inbounds %struct.MeshPayload{{[^ ]*}}, %struct.MeshPayload{{[^ ]*}}* %[[pld]], i32 0, i32 2, i32 3
+// CHECK: %[[pld3:[^ ]+]] = load i32, i32* %[[ppld3]], align 4
+// Inner components reversed due to column_major
+// CHECK: icmp ne i32 %[[pld0]], 0
+// CHECK: icmp ne i32 %[[pld2]], 0
+// CHECK: icmp ne i32 %[[pld1]], 0
+// CHECK: icmp ne i32 %[[pld3]], 0
+
+// CHECK: call void @dx.op.storePrimitiveOutput
+// CHECK: call void @dx.op.storeVertexOutput
+
+// CHECK: ret void
+
+#define MAX_VERT 32
+#define MAX_PRIM 16
+#define NUM_THREADS 32
+struct MeshPerVertex {
+    float4 position : SV_Position;
+    float color[4] : COLOR;
+};
+
+struct MeshPerPrimitive {
+    float normal : NORMAL;
+};
+
+struct MeshPayload {
+    float normal;
+    int4 data;
+    bool2x2 mat;
+};
+
+groupshared float gsMem[MAX_PRIM];
+
+[numthreads(NUM_THREADS, 1, 1)]
+[outputtopology("triangle")]
+void main(
+            out indices uint3 primIndices[MAX_PRIM],
+            out vertices MeshPerVertex verts[MAX_VERT],
+            out primitives MeshPerPrimitive prims[MAX_PRIM],
+            in payload MeshPayload mpl,
+            in uint tig : SV_GroupIndex,
+            in uint vid : SV_ViewID
+         )
+{
+    SetMeshOutputCounts(MAX_VERT, MAX_PRIM);
+    MeshPerVertex ov;
+    if (vid % 2) {
+        ov.position = float4(4.0,5.0,6.0,7.0);
+        ov.color[0] = 4.0;
+        ov.color[1] = 5.0;
+        ov.color[2] = 6.0;
+        ov.color[3] = 7.0;
+    } else {
+        ov.position = float4(14.0,15.0,16.0,17.0);
+        ov.color[0] = 14.0;
+        ov.color[1] = 15.0;
+        ov.color[2] = 16.0;
+        ov.color[3] = 17.0;
+    }
+    if (tig % 3) {
+        primIndices[tig / 3] = uint3(tig, tig + 1, tig + 2);
+        MeshPerPrimitive op;
+        op.normal = dot(mpl.normal.xx, mul(mpl.data.xy, mpl.mat));
+        gsMem[tig / 3] = op.normal;
+        prims[tig / 3] = op;
+    }
+    verts[tig] = ov;
+}

+ 4 - 4
tools/clang/test/HLSLFileCheck/shader_targets/mesh/mesh.hlsl

@@ -1,10 +1,10 @@
 // RUN: %dxc -E main -T ms_6_5 %s | FileCheck %s
 
-// CHECK: dx.op.getMeshPayload.struct.MeshPayload
+// CHECK: dx.op.getMeshPayload.struct.MeshPayload(i32 170)
 // CHECK: dx.op.setMeshOutputCounts(i32 168, i32 32, i32 16)
-// CHECK: dx.op.emitIndices
-// CHECK: dx.op.storeVertexOutput
-// CHECK: dx.op.storePrimitiveOutput
+// CHECK: dx.op.emitIndices(i32 169,
+// CHECK: dx.op.storePrimitiveOutput.f32(i32 172,
+// CHECK: dx.op.storeVertexOutput.f32(i32 171,
 // CHECK: !"cullPrimitive", i32 3, i32 100, i32 4, !"SV_CullPrimitive", i32 7, i32 1}
 
 #define MAX_VERT 32