Browse Source

Add load for argument matrix subscript lowering (#3065)

Matrix lowering was not handling subscript oeprations when the matrix
was a shader input. copying the matrix to a temporary explicitly or
implicitly as part of a copy in function call worked around the problem.
When the argument copy was eliminated, this problem was exposed.

By adding a load for the indicated matrix for a subscript of a shader
input matrix, the lowering can procede using the lowered matrix from the
load.

Additionally, hull shaders with their uninlineable functions were
failing to detect the constanthull function as a graphics function so
the subscript lowering was being treated as if it were in a library.
Even when it got to signature lowering, there was no support for
lowering matrix loads.

The test for shaderism now tests the function attributes of the module's
entry function. The signature lowering for matrix loads is moved into a
function that is called by input lowering for both the constant hull and
entry functions.

Added general tests for matrix subscripts from different memory types as
well as specific tests for the argument passing problem in pixel and
domain shaders.

A more correct way to identify functions that should delay the lowering
of their matrix parameters to signature lowering time is to query
whether the function has a signature. This covers entry functions for
graphics shaders and also constant patch functions.

Fixes #2958
Greg Roth 5 năm trước cách đây
mục cha
commit
03676eb639

+ 1 - 1
lib/HLSL/HLLowerUDT.cpp

@@ -408,7 +408,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
 
         std::vector<Instruction*> DeadInsts;
         HLMatrixSubscriptUseReplacer UseReplacer(
-          CI, NewV, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
+          CI, NewV, /*TempLoweredMatrix*/nullptr, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
         DXASSERT(CI->use_empty(),
                  "Expected all matrix subscript uses to have been replaced.");
         CI->eraseFromParent();

+ 21 - 6
lib/HLSL/HLMatrixLowerPass.cpp

@@ -406,7 +406,7 @@ Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Build
     RootPtr = GEP->getPointerOperand();
 
   Argument *Arg = dyn_cast<Argument>(RootPtr);
-  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
+  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsEntryThatUsesSignatures(Arg->getParent());
   if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
     // Bitcast the matrix pointer to its lowered equivalent.
     // The HLMatrixBitcast pass will take care of this later.
@@ -1474,17 +1474,32 @@ void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, Small
 
   IRBuilder<> CallBuilder(Call);
   Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
-  if (LoweredPtr == nullptr) return;
+  Value *LoweredMatrix = nullptr;
+  Value *RootPtr = LoweredPtr? LoweredPtr: MatPtr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
 
+  if (LoweredPtr == nullptr) {
+    if (!isa<Argument>(RootPtr))
+      return;
+
+    // For a shader input, load the matrix into a lowered ptr
+    // The load will be handled by LowerSignature
+    HLMatLoadStoreOpcode Opcode = (HLSubscriptOpcode)GetHLOpcode(Call) == HLSubscriptOpcode::RowMatSubscript ?
+                                   HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
+    HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
+    LoweredMatrix = callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      MatTy.getLoweredVectorTypeForReg(), { CallBuilder.getInt32((uint32_t)Opcode), MatPtr },
+      Call->getCalledFunction()->getAttributes().getFnAttributes(), CallBuilder);
+  }
   // For global variables, we can GEP directly into the lowered vector pointer.
   // This is necessary to support group shared memory atomics and the likes.
-  Value *RootPtr = LoweredPtr;
-  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
-    RootPtr = GEP->getPointerOperand();
   bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
   
   // Just constructing this does all the work
-  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
+  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, LoweredMatrix,
+                                           ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
 
   DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
   addToDeadInsts(Call);

+ 12 - 6
lib/HLSL/HLMatrixSubscriptUseReplacer.cpp

@@ -19,9 +19,10 @@
 using namespace llvm;
 using namespace hlsl;
 
-HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr,
+HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr, Value *TempLoweredMatrix,
   SmallVectorImpl<Value*> &ElemIndices, bool AllowLoweredPtrGEPs, std::vector<Instruction*> &DeadInsts)
-  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts), AllowLoweredPtrGEPs(AllowLoweredPtrGEPs)
+  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts),
+    AllowLoweredPtrGEPs(AllowLoweredPtrGEPs), TempLoweredMatrix(TempLoweredMatrix)
 {
   HasScalarResult = !Call->getType()->getPointerElementType()->isVectorTy();
 
@@ -32,6 +33,11 @@ HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value
     }
   }
 
+  if (TempLoweredMatrix)
+    LoweredTy = TempLoweredMatrix->getType();
+  else
+    LoweredTy = LoweredPtr->getType()->getPointerElementType();
+
   replaceUses(Call, /* GEPIdx */ nullptr);
 }
 
@@ -162,7 +168,8 @@ void HLMatrixSubscriptUseReplacer::cacheLoweredMatrix(bool ForDynamicIndexing, I
 
   // Load without memory to register representation conversion,
   // since the point is to mimic pointer semantics
-  TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
+  if (!TempLoweredMatrix)
+    TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
 
   if (!ForDynamicIndexing) return;
 
@@ -238,7 +245,6 @@ Value *HLMatrixSubscriptUseReplacer::loadVector(IRBuilder<> &Builder) {
 
   // Otherwise load elements one by one
   // Lowered form may be array when AllowLoweredPtrGEPs == true.
-  Type* LoweredTy = LoweredPtr->getType()->getPointerElementType();
   Type* ElemTy = LoweredTy->isVectorTy() ? LoweredTy->getScalarType() :
               cast<ArrayType>(LoweredTy)->getArrayElementType();
   VectorType *VecTy = VectorType::get(ElemTy, static_cast<unsigned>(ElemIndices.size()));
@@ -270,7 +276,7 @@ void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
     // First re-create the vector from the temporary array
     DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
 
-    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredPtr->getType()->getPointerElementType());
+    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredTy);
     TempLoweredMatrix = UndefValue::get(LoweredMatrixTy);
     Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
     for (unsigned ElemIdx = 0; ElemIdx < LoweredMatrixTy->getNumElements(); ++ElemIdx) {
@@ -284,4 +290,4 @@ void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
   // Store back the lowered matrix to its pointer
   Builder.CreateStore(TempLoweredMatrix, LoweredPtr);
   TempLoweredMatrix = nullptr;
-}
+}

+ 3 - 2
lib/HLSL/HLMatrixSubscriptUseReplacer.h

@@ -30,7 +30,7 @@ namespace hlsl {
 class HLMatrixSubscriptUseReplacer {
 public:
   // The constructor does everything
-  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr,
+  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr, llvm::Value *TempLoweredMatrix,
     llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
     std::vector<llvm::Instruction*> &DeadInsts);
 
@@ -51,6 +51,7 @@ private:
   bool AllowLoweredPtrGEPs = false;
   bool HasScalarResult = false;
   bool HasDynamicElemIndex = false;
+  llvm::Type *LoweredTy = nullptr;
 
   // The entire lowered matrix as loaded from LoweredPtr,
   // nullptr if we copied it to a temporary array.
@@ -64,4 +65,4 @@ private:
   // so we can dynamically index the level 1 indices.
   llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
 };
-} // namespace hlsl
+} // namespace hlsl

+ 112 - 76
lib/HLSL/HLSignatureLower.cpp

@@ -588,6 +588,7 @@ Value *GenerateLdInput(Function *loadInput, ArrayRef<Value *> args,
   }
 }
 
+
 Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
                             unsigned cols, MutableArrayRef<Value *> args,
                             bool bCast) {
@@ -654,6 +655,96 @@ Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
   }
 }
 
+
+void replaceMatStWithStOutputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
+                               Function *ldStFunc, Constant *OpArg, Constant *ID,
+                               Constant *columnConsts[],Value *vertexOrPrimID,
+                               Value *idxVal) {
+  IRBuilder<> LocalBuilder(CI);
+  Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+  HLMatrixType MatTy = HLMatrixType::cast(
+                                          CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+                                          ->getType()->getPointerElementType());
+
+  Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
+
+  if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      Constant *constColIdx = LocalBuilder.getInt32(c);
+      Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
+
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+        Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+        LocalBuilder.CreateCall(ldStFunc,
+                                { OpArg, ID, colIdx, columnConsts[r], Elt });
+      }
+    }
+  } else {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(r);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
+        Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+        LocalBuilder.CreateCall(ldStFunc,
+                                { OpArg, ID, rowIdx, columnConsts[c], Elt });
+      }
+    }
+  }
+  CI->eraseFromParent();
+}
+
+
+void replaceMatLdWithLdInputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
+                              Function *ldStFunc, Constant *OpArg, Constant *ID,
+                              Constant *columnConsts[],Value *vertexOrPrimID,
+                              Value *idxVal) {
+  IRBuilder<> LocalBuilder(CI);
+  HLMatrixType MatTy = HLMatrixType::cast(
+                                          CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+                                          ->getType()->getPointerElementType());
+  std::vector<Value *> matElts(MatTy.getNumElements());
+
+  if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(c);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
+        if (vertexOrPrimID)
+          args.emplace_back(vertexOrPrimID);
+
+        Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+        matElts[matIdx] = input;
+      }
+    }
+  } else {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(r);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
+        if (vertexOrPrimID)
+          args.emplace_back(vertexOrPrimID);
+
+        Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
+        matElts[matIdx] = input;
+      }
+    }
+  }
+
+  Value *newVec =
+    HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
+  newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
+
+  CI->replaceAllUsesWith(newVec);
+  CI->eraseFromParent();
+}
+
+
 void replaceDirectInputParameter(Value *param, Function *loadInput,
                                  unsigned cols, MutableArrayRef<Value *> args,
                                  bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
@@ -964,84 +1055,11 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     switch (matOp) {
     case HLMatLoadStoreOpcode::ColMatLoad:
     case HLMatLoadStoreOpcode::RowMatLoad: {
-      IRBuilder<> LocalBuilder(CI);
-      HLMatrixType MatTy = HLMatrixType::cast(
-        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-          ->getType()->getPointerElementType());
-      std::vector<Value *> matElts(MatTy.getNumElements());
-
-      if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
-        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(c);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-            SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
-            if (vertexOrPrimID)
-              args.emplace_back(vertexOrPrimID);
-
-            Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
-            matElts[matIdx] = input;
-          }
-        }
-      } else {
-        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(r);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-            SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
-            if (vertexOrPrimID)
-              args.emplace_back(vertexOrPrimID);
-
-            Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
-            matElts[matIdx] = input;
-          }
-        }
-      }
-
-      Value *newVec =
-          HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
-      newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
-
-      CI->replaceAllUsesWith(newVec);
-      CI->eraseFromParent();
+      replaceMatLdWithLdInputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
     } break;
     case HLMatLoadStoreOpcode::ColMatStore:
     case HLMatLoadStoreOpcode::RowMatStore: {
-      IRBuilder<> LocalBuilder(CI);
-      Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      HLMatrixType MatTy = HLMatrixType::cast(
-        CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-          ->getType()->getPointerElementType());
-
-      Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
-
-      if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
-        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-          Constant *constColIdx = LocalBuilder.getInt32(c);
-          Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
-
-          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
-            Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-            LocalBuilder.CreateCall(ldStFunc,
-              { OpArg, ID, colIdx, columnConsts[r], Elt });
-          }
-        }
-      } else {
-        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(r);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
-            Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-            LocalBuilder.CreateCall(ldStFunc,
-              { OpArg, ID, rowIdx, columnConsts[c], Elt });
-          }
-        }
-      }
-      CI->eraseFromParent();
+      replaceMatStWithStOutputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
     } break;
     }
   } else {
@@ -1386,6 +1404,14 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
   Type *i1Ty = Type::getInt1Ty(constZero->getContext());
   Type *i32Ty = constZero->getType();
 
+  Constant *columnConsts[] = {
+      hlslOP->GetU8Const(0),  hlslOP->GetU8Const(1),  hlslOP->GetU8Const(2),
+      hlslOP->GetU8Const(3),  hlslOP->GetU8Const(4),  hlslOP->GetU8Const(5),
+      hlslOP->GetU8Const(6),  hlslOP->GetU8Const(7),  hlslOP->GetU8Const(8),
+      hlslOP->GetU8Const(9),  hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
+      hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
+      hlslOP->GetU8Const(15)};
+
   for (Argument &arg : patchConstantFunc->args()) {
     DxilParameterAnnotation &paramAnnotation =
         patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
@@ -1422,11 +1448,21 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
                                    /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
       for (InputOutputAccessInfo &info : accessInfoList) {
+        Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
-          Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
                            info.vertexOrPrimID};
           replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
+        } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
+          HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+          // Intrinsic will be translated later.
+          if (group == HLOpcodeGroup::HLIntrinsic || group == HLOpcodeGroup::NotHL)
+            return;
+          unsigned opcode = GetHLOpcode(CI);
+          DXASSERT_NOMSG(group == HLOpcodeGroup::HLMatLoadStore);
+          HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
+          if (matOp == HLMatLoadStoreOpcode::ColMatLoad || matOp == HLMatLoadStoreOpcode::RowMatLoad)
+            replaceMatLdWithLdInputs(CI, matOp, dxilLdFunc, OpArg, inputID, columnConsts, info.vertexOrPrimID, info.idx);
         } else {
           DXASSERT(0, "input should only be ld");
         }

+ 69 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript.hlsl

@@ -0,0 +1,69 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Test for general subscript operations on matrix arrays.
+// Specifically focused on shader inputs which failed to lower previously
+
+float3 GetRow(const float3x3 m, const int j)
+{
+  return m[j];
+}
+
+float3x3 g[2];
+groupshared float3x3 gs[2];
+
+struct JustMtx {
+  float3x3 mtx;
+};
+
+struct MtxArray {
+  float3x3 mtx[2];
+};
+
+[shader("pixel")]
+float3 main(const int i : I, const int j : J, const float3x3 m[2]: M, JustMtx jm[2] : JM, MtxArray ma : A) : SV_Target
+{
+  float3 ret = 0.0;
+
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  ret += g[MIDX][VIDX];
+
+  // CHECK: load float, float addrspace(3)*
+  // CHECK: load float, float addrspace(3)*
+  // CHECK: load float, float addrspace(3)*
+  ret += gs[MIDX][VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += m[MIDX][VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += jm[MIDX].mtx[VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += ma.mtx[MIDX][VIDX];
+
+  ret += GetRow(g[MIDX], VIDX);
+  ret += GetRow(gs[MIDX], VIDX);
+  ret += GetRow(m[MIDX], VIDX);
+  ret += GetRow(jm[MIDX].mtx, VIDX);
+  ret += GetRow(ma.mtx[MIDX], VIDX);
+
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %{{.*}})
+  return ret;
+}

+ 40 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript_ds.hlsl

@@ -0,0 +1,40 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Specific test for subscript operation on OutputPatch matrix data
+
+struct MatStruct {
+ float4x4 mtx : M;
+};
+
+float4 GetRow(const OutputPatch<MatStruct, 3> tri, int i, int j)
+{
+  return tri[MIDX].mtx[VIDX];
+}
+
+[domain("tri")]
+[shader("domain")]
+float4 main(int i : I, int j : J, const OutputPatch<MatStruct, 3> tri) : SV_Position {
+
+  float4 ret = 0;
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  ret += tri[MIDX].mtx[VIDX];
+
+  ret += GetRow(tri, MIDX, VIDX);
+
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %{{.*}})
+  return ret;
+}

+ 62 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript_hs.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Specific test for subscript operation on matrix array inputs in patch functions
+
+struct MatStruct {
+  int2 uv : TEXCOORD0;
+  float3x4  m_ObjectToWorld : TEXCOORD1;
+};
+
+struct Output {
+  float edges[3] : SV_TessFactor;
+  float inside : SV_InsideTessFactor;
+};
+
+// Instruction order here is a bit inconsistent.
+// So we can't test for all the outputs
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call void @dx.op.storePatchConstant.f32
+// CHECK: call void @dx.op.storePatchConstant.f32
+Output Patch(InputPatch<MatStruct, 3> inputs)
+{
+  Output ret;
+  int i = inputs[0].uv.x;
+  int j = inputs[0].uv.y;
+
+  ret.edges[0] = inputs[MIDX].m_ObjectToWorld[VIDX][0];
+  ret.edges[1] = inputs[MIDX].m_ObjectToWorld[VIDX][1];
+  ret.edges[2] = inputs[MIDX].m_ObjectToWorld[VIDX][2];
+  ret.inside = 1.0f;
+  return ret;
+}
+
+
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+[domain("tri")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_cw")]
+[patchconstantfunc("Patch")]
+[outputcontrolpoints(3)]
+[shader("hull")]
+float4 main(InputPatch<MatStruct, 3> inputs) : SV_Position
+{
+  int i = inputs[0].uv.x;
+  int j = inputs[0].uv.y;
+  return inputs[MIDX].m_ObjectToWorld[VIDX];
+}