mirror of
https://github.com/Ed94/Odin.git
synced 2026-06-17 11:22:22 -07:00
big: Fix sqrt, div, add with certain inputs.
This commit is contained in:
+13
-17
@@ -359,7 +359,7 @@ int_halve :: proc(dest, src: ^Int) -> (err: Error) {
|
||||
*/
|
||||
fwd_carry := DIGIT(0);
|
||||
|
||||
for x := dest.used; x >= 0; x -= 1 {
|
||||
for x := dest.used - 1; x >= 0; x -= 1 {
|
||||
/*
|
||||
Get the carry for the next iteration.
|
||||
*/
|
||||
@@ -761,21 +761,16 @@ sqrmod :: proc { int_sqrmod, };
|
||||
*/
|
||||
_int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
|
||||
dest := dest; x := a; y := b;
|
||||
if err = clear_if_uninitialized(x); err != .None {
|
||||
return err;
|
||||
}
|
||||
if err = clear_if_uninitialized(y); err != .None {
|
||||
return err;
|
||||
}
|
||||
|
||||
old_used, min_used, max_used, i: int;
|
||||
|
||||
if x.used < y.used {
|
||||
x, y = y, x;
|
||||
assert(x.used >= y.used);
|
||||
}
|
||||
|
||||
min_used = x.used;
|
||||
max_used = y.used;
|
||||
min_used = y.used;
|
||||
max_used = x.used;
|
||||
old_used = dest.used;
|
||||
|
||||
if err = grow(dest, max(max_used + 1, _DEFAULT_DIGIT_COUNT)); err != .None {
|
||||
@@ -827,7 +822,6 @@ _int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
|
||||
Add remaining carry.
|
||||
*/
|
||||
dest.digit[i] = carry;
|
||||
|
||||
zero_count := old_used - dest.used;
|
||||
/*
|
||||
Zero remainder.
|
||||
@@ -1111,6 +1105,7 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) {
|
||||
_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
|
||||
|
||||
ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
|
||||
c: int;
|
||||
|
||||
goto_end: for {
|
||||
if err = one(tq); err != .None { break goto_end; }
|
||||
@@ -1121,20 +1116,21 @@ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (er
|
||||
|
||||
if err = abs(ta, numerator); err != .None { break goto_end; }
|
||||
if err = abs(tb, denominator); err != .None { break goto_end; }
|
||||
|
||||
if err = shl(tb, tb, n); err != .None { break goto_end; }
|
||||
if err = shl(tq, tq, n); err != .None { break goto_end; }
|
||||
|
||||
for ; n >= 0; n -= 1 {
|
||||
c: int;
|
||||
if c, err = cmp(tb, ta); err != .None { break goto_end; }
|
||||
if c != 1 {
|
||||
for n >= 0 {
|
||||
if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
|
||||
// ta -= tb
|
||||
if err = sub(ta, ta, tb); err != .None { break goto_end; }
|
||||
if err = add( q, tq, q); err != .None { break goto_end; }
|
||||
// q += tq
|
||||
if err = add( q, q, tq); err != .None { break goto_end; }
|
||||
}
|
||||
if err = shr1(tb, tb); err != .None { break goto_end; }
|
||||
if err = shr1(tq, tq); err != .None { break goto_end; }
|
||||
}
|
||||
|
||||
n -= 1;
|
||||
}
|
||||
|
||||
/*
|
||||
Now q == quotient and ta == remainder.
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
@echo off
|
||||
clear
|
||||
:odin run . -vet
|
||||
:odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
|
||||
odin build . -build-mode:shared -show-timings -o:size -use-separate-modules
|
||||
|
||||
@@ -66,17 +66,26 @@ print :: proc(name: string, a: ^Int, base := i8(10)) {
|
||||
}
|
||||
|
||||
demo :: proc() {
|
||||
// err: Error;
|
||||
// r := &rnd.Rand{};
|
||||
// rnd.init(r, 12345);
|
||||
err: Error;
|
||||
destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
|
||||
defer destroy(destination, source, quotient, remainder, numerator, denominator);
|
||||
|
||||
// as := cstring("12341234");
|
||||
// bs := cstring("159671292010002348397151706347412301");
|
||||
err = atoi(source, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
|
||||
print("src ", source);
|
||||
|
||||
// res := test_log(as, 2, 10);
|
||||
// fmt.print(res);
|
||||
// destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
|
||||
// defer destroy(destination, source, quotient, remainder, numerator, denominator);
|
||||
fmt.println("sqrt should be 843478780275248664696797599030708027195155136953848512749494");
|
||||
|
||||
fmt.println();
|
||||
err = sqrt(destination, source);
|
||||
fmt.printf("sqrt returned: %v\n", err);
|
||||
print("sqrt ", destination);
|
||||
|
||||
err = atoi(denominator, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
|
||||
err = root_n(quotient, denominator, 2);
|
||||
fmt.printf("root_n(2) returned: %v\n", err);
|
||||
print("root_n(2)", quotient);
|
||||
|
||||
// fmt.println();
|
||||
}
|
||||
|
||||
main :: proc() {
|
||||
|
||||
+58
-29
@@ -214,42 +214,53 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
|
||||
This function is less generic than `root_n`, simpler and faster.
|
||||
*/
|
||||
int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(src); err != .None { return err; }
|
||||
|
||||
/* Must be positive. */
|
||||
if src.sign == .Negative { return .Invalid_Argument; }
|
||||
when true {
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
if err = clear_if_uninitialized(src); err != .None { return err; }
|
||||
|
||||
/* Easy out. If src is zero, so is dest. */
|
||||
if z, _ := is_zero(src); z { return zero(dest); }
|
||||
/* Must be positive. */
|
||||
if src.sign == .Negative { return .Invalid_Argument; }
|
||||
|
||||
/* Set up temporaries. */
|
||||
t1, t2 := &Int{}, &Int{};
|
||||
defer destroy(t1, t2);
|
||||
/* Easy out. If src is zero, so is dest. */
|
||||
if z, _ := is_zero(src); z { return zero(dest); }
|
||||
|
||||
if err = copy(t1, src); err != .None { return err; }
|
||||
if err = zero(t2); err != .None { return err; }
|
||||
/* Set up temporaries. */
|
||||
x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{};
|
||||
defer destroy(x, y, t1, t2);
|
||||
|
||||
/* First approximation. Not very bad for large arguments. */
|
||||
if err = shr_digit(t1, t1.used / 2); err != .None { return err; }
|
||||
/* t1 > 0 */
|
||||
if err = div(t2, src, t1); err != .None { return err; }
|
||||
if err = add(t1, t1, t2); err != .None { return err; }
|
||||
if err = shr(t1, t1, 1); err != .None { return err; }
|
||||
count: int;
|
||||
if count, err = count_bits(src); err != .None { return err; }
|
||||
|
||||
/* And now t1 > sqrt(arg). */
|
||||
for {
|
||||
if err = div(t2, src, t1); err != .None { return err; }
|
||||
if err = add(t1, t1, t2); err != .None { return err; }
|
||||
if err = shr(t1, t1, 1); err != .None { return err; }
|
||||
/* t1 >= sqrt(arg) >= t2 at this point */
|
||||
a, b := count >> 1, count & 1;
|
||||
err = power_of_two(x, a+b);
|
||||
|
||||
cm, _ := cmp_mag(t1, t2);
|
||||
if cm != 1 { break; }
|
||||
iter := 0;
|
||||
for {
|
||||
iter += 1;
|
||||
if iter > 100 {
|
||||
swap(dest, x);
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
/*
|
||||
y = (x + n//x)//2
|
||||
*/
|
||||
div(t1, src, x);
|
||||
add(t2, t1, x);
|
||||
shr(y, t2, 1);
|
||||
|
||||
if c, _ := cmp(y, x); c == 0 || c == 1 {
|
||||
swap(dest, x);
|
||||
return .None;
|
||||
}
|
||||
swap(x, y);
|
||||
}
|
||||
|
||||
swap(dest, x);
|
||||
return err;
|
||||
} else {
|
||||
// return root_n(dest, src, 2);
|
||||
}
|
||||
|
||||
swap(dest, t1);
|
||||
return err;
|
||||
}
|
||||
sqrt :: proc { int_sqrt, };
|
||||
|
||||
@@ -263,7 +274,7 @@ sqrt :: proc { int_sqrt, };
|
||||
*/
|
||||
int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
/* Fast path for n == 2 */
|
||||
if n == 2 { return sqrt(dest, src); }
|
||||
// if n == 2 { return sqrt(dest, src); }
|
||||
|
||||
/* Initialize dest + src if needed. */
|
||||
if err = clear_if_uninitialized(dest); err != .None { return err; }
|
||||
@@ -321,6 +332,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
if err = power_of_two(t2, ilog2); err != .None { return err; }
|
||||
|
||||
c: int;
|
||||
iterations := 0;
|
||||
for {
|
||||
/* t1 = t2 */
|
||||
if err = copy(t1, t2); err != .None { return err; }
|
||||
@@ -353,12 +365,23 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
break;
|
||||
}
|
||||
if c, err = cmp(t1, t2); c == 0 { break; }
|
||||
iterations += 1;
|
||||
if iterations == 101 {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
}
|
||||
|
||||
/* Result can be off by a few so check. */
|
||||
/* Loop beneath can overshoot by one if found root is smaller than actual root. */
|
||||
|
||||
iterations = 0;
|
||||
for {
|
||||
if iterations == 101 {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
//fmt.printf("root_n iteration: %v\n", iterations);
|
||||
iterations += 1;
|
||||
|
||||
if err = pow(t2, t1, n); err != .None { return err; }
|
||||
|
||||
c, err = cmp(t2, a);
|
||||
@@ -372,8 +395,14 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
|
||||
}
|
||||
}
|
||||
|
||||
iterations = 0;
|
||||
/* Correct overshoot from above or from recurrence. */
|
||||
for {
|
||||
if iterations == 101 {
|
||||
return .Max_Iterations_Reached;
|
||||
}
|
||||
iterations += 1;
|
||||
|
||||
if err = pow(t2, t1, n); err != .None { return err; }
|
||||
|
||||
c, err = cmp(t2, a);
|
||||
|
||||
@@ -285,7 +285,7 @@ int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int) -> (err: Err
|
||||
shift := DIGIT(_DIGIT_BITS - bits);
|
||||
carry := DIGIT(0);
|
||||
|
||||
for x := quotient.used; x >= 0; x -= 1 {
|
||||
for x := quotient.used - 1; x >= 0; x -= 1 {
|
||||
/*
|
||||
Get the lower bits of this word in a temp.
|
||||
*/
|
||||
@@ -344,7 +344,7 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
|
||||
}
|
||||
quotient.used -= digits;
|
||||
_zero_unused(quotient);
|
||||
return .None;
|
||||
return clamp(quotient);
|
||||
}
|
||||
shr_digit :: proc { int_shr_digit, };
|
||||
|
||||
@@ -446,16 +446,16 @@ int_shl_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
|
||||
/*
|
||||
Increment the used by the shift amount then copy upwards.
|
||||
*/
|
||||
quotient.used += digits;
|
||||
|
||||
/*
|
||||
Much like `int_shr_digit`, this is implemented using a sliding window,
|
||||
except the window goes the other way around.
|
||||
*/
|
||||
for x := quotient.used; x >= digits; x -= 1 {
|
||||
quotient.digit[x] = quotient.digit[x - digits];
|
||||
for x := quotient.used; x > 0; x -= 1 {
|
||||
quotient.digit[x+digits-1] = quotient.digit[x-1];
|
||||
}
|
||||
|
||||
quotient.used += digits;
|
||||
mem.zero_slice(quotient.digit[:digits]);
|
||||
return .None;
|
||||
}
|
||||
|
||||
+137
-22
@@ -26,74 +26,74 @@ PyRes :: struct {
|
||||
return strings.clone_to_cstring(es[err], context.temp_allocator);
|
||||
}
|
||||
|
||||
@export test_add_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
@export test_add :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
aa, bb, sum := &Int{}, &Int{}, &Int{};
|
||||
defer destroy(aa, bb, sum);
|
||||
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add_two:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add_two:atoi(b):", err=err}; }
|
||||
if err = add(sum, aa, bb); err != .None { return PyRes{res=":add_two:add(sum,a,b):", err=err}; }
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add:atoi(b):", err=err}; }
|
||||
if err = add(sum, aa, bb); err != .None { return PyRes{res=":add:add(sum,a,b):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; }
|
||||
if err != .None { return PyRes{res=":add:itoa(sum):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
@export test_sub_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
@export test_sub :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
aa, bb, sum := &Int{}, &Int{}, &Int{};
|
||||
defer destroy(aa, bb, sum);
|
||||
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; }
|
||||
if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; }
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub:atoi(b):", err=err}; }
|
||||
if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub:sub(sum,a,b):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; }
|
||||
if err != .None { return PyRes{res=":sub:itoa(sum):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
@export test_mul_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
@export test_mul :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
aa, bb, product := &Int{}, &Int{}, &Int{};
|
||||
defer destroy(aa, bb, product);
|
||||
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; }
|
||||
if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; }
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul:atoi(b):", err=err}; }
|
||||
if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul:mul(product,a,b):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(product, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; }
|
||||
if err != .None { return PyRes{res=":mul:itoa(product):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
|
||||
*/
|
||||
@export test_div_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
@export test_div :: proc "c" (a, b: cstring) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
aa, bb, quotient := &Int{}, &Int{}, &Int{};
|
||||
defer destroy(aa, bb, quotient);
|
||||
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; }
|
||||
if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; }
|
||||
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div:atoi(a):", err=err}; }
|
||||
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div:atoi(b):", err=err}; }
|
||||
if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div:div(quotient,a,b):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(quotient, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
|
||||
if err != .None { return PyRes{res=":div:itoa(quotient):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
@@ -130,7 +130,6 @@ PyRes :: struct {
|
||||
@export test_pow :: proc "c" (base: cstring, power := int(2)) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
l: int;
|
||||
|
||||
dest, bb := &Int{}, &Int{};
|
||||
defer destroy(dest, bb);
|
||||
@@ -142,4 +141,120 @@ PyRes :: struct {
|
||||
r, err = int_itoa_cstring(dest, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
dest = sqrt(src)
|
||||
*/
|
||||
@export test_sqrt :: proc "c" (source: cstring) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":sqrt:atoi(src):", err=err}; }
|
||||
if err = sqrt(src, src); err != .None { return PyRes{res=":sqrt:sqrt(src):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
dest = shr_digit(src, digits)
|
||||
*/
|
||||
@export test_shr_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_digit:atoi(src):", err=err}; }
|
||||
if err = shr_digit(src, digits); err != .None { return PyRes{res=":shr_digit:shr_digit(src):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":shr_digit:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = shl_digit(src, digits)
|
||||
*/
|
||||
@export test_shl_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl_digit:atoi(src):", err=err}; }
|
||||
if err = shl_digit(src, digits); err != .None { return PyRes{res=":shl_digit:shr_digit(src):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":shl_digit:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = shr(src, bits)
|
||||
*/
|
||||
@export test_shr :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr:atoi(src):", err=err}; }
|
||||
if err = shr(src, src, bits); err != .None { return PyRes{res=":shr:shr(src, bits):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":shr:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = shr_signed(src, bits)
|
||||
*/
|
||||
@export test_shr_signed :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_signed:atoi(src):", err=err}; }
|
||||
if err = shr_signed(src, src, bits); err != .None { return PyRes{res=":shr_signed:shr_signed(src, bits):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":shr_signed:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
/*
|
||||
dest = shl(src, bits)
|
||||
*/
|
||||
@export test_shl :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
|
||||
context = runtime.default_context();
|
||||
err: Error;
|
||||
|
||||
src := &Int{};
|
||||
defer destroy(src);
|
||||
|
||||
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl:atoi(src):", err=err}; }
|
||||
if err = shl(src, src, bits); err != .None { return PyRes{res=":shl:shl(src, bits):", err=err}; }
|
||||
|
||||
r: cstring;
|
||||
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
|
||||
if err != .None { return PyRes{res=":shl:itoa(res):", err=err}; }
|
||||
return PyRes{res = r, err = .None};
|
||||
}
|
||||
|
||||
|
||||
+224
-39
@@ -1,6 +1,6 @@
|
||||
from math import *
|
||||
from ctypes import *
|
||||
from random import *
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
@@ -10,12 +10,14 @@ from enum import Enum
|
||||
# Normally, we report the number of passes and fails.
|
||||
# With EXIT_ON_FAIL set, we exit at the first fail.
|
||||
#
|
||||
EXIT_ON_FAIL = True
|
||||
EXIT_ON_FAIL = False
|
||||
|
||||
#
|
||||
# We skip randomized tests altogether if NO_RANDOM_TESTS is set.
|
||||
#
|
||||
NO_RANDOM_TESTS = False #True
|
||||
NO_RANDOM_TESTS = True
|
||||
NO_RANDOM_TESTS = False
|
||||
|
||||
#
|
||||
# If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
|
||||
@@ -113,13 +115,23 @@ class Res(Structure):
|
||||
|
||||
error_string = load(l.test_error_string, [c_byte], c_char_p)
|
||||
|
||||
add_two = load(l.test_add_two, [c_char_p, c_char_p], Res)
|
||||
sub_two = load(l.test_sub_two, [c_char_p, c_char_p], Res)
|
||||
mul_two = load(l.test_mul_two, [c_char_p, c_char_p], Res)
|
||||
div_two = load(l.test_div_two, [c_char_p, c_char_p], Res)
|
||||
add = load(l.test_add, [c_char_p, c_char_p], Res)
|
||||
sub = load(l.test_sub, [c_char_p, c_char_p], Res)
|
||||
mul = load(l.test_mul, [c_char_p, c_char_p], Res)
|
||||
div = load(l.test_div, [c_char_p, c_char_p], Res)
|
||||
|
||||
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
|
||||
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
|
||||
# Powers and such
|
||||
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
|
||||
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
|
||||
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
|
||||
|
||||
# Logical operations
|
||||
|
||||
int_shl_digit = load(l.test_shl_digit, [c_char_p, c_longlong], Res)
|
||||
int_shr_digit = load(l.test_shr_digit, [c_char_p, c_longlong], Res)
|
||||
int_shl = load(l.test_shl, [c_char_p, c_longlong], Res)
|
||||
int_shr = load(l.test_shr, [c_char_p, c_longlong], Res)
|
||||
int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
|
||||
|
||||
def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
|
||||
passed = True
|
||||
@@ -156,38 +168,38 @@ def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expecte
|
||||
return passed
|
||||
|
||||
|
||||
def test_add_two(a = 0, b = 0, expected_error = Error.Okay):
|
||||
def test_add(a = 0, b = 0, expected_error = Error.Okay):
|
||||
args = [str(a), str(b)]
|
||||
sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
|
||||
res = add_two(sa_c, sb_c)
|
||||
res = add(sa_c, sb_c)
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a + b
|
||||
return test("test_add_two", res, args, expected_error, expected_result)
|
||||
return test("test_add", res, args, expected_error, expected_result)
|
||||
|
||||
def test_sub_two(a = 0, b = 0, expected_error = Error.Okay):
|
||||
def test_sub(a = 0, b = 0, expected_error = Error.Okay):
|
||||
sa, sb = str(a), str(b)
|
||||
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
|
||||
res = sub_two(sa_c, sb_c)
|
||||
res = sub(sa_c, sb_c)
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a - b
|
||||
return test("test_sub_two", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
return test("test_sub", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
|
||||
def test_mul_two(a = 0, b = 0, expected_error = Error.Okay):
|
||||
def test_mul(a = 0, b = 0, expected_error = Error.Okay):
|
||||
sa, sb = str(a), str(b)
|
||||
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
|
||||
res = mul_two(sa_c, sb_c)
|
||||
res = mul(sa_c, sb_c)
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a * b
|
||||
return test("test_mul_two", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
return test("test_mul", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
|
||||
def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
|
||||
def test_div(a = 0, b = 0, expected_error = Error.Okay):
|
||||
sa, sb = str(a), str(b)
|
||||
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
|
||||
try:
|
||||
res = div_two(sa_c, sb_c)
|
||||
res = div(sa_c, sb_c)
|
||||
except:
|
||||
print("Exception with arguments:", a, b)
|
||||
return False
|
||||
@@ -197,12 +209,12 @@ def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
|
||||
# We don't round the division results, so if one component is negative, we're off by one.
|
||||
#
|
||||
if a < 0 and b > 0:
|
||||
expected_result = int(-(abs(a) / b))
|
||||
expected_result = int(-(abs(a) // b))
|
||||
elif b < 0 and a > 0:
|
||||
expected_result = int(-(a / abs((b))))
|
||||
expected_result = int(-(a // abs((b))))
|
||||
else:
|
||||
expected_result = a // b if b != 0 else None
|
||||
return test("test_div_two", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
return test("test_div", res, [sa_c, sb_c], expected_error, expected_result)
|
||||
|
||||
|
||||
def test_log(a = 0, base = 0, expected_error = Error.Okay):
|
||||
@@ -212,7 +224,7 @@ def test_log(a = 0, base = 0, expected_error = Error.Okay):
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = int(log(a, base))
|
||||
expected_result = int(math.log(a, base))
|
||||
return test("test_log", res, args, expected_error, expected_result)
|
||||
|
||||
def test_pow(base = 0, power = 0, expected_error = Error.Okay):
|
||||
@@ -225,9 +237,108 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
|
||||
if power < 0:
|
||||
expected_result = 0
|
||||
else:
|
||||
expected_result = int(base**power)
|
||||
# NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation.
|
||||
# Use built-in `pow` or `a**b` instead.
|
||||
expected_result = pow(base, power)
|
||||
return test("test_pow", res, args, expected_error, expected_result)
|
||||
|
||||
def test_sqrt(number = 0, expected_error = Error.Okay):
|
||||
args = [str(number)]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
try:
|
||||
res = int_sqrt(sa_c)
|
||||
except:
|
||||
print("sqrt:", number)
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
if number < 0:
|
||||
expected_result = 0
|
||||
else:
|
||||
expected_result = int(math.isqrt(number))
|
||||
return test("test_sqrt", res, args, expected_error, expected_result)
|
||||
|
||||
def root_n(number, root):
|
||||
u, s = number, number + 1
|
||||
while u < s:
|
||||
s = u
|
||||
t = (root-1) * s + number // pow(s, root - 1)
|
||||
u = t // root
|
||||
return s
|
||||
|
||||
def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), digits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
res = int_shl_digit(sa_c, digits)
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a << (digits * 60)
|
||||
return test("test_shl_digit", res, args, expected_error, expected_result)
|
||||
|
||||
def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), digits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
try:
|
||||
res = int_shr_digit(sa_c, digits)
|
||||
except:
|
||||
print("int_shr_digit", a, digits)
|
||||
exit()
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
if a < 0:
|
||||
# Don't pass negative numbers. We have a shr_signed.
|
||||
return False
|
||||
else:
|
||||
expected_result = a >> (digits * 60)
|
||||
|
||||
return test("test_shr_digit", res, args, expected_error, expected_result)
|
||||
|
||||
def test_shl(a = 0, bits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), bits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
res = int_shl(sa_c, bits)
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a << bits
|
||||
return test("test_shl", res, args, expected_error, expected_result)
|
||||
|
||||
def test_shr(a = 0, bits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), bits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
try:
|
||||
res = int_shr(sa_c, bits)
|
||||
except:
|
||||
print("int_shr", a, bits)
|
||||
exit()
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
if a < 0:
|
||||
# Don't pass negative numbers. We have a shr_signed.
|
||||
return False
|
||||
else:
|
||||
expected_result = a >> bits
|
||||
|
||||
return test("test_shr", res, args, expected_error, expected_result)
|
||||
|
||||
def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
|
||||
args = [str(a), bits]
|
||||
sa_c = args[0].encode('utf-8')
|
||||
try:
|
||||
res = int_shr_signed(sa_c, bits)
|
||||
except:
|
||||
print("int_shr_signed", a, bits)
|
||||
exit()
|
||||
|
||||
expected_result = None
|
||||
if expected_error == Error.Okay:
|
||||
expected_result = a >> bits
|
||||
|
||||
return test("test_shr_signed", res, args, expected_error, expected_result)
|
||||
|
||||
# TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
|
||||
#
|
||||
# The last two arguments in tests are the expected error and expected result.
|
||||
@@ -237,19 +348,20 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
|
||||
# You can override that by supplying an expected result as the last argument instead.
|
||||
|
||||
TESTS = {
|
||||
test_add_two: [
|
||||
test_add: [
|
||||
[ 1234, 5432],
|
||||
],
|
||||
test_sub_two: [
|
||||
test_sub: [
|
||||
[ 1234, 5432],
|
||||
],
|
||||
test_mul_two: [
|
||||
test_mul: [
|
||||
[ 1234, 5432],
|
||||
[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7]
|
||||
],
|
||||
test_div_two: [
|
||||
test_div: [
|
||||
[ 54321, 12345],
|
||||
[ 55431, 0, Error.Division_by_Zero],
|
||||
[ 12980742146337069150589594264770969721, 4611686018427387904 ],
|
||||
],
|
||||
test_log: [
|
||||
[ 3192, 1, Error.Invalid_Argument],
|
||||
@@ -260,29 +372,87 @@ TESTS = {
|
||||
test_pow: [
|
||||
[ 0, -1, Error.Math_Domain_Error ], # Math
|
||||
[ 0, 0 ], # 1
|
||||
[ 0, 2 ], # 0
|
||||
[ 42, -1,], # 0
|
||||
[ 42, 1 ], # 1
|
||||
[ 42, 0 ], # 42
|
||||
[ 42, 2 ], # 42*42
|
||||
|
||||
|
||||
[ 0, 2 ], # 0
|
||||
[ 42, -1,], # 0
|
||||
[ 42, 1 ], # 1
|
||||
[ 42, 0 ], # 42
|
||||
[ 42, 2 ], # 42*42
|
||||
],
|
||||
test_sqrt: [
|
||||
[ -1, Error.Invalid_Argument, ],
|
||||
[ 42, Error.Okay, ],
|
||||
[ 12345678901234567890, Error.Okay, ],
|
||||
[ 1298074214633706907132624082305024, Error.Okay, ],
|
||||
],
|
||||
test_shl_digit: [
|
||||
[ 3192, 1 ],
|
||||
[ 1298074214633706907132624082305024, 2 ],
|
||||
[ 1024, 3 ],
|
||||
],
|
||||
test_shr_digit: [
|
||||
[ 3680125442705055547392, 1 ],
|
||||
[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
|
||||
[ 219504133884436710204395031992179571, 2 ],
|
||||
],
|
||||
test_shl: [
|
||||
[ 3192, 1 ],
|
||||
[ 1298074214633706907132624082305024, 2 ],
|
||||
[ 1024, 3 ],
|
||||
],
|
||||
test_shr: [
|
||||
[ 3680125442705055547392, 1 ],
|
||||
[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
|
||||
[ 219504133884436710204395031992179571, 2 ],
|
||||
],
|
||||
test_shr_signed: [
|
||||
[ -611105530635358368578155082258244262, 12 ],
|
||||
[ -149195686190273039203651143129455, 12 ],
|
||||
[ 611105530635358368578155082258244262, 12 ],
|
||||
[ 149195686190273039203651143129455, 12 ],
|
||||
]
|
||||
}
|
||||
|
||||
total_passes = 0
|
||||
total_failures = 0
|
||||
|
||||
#
|
||||
# test_shr_signed also tests shr, so we're not going to test shr randomly.
|
||||
#
|
||||
RANDOM_TESTS = [
|
||||
test_add_two, test_sub_two, test_mul_two, test_div_two,
|
||||
test_log, test_pow,
|
||||
test_add, test_sub, test_mul, test_div,
|
||||
test_log, test_pow, test_sqrt,
|
||||
test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
|
||||
]
|
||||
SKIP_LARGE = [test_pow]
|
||||
SKIP_LARGEST = []
|
||||
|
||||
# Untimed warmup.
|
||||
for test_proc in TESTS:
|
||||
for t in TESTS[test_proc]:
|
||||
res = test_proc(*t)
|
||||
|
||||
|
||||
def isqrt(x):
|
||||
n = int(x)
|
||||
a, b = divmod(n.bit_length(), 2)
|
||||
print("isqrt({}), a: {}, b: {}". format(n, a, b))
|
||||
x = 2**(a+b)
|
||||
print("initial: {}".format(x))
|
||||
i = 0
|
||||
while True:
|
||||
# y = (x + n//x)//2
|
||||
t1 = n // x
|
||||
t2 = x + t1
|
||||
t3 = t2 // 2
|
||||
y = (x + n//x)//2
|
||||
|
||||
i += 1
|
||||
print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
|
||||
|
||||
if y >= x:
|
||||
return x
|
||||
x = y
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("---- math/big tests ----")
|
||||
print()
|
||||
@@ -317,7 +487,8 @@ if __name__ == '__main__':
|
||||
print()
|
||||
|
||||
for test_proc in RANDOM_TESTS:
|
||||
if test_proc == test_pow and BITS > 1_200: continue
|
||||
if BITS > 1_200 and test_proc in SKIP_LARGE: continue
|
||||
if BITS > 4_096 and test_proc in SKIP_LARGEST: continue
|
||||
|
||||
count_pass = 0
|
||||
count_fail = 0
|
||||
@@ -331,8 +502,10 @@ if __name__ == '__main__':
|
||||
a = randint(-(1 << BITS), 1 << BITS)
|
||||
b = randint(-(1 << BITS), 1 << BITS)
|
||||
|
||||
if test_proc == test_div_two:
|
||||
if test_proc == test_div:
|
||||
# We've already tested division by zero above.
|
||||
bits = int(BITS * 0.6)
|
||||
b = randint(-(1 << bits), 1 << bits)
|
||||
if b == 0:
|
||||
b == 42
|
||||
elif test_proc == test_log:
|
||||
@@ -341,6 +514,18 @@ if __name__ == '__main__':
|
||||
b = randint(2, 1 << 60)
|
||||
elif test_proc == test_pow:
|
||||
b = randint(1, 10)
|
||||
elif test_proc == test_sqrt:
|
||||
a = randint(1, 1 << BITS)
|
||||
b = Error.Okay
|
||||
elif test_proc == test_shl_digit:
|
||||
b = randint(0, 10);
|
||||
elif test_proc == test_shr_digit:
|
||||
a = abs(a)
|
||||
b = randint(0, 10);
|
||||
elif test_proc == test_shl:
|
||||
b = randint(0, min(BITS, 120));
|
||||
elif test_proc == test_shr_signed:
|
||||
b = randint(0, min(BITS, 120));
|
||||
else:
|
||||
b = randint(0, 1 << BITS)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user