Browse Source

Make scalar static global matrix to col major. (#154)

Xiang Li 8 years ago
parent
commit
df5f52024b

+ 31 - 32
lib/HLSL/HLMatrixLowerPass.cpp

@@ -1416,7 +1416,7 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
     Value *matGlobal, ArrayRef<Value *> vecGlobals,
     Value *matGlobal, ArrayRef<Value *> vecGlobals,
     CallInst *matLdStInst) {
     CallInst *matLdStInst) {
   // No dynamic indexing on matrix, flatten matrix to scalars.
   // No dynamic indexing on matrix, flatten matrix to scalars.
-  // Internal global matrix use row major follow the initializer.
+  // vecGlobals already in col major.
   Type *matType = matGlobal->getType()->getPointerElementType();
   Type *matType = matGlobal->getType()->getPointerElementType();
   unsigned col, row;
   unsigned col, row;
   HLMatrixLower::GetMatrixInfo(matType, col, row);
   HLMatrixLower::GetMatrixInfo(matType, col, row);
@@ -1429,23 +1429,19 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::ColMatLoad:
   case HLMatLoadStoreOpcode::RowMatLoad: {
   case HLMatLoadStoreOpcode::RowMatLoad: {
     Value *Result = UndefValue::get(vecType);
     Value *Result = UndefValue::get(vecType);
-    for (unsigned c = 0; c < col; c++)
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
-        Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
-        Result = Builder.CreateInsertElement(Result, Elt, matIdx);
-      }
+    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
+      Value *Elt = Builder.CreateLoad(vecGlobals[matIdx]);
+      Result = Builder.CreateInsertElement(Result, Elt, matIdx);
+    }
     matLdStInst->replaceAllUsesWith(Result);
     matLdStInst->replaceAllUsesWith(Result);
   } break;
   } break;
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::ColMatStore:
   case HLMatLoadStoreOpcode::RowMatStore: {
   case HLMatLoadStoreOpcode::RowMatStore: {
     Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
     Value *Val = matLdStInst->getArgOperand(HLOperandIndex::kMatStoreValOpIdx);
-    for (unsigned c = 0; c < col; c++)
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
-        Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-        Builder.CreateStore(Elt, vecGlobals[matIdx]);
-      }
+    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
+      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
+      Builder.CreateStore(Elt, vecGlobals[matIdx]);
+    }
   } break;
   } break;
   }
   }
 }
 }
@@ -1453,6 +1449,8 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(
 void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
 void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
                                                       GlobalVariable *scalarArrayGlobal,
                                                       GlobalVariable *scalarArrayGlobal,
                                                       CallInst *matLdStInst) {
                                                       CallInst *matLdStInst) {
+  // vecGlobals already in col major.
+  const bool bColMajor = true;
   HLMatLoadStoreOpcode opcode =
   HLMatLoadStoreOpcode opcode =
       static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
       static_cast<HLMatLoadStoreOpcode>(GetHLOpcode(matLdStInst));
   switch (opcode) {
   switch (opcode) {
@@ -1466,16 +1464,14 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
 
 
     std::vector<Value *> matElts(col * row);
     std::vector<Value *> matElts(col * row);
 
 
-    for (unsigned c = 0; c < col; c++)
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
-        Value *GEP = Builder.CreateInBoundsGEP(
-            scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-        matElts[matIdx] = Builder.CreateLoad(GEP);
-      }
+    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
+      Value *GEP = Builder.CreateInBoundsGEP(
+          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
+      matElts[matIdx] = Builder.CreateLoad(GEP);
+    }
 
 
     Value *newVec =
     Value *newVec =
-        HLMatrixLower::BuildMatrix(EltTy, col, row, false, matElts, Builder);
+        HLMatrixLower::BuildMatrix(EltTy, col, row, bColMajor, matElts, Builder);
     matLdStInst->replaceAllUsesWith(newVec);
     matLdStInst->replaceAllUsesWith(newVec);
     matLdStInst->eraseFromParent();
     matLdStInst->eraseFromParent();
   } break;
   } break;
@@ -1491,14 +1487,12 @@ void HLMatrixLowerPass::TranslateMatLoadStoreOnGlobal(GlobalVariable *matGlobal,
 
 
     std::vector<Value *> matElts(col * row);
     std::vector<Value *> matElts(col * row);
 
 
-    for (unsigned c = 0; c < col; c++)
-      for (unsigned r = 0; r < row; r++) {
-        unsigned matIdx = c * row + r;
-        Value *GEP = Builder.CreateInBoundsGEP(
-            scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
-        Value *Elt = Builder.CreateExtractElement(Val, matIdx);
-        Builder.CreateStore(Elt, GEP);
-      }
+    for (unsigned matIdx = 0; matIdx < col * row; matIdx++) {
+      Value *GEP = Builder.CreateInBoundsGEP(
+          scalarArrayGlobal, {zeroIdx, Builder.getInt32(matIdx)});
+      Value *Elt = Builder.CreateExtractElement(Val, matIdx);
+      Builder.CreateStore(Elt, GEP);
+    }
 
 
     matLdStInst->eraseFromParent();
     matLdStInst->eraseFromParent();
   } break;
   } break;
@@ -2229,6 +2223,7 @@ static Constant *LowerMatrixArrayConst(Constant *MA, ArrayType *ResultTy) {
 
 
 void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
 void HLMatrixLowerPass::runOnGlobalMatrixArray(GlobalVariable *GV) {
   // Lower to array of vector array like float[row][col].
   // Lower to array of vector array like float[row][col].
+  // It's row major.
   // DynamicIndexingVectorToArray will change it to scalar array.
   // DynamicIndexingVectorToArray will change it to scalar array.
   Type *Ty = GV->getType()->getPointerElementType();
   Type *Ty = GV->getType()->getPointerElementType();
   std::vector<unsigned> arraySizeList;
   std::vector<unsigned> arraySizeList;
@@ -2311,9 +2306,11 @@ static void FlattenMatConst(Constant *M, std::vector<Constant *> &Elts) {
       Elts.emplace_back(Elt);
       Elts.emplace_back(Elt);
   } else {
   } else {
     M = M->getAggregateElement((unsigned)0);
     M = M->getAggregateElement((unsigned)0);
-    for (unsigned r = 0; r < row; r++) {
-      Constant *R = M->getAggregateElement(r);
-      for (unsigned c = 0; c < col; c++) {
+    // Initializer is row major.
+    // Make it col major to match temp matrix.
+    for (unsigned c = 0; c < col; c++) {
+      for (unsigned r = 0; r < row; r++) {
+        Constant *R = M->getAggregateElement(r);
         Elts.emplace_back(R->getAggregateElement(c));
         Elts.emplace_back(R->getAggregateElement(c));
       }
       }
     }
     }
@@ -2339,6 +2336,8 @@ void HLMatrixLowerPass::runOnGlobal(GlobalVariable *GV) {
   const DataLayout &DL = M->getDataLayout();
   const DataLayout &DL = M->getDataLayout();
 
 
   std::vector<Constant *> Elts;
   std::vector<Constant *> Elts;
+  // Lower to vector or array for scalar matrix.
+  // Make it col major so don't need shuffle when load/store.
   FlattenMatConst(GV->getInitializer(), Elts);
   FlattenMatConst(GV->getInitializer(), Elts);
 
 
   if (onlyLdSt) {
   if (onlyLdSt) {

+ 17 - 0
tools/clang/test/CodeGenHLSL/constMat.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: fmul
+// CHECK: 2.000000e+00
+// CHECK: fmul
+// CHECK: 3.000000e+00
+// CHECK: @dx.op.storeOutput.f32
+static const float3x3 g_mat1 = {
+    1, 2, 3,
+    4, 5, 6,
+    7, 8, 9,
+};
+
+float4 main(float a : A) : SV_Target {
+    float3 v = float3(a, 0, 0);
+    return float4(mul(v, g_mat1), 0);
+}

+ 17 - 0
tools/clang/test/CodeGenHLSL/constMat2.hlsl

@@ -0,0 +1,17 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: fmul
+// CHECK: 4.000000e+00
+// CHECK: fmul
+// CHECK: 7.000000e+00
+// CHECK: @dx.op.storeOutput.f32
+static const float3x3 g_mat1 = {
+    1, 2, 3,
+    4, 5, 6,
+    7, 8, 9,
+};
+
+float4 main(float a : A) : SV_Target {
+    float3 v = float3(a, 0, 0);
+    return float4(mul(g_mat1, v), 0);
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/constMat3.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: [9 x float] [float 1.000000e+00, float 4.000000e+00, float 7.000000e+00, float 2.000000e+00, float 5.000000e+00, float 8.000000e+00, float 3.000000e+00, float 6.000000e+00, float 9.000000e+00]
+// CHECK: fmul
+// CHECK: 2.000000e+00
+// CHECK: fmul
+// CHECK: 3.000000e+00
+// CHECK: add i32
+// CHECK: , 3
+// CHECK: add i32
+// CHECK: , 6
+
+static const float3x3 g_mat1 = {
+    1, 2, 3,
+    4, 5, 6,
+    7, 8, 9,
+};
+
+float4 main(float a : A) : SV_Target {
+    float3 v = float3(a, 0, 0);
+    float4 c = float4(mul(v, g_mat1), 0);
+    return c + float4(g_mat1[a],0);
+}

+ 23 - 0
tools/clang/test/CodeGenHLSL/constMat4.hlsl

@@ -0,0 +1,23 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// CHECK: [9 x float] [float 1.000000e+00, float 4.000000e+00, float 7.000000e+00, float 2.000000e+00, float 5.000000e+00, float 8.000000e+00, float 3.000000e+00, float 6.000000e+00, float 9.000000e+00]
+// CHECK: fmul
+// CHECK: 4.000000e+00
+// CHECK: fmul
+// CHECK: 7.000000e+00
+// CHECK: add i32
+// CHECK: , 3
+// CHECK: add i32
+// CHECK: , 6
+
+static const float3x3 g_mat1 = {
+    1, 2, 3,
+    4, 5, 6,
+    7, 8, 9,
+};
+
+float4 main(float a : A) : SV_Target {
+    float3 v = float3(a, 0, 0);
+    float4 c = float4(mul(g_mat1, v), 0);
+    return c + float4(g_mat1[a],0);
+}

+ 2 - 2
tools/clang/test/CodeGenHLSL/staticGlobals.hlsl

@@ -6,8 +6,8 @@
 // CHECK: [3 x float] [float 6.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 6.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 7.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 7.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 8.000000e+00, float 0.000000e+00, float 0.000000e+00]
 // CHECK: [3 x float] [float 8.000000e+00, float 0.000000e+00, float 0.000000e+00]
-// CHECK: [16 x float] [float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01, float 1.500000e+01, float 1.600000e+01, float 1.700000e+01, float 1.800000e+01]
-// CHECK: [16 x float] [float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 0.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 2.000000e+00, float 2.000000e+00, float 2.000000e+00, float 2.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00, float 3.000000e+00]
+// CHECK: [16 x float] [float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.500000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.700000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01, float 1.800000e+01]
+// CHECK: [16 x float] [float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 0.000000e+00, float 1.000000e+00, float 2.000000e+00, float 3.000000e+00]
 // CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [4 x float] [float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00]
 // CHECK: [16 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01]
 // CHECK: [16 x float] [float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01, float 2.500000e+01, float 2.600000e+01, float 2.700000e+01, float 2.800000e+01]
 
 

+ 20 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -333,6 +333,10 @@ public:
   TEST_METHOD(CodeGenClip)
   TEST_METHOD(CodeGenClip)
   TEST_METHOD(CodeGenClipPlanes)
   TEST_METHOD(CodeGenClipPlanes)
   TEST_METHOD(CodeGenConstoperand1)
   TEST_METHOD(CodeGenConstoperand1)
+  TEST_METHOD(CodeGenConstMat)
+  TEST_METHOD(CodeGenConstMat2)
+  TEST_METHOD(CodeGenConstMat3)
+  TEST_METHOD(CodeGenConstMat4)
   TEST_METHOD(CodeGenDiscard)
   TEST_METHOD(CodeGenDiscard)
   TEST_METHOD(CodeGenDivZero)
   TEST_METHOD(CodeGenDivZero)
   TEST_METHOD(CodeGenDot1)
   TEST_METHOD(CodeGenDot1)
@@ -2053,6 +2057,22 @@ TEST_F(CompilerTest, CodeGenConstoperand1) {
   CodeGenTest(L"..\\CodeGenHLSL\\constoperand1.hlsl");
   CodeGenTest(L"..\\CodeGenHLSL\\constoperand1.hlsl");
 }
 }
 
 
+TEST_F(CompilerTest, CodeGenConstMat) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\constMat.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenConstMat2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\constMat2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenConstMat3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\constMat3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenConstMat4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\constMat4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenDiscard) {
 TEST_F(CompilerTest, CodeGenDiscard) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\discard.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\discard.hlsl");
 }
 }