diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index 7d0c467a8..fde11eba6 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -359,7 +359,7 @@ int_halve :: proc(dest, src: ^Int) -> (err: Error) { */ fwd_carry := DIGIT(0); - for x := dest.used; x >= 0; x -= 1 { + for x := dest.used - 1; x >= 0; x -= 1 { /* Get the carry for the next iteration. */ @@ -761,21 +761,16 @@ sqrmod :: proc { int_sqrmod, }; */ _int_add :: proc(dest, a, b: ^Int) -> (err: Error) { dest := dest; x := a; y := b; - if err = clear_if_uninitialized(x); err != .None { - return err; - } - if err = clear_if_uninitialized(y); err != .None { - return err; - } old_used, min_used, max_used, i: int; if x.used < y.used { x, y = y, x; + assert(x.used >= y.used); } - min_used = x.used; - max_used = y.used; + min_used = y.used; + max_used = x.used; old_used = dest.used; if err = grow(dest, max(max_used + 1, _DEFAULT_DIGIT_COUNT)); err != .None { @@ -827,7 +822,6 @@ _int_add :: proc(dest, a, b: ^Int) -> (err: Error) { Add remaining carry. */ dest.digit[i] = carry; - zero_count := old_used - dest.used; /* Zero remainder. @@ -1111,6 +1105,7 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) { _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) { ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{}; + c: int; goto_end: for { if err = one(tq); err != .None { break goto_end; } @@ -1121,20 +1116,21 @@ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (er if err = abs(ta, numerator); err != .None { break goto_end; } if err = abs(tb, denominator); err != .None { break goto_end; } - if err = shl(tb, tb, n); err != .None { break goto_end; } if err = shl(tq, tq, n); err != .None { break goto_end; } - for ; n >= 0; n -= 1 { - c: int; - if c, err = cmp(tb, ta); err != .None { break goto_end; } - if c != 1 { + for n >= 0 { + if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 { + // ta -= tb if err = sub(ta, ta, tb); err != .None { break goto_end; } - if err = add( q, tq, q); err != .None { break goto_end; } + // q += tq + if err = add( q, q, tq); err != .None { break goto_end; } } if err = shr1(tb, tb); err != .None { break goto_end; } if err = shr1(tq, tq); err != .None { break goto_end; } - } + + n -= 1; + } /* Now q == quotient and ta == remainder. diff --git a/core/math/big/build.bat b/core/math/big/build.bat index ac533db16..69dfe7995 100644 --- a/core/math/big/build.bat +++ b/core/math/big/build.bat @@ -1,5 +1,4 @@ @echo off -clear :odin run . -vet :odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules odin build . -build-mode:shared -show-timings -o:size -use-separate-modules diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 5d30c85ae..0c8ef35e1 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -66,17 +66,26 @@ print :: proc(name: string, a: ^Int, base := i8(10)) { } demo :: proc() { - // err: Error; - // r := &rnd.Rand{}; - // rnd.init(r, 12345); + err: Error; + destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; + defer destroy(destination, source, quotient, remainder, numerator, denominator); - // as := cstring("12341234"); - // bs := cstring("159671292010002348397151706347412301"); + err = atoi(source, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10); + print("src ", source); - // res := test_log(as, 2, 10); - // fmt.print(res); - // destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; - // defer destroy(destination, source, quotient, remainder, numerator, denominator); + fmt.println("sqrt should be 843478780275248664696797599030708027195155136953848512749494"); + + fmt.println(); + err = sqrt(destination, source); + fmt.printf("sqrt returned: %v\n", err); + print("sqrt ", destination); + + err = atoi(denominator, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10); + err = root_n(quotient, denominator, 2); + fmt.printf("root_n(2) returned: %v\n", err); + print("root_n(2)", quotient); + + // fmt.println(); } main :: proc() { diff --git a/core/math/big/exp_log.odin b/core/math/big/exp_log.odin index 4c31a99d8..840aec8df 100644 --- a/core/math/big/exp_log.odin +++ b/core/math/big/exp_log.odin @@ -214,42 +214,53 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) { This function is less generic than `root_n`, simpler and faster. */ int_sqrt :: proc(dest, src: ^Int) -> (err: Error) { - if err = clear_if_uninitialized(dest); err != .None { return err; } - if err = clear_if_uninitialized(src); err != .None { return err; } - /* Must be positive. */ - if src.sign == .Negative { return .Invalid_Argument; } + when true { + if err = clear_if_uninitialized(dest); err != .None { return err; } + if err = clear_if_uninitialized(src); err != .None { return err; } - /* Easy out. If src is zero, so is dest. */ - if z, _ := is_zero(src); z { return zero(dest); } + /* Must be positive. */ + if src.sign == .Negative { return .Invalid_Argument; } - /* Set up temporaries. */ - t1, t2 := &Int{}, &Int{}; - defer destroy(t1, t2); + /* Easy out. If src is zero, so is dest. */ + if z, _ := is_zero(src); z { return zero(dest); } - if err = copy(t1, src); err != .None { return err; } - if err = zero(t2); err != .None { return err; } + /* Set up temporaries. */ + x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{}; + defer destroy(x, y, t1, t2); - /* First approximation. Not very bad for large arguments. */ - if err = shr_digit(t1, t1.used / 2); err != .None { return err; } - /* t1 > 0 */ - if err = div(t2, src, t1); err != .None { return err; } - if err = add(t1, t1, t2); err != .None { return err; } - if err = shr(t1, t1, 1); err != .None { return err; } + count: int; + if count, err = count_bits(src); err != .None { return err; } - /* And now t1 > sqrt(arg). */ - for { - if err = div(t2, src, t1); err != .None { return err; } - if err = add(t1, t1, t2); err != .None { return err; } - if err = shr(t1, t1, 1); err != .None { return err; } - /* t1 >= sqrt(arg) >= t2 at this point */ + a, b := count >> 1, count & 1; + err = power_of_two(x, a+b); - cm, _ := cmp_mag(t1, t2); - if cm != 1 { break; } + iter := 0; + for { + iter += 1; + if iter > 100 { + swap(dest, x); + return .Max_Iterations_Reached; + } + /* + y = (x + n//x)//2 + */ + div(t1, src, x); + add(t2, t1, x); + shr(y, t2, 1); + + if c, _ := cmp(y, x); c == 0 || c == 1 { + swap(dest, x); + return .None; + } + swap(x, y); + } + + swap(dest, x); + return err; + } else { + // return root_n(dest, src, 2); } - - swap(dest, t1); - return err; } sqrt :: proc { int_sqrt, }; @@ -263,7 +274,7 @@ sqrt :: proc { int_sqrt, }; */ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { /* Fast path for n == 2 */ - if n == 2 { return sqrt(dest, src); } + // if n == 2 { return sqrt(dest, src); } /* Initialize dest + src if needed. */ if err = clear_if_uninitialized(dest); err != .None { return err; } @@ -321,6 +332,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { if err = power_of_two(t2, ilog2); err != .None { return err; } c: int; + iterations := 0; for { /* t1 = t2 */ if err = copy(t1, t2); err != .None { return err; } @@ -353,12 +365,23 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { break; } if c, err = cmp(t1, t2); c == 0 { break; } + iterations += 1; + if iterations == 101 { + return .Max_Iterations_Reached; + } } /* Result can be off by a few so check. */ /* Loop beneath can overshoot by one if found root is smaller than actual root. */ + iterations = 0; for { + if iterations == 101 { + return .Max_Iterations_Reached; + } + //fmt.printf("root_n iteration: %v\n", iterations); + iterations += 1; + if err = pow(t2, t1, n); err != .None { return err; } c, err = cmp(t2, a); @@ -372,8 +395,14 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { } } + iterations = 0; /* Correct overshoot from above or from recurrence. */ for { + if iterations == 101 { + return .Max_Iterations_Reached; + } + iterations += 1; + if err = pow(t2, t1, n); err != .None { return err; } c, err = cmp(t2, a); diff --git a/core/math/big/logical.odin b/core/math/big/logical.odin index 1d2ba0895..e721db603 100644 --- a/core/math/big/logical.odin +++ b/core/math/big/logical.odin @@ -285,7 +285,7 @@ int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int) -> (err: Err shift := DIGIT(_DIGIT_BITS - bits); carry := DIGIT(0); - for x := quotient.used; x >= 0; x -= 1 { + for x := quotient.used - 1; x >= 0; x -= 1 { /* Get the lower bits of this word in a temp. */ @@ -344,7 +344,7 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) { } quotient.used -= digits; _zero_unused(quotient); - return .None; + return clamp(quotient); } shr_digit :: proc { int_shr_digit, }; @@ -446,16 +446,16 @@ int_shl_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) { /* Increment the used by the shift amount then copy upwards. */ - quotient.used += digits; /* Much like `int_shr_digit`, this is implemented using a sliding window, except the window goes the other way around. */ - for x := quotient.used; x >= digits; x -= 1 { - quotient.digit[x] = quotient.digit[x - digits]; + for x := quotient.used; x > 0; x -= 1 { + quotient.digit[x+digits-1] = quotient.digit[x-1]; } + quotient.used += digits; mem.zero_slice(quotient.digit[:digits]); return .None; } diff --git a/core/math/big/test.odin b/core/math/big/test.odin index 1ce653023..04635295d 100644 --- a/core/math/big/test.odin +++ b/core/math/big/test.odin @@ -26,74 +26,74 @@ PyRes :: struct { return strings.clone_to_cstring(es[err], context.temp_allocator); } -@export test_add_two :: proc "c" (a, b: cstring) -> (res: PyRes) { +@export test_add :: proc "c" (a, b: cstring) -> (res: PyRes) { context = runtime.default_context(); err: Error; aa, bb, sum := &Int{}, &Int{}, &Int{}; defer destroy(aa, bb, sum); - if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add_two:atoi(a):", err=err}; } - if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add_two:atoi(b):", err=err}; } - if err = add(sum, aa, bb); err != .None { return PyRes{res=":add_two:add(sum,a,b):", err=err}; } + if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add:atoi(a):", err=err}; } + if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add:atoi(b):", err=err}; } + if err = add(sum, aa, bb); err != .None { return PyRes{res=":add:add(sum,a,b):", err=err}; } r: cstring; r, err = int_itoa_cstring(sum, 10, context.temp_allocator); - if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; } + if err != .None { return PyRes{res=":add:itoa(sum):", err=err}; } return PyRes{res = r, err = .None}; } -@export test_sub_two :: proc "c" (a, b: cstring) -> (res: PyRes) { +@export test_sub :: proc "c" (a, b: cstring) -> (res: PyRes) { context = runtime.default_context(); err: Error; aa, bb, sum := &Int{}, &Int{}, &Int{}; defer destroy(aa, bb, sum); - if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; } - if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; } - if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; } + if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub:atoi(a):", err=err}; } + if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub:atoi(b):", err=err}; } + if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub:sub(sum,a,b):", err=err}; } r: cstring; r, err = int_itoa_cstring(sum, 10, context.temp_allocator); - if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; } + if err != .None { return PyRes{res=":sub:itoa(sum):", err=err}; } return PyRes{res = r, err = .None}; } -@export test_mul_two :: proc "c" (a, b: cstring) -> (res: PyRes) { +@export test_mul :: proc "c" (a, b: cstring) -> (res: PyRes) { context = runtime.default_context(); err: Error; aa, bb, product := &Int{}, &Int{}, &Int{}; defer destroy(aa, bb, product); - if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; } - if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; } - if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; } + if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul:atoi(a):", err=err}; } + if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul:atoi(b):", err=err}; } + if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul:mul(product,a,b):", err=err}; } r: cstring; r, err = int_itoa_cstring(product, 10, context.temp_allocator); - if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; } + if err != .None { return PyRes{res=":mul:itoa(product):", err=err}; } return PyRes{res = r, err = .None}; } /* NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient. */ -@export test_div_two :: proc "c" (a, b: cstring) -> (res: PyRes) { +@export test_div :: proc "c" (a, b: cstring) -> (res: PyRes) { context = runtime.default_context(); err: Error; aa, bb, quotient := &Int{}, &Int{}, &Int{}; defer destroy(aa, bb, quotient); - if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; } - if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; } - if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; } + if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div:atoi(a):", err=err}; } + if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div:atoi(b):", err=err}; } + if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div:div(quotient,a,b):", err=err}; } r: cstring; r, err = int_itoa_cstring(quotient, 10, context.temp_allocator); - if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; } + if err != .None { return PyRes{res=":div:itoa(quotient):", err=err}; } return PyRes{res = r, err = .None}; } @@ -130,7 +130,6 @@ PyRes :: struct { @export test_pow :: proc "c" (base: cstring, power := int(2)) -> (res: PyRes) { context = runtime.default_context(); err: Error; - l: int; dest, bb := &Int{}, &Int{}; defer destroy(dest, bb); @@ -142,4 +141,120 @@ PyRes :: struct { r, err = int_itoa_cstring(dest, 10, context.temp_allocator); if err != .None { return PyRes{res=":log:itoa(res):", err=err}; } return PyRes{res = r, err = .None}; -} \ No newline at end of file +} + +/* + dest = sqrt(src) +*/ +@export test_sqrt :: proc "c" (source: cstring) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":sqrt:atoi(src):", err=err}; } + if err = sqrt(src, src); err != .None { return PyRes{res=":sqrt:sqrt(src):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":log:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + + +/* + dest = shr_digit(src, digits) +*/ +@export test_shr_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_digit:atoi(src):", err=err}; } + if err = shr_digit(src, digits); err != .None { return PyRes{res=":shr_digit:shr_digit(src):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":shr_digit:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + +/* + dest = shl_digit(src, digits) +*/ +@export test_shl_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl_digit:atoi(src):", err=err}; } + if err = shl_digit(src, digits); err != .None { return PyRes{res=":shl_digit:shr_digit(src):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":shl_digit:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + +/* + dest = shr(src, bits) +*/ +@export test_shr :: proc "c" (source: cstring, bits: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr:atoi(src):", err=err}; } + if err = shr(src, src, bits); err != .None { return PyRes{res=":shr:shr(src, bits):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":shr:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + +/* + dest = shr_signed(src, bits) +*/ +@export test_shr_signed :: proc "c" (source: cstring, bits: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_signed:atoi(src):", err=err}; } + if err = shr_signed(src, src, bits); err != .None { return PyRes{res=":shr_signed:shr_signed(src, bits):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":shr_signed:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + +/* + dest = shl(src, bits) +*/ +@export test_shl :: proc "c" (source: cstring, bits: int) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + + src := &Int{}; + defer destroy(src); + + if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl:atoi(src):", err=err}; } + if err = shl(src, src, bits); err != .None { return PyRes{res=":shl:shl(src, bits):", err=err}; } + + r: cstring; + r, err = int_itoa_cstring(src, 10, context.temp_allocator); + if err != .None { return PyRes{res=":shl:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} + diff --git a/core/math/big/test.py b/core/math/big/test.py index 8308c2d5e..956d75e1c 100644 --- a/core/math/big/test.py +++ b/core/math/big/test.py @@ -1,6 +1,6 @@ -from math import * from ctypes import * from random import * +import math import os import platform import time @@ -10,12 +10,14 @@ from enum import Enum # Normally, we report the number of passes and fails. # With EXIT_ON_FAIL set, we exit at the first fail. # +EXIT_ON_FAIL = True EXIT_ON_FAIL = False # # We skip randomized tests altogether if NO_RANDOM_TESTS is set. # -NO_RANDOM_TESTS = False #True +NO_RANDOM_TESTS = True +NO_RANDOM_TESTS = False # # If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations. @@ -113,13 +115,23 @@ class Res(Structure): error_string = load(l.test_error_string, [c_byte], c_char_p) -add_two = load(l.test_add_two, [c_char_p, c_char_p], Res) -sub_two = load(l.test_sub_two, [c_char_p, c_char_p], Res) -mul_two = load(l.test_mul_two, [c_char_p, c_char_p], Res) -div_two = load(l.test_div_two, [c_char_p, c_char_p], Res) +add = load(l.test_add, [c_char_p, c_char_p], Res) +sub = load(l.test_sub, [c_char_p, c_char_p], Res) +mul = load(l.test_mul, [c_char_p, c_char_p], Res) +div = load(l.test_div, [c_char_p, c_char_p], Res) -int_log = load(l.test_log, [c_char_p, c_longlong], Res) -int_pow = load(l.test_pow, [c_char_p, c_longlong], Res) +# Powers and such +int_log = load(l.test_log, [c_char_p, c_longlong], Res) +int_pow = load(l.test_pow, [c_char_p, c_longlong], Res) +int_sqrt = load(l.test_sqrt, [c_char_p], Res) + +# Logical operations + +int_shl_digit = load(l.test_shl_digit, [c_char_p, c_longlong], Res) +int_shr_digit = load(l.test_shr_digit, [c_char_p, c_longlong], Res) +int_shl = load(l.test_shl, [c_char_p, c_longlong], Res) +int_shr = load(l.test_shr, [c_char_p, c_longlong], Res) +int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res) def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""): passed = True @@ -156,38 +168,38 @@ def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expecte return passed -def test_add_two(a = 0, b = 0, expected_error = Error.Okay): +def test_add(a = 0, b = 0, expected_error = Error.Okay): args = [str(a), str(b)] sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8') - res = add_two(sa_c, sb_c) + res = add(sa_c, sb_c) expected_result = None if expected_error == Error.Okay: expected_result = a + b - return test("test_add_two", res, args, expected_error, expected_result) + return test("test_add", res, args, expected_error, expected_result) -def test_sub_two(a = 0, b = 0, expected_error = Error.Okay): +def test_sub(a = 0, b = 0, expected_error = Error.Okay): sa, sb = str(a), str(b) sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') - res = sub_two(sa_c, sb_c) + res = sub(sa_c, sb_c) expected_result = None if expected_error == Error.Okay: expected_result = a - b - return test("test_sub_two", res, [sa_c, sb_c], expected_error, expected_result) + return test("test_sub", res, [sa_c, sb_c], expected_error, expected_result) -def test_mul_two(a = 0, b = 0, expected_error = Error.Okay): +def test_mul(a = 0, b = 0, expected_error = Error.Okay): sa, sb = str(a), str(b) sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') - res = mul_two(sa_c, sb_c) + res = mul(sa_c, sb_c) expected_result = None if expected_error == Error.Okay: expected_result = a * b - return test("test_mul_two", res, [sa_c, sb_c], expected_error, expected_result) + return test("test_mul", res, [sa_c, sb_c], expected_error, expected_result) -def test_div_two(a = 0, b = 0, expected_error = Error.Okay): +def test_div(a = 0, b = 0, expected_error = Error.Okay): sa, sb = str(a), str(b) sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') try: - res = div_two(sa_c, sb_c) + res = div(sa_c, sb_c) except: print("Exception with arguments:", a, b) return False @@ -197,12 +209,12 @@ def test_div_two(a = 0, b = 0, expected_error = Error.Okay): # We don't round the division results, so if one component is negative, we're off by one. # if a < 0 and b > 0: - expected_result = int(-(abs(a) / b)) + expected_result = int(-(abs(a) // b)) elif b < 0 and a > 0: - expected_result = int(-(a / abs((b)))) + expected_result = int(-(a // abs((b)))) else: expected_result = a // b if b != 0 else None - return test("test_div_two", res, [sa_c, sb_c], expected_error, expected_result) + return test("test_div", res, [sa_c, sb_c], expected_error, expected_result) def test_log(a = 0, base = 0, expected_error = Error.Okay): @@ -212,7 +224,7 @@ def test_log(a = 0, base = 0, expected_error = Error.Okay): expected_result = None if expected_error == Error.Okay: - expected_result = int(log(a, base)) + expected_result = int(math.log(a, base)) return test("test_log", res, args, expected_error, expected_result) def test_pow(base = 0, power = 0, expected_error = Error.Okay): @@ -225,9 +237,108 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay): if power < 0: expected_result = 0 else: - expected_result = int(base**power) + # NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation. + # Use built-in `pow` or `a**b` instead. + expected_result = pow(base, power) return test("test_pow", res, args, expected_error, expected_result) +def test_sqrt(number = 0, expected_error = Error.Okay): + args = [str(number)] + sa_c = args[0].encode('utf-8') + try: + res = int_sqrt(sa_c) + except: + print("sqrt:", number) + + expected_result = None + if expected_error == Error.Okay: + if number < 0: + expected_result = 0 + else: + expected_result = int(math.isqrt(number)) + return test("test_sqrt", res, args, expected_error, expected_result) + +def root_n(number, root): + u, s = number, number + 1 + while u < s: + s = u + t = (root-1) * s + number // pow(s, root - 1) + u = t // root + return s + +def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay): + args = [str(a), digits] + sa_c = args[0].encode('utf-8') + res = int_shl_digit(sa_c, digits) + + expected_result = None + if expected_error == Error.Okay: + expected_result = a << (digits * 60) + return test("test_shl_digit", res, args, expected_error, expected_result) + +def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay): + args = [str(a), digits] + sa_c = args[0].encode('utf-8') + try: + res = int_shr_digit(sa_c, digits) + except: + print("int_shr_digit", a, digits) + exit() + + expected_result = None + if expected_error == Error.Okay: + if a < 0: + # Don't pass negative numbers. We have a shr_signed. + return False + else: + expected_result = a >> (digits * 60) + + return test("test_shr_digit", res, args, expected_error, expected_result) + +def test_shl(a = 0, bits = 0, expected_error = Error.Okay): + args = [str(a), bits] + sa_c = args[0].encode('utf-8') + res = int_shl(sa_c, bits) + + expected_result = None + if expected_error == Error.Okay: + expected_result = a << bits + return test("test_shl", res, args, expected_error, expected_result) + +def test_shr(a = 0, bits = 0, expected_error = Error.Okay): + args = [str(a), bits] + sa_c = args[0].encode('utf-8') + try: + res = int_shr(sa_c, bits) + except: + print("int_shr", a, bits) + exit() + + expected_result = None + if expected_error == Error.Okay: + if a < 0: + # Don't pass negative numbers. We have a shr_signed. + return False + else: + expected_result = a >> bits + + return test("test_shr", res, args, expected_error, expected_result) + +def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay): + args = [str(a), bits] + sa_c = args[0].encode('utf-8') + try: + res = int_shr_signed(sa_c, bits) + except: + print("int_shr_signed", a, bits) + exit() + + expected_result = None + if expected_error == Error.Okay: + expected_result = a >> bits + + return test("test_shr_signed", res, args, expected_error, expected_result) + # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on. # # The last two arguments in tests are the expected error and expected result. @@ -237,19 +348,20 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay): # You can override that by supplying an expected result as the last argument instead. TESTS = { - test_add_two: [ + test_add: [ [ 1234, 5432], ], - test_sub_two: [ + test_sub: [ [ 1234, 5432], ], - test_mul_two: [ + test_mul: [ [ 1234, 5432], [ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7] ], - test_div_two: [ + test_div: [ [ 54321, 12345], [ 55431, 0, Error.Division_by_Zero], + [ 12980742146337069150589594264770969721, 4611686018427387904 ], ], test_log: [ [ 3192, 1, Error.Invalid_Argument], @@ -260,29 +372,87 @@ TESTS = { test_pow: [ [ 0, -1, Error.Math_Domain_Error ], # Math [ 0, 0 ], # 1 - [ 0, 2 ], # 0 - [ 42, -1,], # 0 - [ 42, 1 ], # 1 - [ 42, 0 ], # 42 - [ 42, 2 ], # 42*42 - - + [ 0, 2 ], # 0 + [ 42, -1,], # 0 + [ 42, 1 ], # 1 + [ 42, 0 ], # 42 + [ 42, 2 ], # 42*42 ], + test_sqrt: [ + [ -1, Error.Invalid_Argument, ], + [ 42, Error.Okay, ], + [ 12345678901234567890, Error.Okay, ], + [ 1298074214633706907132624082305024, Error.Okay, ], + ], + test_shl_digit: [ + [ 3192, 1 ], + [ 1298074214633706907132624082305024, 2 ], + [ 1024, 3 ], + ], + test_shr_digit: [ + [ 3680125442705055547392, 1 ], + [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ], + [ 219504133884436710204395031992179571, 2 ], + ], + test_shl: [ + [ 3192, 1 ], + [ 1298074214633706907132624082305024, 2 ], + [ 1024, 3 ], + ], + test_shr: [ + [ 3680125442705055547392, 1 ], + [ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ], + [ 219504133884436710204395031992179571, 2 ], + ], + test_shr_signed: [ + [ -611105530635358368578155082258244262, 12 ], + [ -149195686190273039203651143129455, 12 ], + [ 611105530635358368578155082258244262, 12 ], + [ 149195686190273039203651143129455, 12 ], + ] } total_passes = 0 total_failures = 0 +# +# test_shr_signed also tests shr, so we're not going to test shr randomly. +# RANDOM_TESTS = [ - test_add_two, test_sub_two, test_mul_two, test_div_two, - test_log, test_pow, + test_add, test_sub, test_mul, test_div, + test_log, test_pow, test_sqrt, + test_shl_digit, test_shr_digit, test_shl, test_shr_signed, ] +SKIP_LARGE = [test_pow] +SKIP_LARGEST = [] # Untimed warmup. for test_proc in TESTS: for t in TESTS[test_proc]: res = test_proc(*t) + +def isqrt(x): + n = int(x) + a, b = divmod(n.bit_length(), 2) + print("isqrt({}), a: {}, b: {}". format(n, a, b)) + x = 2**(a+b) + print("initial: {}".format(x)) + i = 0 + while True: + # y = (x + n//x)//2 + t1 = n // x + t2 = x + t1 + t3 = t2 // 2 + y = (x + n//x)//2 + + i += 1 + print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n)); + + if y >= x: + return x + x = y + if __name__ == '__main__': print("---- math/big tests ----") print() @@ -317,7 +487,8 @@ if __name__ == '__main__': print() for test_proc in RANDOM_TESTS: - if test_proc == test_pow and BITS > 1_200: continue + if BITS > 1_200 and test_proc in SKIP_LARGE: continue + if BITS > 4_096 and test_proc in SKIP_LARGEST: continue count_pass = 0 count_fail = 0 @@ -331,8 +502,10 @@ if __name__ == '__main__': a = randint(-(1 << BITS), 1 << BITS) b = randint(-(1 << BITS), 1 << BITS) - if test_proc == test_div_two: + if test_proc == test_div: # We've already tested division by zero above. + bits = int(BITS * 0.6) + b = randint(-(1 << bits), 1 << bits) if b == 0: b == 42 elif test_proc == test_log: @@ -341,6 +514,18 @@ if __name__ == '__main__': b = randint(2, 1 << 60) elif test_proc == test_pow: b = randint(1, 10) + elif test_proc == test_sqrt: + a = randint(1, 1 << BITS) + b = Error.Okay + elif test_proc == test_shl_digit: + b = randint(0, 10); + elif test_proc == test_shr_digit: + a = abs(a) + b = randint(0, 10); + elif test_proc == test_shl: + b = randint(0, min(BITS, 120)); + elif test_proc == test_shr_signed: + b = randint(0, min(BITS, 120)); else: b = randint(0, 1 << BITS)