Browse Source

Add f16 specific procedures to core:math

gingerBill 4 years ago
parent
commit
63bb26c0e0
1 changed files with 240 additions and 59 deletions
  1. 240 59
      core/math/math.odin

+ 240 - 59
core/math/math.odin

@@ -31,6 +31,7 @@ LN10         :: 2.30258509299404568401799145468436421;
 
 MAX_F64_PRECISION :: 16; // Maximum number of meaningful digits after the decimal point for 'f64'
 MAX_F32_PRECISION ::  8; // Maximum number of meaningful digits after the decimal point for 'f32'
+MAX_F16_PRECISION ::  4; // Maximum number of meaningful digits after the decimal point for 'f16'
 
 RAD_PER_DEG :: TAU/360.0;
 DEG_PER_RAD :: 360.0/TAU;
@@ -38,81 +39,101 @@ DEG_PER_RAD :: 360.0/TAU;
 
 @(default_calling_convention="none")
 foreign _ {
+	@(link_name="llvm.sqrt.f16")
+	sqrt_f16 :: proc(x: f16) -> f16 ---;
 	@(link_name="llvm.sqrt.f32")
 	sqrt_f32 :: proc(x: f32) -> f32 ---;
 	@(link_name="llvm.sqrt.f64")
 	sqrt_f64 :: proc(x: f64) -> f64 ---;
 
+	@(link_name="llvm.sin.f16")
+	sin_f16 :: proc(θ: f16) -> f16 ---;
 	@(link_name="llvm.sin.f32")
 	sin_f32 :: proc(θ: f32) -> f32 ---;
 	@(link_name="llvm.sin.f64")
 	sin_f64 :: proc(θ: f64) -> f64 ---;
 
+	@(link_name="llvm.cos.f16")
+	cos_f16 :: proc(θ: f16) -> f16 ---;
 	@(link_name="llvm.cos.f32")
 	cos_f32 :: proc(θ: f32) -> f32 ---;
 	@(link_name="llvm.cos.f64")
 	cos_f64 :: proc(θ: f64) -> f64 ---;
 
+	@(link_name="llvm.pow.f16")
+	pow_f16 :: proc(x, power: f16) -> f16 ---;
 	@(link_name="llvm.pow.f32")
 	pow_f32 :: proc(x, power: f32) -> f32 ---;
 	@(link_name="llvm.pow.f64")
 	pow_f64 :: proc(x, power: f64) -> f64 ---;
 
+	@(link_name="llvm.fmuladd.f16")
+	fmuladd_f16 :: proc(a, b, c: f16) -> f16 ---;
 	@(link_name="llvm.fmuladd.f32")
 	fmuladd_f32 :: proc(a, b, c: f32) -> f32 ---;
 	@(link_name="llvm.fmuladd.f64")
 	fmuladd_f64 :: proc(a, b, c: f64) -> f64 ---;
 
+	@(link_name="llvm.log.f16")
+	ln_f16 :: proc(x: f16) -> f16 ---;
 	@(link_name="llvm.log.f32")
 	ln_f32 :: proc(x: f32) -> f32 ---;
 	@(link_name="llvm.log.f64")
 	ln_f64 :: proc(x: f64) -> f64 ---;
 
+	@(link_name="llvm.exp.f16")
+	exp_f16 :: proc(x: f16) -> f16 ---;
 	@(link_name="llvm.exp.f32")
 	exp_f32 :: proc(x: f32) -> f32 ---;
 	@(link_name="llvm.exp.f64")
 	exp_f64 :: proc(x: f64) -> f64 ---;
 
+	@(link_name="llvm.ldexp.f16")
+	ldexp_f16 :: proc(val: f16, exp: i32) -> f16 ---;
 	@(link_name="llvm.ldexp.f32")
 	ldexp_f32 :: proc(val: f32, exp: i32) -> f32 ---;
-
 	@(link_name="llvm.ldexp.f64")
 	ldexp_f64 :: proc(val: f64, exp: i32) -> f64 ---;
 }
 
-sqrt      :: proc{sqrt_f32, sqrt_f64};
-sin       :: proc{sin_f32, sin_f64};
-cos       :: proc{cos_f32, cos_f64};
-pow       :: proc{pow_f32, pow_f64};
-fmuladd   :: proc{fmuladd_f32, fmuladd_f64};
-ln        :: proc{ln_f32, ln_f64};
-exp       :: proc{exp_f32, exp_f64};
+sqrt      :: proc{sqrt_f16,    sqrt_f32,    sqrt_f64};
+sin       :: proc{sin_f16,     sin_f32,     sin_f64};
+cos       :: proc{cos_f16,     cos_f32,     cos_f64};
+pow       :: proc{pow_f16,     pow_f32,     pow_f64};
+fmuladd   :: proc{fmuladd_f16, fmuladd_f32, fmuladd_f64};
+ln        :: proc{ln_f16,      ln_f32,      ln_f64};
+exp       :: proc{exp_f16,     exp_f32,     exp_f64};
 
-ldexp :: proc{ldexp_f32, ldexp_f64};
+ldexp :: proc{ldexp_f16, ldexp_f32, ldexp_f64};
 
+log_f16 :: proc(x, base: f16) -> f16 { return ln(x) / ln(base); }
 log_f32 :: proc(x, base: f32) -> f32 { return ln(x) / ln(base); }
 log_f64 :: proc(x, base: f64) -> f64 { return ln(x) / ln(base); }
-log     :: proc{log_f32, log_f64};
+log     :: proc{log_f16, log_f32, log_f64};
 
+log2_f16 :: proc(x: f16) -> f16 { return ln(x)/LN2; }
 log2_f32 :: proc(x: f32) -> f32 { return ln(x)/LN2; }
 log2_f64 :: proc(x: f64) -> f64 { return ln(x)/LN2; }
-log2     :: proc{log2_f32, log2_f64};
+log2     :: proc{log2_f16, log2_f32, log2_f64};
 
+log10_f16 :: proc(x: f16) -> f16 { return ln(x)/LN10; }
 log10_f32 :: proc(x: f32) -> f32 { return ln(x)/LN10; }
 log10_f64 :: proc(x: f64) -> f64 { return ln(x)/LN10; }
-log10     :: proc{log10_f32, log10_f64};
+log10     :: proc{log10_f16, log10_f32, log10_f64};
 
 
+tan_f16 :: proc(θ: f16) -> f16 { return sin(θ)/cos(θ); }
 tan_f32 :: proc(θ: f32) -> f32 { return sin(θ)/cos(θ); }
 tan_f64 :: proc(θ: f64) -> f64 { return sin(θ)/cos(θ); }
-tan     :: proc{tan_f32, tan_f64};
+tan     :: proc{tan_f16, tan_f32, tan_f64};
 
 lerp :: proc(a, b: $T, t: $E) -> (x: T) { return a*(1-t) + b*t; }
 saturate :: proc(a: $T) -> (x: T) { return clamp(a, 0, 1); };
 
+unlerp_f16 :: proc(a, b, x: f16) -> (t: f16) { return (x-a)/(b-a); }
 unlerp_f32 :: proc(a, b, x: f32) -> (t: f32) { return (x-a)/(b-a); }
 unlerp_f64 :: proc(a, b, x: f64) -> (t: f64) { return (x-a)/(b-a); }
-unlerp     :: proc{unlerp_f32, unlerp_f64};
+unlerp     :: proc{unlerp_f16, unlerp_f32, unlerp_f64};
 
 
 wrap :: proc(x, y: $T) -> T where intrinsics.type_is_numeric(T), !intrinsics.type_is_array(T) {
@@ -149,18 +170,30 @@ gain :: proc(t, g: $T) -> T where intrinsics.type_is_numeric(T) {
 }
 
 
+sign_f16 :: proc(x: f16) -> f16 { return f16(int(0 < x) - int(x < 0)); }
 sign_f32 :: proc(x: f32) -> f32 { return f32(int(0 < x) - int(x < 0)); }
 sign_f64 :: proc(x: f64) -> f64 { return f64(int(0 < x) - int(x < 0)); }
-sign     :: proc{sign_f32, sign_f64};
+sign     :: proc{sign_f16, sign_f32, sign_f64};
+
 
+sign_bit_f16 :: proc(x: f16) -> bool {
+	return (transmute(u16)x) & (1<<15) != 0;
+}
 sign_bit_f32 :: proc(x: f32) -> bool {
 	return (transmute(u32)x) & (1<<31) != 0;
 }
 sign_bit_f64 :: proc(x: f64) -> bool {
 	return (transmute(u64)x) & (1<<63) != 0;
 }
-sign_bit :: proc{sign_bit_f32, sign_bit_f64};
-
+sign_bit :: proc{sign_bit_f16, sign_bit_f32, sign_bit_f64};
+
+copy_sign_f16 :: proc(x, y: f16) -> f16 {
+	ix := transmute(u16)x;
+	iy := transmute(u16)y;
+	ix &= 0x7fff;
+	ix |= iy & 0x8000;
+	return transmute(f16)ix;
+}
 copy_sign_f32 :: proc(x, y: f32) -> f32 {
 	ix := transmute(u32)x;
 	iy := transmute(u32)y;
@@ -175,15 +208,47 @@ copy_sign_f64 :: proc(x, y: f64) -> f64 {
 	ix |= iy & 0x8000_0000_0000_0000;
 	return transmute(f64)ix;
 }
-copy_sign :: proc{copy_sign_f32, copy_sign_f64};
+copy_sign :: proc{copy_sign_f16, copy_sign_f32, copy_sign_f64};
 
 
+to_radians_f16 :: proc(degrees: f16) -> f16 { return degrees * RAD_PER_DEG; }
 to_radians_f32 :: proc(degrees: f32) -> f32 { return degrees * RAD_PER_DEG; }
 to_radians_f64 :: proc(degrees: f64) -> f64 { return degrees * RAD_PER_DEG; }
+to_degrees_f16 :: proc(radians: f16) -> f16 { return radians * DEG_PER_RAD; }
 to_degrees_f32 :: proc(radians: f32) -> f32 { return radians * DEG_PER_RAD; }
 to_degrees_f64 :: proc(radians: f64) -> f64 { return radians * DEG_PER_RAD; }
-to_radians     :: proc{to_radians_f32, to_radians_f64};
-to_degrees     :: proc{to_degrees_f32, to_degrees_f64};
+to_radians     :: proc{to_radians_f16, to_radians_f32, to_radians_f64};
+to_degrees     :: proc{to_degrees_f16, to_degrees_f32, to_degrees_f64};
+
+trunc_f16 :: proc(x: f16) -> f16 {
+	trunc_internal :: proc(f: f16) -> f16 {
+		mask :: 0x1f;
+		shift :: 16 - 6;
+		bias :: 0xf;
+
+		if f < 1 {
+			switch {
+			case f < 0:  return -trunc_internal(-f);
+			case f == 0: return f;
+			case:        return 0;
+			}
+		}
+
+		x := transmute(u16)f;
+		e := (x >> shift) & mask - bias;
+
+		if e < shift {
+			x &= ~(1 << (shift-e)) - 1;
+		}
+		return transmute(f16)x;
+	}
+	switch classify(x) {
+	case .Zero, .Neg_Zero, .NaN, .Inf, .Neg_Inf:
+		return x;
+	case .Normal, .Subnormal: // carry on
+	}
+	return trunc_internal(x);
+}
 
 trunc_f32 :: proc(x: f32) -> f32 {
 	trunc_internal :: proc(f: f32) -> f32 {
@@ -245,21 +310,39 @@ trunc_f64 :: proc(x: f64) -> f64 {
 	return trunc_internal(x);
 }
 
-trunc :: proc{trunc_f32, trunc_f64};
+trunc :: proc{trunc_f16, trunc_f32, trunc_f64};
 
+round_f16 :: proc(x: f16) -> f16 {
+	return ceil(x - 0.5) if x < 0 else floor(x + 0.5);
+}
 round_f32 :: proc(x: f32) -> f32 {
 	return ceil(x - 0.5) if x < 0 else floor(x + 0.5);
 }
 round_f64 :: proc(x: f64) -> f64 {
 	return ceil(x - 0.5) if x < 0 else floor(x + 0.5);
 }
-round :: proc{round_f32, round_f64};
+round :: proc{round_f16, round_f32, round_f64};
 
 
+ceil_f16 :: proc(x: f16) -> f16 { return -floor(-x); }
 ceil_f32 :: proc(x: f32) -> f32 { return -floor(-x); }
 ceil_f64 :: proc(x: f64) -> f64 { return -floor(-x); }
-ceil :: proc{ceil_f32, ceil_f64};
+ceil :: proc{ceil_f16, ceil_f32, ceil_f64};
 
+floor_f16 :: proc(x: f16) -> f16 {
+	if x == 0 || is_nan(x) || is_inf(x) {
+		return x;
+	}
+	if x < 0 {
+		d, fract := modf(-x);
+		if fract != 0.0 {
+			d = d + 1;
+		}
+		return -d;
+	}
+	d, _ := modf(x);
+	return d;
+}
 floor_f32 :: proc(x: f32) -> f32 {
 	if x == 0 || is_nan(x) || is_inf(x) {
 		return x;
@@ -288,7 +371,7 @@ floor_f64 :: proc(x: f64) -> f64 {
 	d, _ := modf(x);
 	return d;
 }
-floor :: proc{floor_f32, floor_f64};
+floor :: proc{floor_f16, floor_f32, floor_f64};
 
 
 floor_div :: proc(x, y: $T) -> T
@@ -310,7 +393,32 @@ floor_mod :: proc(x, y: $T) -> T
 	return r;
 }
 
+modf_f16 :: proc(x: f16) -> (int: f16, frac: f16) {
+	shift :: 16 - 5 - 1;
+	mask  :: 0x1f;
+	bias  :: 15;
 
+	if x < 1 {
+		switch {
+		case x < 0:
+			int, frac = modf(-x);
+			return -int, -frac;
+		case x == 0:
+			return x, x;
+		}
+		return 0, x;
+	}
+
+	i := transmute(u16)x;
+	e := uint(i>>shift)&mask - bias;
+
+	if e < shift {
+		i &~= 1<<(shift-e) - 1;
+	}
+	int = transmute(f16)i;
+	frac = x - int;
+	return;
+}
 modf_f32 :: proc(x: f32) -> (int: f32, frac: f32) {
 	shift :: 32 - 8 - 1;
 	mask  :: 0xff;
@@ -363,9 +471,17 @@ modf_f64 :: proc(x: f64) -> (int: f64, frac: f64) {
 	frac = x - int;
 	return;
 }
-modf :: proc{modf_f32, modf_f64};
+modf :: proc{modf_f16, modf_f32, modf_f64};
 split_decimal :: modf;
 
+mod_f16 :: proc(x, y: f16) -> (n: f16) {
+	z := abs(y);
+	n = remainder(abs(x), z);
+	if sign(n) < 0 {
+		n += z;
+	}
+	return copy_sign(n, x);
+}
 mod_f32 :: proc(x, y: f32) -> (n: f32) {
 	z := abs(y);
 	n = remainder(abs(x), z);
@@ -382,11 +498,12 @@ mod_f64 :: proc(x, y: f64) -> (n: f64) {
 	}
 	return copy_sign(n, x);
 }
-mod :: proc{mod_f32, mod_f64};
+mod :: proc{mod_f16, mod_f32, mod_f64};
 
+remainder_f16 :: proc(x, y: f16) -> f16 { return x - round(x/y) * y; }
 remainder_f32 :: proc(x, y: f32) -> f32 { return x - round(x/y) * y; }
 remainder_f64 :: proc(x, y: f64) -> f64 { return x - round(x/y) * y; }
-remainder :: proc{remainder_f32, remainder_f64};
+remainder :: proc{remainder_f16, remainder_f32, remainder_f64};
 
 
 
@@ -405,25 +522,13 @@ lcm :: proc(x, y: $T) -> T
 	return x / gcd(x, y) * y;
 }
 
+frexp_f16 :: proc(x: f16) -> (significand: f16, exponent: int) {
+	f, e := frexp_f64(f64(x));
+	return f16(f), e;
+}
 frexp_f32 :: proc(x: f32) -> (significand: f32, exponent: int) {
-	switch {
-	case x == 0:
-		return 0, 0;
-	case x < 0:
-		significand, exponent = frexp(-x);
-		return -significand, exponent;
-	}
-	ex := trunc(log2(x));
-	exponent = int(ex);
-	significand = x / pow(2.0, ex);
-	if abs(significand) >= 1 {
-		exponent += 1;
-		significand /= 2;
-	}
-	if exponent == 1024 && significand == 0 {
-		significand = 0.99999999999999988898;
-	}
-	return;
+	f, e := frexp_f64(f64(x));
+	return f32(f), e;
 }
 frexp_f64 :: proc(x: f64) -> (significand: f64, exponent: int) {
 	switch {
@@ -445,7 +550,7 @@ frexp_f64 :: proc(x: f64) -> (significand: f64, exponent: int) {
 	}
 	return;
 }
-frexp :: proc{frexp_f32, frexp_f64};
+frexp :: proc{frexp_f16, frexp_f32, frexp_f64};
 
 
 
@@ -511,6 +616,30 @@ factorial :: proc(n: int) -> int {
 	return table[n];
 }
 
+classify_f16 :: proc(x: f16) -> Float_Class {
+	switch {
+	case x == 0:
+		i := transmute(i16)x;
+		if i < 0 {
+			return .Neg_Zero;
+		}
+		return .Zero;
+	case x*0.5 == x:
+		if x < 0 {
+			return .Neg_Inf;
+		}
+		return .Inf;
+	case !(x == x):
+		return .NaN;
+	}
+
+	u := transmute(u16)x;
+	exp := int(u>>10) & (1<<5 - 1);
+	if exp == 0 {
+		return .Subnormal;
+	}
+	return .Normal;
+}
 classify_f32 :: proc(x: f32) -> Float_Class {
 	switch {
 	case x == 0:
@@ -558,17 +687,28 @@ classify_f64 :: proc(x: f64) -> Float_Class {
 	}
 	return .Normal;
 }
-classify :: proc{classify_f32, classify_f64};
+classify :: proc{classify_f16, classify_f32, classify_f64};
 
+is_nan_f16 :: proc(x: f16) -> bool { return classify(x) == .NaN; }
 is_nan_f32 :: proc(x: f32) -> bool { return classify(x) == .NaN; }
 is_nan_f64 :: proc(x: f64) -> bool { return classify(x) == .NaN; }
-is_nan :: proc{is_nan_f32, is_nan_f64};
+is_nan :: proc{is_nan_f16, is_nan_f32, is_nan_f64};
 
 
 // is_inf reports whether f is an infinity, according to sign.
 // If sign > 0, is_inf reports whether f is positive infinity.
 // If sign < 0, is_inf reports whether f is negative infinity.
 // If sign == 0, is_inf reports whether f is either infinity.
+is_inf_f16 :: proc(x: f16, sign: int = 0) -> bool {
+	class := classify(abs(x));
+	switch {
+	case sign > 0:
+		return class == .Inf;
+	case sign < 0:
+		return class == .Neg_Inf;
+	}
+	return class == .Inf || class == .Neg_Inf;
+}
 is_inf_f32 :: proc(x: f32, sign: int = 0) -> bool {
 	class := classify(abs(x));
 	switch {
@@ -589,10 +729,12 @@ is_inf_f64 :: proc(x: f64, sign: int = 0) -> bool {
 	}
 	return class == .Inf || class == .Neg_Inf;
 }
-is_inf :: proc{is_inf_f32, is_inf_f64};
-
+is_inf :: proc{is_inf_f16, is_inf_f32, is_inf_f64};
 
 
+inf_f16 :: proc(sign: int) -> f16 {
+	return f16(inf_f16(sign));
+}
 inf_f32 :: proc(sign: int) -> f32 {
 	return f32(inf_f64(sign));
 }
@@ -606,7 +748,9 @@ inf_f64 :: proc(sign: int) -> f64 {
 	return transmute(f64)v;
 }
 
-
+nan_f16 :: proc() -> f16 {
+	return f16(nan_f64());
+}
 nan_f32 :: proc() -> f32 {
 	return f32(nan_f64());
 }
@@ -672,7 +816,10 @@ cumsum :: proc(dst, src: $T/[]$E) -> T
 }
 
 
-
+atan2_f16 :: proc(y, x: f16) -> f16 {
+	// TODO(bill): Better atan2_f16
+	return f16(atan2_f64(f64(y), f64(x)));
+}
 atan2_f32 :: proc(y, x: f32) -> f32 {
 	// TODO(bill): Better atan2_f32
 	return f32(atan2_f64(f64(y), f64(x)));
@@ -765,49 +912,68 @@ atan2_f64 :: proc(y, x: f64) -> f64 {
 }
 
 
-atan2 :: proc{atan2_f32, atan2_f64};
+atan2 :: proc{atan2_f16, atan2_f32, atan2_f64};
 
+atan_f16 :: proc(x: f16) -> f16 {
+	return atan2_f16(x, 1);
+}
 atan_f32 :: proc(x: f32) -> f32 {
 	return atan2_f32(x, 1);
 }
 atan_f64 :: proc(x: f64) -> f64 {
 	return atan2_f64(x, 1);
 }
-atan :: proc{atan_f32, atan_f64};
+atan :: proc{atan_f16, atan_f32, atan_f64};
 
+asin_f16 :: proc(x: f16) -> f16 {
+	return atan2_f16(x, 1 + sqrt_f16(1 - x*x));
+}
 asin_f32 :: proc(x: f32) -> f32 {
 	return atan2_f32(x, 1 + sqrt_f32(1 - x*x));
 }
 asin_f64 :: proc(x: f64) -> f64 {
 	return atan2_f64(x, 1 + sqrt_f64(1 - x*x));
 }
-asin :: proc{asin_f32, asin_f64};
+asin :: proc{asin_f16, asin_f32, asin_f64};
 
+acos_f16 :: proc(x: f16) -> f16 {
+	return 2 * atan2_f16(sqrt_f16(1 - x), sqrt_f16(1 + x));
+}
 acos_f32 :: proc(x: f32) -> f32 {
 	return 2 * atan2_f32(sqrt_f32(1 - x), sqrt_f32(1 + x));
 }
 acos_f64 :: proc(x: f64) -> f64 {
 	return 2 * atan2_f64(sqrt_f64(1 - x), sqrt_f64(1 + x));
 }
-acos :: proc{acos_f32, acos_f64};
+acos :: proc{acos_f16, acos_f32, acos_f64};
 
 
+sinh_f16 :: proc(x: f16) -> f16 {
+	return (exp(x) - exp(-x))*0.5;
+}
 sinh_f32 :: proc(x: f32) -> f32 {
 	return (exp(x) - exp(-x))*0.5;
 }
 sinh_f64 :: proc(x: f64) -> f64 {
 	return (exp(x) - exp(-x))*0.5;
 }
-sinh :: proc{sinh_f32, sinh_f64};
+sinh :: proc{sinh_f16, sinh_f32, sinh_f64};
 
+cosh_f16 :: proc(x: f16) -> f16 {
+	return (exp(x) + exp(-x))*0.5;
+}
 cosh_f32 :: proc(x: f32) -> f32 {
 	return (exp(x) + exp(-x))*0.5;
 }
 cosh_f64 :: proc(x: f64) -> f64 {
 	return (exp(x) + exp(-x))*0.5;
 }
-cosh :: proc{cosh_f32, cosh_f64};
+cosh :: proc{cosh_f16, cosh_f32, cosh_f64};
 
+tanh_f16 :: proc(x: f16) -> f16 {
+	t := exp(2*x);
+	return (t - 1) / (t + 1);
+}
 tanh_f32 :: proc(x: f32) -> f32 {
 	t := exp(2*x);
 	return (t - 1) / (t + 1);
@@ -816,7 +982,22 @@ tanh_f64 :: proc(x: f64) -> f64 {
 	t := exp(2*x);
 	return (t - 1) / (t + 1);
 }
-tanh :: proc{tanh_f32, tanh_f64};
+tanh :: proc{tanh_f16, tanh_f32, tanh_f64};
+
+
+F16_DIG        :: 3;
+F16_EPSILON    :: 0.00097656;
+F16_GUARD      :: 0;
+F16_MANT_DIG   :: 11;
+F16_MAX        :: 65504.0;
+F16_MAX_10_EXP :: 4;
+F16_MAX_EXP    :: 15;
+F16_MIN        :: 6.10351562e-5;
+F16_MIN_10_EXP :: -4;
+F16_MIN_EXP    :: -14;
+F16_NORMALIZE  :: 0;
+F16_RADIX      :: 2;
+F16_ROUNDS     :: 1;
 
 
 F32_DIG        :: 6;