roots.rs 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. extern crate num_integer;
  2. extern crate num_traits;
  3. use num_integer::Roots;
  4. use num_traits::checked_pow;
  5. use num_traits::{AsPrimitive, PrimInt, Signed};
  6. use std::f64::MANTISSA_DIGITS;
  7. use std::fmt::Debug;
  8. use std::mem;
  9. trait TestInteger: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
  10. impl<T> TestInteger for T where T: Roots + PrimInt + Debug + AsPrimitive<f64> + 'static {}
  11. /// Check that each root is correct
  12. ///
  13. /// If `x` is positive, check `rⁿ ≤ x < (r+1)ⁿ`.
  14. /// If `x` is negative, check `(r-1)ⁿ < x ≤ rⁿ`.
  15. fn check<T>(v: &[T], n: u32)
  16. where
  17. T: TestInteger,
  18. {
  19. for i in v {
  20. let rt = i.nth_root(n);
  21. // println!("nth_root({:?}, {}) = {:?}", i, n, rt);
  22. if n == 2 {
  23. assert_eq!(rt, i.sqrt());
  24. } else if n == 3 {
  25. assert_eq!(rt, i.cbrt());
  26. }
  27. if *i >= T::zero() {
  28. let rt1 = rt + T::one();
  29. assert!(rt.pow(n) <= *i);
  30. if let Some(x) = checked_pow(rt1, n as usize) {
  31. assert!(*i < x);
  32. }
  33. } else {
  34. let rt1 = rt - T::one();
  35. assert!(rt < T::zero());
  36. assert!(*i <= rt.pow(n));
  37. if let Some(x) = checked_pow(rt1, n as usize) {
  38. assert!(x < *i);
  39. }
  40. };
  41. }
  42. }
  43. /// Get the maximum value that will round down as `f64` (if any),
  44. /// and its successor that will round up.
  45. ///
  46. /// Important because the `std` implementations cast to `f64` to
  47. /// get a close approximation of the roots.
  48. fn mantissa_max<T>() -> Option<(T, T)>
  49. where
  50. T: TestInteger,
  51. {
  52. let bits = if T::min_value().is_zero() {
  53. 8 * mem::size_of::<T>()
  54. } else {
  55. 8 * mem::size_of::<T>() - 1
  56. };
  57. if bits > MANTISSA_DIGITS as usize {
  58. let rounding_bit = T::one() << (bits - MANTISSA_DIGITS as usize - 1);
  59. let x = T::max_value() - rounding_bit;
  60. let x1 = x + T::one();
  61. let x2 = x1 + T::one();
  62. assert!(x.as_() < x1.as_());
  63. assert_eq!(x1.as_(), x2.as_());
  64. Some((x, x1))
  65. } else {
  66. None
  67. }
  68. }
  69. fn extend<T>(v: &mut Vec<T>, start: T, end: T)
  70. where
  71. T: TestInteger,
  72. {
  73. let mut i = start;
  74. while i < end {
  75. v.push(i);
  76. i = i + T::one();
  77. }
  78. v.push(i);
  79. }
  80. fn extend_shl<T>(v: &mut Vec<T>, start: T, end: T, mask: T)
  81. where
  82. T: TestInteger,
  83. {
  84. let mut i = start;
  85. while i != end {
  86. v.push(i);
  87. i = (i << 1) & mask;
  88. }
  89. }
  90. fn extend_shr<T>(v: &mut Vec<T>, start: T, end: T)
  91. where
  92. T: TestInteger,
  93. {
  94. let mut i = start;
  95. while i != end {
  96. v.push(i);
  97. i = i >> 1;
  98. }
  99. }
  100. fn pos<T>() -> Vec<T>
  101. where
  102. T: TestInteger,
  103. i8: AsPrimitive<T>,
  104. {
  105. let mut v: Vec<T> = vec![];
  106. if mem::size_of::<T>() == 1 {
  107. extend(&mut v, T::zero(), T::max_value());
  108. } else {
  109. extend(&mut v, T::zero(), i8::max_value().as_());
  110. extend(
  111. &mut v,
  112. T::max_value() - i8::max_value().as_(),
  113. T::max_value(),
  114. );
  115. if let Some((i, j)) = mantissa_max::<T>() {
  116. v.push(i);
  117. v.push(j);
  118. }
  119. extend_shl(&mut v, T::max_value(), T::zero(), !T::min_value());
  120. extend_shr(&mut v, T::max_value(), T::zero());
  121. }
  122. v
  123. }
  124. fn neg<T>() -> Vec<T>
  125. where
  126. T: TestInteger + Signed,
  127. i8: AsPrimitive<T>,
  128. {
  129. let mut v: Vec<T> = vec![];
  130. if mem::size_of::<T>() <= 1 {
  131. extend(&mut v, T::min_value(), T::zero());
  132. } else {
  133. extend(&mut v, i8::min_value().as_(), T::zero());
  134. extend(
  135. &mut v,
  136. T::min_value(),
  137. T::min_value() - i8::min_value().as_(),
  138. );
  139. if let Some((i, j)) = mantissa_max::<T>() {
  140. v.push(-i);
  141. v.push(-j);
  142. }
  143. extend_shl(&mut v, -T::one(), T::min_value(), !T::zero());
  144. extend_shr(&mut v, T::min_value(), -T::one());
  145. }
  146. v
  147. }
  148. macro_rules! test_roots {
  149. ($I:ident, $U:ident) => {
  150. mod $I {
  151. use check;
  152. use neg;
  153. use num_integer::Roots;
  154. use pos;
  155. use std::mem;
  156. #[test]
  157. #[should_panic]
  158. fn zeroth_root() {
  159. (123 as $I).nth_root(0);
  160. }
  161. #[test]
  162. fn sqrt() {
  163. check(&pos::<$I>(), 2);
  164. }
  165. #[test]
  166. #[should_panic]
  167. fn sqrt_neg() {
  168. (-123 as $I).sqrt();
  169. }
  170. #[test]
  171. fn cbrt() {
  172. check(&pos::<$I>(), 3);
  173. }
  174. #[test]
  175. fn cbrt_neg() {
  176. check(&neg::<$I>(), 3);
  177. }
  178. #[test]
  179. fn nth_root() {
  180. let bits = 8 * mem::size_of::<$I>() as u32 - 1;
  181. let pos = pos::<$I>();
  182. for n in 4..bits {
  183. check(&pos, n);
  184. }
  185. }
  186. #[test]
  187. fn nth_root_neg() {
  188. let bits = 8 * mem::size_of::<$I>() as u32 - 1;
  189. let neg = neg::<$I>();
  190. for n in 2..bits / 2 {
  191. check(&neg, 2 * n + 1);
  192. }
  193. }
  194. #[test]
  195. fn bit_size() {
  196. let bits = 8 * mem::size_of::<$I>() as u32 - 1;
  197. assert_eq!($I::max_value().nth_root(bits - 1), 2);
  198. assert_eq!($I::max_value().nth_root(bits), 1);
  199. assert_eq!($I::min_value().nth_root(bits), -2);
  200. assert_eq!(($I::min_value() + 1).nth_root(bits), -1);
  201. }
  202. }
  203. mod $U {
  204. use check;
  205. use num_integer::Roots;
  206. use pos;
  207. use std::mem;
  208. #[test]
  209. #[should_panic]
  210. fn zeroth_root() {
  211. (123 as $U).nth_root(0);
  212. }
  213. #[test]
  214. fn sqrt() {
  215. check(&pos::<$U>(), 2);
  216. }
  217. #[test]
  218. fn cbrt() {
  219. check(&pos::<$U>(), 3);
  220. }
  221. #[test]
  222. fn nth_root() {
  223. let bits = 8 * mem::size_of::<$I>() as u32 - 1;
  224. let pos = pos::<$I>();
  225. for n in 4..bits {
  226. check(&pos, n);
  227. }
  228. }
  229. #[test]
  230. fn bit_size() {
  231. let bits = 8 * mem::size_of::<$U>() as u32;
  232. assert_eq!($U::max_value().nth_root(bits - 1), 2);
  233. assert_eq!($U::max_value().nth_root(bits), 1);
  234. }
  235. }
  236. };
  237. }
  238. test_roots!(i8, u8);
  239. test_roots!(i16, u16);
  240. test_roots!(i32, u32);
  241. test_roots!(i64, u64);
  242. #[cfg(has_i128)]
  243. test_roots!(i128, u128);
  244. test_roots!(isize, usize);