big: Split more into public and internal.

This commit is contained in:
Jeroen van Rijn
2021-08-06 21:22:51 +02:00
parent 9890e7cfeb
commit 9321616c80
7 changed files with 236 additions and 169 deletions
+23 -149
View File
@@ -12,7 +12,6 @@ package big
*/
import "core:mem"
import "core:intrinsics"
/*
===========================
@@ -140,9 +139,7 @@ int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.allocator) ->
mul :: proc { int_mul, int_mul_digit, };
sqr :: proc(dest, src: ^Int) -> (err: Error) {
return mul(dest, src, src);
}
sqr :: proc(dest, src: ^Int) -> (err: Error) { return mul(dest, src, src); }
/*
divmod.
@@ -205,8 +202,10 @@ mod :: proc { int_mod, int_mod_digit, };
remainder = (number + addend) % modulus.
*/
int_addmod :: proc(remainder, number, addend, modulus: ^Int) -> (err: Error) {
if err = add(remainder, number, addend); err != nil { return err; }
return mod(remainder, remainder, modulus);
if remainder == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(number, addend, modulus); err != nil { return err; }
return #force_inline internal_addmod(remainder, number, addend, modulus);
}
addmod :: proc { int_addmod, };
@@ -214,8 +213,10 @@ addmod :: proc { int_addmod, };
remainder = (number - decrease) % modulus.
*/
int_submod :: proc(remainder, number, decrease, modulus: ^Int) -> (err: Error) {
if err = add(remainder, number, decrease); err != nil { return err; }
return mod(remainder, remainder, modulus);
if remainder == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(number, decrease, modulus); err != nil { return err; }
return #force_inline internal_submod(remainder, number, decrease, modulus);
}
submod :: proc { int_submod, };
@@ -223,8 +224,10 @@ submod :: proc { int_submod, };
remainder = (number * multiplicand) % modulus.
*/
int_mulmod :: proc(remainder, number, multiplicand, modulus: ^Int) -> (err: Error) {
if err = mul(remainder, number, multiplicand); err != nil { return err; }
return mod(remainder, remainder, modulus);
if remainder == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(number, multiplicand, modulus); err != nil { return err; }
return #force_inline internal_mulmod(remainder, number, multiplicand, modulus);
}
mulmod :: proc { int_mulmod, };
@@ -232,89 +235,23 @@ mulmod :: proc { int_mulmod, };
remainder = (number * number) % modulus.
*/
int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) {
if err = sqr(remainder, number); err != nil { return err; }
return mod(remainder, remainder, modulus);
if remainder == nil { return .Invalid_Pointer; };
if err = clear_if_uninitialized(number, modulus); err != nil { return err; }
return #force_inline internal_sqrmod(remainder, number, modulus);
}
sqrmod :: proc { int_sqrmod, };
/*
TODO: Use Sterling's Approximation to estimate log2(N!) to size the result.
This way we'll have to reallocate less, possibly not at all.
*/
int_factorial :: proc(res: ^Int, n: DIGIT) -> (err: Error) {
if n < 0 || n > _FACTORIAL_MAX_N || res == nil { return .Invalid_Argument; }
int_factorial :: proc(res: ^Int, n: int) -> (err: Error) {
if n < 0 || n > _FACTORIAL_MAX_N { return .Invalid_Argument; }
if res == nil { return .Invalid_Pointer; }
i := DIGIT(len(_factorial_table));
if n < i {
return set(res, _factorial_table[n]);
}
if n >= _FACTORIAL_BINARY_SPLIT_CUTOFF {
return int_factorial_binary_split(res, n);
}
if err = set(res, _factorial_table[i - 1]); err != nil { return err; }
for {
if err = mul(res, res, DIGIT(i)); err != nil || i == n { return err; }
i += 1;
}
return nil;
return #force_inline internal_int_factorial(res, n);
}
_int_recursive_product :: proc(res: ^Int, start, stop: DIGIT, level := int(0)) -> (err: Error) {
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if level > _FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS { return .Max_Iterations_Reached; }
num_factors := (stop - start) >> 1;
if num_factors == 2 {
if err = set(t1, start); err != nil { return err; }
if err = add(t2, t1, 2); err != nil { return err; }
return mul(res, t1, t2);
}
if num_factors > 1 {
mid := (start + num_factors) | 1;
if err = _int_recursive_product(t1, start, mid, level + 1); err != nil { return err; }
if err = _int_recursive_product(t2, mid, stop, level + 1); err != nil { return err; }
return mul(res, t1, t2);
}
if num_factors == 1 { return set(res, start); }
return set(res, 1);
}
/*
Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
*/
int_factorial_binary_split :: proc(res: ^Int, n: DIGIT) -> (err: Error) {
if n < 0 || n > _FACTORIAL_MAX_N || res == nil { return .Invalid_Argument; }
inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(inner, outer, start, stop, temp);
if err = set(inner, 1); err != nil { return err; }
if err = set(outer, 1); err != nil { return err; }
bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
for i := bits_used; i >= 0; i -= 1 {
start := (n >> (uint(i) + 1)) + 1 | 1;
stop := (n >> uint(i)) + 1 | 1;
if err = _int_recursive_product(temp, start, stop); err != nil { return err; }
if err = mul(inner, inner, temp); err != nil { return err; }
if err = mul(outer, outer, inner); err != nil { return err; }
}
shift := n - intrinsics.count_ones(n);
return shl(res, outer, int(shift));
}
factorial :: proc { int_factorial, };
/*
Number of ways to choose `k` items from `n` items.
Also known as the binomial coefficient.
@@ -330,7 +267,7 @@ factorial :: proc { int_factorial, };
k, start from previous result
*/
int_choose_digit :: proc(res: ^Int, n, k: DIGIT) -> (err: Error) {
int_choose_digit :: proc(res: ^Int, n, k: int) -> (err: Error) {
if res == nil { return .Invalid_Pointer; }
if err = clear_if_uninitialized(res); err != nil { return err; }
@@ -1120,66 +1057,3 @@ int_mod_bits :: proc(remainder, numerator: ^Int, bits: int) -> (err: Error) {
mod_bits :: proc { int_mod_bits, };
when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
_factorial_table := [35]_WORD{
/* f(00): */ 1,
/* f(01): */ 1,
/* f(02): */ 2,
/* f(03): */ 6,
/* f(04): */ 24,
/* f(05): */ 120,
/* f(06): */ 720,
/* f(07): */ 5_040,
/* f(08): */ 40_320,
/* f(09): */ 362_880,
/* f(10): */ 3_628_800,
/* f(11): */ 39_916_800,
/* f(12): */ 479_001_600,
/* f(13): */ 6_227_020_800,
/* f(14): */ 87_178_291_200,
/* f(15): */ 1_307_674_368_000,
/* f(16): */ 20_922_789_888_000,
/* f(17): */ 355_687_428_096_000,
/* f(18): */ 6_402_373_705_728_000,
/* f(19): */ 121_645_100_408_832_000,
/* f(20): */ 2_432_902_008_176_640_000,
/* f(21): */ 51_090_942_171_709_440_000,
/* f(22): */ 1_124_000_727_777_607_680_000,
/* f(23): */ 25_852_016_738_884_976_640_000,
/* f(24): */ 620_448_401_733_239_439_360_000,
/* f(25): */ 15_511_210_043_330_985_984_000_000,
/* f(26): */ 403_291_461_126_605_635_584_000_000,
/* f(27): */ 10_888_869_450_418_352_160_768_000_000,
/* f(28): */ 304_888_344_611_713_860_501_504_000_000,
/* f(29): */ 8_841_761_993_739_701_954_543_616_000_000,
/* f(30): */ 265_252_859_812_191_058_636_308_480_000_000,
/* f(31): */ 8_222_838_654_177_922_817_725_562_880_000_000,
/* f(32): */ 263_130_836_933_693_530_167_218_012_160_000_000,
/* f(33): */ 8_683_317_618_811_886_495_518_194_401_280_000_000,
/* f(34): */ 295_232_799_039_604_140_847_618_609_643_520_000_000,
};
} else {
_factorial_table := [21]_WORD{
/* f(00): */ 1,
/* f(01): */ 1,
/* f(02): */ 2,
/* f(03): */ 6,
/* f(04): */ 24,
/* f(05): */ 120,
/* f(06): */ 720,
/* f(07): */ 5_040,
/* f(08): */ 40_320,
/* f(09): */ 362_880,
/* f(10): */ 3_628_800,
/* f(11): */ 39_916_800,
/* f(12): */ 479_001_600,
/* f(13): */ 6_227_020_800,
/* f(14): */ 87_178_291_200,
/* f(15): */ 1_307_674_368_000,
/* f(16): */ 20_922_789_888_000,
/* f(17): */ 355_687_428_096_000,
/* f(18): */ 6_402_373_705_728_000,
/* f(19): */ 121_645_100_408_832_000,
/* f(20): */ 2_432_902_008_176_640_000,
};
};
+2 -3
View File
@@ -1,10 +1,9 @@
@echo off
:odin run . -vet
: -o:size -no-bounds-check
odin run . -vet -o:size
:odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:size
:odin build . -build-mode:shared -show-timings -o:speed -no-bounds-check
:odin build . -build-mode:shared -show-timings -o:speed
python test.py
:python test.py
+18 -11
View File
@@ -65,19 +65,26 @@ demo :: proc() {
a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(a, b, c, d, e, f);
N :: 12345;
D :: 4;
N := 10_000;
set(a, N);
print("a: ", a);
div(b, a, D);
rem, _ := mod(a, D);
print("b: ", b);
fmt.printf("rem: %v\n", rem);
FACTORIAL_10_000_FIRST_100 :: "46AB3AE48966202D0FDE097BFA88FADC512AE8AFC0EA1D1D376A4109F10105E9E21F1E907151E85F926B8D82737B9030D572";
mul(b, b, D);
add(b, b, rem);
print("b: ", b);
for _ in 0..10
{
SCOPED_TIMING(.factorial);
factorial(a, N);
}
as, _ := itoa(a, 16);
defer delete(as);
fmt.printf("factorial(%v): %v (first 50 hex digits)\n", N, as[:50]);
if as[:100] == FACTORIAL_10_000_FIRST_100 {
fmt.println("\nCorrect!");
} else {
fmt.printf("\nWrong. Expected: %v\n", FACTORIAL_10_000_FIRST_100);
}
}
main :: proc() {
+187 -2
View File
@@ -28,6 +28,7 @@ package big
*/
import "core:mem"
import "core:intrinsics"
/*
Low-level addition, unsigned. Handbook of Applied Cryptography, algorithm 14.7.
@@ -419,7 +420,7 @@ internal_int_sub_digit :: proc(dest, number: ^Int, digit: DIGIT) -> (err: Error)
#no_bounds_check for i := 0; i < number.used; i += 1 {
dest.digit[i] = number.digit[i] - carry;
carry := dest.digit[i] >> (_DIGIT_TYPE_BITS - 1);
carry = dest.digit[i] >> (_DIGIT_TYPE_BITS - 1);
dest.digit[i] &= _MASK;
}
}
@@ -803,6 +804,122 @@ internal_int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error
}
internal_mod :: proc{ internal_int_mod, };
/*
remainder = (number + addend) % modulus.
*/
internal_int_addmod :: proc(remainder, number, addend, modulus: ^Int) -> (err: Error) {
if err = #force_inline internal_add(remainder, number, addend); err != nil { return err; }
return #force_inline internal_mod(remainder, remainder, modulus);
}
internal_addmod :: proc { internal_int_addmod, };
/*
remainder = (number - decrease) % modulus.
*/
internal_int_submod :: proc(remainder, number, decrease, modulus: ^Int) -> (err: Error) {
if err = #force_inline internal_sub(remainder, number, decrease); err != nil { return err; }
return #force_inline internal_mod(remainder, remainder, modulus);
}
internal_submod :: proc { internal_int_submod, };
/*
remainder = (number * multiplicand) % modulus.
*/
internal_int_mulmod :: proc(remainder, number, multiplicand, modulus: ^Int) -> (err: Error) {
if err = #force_inline internal_mul(remainder, number, multiplicand); err != nil { return err; }
return #force_inline internal_mod(remainder, remainder, modulus);
}
internal_mulmod :: proc { internal_int_mulmod, };
/*
remainder = (number * number) % modulus.
*/
internal_int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) {
if err = #force_inline internal_mul(remainder, number, number); err != nil { return err; }
return #force_inline internal_mod(remainder, remainder, modulus);
}
internal_sqrmod :: proc { internal_int_sqrmod, };
/*
TODO: Use Sterling's Approximation to estimate log2(N!) to size the result.
This way we'll have to reallocate less, possibly not at all.
*/
internal_int_factorial :: proc(res: ^Int, n: int) -> (err: Error) {
if n >= _FACTORIAL_BINARY_SPLIT_CUTOFF {
return #force_inline _int_factorial_binary_split(res, n);
}
i := len(_factorial_table);
if n < i {
return #force_inline set(res, _factorial_table[n]);
}
if err = #force_inline set(res, _factorial_table[i - 1]); err != nil { return err; }
for {
if err = #force_inline internal_mul(res, res, DIGIT(i)); err != nil || i == n { return err; }
i += 1;
}
return nil;
}
_int_recursive_product :: proc(res: ^Int, start, stop: int, level := int(0)) -> (err: Error) {
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if level > _FACTORIAL_BINARY_SPLIT_MAX_RECURSIONS { return .Max_Iterations_Reached; }
num_factors := (stop - start) >> 1;
if num_factors == 2 {
if err = set(t1, start); err != nil { return err; }
when true {
if err = grow(t2, t1.used + 1); err != nil { return err; }
if err = internal_add(t2, t1, 2); err != nil { return err; }
} else {
if err = add(t2, t1, 2); err != nil { return err; }
}
return internal_mul(res, t1, t2);
}
if num_factors > 1 {
mid := (start + num_factors) | 1;
if err = _int_recursive_product(t1, start, mid, level + 1); err != nil { return err; }
if err = _int_recursive_product(t2, mid, stop, level + 1); err != nil { return err; }
return internal_mul(res, t1, t2);
}
if num_factors == 1 { return #force_inline set(res, start); }
return #force_inline set(res, 1);
}
/*
Binary split factorial algo due to: http://www.luschny.de/math/factorial/binarysplitfact.html
*/
_int_factorial_binary_split :: proc(res: ^Int, n: int) -> (err: Error) {
inner, outer, start, stop, temp := &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(inner, outer, start, stop, temp);
if err = set(inner, 1); err != nil { return err; }
if err = set(outer, 1); err != nil { return err; }
bits_used := int(_DIGIT_TYPE_BITS - intrinsics.count_leading_zeros(n));
for i := bits_used; i >= 0; i -= 1 {
start := (n >> (uint(i) + 1)) + 1 | 1;
stop := (n >> uint(i)) + 1 | 1;
if err = _int_recursive_product(temp, start, stop); err != nil { return err; }
if err = internal_mul(inner, inner, temp); err != nil { return err; }
if err = internal_mul(outer, outer, inner); err != nil { return err; }
}
shift := n - intrinsics.count_ones(n);
return shl(res, outer, int(shift));
}
internal_int_zero_unused :: #force_inline proc(dest: ^Int, old_used := -1) {
@@ -824,4 +941,72 @@ internal_int_zero_unused :: #force_inline proc(dest: ^Int, old_used := -1) {
}
}
internal_zero_unused :: proc { internal_int_zero_unused, };
internal_zero_unused :: proc { internal_int_zero_unused, };
/*
Tables.
*/
when MATH_BIG_FORCE_64_BIT || (!MATH_BIG_FORCE_32_BIT && size_of(rawptr) == 8) {
_factorial_table := [35]_WORD{
/* f(00): */ 1,
/* f(01): */ 1,
/* f(02): */ 2,
/* f(03): */ 6,
/* f(04): */ 24,
/* f(05): */ 120,
/* f(06): */ 720,
/* f(07): */ 5_040,
/* f(08): */ 40_320,
/* f(09): */ 362_880,
/* f(10): */ 3_628_800,
/* f(11): */ 39_916_800,
/* f(12): */ 479_001_600,
/* f(13): */ 6_227_020_800,
/* f(14): */ 87_178_291_200,
/* f(15): */ 1_307_674_368_000,
/* f(16): */ 20_922_789_888_000,
/* f(17): */ 355_687_428_096_000,
/* f(18): */ 6_402_373_705_728_000,
/* f(19): */ 121_645_100_408_832_000,
/* f(20): */ 2_432_902_008_176_640_000,
/* f(21): */ 51_090_942_171_709_440_000,
/* f(22): */ 1_124_000_727_777_607_680_000,
/* f(23): */ 25_852_016_738_884_976_640_000,
/* f(24): */ 620_448_401_733_239_439_360_000,
/* f(25): */ 15_511_210_043_330_985_984_000_000,
/* f(26): */ 403_291_461_126_605_635_584_000_000,
/* f(27): */ 10_888_869_450_418_352_160_768_000_000,
/* f(28): */ 304_888_344_611_713_860_501_504_000_000,
/* f(29): */ 8_841_761_993_739_701_954_543_616_000_000,
/* f(30): */ 265_252_859_812_191_058_636_308_480_000_000,
/* f(31): */ 8_222_838_654_177_922_817_725_562_880_000_000,
/* f(32): */ 263_130_836_933_693_530_167_218_012_160_000_000,
/* f(33): */ 8_683_317_618_811_886_495_518_194_401_280_000_000,
/* f(34): */ 295_232_799_039_604_140_847_618_609_643_520_000_000,
};
} else {
_factorial_table := [21]_WORD{
/* f(00): */ 1,
/* f(01): */ 1,
/* f(02): */ 2,
/* f(03): */ 6,
/* f(04): */ 24,
/* f(05): */ 120,
/* f(06): */ 720,
/* f(07): */ 5_040,
/* f(08): */ 40_320,
/* f(09): */ 362_880,
/* f(10): */ 3_628_800,
/* f(11): */ 39_916_800,
/* f(12): */ 479_001_600,
/* f(13): */ 6_227_020_800,
/* f(14): */ 87_178_291_200,
/* f(15): */ 1_307_674_368_000,
/* f(16): */ 20_922_789_888_000,
/* f(17): */ 355_687_428_096_000,
/* f(18): */ 6_402_373_705_728_000,
/* f(19): */ 121_645_100_408_832_000,
/* f(20): */ 2_432_902_008_176_640_000,
};
};
+1 -1
View File
@@ -296,7 +296,7 @@ PyRes :: struct {
/*
dest = factorial(n)
*/
@export test_factorial :: proc "c" (n: DIGIT) -> (res: PyRes) {
@export test_factorial :: proc "c" (n: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
+4 -2
View File
@@ -17,7 +17,7 @@ EXIT_ON_FAIL = False
# We skip randomized tests altogether if NO_RANDOM_TESTS is set.
#
NO_RANDOM_TESTS = True
NO_RANDOM_TESTS = False
#NO_RANDOM_TESTS = False
#
# If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
@@ -444,7 +444,9 @@ TESTS = {
[ 149195686190273039203651143129455, 12 ],
],
test_factorial: [
[ 12_345 ],
[ 6_000 ], # Regular factorial, see cutoff in common.odin.
[ 12_345 ], # Binary split factorial
[ 100_000 ],
],
test_gcd: [
[ 23, 25, ],
+1 -1
View File
@@ -45,7 +45,7 @@ print_timings :: proc() {
for v in Timings {
if v.count > 0 {
fmt.println("Timings:");
fmt.println("\nTimings:");
break;
}
}