Browse Source

Add load for argument matrix subscript lowering (#3065)

Matrix lowering was not handling subscript oeprations when the matrix
was a shader input. copying the matrix to a temporary explicitly or
implicitly as part of a copy in function call worked around the problem.
When the argument copy was eliminated, this problem was exposed.

By adding a load for the indicated matrix for a subscript of a shader
input matrix, the lowering can procede using the lowered matrix from the
load.

Additionally, hull shaders with their uninlineable functions were
failing to detect the constanthull function as a graphics function so
the subscript lowering was being treated as if it were in a library.
Even when it got to signature lowering, there was no support for
lowering matrix loads.

The test for shaderism now tests the function attributes of the module's
entry function. The signature lowering for matrix loads is moved into a
function that is called by input lowering for both the constant hull and
entry functions.

Added general tests for matrix subscripts from different memory types as
well as specific tests for the argument passing problem in pixel and
domain shaders.

A more correct way to identify functions that should delay the lowering
of their matrix parameters to signature lowering time is to query
whether the function has a signature. This covers entry functions for
graphics shaders and also constant patch functions.

Fixes #2958
Greg Roth 5 years ago
parent
commit
03676eb639

+ 1 - 1
lib/HLSL/HLLowerUDT.cpp

@@ -408,7 +408,7 @@ void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
 
 
         std::vector<Instruction*> DeadInsts;
         std::vector<Instruction*> DeadInsts;
         HLMatrixSubscriptUseReplacer UseReplacer(
         HLMatrixSubscriptUseReplacer UseReplacer(
-          CI, NewV, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
+          CI, NewV, /*TempLoweredMatrix*/nullptr, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
         DXASSERT(CI->use_empty(),
         DXASSERT(CI->use_empty(),
                  "Expected all matrix subscript uses to have been replaced.");
                  "Expected all matrix subscript uses to have been replaced.");
         CI->eraseFromParent();
         CI->eraseFromParent();

+ 21 - 6
lib/HLSL/HLMatrixLowerPass.cpp

@@ -406,7 +406,7 @@ Value *HLMatrixLowerPass::tryGetLoweredPtrOperand(Value *Ptr, IRBuilder<> &Build
     RootPtr = GEP->getPointerOperand();
     RootPtr = GEP->getPointerOperand();
 
 
   Argument *Arg = dyn_cast<Argument>(RootPtr);
   Argument *Arg = dyn_cast<Argument>(RootPtr);
-  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsGraphicsShader(Arg->getParent());
+  bool IsNonShaderArg = Arg != nullptr && !m_pHLModule->IsEntryThatUsesSignatures(Arg->getParent());
   if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
   if (IsNonShaderArg || isa<AllocaInst>(RootPtr)) {
     // Bitcast the matrix pointer to its lowered equivalent.
     // Bitcast the matrix pointer to its lowered equivalent.
     // The HLMatrixBitcast pass will take care of this later.
     // The HLMatrixBitcast pass will take care of this later.
@@ -1474,17 +1474,32 @@ void HLMatrixLowerPass::lowerHLMatSubscript(CallInst *Call, Value *MatPtr, Small
 
 
   IRBuilder<> CallBuilder(Call);
   IRBuilder<> CallBuilder(Call);
   Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
   Value *LoweredPtr = tryGetLoweredPtrOperand(MatPtr, CallBuilder);
-  if (LoweredPtr == nullptr) return;
+  Value *LoweredMatrix = nullptr;
+  Value *RootPtr = LoweredPtr? LoweredPtr: MatPtr;
+  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
+    RootPtr = GEP->getPointerOperand();
 
 
+  if (LoweredPtr == nullptr) {
+    if (!isa<Argument>(RootPtr))
+      return;
+
+    // For a shader input, load the matrix into a lowered ptr
+    // The load will be handled by LowerSignature
+    HLMatLoadStoreOpcode Opcode = (HLSubscriptOpcode)GetHLOpcode(Call) == HLSubscriptOpcode::RowMatSubscript ?
+                                   HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
+    HLMatrixType MatTy = HLMatrixType::cast(MatPtr->getType()->getPointerElementType());
+    LoweredMatrix = callHLFunction(
+      *m_pModule, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(Opcode),
+      MatTy.getLoweredVectorTypeForReg(), { CallBuilder.getInt32((uint32_t)Opcode), MatPtr },
+      Call->getCalledFunction()->getAttributes().getFnAttributes(), CallBuilder);
+  }
   // For global variables, we can GEP directly into the lowered vector pointer.
   // For global variables, we can GEP directly into the lowered vector pointer.
   // This is necessary to support group shared memory atomics and the likes.
   // This is necessary to support group shared memory atomics and the likes.
-  Value *RootPtr = LoweredPtr;
-  while (GEPOperator *GEP = dyn_cast<GEPOperator>(RootPtr))
-    RootPtr = GEP->getPointerOperand();
   bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
   bool AllowLoweredPtrGEPs = isa<GlobalVariable>(RootPtr);
   
   
   // Just constructing this does all the work
   // Just constructing this does all the work
-  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
+  HLMatrixSubscriptUseReplacer UseReplacer(Call, LoweredPtr, LoweredMatrix,
+                                           ElemIndices, AllowLoweredPtrGEPs, m_deadInsts);
 
 
   DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
   DXASSERT(Call->use_empty(), "Expected all matrix subscript uses to have been replaced.");
   addToDeadInsts(Call);
   addToDeadInsts(Call);

+ 12 - 6
lib/HLSL/HLMatrixSubscriptUseReplacer.cpp

@@ -19,9 +19,10 @@
 using namespace llvm;
 using namespace llvm;
 using namespace hlsl;
 using namespace hlsl;
 
 
-HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr,
+HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value *LoweredPtr, Value *TempLoweredMatrix,
   SmallVectorImpl<Value*> &ElemIndices, bool AllowLoweredPtrGEPs, std::vector<Instruction*> &DeadInsts)
   SmallVectorImpl<Value*> &ElemIndices, bool AllowLoweredPtrGEPs, std::vector<Instruction*> &DeadInsts)
-  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts), AllowLoweredPtrGEPs(AllowLoweredPtrGEPs)
+  : LoweredPtr(LoweredPtr), ElemIndices(ElemIndices), DeadInsts(DeadInsts),
+    AllowLoweredPtrGEPs(AllowLoweredPtrGEPs), TempLoweredMatrix(TempLoweredMatrix)
 {
 {
   HasScalarResult = !Call->getType()->getPointerElementType()->isVectorTy();
   HasScalarResult = !Call->getType()->getPointerElementType()->isVectorTy();
 
 
@@ -32,6 +33,11 @@ HLMatrixSubscriptUseReplacer::HLMatrixSubscriptUseReplacer(CallInst* Call, Value
     }
     }
   }
   }
 
 
+  if (TempLoweredMatrix)
+    LoweredTy = TempLoweredMatrix->getType();
+  else
+    LoweredTy = LoweredPtr->getType()->getPointerElementType();
+
   replaceUses(Call, /* GEPIdx */ nullptr);
   replaceUses(Call, /* GEPIdx */ nullptr);
 }
 }
 
 
@@ -162,7 +168,8 @@ void HLMatrixSubscriptUseReplacer::cacheLoweredMatrix(bool ForDynamicIndexing, I
 
 
   // Load without memory to register representation conversion,
   // Load without memory to register representation conversion,
   // since the point is to mimic pointer semantics
   // since the point is to mimic pointer semantics
-  TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
+  if (!TempLoweredMatrix)
+    TempLoweredMatrix = Builder.CreateLoad(LoweredPtr);
 
 
   if (!ForDynamicIndexing) return;
   if (!ForDynamicIndexing) return;
 
 
@@ -238,7 +245,6 @@ Value *HLMatrixSubscriptUseReplacer::loadVector(IRBuilder<> &Builder) {
 
 
   // Otherwise load elements one by one
   // Otherwise load elements one by one
   // Lowered form may be array when AllowLoweredPtrGEPs == true.
   // Lowered form may be array when AllowLoweredPtrGEPs == true.
-  Type* LoweredTy = LoweredPtr->getType()->getPointerElementType();
   Type* ElemTy = LoweredTy->isVectorTy() ? LoweredTy->getScalarType() :
   Type* ElemTy = LoweredTy->isVectorTy() ? LoweredTy->getScalarType() :
               cast<ArrayType>(LoweredTy)->getArrayElementType();
               cast<ArrayType>(LoweredTy)->getArrayElementType();
   VectorType *VecTy = VectorType::get(ElemTy, static_cast<unsigned>(ElemIndices.size()));
   VectorType *VecTy = VectorType::get(ElemTy, static_cast<unsigned>(ElemIndices.size()));
@@ -270,7 +276,7 @@ void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
     // First re-create the vector from the temporary array
     // First re-create the vector from the temporary array
     DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
     DXASSERT_NOMSG(LazyTempElemArrayAlloca != nullptr);
 
 
-    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredPtr->getType()->getPointerElementType());
+    VectorType *LoweredMatrixTy = cast<VectorType>(LoweredTy);
     TempLoweredMatrix = UndefValue::get(LoweredMatrixTy);
     TempLoweredMatrix = UndefValue::get(LoweredMatrixTy);
     Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
     Value *GEPIndices[2] = { Builder.getInt32(0), nullptr };
     for (unsigned ElemIdx = 0; ElemIdx < LoweredMatrixTy->getNumElements(); ++ElemIdx) {
     for (unsigned ElemIdx = 0; ElemIdx < LoweredMatrixTy->getNumElements(); ++ElemIdx) {
@@ -284,4 +290,4 @@ void HLMatrixSubscriptUseReplacer::flushLoweredMatrix(IRBuilder<> &Builder) {
   // Store back the lowered matrix to its pointer
   // Store back the lowered matrix to its pointer
   Builder.CreateStore(TempLoweredMatrix, LoweredPtr);
   Builder.CreateStore(TempLoweredMatrix, LoweredPtr);
   TempLoweredMatrix = nullptr;
   TempLoweredMatrix = nullptr;
-}
+}

+ 3 - 2
lib/HLSL/HLMatrixSubscriptUseReplacer.h

@@ -30,7 +30,7 @@ namespace hlsl {
 class HLMatrixSubscriptUseReplacer {
 class HLMatrixSubscriptUseReplacer {
 public:
 public:
   // The constructor does everything
   // The constructor does everything
-  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr,
+  HLMatrixSubscriptUseReplacer(llvm::CallInst* Call, llvm::Value *LoweredPtr, llvm::Value *TempLoweredMatrix,
     llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
     llvm::SmallVectorImpl<llvm::Value*> &ElemIndices, bool AllowLoweredPtrGEPs,
     std::vector<llvm::Instruction*> &DeadInsts);
     std::vector<llvm::Instruction*> &DeadInsts);
 
 
@@ -51,6 +51,7 @@ private:
   bool AllowLoweredPtrGEPs = false;
   bool AllowLoweredPtrGEPs = false;
   bool HasScalarResult = false;
   bool HasScalarResult = false;
   bool HasDynamicElemIndex = false;
   bool HasDynamicElemIndex = false;
+  llvm::Type *LoweredTy = nullptr;
 
 
   // The entire lowered matrix as loaded from LoweredPtr,
   // The entire lowered matrix as loaded from LoweredPtr,
   // nullptr if we copied it to a temporary array.
   // nullptr if we copied it to a temporary array.
@@ -64,4 +65,4 @@ private:
   // so we can dynamically index the level 1 indices.
   // so we can dynamically index the level 1 indices.
   llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
   llvm::AllocaInst *LazyTempElemIndicesArrayAlloca = nullptr;
 };
 };
-} // namespace hlsl
+} // namespace hlsl

+ 112 - 76
lib/HLSL/HLSignatureLower.cpp

@@ -588,6 +588,7 @@ Value *GenerateLdInput(Function *loadInput, ArrayRef<Value *> args,
   }
   }
 }
 }
 
 
+
 Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
 Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
                             unsigned cols, MutableArrayRef<Value *> args,
                             unsigned cols, MutableArrayRef<Value *> args,
                             bool bCast) {
                             bool bCast) {
@@ -654,6 +655,96 @@ Value *replaceLdWithLdInput(Function *loadInput, LoadInst *ldInst,
   }
   }
 }
 }
 
 
+
+void replaceMatStWithStOutputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
+                               Function *ldStFunc, Constant *OpArg, Constant *ID,
+                               Constant *columnConsts[],Value *vertexOrPrimID,
+                               Value *idxVal) {
+  IRBuilder<> LocalBuilder(CI);
+  Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+  HLMatrixType MatTy = HLMatrixType::cast(
+                                          CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+                                          ->getType()->getPointerElementType());
+
+  Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
+
+  if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      Constant *constColIdx = LocalBuilder.getInt32(c);
+      Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
+
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+        Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+        LocalBuilder.CreateCall(ldStFunc,
+                                { OpArg, ID, colIdx, columnConsts[r], Elt });
+      }
+    }
+  } else {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(r);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
+        Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+        LocalBuilder.CreateCall(ldStFunc,
+                                { OpArg, ID, rowIdx, columnConsts[c], Elt });
+      }
+    }
+  }
+  CI->eraseFromParent();
+}
+
+
+void replaceMatLdWithLdInputs(CallInst *CI, HLMatLoadStoreOpcode matOp,
+                              Function *ldStFunc, Constant *OpArg, Constant *ID,
+                              Constant *columnConsts[],Value *vertexOrPrimID,
+                              Value *idxVal) {
+  IRBuilder<> LocalBuilder(CI);
+  HLMatrixType MatTy = HLMatrixType::cast(
+                                          CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+                                          ->getType()->getPointerElementType());
+  std::vector<Value *> matElts(MatTy.getNumElements());
+
+  if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
+    for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(c);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+        SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
+        if (vertexOrPrimID)
+          args.emplace_back(vertexOrPrimID);
+
+        Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+        unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
+        matElts[matIdx] = input;
+      }
+    }
+  } else {
+    for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
+      Constant *constRowIdx = LocalBuilder.getInt32(r);
+      Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+      for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
+        SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
+        if (vertexOrPrimID)
+          args.emplace_back(vertexOrPrimID);
+
+        Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+        unsigned matIdx = MatTy.getRowMajorIndex(r, c);
+        matElts[matIdx] = input;
+      }
+    }
+  }
+
+  Value *newVec =
+    HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
+  newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
+
+  CI->replaceAllUsesWith(newVec);
+  CI->eraseFromParent();
+}
+
+
 void replaceDirectInputParameter(Value *param, Function *loadInput,
 void replaceDirectInputParameter(Value *param, Function *loadInput,
                                  unsigned cols, MutableArrayRef<Value *> args,
                                  unsigned cols, MutableArrayRef<Value *> args,
                                  bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
                                  bool bCast, OP *hlslOP, IRBuilder<> &Builder) {
@@ -964,84 +1055,11 @@ void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertex
     switch (matOp) {
     switch (matOp) {
     case HLMatLoadStoreOpcode::ColMatLoad:
     case HLMatLoadStoreOpcode::ColMatLoad:
     case HLMatLoadStoreOpcode::RowMatLoad: {
     case HLMatLoadStoreOpcode::RowMatLoad: {
-      IRBuilder<> LocalBuilder(CI);
-      HLMatrixType MatTy = HLMatrixType::cast(
-        CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-          ->getType()->getPointerElementType());
-      std::vector<Value *> matElts(MatTy.getNumElements());
-
-      if (matOp == HLMatLoadStoreOpcode::ColMatLoad) {
-        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(c);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-            SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[r] };
-            if (vertexOrPrimID)
-              args.emplace_back(vertexOrPrimID);
-
-            Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
-            matElts[matIdx] = input;
-          }
-        }
-      } else {
-        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(r);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-            SmallVector<Value *, 4> args = { OpArg, ID, rowIdx, columnConsts[c] };
-            if (vertexOrPrimID)
-              args.emplace_back(vertexOrPrimID);
-
-            Value *input = LocalBuilder.CreateCall(ldStFunc, args);
-            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
-            matElts[matIdx] = input;
-          }
-        }
-      }
-
-      Value *newVec =
-          HLMatrixLower::BuildVector(matElts[0]->getType(), matElts, LocalBuilder);
-      newVec = MatTy.emitLoweredMemToReg(newVec, LocalBuilder);
-
-      CI->replaceAllUsesWith(newVec);
-      CI->eraseFromParent();
+      replaceMatLdWithLdInputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
     } break;
     } break;
     case HLMatLoadStoreOpcode::ColMatStore:
     case HLMatLoadStoreOpcode::ColMatStore:
     case HLMatLoadStoreOpcode::RowMatStore: {
     case HLMatLoadStoreOpcode::RowMatStore: {
-      IRBuilder<> LocalBuilder(CI);
-      Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      HLMatrixType MatTy = HLMatrixType::cast(
-        CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-          ->getType()->getPointerElementType());
-
-      Val = MatTy.emitLoweredRegToMem(Val, LocalBuilder);
-
-      if (matOp == HLMatLoadStoreOpcode::ColMatStore) {
-        for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-          Constant *constColIdx = LocalBuilder.getInt32(c);
-          Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
-
-          for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-            unsigned matIdx = MatTy.getColumnMajorIndex(r, c);
-            Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-            LocalBuilder.CreateCall(ldStFunc,
-              { OpArg, ID, colIdx, columnConsts[r], Elt });
-          }
-        }
-      } else {
-        for (unsigned r = 0; r < MatTy.getNumRows(); r++) {
-          Constant *constRowIdx = LocalBuilder.getInt32(r);
-          Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-          for (unsigned c = 0; c < MatTy.getNumColumns(); c++) {
-            unsigned matIdx = MatTy.getRowMajorIndex(r, c);
-            Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-            LocalBuilder.CreateCall(ldStFunc,
-              { OpArg, ID, rowIdx, columnConsts[c], Elt });
-          }
-        }
-      }
-      CI->eraseFromParent();
+      replaceMatStWithStOutputs(CI, matOp, ldStFunc, OpArg, ID, columnConsts, vertexOrPrimID, idxVal);
     } break;
     } break;
     }
     }
   } else {
   } else {
@@ -1386,6 +1404,14 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
   Type *i1Ty = Type::getInt1Ty(constZero->getContext());
   Type *i1Ty = Type::getInt1Ty(constZero->getContext());
   Type *i32Ty = constZero->getType();
   Type *i32Ty = constZero->getType();
 
 
+  Constant *columnConsts[] = {
+      hlslOP->GetU8Const(0),  hlslOP->GetU8Const(1),  hlslOP->GetU8Const(2),
+      hlslOP->GetU8Const(3),  hlslOP->GetU8Const(4),  hlslOP->GetU8Const(5),
+      hlslOP->GetU8Const(6),  hlslOP->GetU8Const(7),  hlslOP->GetU8Const(8),
+      hlslOP->GetU8Const(9),  hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
+      hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
+      hlslOP->GetU8Const(15)};
+
   for (Argument &arg : patchConstantFunc->args()) {
   for (Argument &arg : patchConstantFunc->args()) {
     DxilParameterAnnotation &paramAnnotation =
     DxilParameterAnnotation &paramAnnotation =
         patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
         patchFuncAnnotation->GetParameterAnnotation(arg.getArgNo());
@@ -1422,11 +1448,21 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
                                    /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
                                    /*hasVertexOrPrimID*/ true, true, bRowMajor, false);
       for (InputOutputAccessInfo &info : accessInfoList) {
       for (InputOutputAccessInfo &info : accessInfoList) {
+        Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
-          Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx,
                            info.vertexOrPrimID};
                            info.vertexOrPrimID};
           replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
           replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
+        } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
+          HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+          // Intrinsic will be translated later.
+          if (group == HLOpcodeGroup::HLIntrinsic || group == HLOpcodeGroup::NotHL)
+            return;
+          unsigned opcode = GetHLOpcode(CI);
+          DXASSERT_NOMSG(group == HLOpcodeGroup::HLMatLoadStore);
+          HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
+          if (matOp == HLMatLoadStoreOpcode::ColMatLoad || matOp == HLMatLoadStoreOpcode::RowMatLoad)
+            replaceMatLdWithLdInputs(CI, matOp, dxilLdFunc, OpArg, inputID, columnConsts, info.vertexOrPrimID, info.idx);
         } else {
         } else {
           DXASSERT(0, "input should only be ld");
           DXASSERT(0, "input should only be ld");
         }
         }

+ 69 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript.hlsl

@@ -0,0 +1,69 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T ps_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Test for general subscript operations on matrix arrays.
+// Specifically focused on shader inputs which failed to lower previously
+
+float3 GetRow(const float3x3 m, const int j)
+{
+  return m[j];
+}
+
+float3x3 g[2];
+groupshared float3x3 gs[2];
+
+struct JustMtx {
+  float3x3 mtx;
+};
+
+struct MtxArray {
+  float3x3 mtx[2];
+};
+
+[shader("pixel")]
+float3 main(const int i : I, const int j : J, const float3x3 m[2]: M, JustMtx jm[2] : JM, MtxArray ma : A) : SV_Target
+{
+  float3 ret = 0.0;
+
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  // CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+  ret += g[MIDX][VIDX];
+
+  // CHECK: load float, float addrspace(3)*
+  // CHECK: load float, float addrspace(3)*
+  // CHECK: load float, float addrspace(3)*
+  ret += gs[MIDX][VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += m[MIDX][VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += jm[MIDX].mtx[VIDX];
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 4, i32 {{%?[0-9]*}}, i8 2, i32 undef)
+  ret += ma.mtx[MIDX][VIDX];
+
+  ret += GetRow(g[MIDX], VIDX);
+  ret += GetRow(gs[MIDX], VIDX);
+  ret += GetRow(m[MIDX], VIDX);
+  ret += GetRow(jm[MIDX].mtx, VIDX);
+  ret += GetRow(ma.mtx[MIDX], VIDX);
+
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %{{.*}})
+  return ret;
+}

+ 40 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript_ds.hlsl

@@ -0,0 +1,40 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T ds_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Specific test for subscript operation on OutputPatch matrix data
+
+struct MatStruct {
+ float4x4 mtx : M;
+};
+
+float4 GetRow(const OutputPatch<MatStruct, 3> tri, int i, int j)
+{
+  return tri[MIDX].mtx[VIDX];
+}
+
+[domain("tri")]
+[shader("domain")]
+float4 main(int i : I, int j : J, const OutputPatch<MatStruct, 3> tri) : SV_Position {
+
+  float4 ret = 0;
+
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  // CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 {{%?[0-9]*}}, i8 2, i32 {{%?[0-9]*}})
+  ret += tri[MIDX].mtx[VIDX];
+
+  ret += GetRow(tri, MIDX, VIDX);
+
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float %{{.*}})
+  // CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %{{.*}})
+  return ret;
+}

+ 62 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/matrix/matrix_subscript_hs.hlsl

@@ -0,0 +1,62 @@
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T hs_6_0 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=2 -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=1 -DVIDX=j -T lib_6_3 %s | FileCheck %s
+// RUN: %dxc -DMIDX=i -DVIDX=j -T lib_6_3 %s | FileCheck %s
+
+// Specific test for subscript operation on matrix array inputs in patch functions
+
+struct MatStruct {
+  int2 uv : TEXCOORD0;
+  float3x4  m_ObjectToWorld : TEXCOORD1;
+};
+
+struct Output {
+  float edges[3] : SV_TessFactor;
+  float inside : SV_InsideTessFactor;
+};
+
+// Instruction order here is a bit inconsistent.
+// So we can't test for all the outputs
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call void @dx.op.storePatchConstant.f32
+// CHECK: call void @dx.op.storePatchConstant.f32
+Output Patch(InputPatch<MatStruct, 3> inputs)
+{
+  Output ret;
+  int i = inputs[0].uv.x;
+  int j = inputs[0].uv.y;
+
+  ret.edges[0] = inputs[MIDX].m_ObjectToWorld[VIDX][0];
+  ret.edges[1] = inputs[MIDX].m_ObjectToWorld[VIDX][1];
+  ret.edges[2] = inputs[MIDX].m_ObjectToWorld[VIDX][2];
+  ret.inside = 1.0f;
+  return ret;
+}
+
+
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call float @dx.op.loadInput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+// CHECK: call void @dx.op.storeOutput.f32
+[domain("tri")]
+[partitioning("fractional_odd")]
+[outputtopology("triangle_cw")]
+[patchconstantfunc("Patch")]
+[outputcontrolpoints(3)]
+[shader("hull")]
+float4 main(InputPatch<MatStruct, 3> inputs) : SV_Position
+{
+  int i = inputs[0].uv.x;
+  int j = inputs[0].uv.y;
+  return inputs[MIDX].m_ObjectToWorld[VIDX];
+}