roots.rs 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. mod biguint {
  2. use num_bigint::BigUint;
  3. use num_traits::{One, Zero};
  4. use std::{i32, u32};
  5. fn check<T: Into<BigUint>>(x: T, n: u32) {
  6. let x: BigUint = x.into();
  7. let root = x.nth_root(n);
  8. println!("check {}.nth_root({}) = {}", x, n, root);
  9. if n == 2 {
  10. assert_eq!(root, x.sqrt())
  11. } else if n == 3 {
  12. assert_eq!(root, x.cbrt())
  13. }
  14. let lo = root.pow(n);
  15. assert!(lo <= x);
  16. assert_eq!(lo.nth_root(n), root);
  17. if !lo.is_zero() {
  18. assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
  19. }
  20. let hi = (&root + 1u32).pow(n);
  21. assert!(hi > x);
  22. assert_eq!(hi.nth_root(n), &root + 1u32);
  23. assert_eq!((&hi - 1u32).nth_root(n), root);
  24. }
  25. #[test]
  26. fn test_sqrt() {
  27. check(99u32, 2);
  28. check(100u32, 2);
  29. check(120u32, 2);
  30. }
  31. #[test]
  32. fn test_cbrt() {
  33. check(8u32, 3);
  34. check(26u32, 3);
  35. }
  36. #[test]
  37. fn test_nth_root() {
  38. check(0u32, 1);
  39. check(10u32, 1);
  40. check(100u32, 4);
  41. }
  42. #[test]
  43. #[should_panic]
  44. fn test_nth_root_n_is_zero() {
  45. check(4u32, 0);
  46. }
  47. #[test]
  48. fn test_nth_root_big() {
  49. let x = BigUint::from(123_456_789_u32);
  50. let expected = BigUint::from(6u32);
  51. assert_eq!(x.nth_root(10), expected);
  52. check(x, 10);
  53. }
  54. #[test]
  55. fn test_nth_root_googol() {
  56. let googol = BigUint::from(10u32).pow(100u32);
  57. // perfect divisors of 100
  58. for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
  59. let expected = BigUint::from(10u32).pow(100u32 / n);
  60. assert_eq!(googol.nth_root(n), expected);
  61. check(googol.clone(), n);
  62. }
  63. }
  64. #[test]
  65. fn test_nth_root_twos() {
  66. const EXP: u32 = 12;
  67. const LOG2: usize = 1 << EXP;
  68. let x = BigUint::one() << LOG2;
  69. // the perfect divisors are just powers of two
  70. for exp in 1..=EXP {
  71. let n = 2u32.pow(exp);
  72. let expected = BigUint::one() << (LOG2 / n as usize);
  73. assert_eq!(x.nth_root(n), expected);
  74. check(x.clone(), n);
  75. }
  76. // degenerate cases should return quickly
  77. assert!(x.nth_root(x.bits() as u32).is_one());
  78. assert!(x.nth_root(i32::MAX as u32).is_one());
  79. assert!(x.nth_root(u32::MAX).is_one());
  80. }
  81. #[test]
  82. fn test_roots_rand1() {
  83. // A random input that found regressions
  84. let s = "575981506858479247661989091587544744717244516135539456183849\
  85. 986593934723426343633698413178771587697273822147578889823552\
  86. 182702908597782734558103025298880194023243541613924361007059\
  87. 353344183590348785832467726433749431093350684849462759540710\
  88. 026019022227591412417064179299354183441181373862905039254106\
  89. 4781867";
  90. let x: BigUint = s.parse().unwrap();
  91. check(x.clone(), 2);
  92. check(x.clone(), 3);
  93. check(x.clone(), 10);
  94. check(x, 100);
  95. }
  96. }
  97. mod bigint {
  98. use num_bigint::BigInt;
  99. use num_traits::Signed;
  100. fn check(x: i64, n: u32) {
  101. let big_x = BigInt::from(x);
  102. let res = big_x.nth_root(n);
  103. if n == 2 {
  104. assert_eq!(&res, &big_x.sqrt())
  105. } else if n == 3 {
  106. assert_eq!(&res, &big_x.cbrt())
  107. }
  108. if big_x.is_negative() {
  109. assert!(res.pow(n) >= big_x);
  110. assert!((res - 1u32).pow(n) < big_x);
  111. } else {
  112. assert!(res.pow(n) <= big_x);
  113. assert!((res + 1u32).pow(n) > big_x);
  114. }
  115. }
  116. #[test]
  117. fn test_nth_root() {
  118. check(-100, 3);
  119. }
  120. #[test]
  121. #[should_panic]
  122. fn test_nth_root_x_neg_n_even() {
  123. check(-100, 4);
  124. }
  125. #[test]
  126. #[should_panic]
  127. fn test_sqrt_x_neg() {
  128. check(-4, 2);
  129. }
  130. #[test]
  131. fn test_cbrt() {
  132. check(8, 3);
  133. check(-8, 3);
  134. }
  135. }