From 4be48973adc323838682d7f8e1f38b4082a48024 Mon Sep 17 00:00:00 2001 From: Jeroen van Rijn Date: Fri, 6 Aug 2021 14:57:53 +0200 Subject: [PATCH] big: Squashed shl1 bug when a larger dest was reused for a smaller result. --- core/math/big/basic.odin | 49 +++----------- core/math/big/build.bat | 2 +- core/math/big/example.odin | 24 ++----- core/math/big/exp_log.odin | 2 +- core/math/big/internal.odin | 124 ++++++++++++++++++++++++------------ core/math/big/test.odin | 4 +- core/math/big/test.py | 10 ++- 7 files changed, 113 insertions(+), 102 deletions(-) diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index a06c275b8..6c7d46ed9 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -152,45 +152,18 @@ int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: E /* Early out if neither of the results is wanted. */ - if quotient == nil && remainder == nil { return nil; } + if quotient == nil && remainder == nil { return nil; } + if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; } - if err = clear_if_uninitialized(numerator); err != nil { return err; } - if err = clear_if_uninitialized(denominator); err != nil { return err; } - - z: bool; - if z, err = is_zero(denominator); z { return .Division_by_Zero; } - - /* - If numerator < denominator then quotient = 0, remainder = numerator. - */ - c: int; - if c, err = cmp_mag(numerator, denominator); c == -1 { - if remainder != nil { - if err = copy(remainder, numerator); err != nil { return err; } - } - if quotient != nil { - zero(quotient); - } - return nil; - } - - if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) { - // err = _int_div_recursive(quotient, remainder, numerator, denominator); - } else { - err = _int_div_school(quotient, remainder, numerator, denominator); - /* - NOTE(Jeroen): We no longer need or use `_int_div_small`. - We'll keep it around for a bit. - err = _int_div_small(quotient, remainder, numerator, denominator); - */ - } - - return err; + return #force_inline internal_int_divmod(quotient, remainder, numerator, denominator); } divmod :: proc{ int_divmod, }; int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) { - return int_divmod(quotient, nil, numerator, denominator); + if quotient == nil { return .Invalid_Pointer; }; + if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; } + + return #force_inline internal_int_divmod(quotient, nil, numerator, denominator); } div :: proc { int_div, }; @@ -200,11 +173,10 @@ div :: proc { int_div, }; denominator < remainder <= 0 if denominator < 0 */ int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) { - if err = divmod(nil, remainder, numerator, denominator); err != nil { return err; } + if remainder == nil { return .Invalid_Pointer; }; + if err = clear_if_uninitialized(numerator, denominator); err != nil { return err; } - z: bool; - if z, err = is_zero(remainder); z || denominator.sign == remainder.sign { return nil; } - return add(remainder, remainder, numerator); + return #force_inline internal_int_mod(remainder, numerator, denominator); } int_mod_digit :: proc(numerator: ^Int, denominator: DIGIT) -> (remainder: DIGIT, err: Error) { @@ -776,7 +748,6 @@ _int_div_school :: proc(quotient, remainder, numerator, denominator: ^Int) -> (e t2.used = 3; if t1_t2, _ := cmp_mag(t1, t2); t1_t2 != 1 { - break; } iter += 1; if iter > 100 { return .Max_Iterations_Reached; } diff --git a/core/math/big/build.bat b/core/math/big/build.bat index 94c8c7144..b1d2a00ee 100644 --- a/core/math/big/build.bat +++ b/core/math/big/build.bat @@ -1,5 +1,5 @@ @echo off -:odin run . -vet-more +:odin run . -vet : -o:size -no-bounds-check :odin build . -build-mode:shared -show-timings -o:minimal -no-bounds-check :odin build . -build-mode:shared -show-timings -o:size -no-bounds-check diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 806aafda6..4777654e1 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -62,30 +62,16 @@ print :: proc(name: string, a: ^Int, base := i8(10), print_name := true, newline } demo :: proc() { - err: Error; - as: string; - defer delete(as); - a, b, c, d, e, f := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; defer destroy(a, b, c, d, e, f); - err = factorial(a, 1224); - count, _ := count_bits(a); + foo := "686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214"; + err := atoi(a, foo, 10); - bits := 51; - be1: _WORD; + if err != nil do fmt.printf("atoi returned %v\n", err); + + print("foo: ", a); - /* - Timing loop - */ - { - SCOPED_TIMING(.bitfield_extract); - for o := 0; o < count - bits; o += 1 { - be1, _ = int_bitfield_extract(a, o, bits); - } - } - SCOPED_COUNT_ADD(.bitfield_extract, count - bits - 1); - fmt.printf("be1: %v\n", be1); } main :: proc() { diff --git a/core/math/big/exp_log.odin b/core/math/big/exp_log.odin index e0569c586..33ebe9e21 100644 --- a/core/math/big/exp_log.odin +++ b/core/math/big/exp_log.odin @@ -253,7 +253,7 @@ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) { swap(dest, x); return err; } else { - // return root_n(dest, src, 2); + return root_n(dest, src, 2); } } sqrt :: proc { int_sqrt, }; diff --git a/core/math/big/internal.odin b/core/math/big/internal.odin index df0acbf43..ed92b9f57 100644 --- a/core/math/big/internal.odin +++ b/core/math/big/internal.odin @@ -491,49 +491,30 @@ internal_int_shr1 :: proc(dest, src: ^Int) -> (err: Error) { dest = src << 1 */ internal_int_shl1 :: proc(dest, src: ^Int) -> (err: Error) { - old_used := dest.used; dest.used = src.used + 1; - + if err = copy(dest, src); err != nil { return err; } /* - Forward carry + Grow `dest` to accommodate the additional bits. */ + digits_needed := dest.used + 1; + if err = grow(dest, digits_needed); err != nil { return err; } + dest.used = digits_needed; + + mask := (DIGIT(1) << uint(1)) - DIGIT(1); + shift := DIGIT(_DIGIT_BITS - 1); carry := DIGIT(0); - #no_bounds_check for x := 0; x < src.used; x += 1 { - /* - Get what will be the *next* carry bit from the MSB of the current digit. - */ - src_digit := src.digit[x]; - fwd_carry := src_digit >> (_DIGIT_BITS - 1); - /* - Now shift up this digit, add in the carry [from the previous] - */ - dest.digit[x] = (src_digit << 1 | carry) & _MASK; - - /* - Update carry - */ + #no_bounds_check for x:= 0; x < dest.used; x+= 1 { + fwd_carry := (dest.digit[x] >> shift) & mask; + dest.digit[x] = (dest.digit[x] << uint(1) | carry) & _MASK; carry = fwd_carry; } /* - New leading digit? + Use final carry. */ if carry != 0 { - /* - Add a MSB which is always 1 at this point. - */ - dest.digit[dest.used] = 1; + dest.digit[dest.used] = carry; + dest.used += 1; } - zero_count := old_used - dest.used; - /* - Zero remainder. - */ - if zero_count > 0 { - mem.zero_slice(dest.digit[dest.used:][:zero_count]); - } - /* - Adjust dest.used based on leading zeroes. - */ - dest.sign = src.sign; return clamp(dest); } @@ -552,7 +533,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator := Power of two? */ if multiplier == 2 { - return #force_inline shl1(dest, src); + return #force_inline internal_int_shl1(dest, src); } if is_power_of_two(int(multiplier)) { ix: int; @@ -581,7 +562,7 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator := Compute columns. */ ix := 0; - #no_bounds_check for ; ix < src.used; ix += 1 { + for ; ix < src.used; ix += 1 { /* Compute product and carry sum for this term */ @@ -600,13 +581,15 @@ internal_int_mul_digit :: proc(dest, src: ^Int, multiplier: DIGIT, allocator := Store final carry [if any] and increment used. */ dest.digit[ix] = DIGIT(carry); + dest.used = src.used + 1; /* Zero unused digits. */ + //_zero_unused(dest); zero_count := old_used - dest.used; - if zero_count > 0 { - mem.zero_slice(dest.digit[zero_count:]); + if zero_count > 0 { + mem.zero_slice(dest.digit[dest.used:][:zero_count]); } return clamp(dest); } @@ -675,9 +658,72 @@ internal_int_mul :: proc(dest, src, multiplier: ^Int, allocator := context.alloc err = _int_mul(dest, src, multiplier, digits); } } - neg := src.sign != multiplier.sign; + neg := src.sign != multiplier.sign; dest.sign = .Negative if dest.used > 0 && neg else .Zero_or_Positive; return err; } -internal_mul :: proc { internal_int_mul, internal_int_mul_digit, }; \ No newline at end of file +internal_mul :: proc { internal_int_mul, internal_int_mul_digit, }; + +/* + divmod. + Both the quotient and remainder are optional and may be passed a nil. +*/ +internal_int_divmod :: proc(quotient, remainder, numerator, denominator: ^Int, allocator := context.allocator) -> (err: Error) { + + if denominator.used == 0 { return .Division_by_Zero; } + /* + If numerator < denominator then quotient = 0, remainder = numerator. + */ + c: int; + if c, err = #force_inline cmp_mag(numerator, denominator); c == -1 { + if remainder != nil { + if err = copy(remainder, numerator, false, allocator); err != nil { return err; } + } + if quotient != nil { + zero(quotient); + } + return nil; + } + + if false && (denominator.used > 2 * _MUL_KARATSUBA_CUTOFF) && (denominator.used <= (numerator.used/3) * 2) { + // err = _int_div_recursive(quotient, remainder, numerator, denominator); + } else { + when true { + err = _int_div_school(quotient, remainder, numerator, denominator); + } else { + /* + NOTE(Jeroen): We no longer need or use `_int_div_small`. + We'll keep it around for a bit until we're reasonably certain div_school is bug free. + err = _int_div_small(quotient, remainder, numerator, denominator); + */ + err = _int_div_small(quotient, remainder, numerator, denominator); + } + } + return; +} +internal_divmod :: proc { internal_int_divmod, }; + +/* + Asssumes quotient, numerator and denominator to have been initialized and not to be nil. +*/ +internal_int_div :: proc(quotient, numerator, denominator: ^Int) -> (err: Error) { + return #force_inline internal_int_divmod(quotient, nil, numerator, denominator); +} +internal_div :: proc { internal_int_div, }; + +/* + remainder = numerator % denominator. + 0 <= remainder < denominator if denominator > 0 + denominator < remainder <= 0 if denominator < 0 + + Asssumes quotient, numerator and denominator to have been initialized and not to be nil. +*/ +internal_int_mod :: proc(remainder, numerator, denominator: ^Int) -> (err: Error) { + if err = #force_inline internal_int_divmod(nil, remainder, numerator, denominator); err != nil { return err; } + + if remainder.used == 0 || denominator.sign == remainder.sign { return nil; } + + return #force_inline internal_add(remainder, remainder, numerator); +} +internal_mod :: proc{ internal_int_mod, }; \ No newline at end of file diff --git a/core/math/big/test.odin b/core/math/big/test.odin index ae105969c..744a4436a 100644 --- a/core/math/big/test.odin +++ b/core/math/big/test.odin @@ -24,9 +24,9 @@ PyRes :: struct { err: Error, } -@export test_initialize_constants :: proc "c" () -> (res: int) { +@export test_initialize_constants :: proc "c" () -> (res: u64) { context = runtime.default_context(); - return initialize_constants(); + return u64(initialize_constants()); } @export test_error_string :: proc "c" (err: Error) -> (res: cstring) { diff --git a/core/math/big/test.py b/core/math/big/test.py index 836f03353..6dcb235de 100644 --- a/core/math/big/test.py +++ b/core/math/big/test.py @@ -254,7 +254,13 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay): def test_sqrt(number = 0, expected_error = Error.Okay): args = [arg_to_odin(number)] - res = int_sqrt(*args) + try: + res = int_sqrt(*args) + except OSError as e: + print("{} while trying to sqrt {}.".format(e, number)) + if EXIT_ON_FAIL: exit(3) + return False + expected_result = None if expected_error == Error.Okay: if number < 0: @@ -384,6 +390,7 @@ TESTS = { [ 54321, 12345], [ 55431, 0, Error.Division_by_Zero], [ 12980742146337069150589594264770969721, 4611686018427387904 ], + [ 831956404029821402159719858789932422, 243087903122332132 ], ], test_log: [ [ 3192, 1, Error.Invalid_Argument], @@ -405,6 +412,7 @@ TESTS = { [ 42, Error.Okay, ], [ 12345678901234567890, Error.Okay, ], [ 1298074214633706907132624082305024, Error.Okay, ], + [ 686885735734829009541949746871140768343076607029752932751182108475420900392874228486622313727012705619148037570309621219533087263900443932890792804879473795673302686046941536636874184361869252299636701671980034458333859202703255467709267777184095435235980845369829397344182319113372092844648570818726316581751114346501124871729572474923695509057166373026411194094493240101036672016770945150422252961487398124677567028263059046193391737576836378376192651849283925197438927999526058932679219572030021792914065825542626400207956134072247020690107136531852625253942429167557531123651471221455967386267137846791963149859804549891438562641323068751514370656287452006867713758971418043865298618635213551059471668293725548570452377976322899027050925842868079489675596835389444833567439058609775325447891875359487104691935576723532407937236505941186660707032433807075470656782452889754501872408562496805517394619388777930253411467941214807849472083814447498068636264021405175653742244368865090604940094889189800007448083930490871954101880815781177612910234741529950538835837693870921008635195545246771593130784786737543736434086434015200264933536294884482218945403958647118802574342840790536176272341586020230110889699633073513016344826709214, Error.Okay, ], ], test_root_n: [ [ 1298074214633706907132624082305024, 2, Error.Okay, ],