Forráskód Böngészése

For vector writemask, create gep + store instead of load + shuffle + store. (#120)

Xiang Li 8 éve
szülő
commit
fc2d2154cd

+ 225 - 200
lib/HLSL/DxilGenerationPass.cpp

@@ -895,18 +895,16 @@ static void GenerateStOutput(Function *stOutput, MutableArrayRef<Value *> args,
 }
 
 static void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
-                                  OP::OpCode opcode, Constant *outputID,
-                                  Value *idx, unsigned cols, bool bI1Cast, OP *hlslOP) {
+                                  Constant *OpArg, Constant *outputID,
+                                  Value *idx, unsigned cols, bool bI1Cast) {
   IRBuilder<> Builder(stInst);
   Value *val = stInst->getValueOperand();
 
-  Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
-
   if (VectorType *VT = dyn_cast<VectorType>(val->getType())) {
     DXASSERT(cols == VT->getNumElements(), "vec size must match");
     for (unsigned col = 0; col < cols; col++) {
       Value *subVal = Builder.CreateExtractElement(val, col);
-      Value *colIdx = hlslOP->GetU8Const(col);
+      Value *colIdx = Builder.getInt8(col);
       Value *args[] = {OpArg, outputID, idx, colIdx, subVal};
       GenerateStOutput(stOutput, args, Builder, bI1Cast);
     }
@@ -915,7 +913,7 @@ static void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
   } else if (!val->getType()->isArrayTy()) {
     // TODO: support case cols not 1
     DXASSERT(cols == 1, "only support scalar here");
-    Value *colIdx = hlslOP->GetU8Const(0);
+    Value *colIdx = Builder.getInt8(0);
     Value *args[] = {OpArg, outputID, idx, colIdx, val};
     GenerateStOutput(stOutput, args, Builder, bI1Cast);
     // remove stInst
@@ -923,7 +921,7 @@ static void replaceStWithStOutput(Function *stOutput, StoreInst *stInst,
   } else {
     DXASSERT(0, "not support array yet");
     // TODO: support array.
-    Value *colIdx = hlslOP->GetU8Const(0);
+    Value *colIdx = Builder.getInt8(0);
     ArrayType *AT = cast<ArrayType>(val->getType());
     Value *args[] = {OpArg, outputID, idx, colIdx, /*val*/nullptr};
     args;
@@ -950,19 +948,18 @@ static Value *replaceLdWithLdInput(Function *loadInput,
                                  LoadInst *ldInst,
                                  unsigned cols, 
                                  MutableArrayRef<Value *>args,
-                                 bool bCast,
-                                 OP *hlslOP) {
+                                 bool bCast) {
   IRBuilder<> Builder(ldInst);
   Type *Ty = ldInst->getType();
   Type *EltTy = Ty->getScalarType();
   // Change i1 to i32 for load input.
-  Value *zero = hlslOP->GetU32Const(0);
+  Value *zero = Builder.getInt32(0);
 
   if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
     Value *newVec = llvm::UndefValue::get(VT);
     DXASSERT(cols == VT->getNumElements(), "vec size must match");
     for (unsigned col = 0; col < cols; col++) {
-      Value *colIdx = hlslOP->GetU8Const(col);
+      Value *colIdx = Builder.getInt8(col);
       args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
       Value *input = GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
       newVec = Builder.CreateInsertElement(newVec, input, col);
@@ -974,7 +971,7 @@ static Value *replaceLdWithLdInput(Function *loadInput,
     Value *colIdx = args[DXIL::OperandIndex::kLoadInputColOpIdx];
     if (colIdx == nullptr) {
       DXASSERT(cols == 1, "only support scalar here");
-      colIdx = hlslOP->GetU8Const(0);
+      colIdx = Builder.getInt8(0);
     }
 
     if (isa<ConstantInt>(colIdx)) {
@@ -989,10 +986,10 @@ static Value *replaceLdWithLdInput(Function *loadInput,
       // Load to array.
       ArrayType *AT = ArrayType::get(ldInst->getType(), cols);
       Value *arrayVec = Builder.CreateAlloca(AT);
-      Value *zeroIdx = hlslOP->GetU32Const(0);
+      Value *zeroIdx = Builder.getInt32(0);
 
       for (unsigned col = 0; col < cols; col++) {
-        Value *colIdx = hlslOP->GetU8Const(col);
+        Value *colIdx = Builder.getInt8(col);
         args[DXIL::OperandIndex::kLoadInputColOpIdx] = colIdx;
         Value *input = GenerateLdInput(loadInput, args, Builder, zero, bCast, EltTy);
         Value *GEP = Builder.CreateInBoundsGEP(arrayVec, {zeroIdx, colIdx});
@@ -1273,6 +1270,180 @@ static void replaceInputOutputWithIntrinsic(DXIL::SemanticKind semKind, Value *G
   }
 }
 
+void GenerateInputOutputUserCall(InputOutputAccessInfo &info, Value *undefVertexIdx,
+    Function *ldStFunc, Constant *OpArg, Constant *ID, unsigned cols, bool bI1Cast,
+    Constant *columnConsts[],
+    bool bNeedVertexID, bool isArrayTy, bool bInput, bool bIsInout) {
+  Value *idxVal = info.idx;
+  Value *vertexID = undefVertexIdx;
+  if (bNeedVertexID && isArrayTy) {
+    vertexID = info.vertexID;
+  }
+
+  if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
+    SmallVector<Value *, 4> args = {OpArg, ID, idxVal, info.vectorIdx};
+    if (vertexID)
+      args.emplace_back(vertexID);
+
+    replaceLdWithLdInput(ldStFunc, ldInst, cols, args, bI1Cast);
+  } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
+    if (bInput) {
+      DXASSERT_LOCALVAR(bIsInout, bIsInout, "input should not have store use.");
+    } else {
+      if (!info.vectorIdx) {
+        replaceStWithStOutput(ldStFunc, stInst, OpArg, ID, idxVal, cols,
+                              bI1Cast);
+      } else {
+        Value *V = stInst->getValueOperand();
+        Type *Ty = V->getType();
+        DXASSERT(Ty == Ty->getScalarType() && !Ty->isAggregateType(),
+                 "only support scalar here");
+
+        if (ConstantInt *ColIdx = dyn_cast<ConstantInt>(info.vectorIdx)) {
+          IRBuilder<> Builder(stInst);
+          if (ColIdx->getType()->getBitWidth() != 8) {
+            ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
+          }
+          Value *args[] = {OpArg, ID, idxVal, ColIdx, V};
+          GenerateStOutput(ldStFunc, args, Builder, bI1Cast);
+        } else {
+          BasicBlock *BB = stInst->getParent();
+          BasicBlock *EndBB = BB->splitBasicBlock(stInst);
+
+          TerminatorInst *TI = BB->getTerminator();
+          IRBuilder<> SwitchBuilder(TI);
+          LLVMContext &Ctx = stInst->getContext();
+          SwitchInst *Switch =
+              SwitchBuilder.CreateSwitch(info.vectorIdx, EndBB, cols);
+          TI->eraseFromParent();
+
+          Function *F = EndBB->getParent();
+          for (unsigned i = 0; i < cols; i++) {
+            BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
+            Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
+            IRBuilder<> CaseBuilder(CaseBB);
+
+            ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
+
+            Value *args[] = {OpArg, ID, idxVal, CaseIdx, V};
+            GenerateStOutput(ldStFunc, args, CaseBuilder, bI1Cast);
+
+            CaseBuilder.CreateBr(EndBB);
+          }
+        }
+        // remove stInst
+        stInst->eraseFromParent();
+      }
+    }
+  } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
+    HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+    // Intrinsic will be translated later.
+    if (group == HLOpcodeGroup::HLIntrinsic)
+      return;
+    unsigned opcode = GetHLOpcode(CI);
+    DXASSERT(group == HLOpcodeGroup::HLMatLoadStore, "");
+    HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
+    switch (matOp) {
+    case HLMatLoadStoreOpcode::ColMatLoad: {
+      IRBuilder<> LocalBuilder(CI);
+      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+                        ->getType()
+                        ->getPointerElementType();
+      unsigned col, row;
+      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
+      std::vector<Value *> matElts(col * row);
+      for (unsigned c = 0; c < col; c++) {
+        Constant *constRowIdx = LocalBuilder.getInt32(c);
+        Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+        for (unsigned r = 0; r < row; r++) {
+          SmallVector<Value *, 4> args = {OpArg, ID, rowIdx, columnConsts[r]};
+          if (vertexID)
+            args.emplace_back(vertexID);
+
+          Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+          unsigned matIdx = c * row + r;
+          matElts[matIdx] = input;
+        }
+      }
+      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, true, matElts,
+                                                 LocalBuilder);
+      CI->replaceAllUsesWith(newVec);
+      CI->eraseFromParent();
+    } break;
+    case HLMatLoadStoreOpcode::RowMatLoad: {
+      IRBuilder<> LocalBuilder(CI);
+      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
+                        ->getType()
+                        ->getPointerElementType();
+      unsigned col, row;
+      Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
+      std::vector<Value *> matElts(col * row);
+      for (unsigned r = 0; r < row; r++) {
+        Constant *constRowIdx = LocalBuilder.getInt32(r);
+        Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+        for (unsigned c = 0; c < col; c++) {
+          SmallVector<Value *, 4> args = {OpArg, ID, rowIdx, columnConsts[c]};
+          if (vertexID)
+            args.emplace_back(vertexID);
+
+          Value *input = LocalBuilder.CreateCall(ldStFunc, args);
+          unsigned matIdx = r * col + c;
+          matElts[matIdx] = input;
+        }
+      }
+      Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, false,
+                                                 matElts, LocalBuilder);
+      CI->replaceAllUsesWith(newVec);
+      CI->eraseFromParent();
+    } break;
+    case HLMatLoadStoreOpcode::ColMatStore: {
+      IRBuilder<> LocalBuilder(CI);
+      Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+                        ->getType()
+                        ->getPointerElementType();
+      unsigned col, row;
+      HLMatrixLower::GetMatrixInfo(matTy, col, row);
+
+      for (unsigned c = 0; c < col; c++) {
+        Constant *constColIdx = LocalBuilder.getInt32(c);
+        Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
+
+        for (unsigned r = 0; r < row; r++) {
+          unsigned matIdx = c * row + r;
+          Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+          LocalBuilder.CreateCall(ldStFunc,
+                                  {OpArg, ID, colIdx, columnConsts[r], Elt});
+        }
+      }
+      CI->eraseFromParent();
+    } break;
+    case HLMatLoadStoreOpcode::RowMatStore: {
+      IRBuilder<> LocalBuilder(CI);
+      Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+      Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
+                        ->getType()
+                        ->getPointerElementType();
+      unsigned col, row;
+      HLMatrixLower::GetMatrixInfo(matTy, col, row);
+
+      for (unsigned r = 0; r < row; r++) {
+        Constant *constRowIdx = LocalBuilder.getInt32(r);
+        Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
+        for (unsigned c = 0; c < col; c++) {
+          unsigned matIdx = r * col + c;
+          Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
+          LocalBuilder.CreateCall(ldStFunc,
+                                  {OpArg, ID, rowIdx, columnConsts[c], Elt});
+        }
+      }
+      CI->eraseFromParent();
+    } break;
+    }
+  } else
+    DXASSERT(0, "invalid operation on input output");
+}
+
 void DxilGenerationPass::GenerateDxilInputsOutputs(bool bInput) {
   OP *hlslOP = m_pHLModule->GetOP();
   const ShaderModel *pSM = m_pHLModule->GetShaderModel();
@@ -1342,175 +1513,18 @@ void DxilGenerationPass::GenerateDxilInputsOutputs(bool bInput) {
       continue;
     }
 
-    bool isArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
-    bool isPrecise = m_preciseSigSet.count(SE);
-    if (isPrecise)
+    bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
+    bool bIsPrecise = m_preciseSigSet.count(SE);
+    if (bIsPrecise)
       HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
 
     std::vector<InputOutputAccessInfo> accessInfoList;
-    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID && isArrayTy, bInput);
+    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID && bIsArrayTy, bInput);
 
     for (InputOutputAccessInfo &info : accessInfoList) {
-      Value *idxVal = info.idx;
-      Value *vertexID = undefVertexIdx;
-      if (bNeedVertexID && isArrayTy) {
-        vertexID = info.vertexID;
-      }
-
-      if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
-        Value *args[] = {OpArg, ID, idxVal, info.vectorIdx, vertexID};
-        replaceLdWithLdInput(dxilFunc, ldInst, cols, args, bI1Cast, hlslOP);
-      }
-      else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user)) {
-        if (bInput) {
-          DXASSERT_LOCALVAR(bIsInout, bIsInout, "input should not have store use.");
-        } else {
-          if (!info.vectorIdx) {
-            replaceStWithStOutput(dxilFunc, stInst, opcode, ID, idxVal, cols,
-                                  bI1Cast, hlslOP);
-          } else {
-            Value *V = stInst->getValueOperand();
-            Type *Ty = V->getType();
-            DXASSERT(Ty == Ty->getScalarType() && !Ty->isAggregateType(),
-                     "only support scalar here");
-
-            if (ConstantInt *ColIdx = dyn_cast<ConstantInt>(info.vectorIdx)) {
-              IRBuilder<> Builder(stInst);
-              if (ColIdx->getType()->getBitWidth() != 8) {
-                ColIdx = Builder.getInt8(ColIdx->getValue().getLimitedValue());
-              }
-              Value *args[] = {OpArg, ID, idxVal, ColIdx, V};
-              GenerateStOutput(dxilFunc, args, Builder, bI1Cast);
-            } else {
-              BasicBlock *BB = stInst->getParent();
-              BasicBlock *EndBB = BB->splitBasicBlock(stInst);
-
-              TerminatorInst *TI = BB->getTerminator();
-              IRBuilder<> SwitchBuilder(TI);
-              LLVMContext &Ctx = m_pHLModule->GetCtx();
-              SwitchInst *Switch =
-                  SwitchBuilder.CreateSwitch(info.vectorIdx, EndBB, cols);
-              TI->eraseFromParent();
-
-              Function *F = EndBB->getParent();
-              for (unsigned i = 0; i < cols; i++) {
-                BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
-                Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
-                IRBuilder<> CaseBuilder(CaseBB);
-
-                ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
-
-                Value *args[] = {OpArg, ID, idxVal, CaseIdx, V};
-                GenerateStOutput(dxilFunc, args, CaseBuilder, bI1Cast);
-
-                CaseBuilder.CreateBr(EndBB);
-              }
-            }
-            // remove stInst
-            stInst->eraseFromParent();
-          }
-        }
-      } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
-        HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-        // Intrinsic will be translated later.
-        if (group == HLOpcodeGroup::HLIntrinsic)
-          continue;
-        unsigned opcode = GetHLOpcode(CI);
-        DXASSERT(group == HLOpcodeGroup::HLMatLoadStore, "");
-        HLMatLoadStoreOpcode matOp = static_cast<HLMatLoadStoreOpcode>(opcode);
-        switch (matOp) {
-        case HLMatLoadStoreOpcode::ColMatLoad: {
-          IRBuilder<> LocalBuilder(CI);
-          Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                            ->getType()
-                            ->getPointerElementType();
-          unsigned col, row;
-          Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-          std::vector<Value *> matElts(col * row);
-          for (unsigned c = 0; c < col; c++) {
-            Constant *constRowIdx = hlslOP->GetI32Const(c);
-            Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-            for (unsigned r = 0; r < row; r++) {
-              Value *input = LocalBuilder.CreateCall(
-                  dxilFunc, {OpArg, ID, rowIdx, columnConsts[r], vertexID});
-              unsigned matIdx = c * row + r;
-              matElts[matIdx] = input;
-            }
-          }
-          Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, true,
-                                                     matElts, LocalBuilder);
-          CI->replaceAllUsesWith(newVec);
-          CI->eraseFromParent();
-        } break;
-        case HLMatLoadStoreOpcode::RowMatLoad: {
-          IRBuilder<> LocalBuilder(CI);
-          Type *matTy = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx)
-                            ->getType()
-                            ->getPointerElementType();
-          unsigned col, row;
-          Type *EltTy = HLMatrixLower::GetMatrixInfo(matTy, col, row);
-          std::vector<Value *> matElts(col * row);
-          for (unsigned r = 0; r < row; r++) {
-            Constant *constRowIdx = hlslOP->GetI32Const(r);
-            Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-            for (unsigned c = 0; c < col; c++) {
-              Value *input = LocalBuilder.CreateCall(
-                  dxilFunc, {OpArg, ID, rowIdx, columnConsts[c], vertexID});
-              unsigned matIdx = r * col + c;
-              matElts[matIdx] = input;
-            }
-          }
-          Value *newVec = HLMatrixLower::BuildMatrix(EltTy, col, row, false,
-                                                     matElts, LocalBuilder);
-          CI->replaceAllUsesWith(newVec);
-          CI->eraseFromParent();
-        } break;
-        case HLMatLoadStoreOpcode::ColMatStore: {
-          IRBuilder<> LocalBuilder(CI);
-          Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-          Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-                            ->getType()
-                            ->getPointerElementType();
-          unsigned col, row;
-          HLMatrixLower::GetMatrixInfo(matTy, col, row);
-
-          for (unsigned c = 0; c < col; c++) {
-            Constant *constColIdx = hlslOP->GetI32Const(c);
-            Value *colIdx = LocalBuilder.CreateAdd(idxVal, constColIdx);
-
-            for (unsigned r = 0; r < row; r++) {
-              unsigned matIdx = c * row + r;
-              Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-              LocalBuilder.CreateCall(
-                  dxilFunc, {OpArg, ID, colIdx, columnConsts[r], Elt});
-            }
-          }
-          CI->eraseFromParent();
-        } break;
-        case HLMatLoadStoreOpcode::RowMatStore: {
-          IRBuilder<> LocalBuilder(CI);
-          Value *Val = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-          Type *matTy = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx)
-                            ->getType()
-                            ->getPointerElementType();
-          unsigned col, row;
-          HLMatrixLower::GetMatrixInfo(matTy, col, row);
-
-          for (unsigned r = 0; r < row; r++) {
-            Constant *constRowIdx = hlslOP->GetI32Const(r);
-            Value *rowIdx = LocalBuilder.CreateAdd(idxVal, constRowIdx);
-            for (unsigned c = 0; c < col; c++) {
-              unsigned matIdx = r * col + c;
-              Value *Elt = LocalBuilder.CreateExtractElement(Val, matIdx);
-              LocalBuilder.CreateCall(
-                  dxilFunc, {OpArg, ID, rowIdx, columnConsts[c], Elt});
-            }
-          }
-          CI->eraseFromParent();
-        } break;
-        }
-      } else
-        DXASSERT(0, "invalid operation on input output");
+      GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  bIsArrayTy, bInput, bIsInout);
     }
   }
 }
@@ -1600,14 +1614,32 @@ void DxilGenerationPass::GenerateDxilPatchConstantLdSt() {
   DxilSignature &Sig = m_pHLModule->GetPatchConstantSignature();
   Function *EntryFunc = m_pHLModule->GetEntryFunction();
   auto InsertPt = EntryFunc->getEntryBlock().getFirstInsertionPt();
-  if (m_pHLModule->GetShaderModel()->IsHS()) {
+  const bool bIsHs = m_pHLModule->GetShaderModel()->IsHS();
+  const bool bIsInput = !bIsHs;
+  const bool bIsInout = false;
+  const bool bNeedVertexID = false;
+  if (bIsHs) {
     HLFunctionProps &EntryQual = m_pHLModule->GetHLFunctionProps(EntryFunc);
     Function *patchConstantFunc = EntryQual.ShaderProps.HS.patchConstantFunc;
     InsertPt = patchConstantFunc->getEntryBlock().getFirstInsertionPt();
   }
   IRBuilder<> Builder(InsertPt);
-  Type *i1Ty = Type::getInt1Ty(constZero->getContext());
-  Type *i32Ty = constZero->getType();
+  Type *i1Ty = Builder.getInt1Ty();
+  Type *i32Ty = Builder.getInt32Ty();
+  // LoadPatchConst don't have vertexIdx operand.
+  Value *undefVertexIdx = nullptr;
+
+  Constant *columnConsts[] = {
+      hlslOP->GetU8Const(0),  hlslOP->GetU8Const(1),  hlslOP->GetU8Const(2),
+      hlslOP->GetU8Const(3),  hlslOP->GetU8Const(4),  hlslOP->GetU8Const(5),
+      hlslOP->GetU8Const(6),  hlslOP->GetU8Const(7),  hlslOP->GetU8Const(8),
+      hlslOP->GetU8Const(9),  hlslOP->GetU8Const(10), hlslOP->GetU8Const(11),
+      hlslOP->GetU8Const(12), hlslOP->GetU8Const(13), hlslOP->GetU8Const(14),
+      hlslOP->GetU8Const(15)};
+
+  OP::OpCode opcode =
+      bIsInput ? OP::OpCode::LoadPatchConstant : OP::OpCode::StorePatchConstant;
+  Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
 
   for (unsigned i = 0; i < Sig.GetElements().size(); i++) {
     DxilSignatureElement *SE = &Sig.GetElement(i);
@@ -1629,39 +1661,32 @@ void DxilGenerationPass::GenerateDxilPatchConstantLdSt() {
       bI1Cast = true;
       Ty = i32Ty;
     }
-    Function *dxilLdFunc = hlslOP->GetOpFunc(OP::OpCode::LoadPatchConstant, Ty);
-    Function *dxilStFunc = hlslOP->GetOpFunc(OP::OpCode::StorePatchConstant, Ty);
 
     unsigned cols = SE->GetCols();
 
+    Function *dxilFunc = hlslOP->GetOpFunc(opcode, Ty);
+
     if (!GV->getType()->isPointerTy()) {
-      // Must be DS input.
+      DXASSERT(bIsInput, "Must be DS input.");
       Constant *OpArg = hlslOP->GetU32Const(static_cast<unsigned>(OP::OpCode::LoadPatchConstant));
       Value *args[] = {OpArg, ID, /*rowIdx*/constZero, /*colIdx*/nullptr};
-      replaceDirectInputParameter(GV, dxilLdFunc, cols, args, bI1Cast, hlslOP, Builder);
+      replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP, Builder);
       continue;
     }
     
     std::vector<InputOutputAccessInfo> accessInfoList;
-    collectInputOutputAccessInfo(GV, constZero, accessInfoList, /*hasVertexID*/ false,
-      !m_pHLModule->GetShaderModel()->IsHS());
+    collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID,
+      bIsInput);
+
+    bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
     bool isPrecise = m_preciseSigSet.count(SE);
     if (isPrecise)
       HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
 
     for (InputOutputAccessInfo &info : accessInfoList) {
-      Value *idxVal = info.idx;
-      if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
-        OP::OpCode opcode = OP::OpCode::LoadPatchConstant;
-        Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
-        Value *args[] = {OpArg, ID, idxVal, /*colIdx*/nullptr};
-        replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast, hlslOP);
-      } else if (StoreInst *stInst = dyn_cast<StoreInst>(info.user))
-        replaceStWithStOutput(dxilStFunc, stInst,
-                              OP::OpCode::StorePatchConstant, ID, idxVal, cols,
-                              bI1Cast, hlslOP);
-      else
-        DXASSERT(0, "invalid instruction on patch constant");
+      GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
+                                  cols, bI1Cast, columnConsts, bNeedVertexID,
+                                  bIsArrayTy, bIsInput, bIsInout);
     }
   }
 }
@@ -1709,7 +1734,7 @@ void DxilGenerationPass::GenerateDxilPatchConstantFunctionInputs() {
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
           Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
           Value *args[] = {OpArg, inputID, info.idx, info.vectorIdx, info.vertexID};
-          replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast, hlslOP);
+          replaceLdWithLdInput(dxilLdFunc, ldInst, cols, args, bI1Cast);
         } else
           DXASSERT(0, "input should only be ld");
       }

+ 18 - 0
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1690,6 +1690,24 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
   if (VTy == nullptr && getContext().getLangOpts().HLSL)
     VTy =
         hlsl::ConvertHLSLVecMatTypeToExtVectorType(getContext(), Dst.getType());
+  llvm::Value * VecDstPtr = Dst.getExtVectorAddr();
+  llvm::Value *Zero = Builder.getInt32(0);
+  if (VTy) {
+    llvm::Type *VecTy = VecDstPtr->getType()->getPointerElementType();
+    for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
+      if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
+        llvm::Value *EltGEP = Builder.CreateGEP(VecDstPtr, {Zero, Elt});
+        llvm::Value *SrcElt = Builder.CreateExtractElement(SrcVal, i);
+        Builder.CreateStore(SrcElt, EltGEP);
+      }
+    }
+  } else {
+    // If the Src is a scalar (not a vector) it must be updating one element.
+    llvm::Value *EltGEP = Builder.CreateGEP(
+        VecDstPtr, {Zero, Elts->getAggregateElement((unsigned)0)});
+    Builder.CreateStore(SrcVal, EltGEP);
+  }
+  return;
   // HLSL Change Ends
   if (VTy) {  // HLSL Change
     unsigned NumSrcElts = VTy->getNumElements();

+ 63 - 0
tools/clang/test/CodeGenHLSL/inout4.hlsl

@@ -0,0 +1,63 @@
+// RUN: %dxc -E main -T ds_6_0 %s | FileCheck %s
+
+// CHECK: SV_RenderTargetArrayIndex or SV_ViewportArrayIndex from any shader feeding rasterizer
+// CHECK: InputControlPointCount=3
+// CHECK: OutputPositionPresent=1
+// CHECK: domainLocation.f32
+
+// loadPatchConstant for the inout signature.
+// CHECK: loadPatchConstant
+
+//--------------------------------------------------------------------------------------
+// SimpleTessellation.hlsl
+//
+// Advanced Technology Group (ATG)
+// Copyright (C) Microsoft Corporation. All rights reserved.
+//--------------------------------------------------------------------------------------
+
+struct VSSceneIn {
+  float3 pos : POSITION;
+  float3 norm : NORMAL;
+  float2 tex : TEXCOORD0;
+};
+
+struct PSSceneIn {
+  float4 pos : SV_Position;
+  float2 tex : TEXCOORD0;
+  float3 norm : NORMAL;
+
+uint   RTIndex      : SV_RenderTargetArrayIndex;
+};
+
+//////////////////////////////////////////////////////////////////////////////////////////
+// Simple forwarding Tessellation shaders
+
+struct HSPerVertexData {
+  // This is just the original vertex verbatim. In many real life cases this would be a
+  // control point instead
+  PSSceneIn v;
+};
+
+struct HSPerPatchData {
+  // We at least have to specify tess factors per patch
+  // As we're tesselating triangles, there will be 4 tess factors
+  // In real life case this might contain face normal, for example
+  float edges[3] : SV_TessFactor;
+  float inside : SV_InsideTessFactor;
+};
+
+// domain shader that actually outputs the triangle vertices
+[domain("tri")] PSSceneIn main(const float3 bary
+                               : SV_DomainLocation,
+                                 const OutputPatch<HSPerVertexData, 3> patch,
+                                 const HSPerPatchData perPatchData,
+                                 inout float x : X) {
+  PSSceneIn v;
+
+  // Compute interpolated coordinates
+  v.pos = patch[0].v.pos * bary.x + patch[1].v.pos * bary.y + patch[2].v.pos * bary.z;
+  v.tex = patch[0].v.tex * bary.x + patch[1].v.tex * bary.y + patch[2].v.tex * bary.z;
+  v.norm = patch[0].v.norm * bary.x + patch[1].v.norm * bary.y + patch[2].v.norm * bary.z;
+  v.RTIndex = 0;
+  return v;
+}

+ 22 - 0
tools/clang/test/CodeGenHLSL/inout5.hlsl

@@ -0,0 +1,22 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: loadInput.f32(i32 4
+// CHECK: loadInput.f32(i32 4
+// CHECK: loadInput.f32(i32 4
+// CHECK: loadInput.f32(i32 4
+
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+// CHECK: storeOutput.f32(i32 5
+
+
+float4 main(inout float4 a : A) : SV_POSITION
+{
+  return 0;
+}
+

+ 15 - 0
tools/clang/test/CodeGenHLSL/writeMaskBuf.hlsl

@@ -0,0 +1,15 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK-NOT: dx.op.bufferLoad
+
+struct foo {
+  uint4 bar;
+};
+
+RWStructuredBuffer<foo> buf;
+
+[numthreads(8, 8, 1)]
+void main(uint2 id : SV_DispatchThreadId) {
+  buf[id.x].bar.w = 1;
+  buf[id.y].bar.xz = int2(2,3);
+}

+ 15 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -374,6 +374,8 @@ public:
   TEST_METHOD(CodeGenInout1)
   TEST_METHOD(CodeGenInout2)
   TEST_METHOD(CodeGenInout3)
+  TEST_METHOD(CodeGenInout4)
+  TEST_METHOD(CodeGenInout5)
   TEST_METHOD(CodeGenInput1)
   TEST_METHOD(CodeGenInput2)
   TEST_METHOD(CodeGenInput3)
@@ -536,6 +538,7 @@ public:
   TEST_METHOD(CodeGenVec_Comp_Arg)
   TEST_METHOD(CodeGenVecCmpCond)
   TEST_METHOD(CodeGenWave)
+  TEST_METHOD(CodeGenWriteMaskBuf)
   TEST_METHOD(CodeGenWriteToInput)
   TEST_METHOD(CodeGenWriteToInput2)
   TEST_METHOD(CodeGenWriteToInput3)
@@ -2175,6 +2178,14 @@ TEST_F(CompilerTest, CodeGenInout3) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\inout3.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenInout4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\inout4.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenInout5) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\inout5.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenInput1) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\input1.hlsl");
 }
@@ -2824,6 +2835,10 @@ TEST_F(CompilerTest, CodeGenWave) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\wave.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenWriteMaskBuf) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenWriteToInput) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeToInput.hlsl");
 }