Sfoglia il codice sorgente

[spirv] Support other primitive int/float types. (#628)

* [spirv] Support other primitive int/float types.

* Address comments.
Ehsan 8 anni fa
parent
commit
479f1cd9bf

+ 0 - 4
docs/SPIR-V.rst

@@ -104,8 +104,6 @@ compatibility. Direct3D 10 shader targets map all ``half`` data types to
 ``float`` data types." This may change in the future to map to 16-bit floating
 point numbers (possibly via a command-line option).
 
-Note: ``float`` and ``double`` not implemented yet
-
 Minimal precision scalar types
 ------------------------------
 
@@ -126,8 +124,6 @@ the corresponding 32-bit scalar types with the ``RelexedPrecision`` decoration:
 ``min16uint``  ``OpTypeInt 32 0`` ``RelexedPrecision``
 ============== ================== ====================
 
-Note: not implemented yet
-
 Vectors and matrices
 --------------------
 

+ 1 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -296,6 +296,7 @@ public:
   uint32_t getInt32Type();
   uint32_t getUint32Type();
   uint32_t getFloat32Type();
+  uint32_t getFloat64Type();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getMatType(uint32_t colType, uint32_t colCount);
   uint32_t getPointerType(uint32_t pointeeType, spv::StorageClass);

+ 18 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -553,6 +553,9 @@ void ModuleBuilder::decorate(uint32_t targetId, spv::Decoration decoration) {
   case spv::Decoration::Block:
     d = Decoration::getBlock(theContext);
     break;
+  case spv::Decoration::RelaxedPrecision:
+    d = Decoration::getRelaxedPrecision(theContext);
+    break;
   }
 
   assert(d && "unimplemented decoration");
@@ -577,6 +580,21 @@ IMPL_GET_PRIMITIVE_TYPE(Float32)
 
 #undef IMPL_GET_PRIMITIVE_TYPE
 
+#define IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(ty)                            \
+  \
+uint32_t ModuleBuilder::get##ty##Type() {                                      \
+    requireCapability(spv::Capability::ty);                                    \
+    const Type *type = Type::get##ty(theContext);                              \
+    const uint32_t typeId = theContext.getResultIdForType(type);               \
+    theModule.addType(type, typeId);                                           \
+    return typeId;                                                             \
+  \
+}
+
+IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY(Float64)
+
+#undef IMPL_GET_PRIMITIVE_TYPE_WITH_CAPABILITY
+
 uint32_t ModuleBuilder::getVecType(uint32_t elemType, uint32_t elemCount) {
   const Type *type = nullptr;
   switch (elemCount) {

+ 7 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -466,6 +466,8 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
 }
 
 void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
+  uint32_t varId = 0;
+
   // The contents in externally visible variables can be updated via the
   // pipeline. They should be handled differently from file and function scope
   // variables.
@@ -495,7 +497,6 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       constInit = llvm::Optional<uint32_t>(theBuilder.getConstantNull(varType));
     }
 
-    uint32_t varId;
     if (isFileScopeVar)
       varId = declIdMapper.createFileVar(varType, decl, constInit);
     else
@@ -521,7 +522,11 @@ void SPIRVEmitter::doVarDecl(const VarDecl *decl) {
       }
     }
   } else {
-    (void)declIdMapper.createExternVar(decl);
+    varId = declIdMapper.createExternVar(decl);
+  }
+
+  if (TypeTranslator::isRelaxedPrecisionType(decl->getType())) {
+    theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision);
   }
 }
 

+ 40 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -34,6 +34,33 @@ inline void roundToPow2(uint32_t *val, uint32_t pow2) {
 }
 } // anonymous namespace
 
+bool TypeTranslator::isRelaxedPrecisionType(QualType type) {
+  // Primitive types
+  {
+    QualType ty = {};
+    if (isScalarType(type, &ty))
+      if (const auto *builtinType = ty->getAs<BuiltinType>())
+        switch (builtinType->getKind()) {
+        case BuiltinType::Short:
+        case BuiltinType::UShort:
+        case BuiltinType::Min12Int:
+        case BuiltinType::Min10Float:
+        case BuiltinType::Half:
+          return true;
+        }
+  }
+
+  // Vector & Matrix types could use relaxed precision based on their element
+  // type.
+  {
+    QualType elemType = {};
+    if (isVectorType(type, &elemType) || isMxNMatrix(type, &elemType))
+      return isRelaxedPrecisionType(elemType);
+  }
+
+  return false;
+}
+
 uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
                                        bool isRowMajor) {
   // We can only apply row_major to matrices or arrays of matrices.
@@ -55,12 +82,25 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
           return theBuilder.getVoidType();
         case BuiltinType::Bool:
           return theBuilder.getBoolType();
+        // int, min16int (short), and min12int are all translated to 32-bit
+        // signed integers in SPIR-V.
         case BuiltinType::Int:
+        case BuiltinType::Short:
+        case BuiltinType::Min12Int:
           return theBuilder.getInt32Type();
+        // uint and min16uint (ushort) are both translated to 32-bit unsigned
+        // integers in SPIR-V.
+        case BuiltinType::UShort:
         case BuiltinType::UInt:
           return theBuilder.getUint32Type();
+        // float, min16float (half), and min10float are all translated to 32-bit
+        // float in SPIR-V.
         case BuiltinType::Float:
+        case BuiltinType::Half:
+        case BuiltinType::Min10Float:
           return theBuilder.getFloat32Type();
+        case BuiltinType::Double:
+          return theBuilder.getFloat64Type();
         default:
           emitError("Primitive type '%0' is not supported yet.")
               << builtinType->getTypeClassName();

+ 5 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -125,6 +125,11 @@ public:
   /// counts.
   static bool isSpirvAcceptableMatrixType(QualType type);
 
+  /// \brief Returns true if the given type can use relaxed precision
+  /// decoration. Integer and float types with lower than 32 bits can be
+  /// operated on with a relaxed precision.
+  static bool isRelaxedPrecisionType(QualType);
+
   /// \brief Returns the the element type for the given scalar/vector/matrix
   /// type. Returns empty QualType for other cases.
   QualType getElementType(QualType type);

+ 70 - 23
tools/clang/test/CodeGenSPIRV/type.scalar.hlsl

@@ -1,38 +1,85 @@
 // Run: %dxc -T ps_6_0 -E main
 
-// TODO
-// - 16bit & 64bit integers/floats (require additional capabilities)
+// CHECK: OpCapability Float64
+
+// CHECK: OpDecorate %m16i RelaxedPrecision
+// CHECK: OpDecorate %m12i RelaxedPrecision
+// CHECK: OpDecorate %m16u RelaxedPrecision
+// CHECK: OpDecorate %m16f RelaxedPrecision
+// CHECK: OpDecorate %m10f RelaxedPrecision
+// CHECK: OpDecorate %m16i1 RelaxedPrecision
+// CHECK: OpDecorate %m12i1 RelaxedPrecision
+// CHECK: OpDecorate %m16u1 RelaxedPrecision
+// CHECK: OpDecorate %m16f1 RelaxedPrecision
+// CHECK: OpDecorate %m10f1 RelaxedPrecision
 
 // CHECK-DAG: %void = OpTypeVoid
 // CHECK-DAG: %{{[0-9]+}} = OpTypeFunction %void
 void main() {
+
 // CHECK-DAG: %bool = OpTypeBool
 // CHECK-DAG: %_ptr_Function_bool = OpTypePointer Function %bool
-  bool a;
+  bool boolvar;
+
 // CHECK-DAG: %int = OpTypeInt 32 1
 // CHECK-DAG: %_ptr_Function_int = OpTypePointer Function %int
-  int b;
+  int      intvar;
+  min16int m16i;
+  min12int m12i;
+
 // CHECK-DAG: %uint = OpTypeInt 32 0
 // CHECK-DAG: %_ptr_Function_uint = OpTypePointer Function %uint
-  uint c;
-  dword d;
+  uint      uintvar;
+  dword     dwordvar;
+  min16uint m16u;
+
 // CHECK-DAG: %float = OpTypeFloat 32
 // CHECK-DAG: %_ptr_Function_float = OpTypePointer Function %float
-  float e;
-  bool1 a1;
-  int1 b1;
-  uint1 c1;
-  dword1 d1;
-  float1 e1;
-
-// CHECK: %a = OpVariable %_ptr_Function_bool Function
-// CHECK-NEXT: %b = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: %c = OpVariable %_ptr_Function_uint Function
-// CHECK-NEXT: %d = OpVariable %_ptr_Function_uint Function
-// CHECK-NEXT: %e = OpVariable %_ptr_Function_float Function
-// CHECK-NEXT: %a1 = OpVariable %_ptr_Function_bool Function
-// CHECK-NEXT: %b1 = OpVariable %_ptr_Function_int Function
-// CHECK-NEXT: %c1 = OpVariable %_ptr_Function_uint Function
-// CHECK-NEXT: %d1 = OpVariable %_ptr_Function_uint Function
-// CHECK-NEXT: %e1 = OpVariable %_ptr_Function_float Function
+  float      floatvar;
+  half       halfvar;
+  min16float m16f;
+  min10float m10f;
+
+// CHECK-DAG: %double = OpTypeFloat 64
+// CHECK-DAG: %_ptr_Function_double = OpTypePointer Function %double
+  double doublevar;
+
+// These following variables should use the types already defined above.
+  bool1       boolvar1;
+  int1        intvar1;
+  min16int1   m16i1;
+  min12int1   m12i1;
+  uint1       uintvar1;
+  dword1      dwordvar1;
+  min16uint1  m16u1;
+  float1      floatvar1;
+  half1       halfvar1;
+  min16float1 m16f1;
+  min10float1 m10f1;
+  double1     doublevar1;
+
+// CHECK:         %boolvar = OpVariable %_ptr_Function_bool Function
+// CHECK-NEXT:     %intvar = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:       %m16i = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:       %m12i = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:    %uintvar = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:   %dwordvar = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:       %m16u = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:   %floatvar = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:    %halfvar = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:       %m16f = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:       %m10f = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:  %doublevar = OpVariable %_ptr_Function_double Function
+// CHECK-NEXT:   %boolvar1 = OpVariable %_ptr_Function_bool Function
+// CHECK-NEXT:    %intvar1 = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:      %m16i1 = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:      %m12i1 = OpVariable %_ptr_Function_int Function
+// CHECK-NEXT:   %uintvar1 = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:  %dwordvar1 = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:      %m16u1 = OpVariable %_ptr_Function_uint Function
+// CHECK-NEXT:  %floatvar1 = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:   %halfvar1 = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:      %m16f1 = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT:      %m10f1 = OpVariable %_ptr_Function_float Function
+// CHECK-NEXT: %doublevar1 = OpVariable %_ptr_Function_double Function
 }