瀏覽代碼

Merged PR 34: Fix matrix handling when array and other cases within function args.

Fix matrix handling when array and other cases within function args.
Tex Riddell 7 年之前
父節點
當前提交
2a9e758f06
共有 3 個文件被更改,包括 119 次插入146 次删除
  1. 0 141
      lib/HLSL/HLMatrixLowerPass.cpp
  2. 101 5
      lib/HLSL/HLOperationLower.cpp
  3. 18 0
      tools/clang/test/CodeGenHLSL/quick-test/lib_mat_array.hlsl

+ 0 - 141
lib/HLSL/HLMatrixLowerPass.cpp

@@ -807,37 +807,6 @@ static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, b
   }
 }
 
-static bool IsUserCall(Value *V) {
-  if (CallInst *CI = dyn_cast<CallInst>(V)) {
-    if (GetHLOpcodeGroupByName(CI->getCalledFunction()) == HLOpcodeGroup::NotHL) {
-      return true;
-    }
-  }
-  return false;
-}
-
-static bool UsedByUserCall(Value *V) {
-  for (auto U : V->users()) {
-    if (CallInst *CI = dyn_cast<CallInst>(U)) {
-      if (IsUserCall(U)) {
-        return true;
-      }
-    }
-    else if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
-      if (UsedByUserCall(U))
-        return true;
-    }
-    else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
-      if (IsUserCall(SI->getValueOperand()))
-        return true;
-    }
-    if (UsedByUserCall(U)) {
-      return true;
-    }
-  }
-  return false;
-}
-
 void HLMatrixLowerPass::lowerToVec(Instruction *matInst) {
   Value *vecVal = nullptr;
 
@@ -2498,112 +2467,6 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   }
 }
 
-void HLMatrixLowerPass::TranslateArgForLibFunc(CallInst *CI) {
-  IRBuilder<> Builder(CI);
-  HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-  switch (group) {
-  case HLOpcodeGroup::HLCast: {
-    HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
-    bool bTranspose = false;
-    bool bColSource = false;
-    switch (opcode) {
-    case HLCastOpcode::ColMatrixToRowMatrix:
-      bColSource = true;
-    case HLCastOpcode::RowMatrixToColMatrix:
-      bTranspose = true;
-    case HLCastOpcode::ColMatrixToVecCast:
-    case HLCastOpcode::RowMatrixToVecCast: {
-      Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
-      Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
-                                        /*bOrigAllocaTy*/false,
-                                        matVal->getName());
-      if (bTranspose) {
-        unsigned row, col;
-        HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
-        if (bColSource) std::swap(row,col);
-        vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-      }
-      CI->replaceAllUsesWith(vecVal);
-      CI->eraseFromParent();
-    } break;
-    }
-  } break;
-  case HLOpcodeGroup::HLMatLoadStore: {
-    HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
-    //bool bTranspose = false;
-    switch (opcode) {
-    case HLMatLoadStoreOpcode::ColMatStore:
-      //bTranspose = true;
-    case HLMatLoadStoreOpcode::RowMatStore: {
-      // shuffle if transposed, bitcast, and store
-      Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
-      //if (bTranspose) {
-      //  unsigned row, col;
-      //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-      //  vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
-      //}
-      Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
-      Builder.CreateStore(vecVal, castPtr);
-      CI->eraseFromParent();
-    } break;
-    case HLMatLoadStoreOpcode::ColMatLoad:
-      //bTranspose = true;
-    case HLMatLoadStoreOpcode::RowMatLoad: {
-      // bitcast, load, and shuffle if transposed
-      Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
-      Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
-      Value *vecVal = Builder.CreateLoad(castPtr);
-      //if (bTranspose) {
-      //  unsigned row, col;
-      //  HLMatrixLower::GetMatrixInfo(matPtr->getType()->getPointerElementType(), col, row);
-      //  // row/col swapped for col major source
-      //  vecVal = CreateTransposeShuffle(Builder, vecVal, col, row);
-      //}
-      CI->replaceAllUsesWith(vecVal);
-      CI->eraseFromParent();
-    } break;
-    }
-  } break;
-  }
-}
-static CallInst *GetAsMatCastOrLdSt(Value* V) {
-  if (CallInst *CI = dyn_cast<CallInst>(V)) {
-    HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
-    switch (group) {
-    case HLOpcodeGroup::HLCast:
-    case HLOpcodeGroup::HLMatLoadStore:
-      return CI;
-      break;
-    }
-  }
-  return nullptr;
-}
-void HLMatrixLowerPass::TranslateArgsForLibFunc(Function &F) {
-  // Replace HLCast with BitCastValueOrPtr (+ transpose for colMatToVec)
-  // Replace HLMatLoadStore with bitcast + load/store + shuffle if col major
-  // NOTE: Transpose has been removed, as it should have explicit cast op
-  //       ColMatrixToRowMatrix or RowMatrixToColMatrix.
-  for (auto &arg : F.args()) {
-    for (auto itU = arg.user_begin(); itU != arg.user_end();) {
-      Value *U = *(itU++);
-      if (CallInst *CI = GetAsMatCastOrLdSt(U)) {
-        TranslateArgForLibFunc(CI);
-      } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
-        if (IsMatrixType(GEP->getResultElementType())) {
-          // arg is struct, GEP returns matrix ptr
-          // look for load/store/cast to translate in GEP users
-          for (auto itGEPU = GEP->user_begin(); itGEPU != GEP->user_end();) {
-            Value *gepU = *(itGEPU++);
-            if (CallInst *CI = GetAsMatCastOrLdSt(gepU))
-              TranslateArgForLibFunc(CI);
-          }
-        }
-      }
-    }
-  }
-}
-
 void HLMatrixLowerPass::runOnFunction(Function &F) {
   // Skip hl function definition (like createhandle)
   if (hlsl::GetHLOpcodeGroupByName(&F) != HLOpcodeGroup::NotHL)
@@ -2676,9 +2539,5 @@ void HLMatrixLowerPass::runOnFunction(Function &F) {
   matToVecMap.clear();
   vecToMatMap.clear();
 
-  // If this is a library function, now fix input/output matrix params
-  if (!m_pHLModule->IsEntryThatUsesSignatures(&F)) {
-    TranslateArgsForLibFunc(F);
-  }
   return;
 }

+ 101 - 5
lib/HLSL/HLOperationLower.cpp

@@ -6715,6 +6715,38 @@ void TranslateSubscriptOperation(Function *F, HLOperationLowerHelper &helper,  H
   }
 }
 
+// Create BitCast if ptr, otherwise, create alloca of new type, write to bitcast of alloca, and return load from alloca
+// If bOrigAllocaTy is true: create alloca of old type instead, write to alloca, and return load from bitcast of alloca
+static Instruction *BitCastValueOrPtr(Value* V, Instruction *Insert, Type *Ty, bool bOrigAllocaTy = false, const Twine &Name = "") {
+  if (Ty->isPointerTy()) {
+    // If pointer, we can bitcast directly
+    IRBuilder<> Builder(Insert);
+    return cast<Instruction>(Builder.CreateBitCast(V, Ty, Name));
+  }
+  else {
+    // If value, we have to alloca, store to bitcast ptr, and load
+    IRBuilder<> EntryBuilder(Insert->getParent()->getParent()->getEntryBlock().begin());
+    Type *allocaTy = bOrigAllocaTy ? V->getType() : Ty;
+    Type *otherTy = bOrigAllocaTy ? Ty : V->getType();
+    Instruction *allocaInst = EntryBuilder.CreateAlloca(allocaTy);
+    IRBuilder<> Builder(Insert);
+    Instruction *bitCast = cast<Instruction>(Builder.CreateBitCast(allocaInst, otherTy->getPointerTo()));
+    Builder.CreateStore(V, bOrigAllocaTy ? allocaInst : bitCast);
+    return Builder.CreateLoad(bOrigAllocaTy ? bitCast : allocaInst, Name);
+  }
+}
+
+static Instruction *CreateTransposeShuffle(IRBuilder<> &Builder, Value *vecVal, unsigned row, unsigned col) {
+  SmallVector<int, 16> castMask(col * row);
+  unsigned idx = 0;
+  for (unsigned c = 0; c < col; c++)
+    for (unsigned r = 0; r < row; r++)
+      castMask[idx++] = r * col + c;
+  return cast<Instruction>(
+    Builder.CreateShuffleVector(vecVal, vecVal, castMask));
+}
+
+
 void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
                                hlsl::HLOpcodeGroup group, HLObjectOperationLowerHelper *pObjHelper) {
   if (group == HLOpcodeGroup::HLIntrinsic) {
@@ -6744,14 +6776,78 @@ void TranslateHLBuiltinOperation(Function *F, HLOperationLowerHelper &helper,
       Type *PtrTy =
           F->getFunctionType()->getParamType(HLOperandIndex::kMatLoadPtrOpIdx);
 
-      if (PtrTy->getPointerAddressSpace() == DXIL::kTGSMAddrSpace ||
-          // TODO: use DeviceAddressSpace for SRV/UAV and CBufferAddressSpace
-          // for CBuffer.
-          PtrTy->getPointerAddressSpace() == DXIL::kDefaultAddrSpace) {
-        // Translate matrix into vector of array for share memory or local
+      if (PtrTy->getPointerAddressSpace() == DXIL::kTGSMAddrSpace) {
+        // Translate matrix into vector of array for shared memory
         // variable should be done in HLMatrixLowerPass.
         if (!F->user_empty())
           F->getContext().emitError("Fail to lower matrix load/store.");
+      } else if (PtrTy->getPointerAddressSpace() == DXIL::kDefaultAddrSpace) {
+        // Default address space may be function argument in lib target
+        if (!F->user_empty()) {
+          for (auto U = F->user_begin(); U != F->user_end();) {
+            Value *User = *(U++);
+            if (!isa<Instruction>(User))
+              continue;
+            // must be call inst
+            CallInst *CI = cast<CallInst>(User);
+            IRBuilder<> Builder(CI);
+            HLMatLoadStoreOpcode opcode = static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+            switch (opcode) {
+            case HLMatLoadStoreOpcode::ColMatStore:
+            case HLMatLoadStoreOpcode::RowMatStore: {
+              Value *vecVal = CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
+              Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatStoreDstPtrOpIdx);
+              Value *castPtr = Builder.CreateBitCast(matPtr, vecVal->getType()->getPointerTo());
+              Builder.CreateStore(vecVal, castPtr);
+              CI->eraseFromParent();
+            } break;
+            case HLMatLoadStoreOpcode::ColMatLoad:
+            case HLMatLoadStoreOpcode::RowMatLoad: {
+              Value *matPtr = CI->getArgOperand(HLOperandIndex::kMatLoadPtrOpIdx);
+              Value *castPtr = Builder.CreateBitCast(matPtr, CI->getType()->getPointerTo());
+              Value *vecVal = Builder.CreateLoad(castPtr);
+              CI->replaceAllUsesWith(vecVal);
+              CI->eraseFromParent();
+            } break;
+            }
+          }
+        }
+      }
+    } else if (group == HLOpcodeGroup::HLCast) {
+      // HLCast may be used on matrix value function argument in lib target
+      if (!F->user_empty()) {
+        for (auto U = F->user_begin(); U != F->user_end();) {
+          Value *User = *(U++);
+          if (!isa<Instruction>(User))
+            continue;
+          // must be call inst
+          CallInst *CI = cast<CallInst>(User);
+          IRBuilder<> Builder(CI);
+          HLCastOpcode opcode = static_cast<HLCastOpcode>(hlsl::GetHLOpcode(CI));
+          bool bTranspose = false;
+          bool bColSource = false;
+          switch (opcode) {
+          case HLCastOpcode::ColMatrixToRowMatrix:
+            bColSource = true;
+          case HLCastOpcode::RowMatrixToColMatrix:
+            bTranspose = true;
+          case HLCastOpcode::ColMatrixToVecCast:
+          case HLCastOpcode::RowMatrixToVecCast: {
+            Value *matVal = CI->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
+            Value *vecVal = BitCastValueOrPtr(matVal, CI, CI->getType(),
+              /*bOrigAllocaTy*/false,
+              matVal->getName());
+            if (bTranspose) {
+              unsigned row, col;
+              HLMatrixLower::GetMatrixInfo(matVal->getType(), col, row);
+              if (bColSource) std::swap(row, col);
+              vecVal = CreateTransposeShuffle(Builder, vecVal, row, col);
+            }
+            CI->replaceAllUsesWith(vecVal);
+            CI->eraseFromParent();
+          } break;
+          }
+        }
       }
     } else if (group == HLOpcodeGroup::HLSubscript) {
       TranslateSubscriptOperation(F, helper, pObjHelper);

+ 18 - 0
tools/clang/test/CodeGenHLSL/quick-test/lib_mat_array.hlsl

@@ -0,0 +1,18 @@
+// RUN: %dxc -T lib_6_3 -Zpr %s | FileCheck %s
+
+// check that matrix lowering succeeds
+// CHECK-NOT: Fail to lower matrix load/store.
+// make sure no transpose is present
+// CHECK-NOT: shufflevector
+
+// Check that compile succeeds
+// CHECK: ret %class.matrix.float.3.4
+
+struct Foo {
+  float3x4 mat_array[2];
+  int i;
+};
+
+float3x4 lookup(Foo f, inout float3x4 mat) {
+  return f.mat_array[f.i];
+}