Kaynağa Gözat

Fix matrix support in GS OutputStream (#1769)

A couple things were missing here:
- SROA_Parameter did not properly handle DxilFieldAttributes when generating the flattened OutputStream.Append call, this would cause at least an assert to be hit.
- SROA_Parameter did not expect outputstream-qualified matrices and would replace all uses of such matrices with a local variable, with the result that we would never actually write to the outputstream value.
- HLMatrixLowerPass did not support lowering OutputStream.Append(matrix)

Also adds a test and includes a few cleanups.
Tristan Labelle 6 yıl önce
ebeveyn
işleme
9d0b011eab

+ 39 - 13
lib/HLSL/HLMatrixLowerPass.cpp

@@ -2183,20 +2183,46 @@ void HLMatrixLowerPass::replaceMatWithVec(Value *matVal,
           MatIntrinsicReplace(matCI, vecVal, useCall);
         } else {
           IntrinsicOp opcode = static_cast<IntrinsicOp>(GetHLOpcode(useCall));
-          DXASSERT_LOCALVAR(opcode, opcode == IntrinsicOp::IOP_frexp,
-                   "otherwise, unexpected opcode with matrix out parameter");
-          // NOTE: because out param use copy out semantic, so the operand of
-          // out must be temp alloca.
-          DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
-          auto it = matToVecMap.find(useCall);
-          DXASSERT(it != matToVecMap.end(),
-                   "else fail to create vec version of useCall");
-          CallInst *vecUseInst = cast<CallInst>(it->second);
-
-          for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
-            if (useCall->getArgOperand(i) == matVal) {
-              vecUseInst->setArgOperand(i, vecVal);
+          if (opcode == IntrinsicOp::MOP_Append) {
+            // Replace matrix with vector representation and update intrinsic signature
+            // We don't care about matrix orientation here, since that will need to be
+            // taken into account anyways when generating the store output calls.
+            SmallVector<Value *, 4> flatArgs;
+            SmallVector<Type *, 4> flatParamTys;
+            for (Value *arg : useCall->arg_operands()) {
+              Value *flagArg = arg == matVal ? vecVal : arg;
+              flatArgs.emplace_back(arg == matVal ? vecVal : arg);
+              flatParamTys.emplace_back(flagArg->getType());
             }
+
+            // Don't need flat return type for Append.
+            FunctionType *flatFuncTy =
+              FunctionType::get(useInst->getType(), flatParamTys, false);
+            Function *flatF = GetOrCreateHLFunction(*m_pModule, flatFuncTy, group, static_cast<unsigned int>(opcode));
+            
+            // Append returns void, so the old call should have no users
+            DXASSERT(useInst->getType()->isVoidTy(), "Unexpected MOP_Append intrinsic return type");
+            DXASSERT(useInst->use_empty(), "Unexpected users of MOP_Append intrinsic return value");
+            IRBuilder<> Builder(useCall);
+            Builder.CreateCall(flatF, flatArgs);
+            AddToDeadInsts(useCall);
+          }
+          else if (opcode == IntrinsicOp::IOP_frexp) {
+            // NOTE: because out param use copy out semantic, so the operand of
+            // out must be temp alloca.
+            DXASSERT(isa<AllocaInst>(matVal), "else invalid mat ptr for frexp");
+            auto it = matToVecMap.find(useCall);
+            DXASSERT(it != matToVecMap.end(),
+              "else fail to create vec version of useCall");
+            CallInst *vecUseInst = cast<CallInst>(it->second);
+
+            for (unsigned i = 0; i < vecUseInst->getNumArgOperands(); i++) {
+              if (useCall->getArgOperand(i) == matVal) {
+                vecUseInst->setArgOperand(i, vecVal);
+              }
+            }
+          } else {
+            DXASSERT(false, "Unexpected matrix user intrinsic.");
           }
         }
       } break;

+ 80 - 75
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2272,7 +2272,7 @@ static bool IsMemCpyTy(Type *Ty, DxilTypeSystem &typeSys) {
 static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
                      SmallVector<Value *, 16> &idxList, IRBuilder<> &Builder,
                      const DataLayout &DL, DxilTypeSystem &typeSys,
-                     DxilFieldAnnotation *fieldAnnotation, const bool bEltMemCpy = true) {
+                     const DxilFieldAnnotation *fieldAnnotation, const bool bEltMemCpy = true) {
   if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
     Constant *idx = Constant::getIntegerValue(
         IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
@@ -2293,31 +2293,31 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
                   MatrixOrientation::RowMajor;
     }
     Module *M = Builder.GetInsertPoint()->getModule();
-    Value *DestGEP = Builder.CreateInBoundsGEP(Dest, idxList);
-    Value *SrcGEP = Builder.CreateInBoundsGEP(Src, idxList);
-    if (bRowMajor) {
-      Value *Load = HLModule::EmitHLOperationCall(
-          Builder, HLOpcodeGroup::HLMatLoadStore,
-          static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad), Ty, {SrcGEP},
-          *M);
-
-      // Generate Matrix Store.
-      HLModule::EmitHLOperationCall(
-          Builder, HLOpcodeGroup::HLMatLoadStore,
-          static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore), Ty,
-          {DestGEP, Load}, *M);
-    } else {
-      Value *Load = HLModule::EmitHLOperationCall(
-          Builder, HLOpcodeGroup::HLMatLoadStore,
-          static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad), Ty, {SrcGEP},
-          *M);
 
-      // Generate Matrix Store.
-      HLModule::EmitHLOperationCall(
-          Builder, HLOpcodeGroup::HLMatLoadStore,
-          static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
-          {DestGEP, Load}, *M);
+    Value *DestMatPtr;
+    Value *SrcMatPtr;
+    if (idxList.size() == 1 && idxList[0] == ConstantInt::get(
+      IntegerType::get(Ty->getContext(), 32), APInt(32, 0))) {
+      // Avoid creating GEP(0)
+      DestMatPtr = Dest;
+      SrcMatPtr = Src;
+    }
+    else {
+      DestMatPtr = Builder.CreateInBoundsGEP(Dest, idxList);
+      SrcMatPtr = Builder.CreateInBoundsGEP(Src, idxList);
     }
+
+    HLMatLoadStoreOpcode loadOp = bRowMajor
+      ? HLMatLoadStoreOpcode::RowMatLoad : HLMatLoadStoreOpcode::ColMatLoad;
+    HLMatLoadStoreOpcode storeOp = bRowMajor
+      ? HLMatLoadStoreOpcode::RowMatStore : HLMatLoadStoreOpcode::ColMatStore;
+
+    Value *Load = HLModule::EmitHLOperationCall(
+      Builder, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(loadOp),
+      Ty, { SrcMatPtr }, *M);
+    HLModule::EmitHLOperationCall(
+      Builder, HLOpcodeGroup::HLMatLoadStore, static_cast<unsigned>(storeOp),
+      Ty, { DestMatPtr, Load }, *M);
   } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
     if (dxilutil::IsHLSLObjectType(ST)) {
       // Avoid split HLSL object.
@@ -2365,44 +2365,55 @@ static void SplitCpy(Type *Ty, Value *Dest, Value *Src,
   }
 }
 
-static void SplitPtr(Type *Ty, Value *Ptr, SmallVector<Value *, 16> &idxList,
-                     SmallVector<Value *, 16> &EltPtrList,
-                     IRBuilder<> &Builder) {
+// Given a pointer to a value, produces a list of pointers to
+// all scalar elements of that value and their field annotations, at any nesting level.
+static void SplitPtr(Value *Ptr, // The root value pointer
+  SmallVectorImpl<Value *> &IdxList, // GEP indices stack during recursion
+  Type *Ty, // Type at the current GEP indirection level
+  const DxilFieldAnnotation &Annotation, // Annotation at the current GEP indirection level
+  SmallVectorImpl<Value *> &EltPtrList, // Accumulates pointers to each element found
+  SmallVectorImpl<const DxilFieldAnnotation*> &EltAnnotationList, // Accumulates field annotations for each element found
+  DxilTypeSystem &TypeSys,
+  IRBuilder<> &Builder) {
+
   if (PointerType *PT = dyn_cast<PointerType>(Ty)) {
     Constant *idx = Constant::getIntegerValue(
         IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
-    idxList.emplace_back(idx);
+    IdxList.emplace_back(idx);
 
-    SplitPtr(PT->getElementType(), Ptr, idxList, EltPtrList, Builder);
+    SplitPtr(Ptr, IdxList, PT->getElementType(), Annotation,
+      EltPtrList, EltAnnotationList, TypeSys, Builder);
 
-    idxList.pop_back();
-  } else if (HLMatrixLower::IsMatrixType(Ty)) {
-    Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
-    EltPtrList.emplace_back(GEP);
-  } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
-    if (dxilutil::IsHLSLObjectType(ST)) {
-      // Avoid split HLSL object.
-      Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
-      EltPtrList.emplace_back(GEP);
-      return;
-    }
-    for (uint32_t i = 0; i < ST->getNumElements(); i++) {
-      llvm::Type *ET = ST->getElementType(i);
+    IdxList.pop_back();
+    return;
+  }
+  
+  if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    if (!HLMatrixLower::IsMatrixType(Ty) && !dxilutil::IsHLSLObjectType(ST)) {
+      const DxilStructAnnotation* SA = TypeSys.GetStructAnnotation(ST);
 
-      Constant *idx = llvm::Constant::getIntegerValue(
+      for (uint32_t i = 0; i < ST->getNumElements(); i++) {
+        llvm::Type *EltTy = ST->getElementType(i);
+
+        Constant *idx = llvm::Constant::getIntegerValue(
           IntegerType::get(Ty->getContext(), 32), APInt(32, i));
-      idxList.emplace_back(idx);
+        IdxList.emplace_back(idx);
 
-      SplitPtr(ET, Ptr, idxList, EltPtrList, Builder);
+        SplitPtr(Ptr, IdxList, EltTy, SA->GetFieldAnnotation(i),
+          EltPtrList, EltAnnotationList, TypeSys, Builder);
 
-      idxList.pop_back();
+        IdxList.pop_back();
+      }
+      return;
     }
-
-  } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
-    if (AT->getNumContainedTypes() == 0) {
-      // Skip case like [0 x %struct].
+  }
+  
+  if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
+    if (AT->getArrayNumElements() == 0) {
+      // Skip cases like [0 x %struct], nothing to copy
       return;
     }
+
     Type *ElTy = AT->getElementType();
     SmallVector<ArrayType *, 4> nestArrayTys;
 
@@ -2414,19 +2425,16 @@ static void SplitPtr(Type *Ty, Value *Ptr, SmallVector<Value *, 16> &idxList,
       ElTy = ElAT->getElementType();
     }
 
-    if (!ElTy->isStructTy() ||
-        HLMatrixLower::IsMatrixType(ElTy)) {
-      // Not split array of basic type.
-      Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
-      EltPtrList.emplace_back(GEP);
-    }
-    else {
+    if (ElTy->isStructTy() && !HLMatrixLower::IsMatrixType(ElTy)) {
       DXASSERT(0, "Not support array of struct when split pointers.");
+      return;
     }
-  } else {
-    Value *GEP = Builder.CreateInBoundsGEP(Ptr, idxList);
-    EltPtrList.emplace_back(GEP);
   }
+
+  // Return a pointer to the current element and its annotation
+  Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
+  EltPtrList.emplace_back(GEP);
+  EltAnnotationList.emplace_back(&Annotation);
 }
 
 // Support case when bitcast (gep ptr, 0,0) is transformed into bitcast ptr.
@@ -5622,23 +5630,24 @@ void SROA_Parameter_HLSL::flattenArgument(
                 // Must be struct to be flatten.
                 IRBuilder<> Builder(CI);
 
-                llvm::SmallVector<llvm::Value *, 16> idxList;
+                llvm::SmallVector<llvm::Value *, 16> IdxList;
                 llvm::SmallVector<llvm::Value *, 16> EltPtrList;
+                llvm::SmallVector<const DxilFieldAnnotation*, 16> EltAnnotationList;
                 // split
-                SplitPtr(outputVal->getType(), outputVal, idxList, EltPtrList,
-                         Builder);
+                SplitPtr(outputVal, IdxList, outputVal->getType(), flatParamAnnotation,
+                  EltPtrList, EltAnnotationList, dxilTypeSys, Builder);
 
                 unsigned eltCount = CI->getNumArgOperands()-2;
                 DXASSERT_LOCALVAR(eltCount, eltCount == EltPtrList.size(), "invalid element count");
 
                 for (unsigned i = HLOperandIndex::kStreamAppendDataOpIndex; i < CI->getNumArgOperands(); i++) {
                   Value *DataPtr = CI->getArgOperand(i);
-                  Value *EltPtr =
-                      EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
+                  Value *EltPtr = EltPtrList[i - HLOperandIndex::kStreamAppendDataOpIndex];
+                  const DxilFieldAnnotation *EltAnnotation = EltAnnotationList[i - HLOperandIndex::kStreamAppendDataOpIndex];
 
-                  llvm::SmallVector<llvm::Value *, 16> idxList;
-                  SplitCpy(DataPtr->getType(), EltPtr, DataPtr, idxList,
-                           Builder, DL, dxilTypeSys, &flatParamAnnotation);
+                  llvm::SmallVector<llvm::Value *, 16> IdxList;
+                  SplitCpy(DataPtr->getType(), EltPtr, DataPtr, IdxList,
+                           Builder, DL, dxilTypeSys, EltAnnotation);
                   CI->setArgOperand(i, EltPtr);
                 }
               }
@@ -5931,22 +5940,18 @@ static void LegalizeDxilInputOutputs(Function *F,
     bool bStore = false;
     CheckArgUsage(&arg, bLoad, bStore);
 
-    bool bNeedTemp = false;
     bool bStoreInputToTemp = false;
     bool bLoadOutputFromTemp = false;
 
     if (qual == DxilParamInputQual::In && bStore) {
-      bNeedTemp = true;
       bStoreInputToTemp = true;
     } else if (qual == DxilParamInputQual::Out && bLoad) {
-      bNeedTemp = true;
       bLoadOutputFromTemp = true;
     } else if (bLoad && bStore) {
       switch (qual) {
       case DxilParamInputQual::InputPrimitive:
       case DxilParamInputQual::InputPatch:
       case DxilParamInputQual::OutputPatch: {
-        bNeedTemp = true;
         bStoreInputToTemp = true;
       } break;
       case DxilParamInputQual::Inout:
@@ -5956,13 +5961,11 @@ static void LegalizeDxilInputOutputs(Function *F,
       }
     } else if (qual == DxilParamInputQual::Inout) {
       // Only replace inout when (bLoad && bStore) == false.
-      bNeedTemp = true;
       bLoadOutputFromTemp = true;
       bStoreInputToTemp = true;
     }
 
     if (HLMatrixLower::IsMatrixType(Ty)) {
-      bNeedTemp = true;
       if (qual == DxilParamInputQual::In)
         bStoreInputToTemp = bLoad;
       else if (qual == DxilParamInputQual::Out)
@@ -5973,7 +5976,7 @@ static void LegalizeDxilInputOutputs(Function *F,
       }
     }
 
-    if (bNeedTemp) {
+    if (bStoreInputToTemp || bLoadOutputFromTemp) {
       IRBuilder<> AllocaBuilder(EntryBlk.getFirstInsertionPt());
       IRBuilder<> Builder(dxilutil::FirstNonAllocaInsertionPt(&EntryBlk));
 
@@ -6344,6 +6347,8 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
       }
 
       flatArg->replaceAllUsesWith(Arg);
+      if (isa<Instruction>(flatArg))
+        DeadInsts.emplace_back(flatArg);
 
       HLModule::MergeGepUse(Arg);
       // Flatten store of array parameter.

+ 45 - 0
tools/clang/test/CodeGenHLSL/quick-test/streamout_matrix_all_orientations.hlsl

@@ -0,0 +1,45 @@
+// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s
+
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 0, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 0, i32 0, i8 1, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 0, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 1, i32 0, i8 1, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 0, i8 0, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 2, i32 1, i8 0, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 0, i8 0, i32 0)
+// CHECK: call float @dx.op.loadInput.f32(i32 4, i32 3, i32 1, i8 0, i32 0)
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 1, i32 0, i8 0, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 1, i32 1, i8 0, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 2, i32 0, i8 0, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 2, i32 0, i8 1, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 3, i32 0, i8 0, float {{.*}})
+// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 3, i32 1, i8 0, float {{.*}})
+
+struct GSIn
+{
+    row_major float1x2 a : A;
+    row_major float1x2 b : B;
+    column_major float1x2 c : C;
+    column_major float1x2 d : D;
+};
+
+struct GSOut
+{
+    row_major float1x2 a : A;
+    column_major float1x2 b : B;
+    row_major float1x2 c : C;
+    column_major float1x2 d : D;
+};
+
+[maxvertexcount(1)]
+void main(point GSIn input[1], inout PointStream<GSOut> output)
+{
+    GSOut result;
+    result.a = input[0].a;
+    result.b = input[0].b;
+    result.c = input[0].c;
+    result.d = input[0].d;
+    output.Append(result);
+}