Przeglądaj źródła

fix handling of values around infinity for e4m3 constants

Jeff Bolz 1 miesiąc temu
rodzic
commit
d09cc26649

+ 19 - 4
SPIRV/hex_float.h

@@ -252,6 +252,7 @@ struct HexFloatTraits {
   // The bias of the exponent. (How much we need to subtract from the stored
   // value to get the correct value.)
   static const uint32_t exponent_bias = 0;
+  static bool supportsInfinity() { return true; }
 };
 
 // Traits for IEEE float.
@@ -266,6 +267,7 @@ struct HexFloatTraits<FloatProxy<float>> {
   static const uint_type num_exponent_bits = 8;
   static const uint_type num_fraction_bits = 23;
   static const uint_type exponent_bias = 127;
+  static bool supportsInfinity() { return true; }
 };
 
 // Traits for IEEE double.
@@ -280,6 +282,7 @@ struct HexFloatTraits<FloatProxy<double>> {
   static const uint_type num_exponent_bits = 11;
   static const uint_type num_fraction_bits = 52;
   static const uint_type exponent_bias = 1023;
+  static bool supportsInfinity() { return true; }
 };
 
 // Traits for IEEE half.
@@ -294,6 +297,7 @@ struct HexFloatTraits<FloatProxy<Float16>> {
   static const uint_type num_exponent_bits = 5;
   static const uint_type num_fraction_bits = 10;
   static const uint_type exponent_bias = 15;
+  static bool supportsInfinity() { return true; }
 };
 
 template <>
@@ -306,6 +310,7 @@ struct HexFloatTraits<FloatProxy<FloatE5M2>> {
   static const uint_type num_exponent_bits = 5;
   static const uint_type num_fraction_bits = 2;
   static const uint_type exponent_bias = 15;
+  static bool supportsInfinity() { return true; }
 };
 
 template <>
@@ -318,6 +323,7 @@ struct HexFloatTraits<FloatProxy<FloatE4M3>> {
   static const uint_type num_exponent_bits = 4;
   static const uint_type num_fraction_bits = 3;
   static const uint_type exponent_bias = 7;
+  static bool supportsInfinity() { return false; }
 };
 
 enum round_direction {
@@ -337,6 +343,7 @@ class HexFloat {
   typedef typename Traits::int_type int_type;
   typedef typename Traits::underlying_type underlying_type;
   typedef typename Traits::native_type native_type;
+  using Traits_T = Traits;
 
   explicit HexFloat(T f) : value_(f) {}
 
@@ -678,14 +685,22 @@ class HexFloat {
         (getBits() & exponent_mask) == exponent_mask && significand != 0;
     bool is_inf =
         !is_nan &&
-        ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) ||
+        (((exponent + carried) > static_cast<int_type>(other_T::exponent_bias) && other_T::Traits_T::supportsInfinity()) ||
+         ((exponent + carried) > static_cast<int_type>(other_T::exponent_bias + 1) && !other_T::Traits_T::supportsInfinity()) ||
          (significand == 0 && (getBits() & exponent_mask) == exponent_mask));
 
     // If we are Nan or Inf we should pass that through.
     if (is_inf) {
-      other.set_value(BitwiseCast<typename other_T::underlying_type>(
-          static_cast<typename other_T::uint_type>(
-              (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
+      if (other_T::Traits_T::supportsInfinity()) {
+        // encode as +/-inf
+        other.set_value(BitwiseCast<typename other_T::underlying_type>(
+            static_cast<typename other_T::uint_type>(
+                (negate ? other_T::sign_mask : 0) | other_T::exponent_mask)));
+      } else {
+        // encode as +/-nan
+        other.set_value(BitwiseCast<typename other_T::underlying_type>(
+            static_cast<typename other_T::uint_type>(negate ? ~0 : ~other_T::sign_mask)));
+      }
       return;
     }
     if (is_nan) {

+ 110 - 0
Test/baseResults/spv.floate4m3.const.comp.out

@@ -0,0 +1,110 @@
+spv.floate4m3.const.comp
+// Module Version 10600
+// Generated by (magic number): 8000b
+// Id's are bound by 48
+
+                              Capability Shader
+                              Capability Int8
+                              Capability CapabilityFloat8EXT
+                              Extension  "SPV_EXT_float8"
+               1:             ExtInstImport  "GLSL.std.450"
+                              MemoryModel Logical GLSL450
+                              EntryPoint GLCompute 4  "main"
+                              ExecutionMode 4 LocalSize 1 1 1
+                              Source GLSL 450
+                              SourceExtension  "GL_EXT_bfloat16"
+                              SourceExtension  "GL_EXT_float_e4m3"
+                              SourceExtension  "GL_EXT_scalar_block_layout"
+                              SourceExtension  "GL_EXT_shader_explicit_arithmetic_types"
+                              SourceExtension  "GL_KHR_cooperative_matrix"
+                              SourceExtension  "GL_KHR_memory_scope_semantics"
+                              SourceExtension  "GL_NV_cooperative_matrix2"
+                              Name 4  "main"
+                              Name 8  "c01111_000"
+                              Name 10  "c01111_001"
+                              Name 12  "c01111_010"
+                              Name 14  "c01111_011"
+                              Name 16  "c01111_100"
+                              Name 18  "c01111_101"
+                              Name 20  "c01111_110"
+                              Name 22  "c01111_110_2"
+                              Name 23  "c01111_111"
+                              Name 25  "c01111_111_2"
+                              Name 26  "c11111_000"
+                              Name 28  "c11111_001"
+                              Name 30  "c11111_010"
+                              Name 32  "c11111_011"
+                              Name 34  "c11111_100"
+                              Name 36  "c11111_101"
+                              Name 38  "c11111_110"
+                              Name 40  "c11111_110_2"
+                              Name 41  "c11111_111"
+                              Name 43  "c11111_111_2"
+               2:             TypeVoid
+               3:             TypeFunction 2
+               6:             TypeFloat 8 4214
+               7:             TypePointer Function 6(floate4m3_t)
+               9:6(floate4m3_t) Constant 120
+              11:6(floate4m3_t) Constant 121
+              13:6(floate4m3_t) Constant 122
+              15:6(floate4m3_t) Constant 123
+              17:6(floate4m3_t) Constant 124
+              19:6(floate4m3_t) Constant 125
+              21:6(floate4m3_t) Constant 126
+              24:6(floate4m3_t) Constant 127
+              27:6(floate4m3_t) Constant 248
+              29:6(floate4m3_t) Constant 249
+              31:6(floate4m3_t) Constant 250
+              33:6(floate4m3_t) Constant 251
+              35:6(floate4m3_t) Constant 252
+              37:6(floate4m3_t) Constant 253
+              39:6(floate4m3_t) Constant 254
+              42:6(floate4m3_t) Constant 255
+              44:             TypeInt 32 0
+              45:             TypeVector 44(int) 3
+              46:     44(int) Constant 1
+              47:   45(ivec3) ConstantComposite 46 46 46
+         4(main):           2 Function None 3
+               5:             Label
+   8(c01111_000):      7(ptr) Variable Function
+  10(c01111_001):      7(ptr) Variable Function
+  12(c01111_010):      7(ptr) Variable Function
+  14(c01111_011):      7(ptr) Variable Function
+  16(c01111_100):      7(ptr) Variable Function
+  18(c01111_101):      7(ptr) Variable Function
+  20(c01111_110):      7(ptr) Variable Function
+22(c01111_110_2):      7(ptr) Variable Function
+  23(c01111_111):      7(ptr) Variable Function
+25(c01111_111_2):      7(ptr) Variable Function
+  26(c11111_000):      7(ptr) Variable Function
+  28(c11111_001):      7(ptr) Variable Function
+  30(c11111_010):      7(ptr) Variable Function
+  32(c11111_011):      7(ptr) Variable Function
+  34(c11111_100):      7(ptr) Variable Function
+  36(c11111_101):      7(ptr) Variable Function
+  38(c11111_110):      7(ptr) Variable Function
+40(c11111_110_2):      7(ptr) Variable Function
+  41(c11111_111):      7(ptr) Variable Function
+43(c11111_111_2):      7(ptr) Variable Function
+                              Store 8(c01111_000) 9
+                              Store 10(c01111_001) 11
+                              Store 12(c01111_010) 13
+                              Store 14(c01111_011) 15
+                              Store 16(c01111_100) 17
+                              Store 18(c01111_101) 19
+                              Store 20(c01111_110) 21
+                              Store 22(c01111_110_2) 21
+                              Store 23(c01111_111) 24
+                              Store 25(c01111_111_2) 24
+                              Store 26(c11111_000) 27
+                              Store 28(c11111_001) 29
+                              Store 30(c11111_010) 31
+                              Store 32(c11111_011) 33
+                              Store 34(c11111_100) 35
+                              Store 36(c11111_101) 37
+                              Store 38(c11111_110) 39
+                              Store 40(c11111_110_2) 39
+                              Store 41(c11111_111) 42
+                              Store 43(c11111_111_2) 42
+                              Return
+                              FunctionEnd

+ 80 - 0
Test/baseResults/spv.floate5m2.const.comp.out

@@ -0,0 +1,80 @@
+spv.floate5m2.const.comp
+// Module Version 10600
+// Generated by (magic number): 8000b
+// Id's are bound by 34
+
+                              Capability Shader
+                              Capability Int8
+                              Capability CapabilityFloat8EXT
+                              Extension  "SPV_EXT_float8"
+               1:             ExtInstImport  "GLSL.std.450"
+                              MemoryModel Logical GLSL450
+                              EntryPoint GLCompute 4  "main"
+                              ExecutionMode 4 LocalSize 1 1 1
+                              Source GLSL 450
+                              SourceExtension  "GL_EXT_bfloat16"
+                              SourceExtension  "GL_EXT_float_e5m2"
+                              SourceExtension  "GL_EXT_scalar_block_layout"
+                              SourceExtension  "GL_EXT_shader_explicit_arithmetic_types"
+                              SourceExtension  "GL_KHR_cooperative_matrix"
+                              SourceExtension  "GL_KHR_memory_scope_semantics"
+                              SourceExtension  "GL_NV_cooperative_matrix2"
+                              Name 4  "main"
+                              Name 8  "c011110_00"
+                              Name 10  "c011110_01"
+                              Name 12  "c011110_10"
+                              Name 14  "c011110_11"
+                              Name 16  "c011110_11_2"
+                              Name 17  "c011111_00"
+                              Name 19  "c111110_00"
+                              Name 21  "c111110_01"
+                              Name 23  "c111110_10"
+                              Name 25  "c111110_11"
+                              Name 27  "c111110_11_2"
+                              Name 28  "c111111_00"
+               2:             TypeVoid
+               3:             TypeFunction 2
+               6:             TypeFloat 8 4215
+               7:             TypePointer Function 6(floate5m2_t)
+               9:6(floate5m2_t) Constant 120
+              11:6(floate5m2_t) Constant 121
+              13:6(floate5m2_t) Constant 122
+              15:6(floate5m2_t) Constant 123
+              18:6(floate5m2_t) Constant 124
+              20:6(floate5m2_t) Constant 248
+              22:6(floate5m2_t) Constant 249
+              24:6(floate5m2_t) Constant 250
+              26:6(floate5m2_t) Constant 251
+              29:6(floate5m2_t) Constant 252
+              30:             TypeInt 32 0
+              31:             TypeVector 30(int) 3
+              32:     30(int) Constant 1
+              33:   31(ivec3) ConstantComposite 32 32 32
+         4(main):           2 Function None 3
+               5:             Label
+   8(c011110_00):      7(ptr) Variable Function
+  10(c011110_01):      7(ptr) Variable Function
+  12(c011110_10):      7(ptr) Variable Function
+  14(c011110_11):      7(ptr) Variable Function
+16(c011110_11_2):      7(ptr) Variable Function
+  17(c011111_00):      7(ptr) Variable Function
+  19(c111110_00):      7(ptr) Variable Function
+  21(c111110_01):      7(ptr) Variable Function
+  23(c111110_10):      7(ptr) Variable Function
+  25(c111110_11):      7(ptr) Variable Function
+27(c111110_11_2):      7(ptr) Variable Function
+  28(c111111_00):      7(ptr) Variable Function
+                              Store 8(c011110_00) 9
+                              Store 10(c011110_01) 11
+                              Store 12(c011110_10) 13
+                              Store 14(c011110_11) 15
+                              Store 16(c011110_11_2) 15
+                              Store 17(c011111_00) 18
+                              Store 19(c111110_00) 20
+                              Store 21(c111110_01) 22
+                              Store 23(c111110_10) 24
+                              Store 25(c111110_11) 26
+                              Store 27(c111110_11_2) 26
+                              Store 28(c111111_00) 29
+                              Return
+                              FunctionEnd

+ 35 - 0
Test/spv.floate4m3.const.comp

@@ -0,0 +1,35 @@
+#version 450 core
+
+#extension GL_EXT_bfloat16 : require
+#extension GL_EXT_float_e4m3 : require
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_EXT_shader_explicit_arithmetic_types : enable
+#extension GL_EXT_scalar_block_layout : enable
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+
+void main()
+{
+    floate4m3_t c01111_000 = floate4m3_t(1.0 * 256);
+    floate4m3_t c01111_001 = floate4m3_t(1.125 * 256);
+    floate4m3_t c01111_010 = floate4m3_t(1.25 * 256);
+    floate4m3_t c01111_011 = floate4m3_t(1.375 * 256);
+    floate4m3_t c01111_100 = floate4m3_t(1.5 * 256);
+    floate4m3_t c01111_101 = floate4m3_t(1.625 * 256);
+    floate4m3_t c01111_110 = floate4m3_t(1.75 * 256);
+    floate4m3_t c01111_110_2 = floate4m3_t(1.85 * 256);
+    floate4m3_t c01111_111 = floate4m3_t(1.95 * 256);
+    floate4m3_t c01111_111_2 = floate4m3_t(2.0 * 256);
+    floate4m3_t c11111_000 = floate4m3_t(-1.0 * 256);
+    floate4m3_t c11111_001 = floate4m3_t(-1.125 * 256);
+    floate4m3_t c11111_010 = floate4m3_t(-1.25 * 256);
+    floate4m3_t c11111_011 = floate4m3_t(-1.375 * 256);
+    floate4m3_t c11111_100 = floate4m3_t(-1.5 * 256);
+    floate4m3_t c11111_101 = floate4m3_t(-1.625 * 256);
+    floate4m3_t c11111_110 = floate4m3_t(-1.75 * 256);
+    floate4m3_t c11111_110_2 = floate4m3_t(-1.85 * 256);
+    floate4m3_t c11111_111 = floate4m3_t(-1.95 * 256);
+    floate4m3_t c11111_111_2 = floate4m3_t(-2.0 * 256);
+}

+ 27 - 0
Test/spv.floate5m2.const.comp

@@ -0,0 +1,27 @@
+#version 450 core
+
+#extension GL_EXT_bfloat16 : require
+#extension GL_EXT_float_e5m2 : require
+#extension GL_KHR_cooperative_matrix : enable
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_KHR_memory_scope_semantics : enable
+#extension GL_EXT_shader_explicit_arithmetic_types : enable
+#extension GL_EXT_scalar_block_layout : enable
+
+layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
+
+void main()
+{
+    floate5m2_t c011110_00 = floate5m2_t(1.0 * 32768);
+    floate5m2_t c011110_01 = floate5m2_t(1.25 * 32768);
+    floate5m2_t c011110_10 = floate5m2_t(1.5 * 32768);
+    floate5m2_t c011110_11 = floate5m2_t(1.75 * 32768);
+    floate5m2_t c011110_11_2 = floate5m2_t(1.85 * 32768);
+    floate5m2_t c011111_00 = floate5m2_t(2.0 * 32768);
+    floate5m2_t c111110_00 = floate5m2_t(-1.0 * 32768);
+    floate5m2_t c111110_01 = floate5m2_t(-1.25 * 32768);
+    floate5m2_t c111110_10 = floate5m2_t(-1.5 * 32768);
+    floate5m2_t c111110_11 = floate5m2_t(-1.75 * 32768);
+    floate5m2_t c111110_11_2 = floate5m2_t(-1.85 * 32768);
+    floate5m2_t c111111_00 = floate5m2_t(-2.0 * 32768);
+}

+ 2 - 0
gtests/Spv.FromFile.cpp

@@ -795,8 +795,10 @@ INSTANTIATE_TEST_SUITE_P(
         "spv.1.6.nontemporalimage.frag",
         "spv.noexplicitlayout.comp",
         "spv.floate4m3.comp",
+        "spv.floate4m3.const.comp",
         "spv.floate4m3_error.comp",
         "spv.floate5m2.comp",
+        "spv.floate5m2.const.comp",
         "spv.floate5m2_error.comp",
     })),
     FileNameAsCustomTestSuffix