Selaa lähdekoodia

Remove/refactor dead code paths for flattenning library functions

Tex Riddell 7 vuotta sitten
vanhempi
commit
d8588efbb1
1 muutettua tiedostoa jossa 143 lisäystä ja 543 poistoa
  1. 143 543
      lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

+ 143 - 543
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -4097,27 +4097,20 @@ private:
   void moveFunctionBody(Function *F, Function *flatF);
   void replaceCall(Function *F, Function *flatF);
   void createFlattenedFunction(Function *F);
-  void createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI);
   void
   flattenArgument(Function *F, Value *Arg, bool bForParam,
                   DxilParameterAnnotation &paramAnnotation,
                   std::vector<Value *> &FlatParamList,
                   std::vector<DxilParameterAnnotation> &FlatRetAnnotationList,
-                  IRBuilder<> &Builder, DbgDeclareInst *DDI,
-                  bool hasShaderInputOutput);
+                  IRBuilder<> &Builder, DbgDeclareInst *DDI);
   Value *castResourceArgIfRequired(Value *V, Type *Ty, bool bOut,
                                    DxilParamInputQual inputQual,
                                    IRBuilder<> &Builder);
   Value *castArgumentIfRequired(Value *V, Type *Ty, bool bOut,
-                                bool hasShaderInputOutput,
                                 DxilParamInputQual inputQual,
                                 DxilFieldAnnotation &annotation,
                                 std::deque<Value *> &WorkList,
                                 IRBuilder<> &Builder);
-  // Replace argument which changed type when flatten.
-  void replaceCastArgument(Value *&NewArg, Value *OldArg,
-                           DxilParamInputQual inputQual,
-                           IRBuilder<> &CallBuilder, IRBuilder<> &RetBuilder);
   // Replace use of parameter which changed type when flatten.
   // Also add information to Arg if required.
   void replaceCastParameter(Value *NewParam, Value *OldParam, Function &F,
@@ -4718,84 +4711,6 @@ static void CastCopyNewPtrToOldPtr(Value *NewPtr, Value *OldPtr, HLModule &HLM,
   }
 }
 
-void SROA_Parameter_HLSL::replaceCastArgument(Value *&NewArg, Value *OldArg,
-                                              DxilParamInputQual inputQual,
-                                              IRBuilder<> &CallBuilder,
-                                              IRBuilder<> &RetBuilder) {
-  Type *HandleTy = m_pHLModule->GetOP()->GetHandleType();
-
-  Type *NewTy = NewArg->getType();
-  Type *OldTy = OldArg->getType();
-
-  bool bIn = inputQual == DxilParamInputQual::Inout ||
-             inputQual == DxilParamInputQual::In;
-  bool bOut = inputQual == DxilParamInputQual::Inout ||
-              inputQual == DxilParamInputQual::Out;
-
-  if (NewArg->getType() == HandleTy) {
-    Value *Handle =
-        CastResourcePtrToHandle(OldArg, HandleTy, *m_pHLModule, CallBuilder);
-    // Use Handle as NewArg.
-    NewArg = Handle;
-  } else if (vectorEltsMap.count(NewArg)) {
-    Type *VecTy = OldTy;
-    if (VecTy->isPointerTy())
-      VecTy = VecTy->getPointerElementType();
-
-    // Flattened vector.
-    SmallVector<Value *, 4> &elts = vectorEltsMap[NewArg];
-    unsigned vecSize = elts.size();
-
-    if (NewTy->isPointerTy()) {
-      if (bIn) {
-        // Copy OldArg to NewArg before Call.
-        CopyVectorPtrToEltsPtr(OldArg, elts, vecSize, CallBuilder);
-      }
-
-      // bOut must be true here.
-      // Store NewArg to  OldArg after Call.
-      CopyEltsPtrToVectorPtr(elts, OldArg, VecTy, vecSize, RetBuilder);
-    } else {
-      // Must be in parameter.
-      // Copy OldArg to NewArg before Call.
-      Value *Vec = OldArg;
-      if (OldTy->isPointerTy()) {
-        Vec = CallBuilder.CreateLoad(OldArg);
-      }
-
-      for (unsigned i = 0; i < vecSize; i++) {
-        Value *Elt = CallBuilder.CreateExtractElement(Vec, i);
-        // Save elt to update arg in createFlattenedFunctionCall.
-        elts[i] = Elt;
-      }
-    }
-    // Don't need elts anymore.
-    vectorEltsMap.erase(NewArg);
-  } else if (!NewTy->isPointerTy()) {
-    // Ptr param is cast to non-ptr param.
-    // Must be in param.
-    // Load OldArg as NewArg before call.
-    NewArg = CallBuilder.CreateLoad(OldArg);
-  } else if (HLMatrixLower::IsMatrixType(OldTy)) {
-    bool bRowMajor = castRowMajorParamMap.count(NewArg);
-    CopyMatToArrayPtr(OldArg, NewArg, /*arrayBaseIdx*/ 0, *m_pHLModule,
-                      CallBuilder, bRowMajor);
-  } else {
-    bool bRowMajor = castRowMajorParamMap.count(NewArg);
-    // NewTy is pointer type.
-    // Copy OldArg to NewArg before Call.
-    if (bIn) {
-      CastCopyOldPtrToNewPtr(OldArg, NewArg, *m_pHLModule, HandleTy,
-                             CallBuilder, bRowMajor);
-    }
-    if (bOut) {
-      // Store NewArg to OldArg after Call.
-      CastCopyNewPtrToOldPtr(NewArg, OldArg, *m_pHLModule, HandleTy, RetBuilder,
-                             bRowMajor);
-    }
-  }
-}
-
 void SROA_Parameter_HLSL::replaceCastParameter(
     Value *NewParam, Value *OldParam, Function &F, Argument *Arg,
     const DxilParamInputQual inputQual, IRBuilder<> &Builder) {
@@ -4956,7 +4871,7 @@ Value *SROA_Parameter_HLSL::castResourceArgIfRequired(
 }
 
 Value *SROA_Parameter_HLSL::castArgumentIfRequired(
-    Value *V, Type *Ty, bool bOut, bool hasShaderInputOutput,
+    Value *V, Type *Ty, bool bOut,
     DxilParamInputQual inputQual, DxilFieldAnnotation &annotation,
     std::deque<Value *> &WorkList, IRBuilder<> &Builder) {
   Module &M = *m_pHLModule->GetModule();
@@ -4965,8 +4880,7 @@ Value *SROA_Parameter_HLSL::castArgumentIfRequired(
     Value *Ptr = Builder.CreateAlloca(Ty);
     V->replaceAllUsesWith(Ptr);
     // Create load here to make correct type.
-    // The Ptr will be store with correct value in replaceCastParameter and
-    // replaceCastArgument.
+    // The Ptr will be store with correct value in replaceCastParameter.
     if (Ptr->hasOneUse()) {
       // Load after existing user for call arg replace.
       // If not, call arg will load undef.
@@ -4983,153 +4897,61 @@ Value *SROA_Parameter_HLSL::castArgumentIfRequired(
 
   V = castResourceArgIfRequired(V, Ty, bOut, inputQual, Builder);
 
-  if (!hasShaderInputOutput) {
-    if (Ty->isVectorTy()) {
-      Value *OldV = V;
-      Type *EltTy = Ty->getVectorElementType();
-      unsigned vecSize = Ty->getVectorNumElements();
-
-      // Split vector into scalars.
-      if (OldV->getType()->isPointerTy()) {
-        // Split into scalar ptr.
-        V = Builder.CreateAlloca(EltTy);
-        vectorEltsMap[V].emplace_back(V);
-        for (unsigned i = 1; i < vecSize; i++) {
-          Value *Elt = Builder.CreateAlloca(EltTy);
-          vectorEltsMap[V].emplace_back(Elt);
-        }
-      } else {
-        IRBuilder<> TmpBuilder(Builder.GetInsertPoint());
-        // Make sure extract element after OldV.
-        if (Instruction *OldI = dyn_cast<Instruction>(OldV)) {
-          TmpBuilder.SetInsertPoint(OldI->getNextNode());
-        }
-        // Split into scalar.
-        V = TmpBuilder.CreateExtractElement(OldV, (uint64_t)0);
-        vectorEltsMap[V].emplace_back(V);
-        for (unsigned i = 1; i < vecSize; i++) {
-          Value *Elt = TmpBuilder.CreateExtractElement(OldV, i);
-          vectorEltsMap[V].emplace_back(Elt);
-        }
-      }
-      // Add to work list by reverse order.
-      for (unsigned i = vecSize - 1; i > 0; i--) {
-        Value *Elt = vectorEltsMap[V][i];
-        WorkList.push_front(Elt);
-      }
-      // For case OldV is from input vector ptr.
-      if (castParamMap.count(OldV)) {
-        OldV = castParamMap[OldV].first;
-      }
-      castParamMap[V] = std::make_pair(OldV, inputQual);
-    } else if (HLMatrixLower::IsMatrixType(Ty)) {
-      unsigned col, row;
-      Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
-      Value *Mat = V;
-      // Cast matrix to array.
-      Type *AT = ArrayType::get(EltTy, col * row);
-      V = Builder.CreateAlloca(AT);
-      castParamMap[V] = std::make_pair(Mat, inputQual);
-
-      DXASSERT(annotation.HasMatrixAnnotation(), "need matrix annotation here");
-      if (annotation.GetMatrixAnnotation().Orientation ==
-          hlsl::MatrixOrientation::RowMajor) {
-        castRowMajorParamMap.insert(V);
-      }
-    } else if (Ty->isArrayTy()) {
-      unsigned arraySize = 1;
-      Type *AT = Ty;
-      unsigned dim = 0;
-      while (AT->isArrayTy()) {
-        ++dim;
-        arraySize *= AT->getArrayNumElements();
-        AT = AT->getArrayElementType();
-      }
-
-      if (VectorType *VT = dyn_cast<VectorType>(AT)) {
-        Value *VecArray = V;
-        Type *AT = ArrayType::get(VT->getElementType(),
-                                  arraySize * VT->getNumElements());
-        V = Builder.CreateAlloca(AT);
-        castParamMap[V] = std::make_pair(VecArray, inputQual);
-      } else if (HLMatrixLower::IsMatrixType(AT)) {
-        unsigned col, row;
-        Type *EltTy = HLMatrixLower::GetMatrixInfo(AT, col, row);
-        Value *MatArray = V;
-        Type *AT = ArrayType::get(EltTy, arraySize * col * row);
-        V = Builder.CreateAlloca(AT);
-        castParamMap[V] = std::make_pair(MatArray, inputQual);
-        DXASSERT(annotation.HasMatrixAnnotation(),
-                 "need matrix annotation here");
-        if (annotation.GetMatrixAnnotation().Orientation ==
-            hlsl::MatrixOrientation::RowMajor) {
-          castRowMajorParamMap.insert(V);
-        }
-      } else if (dim > 1) {
-        // Flatten multi-dim array to 1dim.
-        Value *MultiArray = V;
-        V = Builder.CreateAlloca(
-            ArrayType::get(VT->getElementType(), arraySize));
-        castParamMap[V] = std::make_pair(MultiArray, inputQual);
-      }
-    }
-  } else {
-    // Entry function matrix value parameter has major.
-    // Make sure its user use row major matrix value.
-    bool updateToColMajor = annotation.HasMatrixAnnotation() &&
-                            annotation.GetMatrixAnnotation().Orientation ==
-                                MatrixOrientation::ColumnMajor;
-    if (updateToColMajor) {
-      if (V->getType()->isPointerTy()) {
-        for (User *user : V->users()) {
-          CallInst *CI = dyn_cast<CallInst>(user);
-          if (!CI)
-            continue;
+  // Entry function matrix value parameter has major.
+  // Make sure its user use row major matrix value.
+  bool updateToColMajor = annotation.HasMatrixAnnotation() &&
+                          annotation.GetMatrixAnnotation().Orientation ==
+                              MatrixOrientation::ColumnMajor;
+  if (updateToColMajor) {
+    if (V->getType()->isPointerTy()) {
+      for (User *user : V->users()) {
+        CallInst *CI = dyn_cast<CallInst>(user);
+        if (!CI)
+          continue;
 
-          HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-          if (group != HLOpcodeGroup::HLMatLoadStore)
-            continue;
-          HLMatLoadStoreOpcode opcode =
-              static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
-          Type *opcodeTy = Builder.getInt32Ty();
-          switch (opcode) {
-          case HLMatLoadStoreOpcode::RowMatLoad: {
-            // Update matrix function opcode to col major version.
-            Value *rowOpArg = ConstantInt::get(
-                opcodeTy,
-                static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
-            CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
-            // Cast it to row major.
-            CallInst *RowMat = HLModule::EmitHLOperationCall(
-                Builder, HLOpcodeGroup::HLCast,
-                (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
-            CI->replaceAllUsesWith(RowMat);
-            // Set arg to CI again.
-            RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
-          } break;
-          case HLMatLoadStoreOpcode::RowMatStore:
-            // Update matrix function opcode to col major version.
-            Value *rowOpArg = ConstantInt::get(
-                opcodeTy,
-                static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
-            CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
-            Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-            // Cast it to col major.
-            CallInst *RowMat = HLModule::EmitHLOperationCall(
-                Builder, HLOpcodeGroup::HLCast,
-                (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
-            CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
-            break;
-          }
+        HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+        if (group != HLOpcodeGroup::HLMatLoadStore)
+          continue;
+        HLMatLoadStoreOpcode opcode =
+            static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(CI));
+        Type *opcodeTy = Builder.getInt32Ty();
+        switch (opcode) {
+        case HLMatLoadStoreOpcode::RowMatLoad: {
+          // Update matrix function opcode to col major version.
+          Value *rowOpArg = ConstantInt::get(
+              opcodeTy,
+              static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad));
+          CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
+          // Cast it to row major.
+          CallInst *RowMat = HLModule::EmitHLOperationCall(
+              Builder, HLOpcodeGroup::HLCast,
+              (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {CI}, M);
+          CI->replaceAllUsesWith(RowMat);
+          // Set arg to CI again.
+          RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, CI);
+        } break;
+        case HLMatLoadStoreOpcode::RowMatStore:
+          // Update matrix function opcode to col major version.
+          Value *rowOpArg = ConstantInt::get(
+              opcodeTy,
+              static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore));
+          CI->setOperand(HLOperandIndex::kOpcodeIdx, rowOpArg);
+          Value *Mat = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+          // Cast it to col major.
+          CallInst *RowMat = HLModule::EmitHLOperationCall(
+              Builder, HLOpcodeGroup::HLCast,
+              (unsigned)HLCastOpcode::RowMatrixToColMatrix, Ty, {Mat}, M);
+          CI->setArgOperand(HLOperandIndex::kMatStoreValOpIdx, RowMat);
+          break;
         }
-      } else {
-        CallInst *RowMat = HLModule::EmitHLOperationCall(
-            Builder, HLOpcodeGroup::HLCast,
-            (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
-        V->replaceAllUsesWith(RowMat);
-        // Set arg to V again.
-        RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
       }
+    } else {
+      CallInst *RowMat = HLModule::EmitHLOperationCall(
+          Builder, HLOpcodeGroup::HLCast,
+          (unsigned)HLCastOpcode::ColMatrixToRowMatrix, Ty, {V}, M);
+      V->replaceAllUsesWith(RowMat);
+      // Set arg to V again.
+      RowMat->setArgOperand(HLOperandIndex::kUnaryOpSrc0Idx, V);
     }
   }
   return V;
@@ -5140,8 +4962,7 @@ void SROA_Parameter_HLSL::flattenArgument(
     DxilParameterAnnotation &paramAnnotation,
     std::vector<Value *> &FlatParamList,
     std::vector<DxilParameterAnnotation> &FlatAnnotationList,
-    IRBuilder<> &Builder, DbgDeclareInst *DDI,
-    bool hasShaderInputOutput) {
+    IRBuilder<> &Builder, DbgDeclareInst *DDI) {
   std::deque<Value *> WorkList;
   WorkList.push_back(Arg);
 
@@ -5343,7 +5164,7 @@ void SROA_Parameter_HLSL::flattenArgument(
       }
 
       // Cast vector/matrix/resource parameter.
-      V = castArgumentIfRequired(V, Ty, bOut, hasShaderInputOutput, inputQual,
+      V = castArgumentIfRequired(V, Ty, bOut, inputQual,
                                  annotation, WorkList, Builder);
 
       // Cannot SROA, save it to final parameter list.
@@ -5781,6 +5602,11 @@ static void LegalizeDxilInputOutputs(Function *F,
 void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
 
+  DXASSERT(F == m_pHLModule->GetEntryFunction() ||
+           m_pHLModule->IsEntryThatUsesSignatures(F),
+    "otherwise, createFlattenedFunction called on library function "
+    "that should not be flattened.");
+
   // Skip void (void) function.
   if (F->getReturnType()->isVoidTy() && F->getArgumentList().empty()) {
     return;
@@ -5796,21 +5622,13 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
 
   LLVMContext &Ctx = m_pHLModule->GetCtx();
   std::unique_ptr<BasicBlock> TmpBlockForFuncDecl;
-  bool hasShaderInputOutput = false;
   if (F->isDeclaration()) {
     TmpBlockForFuncDecl.reset(BasicBlock::Create(Ctx));
     // Create return as terminator.
     IRBuilder<> RetBuilder(TmpBlockForFuncDecl.get());
     RetBuilder.CreateRetVoid();
-  } else {
-    hasShaderInputOutput = F == m_pHLModule->GetEntryFunction() ||
-                           m_pHLModule->IsEntryThatUsesSignatures(F);
   }
 
-  // Skip flattenning for library functions
-  if (!hasShaderInputOutput)
-    return;
-
   std::vector<Value *> FlatParamList;
   std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
   std::vector<int> FlatParamOriArgNoList;
@@ -5835,8 +5653,7 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
         funcAnnotation->GetParameterAnnotation(Arg.getArgNo());
     DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(&Arg);
     flattenArgument(F, &Arg, bForParamTrue, paramAnnotation, FlatParamList,
-                    FlatParamAnnotationList, Builder, DDI,
-                    hasShaderInputOutput);
+                    FlatParamAnnotationList, Builder, DDI);
 
     unsigned newFlatParamCount = FlatParamList.size() - prevFlatParamCount;
     for (unsigned i = 0; i < newFlatParamCount; i++) {
@@ -5845,97 +5662,95 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   }
 
   Type *retType = F->getReturnType();
-  if (hasShaderInputOutput) {
-    // Only flatten return parameter if this is a shader entry function using signatures
-    std::vector<Value *> FlatRetList;
-    std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
-    // Split and change to out parameter.
-    if (!retType->isVoidTy()) {
-      IRBuilder<> Builder(Ctx);
-      if (!F->isDeclaration()) {
-        Builder.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
-      } else {
-        Builder.SetInsertPoint(TmpBlockForFuncDecl->getFirstInsertionPt());
-      }
-      Value *retValAddr = Builder.CreateAlloca(retType);
-      DxilParameterAnnotation &retAnnotation =
-          funcAnnotation->GetRetTypeAnnotation();
-      Module &M = *m_pHLModule->GetModule();
-      Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
-      // Create DbgDecl for the ret value.
-      if (DISubprogram *funcDI = getDISubprogram(F)) {
-         DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
-         DITypeIdentifierMap EmptyMap;
-         DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
-         DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
-         DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
-             funcDI->getLine(), RetDIType);
-         DIExpression *Expr = nullptr;
-         // TODO: how to get col?
-         DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0,  funcDI);
-         DIB.insertDeclare(retValAddr, RetVar, Expr, DL, Builder.GetInsertPoint());
-      }
-      for (BasicBlock &BB : F->getBasicBlockList()) {
-        if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
-          // Create store for return.
-          IRBuilder<> RetBuilder(RI);
-          if (!retAnnotation.HasMatrixAnnotation()) {
-            RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
-          } else {
-            bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
-                              MatrixOrientation::RowMajor;
-            Value *RetVal = RI->getReturnValue();
-            if (!isRowMajor) {
-              // Matrix value is row major. ColMatStore require col major.
-              // Cast before store.
-              RetVal = HLModule::EmitHLOperationCall(
-                  RetBuilder, HLOpcodeGroup::HLCast,
-                  static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
-                  RetVal->getType(), {RetVal}, M);
-            }
-            unsigned opcode = static_cast<unsigned>(
-                isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
-                           : HLMatLoadStoreOpcode::ColMatStore);
-            HLModule::EmitHLOperationCall(RetBuilder,
-                                          HLOpcodeGroup::HLMatLoadStore, opcode,
-                                          voidTy, {retValAddr, RetVal}, M);
+
+  std::vector<Value *> FlatRetList;
+  std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
+  // Split and change to out parameter.
+  if (!retType->isVoidTy()) {
+    IRBuilder<> Builder(Ctx);
+    if (!F->isDeclaration()) {
+      Builder.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt());
+    } else {
+      Builder.SetInsertPoint(TmpBlockForFuncDecl->getFirstInsertionPt());
+    }
+    Value *retValAddr = Builder.CreateAlloca(retType);
+    DxilParameterAnnotation &retAnnotation =
+        funcAnnotation->GetRetTypeAnnotation();
+    Module &M = *m_pHLModule->GetModule();
+    Type *voidTy = Type::getVoidTy(m_pHLModule->GetCtx());
+    // Create DbgDecl for the ret value.
+    if (DISubprogram *funcDI = getDISubprogram(F)) {
+        DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
+        DITypeIdentifierMap EmptyMap;
+        DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
+        DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
+        DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
+            funcDI->getLine(), RetDIType);
+        DIExpression *Expr = nullptr;
+        // TODO: how to get col?
+        DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0,  funcDI);
+        DIB.insertDeclare(retValAddr, RetVar, Expr, DL, Builder.GetInsertPoint());
+    }
+    for (BasicBlock &BB : F->getBasicBlockList()) {
+      if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) {
+        // Create store for return.
+        IRBuilder<> RetBuilder(RI);
+        if (!retAnnotation.HasMatrixAnnotation()) {
+          RetBuilder.CreateStore(RI->getReturnValue(), retValAddr);
+        } else {
+          bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
+                            MatrixOrientation::RowMajor;
+          Value *RetVal = RI->getReturnValue();
+          if (!isRowMajor) {
+            // Matrix value is row major. ColMatStore require col major.
+            // Cast before store.
+            RetVal = HLModule::EmitHLOperationCall(
+                RetBuilder, HLOpcodeGroup::HLCast,
+                static_cast<unsigned>(HLCastOpcode::RowMatrixToColMatrix),
+                RetVal->getType(), {RetVal}, M);
           }
+          unsigned opcode = static_cast<unsigned>(
+              isRowMajor ? HLMatLoadStoreOpcode::RowMatStore
+                          : HLMatLoadStoreOpcode::ColMatStore);
+          HLModule::EmitHLOperationCall(RetBuilder,
+                                        HLOpcodeGroup::HLMatLoadStore, opcode,
+                                        voidTy, {retValAddr, RetVal}, M);
         }
       }
-      // Create a fake store to keep retValAddr so it can be flattened.
-      if (retValAddr->user_empty()) {
-        Builder.CreateStore(UndefValue::get(retType), retValAddr);
-      }
+    }
+    // Create a fake store to keep retValAddr so it can be flattened.
+    if (retValAddr->user_empty()) {
+      Builder.CreateStore(UndefValue::get(retType), retValAddr);
+    }
 
-      DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(retValAddr);
-      flattenArgument(F, retValAddr, bForParamTrue,
-                      funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
-                      FlatRetAnnotationList, Builder, DDI,
-                      hasShaderInputOutput);
+    DbgDeclareInst *DDI = llvm::FindAllocaDbgDeclare(retValAddr);
+    flattenArgument(F, retValAddr, bForParamTrue,
+                    funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
+                    FlatRetAnnotationList, Builder, DDI);
 
-      const int kRetArgNo = -1;
-      for (unsigned i = 0; i < FlatRetList.size(); i++) {
-        FlatParamOriArgNoList.emplace_back(kRetArgNo);
-      }
+    const int kRetArgNo = -1;
+    for (unsigned i = 0; i < FlatRetList.size(); i++) {
+      FlatParamOriArgNoList.emplace_back(kRetArgNo);
     }
+  }
 
-    // Always change return type as parameter.
-    // By doing this, no need to check return when generate storeOutput.
-    if (FlatRetList.size() ||
-        // For empty struct return type.
-        !retType->isVoidTy()) {
-      // Return value is flattened.
-      // Change return value into out parameter.
-      retType = Type::getVoidTy(retType->getContext());
-      // Merge return data info param data.
-      FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
+  // Always change return type as parameter.
+  // By doing this, no need to check return when generate storeOutput.
+  if (FlatRetList.size() ||
+      // For empty struct return type.
+      !retType->isVoidTy()) {
+    // Return value is flattened.
+    // Change return value into out parameter.
+    retType = Type::getVoidTy(retType->getContext());
+    // Merge return data info param data.
+    FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
 
-      FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
-                                     FlatRetAnnotationList.begin(),
-                                     FlatRetAnnotationList.end());
-    }
+    FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
+                                    FlatRetAnnotationList.begin(),
+                                    FlatRetAnnotationList.end());
   }
 
+
   std::vector<Type *> FinalTypeList;
   for (Value * arg : FlatParamList) {
     FinalTypeList.emplace_back(arg->getType());
@@ -6106,216 +5921,6 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   }
 }
 
-void SROA_Parameter_HLSL::createFlattenedFunctionCall(Function *F, Function *flatF, CallInst *CI) {
-  DxilFunctionAnnotation *funcAnnotation = m_pHLModule->GetFunctionAnnotation(F);
-  DXASSERT(funcAnnotation, "must find annotation for function");
-
-  // Clear maps for cast.
-  castParamMap.clear();
-  vectorEltsMap.clear();
-
-  DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
-
-  std::vector<Value *> FlatParamList;
-  std::vector<DxilParameterAnnotation> FlatParamAnnotationList;
-
-  IRBuilder<> AllocaBuilder(
-      CI->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
-  IRBuilder<> CallBuilder(CI);
-  IRBuilder<> RetBuilder(CI->getNextNode());
-
-  const bool bForParamFalse = false;
-#if 0 // Disable return parameter movement to argument and flattening
-  Type *retType = F->getReturnType();
-  std::vector<Value *> FlatRetList;
-  std::vector<DxilParameterAnnotation> FlatRetAnnotationList;
-  // Split and change to out parameter.
-  if (!retType->isVoidTy()) {
-    Value *retValAddr = AllocaBuilder.CreateAlloca(retType);
-    // Create DbgDecl for the ret value.
-    if (DISubprogram *funcDI = getDISubprogram(F)) {
-       DITypeRef RetDITyRef = funcDI->getType()->getTypeArray()[0];
-       DITypeIdentifierMap EmptyMap;
-       DIType * RetDIType = RetDITyRef.resolve(EmptyMap);
-       DIBuilder DIB(*F->getParent(), /*AllowUnresolved*/ false);
-       DILocalVariable *RetVar = DIB.createLocalVariable(llvm::dwarf::Tag::DW_TAG_arg_variable, funcDI, F->getName().str() + ".Ret", funcDI->getFile(),
-           funcDI->getLine(), RetDIType);
-       DIExpression *Expr = nullptr;
-       // TODO: how to get col?
-       DILocation *DL = DILocation::get(F->getContext(), funcDI->getLine(), 0,  funcDI);
-       DIB.insertDeclare(retValAddr, RetVar, Expr, DL, CI);
-    }
-
-    DxilParameterAnnotation &retAnnotation = funcAnnotation->GetRetTypeAnnotation();
-    // Load ret value and replace CI.
-    Value *newRetVal = nullptr;
-    if (!retAnnotation.HasMatrixAnnotation()) {
-      newRetVal = RetBuilder.CreateLoad(retValAddr);
-    } else {
-      bool isRowMajor = retAnnotation.GetMatrixAnnotation().Orientation ==
-                        MatrixOrientation::RowMajor;
-      unsigned opcode =
-          static_cast<unsigned>(isRowMajor ? HLMatLoadStoreOpcode::RowMatLoad
-                                           : HLMatLoadStoreOpcode::ColMatLoad);
-      newRetVal = HLModule::EmitHLOperationCall(RetBuilder, HLOpcodeGroup::HLMatLoadStore,
-                                    opcode, retType, {retValAddr},
-                                    *m_pHLModule->GetModule());
-      if (!isRowMajor) {
-        // ColMatLoad will return a col major.
-        // Matrix value should be row major.
-        // Cast it here.
-        newRetVal = HLModule::EmitHLOperationCall(
-            RetBuilder, HLOpcodeGroup::HLCast,
-            static_cast<unsigned>(HLCastOpcode::ColMatrixToRowMatrix), retType,
-            {newRetVal}, *m_pHLModule->GetModule());
-      }
-    }
-    CI->replaceAllUsesWith(newRetVal);
-    // Flat ret val
-    flattenArgument(flatF, retValAddr, bForParamFalse,
-                    funcAnnotation->GetRetTypeAnnotation(), FlatRetList,
-                    FlatRetAnnotationList, AllocaBuilder,
-                    /*DbgDeclareInst*/ nullptr,
-                    /*hasShaderInputOutput*/false);
-  }
-#endif // Disable return parameter movement to argument and flattening
-
-  std::vector<Value *> args;
-  for (auto &arg : CI->arg_operands()) {
-    args.emplace_back(arg.get());
-  }
-  // Remove CI from user of args.
-  CI->dropAllReferences();
-
-  // Add all argument to worklist.
-  for (unsigned i=0;i<args.size();i++) {
-    DxilParameterAnnotation &paramAnnotation =
-        funcAnnotation->GetParameterAnnotation(i);
-    Value *arg = args[i];
-    Type *Ty = arg->getType();
-    if (Ty->isPointerTy()) {
-      // For pointer, alloca another pointer, replace in CI.
-      Value *tempArg =
-          AllocaBuilder.CreateAlloca(arg->getType()->getPointerElementType());
-
-      DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
-      // TODO: support special InputQual like InputPatch.
-      if (inputQual == DxilParamInputQual::In ||
-          inputQual == DxilParamInputQual::Inout) {
-        // Copy in param.
-        llvm::SmallVector<llvm::Value *, 16> idxList;
-        // split copy to avoid load of struct.
-        SplitCpy(Ty, tempArg, arg, idxList, CallBuilder, typeSys,
-                 &paramAnnotation);
-      }
-
-      if (inputQual == DxilParamInputQual::Out ||
-          inputQual == DxilParamInputQual::Inout) {
-        // Copy out param.
-        llvm::SmallVector<llvm::Value *, 16> idxList;
-        // split copy to avoid load of struct.
-        SplitCpy(Ty, arg, tempArg, idxList, RetBuilder, typeSys,
-                 &paramAnnotation);
-      }
-      arg = tempArg;
-      flattenArgument(flatF, arg, bForParamFalse, paramAnnotation,
-                      FlatParamList, FlatParamAnnotationList, AllocaBuilder,
-                      /*DbgDeclareInst*/ nullptr,
-                      /*hasShaderInputOutput*/false);
-    } else {
-      // Cast vector into array.
-      if (Ty->isVectorTy()) {
-        unsigned vecSize = Ty->getVectorNumElements();
-        for (unsigned vi = 0; vi < vecSize; vi++) {
-          Value *Elt = CallBuilder.CreateExtractElement(arg, vi);
-          // Cannot SROA, save it to final parameter list.
-          FlatParamList.emplace_back(Elt);
-          // Create ParamAnnotation for V.
-          FlatParamAnnotationList.emplace_back(DxilParameterAnnotation());
-          DxilParameterAnnotation &flatParamAnnotation =
-            FlatParamAnnotationList.back();
-          flatParamAnnotation = paramAnnotation;
-        }
-      } else if (HLMatrixLower::IsMatrixType(Ty)) {
-        unsigned col, row;
-        Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
-        Value *Mat = arg;
-        // Cast matrix to array.
-        Type *AT = ArrayType::get(EltTy, col * row);
-        arg = AllocaBuilder.CreateAlloca(AT);
-        DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual();
-        castParamMap[arg] = std::make_pair(Mat, inputQual);
-
-        DXASSERT(paramAnnotation.HasMatrixAnnotation(),
-                 "need matrix annotation here");
-        if (paramAnnotation.GetMatrixAnnotation().Orientation ==
-            hlsl::MatrixOrientation::RowMajor) {
-          castRowMajorParamMap.insert(arg);
-        }
-
-        // Cannot SROA, save it to final parameter list.
-        FlatParamList.emplace_back(arg);
-        // Create ParamAnnotation for V.
-        FlatParamAnnotationList.emplace_back(DxilParameterAnnotation());
-        DxilParameterAnnotation &flatParamAnnotation =
-          FlatParamAnnotationList.back();
-        flatParamAnnotation = paramAnnotation;
-      } else {
-        // Cannot SROA, save it to final parameter list.
-        FlatParamList.emplace_back(arg);
-        // Create ParamAnnotation for V.
-        FlatParamAnnotationList.emplace_back(DxilParameterAnnotation());
-        DxilParameterAnnotation &flatParamAnnotation =
-          FlatParamAnnotationList.back();
-        flatParamAnnotation = paramAnnotation;
-      }
-    }
-  }
-
-#if 0 // Disable return parameter movement to argument and flattening
-  // Always change return type as parameter.
-  // By doing this, no need to check return when generate storeOutput.
-  if (FlatRetList.size() ||
-      // For empty struct return type.
-      !retType->isVoidTy()) {
-    // Merge return data info param data.
-    FlatParamList.insert(FlatParamList.end(), FlatRetList.begin(), FlatRetList.end());
-
-    FlatParamAnnotationList.insert(FlatParamAnnotationList.end(),
-                                   FlatRetAnnotationList.begin(),
-                                   FlatRetAnnotationList.end());
-  }
-#endif // Disable return parameter movement to argument and flattening
-
-  RetBuilder.SetInsertPoint(CI->getNextNode());
-  unsigned paramSize = FlatParamList.size();
-  for (unsigned i = 0; i < paramSize; i++) {
-    Value *&flatArg = FlatParamList[i];
-    if (castParamMap.count(flatArg)) {
-      replaceCastArgument(flatArg, castParamMap[flatArg].first,
-                          castParamMap[flatArg].second, CallBuilder,
-                          RetBuilder);
-      if (vectorEltsMap.count(flatArg) && !flatArg->getType()->isPointerTy()) {
-        // Vector elements need to be updated.
-        SmallVector<Value *, 4> &elts = vectorEltsMap[flatArg];
-        // Back one step.
-        --i;
-        for (Value *elt : elts) {
-          FlatParamList[++i] = elt;
-        }
-        // Don't need elts anymore.
-        vectorEltsMap.erase(flatArg);
-      }
-    }
-  }
-
-  CallInst *NewCI = CallBuilder.CreateCall(flatF, FlatParamList);
-
-  CallBuilder.SetInsertPoint(NewCI);
-
-  CI->eraseFromParent();
-}
-
 void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
   // Update entry function.
   if (F == m_pHLModule->GetEntryFunction()) {
@@ -6332,12 +5937,7 @@ void SROA_Parameter_HLSL::replaceCall(Function *F, Function *flatF) {
       }
     }
   }
-  // TODO: flatten vector argument and lower resource argument when flatten
-  // functions.
-  for (auto it = F->user_begin(); it != F->user_end(); ) {
-    CallInst *CI = cast<CallInst>(*(it++));
-    createFlattenedFunctionCall(F, flatF, CI);
-  }
+  DXASSERT(F->user_empty(), "otherwise we flattened a library function.");
 }
 
 // Public interface to the SROA_Parameter_HLSL pass