Browse Source

bit: Optimized `int_bitfield_extract`.

Jeroen van Rijn 4 years ago
parent
commit
35d8976de4
3 changed files with 68 additions and 66 deletions
  1. 6 26
      core/math/big/example.odin
  2. 60 38
      core/math/big/helpers.odin
  3. 2 2
      core/math/big/test.py

+ 6 - 26
core/math/big/example.odin

@@ -81,8 +81,7 @@ Category :: enum {
 	choose,
 	lsb,
 	ctz,
-	bitfield_extract_old,
-	bitfield_extract_new,
+	bitfield_extract,
 };
 Event :: struct {
 	t: time.Duration,
@@ -123,20 +122,8 @@ demo :: proc() {
 	err = factorial(a, 1224);
 	count, _ := count_bits(a);
 
-	bits   :=  101;
-	be1, be2: _WORD;
-
-	/*
-		Sanity check loop.
-	*/
-	for o := 0; o < count - bits; o += 1 {
-		be1, _ = int_bitfield_extract(a, o, bits);
-		be2, _ = int_bitfield_extract_fast(a, o, bits);
-		if be1 != be2 {
-			fmt.printf("Offset: %v | Expected: %v | Got: %v\n", o, be1, be2);
-			assert(false);
-		}
-	}
+	bits :=  51;
+	be1: _WORD;
 
 	/*
 		Timing loop
@@ -145,16 +132,9 @@ demo :: proc() {
 	for o := 0; o < count - bits; o += 1 {
 		be1, _ = int_bitfield_extract(a, o, bits);
 	}
-	Timings[.bitfield_extract_old].t += time.tick_since(s_old);
-	Timings[.bitfield_extract_old].c += (count - bits);
-
-	s_new := time.tick_now();
-	for o := 0; o < count - bits; o += 1 {
-		be2, _ = int_bitfield_extract_fast(a, o, bits);
-	}
-	Timings[.bitfield_extract_new].t += time.tick_since(s_new);
-	Timings[.bitfield_extract_new].c += (count - bits);
-	assert(be1 == be2);
+	Timings[.bitfield_extract].t += time.tick_since(s_old);
+	Timings[.bitfield_extract].c += (count - bits);
+	fmt.printf("be1: %v\n", be1);
 }
 
 main :: proc() {

+ 60 - 38
core/math/big/helpers.odin

@@ -191,20 +191,8 @@ neg :: proc(dest, src: ^Int, allocator := context.allocator) -> (err: Error) {
 /*
 	Helpers to extract values from the `Int`.
 */
-extract_bit :: proc(a: ^Int, bit_offset: int) -> (bit: DIGIT, err: Error) {
-	/*
-		Check that `a`is usable.
-	*/
-	if err = clear_if_uninitialized(a); err != nil { return 0, err; }
-
-	limb := bit_offset / _DIGIT_BITS;
-	if limb < 0 || limb >= a.used {
-		return 0, .Invalid_Argument;
-	}
-
-	i := DIGIT(1 << DIGIT((bit_offset % _DIGIT_BITS)));
-
-	return 1 if ((a.digit[limb] & i) != 0) else 0, nil;
+int_bitfield_extract_single :: proc(a: ^Int, offset: int) -> (bit: _WORD, err: Error) {
+	return int_bitfield_extract(a, offset, 1);
 }
 
 int_bitfield_extract :: proc(a: ^Int, offset, count: int) -> (res: _WORD, err: Error) {
@@ -212,38 +200,72 @@ int_bitfield_extract :: proc(a: ^Int, offset, count: int) -> (res: _WORD, err: E
 		Check that `a` is usable.
 	*/
 	if err = clear_if_uninitialized(a); err != nil { return 0, err; }
+	/*
+		Early out for single bit.
+	*/
+	if count == 1 {
+		limb := offset / _DIGIT_BITS;
+		if limb < 0 || limb >= a.used { return 0, .Invalid_Argument; }
+		i := _WORD(1 << _WORD((offset % _DIGIT_BITS)));
+		return 1 if ((_WORD(a.digit[limb]) & i) != 0) else 0, nil;
+	}
+
 	if count > _WORD_BITS || count < 1             { return 0, .Invalid_Argument; }
 
-	for shift := 0; shift < count; shift += 1 {
-		bit_offset := offset + shift;
+	/*
+		There are 3 possible cases.
+		-	[offset:][:count] covers 1 DIGIT,
+				e.g. offset: 0, count: 60 = bits 0..59
+		-	[offset:][:count] covers 2 DIGITS,
+				e.g. offset: 5, count: 60 = bits 5..59, 0..4
+				e.g. offset:  0, count: 120 = bits 0..59, 60..119
+		-	[offset:][:count] covers 3 DIGITS,
+				e.g. offset: 40, count: 100 = bits 40..59, 0..59, 0..19
+				e.g. offset: 40, count: 120 = bits 40..59, 0..59, 0..39
+	*/
 
-		limb := bit_offset / _DIGIT_BITS;
-		mask := DIGIT(1 << DIGIT((bit_offset % _DIGIT_BITS)));
+	limb        := offset / _DIGIT_BITS;
+	bits_left   := count;
+	bits_offset := offset % _DIGIT_BITS;
 
-		if (a.digit[limb] & mask) != 0 {
-			res += _WORD(1) << uint(shift);
-		}
-	}
-	return res, nil;
-}
+	num_bits    := min(bits_left, _DIGIT_BITS - bits_offset);
 
-int_bitfield_extract_fast :: proc(a: ^Int, offset, count: int) -> (res: _WORD, err: Error) {
-	/*
-		Check that `a` is usable.
-	*/
-	if err = clear_if_uninitialized(a); err != nil { return 0, err; }
-	if count > _WORD_BITS || count < 1             { return 0, .Invalid_Argument; }
+	// fmt.printf("offset: %v | count: %v\n\n", offset, count);
+	// fmt.printf("left:   %v | bits_offset: %v | limb:  %v | num: %v\n\n", bits_left, bits_offset, limb, num_bits);
 
-	for shift := 0; shift < count; shift += 1 {
-		bit_offset := offset + shift;
+	shift := offset % _DIGIT_BITS;
+	mask  := (_WORD(1) << uint(num_bits)) - 1;
+
+	// fmt.printf("shift: %v | mask: %v\n", shift, mask);
+	// fmt.printf("d: %v\n", a.digit[limb]);
+
+	res  = (_WORD(a.digit[limb]) >> uint(shift)) & mask;
+
+	// fmt.printf("res: %v\n", res);
+
+	bits_left -= num_bits;
+	if bits_left == 0 { return res, nil; }
+
+	res_shift := num_bits;
+
+	num_bits = min(bits_left, _DIGIT_BITS);
+	mask     = (1 << uint(num_bits)) - 1;
+
+	v := (_WORD(a.digit[limb + 1]) & mask) << uint(res_shift);
+	res |= v;
+
+	bits_left -= num_bits;
+	if bits_left == 0 { return res, nil; }
+
+	// fmt.printf("bits_left: %v | offset: %v | num: %v\n", bits_left, offset, num_bits);
+
+	mask     = (1 << uint(bits_left)) - 1;
+	res_shift += _DIGIT_BITS;
+
+	v = (_WORD(a.digit[limb + 2]) & mask) << uint(res_shift);
+	res |= v;
 
-		limb := bit_offset / _DIGIT_BITS;
-		mask := DIGIT(1 << DIGIT((bit_offset % _DIGIT_BITS)));
 
-		if (a.digit[limb] & mask) != 0 {
-			res += _WORD(1) << uint(shift);
-		}
-	}
 	return res, nil;
 }
 

+ 2 - 2
core/math/big/test.py

@@ -11,13 +11,13 @@ from enum import Enum
 # With EXIT_ON_FAIL set, we exit at the first fail.
 #
 EXIT_ON_FAIL = True
-#EXIT_ON_FAIL = False
+EXIT_ON_FAIL = False
 
 #
 # We skip randomized tests altogether if NO_RANDOM_TESTS is set.
 #
 NO_RANDOM_TESTS = True
-#NO_RANDOM_TESTS = False
+NO_RANDOM_TESTS = False
 
 #
 # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.