Преглед на файлове

Support matrix transpose. (#900)

Xiang Li преди 7 години
родител
ревизия
455885b3fc

+ 20 - 6
lib/HLSL/HLMatrixLowerPass.cpp

@@ -251,7 +251,7 @@ private:
   void TranslateMatCast(CallInst *matInst, Instruction *vecInst,
                         CallInst *castInst);
   void TranslateMatMajorCast(CallInst *matInst, Instruction *vecInst,
-                        CallInst *castInst, bool rowToCol);
+                        CallInst *castInst, bool rowToCol, bool transpose);
   // Replace matInst with vecInst in matSubscript
   void TranslateMatSubscript(Value *matInst, Value *vecInst,
                              CallInst *matSubInst);
@@ -1073,7 +1073,8 @@ void HLMatrixLowerPass::TranslateMatTranspose(CallInst *matInst,
                                               Instruction *vecInst,
                                               CallInst *transposeInst) {
   // Matrix value is row major, transpose is cast it to col major.
-  TranslateMatMajorCast(matInst, vecInst, transposeInst, /*bRowToCol*/ true);
+  TranslateMatMajorCast(matInst, vecInst, transposeInst,
+      /*bRowToCol*/ true, /*bTranspose*/ true);
 }
 
 static Value *Determinant2x2(Value *m00, Value *m01, Value *m10, Value *m11,
@@ -1194,10 +1195,22 @@ void HLMatrixLowerPass::TrivialMatReplace(CallInst *matInst,
 void HLMatrixLowerPass::TranslateMatMajorCast(CallInst *matInst,
                                               Instruction *vecInst,
                                               CallInst *castInst,
-                                              bool bRowToCol) {
+                                              bool bRowToCol,
+                                              bool bTranspose) {
   unsigned col, row;
-  GetMatrixInfo(castInst->getType(), col, row);
-  DXASSERT(castInst->getType() == matInst->getType(), "type must match");
+  if (!bTranspose) {
+    GetMatrixInfo(castInst->getType(), col, row);
+    DXASSERT(castInst->getType() == matInst->getType(), "type must match");
+  } else {
+    unsigned castCol, castRow;
+    Type *castTy = GetMatrixInfo(castInst->getType(), castCol, castRow);
+    unsigned srcCol, srcRow;
+    Type *srcTy = GetMatrixInfo(matInst->getType(), srcCol, srcRow);
+    DXASSERT(srcTy == castTy, "type must match");
+    DXASSERT(castCol == srcRow && castRow == srcCol, "col row must match");
+    col = srcCol;
+    row = srcRow;
+  }
 
   IRBuilder<> Builder(castInst);
 
@@ -1321,7 +1334,8 @@ void HLMatrixLowerPass::TranslateMatCast(CallInst *matInst,
   if (opcode == HLCastOpcode::ColMatrixToRowMatrix ||
       opcode == HLCastOpcode::RowMatrixToColMatrix) {
     TranslateMatMajorCast(matInst, vecInst, castInst,
-                          opcode == HLCastOpcode::RowMatrixToColMatrix);
+                          opcode == HLCastOpcode::RowMatrixToColMatrix,
+                          /*bTranspose*/false);
   } else {
     bool ToMat = IsMatrixType(castInst->getType());
     bool FromMat = IsMatrixType(matInst->getType());

+ 2 - 1
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -3397,7 +3397,8 @@ static Value *CastLdValue(Value *Ptr, llvm::Type *FromTy, llvm::Type *ToTy, IRBu
       Value *V = Builder.CreateLoad(Ptr);
       // VectorTrunc
       // Change vector into vec1.
-      return Builder.CreateShuffleVector(V, V, {0});
+      int mask[] = {0};
+      return Builder.CreateShuffleVector(V, V, mask);
     } else if (FromTy->isArrayTy()) {
       llvm::Type *FromEltTy = FromTy->getArrayElementType();
 

+ 14 - 0
tools/clang/test/CodeGenHLSL/quick-test/mat_transpose.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure get cb0[0].y and cb0[1].y.
+// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
+// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
+
+row_major float2x3 m;
+
+float2 main(int i : A) : SV_TARGET
+{
+  return transpose(m)[1];
+}

+ 13 - 0
tools/clang/test/CodeGenHLSL/quick-test/mat_transpose_2.hlsl

@@ -0,0 +1,13 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure get cb0[1].xy.
+// CHECK: call %dx.types.CBufRet.f32 @dx.op.cbufferLoadLegacy.f32
+// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 0
+// CHECK: extractvalue %dx.types.CBufRet.f32 {{.*}}, 1
+
+float2x3 m;
+
+float2 main(int i : A) : SV_TARGET
+{
+  return transpose(m)[1];
+}