Browse Source

Fix row index for array of matrix input/output. (#660)

* Fix row index for array of matrix input/output.
Also remove sret when flatten functions.
Xiang Li 8 years ago
parent
commit
36b2a12695

+ 47 - 7
lib/HLSL/HLSignatureLower.cpp

@@ -700,7 +700,7 @@ struct InputOutputAccessInfo {
 void collectInputOutputAccessInfo(
     Value *GV, Constant *constZero,
     std::vector<InputOutputAccessInfo> &accessInfoList, bool hasVertexID,
-    bool bInput) {
+    bool bInput, bool bRowMajor) {
   auto User = GV->user_begin();
   auto UserE = GV->user_end();
   for (; User != UserE;) {
@@ -743,6 +743,8 @@ void collectInputOutputAccessInfo(
           DXASSERT_NOMSG((++GEPIt) == E);
         } else {
           // Array which may have vector indexing.
+          // Highest dim index is saved in rowIdx,
+          //  array size for highest dim not affect index.
           GEPIt++;
           IRBuilder<> Builder(GEP);
           Type *idxTy = rowIdx->getType();
@@ -764,6 +766,15 @@ void collectInputOutputAccessInfo(
               vectorIdx = GEPIt.getOperand();
             }
           }
+          if (HLMatrixLower::IsMatrixType(*GEPIt)) {
+            unsigned row, col;
+            HLMatrixLower::GetMatrixInfo(*GEPIt, col, row);
+            Constant *arraySize = ConstantInt::get(idxTy, col);
+            if (bRowMajor) {
+              arraySize = ConstantInt::get(idxTy, row);
+            }
+            rowIdx = Builder.CreateMul(rowIdx, arraySize);
+          }
         }
       } else
         rowIdx = constZero;
@@ -1016,6 +1027,9 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
   DxilSignature &Sig =
       bInput ? EntrySig.InputSignature : EntrySig.OutputSignature;
 
+  DxilTypeSystem &typeSys = HLM.GetTypeSystem();
+  DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
+
   Type *i1Ty = Type::getInt1Ty(constZero->getContext());
   Type *i32Ty = constZero->getType();
 
@@ -1073,10 +1087,18 @@ void HLSignatureLower::GenerateDxilInputsOutputs(bool bInput) {
     bool bIsPrecise = m_preciseSigSet.count(SE);
     if (bIsPrecise)
       HLModule::MarkPreciseAttributeOnPtrWithFunctionCall(GV, M);
-
+    bool bRowMajor = false;
+    if (Argument *Arg = dyn_cast<Argument>(GV)) {
+      if (pFuncAnnot) {
+        auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
+        if (paramAnnot.HasMatrixAnnotation())
+          bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
+                      MatrixOrientation::RowMajor;
+      }
+    }
     std::vector<InputOutputAccessInfo> accessInfoList;
     collectInputOutputAccessInfo(GV, constZero, accessInfoList,
-                                 bNeedVertexID && bIsArrayTy, bInput);
+                                 bNeedVertexID && bIsArrayTy, bInput, bRowMajor);
 
     for (InputOutputAccessInfo &info : accessInfoList) {
       GenerateInputOutputUserCall(info, undefVertexIdx, dxilFunc, OpArg, ID,
@@ -1172,6 +1194,8 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
   Module &M = *(HLM.GetModule());
   Constant *constZero = hlslOP->GetU32Const(0);
   DxilSignature &Sig = EntrySig.PatchConstantSignature;
+  DxilTypeSystem &typeSys = HLM.GetTypeSystem();
+  DxilFunctionAnnotation *pFuncAnnot = typeSys.GetFunctionAnnotation(Entry);
   auto InsertPt = Entry->getEntryBlock().getFirstInsertionPt();
   const bool bIsHs = props.IsHS();
   const bool bIsInput = !bIsHs;
@@ -1234,10 +1258,18 @@ void HLSignatureLower::GenerateDxilPatchConstantLdSt() {
                                   Builder);
       continue;
     }
-
+    bool bRowMajor = false;
+    if (Argument *Arg = dyn_cast<Argument>(GV)) {
+      if (pFuncAnnot) {
+        auto &paramAnnot = pFuncAnnot->GetParameterAnnotation(Arg->getArgNo());
+        if (paramAnnot.HasMatrixAnnotation())
+          bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
+                      MatrixOrientation::RowMajor;
+      }
+    }
     std::vector<InputOutputAccessInfo> accessInfoList;
     collectInputOutputAccessInfo(GV, constZero, accessInfoList, bNeedVertexID,
-                                 bIsInput);
+                                 bIsInput, bRowMajor);
 
     bool bIsArrayTy = GV->getType()->getPointerElementType()->isArrayTy();
     bool isPrecise = m_preciseSigSet.count(SE);
@@ -1291,10 +1323,18 @@ void HLSignatureLower::GenerateDxilPatchConstantFunctionInputs() {
                               ? OP::OpCode::LoadInput
                               : OP::OpCode::LoadOutputControlPoint;
       Function *dxilLdFunc = hlslOP->GetOpFunc(opcode, Ty);
-
+      bool bRowMajor = false;
+      if (Argument *Arg = dyn_cast<Argument>(&arg)) {
+        if (patchFuncAnnotation) {
+          auto &paramAnnot = patchFuncAnnotation->GetParameterAnnotation(Arg->getArgNo());
+          if (paramAnnot.HasMatrixAnnotation())
+            bRowMajor = paramAnnot.GetMatrixAnnotation().Orientation ==
+                      MatrixOrientation::RowMajor;
+        }
+      }
       std::vector<InputOutputAccessInfo> accessInfoList;
       collectInputOutputAccessInfo(&arg, constZero, accessInfoList,
-                                   /*hasVertexID*/ true, true);
+                                   /*hasVertexID*/ true, true, bRowMajor);
       for (InputOutputAccessInfo &info : accessInfoList) {
         if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
           Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);

+ 11 - 0
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -5896,6 +5896,17 @@ void SROA_Parameter_HLSL::createFlattenedFunction(Function *F) {
   }
 
   // Function Attr and Parameter Attr.
+  // Remove sret first.
+  if (F->hasStructRetAttr())
+    F->removeFnAttr(Attribute::StructRet);
+  for (Argument &arg : F->args()) {
+    if (arg.hasStructRetAttr()) {
+      Attribute::AttrKind SRet [] = {Attribute::StructRet};
+      AttributeSet SRetAS = AttributeSet::get(Ctx, arg.getArgNo() + 1, SRet);
+      arg.removeAttr(SRetAS);
+    }
+  }
+
   AttributeSet AS = F->getAttributes();
   AttrBuilder FnAttrs(AS.getFnAttributes(), AttributeSet::FunctionIndex);
   AttributeSet flatAS;

+ 38 - 0
tools/clang/test/CodeGenHLSL/MatArrayOutput.hlsl

@@ -0,0 +1,38 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Make sure every row is stored.
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 0, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 0, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 1, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 1, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 2, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 2, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 3, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 3, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 4, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 4, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 5, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 5, i8 1
+
+
+
+struct Vertex
+{
+    float2x2 pos     : POSITION0;
+    float3x2 t     : T;
+};
+
+struct Interpolants
+{
+    float4 pos     : SV_POSITION0;
+    row_major float3x2 p[2] : O;
+};
+
+Interpolants main( Vertex In )
+{
+    Interpolants o;
+    o.pos = (float4)In.pos;
+    o.p[0] = In.t*2;
+    o.p[1] = In.t*3;
+    return o;
+}

+ 41 - 0
tools/clang/test/CodeGenHLSL/MatArrayOutput2.hlsl

@@ -0,0 +1,41 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// Make sure every row is stored.
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 0, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 0, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 1, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 1, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 2, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 2, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 3, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 3, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 4, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 4, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 5, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 5, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 6, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 6, i8 1
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 7, i8 0
+// CHECK: dx.op.storeOutput.f32(i32 5, i32 1, i32 7, i8 1
+
+struct Vertex
+{
+    float2x2 pos     : POSITION0;
+};
+
+struct Interpolants
+{
+    float4 pos     : SV_POSITION0;
+    float2x2 p[2][2] : O;
+};
+
+Interpolants main( Vertex In )
+{
+    Interpolants o;
+    o.pos = (float4)In.pos;
+    o.p[0][0] = In.pos*2;
+    o.p[0][1] = In.pos*3;
+    o.p[1][0] = In.pos*4;
+    o.p[1][1] = In.pos*5;
+    return o;
+}

+ 10 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -597,6 +597,8 @@ public:
   TEST_METHOD(CodeGenMatParam)
   TEST_METHOD(CodeGenMatParam2)
  // TEST_METHOD(CodeGenMatParam3)
+  TEST_METHOD(CodeGenMatArrayOutput)
+  TEST_METHOD(CodeGenMatArrayOutput2)
   TEST_METHOD(CodeGenMatElt)
   TEST_METHOD(CodeGenMatInit)
   TEST_METHOD(CodeGenMatMulMat)
@@ -3530,6 +3532,14 @@ TEST_F(CompilerTest, CodeGenMatParam2) {
 //  CodeGenTestCheck(L"..\\CodeGenHLSL\\mat_param3.hlsl");
 //}
 
+TEST_F(CompilerTest, CodeGenMatArrayOutput) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\MatArrayOutput.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenMatArrayOutput2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\MatArrayOutput2.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenMatElt) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\matElt.hlsl");
 }