Ver código fonte

[spirv] Add support for floating point matrix types (#504)

Lei Zhang 8 anos atrás
pai
commit
7ecb030694

+ 1 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -201,6 +201,7 @@ public:
   uint32_t getUint32Type();
   uint32_t getFloat32Type();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
+  uint32_t getMatType(uint32_t colType, uint32_t colCount);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);
   uint32_t getStructType(llvm::ArrayRef<uint32_t> fieldTypes);
   uint32_t getFunctionType(uint32_t returnType,

+ 8 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -357,6 +357,14 @@ uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
   return typeId;
 }
 
+uint32_t ModuleBuilder::getMatType(uint32_t colType, uint32_t colCount) {
+  const Type *type = Type::getMatrix(theContext, colType, colCount);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+
+  return typeId;
+}
+
 uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
                                        spv::StorageClass storageClass) {
   const Type *type = Type::getPointer(theContext, storageClass, pointeeType);

+ 37 - 1
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -40,8 +40,9 @@ uint32_t TypeTranslator::translateType(QualType type) {
     return translateType(typedefType->desugar());
   }
 
-  // In AST, vector types are TypedefType of TemplateSpecializationType.
+  // In AST, vector/matrix types are TypedefType of TemplateSpecializationType.
   // We handle them via HLSL type inspection functions.
+
   if (hlsl::IsHLSLVecType(type)) {
     const auto elemType = hlsl::GetHLSLVecElementType(type);
     const auto elemCount = hlsl::GetHLSLVecSize(type);
@@ -53,6 +54,41 @@ uint32_t TypeTranslator::translateType(QualType type) {
     return theBuilder.getVecType(translateType(elemType), elemCount);
   }
 
+  if (hlsl::IsHLSLMatType(type)) {
+    const auto elemTy = hlsl::GetHLSLMatElementType(type);
+    // NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
+    // Validation Rules":
+    //   Matrix types can only be parameterized with floating-point types.
+    //
+    // So we need special handling of non-fp matrices, probably by emulating
+    // them using other types. But for now just disable them.
+    if (!elemTy->isFloatingType()) {
+      emitError("Non-floating-point matrices not supported yet");
+      return 0;
+    }
+    const auto elemType = translateType(elemTy);
+
+    uint32_t rowCount = 0, colCount = 0;
+    hlsl::GetHLSLMatRowColCount(type, rowCount, colCount);
+
+    // In SPIR-V, matrices must have two or more columns.
+    // Handle degenerated cases first.
+
+    if (rowCount == 1 && colCount == 1)
+      return elemType;
+
+    if (rowCount == 1)
+      return theBuilder.getVecType(elemType, colCount);
+
+    if (colCount == 1)
+      return theBuilder.getVecType(elemType, rowCount);
+
+    // HLSL matrices are row major, while SPIR-V matrices are column major.
+    // We are mapping what HLSL semantically mean a row into a column here.
+    const uint32_t vecType = theBuilder.getVecType(elemType, colCount);
+    return theBuilder.getMatType(vecType, rowCount);
+  }
+
   // Struct type
   if (const auto *structType = dyn_cast<RecordType>(typePtr)) {
     const auto *decl = structType->getDecl();

+ 99 - 0
tools/clang/test/CodeGenSPIRV/type.matrix.hlsl

@@ -0,0 +1,99 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// NOTE: According to Item "Data rules" of SPIR-V Spec 2.16.1 "Universal
+// Validation Rules":
+//   Matrix types can only be parameterized with floating-point types.
+//
+// So we need special handling for matrices with non-fp elements. An extension
+// to SPIR-V to lessen the above rule is a possible way, which will enable the
+// generation of SPIR-V currently commented out. Or we can emulate them using
+// other types.
+
+void main() {
+// XXXXX: %int = OpTypeInt 32 1
+// XXXXX: %uint = OpTypeInt 32 0
+
+// CHECK: %float = OpTypeFloat 32
+    float1x1 mat11;
+// XXXXX: %v2int = OpTypeVector %int 2
+    //int1x2   mat12;
+// XXXXX: %v3uint = OpTypeVector %uint 3
+    //uint1x3  mat13;
+// XXXXX: %bool = OpTypeBool
+// XXXXX-NEXT: %v4bool = OpTypeVector %bool 4
+    //bool1x4  mat14;
+
+    //int2x1   mat21;
+// XXXXX: %v2uint = OpTypeVector %uint 2
+// XXXXX-NEXT: %mat2v2uint = OpTypeMatrix %v2uint 2
+    //uint2x2  mat22;
+// XXXXX: %v3bool = OpTypeVector %bool 3
+// XXXXX-NEXT: %mat2v3bool = OpTypeMatrix %v3bool 2
+    //bool2x3  mat23;
+// CHECK: %v4float = OpTypeVector %float 4
+// CHECK-NEXT: %mat2v4float = OpTypeMatrix %v4float 2
+    float2x4 mat24;
+
+    //uint3x1  mat31;
+// XXXXX: %v2bool = OpTypeVector %bool 2
+// XXXXX-NEXT: %mat3v2bool = OpTypeMatrix %v2bool 3
+    //bool3x2  mat32;
+// CHECK: %v3float = OpTypeVector %float 3
+// CHECK-NEXT: %mat3v3float = OpTypeMatrix %v3float 3
+    float3x3 mat33;
+// XXXXX: %v4int = OpTypeVector %int 4
+// XXXXX-NEXT: %mat3v4int = OpTypeMatrix %v4int 3
+    //int3x4   mat34;
+
+    //bool4x1  mat41;
+// CHECK: %v2float = OpTypeVector %float 2
+// CHECK-NEXT: %mat4v2float = OpTypeMatrix %v2float 4
+    float4x2 mat42;
+// XXXXX: %v3int = OpTypeVector %int 3
+// XXXXX-NEXT: %mat4v3int = OpTypeMatrix %v3int 4
+    //int4x3   mat43;
+// XXXXX: %v4uint = OpTypeVector %uint 4
+// XXXXX-NEXT: %mat4v4uint = OpTypeMatrix %v4uint 4
+    //uint4x4  mat44;
+
+// CHECK: %mat4v4float = OpTypeMatrix %v4float 4
+    matrix mat;
+
+    //matrix<int, 1, 1>   imat11;
+    //matrix<uint, 1, 3>  umat23;
+    matrix<float, 2, 1> fmat21;
+    matrix<float, 1, 2> fmat12;
+// XXXXX: %mat3v4bool = OpTypeMatrix %v4bool 3
+    //matrix<bool, 3, 4>  bmat34;
+
+// CHECK-LABEL: %bb_entry = OpLabel
+
+
+// CHECK-NEXT: %mat11 = OpVariable %_ptr_Function_float Function
+// XXXXX-NEXT: %mat12 = OpVariable %_ptr_Function_v2int Function
+// XXXXX-NEXT: %mat13 = OpVariable %_ptr_Function_v3uint Function
+// XXXXX-NEXT: %mat14 = OpVariable %_ptr_Function_v4bool Function
+
+// XXXXX-NEXT: %mat21 = OpVariable %_ptr_Function_v2int Function
+// XXXXX-NEXT: %mat22 = OpVariable %_ptr_Function_mat2v2uint Function
+// XXXXX-NEXT: %mat23 = OpVariable %_ptr_Function_mat2v3bool Function
+// CHECK-NEXT: %mat24 = OpVariable %_ptr_Function_mat2v4float Function
+
+// XXXXX-NEXT: %mat31 = OpVariable %_ptr_Function_v3uint Function
+// XXXXX-NEXT: %mat32 = OpVariable %_ptr_Function_mat3v2bool Function
+// CHECK-NEXT: %mat33 = OpVariable %_ptr_Function_mat3v3float Function
+// XXXXX-NEXT: %mat34 = OpVariable %_ptr_Function_mat3v4int Function
+
+// XXXXX-NEXT: %mat41 = OpVariable %_ptr_Function_v4bool Function
+// CHECK-NEXT: %mat42 = OpVariable %_ptr_Function_mat4v2float Function
+// XXXXX-NEXT: %mat43 = OpVariable %_ptr_Function_mat4v3int Function
+// XXXXX-NEXT: %mat44 = OpVariable %_ptr_Function_mat4v4uint Function
+
+// CHECK-NEXT: %mat = OpVariable %_ptr_Function_mat4v4float Function
+
+// XXXXX-NEXT: %imat11 = OpVariable %_ptr_Function_int Function
+// XXXXX-NEXT: %umat23 = OpVariable %_ptr_Function_v3uint Function
+// CHECK-NEXT: %fmat21 = OpVariable %_ptr_Function_v2float Function
+// CHECK-NEXT: %fmat12 = OpVariable %_ptr_Function_v2float Function
+// XXXXX-NEXT: %bmat34 = OpVariable %_ptr_Function_mat3v4bool Function
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -37,6 +37,7 @@ TEST_F(WholeFileTest, ConstantPixelShader) {
 // For types
 TEST_F(FileTest, ScalarTypes) { runFileTest("type.scalar.hlsl"); }
 TEST_F(FileTest, VectorTypes) { runFileTest("type.vector.hlsl"); }
+TEST_F(FileTest, MatrixTypes) { runFileTest("type.matrix.hlsl"); }
 
 // For constants
 TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }