HalfFloat.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. // SPDX-FileCopyrightText: 2021 Jorrit Rouwe
  2. // SPDX-License-Identifier: MIT
  3. #pragma once
  4. #include <Jolt/Math/Vec4.h>
  5. JPH_NAMESPACE_BEGIN
  6. using HalfFloat = uint16;
  7. // Define half float constant values
  8. static constexpr HalfFloat HALF_FLT_MAX = 0x7bff;
  9. static constexpr HalfFloat HALF_FLT_MAX_NEGATIVE = 0xfbff;
  10. static constexpr HalfFloat HALF_FLT_INF = 0x7c00;
  11. static constexpr HalfFloat HALF_FLT_INF_NEGATIVE = 0xfc00;
  12. static constexpr HalfFloat HALF_FLT_NANQ = 0x7e00;
  13. static constexpr HalfFloat HALF_FLT_NANQ_NEGATIVE = 0xfe00;
  14. namespace HalfFloatConversion {
  15. // Layout of a float
  16. static constexpr int FLOAT_SIGN_POS = 31;
  17. static constexpr int FLOAT_EXPONENT_POS = 23;
  18. static constexpr int FLOAT_EXPONENT_BITS = 8;
  19. static constexpr int FLOAT_EXPONENT_MASK = (1 << FLOAT_EXPONENT_BITS) - 1;
  20. static constexpr int FLOAT_EXPONENT_BIAS = 127;
  21. static constexpr int FLOAT_MANTISSA_BITS = 23;
  22. static constexpr int FLOAT_MANTISSA_MASK = (1 << FLOAT_MANTISSA_BITS) - 1;
  23. static constexpr int FLOAT_EXPONENT_AND_MANTISSA_MASK = FLOAT_MANTISSA_MASK + (FLOAT_EXPONENT_MASK << FLOAT_EXPONENT_POS);
  24. // Layout of half float
  25. static constexpr int HALF_FLT_SIGN_POS = 15;
  26. static constexpr int HALF_FLT_EXPONENT_POS = 10;
  27. static constexpr int HALF_FLT_EXPONENT_BITS = 5;
  28. static constexpr int HALF_FLT_EXPONENT_MASK = (1 << HALF_FLT_EXPONENT_BITS) - 1;
  29. static constexpr int HALF_FLT_EXPONENT_BIAS = 15;
  30. static constexpr int HALF_FLT_MANTISSA_BITS = 10;
  31. static constexpr int HALF_FLT_MANTISSA_MASK = (1 << HALF_FLT_MANTISSA_BITS) - 1;
  32. static constexpr int HALF_FLT_EXPONENT_AND_MANTISSA_MASK = HALF_FLT_MANTISSA_MASK + (HALF_FLT_EXPONENT_MASK << HALF_FLT_EXPONENT_POS);
  33. /// Define half-float rounding modes
  34. enum ERoundingMode
  35. {
  36. ROUND_TO_NEG_INF, ///< Round to negative infinity
  37. ROUND_TO_POS_INF, ///< Round to positive infinity
  38. ROUND_TO_NEAREST, ///< Round to nearest value
  39. };
  40. /// Convert a float (32-bits) to a half float (16-bits), fallback version when no intrinsics available
  41. template <int RoundingMode>
  42. inline HalfFloat FromFloatFallback(float inV)
  43. {
  44. // Reinterpret the float as an uint32
  45. static_assert(sizeof(float) == sizeof(uint32));
  46. union FloatToInt
  47. {
  48. float f;
  49. uint32 i;
  50. };
  51. FloatToInt f_to_i;
  52. f_to_i.f = inV;
  53. uint32 value = f_to_i.i;
  54. // Extract exponent
  55. uint32 exponent = (value >> FLOAT_EXPONENT_POS) & FLOAT_EXPONENT_MASK;
  56. // Extract mantissa
  57. uint32 mantissa = value & FLOAT_MANTISSA_MASK;
  58. // Extract the sign and move it into the right spot for the half float (so we can just or it in at the end)
  59. HalfFloat hf_sign = HalfFloat(value >> (FLOAT_SIGN_POS - HALF_FLT_SIGN_POS)) & (1 << HALF_FLT_SIGN_POS);
  60. // Check NaN or INF
  61. if (exponent == FLOAT_EXPONENT_MASK) // NaN or INF
  62. return hf_sign | (mantissa == 0? HALF_FLT_INF : HALF_FLT_NANQ);
  63. // Rebias the exponent for half floats
  64. int rebiased_exponent = int(exponent) - FLOAT_EXPONENT_BIAS + HALF_FLT_EXPONENT_BIAS;
  65. // Check overflow to infinity
  66. if (rebiased_exponent >= HALF_FLT_EXPONENT_MASK)
  67. {
  68. bool round_up = RoundingMode == ROUND_TO_NEAREST || (hf_sign == 0) == (RoundingMode == ROUND_TO_POS_INF);
  69. return hf_sign | (round_up? HALF_FLT_INF : HALF_FLT_MAX);
  70. }
  71. // Check underflow to zero
  72. if (rebiased_exponent < -HALF_FLT_MANTISSA_BITS)
  73. {
  74. bool round_up = RoundingMode != ROUND_TO_NEAREST && (hf_sign == 0) == (RoundingMode == ROUND_TO_POS_INF) && (value & FLOAT_EXPONENT_AND_MANTISSA_MASK) != 0;
  75. return hf_sign | (round_up? 1 : 0);
  76. }
  77. HalfFloat hf_exponent;
  78. int shift;
  79. if (rebiased_exponent <= 0)
  80. {
  81. // Underflow to denormalized number
  82. hf_exponent = 0;
  83. mantissa |= 1 << FLOAT_MANTISSA_BITS; // Add the implicit 1 bit to the mantissa
  84. shift = FLOAT_MANTISSA_BITS - HALF_FLT_MANTISSA_BITS + 1 - rebiased_exponent;
  85. }
  86. else
  87. {
  88. // Normal half float
  89. hf_exponent = HalfFloat(rebiased_exponent << HALF_FLT_EXPONENT_POS);
  90. shift = FLOAT_MANTISSA_BITS - HALF_FLT_MANTISSA_BITS;
  91. }
  92. // Compose the half float
  93. HalfFloat hf_mantissa = HalfFloat(mantissa >> shift);
  94. HalfFloat hf = hf_sign | hf_exponent | hf_mantissa;
  95. // Calculate the remaining bits that we're discarding
  96. uint remainder = mantissa & ((1 << shift) - 1);
  97. if constexpr (RoundingMode == ROUND_TO_NEAREST)
  98. {
  99. // Round to nearest
  100. uint round_threshold = 1 << (shift - 1);
  101. if (remainder > round_threshold // Above threshold, we must always round
  102. || (remainder == round_threshold && (hf_mantissa & 1))) // When equal, round to nearest even
  103. hf++; // May overflow to infinity
  104. }
  105. else
  106. {
  107. // Round up or down (truncate) depending on the rounding mode
  108. bool round_up = (hf_sign == 0) == (RoundingMode == ROUND_TO_POS_INF) && remainder != 0;
  109. if (round_up)
  110. hf++; // May overflow to infinity
  111. }
  112. return hf;
  113. }
  114. /// Convert a float (32-bits) to a half float (16-bits)
  115. template <int RoundingMode>
  116. JPH_INLINE HalfFloat FromFloat(float inV)
  117. {
  118. #ifdef JPH_USE_F16C
  119. union
  120. {
  121. __m128i u128;
  122. HalfFloat u16[8];
  123. } hf;
  124. __m128 val = _mm_load_ss(&inV);
  125. switch (RoundingMode)
  126. {
  127. case ROUND_TO_NEG_INF:
  128. hf.u128 = _mm_cvtps_ph(val, _MM_FROUND_TO_NEG_INF);
  129. break;
  130. case ROUND_TO_POS_INF:
  131. hf.u128 = _mm_cvtps_ph(val, _MM_FROUND_TO_POS_INF);
  132. break;
  133. case ROUND_TO_NEAREST:
  134. hf.u128 = _mm_cvtps_ph(val, _MM_FROUND_TO_NEAREST_INT);
  135. break;
  136. }
  137. return hf.u16[0];
  138. #else
  139. return FromFloatFallback<RoundingMode>(inV);
  140. #endif
  141. }
  142. /// Convert 4 half floats (lower 64 bits) to floats, fallback version when no intrinsics available
  143. inline Vec4 ToFloatFallback(UVec4Arg inValue)
  144. {
  145. // Unpack half floats to 4 uint32's
  146. UVec4 value = inValue.Expand4Uint16Lo();
  147. // Normal half float path, extract the exponent and mantissa, shift them into place and update the exponent bias
  148. UVec4 exponent_mantissa = UVec4::sAnd(value, UVec4::sReplicate(HALF_FLT_EXPONENT_AND_MANTISSA_MASK)).LogicalShiftLeft<FLOAT_EXPONENT_POS - HALF_FLT_EXPONENT_POS>() + UVec4::sReplicate((FLOAT_EXPONENT_BIAS - HALF_FLT_EXPONENT_BIAS) << FLOAT_EXPONENT_POS);
  149. // Denormalized half float path, renormalize the float
  150. UVec4 exponent_mantissa_denormalized = ((exponent_mantissa + UVec4::sReplicate(1 << FLOAT_EXPONENT_POS)).ReinterpretAsFloat() - UVec4::sReplicate((FLOAT_EXPONENT_BIAS - HALF_FLT_EXPONENT_BIAS + 1) << FLOAT_EXPONENT_POS).ReinterpretAsFloat()).ReinterpretAsInt();
  151. // NaN / INF path, set all exponent bits
  152. UVec4 exponent_mantissa_nan_inf = UVec4::sOr(exponent_mantissa, UVec4::sReplicate(FLOAT_EXPONENT_MASK << FLOAT_EXPONENT_POS));
  153. // Get the exponent to determine which of the paths we should take
  154. UVec4 exponent_mask = UVec4::sReplicate(HALF_FLT_EXPONENT_MASK << HALF_FLT_EXPONENT_POS);
  155. UVec4 exponent = UVec4::sAnd(value, exponent_mask);
  156. UVec4 is_denormalized = UVec4::sEquals(exponent, UVec4::sZero());
  157. UVec4 is_nan_inf = UVec4::sEquals(exponent, exponent_mask);
  158. // Select the correct result
  159. UVec4 result_exponent_mantissa = UVec4::sSelect(UVec4::sSelect(exponent_mantissa, exponent_mantissa_nan_inf, is_nan_inf), exponent_mantissa_denormalized, is_denormalized);
  160. // Extract the sign bit and shift it to the left
  161. UVec4 sign = UVec4::sAnd(value, UVec4::sReplicate(1 << HALF_FLT_SIGN_POS)).LogicalShiftLeft<FLOAT_SIGN_POS - HALF_FLT_SIGN_POS>();
  162. // Construct the float
  163. return UVec4::sOr(sign, result_exponent_mantissa).ReinterpretAsFloat();
  164. }
  165. /// Convert 4 half floats (lower 64 bits) to floats
  166. JPH_INLINE Vec4 ToFloat(UVec4Arg inValue)
  167. {
  168. #if defined(JPH_USE_F16C)
  169. return _mm_cvtph_ps(inValue.mValue);
  170. #elif defined(JPH_USE_NEON)
  171. return vcvt_f32_f16(vreinterpret_f16_f32(vget_low_f32(inValue.mValue)));
  172. #else
  173. return ToFloatFallback(inValue);
  174. #endif
  175. }
  176. } // HalfFloatConversion
  177. JPH_NAMESPACE_END