Przeglądaj źródła

Merge branch 'master' of https://github.com/odin-lang/Odin

gingerBill 4 lat temu
rodzic
commit
c3a64c2a59

+ 23 - 26
Makefile

@@ -8,31 +8,31 @@ CC=clang
 OS=$(shell uname)
 
 ifeq ($(OS), Darwin)
-	LLVM_CONFIG=llvm-config
-	ifneq ($(shell llvm-config --version | grep '^11\.'),)
-		LLVM_CONFIG=llvm-config
-	else
-		$(error "Requirement: llvm-config must be version 11")
-	endif
-
-	LDFLAGS:=$(LDFLAGS) -liconv
-	CFLAGS:=$(CFLAGS) $(shell $(LLVM_CONFIG) --cxxflags --ldflags)
-	LDFLAGS:=$(LDFLAGS) -lLLVM-C
+    LLVM_CONFIG=llvm-config
+    ifneq ($(shell llvm-config --version | grep '^11\.'),)
+        LLVM_CONFIG=llvm-config
+    else
+        $(error "Requirement: llvm-config must be version 11")
+    endif
+
+    LDFLAGS:=$(LDFLAGS) -liconv
+    CFLAGS:=$(CFLAGS) $(shell $(LLVM_CONFIG) --cxxflags --ldflags)
+    LDFLAGS:=$(LDFLAGS) -lLLVM-C
 endif
 ifeq ($(OS), Linux)
-	LLVM_CONFIG=llvm-config-11
-	ifneq ($(shell which llvm-config-11 2>/dev/null),)
-		LLVM_CONFIG=llvm-config-11
-	else
-		ifneq ($(shell llvm-config --version | grep '^11\.'),)
-			LLVM_CONFIG=llvm-config
-		else
-			$(error "Requirement: llvm-config must be version 11")
-		endif
-	endif
-
-	CFLAGS:=$(CFLAGS) $(shell $(LLVM_CONFIG) --cxxflags --ldflags)
-	LDFLAGS:=$(LDFLAGS) $(shell $(LLVM_CONFIG) --libs core native --system-libs)
+    LLVM_CONFIG=llvm-config-11
+    ifneq ($(shell which llvm-config-11 2>/dev/null),)
+        LLVM_CONFIG=llvm-config-11
+    else
+        ifneq ($(shell llvm-config --version | grep '^11\.'),)
+            LLVM_CONFIG=llvm-config
+        else
+            $(error "Requirement: llvm-config must be version 11")
+        endif
+    endif
+
+    CFLAGS:=$(CFLAGS) $(shell $(LLVM_CONFIG) --cxxflags --ldflags)
+    LDFLAGS:=$(LDFLAGS) $(shell $(LLVM_CONFIG) --libs core native --system-libs)
 endif
 
 all: debug demo
@@ -51,6 +51,3 @@ release_native:
 
 nightly:
 	$(CC) src/main.cpp src/libtommath.cpp $(DISABLED_WARNINGS) $(CFLAGS) -DNIGHTLY -O3 $(LDFLAGS) -o odin
-
-
-

+ 5 - 1
core/c/libc/complex.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.3 Complex arithmetic
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 @(default_calling_convention="c")
 foreign libc {

+ 5 - 1
core/c/libc/ctype.odin

@@ -1,6 +1,10 @@
 package libc
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 // 7.4 Character handling
 

+ 5 - 1
core/c/libc/errno.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.5 Errors
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 // C11 standard only requires the definition of:
 //	EDOM,

+ 5 - 1
core/c/libc/math.odin

@@ -4,7 +4,11 @@ package libc
 
 import "core:intrinsics"
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 // To support C's tgmath behavior we use Odin's explicit procedure overloading,
 // but we cannot use the same names as exported by libc so use @(link_name)

+ 43 - 9
core/c/libc/setjmp.odin

@@ -2,18 +2,52 @@ package libc
 
 // 7.13 Nonlocal jumps
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
+
+when ODIN_OS == "windows" {
+	@(default_calling_convention="c")
+	foreign libc {
+		// 7.13.1 Save calling environment
+		//
+		// NOTE(dweiler): C11 requires setjmp be a macro, which means it won't
+		// necessarily export a symbol named setjmp but rather _setjmp in the case
+		// of musl, glibc, BSD libc, and msvcrt.
+		//
+		/// NOTE(dweiler): UCRT has two implementations of longjmp. One that performs
+		// stack unwinding and one that doesn't. The choice of which to use depends on a
+		// flag which is set inside the jmp_buf structure given to setjmp. The default
+		// behavior is to unwind the stack. Within Odin, we cannot use the stack
+		// unwinding version as the unwinding information isn't present. To opt-in to
+		// the regular non-unwinding version we need a way to set this flag. Since the
+		// location of the flag within the struct is not defined or part of the ABI and
+		// can change between versions of UCRT, we must rely on setjmp to set it. It
+		// turns out that setjmp receives this flag in the RDX register on Win64, this
+		// just so happens to coincide with the second argument of a function in the
+		// Win64 ABI. By giving our setjmp a second argument with the value of zero,
+		// the RDX register will contain zero and correctly set the flag to disable
+		// stack unwinding.
+		@(link_name="_setjmp")
+		setjmp  :: proc(env: ^jmp_buf, hack: rawptr = nil) -> int ---;
+	}
+} else {
+	@(default_calling_convention="c")
+	foreign libc {
+		// 7.13.1 Save calling environment
+		//
+		// NOTE(dweiler): C11 requires setjmp be a macro, which means it won't
+		// necessarily export a symbol named setjmp but rather _setjmp in the case
+		// of musl, glibc, BSD libc, and msvcrt.
+		@(link_name="_setjmp")
+		setjmp  :: proc(env: ^jmp_buf) -> int ---;
+	}
+}
 
 @(default_calling_convention="c")
 foreign libc {
-	// 7.13.1 Save calling environment
-	//
-	// NOTE(dweiler): C11 requires setjmp be a macro, which means it won't
-	// necessarily export a symbol named setjmp but rather _setjmp in the case
-	// of musl, glibc, BSD libc, and msvcrt.
-	@(link_name="_setjmp")
-	setjmp  :: proc(env: ^jmp_buf) -> int ---;
-
 	// 7.13.2 Restore calling environment
 	longjmp :: proc(env: ^jmp_buf, val: int) -> ! ---;
 }

+ 5 - 1
core/c/libc/signal.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.14 Signal handling
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 sig_atomic_t :: distinct atomic_int;
 

+ 5 - 1
core/c/libc/stdio.odin

@@ -1,6 +1,10 @@
 package libc
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 // 7.21 Input/output
 

+ 6 - 2
core/c/libc/stdlib.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.22 General utilities
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 when ODIN_OS == "windows" {
 	RAND_MAX :: 0x7fff;
@@ -14,7 +18,7 @@ when ODIN_OS == "windows" {
 	}
 
 	MB_CUR_MAX :: #force_inline proc() -> size_t {
-		return ___mb_cur_max_func();
+		return size_t(___mb_cur_max_func());
 	}
 }
 

+ 5 - 1
core/c/libc/string.odin

@@ -3,7 +3,11 @@ package libc
 
 // 7.24 String handling
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 foreign libc {
 	// 7.24.2 Copying functions

+ 50 - 0
core/c/libc/tests/general.odin

@@ -0,0 +1,50 @@
+package libc_tests
+
+import "core:c/libc"
+
+test_stdio :: proc() {
+    c: libc.char = 'C';
+    libc.puts("Hello from puts");
+    libc.printf("Hello from printf in %c\n", c);
+}
+test_thread :: proc() {
+    thread_proc :: proc "c" (rawptr) -> libc.int {
+        libc.printf("Hello from thread");
+        return 42;
+    }
+    thread: libc.thrd_t;
+    libc.thrd_create(&thread, thread_proc, nil);
+    result: libc.int;
+    libc.thrd_join(thread, &result);
+    libc.printf(" %d\n", result);
+}
+
+jmp: libc.jmp_buf;
+test_sjlj :: proc() {
+    if libc.setjmp(&jmp) != 0 {
+        libc.printf("Hello from longjmp\n");
+        return;
+    }
+    libc.printf("Hello from setjmp\n");
+    libc.longjmp(&jmp, 1);
+}
+test_signal :: proc() {
+    handler :: proc "c" (sig: libc.int) {
+        libc.printf("Hello from signal handler\n");
+    }
+    libc.signal(libc.SIGABRT, handler);
+    libc.raise(libc.SIGABRT);
+}
+test_atexit :: proc() {
+    handler :: proc "c" () {
+        libc.printf("Hello from atexit\n");
+    }
+    libc.atexit(handler);
+}
+main :: proc() {
+    test_stdio();
+    test_thread();
+    test_sjlj();
+    test_signal();
+    test_atexit();
+}

+ 7 - 3
core/c/libc/threads.odin

@@ -6,7 +6,10 @@ thrd_start_t :: proc "c" (rawptr) -> int;
 tss_dtor_t   :: proc "c" (rawptr);
 
 when ODIN_OS == "windows" {
-	foreign import libc "system:c"
+	foreign import libc {
+		"system:libucrt.lib", 
+		"system:msvcprt.lib"
+	}
 
 	thrd_success        :: 0;                             // _Thrd_success
 	thrd_nomem          :: 1;                             // _Thrd_nomem
@@ -24,6 +27,7 @@ when ODIN_OS == "windows" {
 	thrd_t              :: struct { _: rawptr, _: uint, } // _Thrd_t
 	tss_t               :: distinct int;                  // _Tss_imp_t
 	cnd_t               :: distinct rawptr;               // _Cnd_imp_t
+	mtx_t               :: distinct rawptr;               // _Mtx_imp_t
 
 	// MSVCRT does not expose the C11 symbol names as what they are in C11
 	// because they held off implementing <threads.h> and C11 support for so
@@ -52,9 +56,9 @@ when ODIN_OS == "windows" {
 		@(link_name="_Mtx_unlock")    mtx_unlock    :: proc(mtx: ^mtx_t) -> int ---;
 
 		// 7.26.5 Thread functions
-		@(link_name="_Thrd_create")   thrd_create   :: proc(thr: ^thr_t, func: thrd_start_t, arg: rawptr) -> int ---;
+		@(link_name="_Thrd_create")   thrd_create   :: proc(thr: ^thrd_t, func: thrd_start_t, arg: rawptr) -> int ---;
 		@(link_name="_Thrd_current")  thrd_current  :: proc() -> thrd_t ---;
-		@(link_name="_Thrd_detach")   thrd_detach   :: proc(thr: thr_t) -> int ---;
+		@(link_name="_Thrd_detach")   thrd_detach   :: proc(thr: thrd_t) -> int ---;
 		@(link_name="_Thrd_equal")    thrd_equal    :: proc(lhs, rhs: thrd_t) -> int ---;
 		@(link_name="_Thrd_exit")     thrd_exit     :: proc(res: int) -> ! ---;
 		@(link_name="_Thrd_join")     thrd_join     :: proc(thr: thrd_t, res: ^int) -> int ---;

+ 5 - 1
core/c/libc/time.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.27 Date and time
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 // We enforce 64-bit time_t and timespec as there is no reason to use 32-bit as
 // we approach the 2038 problem. Windows has defaulted to this since VC8 (2005).

+ 5 - 1
core/c/libc/uchar.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.28 Unicode utilities
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 @(default_calling_convention="c")
 foreign libc {

+ 5 - 1
core/c/libc/wchar.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.29 Extended multibyte and wide character utilities
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 @(default_calling_convention="c")
 foreign libc {

+ 5 - 1
core/c/libc/wctype.odin

@@ -2,7 +2,11 @@ package libc
 
 // 7.30 Wide character classification and mapping utilities
 
-foreign import libc "system:c"
+when ODIN_OS == "windows" {
+	foreign import libc "system:libucrt.lib"
+} else {
+	foreign import libc "system:c"
+}
 
 when ODIN_OS == "windows" {
 	wctrans_t :: distinct wchar_t;

+ 8 - 8
core/math/big/build.bat

@@ -1,9 +1,9 @@
 @echo off
-odin run . -vet -o:size
-: -o:size
-:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py -fast-tests
-:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py
-: -fast-tests
-:odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests
+:odin run . -vet
+
+set TEST_ARGS=-fast-tests
+:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+odin build . -build-mode:shared -show-timings -o:size -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:size -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check -define:MATH_BIG_EXE=false && python test.py %TEST_ARGS%
+:odin build . -build-mode:shared -show-timings -o:speed -define:MATH_BIG_EXE=false && python test.py -fast-tests %TEST_ARGS%

Plik diff jest za duży
+ 0 - 0
core/math/big/example.odin


+ 93 - 30
core/math/big/internal.odin

@@ -659,8 +659,7 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 			Can we use the balance method? Check sizes.
 			* The smaller one needs to be larger than the Karatsuba cut-off.
 			* The bigger one needs to be at least about one `_MUL_KARATSUBA_CUTOFF` bigger
-			* to make some sense, but it depends on architecture, OS, position of the
-			* stars... so YMMV.
+			* to make some sense, but it depends on architecture, OS, position of the stars... so YMMV.
 			* Using it to cut the input into slices small enough for _mul_comba
 			* was actually slower on the author's machine, but YMMV.
 		*/
@@ -669,13 +668,11 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc
 		max_used := max(src.used, multiplier.used);
 		digits   := src.used + multiplier.used + 1;
 
-		if        false &&  min_used     >= MUL_KARATSUBA_CUTOFF &&
-						    max_used / 2 >= MUL_KARATSUBA_CUTOFF &&
+		if min_used >= MUL_KARATSUBA_CUTOFF && (max_used / 2) >= MUL_KARATSUBA_CUTOFF && max_used >= (2 * min_used) {
 			/*
 				Not much effect was observed below a ratio of 1:2, but again: YMMV.
 			*/
-							max_used     >= 2 * min_used {
-			// err = s_mp_mul_balance(a,b,c);
+			err = _private_int_mul_balance(dest, src, multiplier);
 		} else if min_used >= MUL_TOOM_CUTOFF {
 			/*
 				Toom path commented out until it no longer fails Factorial 10k or 100k,
@@ -861,7 +858,12 @@ internal_int_mod :: proc(remainder, numerator, denominator: ^Int, allocator := c
 
 	return #force_inline internal_add(remainder, remainder, numerator, allocator);
 }
-internal_mod :: proc{ internal_int_mod, };
+
+internal_int_mod_digit :: proc(numerator: ^Int, denominator: DIGIT, allocator := context.allocator) -> (remainder: DIGIT, err: Error) {
+	return internal_int_divmod_digit(nil, numerator, denominator, allocator);
+}
+
+internal_mod :: proc{ internal_int_mod, internal_int_mod_digit};
 
 /*
 	remainder = (number + addend) % modulus.
@@ -909,7 +911,7 @@ internal_int_factorial :: proc(res: ^Int, n: int, allocator := context.allocator
 	context.allocator = allocator;
 
 	if n >= FACTORIAL_BINARY_SPLIT_CUTOFF {
-		return #force_inline _private_int_factorial_binary_split(res, n);
+		return _private_int_factorial_binary_split(res, n);
 	}
 
 	i := len(_factorial_table);
@@ -1106,10 +1108,10 @@ internal_compare :: proc { internal_int_compare, internal_int_compare_digit, };
 internal_cmp :: internal_compare;
 
 /*
-    Compare an `Int` to an unsigned number upto `DIGIT & _MASK`.
-    Returns -1 if `a` < `b`, 0 if `a` == `b` and 1 if `b` > `a`.
+	Compare an `Int` to an unsigned number upto `DIGIT & _MASK`.
+	Returns -1 if `a` < `b`, 0 if `a` == `b` and 1 if `b` > `a`.
 
-    Expects: `a` and `b` both to be valid `Int`s, i.e. initialized and not `nil`.
+	Expects: `a` and `b` both to be valid `Int`s, i.e. initialized and not `nil`.
 */
 internal_int_compare_digit :: #force_inline proc(a: ^Int, b: DIGIT) -> (comparison: int) {
 	a_is_negative := #force_inline internal_is_negative(a);
@@ -1165,11 +1167,72 @@ internal_int_compare_magnitude :: #force_inline proc(a, b: ^Int) -> (comparison:
 		}
 	}
 
-   	return 0;
+	return 0;
 }
 internal_compare_magnitude :: proc { internal_int_compare_magnitude, };
 internal_cmp_mag :: internal_compare_magnitude;
 
+/*
+	Check if remainders are possible squares - fast exclude non-squares.
+
+	Returns `true` if `a` is a square, `false` if not.
+	Assumes `a` not to be `nil` and to have been initialized.
+*/
+internal_int_is_square :: proc(a: ^Int, allocator := context.allocator) -> (square: bool, err: Error) {
+	context.allocator = allocator;
+
+	/*
+		Default to Non-square :)
+	*/
+	square = false;
+
+	if internal_is_negative(a)                                       { return; }
+	if internal_is_zero(a)                                           { return; }
+
+	/*
+		First check mod 128 (suppose that _DIGIT_BITS is at least 7).
+	*/
+	if _private_int_rem_128[127 & a.digit[0]] == 1                   { return; }
+
+	/*
+		Next check mod 105 (3*5*7).
+	*/
+	c: DIGIT;
+	c, err = internal_mod(a, 105);
+	if _private_int_rem_105[c] == 1                                  { return; }
+
+	t := &Int{};
+	defer destroy(t);
+
+	set(t, 11 * 13 * 17 * 19 * 23 * 29 * 31) or_return;
+	internal_mod(t, a, t) or_return;
+
+	r: u64;
+	r, err = internal_int_get(t, u64);
+
+	/*
+		Check for other prime modules, note it's not an ERROR but we must
+		free "t" so the easiest way is to goto LBL_ERR.  We know that err
+		is already equal to MP_OKAY from the mp_mod call
+	*/
+	if (1 << (r % 11) &      0x5C4) != 0                             { return; }
+	if (1 << (r % 13) &      0x9E4) != 0                             { return; }
+	if (1 << (r % 17) &     0x5CE8) != 0                             { return; }
+	if (1 << (r % 19) &    0x4F50C) != 0                             { return; }
+	if (1 << (r % 23) &   0x7ACCA0) != 0                             { return; }
+	if (1 << (r % 29) &  0xC2EDD0C) != 0                             { return; }
+	if (1 << (r % 31) & 0x6DE2B848) != 0                             { return; }
+
+	/*
+		Final check - is sqr(sqrt(arg)) == arg?
+	*/
+	sqrt(t, a) or_return;
+	sqr(t, t)  or_return;
+
+	square = internal_cmp_mag(t, a) == 0;
+
+	return;
+}
 
 /*
 	=========================    Logs, powers and roots    ============================
@@ -2300,12 +2363,12 @@ internal_int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int, all
 			/*
 				Shift the current word and mix in the carry bits from the previous word.
 			*/
-	        quotient.digit[x] = (quotient.digit[x] >> uint(bits)) | (carry << shift);
+			quotient.digit[x] = (quotient.digit[x] >> uint(bits)) | (carry << shift);
 
-	        /*
-	        	Update carry from forward carry.
-	        */
-	        carry = fwd_carry;
+			/*
+				Update carry from forward carry.
+			*/
+			carry = fwd_carry;
 		}
 
 	}
@@ -2331,17 +2394,17 @@ internal_int_shr_digit :: proc(quotient: ^Int, digits: int, allocator := context
 	*/
 	if digits > quotient.used { return internal_zero(quotient); }
 
-   	/*
+	/*
 		Much like `int_shl_digit`, this is implemented using a sliding window,
 		except the window goes the other way around.
 
 		b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
-		            /\                   |      ---->
-		             \-------------------/      ---->
-    */
+					/\                   |      ---->
+					 \-------------------/      ---->
+	*/
 
 	#no_bounds_check for x := 0; x < (quotient.used - digits); x += 1 {
-    	quotient.digit[x] = quotient.digit[x + digits];
+		quotient.digit[x] = quotient.digit[x + digits];
 	}
 	quotient.used -= digits;
 	internal_zero_unused(quotient);
@@ -2445,14 +2508,14 @@ internal_int_shl_digit :: proc(quotient: ^Int, digits: int, allocator := context
 	/*
 		Much like `int_shr_digit`, this is implemented using a sliding window,
 		except the window goes the other way around.
-    */
-    #no_bounds_check for x := quotient.used; x > 0; x -= 1 {
-    	quotient.digit[x+digits-1] = quotient.digit[x-1];
-    }
-
-   	quotient.used += digits;
-    mem.zero_slice(quotient.digit[:digits]);
-    return nil;
+	*/
+	#no_bounds_check for x := quotient.used; x > 0; x -= 1 {
+		quotient.digit[x+digits-1] = quotient.digit[x-1];
+	}
+
+	quotient.used += digits;
+	mem.zero_slice(quotient.digit[:digits]);
+	return nil;
 }
 internal_shl_digit :: proc { internal_int_shl_digit, };
 

+ 177 - 0
core/math/big/prime.odin

@@ -33,6 +33,183 @@ int_prime_is_divisible :: proc(a: ^Int, allocator := context.allocator) -> (res:
 	return false, nil;
 }
 
+/*
+	Computes xR**-1 == x (mod N) via Montgomery Reduction.
+*/
+internal_int_montgomery_reduce :: proc(x, n: ^Int, rho: DIGIT, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+	/*
+		Can the fast reduction [comba] method be used?
+		Note that unlike in mul, you're safely allowed *less* than the available columns [255 per default],
+		since carries are fixed up in the inner loop.
+	*/
+	digs := (n.used * 2) + 1;
+	if digs < _WARRAY && x.used <= _WARRAY && n.used < _MAX_COMBA {
+		return _private_montgomery_reduce_comba(x, n, rho);
+	}
+
+	/*
+		Grow the input as required
+	*/
+	internal_grow(x, digs)                                           or_return;
+	x.used = digs;
+
+	for ix := 0; ix < n.used; ix += 1 {
+		/*
+			`mu = ai * rho mod b`
+			The value of rho must be precalculated via `int_montgomery_setup()`,
+			such that it equals -1/n0 mod b this allows the following inner loop
+			to reduce the input one digit at a time.
+		*/
+
+		mu := DIGIT((_WORD(x.digit[ix]) * _WORD(rho)) & _WORD(_MASK));
+
+		/*
+			a = a + mu * m * b**i
+			Multiply and add in place.
+		*/
+		u  := DIGIT(0);
+		iy := int(0);
+		for ; iy < n.used; iy += 1 {
+			/*
+				Compute product and sum.
+			*/
+			r := (_WORD(mu) * _WORD(n.digit[iy]) + _WORD(u) + _WORD(x.digit[ix + iy]));
+
+			/*
+				Get carry.
+			*/
+			u = DIGIT(r >> _DIGIT_BITS);
+
+			/*
+				Fix digit.
+			*/
+			x.digit[ix + iy] = DIGIT(r & _WORD(_MASK));
+		}
+
+		/*
+			At this point the ix'th digit of x should be zero.
+			Propagate carries upwards as required.
+		*/
+		for u != 0 {
+			x.digit[ix + iy] += u;
+			u = x.digit[ix + iy] >> _DIGIT_BITS;
+			x.digit[ix + iy] &= _MASK;
+			iy += 1;
+		}
+	}
+
+	/*
+		At this point the n.used'th least significant digits of x are all zero,
+		which means we can shift x to the right by n.used digits and the
+		residue is unchanged.
+
+		x = x/b**n.used.
+	*/
+	internal_clamp(x);
+	internal_shr_digit(x, n.used);
+
+	/*
+		if x >= n then x = x - n
+	*/
+	if internal_cmp_mag(x, n) != -1 {
+		return internal_sub(x, x, n);
+	}
+
+	return nil;
+}
+
+int_montgomery_reduce :: proc(x, n: ^Int, rho: DIGIT, allocator := context.allocator) -> (err: Error) {
+	assert_if_nil(x, n);
+	context.allocator = allocator;
+
+	internal_clear_if_uninitialized(x, n) or_return;
+
+	return #force_inline internal_int_montgomery_reduce(x, n, rho);
+}
+
+/*
+	Shifts with subtractions when the result is greater than b.
+
+	The method is slightly modified to shift B unconditionally upto just under
+	the leading bit of b.  This saves alot of multiple precision shifting.
+*/
+internal_int_montgomery_calc_normalization :: proc(a, b: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+	/*
+		How many bits of last digit does b use.
+	*/
+	bits := internal_count_bits(b) % _DIGIT_BITS;
+
+	if b.used > 1 {
+		power := ((b.used - 1) * _DIGIT_BITS) + bits - 1;
+		internal_int_power_of_two(a, power)                          or_return;
+	} else {
+		internal_one(a);
+		bits = 1;
+	}
+
+	/*
+		Now compute C = A * B mod b.
+	*/
+	for x := bits - 1; x < _DIGIT_BITS; x += 1 {
+		internal_int_shl1(a, a)                                      or_return;
+		if internal_cmp_mag(a, b) != -1 {
+			internal_sub(a, a, b)                                    or_return;
+		}
+	}
+	return nil;
+}
+
+int_montgomery_calc_normalization :: proc(a, b: ^Int, allocator := context.allocator) -> (err: Error) {
+	assert_if_nil(a, b);
+	context.allocator = allocator;
+
+	internal_clear_if_uninitialized(a, b) or_return;
+
+	return #force_inline internal_int_montgomery_calc_normalization(a, b);
+}
+
+/*
+	Sets up the Montgomery reduction stuff.
+*/
+internal_int_montgomery_setup :: proc(n: ^Int) -> (rho: DIGIT, err: Error) {
+	/*
+		Fast inversion mod 2**k
+		Based on the fact that:
+
+		XA = 1 (mod 2**n) => (X(2-XA)) A = 1 (mod 2**2n)
+		                  =>  2*X*A - X*X*A*A = 1
+		                  =>  2*(1) - (1)     = 1
+	*/
+	b := n.digit[0];
+	if b & 1 == 0 { return 0, .Invalid_Argument; }
+
+	x := (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
+	x *= 2 - (b * x);              /* here x*a==1 mod 2**8 */
+	x *= 2 - (b * x);              /* here x*a==1 mod 2**16 */
+	when _WORD_TYPE_BITS == 64 {
+		x *= 2 - (b * x);              /* here x*a==1 mod 2**32 */
+		x *= 2 - (b * x);              /* here x*a==1 mod 2**64 */
+	}
+
+	/*
+		rho = -1/m mod b
+	*/
+	rho = DIGIT(((_WORD(1) << _WORD(_DIGIT_BITS)) - _WORD(x)) & _WORD(_MASK));
+	return rho, nil;
+}
+
+int_montgomery_setup :: proc(n: ^Int, allocator := context.allocator) -> (rho: DIGIT, err: Error) {
+	assert_if_nil(n);
+	internal_clear_if_uninitialized(n, allocator) or_return;
+
+	return #force_inline internal_int_montgomery_setup(n);
+}
+
+/*
+	Returns the number of Rabin-Miller trials needed for a given bit size.
+*/
 number_of_rabin_miller_trials :: proc(bit_size: int) -> (number_of_trials: int) {
 	switch {
 	case bit_size <=    80:

+ 356 - 20
core/math/big/private.odin

@@ -113,7 +113,7 @@ _private_int_mul_toom :: proc(dest, a, b: ^Int, allocator := context.allocator)
 	context.allocator = allocator;
 
 	S1, S2, T1, a0, a1, a2, b0, b1, b2 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
-	defer destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
+	defer internal_destroy(S1, S2, T1, a0, a1, a2, b0, b1, b2);
 
 	/*
 		Init temps.
@@ -258,7 +258,7 @@ _private_int_mul_karatsuba :: proc(dest, a, b: ^Int, allocator := context.alloca
 	context.allocator = allocator;
 
 	x0, x1, y0, y1, t1, x0y0, x1y1 := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
-	defer destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
+	defer internal_destroy(x0, x1, y0, y1, t1, x0y0, x1y1);
 
 	/*
 		min # of digits, divided by two.
@@ -426,6 +426,195 @@ _private_int_mul_comba :: proc(dest, a, b: ^Int, digits: int, allocator := conte
 	return internal_clamp(dest);
 }
 
+/*
+	Multiplies |a| * |b| and does not compute the lower digs digits
+	[meant to get the higher part of the product]
+*/
+_private_int_mul_high :: proc(dest, a, b: ^Int, digits: int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	/*
+		Can we use the fast multiplier?
+	*/
+	if a.used + b.used + 1 < _WARRAY && min(a.used, b.used) < _MAX_COMBA {
+		return _private_int_mul_high_comba(dest, a, b, digits);
+	}
+
+	internal_grow(dest, a.used + b.used + 1) or_return;
+	dest.used = a.used + b.used + 1;
+
+	pa := a.used;
+	pb := b.used;
+	for ix := 0; ix < pa; ix += 1 {
+		carry := DIGIT(0);
+
+		for iy := digits - ix; iy < pb; iy += 1 {
+			/*
+				Calculate the double precision result.
+			*/
+			r := _WORD(dest.digit[ix + iy]) + _WORD(a.digit[ix]) * _WORD(b.digit[iy]) + _WORD(carry);
+
+			/*
+				Get the lower part.
+			*/
+			dest.digit[ix + iy] = DIGIT(r & _WORD(_MASK));
+
+			/*
+				Carry the carry.
+			*/
+			carry = DIGIT(r >> _WORD(_DIGIT_BITS));
+		}
+		dest.digit[ix + pb] = carry;
+	}
+	return internal_clamp(dest);
+}
+
+/*
+	This is a modified version of `_private_int_mul_comba` that only produces output digits *above* `digits`.
+	See the comments for `_private_int_mul_comba` to see how it works.
+
+	This is used in the Barrett reduction since for one of the multiplications
+	only the higher digits were needed.  This essentially halves the work.
+
+	Based on Algorithm 14.12 on pp.595 of HAC.
+*/
+_private_int_mul_high_comba :: proc(dest, a, b: ^Int, digits: int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
+	W: [_WARRAY]DIGIT = ---;
+	_W: _WORD = 0;
+
+	/*
+		Number of output digits to produce. Grow the destination as required.
+	*/
+	pa := a.used + b.used;
+	internal_grow(dest, pa) or_return;
+
+	ix: int;
+	for ix = digits; ix < pa; ix += 1 {
+		/*
+			Get offsets into the two bignums.
+		*/
+		ty := min(b.used - 1, ix);
+		tx := ix - ty;
+
+		/*
+			This is the number of times the loop will iterrate, essentially it's
+			while (tx++ < a->used && ty-- >= 0) { ... }
+		*/
+		iy := min(a.used - tx, ty + 1);
+
+		/*
+			Execute loop.
+		*/
+		for iz := 0; iz < iy; iz += 1 {
+			_W += _WORD(a.digit[tx + iz]) * _WORD(b.digit[ty - iz]);
+		}
+
+		/*
+			Store term.
+		*/
+		W[ix] = DIGIT(_W) & DIGIT(_MASK);
+
+		/*
+			Make next carry.
+		*/
+		_W = _W >> _WORD(_DIGIT_BITS);
+	}
+
+	/*
+		Setup dest
+	*/
+	old_used := dest.used;
+	dest.used = pa;
+
+	for ix = digits; ix < pa; ix += 1 {
+		/*
+			Now extract the previous digit [below the carry].
+		*/
+		dest.digit[ix] = W[ix];
+	}
+
+	/*
+		Zero remainder.
+	*/
+	internal_zero_unused(dest, old_used);
+
+	/*
+		Adjust dest.used based on leading zeroes.
+	*/
+	return internal_clamp(dest);
+}
+
+/*
+	Single-digit multiplication with the smaller number as the single-digit.
+*/
+_private_int_mul_balance :: proc(dest, a, b: ^Int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+	a, b := a, b;
+
+	a0, tmp, r := &Int{}, &Int{}, &Int{};
+	defer internal_destroy(a0, tmp, r);
+
+	b_size   := min(a.used, b.used);
+	n_blocks := max(a.used, b.used) / b_size;
+
+	internal_grow(a0, b_size + 2) or_return;
+	internal_init_multi(tmp, r)   or_return;
+
+	/*
+		Make sure that `a` is the larger one.
+	*/
+	if a.used < b.used {
+		a, b = b, a;
+	}
+	assert(a.used >= b.used);
+
+	i, j := 0, 0;
+	for ; i < n_blocks; i += 1 {
+		/*
+			Cut a slice off of `a`.
+		*/
+
+		a0.used = b_size;
+		internal_copy_digits(a0, a, a0.used, j);
+		j += a0.used;
+		internal_clamp(a0);
+
+		/*
+			Multiply with `b`.
+		*/
+		internal_mul(tmp, a0, b)                                     or_return;
+
+		/*
+			Shift `tmp` to the correct position.
+		*/
+		internal_shl_digit(tmp, b_size * i)                          or_return;
+
+		/*
+			Add to output. No carry needed.
+		*/
+		internal_add(r, r, tmp)                                      or_return;
+	}
+
+	/*
+		The left-overs; there are always left-overs.
+	*/
+	if j < a.used {
+		a0.used = a.used - j;
+		internal_copy_digits(a0, a, a0.used, j);
+		j += a0.used;
+		internal_clamp(a0);
+
+		internal_mul(tmp, a0, b)                                     or_return;
+		internal_shl_digit(tmp, b_size * i)                          or_return;
+		internal_add(r, r, tmp)                                      or_return;
+	}
+
+	internal_swap(dest, r);
+	return;
+}
+
 /*
 	Low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16
 	Assumes `dest` and `src` to not be `nil`, and `src` to have been initialized.
@@ -1188,7 +1377,7 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
 
 	ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
 	c: int;
-	defer destroy(ta, tb, tq, q);
+	defer internal_destroy(ta, tb, tq, q);
 
 	for {
 		internal_one(tq) or_return;
@@ -1241,31 +1430,34 @@ _private_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int
 	Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
 */
 _private_int_factorial_binary_split :: proc(res: ^Int, n: int, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
 
 	inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
 	defer internal_destroy(inner, outer, start, stop, temp);
 
-	internal_one(inner, false, allocator) or_return;
-	internal_one(outer, false, allocator) or_return;
+	internal_one(inner, false)                                       or_return;
+	internal_one(outer, false)                                       or_return;
 
 	bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
 
 	for i := bits_used; i >= 0; i -= 1 {
 		start := (n >> (uint(i) + 1)) + 1 | 1;
 		stop  := (n >> uint(i)) + 1 | 1;
-		_private_int_recursive_product(temp, start, stop, 0, allocator) or_return;
-		internal_mul(inner, inner, temp, allocator) or_return;
-		internal_mul(outer, outer, inner, allocator) or_return;
+		_private_int_recursive_product(temp, start, stop, 0)         or_return;
+		internal_mul(inner, inner, temp)                             or_return;
+		internal_mul(outer, outer, inner)                            or_return;
 	}
 	shift := n - intrinsics.count_ones(n);
 
-	return internal_shl(res, outer, int(shift), allocator);
+	return internal_shl(res, outer, int(shift));
 }
 
 /*
 	Recursive product used by binary split factorial algorithm.
 */
 _private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0), allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+
 	t1, t2 := &Int{}, &Int{};
 	defer internal_destroy(t1, t2);
 
@@ -1275,28 +1467,28 @@ _private_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int
 
 	num_factors := (stop - start) >> 1;
 	if num_factors == 2 {
-		internal_set(t1, start, false, allocator) or_return;
+		internal_set(t1, start, false)                               or_return;
 		when true {
-			internal_grow(t2, t1.used + 1, false, allocator) or_return;
-			internal_add(t2, t1, 2, allocator) or_return;
+			internal_grow(t2, t1.used + 1, false)                    or_return;
+			internal_add(t2, t1, 2)                                  or_return;
 		} else {
-			add(t2, t1, 2) or_return;
+			internal_add(t2, t1, 2)                                  or_return;
 		}
-		return internal_mul(res, t1, t2, allocator);
+		return internal_mul(res, t1, t2);
 	}
 
 	if num_factors > 1 {
 		mid := (start + num_factors) | 1;
-		_private_int_recursive_product(t1, start,  mid, level + 1, allocator) or_return;
-		_private_int_recursive_product(t2,   mid, stop, level + 1, allocator) or_return;
-		return internal_mul(res, t1, t2, allocator);
+		_private_int_recursive_product(t1, start,  mid, level + 1)   or_return;
+		_private_int_recursive_product(t2,   mid, stop, level + 1)   or_return;
+		return internal_mul(res, t1, t2);
 	}
 
 	if num_factors == 1 {
-		return #force_inline internal_set(res, start, true, allocator);
+		return #force_inline internal_set(res, start, true);
 	}
 
-	return #force_inline internal_one(res, true, allocator);
+	return #force_inline internal_one(res, true);
 }
 
 /*
@@ -1542,6 +1734,126 @@ _private_int_log :: proc(a: ^Int, base: DIGIT, allocator := context.allocator) -
 
 
 
+/*
+	Computes xR**-1 == x (mod N) via Montgomery Reduction.
+	This is an optimized implementation of `internal_montgomery_reduce`
+	which uses the comba method to quickly calculate the columns of the reduction.
+	Based on Algorithm 14.32 on pp.601 of HAC.
+*/
+_private_montgomery_reduce_comba :: proc(x, n: ^Int, rho: DIGIT, allocator := context.allocator) -> (err: Error) {
+	context.allocator = allocator;
+	W: [_WARRAY]_WORD = ---;
+
+	if x.used > _WARRAY { return .Invalid_Argument; }
+
+	/*
+		Get old used count.
+	*/
+	old_used := x.used;
+
+	/*
+		Grow `x` as required.
+	*/
+	internal_grow(x, n.used + 1) or_return;
+
+	/*
+		First we have to get the digits of the input into an array of double precision words W[...]
+		Copy the digits of `x` into W[0..`x.used` - 1]
+	*/
+	ix: int;
+	for ix = 0; ix < x.used; ix += 1 {
+		W[ix] = _WORD(x.digit[ix]);
+	}
+
+	/*
+		Zero the high words of W[a->used..m->used*2].
+	*/
+	zero_upper := (n.used * 2) + 1;
+	if ix < zero_upper {
+		for ix = x.used; ix < zero_upper; ix += 1 {
+			W[ix] = {};
+		}
+	}
+
+	/*
+		Now we proceed to zero successive digits from the least significant upwards.
+	*/
+	for ix = 0; ix < n.used; ix += 1 {
+		/*
+			`mu = ai * m' mod b`
+
+			We avoid a double precision multiplication (which isn't required)
+			by casting the value down to a DIGIT.  Note this requires
+			that W[ix-1] have the carry cleared (see after the inner loop)
+		*/
+		mu := ((W[ix] & _WORD(_MASK)) * _WORD(rho)) & _WORD(_MASK);
+
+		/*
+			`a = a + mu * m * b**i`
+		
+			This is computed in place and on the fly.  The multiplication
+		 	by b**i is handled by offseting which columns the results
+		 	are added to.
+		
+			Note the comba method normally doesn't handle carries in the
+			inner loop In this case we fix the carry from the previous
+			column since the Montgomery reduction requires digits of the
+			result (so far) [see above] to work.
+
+			This is	handled by fixing up one carry after the inner loop.
+			The carry fixups are done in order so after these loops the
+			first m->used words of W[] have the carries fixed.
+		*/
+		for iy := 0; iy < n.used; iy += 1 {
+			W[ix + iy] += mu * _WORD(n.digit[iy]);
+		}
+
+		/*
+			Now fix carry for next digit, W[ix+1].
+		*/
+		W[ix + 1] += (W[ix] >> _DIGIT_BITS);
+	}
+
+	/*
+		Now we have to propagate the carries and shift the words downward
+		[all those least significant digits we zeroed].
+	*/
+
+	for ; ix < n.used * 2; ix += 1 {
+		W[ix + 1] += (W[ix] >> _DIGIT_BITS);
+	}
+
+	/* copy out, A = A/b**n
+	 *
+	 * The result is A/b**n but instead of converting from an
+	 * array of mp_word to mp_digit than calling mp_rshd
+	 * we just copy them in the right order
+	 */
+
+	for ix = 0; ix < (n.used + 1); ix += 1 {
+		x.digit[ix] = DIGIT(W[n.used + ix] & _WORD(_MASK));
+	}
+
+	/*
+		Set the max used.
+	*/
+	x.used = n.used + 1;
+
+	/*
+		Zero old_used digits, if the input a was larger than m->used+1 we'll have to clear the digits.
+	*/
+	internal_zero_unused(x, old_used);
+	internal_clamp(x);
+
+	/*
+		if A >= m then A = A - m
+	*/
+	if internal_cmp_mag(x, n) != -1 {
+		return internal_sub(x, x, n);
+	}
+	return nil;
+}
+
 /*
 	hac 14.61, pp608
 */
@@ -1887,7 +2199,30 @@ _private_copy_digits :: proc(dest, src: ^Int, digits: int, offset := int(0)) ->
 	Tables used by `internal_*` and `_*`.
 */
 
-_private_prime_table := []DIGIT{
+_private_int_rem_128 := [?]DIGIT{
+	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
+};
+#assert(128 * size_of(DIGIT) == size_of(_private_int_rem_128));
+
+_private_int_rem_105 := [?]DIGIT{
+	0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
+	0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1,
+	0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1,
+	1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
+	0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
+	1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1,
+	1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1,
+};
+#assert(105 * size_of(DIGIT) == size_of(_private_int_rem_105));
+
+_private_prime_table := [?]DIGIT{
 	0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
 	0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
 	0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
@@ -1924,6 +2259,7 @@ _private_prime_table := []DIGIT{
 	0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
 	0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653,
 };
+#assert(256 * size_of(DIGIT) == size_of(_private_prime_table));
 
 when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
 	_factorial_table := [35]_WORD{

+ 15 - 0
core/math/big/public.odin

@@ -555,4 +555,19 @@ int_compare_magnitude :: proc(a, b: ^Int, allocator := context.allocator) -> (re
 	internal_clear_if_uninitialized(a, b) or_return;
 
 	return #force_inline internal_cmp_mag(a, b), nil;
+}
+
+/*
+	Check if remainders are possible squares - fast exclude non-squares.
+
+	Returns `true` if `a` is a square, `false` if not.
+	Assumes `a` not to be `nil` and to have been initialized.
+*/
+int_is_square :: proc(a: ^Int, allocator := context.allocator) -> (square: bool, err: Error) {
+	assert_if_nil(a);
+	context.allocator = allocator;
+
+	internal_clear_if_uninitialized(a) or_return;
+
+	return #force_inline internal_int_is_square(a);
 }

+ 1 - 1
core/math/big/radix.odin

@@ -235,7 +235,7 @@ int_to_cstring :: int_itoa_cstring;
 /*
 	Read a string [ASCII] in a given radix.
 */
-int_atoi :: proc(res: ^Int, input: string, radix: i8, allocator := context.allocator) -> (err: Error) {
+int_atoi :: proc(res: ^Int, input: string, radix := i8(10), allocator := context.allocator) -> (err: Error) {
 	assert_if_nil(res);
 	input := input;
 	context.allocator = allocator;

+ 19 - 0
core/math/big/test.odin

@@ -369,3 +369,22 @@ PyRes :: struct {
 	return PyRes{res = r, err = nil};
 }
 
+/*
+	dest = lcm(a, b)
+*/
+@export test_is_square :: proc "c" (a: cstring) -> (res: PyRes) {
+	context = runtime.default_context();
+	err:    Error;
+	square: bool;
+
+	ai := &Int{};
+	defer internal_destroy(ai);
+
+	if err = atoi(ai, string(a), 16); err != nil { return PyRes{res=":is_square:atoi(a):", err=err}; }
+	if square, err = #force_inline internal_int_is_square(ai); err != nil { return PyRes{res=":is_square:is_square(a):", err=err}; }
+
+	if square {
+		return PyRes{"True", nil};
+	}
+	return PyRes{"False", nil};
+}

+ 41 - 17
core/math/big/test.py

@@ -160,11 +160,11 @@ print("initialize_constants: ", initialize_constants())
 
 error_string = load(l.test_error_string, [c_byte], c_char_p)
 
-add        =     load(l.test_add,        [c_char_p, c_char_p],   Res)
-sub        =     load(l.test_sub,        [c_char_p, c_char_p],   Res)
-mul        =     load(l.test_mul,        [c_char_p, c_char_p],   Res)
-sqr        =     load(l.test_sqr,        [c_char_p          ],   Res)
-div        =     load(l.test_div,        [c_char_p, c_char_p],   Res)
+add        =     load(l.test_add,        [c_char_p, c_char_p  ], Res)
+sub        =     load(l.test_sub,        [c_char_p, c_char_p  ], Res)
+mul        =     load(l.test_mul,        [c_char_p, c_char_p  ], Res)
+sqr        =     load(l.test_sqr,        [c_char_p            ], Res)
+div        =     load(l.test_div,        [c_char_p, c_char_p  ], Res)
 
 # Powers and such
 int_log    =     load(l.test_log,        [c_char_p, c_longlong], Res)
@@ -179,9 +179,11 @@ int_shl        = load(l.test_shl,        [c_char_p, c_longlong], Res)
 int_shr        = load(l.test_shr,        [c_char_p, c_longlong], Res)
 int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
 
-int_factorial  = load(l.test_factorial,  [c_uint64], Res)
-int_gcd        = load(l.test_gcd,        [c_char_p, c_char_p], Res)
-int_lcm        = load(l.test_lcm,        [c_char_p, c_char_p], Res)
+int_factorial  = load(l.test_factorial,  [c_uint64            ], Res)
+int_gcd        = load(l.test_gcd,        [c_char_p, c_char_p  ], Res)
+int_lcm        = load(l.test_lcm,        [c_char_p, c_char_p  ], Res)
+
+is_square      = load(l.test_is_square,  [c_char_p            ], Res)
 
 def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = "", radix=16):
 	passed = True
@@ -401,14 +403,21 @@ def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
 		
 	return test("test_shr_signed", res, [a, bits], expected_error, expected_result)
 
-def test_factorial(n = 0, expected_error = Error.Okay):
-	args  = [n]
-	res   = int_factorial(*args)
+def test_factorial(number = 0, expected_error = Error.Okay):
+	print("Factorial:", number)
+	args  = [number]
+	try:
+		res = int_factorial(*args)
+	except OSError as e:
+		print("{} while trying to factorial {}.".format(e, number))
+		if EXIT_ON_FAIL: exit(3)
+		return False
+
 	expected_result = None
 	if expected_error == Error.Okay:
-		expected_result = math.factorial(n)
+		expected_result = math.factorial(number)
 		
-	return test("test_factorial", res, [n], expected_error, expected_result)
+	return test("test_factorial", res, [number], expected_error, expected_result)
 
 def test_gcd(a = 0, b = 0, expected_error = Error.Okay):
 	args  = [arg_to_odin(a), arg_to_odin(b)]
@@ -428,6 +437,15 @@ def test_lcm(a = 0, b = 0, expected_error = Error.Okay):
 		
 	return test("test_lcm", res, [a, b], expected_error, expected_result)
 
+def test_is_square(a = 0, b = 0, expected_error = Error.Okay):
+	args  = [arg_to_odin(a)]
+	res   = is_square(*args)
+	expected_result = None
+	if expected_error == Error.Okay:
+		expected_result = str(math.isqrt(a) ** 2 == a) if a > 0 else "False"
+		
+	return test("test_is_square", res, [a], expected_error, expected_result)
+
 # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
 #
 # The last two arguments in tests are the expected error and expected result.
@@ -527,6 +545,10 @@ TESTS = {
 		[   0, 0,  ],
 		[   0, 125,],
 	],
+	test_is_square: [
+		[ 12, ],
+		[ 92232459121502451677697058974826760244863271517919321608054113675118660929276431348516553336313179167211015633639725554914519355444316239500734169769447134357534241879421978647995614218985202290368055757891124109355450669008628757662409138767505519391883751112010824030579849970582074544353971308266211776494228299586414907715854328360867232691292422194412634523666770452490676515117702116926803826546868467146319938818238521874072436856528051486567230096290549225463582766830777324099589751817442141036031904145041055454639783559905920619197290800070679733841430619962318433709503256637256772215111521321630777950145713049902839937043785039344243357384899099910837463164007565230287809026956254332260375327814271845678201, ]
+	],
 }
 
 if not args.fast_tests:
@@ -545,7 +567,7 @@ RANDOM_TESTS = [
 	test_add, test_sub, test_mul, test_sqr, test_div,
 	test_log, test_pow, test_sqrt, test_root_n,
 	test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
-	test_gcd, test_lcm,
+	test_gcd, test_lcm, test_is_square,
 ]
 SKIP_LARGE   = [
 	test_pow, test_root_n, # test_gcd,
@@ -648,11 +670,13 @@ if __name__ == '__main__':
 					a = abs(a)
 					b = randint(0, 10);
 				elif test_proc == test_shl:
-					b = randint(0, min(BITS, 120));
+					b = randint(0, min(BITS, 120))
 				elif test_proc == test_shr_signed:
-					b = randint(0, min(BITS, 120));
+					b = randint(0, min(BITS, 120))
+				elif test_proc == test_is_square:
+					a = randint(0, 1 << BITS)
 				else:
-					b = randint(0, 1 << BITS)					
+					b = randint(0, 1 << BITS)
 
 				res = None
 

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików