Explorar o código

[spirv] Add support for matrix swizzling (#514)

This PR handles two formats for indexing matrices: _mXX and _XX.
The operator[] format will be handled in the next PR.
Lei Zhang %!s(int64=8) %!d(string=hai) anos
pai
achega
9a2c5cc89d

+ 136 - 6
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -207,7 +207,6 @@ public:
         }
       }
     }
-    // TODO: enlarge the queue upon seeing a function call.
 
     // Translate all functions reachable from the entry function.
     // The queue can grow in the meanwhile; so need to keep evaluating
@@ -894,6 +893,10 @@ public:
       return doHLSLVectorElementExpr(vecElemExpr);
     }
 
+    if (const auto *matElemExpr = dyn_cast<ExtMatrixElementExpr>(expr)) {
+      return doExtMatrixElementExpr(matElemExpr);
+    }
+
     if (const auto *funcCall = dyn_cast<CallExpr>(expr)) {
       return doCallExpr(funcCall);
     }
@@ -1051,6 +1054,70 @@ public:
     return rhs;
   }
 
+  /// Tries to emit instructions for assigning to the given matrix element
+  /// accessing expression. Returns 0 if the trial fails and no instructions
+  /// are generated.
+  uint32_t tryToAssignToMatrixElements(const Expr *lhs, uint32_t rhs) {
+    const auto *lhsExpr = dyn_cast<ExtMatrixElementExpr>(lhs);
+    if (!lhsExpr)
+      return 0;
+
+    const Expr *baseMat = lhsExpr->getBase();
+    const uint32_t base = doExpr(baseMat);
+    const QualType elemType = hlsl::GetHLSLMatElementType(baseMat->getType());
+    const uint32_t elemTypeId = typeTranslator.translateType(elemType);
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(baseMat->getType(), rowCount, colCount);
+
+    // For each lhs element written to:
+    // 1. Extract the corresponding rhs element using OpCompositeExtract
+    // 2. Create access chain for the lhs element using OpAccessChain
+    // 3. Write using OpStore
+
+    const auto accessor = lhsExpr->getEncodedElementAccess();
+    for (uint32_t i = 0; i < accessor.Count; ++i) {
+      uint32_t row = 0, col = 0;
+      accessor.GetPosition(i, &row, &col);
+
+      llvm::SmallVector<uint32_t, 2> indices;
+      // If the matrix only has one row/column, we are indexing into a vector
+      // then. Only one index is needed for such cases.
+      if (rowCount > 1)
+        indices.push_back(row);
+      if (colCount > 1)
+        indices.push_back(col);
+
+      for (uint32_t i = 0; i < indices.size(); ++i)
+        indices[i] = theBuilder.getConstantInt32(indices[i]);
+
+      // If we are writing to only one element, the rhs should already be a
+      // scalar value.
+      uint32_t rhsElem = rhs;
+      if (accessor.Count > 1)
+        rhsElem = theBuilder.createCompositeExtract(elemTypeId, rhs, {i});
+
+      // TODO: select storage type based on the underlying variable
+      const uint32_t ptrType =
+          theBuilder.getPointerType(elemTypeId, spv::StorageClass::Function);
+
+      // If the lhs is actually a matrix of size 1x1, we don't need the access
+      // chain. base is already the dest pointer.
+      uint32_t lhsElemPtr = base;
+      if (!indices.empty()) {
+        // Load the element via access chain
+        lhsElemPtr = theBuilder.createAccessChain(ptrType, base, indices);
+      }
+
+      theBuilder.createStore(lhsElemPtr, rhsElem);
+    }
+
+    // TODO: OK, this return value is incorrect for compound assignments, for
+    // which cases we should return lvalues. Should at least emit errors if
+    // this return value is used (can be checked via ASTContext.getParents).
+    return rhs;
+  }
+
   /// Generates the necessary instructions for assigning rhs to lhs. If lhsPtr
   /// is not zero, it will be used as the pointer from lhs instead of evaluating
   /// lhs again.
@@ -1060,6 +1127,10 @@ public:
     if (const uint32_t result = tryToAssignToVectorElements(lhs, rhs)) {
       return result;
     }
+    // Assigning to matrix swizzling should be handled differently.
+    if (const uint32_t result = tryToAssignToMatrixElements(lhs, rhs)) {
+      return result;
+    }
 
     // Normal assignment procedure
     if (lhsPtr == 0)
@@ -1474,6 +1545,63 @@ public:
     return theBuilder.createVectorShuffle(type, baseVal, baseVal, selectors);
   }
 
+  uint32_t doExtMatrixElementExpr(const ExtMatrixElementExpr *expr) {
+    const Expr *baseExpr = expr->getBase();
+    const uint32_t base = doExpr(baseExpr);
+    const auto accessor = expr->getEncodedElementAccess();
+    const uint32_t elemType = typeTranslator.translateType(
+        hlsl::GetHLSLMatElementType(baseExpr->getType()));
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(baseExpr->getType(), rowCount, colCount);
+
+    // Construct a temporary vector out of all elements accessed:
+    // 1. Create access chain for each element using OpAccessChain
+    // 2. Load each element using OpLoad
+    // 3. Create the vector using OpCompositeConstruct
+
+    llvm::SmallVector<uint32_t, 4> elements;
+    for (uint32_t i = 0; i < accessor.Count; ++i) {
+      uint32_t row = 0, col = 0, elem = 0;
+      accessor.GetPosition(i, &row, &col);
+
+      llvm::SmallVector<uint32_t, 2> indices;
+      // If the matrix only have one row/column, we are indexing into a vector
+      // then. Only one index is needed for such cases.
+      if (rowCount > 1)
+        indices.push_back(row);
+      if (colCount > 1)
+        indices.push_back(col);
+
+      if (baseExpr->isGLValue()) {
+        for (uint32_t i = 0; i < indices.size(); ++i)
+          indices[i] = theBuilder.getConstantInt32(indices[i]);
+
+        // TODO: select storage type based on the underlying variable
+        const uint32_t ptrType =
+            theBuilder.getPointerType(elemType, spv::StorageClass::Function);
+        if (!indices.empty()) {
+          // Load the element via access chain
+          elem = theBuilder.createAccessChain(ptrType, base, indices);
+        } else {
+          // The matrix is of size 1x1. No need to use access chain, base should
+          // be the source pointer.
+          elem = base;
+        }
+        elem = theBuilder.createLoad(elemType, elem);
+      } else { // e.g., (mat1 + mat2)._m11
+        elem = theBuilder.createCompositeExtract(elemType, base, indices);
+      }
+      elements.push_back(elem);
+    }
+
+    if (elements.size() == 1)
+      return elements.front();
+
+    const uint32_t vecType = theBuilder.getVecType(elemType, elements.size());
+    return theBuilder.createCompositeConstruct(vecType, elements);
+  }
+
   /// Returns true if the given expression will be translated into a vector
   /// shuffle instruction in SPIR-V.
   ///
@@ -1522,11 +1650,13 @@ public:
     switch (expr->getCastKind()) {
     case CastKind::CK_LValueToRValue: {
       const uint32_t fromValue = doExpr(subExpr);
-      if (isVectorShuffle(subExpr)) {
-        // By reaching here, it means the vector element accessing operation is
-        // an lvalue. If we generated a vector shuffle for it and trying to use
-        // it as a rvalue, we cannot do the load here as normal. Need the upper
-        // nodes in the AST tree to handle it properly.
+      if (isVectorShuffle(subExpr) || isa<ExtMatrixElementExpr>(subExpr)) {
+        // By reaching here, it means the vector/matrix element accessing
+        // operation is an lvalue. For vector element accessing, if we generated
+        // a vector shuffle for it and trying to use it as a rvalue, we cannot
+        // do the load here as normal. Need the upper nodes in the AST tree to
+        // handle it properly. For matrix element accessing, load should have
+        // already happened after creating access chain for each element.
         return fromValue;
       }
 

+ 30 - 0
tools/clang/test/CodeGenSPIRV/op.matrix.access.1x1.hlsl

@@ -0,0 +1,30 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    float1x1 mat;
+    float3 vec3;
+    float2 vec2;
+    float scalar;
+
+    // 1 element (from lvalue)
+// CHECK:      [[load0:%\d+]] = OpLoad %float %mat
+// CHECK-NEXT: OpStore %scalar [[load0]]
+    scalar = mat._m00; // Used as rvalue
+// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
+// CHECK-NEXT: OpStore %mat [[load1]]
+    mat._11 = scalar; // Used as lvalue
+
+    // >1 elements (from lvalue)
+// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float %mat
+// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float %mat
+// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
+// CHECK-NEXT: OpStore %vec2 [[cc0]]
+    vec2 = mat._11_11; // Used as rvalue
+
+    // The following statements will trigger errors:
+    //   invalid format for vector swizzle
+    // scalar = (mat + mat)._m00;
+    // vec2 = (mat * mat)._11_11;
+}

+ 44 - 0
tools/clang/test/CodeGenSPIRV/op.matrix.access.1xn.hlsl

@@ -0,0 +1,44 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    float1x3 mat;
+    float3 vec3;
+    float2 vec2;
+    float scalar;
+
+    // 1 element (from lvalue)
+// CHECK:      [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
+// CHECK-NEXT: OpStore %scalar [[load0]]
+    scalar = mat._m02; // Used as rvalue
+// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
+// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
+// CHECK-NEXT: OpStore [[access1]] [[load1]]
+    mat._12 = scalar; // Used as lvalue
+
+    // > 1 elements (from lvalue)
+// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
+// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
+// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
+// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
+// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
+// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
+// CHECK-NEXT: OpStore %vec3 [[cc0]]
+    vec3 = mat._11_13_12; // Used as rvalue
+// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
+// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
+// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
+// CHECK-NEXT: OpStore [[access5]] [[ce0]]
+// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
+// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: OpStore [[access6]] [[ce1]]
+    mat._m00_m02 = vec2; // Used as lvalue
+
+    // The following statements will trigger errors:
+    //   invalid format for vector swizzle
+    // scalar = (mat + mat)._m02;
+    // vec2 = (mat * mat)._11_12;
+}

+ 61 - 0
tools/clang/test/CodeGenSPIRV/op.matrix.access.mx1.hlsl

@@ -0,0 +1,61 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    float3x1 mat;
+    float3 vec3;
+    float2 vec2;
+    float scalar;
+
+    // 1 element (from lvalue)
+// CHECK:      [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
+// CHECK-NEXT: OpStore %scalar [[load0]]
+    scalar = mat._m20; // Used as rvalue
+// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
+// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
+// CHECK-NEXT: OpStore [[access1]] [[load1]]
+    mat._21 = scalar; // Used as lvalue
+
+    // > 1 elements (from lvalue)
+// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
+// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
+// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
+// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1
+// CHECK-NEXT: [[load4:%\d+]] = OpLoad %float [[access4]]
+// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v3float [[load2]] [[load3]] [[load4]]
+// CHECK-NEXT: OpStore %vec3 [[cc0]]
+    vec3 = mat._11_31_21; // Used as rvalue
+// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v2float %vec2
+// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
+// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0
+// CHECK-NEXT: OpStore [[access5]] [[ce0]]
+// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
+// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_2
+// CHECK-NEXT: OpStore [[access6]] [[ce1]]
+    mat._m00_m20 = vec2; // Used as lvalue
+
+    // 1 element (from rvalue)
+// CHECK-NEXT: [[load5:%\d+]] = OpLoad %v3float %mat
+// CHECK-NEXT: [[load6:%\d+]] = OpLoad %v3float %mat
+// CHECK-NEXT: [[add0:%\d+]] = OpFAdd %v3float [[load5]] [[load6]]
+// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[add0]] 2
+// CHECK-NEXT: OpStore %scalar [[ce2]]
+    // Codegen: construct a temporary vector first out of (mat + mat) and
+    // then extract the value
+    scalar = (mat + mat)._m20;
+
+    // > 1 element (from rvalue)
+// CHECK-NEXT: [[load7:%\d+]] = OpLoad %v3float %mat
+// CHECK-NEXT: [[load8:%\d+]] = OpLoad %v3float %mat
+// CHECK-NEXT: [[mul0:%\d+]] = OpFMul %v3float [[load7]] [[load8]]
+// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[mul0]] 0
+// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[mul0]] 1
+// CHECK-NEXT: [[cc1:%\d+]] = OpCompositeConstruct %v2float [[ce3]] [[ce4]]
+// CHECK-NEXT: OpStore %vec2 [[cc1]]
+    // Codegen: construct a temporary vector first out of (mat * mat) and
+    // then extract the value
+    vec2 = (mat * mat)._11_21;
+}

+ 58 - 0
tools/clang/test/CodeGenSPIRV/op.matrix.access.mxn.hlsl

@@ -0,0 +1,58 @@
+// Run: %dxc -T vs_6_0 -E main
+
+void main() {
+// CHECK-LABEL: %bb_entry = OpLabel
+
+    float2x3 mat;
+    float3 vec3;
+    float2 vec2;
+    float scalar;
+
+    // 1 element (from lvalue)
+// CHECK:      [[access0:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_2
+// CHECK-NEXT: [[load0:%\d+]] = OpLoad %float [[access0]]
+// CHECK-NEXT: OpStore %scalar [[load0]]
+    scalar = mat._m12; // Used as rvalue
+// CHECK-NEXT: [[load1:%\d+]] = OpLoad %float %scalar
+// CHECK-NEXT: [[access1:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
+// CHECK-NEXT: OpStore [[access1]] [[load1]]
+    mat._12 = scalar; // Used as lvalue
+
+    // >1 elements (from lvalue)
+// CHECK-NEXT: [[access2:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
+// CHECK-NEXT: [[load2:%\d+]] = OpLoad %float [[access2]]
+// CHECK-NEXT: [[access3:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_2
+// CHECK-NEXT: [[load3:%\d+]] = OpLoad %float [[access3]]
+// CHECK-NEXT: [[cc0:%\d+]] = OpCompositeConstruct %v2float [[load2]] [[load3]]
+// CHECK-NEXT: OpStore %vec2 [[cc0]]
+    vec2 = mat._m01_m02; // Used as rvalue
+// CHECK-NEXT: [[rhs0:%\d+]] = OpLoad %v3float %vec3
+// CHECK-NEXT: [[ce0:%\d+]] = OpCompositeExtract %float [[rhs0]] 0
+// CHECK-NEXT: [[access4:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_1 %int_0
+// CHECK-NEXT: OpStore [[access4]] [[ce0]]
+// CHECK-NEXT: [[ce1:%\d+]] = OpCompositeExtract %float [[rhs0]] 1
+// CHECK-NEXT: [[access5:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_1
+// CHECK-NEXT: OpStore [[access5]] [[ce1]]
+// CHECK-NEXT: [[ce2:%\d+]] = OpCompositeExtract %float [[rhs0]] 2
+// CHECK-NEXT: [[access6:%\d+]] = OpAccessChain %_ptr_Function_float %mat %int_0 %int_0
+// CHECK-NEXT: OpStore [[access6]] [[ce2]]
+    mat._21_12_11 = vec3; // Used as lvalue
+
+    // 1 element (from rvalue)
+// CHECK:      [[cc1:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[ce3:%\d+]] = OpCompositeExtract %float [[cc1]] 1 2
+// CHECK-NEXT: OpStore %scalar [[ce3]]
+    // Codegen: construct a temporary matrix first out of (mat + mat) and
+    // then extract the value
+    scalar = (mat + mat)._m12;
+
+    // > 1 element (from rvalue)
+// CHECK:      [[cc2:%\d+]] = OpCompositeConstruct %mat2v3float {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[ce4:%\d+]] = OpCompositeExtract %float [[cc2]] 0 1
+// CHECK-NEXT: [[ce5:%\d+]] = OpCompositeExtract %float [[cc2]] 0 2
+// CHECK-NEXT: [[cc3:%\d+]] = OpCompositeConstruct %v2float [[ce4]] [[ce5]]
+// CHECK-NEXT: OpStore %vec2 [[cc3]]
+    // Codegen: construct a temporary matrix first out of (mat * mat) and
+    // then extract the value
+    vec2 = (mat * mat)._m01_m02;
+}

+ 14 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -153,6 +153,20 @@ TEST_F(FileTest, OpVectorSize1Swizzle) {
   runFileTest("op.vector.swizzle.size1.hlsl");
 }
 
+// For matrix accessing operators
+TEST_F(FileTest, OpMatrixAccessMxN) {
+  runFileTest("op.matrix.access.mxn.hlsl");
+}
+TEST_F(FileTest, OpMatrixAccessMx1) {
+  runFileTest("op.matrix.access.mx1.hlsl");
+}
+TEST_F(FileTest, OpMatrixAccess1xN) {
+  runFileTest("op.matrix.access.1xn.hlsl");
+}
+TEST_F(FileTest, OpMatrixAccess1x1) {
+  runFileTest("op.matrix.access.1x1.hlsl");
+}
+
 // For casting
 TEST_F(FileTest, CastNoOp) { runFileTest("cast.no-op.hlsl"); }
 TEST_F(FileTest, CastImplicit2Bool) { runFileTest("cast.2bool.implicit.hlsl"); }