Переглянути джерело

Fix HLMatrixLower for library function UDT args

Tex Riddell 7 роки тому
батько
коміт
23301fc727
1 змінених файлів з 123 додано та 77 видалено
  1. 123 77
      lib/HLSL/HLMatrixLowerPass.cpp

+ 123 - 77
lib/HLSL/HLMatrixLowerPass.cpp

@@ -273,7 +273,8 @@ private:
   Value *GetMatrixForVec(Value *vecVal, Type *matTy);
   Value *GetMatrixForVec(Value *vecVal, Type *matTy);
 
 
   // Translate library function input/output to preserve function signatures
   // Translate library function input/output to preserve function signatures
-  void TranslateLibraryArgs(Function &F);
+  void TranslateArgForLibFunc(CallInst *CI);
+  void TranslateArgsForLibFunc(Function &F);
 
 
   // Replace matVal with vecVal on matUseInst.
   // Replace matVal with vecVal on matUseInst.
   void TrivialMatReplace(Value *matVal, Value *vecVal,
   void TrivialMatReplace(Value *matVal, Value *vecVal,
@@ -414,6 +415,23 @@ Instruction *HLMatrixLowerPass::MatCastToVec(CallInst *CI) {
   return MatIntrinsicToVec(CI);
   return MatIntrinsicToVec(CI);
 }
 }
 
 
+// Return GEP if value is Matrix resulting GEP from UDT alloca
+// UDT alloca must be there for library function args
+static GetElementPtrInst *GetIfMatrixGEPOfUDTAlloca(Value *V) {
+  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
+    if (IsMatrixType(GEP->getResultElementType())) {
+      Value *ptr = GEP->getPointerOperand();
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(ptr)) {
+        Type *ATy = AI->getAllocatedType();
+        if (ATy->isStructTy() && !IsMatrixType(ATy)) {
+          return GEP;
+        }
+      }
+    }
+  }
+  return nullptr;
+}
+
 Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
 Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   IRBuilder<> Builder(CI);
   IRBuilder<> Builder(CI);
   unsigned opcode = GetHLOpcode(CI);
   unsigned opcode = GetHLOpcode(CI);
@@ -423,7 +441,7 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::RowMatLoad: {
   case HLMatLoadStoreOpcode::RowMatLoad: {
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-    if (isa<AllocaInst>(matPtr)) {
+    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       result = Builder.CreateLoad(vecPtr);
       result = Builder.CreateLoad(vecPtr);
     } else
     } else
@@ -432,7 +450,7 @@ Instruction *HLMatrixLowerPass::MatLdStToVec(CallInst *CI) {
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::RowMatStore: {
   case HLMatLoadStoreOpcode::RowMatStore: {
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
     Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-    if (isa<AllocaInst>(matPtr)) {
+    if (isa<AllocaInst>(matPtr) || GetIfMatrixGEPOfUDTAlloca(matPtr)) {
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       Value *vecPtr = matToVecMap[cast<Instruction>(matPtr)];
       Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
       Value *matVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
       Value *vecVal =
       Value *vecVal =
@@ -893,6 +911,16 @@ void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
     if (HLModule::HasPreciseAttributeWithMetadata(AI))
     if (HLModule::HasPreciseAttributeWithMetadata(AI))
       HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
       HLModule::MarkPreciseAttributeWithMetadata(cast<Instruction>(vecVal));
 
 
+  } else if (GetIfMatrixGEPOfUDTAlloca(matInst)) {
+    // If GEP from alloca of non-matrix UDT, bitcast
+    IRBuilder<> Builder(matInst->getNextNode());
+    vecVal = Builder.CreateBitCast(matInst,
+      HLMatrixLower::LowerMatrixType(
+        matInst->getType()->getPointerElementType() )->getPointerTo());
+    // matrix equivalent of this new vector will be the original, retained GEP
+    vecToMatMap[vecVal] = matInst;
+    // Add to matInstsToKeep so we don't delete this GEP
+    matInstsToKeep.push_back(matInst);
   } else {
   } else {
     DXASSERT(0, "invalid inst");
     DXASSERT(0, "invalid inst");
   }
   }
@@ -2152,7 +2180,7 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       case HLOpcodeGroup::HLMatLoadStore: {
       case HLOpcodeGroup::HLMatLoadStore: {
         DXASSERT(matToVecMap.count(useCall), "must has vec version");
         DXASSERT(matToVecMap.count(useCall), "must has vec version");
         Value *vecUser = matToVecMap[useCall];
         Value *vecUser = matToVecMap[useCall];
-        if (AllocaInst *AI = dyn_cast<AllocaInst>(matVal)) {
+        if (isa<AllocaInst>(matVal) || GetIfMatrixGEPOfUDTAlloca(matVal)) {
           // Load Already translated in lowerToVec.
           // Load Already translated in lowerToVec.
           // Store val operand will be set by the val use.
           // Store val operand will be set by the val use.
           // Do nothing here.
           // Do nothing here.
@@ -2187,7 +2215,8 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
       }
       }
     } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
     } else if (BitCastInst *BCI = dyn_cast<BitCastInst>(useInst)) {
       // Just replace the src with vec version.
       // Just replace the src with vec version.
-      useInst->setOperand(0, vecVal);
+      if (useInst != vecVal)
+        useInst->setOperand(0, vecVal);
     } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
     } else if (ReturnInst *RI = dyn_cast<ReturnInst>(useInst)) {
       Value *newMatVal = GetMatrixForVec(vecVal, matTy);
       Value *newMatVal = GetMatrixForVec(vecVal, matTy);
       RI->setOperand(0, newMatVal);
       RI->setOperand(0, newMatVal);
@@ -2464,82 +2493,98 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   }
   }
 }
 }
 
 
-void HLMatrixLowerPass::TranslateLibraryArgs(Function &F) {
+void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
+  IRBuilder<> Builder(CI);
+  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+  switch (group) {
+  case HLOpcodeGroup::HLCast: {
+    HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
+    if (opcode == HLCastOpcode::RowMatrixToVecCast ||
+        opcode == HLCastOpcode::ColMatrixToVecCast) {
+      Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
+      Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
+                                        /*bOrigAllocaTy*/false,
+                                        matVal->getName());
+      //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
+      //  unsigned row, col;
+      //  HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+      //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+      //}
+      CI->replaceAllUsesWith(vecVal);
+      CI->eraseFromParent();
+    }
+  } break;
+  case HLOpcodeGroup::HLMatLoadStore: {
+    HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+    //bool bTranspose = false;
+    switch (opcode) {
+    case HLMatLoadStoreOpcode::ColMatStore:
+      //bTranspose = true;
+    case HLMatLoadStoreOpcode::RowMatStore: {
+      // shuffle if transposed, bitcast, and store
+      Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
+      //if (bTranspose) {
+      //  unsigned row, col;
+      //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+      //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+      //}
+      Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
+      Builder.CreateStore(vecVal, castPtr);
+      CI->eraseFromParent();
+    } break;
+    case HLMatLoadStoreOpcode::ColMatLoad:
+      //bTranspose = true;
+    case HLMatLoadStoreOpcode::RowMatLoad: {
+      // bitcast, load, and shuffle if transposed
+      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
+      Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
+      Value *vecVal = Builder.CreateLoad(castPtr);
+      //if (bTranspose) {
+      //  unsigned row, col;
+      //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
+      //  // row/col swapped for col major source
+      //  vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
+      //}
+      CI->replaceAllUsesWith(vecVal);
+      CI->eraseFromParent();
+    } break;
+    }
+  } break;
+  }
+}
+static CallInst *GetAsMatCastOrLdSt(Value* V) {
+  if (CallInst *CI = dyn_cast<CallInst>(V)) {
+    HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
+    switch (group) {
+    case HLOpcodeGroup::HLCast:
+    case HLOpcodeGroup::HLMatLoadStore:
+      return CI;
+      break;
+    }
+  }
+  return nullptr;
+}
+void HLMatrixLowerPass::TranslateArgsForLibFunc(Function &F) {
   // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
   // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
   // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
   // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
   // NOTE: Transpose has been removed, as it should have explicit cast op
   // NOTE: Transpose has been removed, as it should have explicit cast op
   //       ColMatrixToRowMatrix or RowMatrixToColMatrix.
   //       ColMatrixToRowMatrix or RowMatrixToColMatrix.
   for (auto &arg : F.args()) {
   for (auto &arg : F.args()) {
-    SmallVector<CallInst *, 4> Candidates;
-    for (User *U : arg.users()) {
-      if (CallInst *CI = dyn_cast<CallInst>(U)) {
-        HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-        switch (group) {
-        case HLOpcodeGroup::HLCast:
-        case HLOpcodeGroup::HLMatLoadStore:
-          Candidates.push_back(CI);
-          break;
-        }
-      }
-    }
-    for (CallInst *CI : Candidates) {
-      IRBuilder<> Builder(CI);
-      HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-      switch (group) {
-      case HLOpcodeGroup::HLCast: {
-        HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
-        if (opcode == HLCastOpcode::RowMatrixToVecCast ||
-            opcode == HLCastOpcode::ColMatrixToVecCast) {
-          Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
-          Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
-                                            /*bOrigAllocaTy*/false,
-                                            matVal->getName());
-          //if (opcode == HLCastOpcode::ColMatrixToVecCast) {
-          //  unsigned row, col;
-          //  HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
-          //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-          //}
-          CI->replaceAllUsesWith(vecVal);
-          CI->eraseFromParent();
-        }
-      } break;
-      case HLOpcodeGroup::HLMatLoadStore: {
-        HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
-        //bool bTranspose = false;
-        switch (opcode) {
-        case HLMatLoadStoreOpcode::ColMatStore:
-          //bTranspose = true;
-        case HLMatLoadStoreOpcode::RowMatStore: {
-          // shuffle if transposed, bitcast, and store
-          Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-          Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-          //if (bTranspose) {
-          //  unsigned row, col;
-          //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-          //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-          //}
-          Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
-          Builder.CreateStore(vecVal, castPtr);
-          CI->eraseFromParent();
-        } break;
-        case HLMatLoadStoreOpcode::ColMatLoad:
-          //bTranspose = true;
-        case HLMatLoadStoreOpcode::RowMatLoad: {
-          // bitcast, load, and shuffle if transposed
-          Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-          Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
-          Value *vecVal = Builder.CreateLoad(castPtr);
-          //if (bTranspose) {
-          //  unsigned row, col;
-          //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-          //  // row/col swapped for col major source
-          //  vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
-          //}
-          CI->replaceAllUsesWith(vecVal);
-          CI->eraseFromParent();
-        } break;
+    for (auto itU = arg.user_begin(); itU != arg.user_end();) {
+      Value *U = *(itU++);
+      if (CallInst *CI = GetAsMatCastOrLdSt(U)) {
+        TranslateArgForLibFunc(CI);
+      } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
+        if (IsMatrixType(GEP->getResultElementType())) {
+          // arg is struct, GEP returns matrix ptr
+          // look for load/store/cast to translate in GEP users
+          for (auto itGEPU = GEP->user_begin(); itGEPU != GEP->user_end();) {
+            Value *gepU = *(itGEPU++);
+            if (CallInst *CI = GetAsMatCastOrLdSt(gepU))
+              TranslateArgForLibFunc(CI);
+          }
         }
         }
-      } break;
       }
       }
     }
     }
   }
   }
@@ -2577,6 +2622,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
           // Lower it here to make sure it is ready before replace.
           // Lower it here to make sure it is ready before replace.
           lowerToVec(&I);
           lowerToVec(&I);
         }
         }
+      } else if (GetIfMatrixGEPOfUDTAlloca(&I)) {
+        lowerToVec(&I);
       }
       }
     }
     }
   }
   }
@@ -2616,9 +2663,8 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
   vecToMatMap.clear();
   vecToMatMap.clear();
 
 
   // If this is a library function, now fix input/output matrix params
   // If this is a library function, now fix input/output matrix params
-  // TODO: What about Patch Constant Shaders?
   if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
   if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
-    TranslateLibraryArgs(F);
+    TranslateArgsForLibFunc(F);
   }
   }
   return;
   return;
 }
 }