big: Fix sqrt, div, add with certain inputs.

This commit is contained in:
Jeroen van Rijn
2021-07-31 17:58:52 +02:00
parent 7afd1b15a8
commit 149c7b88df
7 changed files with 455 additions and 122 deletions
+13 -17
View File
@@ -359,7 +359,7 @@ int_halve :: proc(dest, src: ^Int) -> (err: Error) {
*/
fwd_carry := DIGIT(0);
for x := dest.used; x >= 0; x -= 1 {
for x := dest.used - 1; x >= 0; x -= 1 {
/*
Get the carry for the next iteration.
*/
@@ -761,21 +761,16 @@ sqrmod :: proc { int_sqrmod, };
*/
_int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
dest := dest; x := a; y := b;
if err = clear_if_uninitialized(x); err != .None {
return err;
}
if err = clear_if_uninitialized(y); err != .None {
return err;
}
old_used, min_used, max_used, i: int;
if x.used < y.used {
x, y = y, x;
assert(x.used >= y.used);
}
min_used = x.used;
max_used = y.used;
min_used = y.used;
max_used = x.used;
old_used = dest.used;
if err = grow(dest, max(max_used + 1, _DEFAULT_DIGIT_COUNT)); err != .None {
@@ -827,7 +822,6 @@ _int_add :: proc(dest, a, b: ^Int) -> (err: Error) {
Add remaining carry.
*/
dest.digit[i] = carry;
zero_count := old_used - dest.used;
/*
Zero remainder.
@@ -1111,6 +1105,7 @@ _int_div_3 :: proc(quotient, numerator: ^Int) -> (remainder: int, err: Error) {
_int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (err: Error) {
ta, tb, tq, q := &Int{}, &Int{}, &Int{}, &Int{};
c: int;
goto_end: for {
if err = one(tq); err != .None { break goto_end; }
@@ -1121,20 +1116,21 @@ _int_div_small :: proc(quotient, remainder, numerator, denominator: ^Int) -> (er
if err = abs(ta, numerator); err != .None { break goto_end; }
if err = abs(tb, denominator); err != .None { break goto_end; }
if err = shl(tb, tb, n); err != .None { break goto_end; }
if err = shl(tq, tq, n); err != .None { break goto_end; }
for ; n >= 0; n -= 1 {
c: int;
if c, err = cmp(tb, ta); err != .None { break goto_end; }
if c != 1 {
for n >= 0 {
if c, _ = cmp_mag(ta, tb); c == 0 || c == 1 {
// ta -= tb
if err = sub(ta, ta, tb); err != .None { break goto_end; }
if err = add( q, tq, q); err != .None { break goto_end; }
// q += tq
if err = add( q, q, tq); err != .None { break goto_end; }
}
if err = shr1(tb, tb); err != .None { break goto_end; }
if err = shr1(tq, tq); err != .None { break goto_end; }
}
n -= 1;
}
/*
Now q == quotient and ta == remainder.
-1
View File
@@ -1,5 +1,4 @@
@echo off
clear
:odin run . -vet
:odin build . -build-mode:shared -show-timings -o:minimal -use-separate-modules
odin build . -build-mode:shared -show-timings -o:size -use-separate-modules
+18 -9
View File
@@ -66,17 +66,26 @@ print :: proc(name: string, a: ^Int, base := i8(10)) {
}
demo :: proc() {
// err: Error;
// r := &rnd.Rand{};
// rnd.init(r, 12345);
err: Error;
destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(destination, source, quotient, remainder, numerator, denominator);
// as := cstring("12341234");
// bs := cstring("159671292010002348397151706347412301");
err = atoi(source, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
print("src ", source);
// res := test_log(as, 2, 10);
// fmt.print(res);
// destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
// defer destroy(destination, source, quotient, remainder, numerator, denominator);
fmt.println("sqrt should be 843478780275248664696797599030708027195155136953848512749494");
fmt.println();
err = sqrt(destination, source);
fmt.printf("sqrt returned: %v\n", err);
print("sqrt ", destination);
err = atoi(denominator, "711456452774621215865929644892071691538299606591173717356248653735056872543694196490784640730887936656406546625676792022", 10);
err = root_n(quotient, denominator, 2);
fmt.printf("root_n(2) returned: %v\n", err);
print("root_n(2)", quotient);
// fmt.println();
}
main :: proc() {
+58 -29
View File
@@ -214,42 +214,53 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
This function is less generic than `root_n`, simpler and faster.
*/
int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src); err != .None { return err; }
/* Must be positive. */
if src.sign == .Negative { return .Invalid_Argument; }
when true {
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src); err != .None { return err; }
/* Easy out. If src is zero, so is dest. */
if z, _ := is_zero(src); z { return zero(dest); }
/* Must be positive. */
if src.sign == .Negative { return .Invalid_Argument; }
/* Set up temporaries. */
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
/* Easy out. If src is zero, so is dest. */
if z, _ := is_zero(src); z { return zero(dest); }
if err = copy(t1, src); err != .None { return err; }
if err = zero(t2); err != .None { return err; }
/* Set up temporaries. */
x, y, t1, t2 := &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(x, y, t1, t2);
/* First approximation. Not very bad for large arguments. */
if err = shr_digit(t1, t1.used / 2); err != .None { return err; }
/* t1 > 0 */
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
count: int;
if count, err = count_bits(src); err != .None { return err; }
/* And now t1 > sqrt(arg). */
for {
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
/* t1 >= sqrt(arg) >= t2 at this point */
a, b := count >> 1, count & 1;
err = power_of_two(x, a+b);
cm, _ := cmp_mag(t1, t2);
if cm != 1 { break; }
iter := 0;
for {
iter += 1;
if iter > 100 {
swap(dest, x);
return .Max_Iterations_Reached;
}
/*
y = (x + n//x)//2
*/
div(t1, src, x);
add(t2, t1, x);
shr(y, t2, 1);
if c, _ := cmp(y, x); c == 0 || c == 1 {
swap(dest, x);
return .None;
}
swap(x, y);
}
swap(dest, x);
return err;
} else {
// return root_n(dest, src, 2);
}
swap(dest, t1);
return err;
}
sqrt :: proc { int_sqrt, };
@@ -263,7 +274,7 @@ sqrt :: proc { int_sqrt, };
*/
int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
/* Fast path for n == 2 */
if n == 2 { return sqrt(dest, src); }
// if n == 2 { return sqrt(dest, src); }
/* Initialize dest + src if needed. */
if err = clear_if_uninitialized(dest); err != .None { return err; }
@@ -321,6 +332,7 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
if err = power_of_two(t2, ilog2); err != .None { return err; }
c: int;
iterations := 0;
for {
/* t1 = t2 */
if err = copy(t1, t2); err != .None { return err; }
@@ -353,12 +365,23 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
break;
}
if c, err = cmp(t1, t2); c == 0 { break; }
iterations += 1;
if iterations == 101 {
return .Max_Iterations_Reached;
}
}
/* Result can be off by a few so check. */
/* Loop beneath can overshoot by one if found root is smaller than actual root. */
iterations = 0;
for {
if iterations == 101 {
return .Max_Iterations_Reached;
}
//fmt.printf("root_n iteration: %v\n", iterations);
iterations += 1;
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
@@ -372,8 +395,14 @@ int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
}
}
iterations = 0;
/* Correct overshoot from above or from recurrence. */
for {
if iterations == 101 {
return .Max_Iterations_Reached;
}
iterations += 1;
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
+5 -5
View File
@@ -285,7 +285,7 @@ int_shrmod :: proc(quotient, remainder, numerator: ^Int, bits: int) -> (err: Err
shift := DIGIT(_DIGIT_BITS - bits);
carry := DIGIT(0);
for x := quotient.used; x >= 0; x -= 1 {
for x := quotient.used - 1; x >= 0; x -= 1 {
/*
Get the lower bits of this word in a temp.
*/
@@ -344,7 +344,7 @@ int_shr_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
}
quotient.used -= digits;
_zero_unused(quotient);
return .None;
return clamp(quotient);
}
shr_digit :: proc { int_shr_digit, };
@@ -446,16 +446,16 @@ int_shl_digit :: proc(quotient: ^Int, digits: int) -> (err: Error) {
/*
Increment the used by the shift amount then copy upwards.
*/
quotient.used += digits;
/*
Much like `int_shr_digit`, this is implemented using a sliding window,
except the window goes the other way around.
*/
for x := quotient.used; x >= digits; x -= 1 {
quotient.digit[x] = quotient.digit[x - digits];
for x := quotient.used; x > 0; x -= 1 {
quotient.digit[x+digits-1] = quotient.digit[x-1];
}
quotient.used += digits;
mem.zero_slice(quotient.digit[:digits]);
return .None;
}
+137 -22
View File
@@ -26,74 +26,74 @@ PyRes :: struct {
return strings.clone_to_cstring(es[err], context.temp_allocator);
}
@export test_add_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
@export test_add :: proc "c" (a, b: cstring) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, sum := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, sum);
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add_two:atoi(b):", err=err}; }
if err = add(sum, aa, bb); err != .None { return PyRes{res=":add_two:add(sum,a,b):", err=err}; }
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":add:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":add:atoi(b):", err=err}; }
if err = add(sum, aa, bb); err != .None { return PyRes{res=":add:add(sum,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
if err != .None { return PyRes{res=":add_two:itoa(sum):", err=err}; }
if err != .None { return PyRes{res=":add:itoa(sum):", err=err}; }
return PyRes{res = r, err = .None};
}
@export test_sub_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
@export test_sub :: proc "c" (a, b: cstring) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, sum := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, sum);
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub_two:atoi(b):", err=err}; }
if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub_two:sub(sum,a,b):", err=err}; }
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":sub:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":sub:atoi(b):", err=err}; }
if err = sub(sum, aa, bb); err != .None { return PyRes{res=":sub:sub(sum,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(sum, 10, context.temp_allocator);
if err != .None { return PyRes{res=":sub_two:itoa(sum):", err=err}; }
if err != .None { return PyRes{res=":sub:itoa(sum):", err=err}; }
return PyRes{res = r, err = .None};
}
@export test_mul_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
@export test_mul :: proc "c" (a, b: cstring) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, product := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, product);
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul_two:atoi(b):", err=err}; }
if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul_two:mul(product,a,b):", err=err}; }
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":mul:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":mul:atoi(b):", err=err}; }
if err = mul(product, aa, bb); err != .None { return PyRes{res=":mul:mul(product,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(product, 10, context.temp_allocator);
if err != .None { return PyRes{res=":mul_two:itoa(product):", err=err}; }
if err != .None { return PyRes{res=":mul:itoa(product):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
NOTE(Jeroen): For simplicity, we don't return the quotient and the remainder, just the quotient.
*/
@export test_div_two :: proc "c" (a, b: cstring) -> (res: PyRes) {
@export test_div :: proc "c" (a, b: cstring) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
aa, bb, quotient := &Int{}, &Int{}, &Int{};
defer destroy(aa, bb, quotient);
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div_two:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div_two:atoi(b):", err=err}; }
if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div_two:div(quotient,a,b):", err=err}; }
if err = atoi(aa, string(a), 10); err != .None { return PyRes{res=":div:atoi(a):", err=err}; }
if err = atoi(bb, string(b), 10); err != .None { return PyRes{res=":div:atoi(b):", err=err}; }
if err = div(quotient, aa, bb); err != .None { return PyRes{res=":div:div(quotient,a,b):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(quotient, 10, context.temp_allocator);
if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
if err != .None { return PyRes{res=":div:itoa(quotient):", err=err}; }
return PyRes{res = r, err = .None};
}
@@ -130,7 +130,6 @@ PyRes :: struct {
@export test_pow :: proc "c" (base: cstring, power := int(2)) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
l: int;
dest, bb := &Int{}, &Int{};
defer destroy(dest, bb);
@@ -142,4 +141,120 @@ PyRes :: struct {
r, err = int_itoa_cstring(dest, 10, context.temp_allocator);
if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
}
/*
dest = sqrt(src)
*/
@export test_sqrt :: proc "c" (source: cstring) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":sqrt:atoi(src):", err=err}; }
if err = sqrt(src, src); err != .None { return PyRes{res=":sqrt:sqrt(src):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shr_digit(src, digits)
*/
@export test_shr_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_digit:atoi(src):", err=err}; }
if err = shr_digit(src, digits); err != .None { return PyRes{res=":shr_digit:shr_digit(src):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":shr_digit:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shl_digit(src, digits)
*/
@export test_shl_digit :: proc "c" (source: cstring, digits: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl_digit:atoi(src):", err=err}; }
if err = shl_digit(src, digits); err != .None { return PyRes{res=":shl_digit:shr_digit(src):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":shl_digit:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shr(src, bits)
*/
@export test_shr :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr:atoi(src):", err=err}; }
if err = shr(src, src, bits); err != .None { return PyRes{res=":shr:shr(src, bits):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":shr:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shr_signed(src, bits)
*/
@export test_shr_signed :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shr_signed:atoi(src):", err=err}; }
if err = shr_signed(src, src, bits); err != .None { return PyRes{res=":shr_signed:shr_signed(src, bits):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":shr_signed:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
/*
dest = shl(src, bits)
*/
@export test_shl :: proc "c" (source: cstring, bits: int) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
src := &Int{};
defer destroy(src);
if err = atoi(src, string(source), 10); err != .None { return PyRes{res=":shl:atoi(src):", err=err}; }
if err = shl(src, src, bits); err != .None { return PyRes{res=":shl:shl(src, bits):", err=err}; }
r: cstring;
r, err = int_itoa_cstring(src, 10, context.temp_allocator);
if err != .None { return PyRes{res=":shl:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
+224 -39
View File
@@ -1,6 +1,6 @@
from math import *
from ctypes import *
from random import *
import math
import os
import platform
import time
@@ -10,12 +10,14 @@ from enum import Enum
# Normally, we report the number of passes and fails.
# With EXIT_ON_FAIL set, we exit at the first fail.
#
EXIT_ON_FAIL = True
EXIT_ON_FAIL = False
#
# We skip randomized tests altogether if NO_RANDOM_TESTS is set.
#
NO_RANDOM_TESTS = False #True
NO_RANDOM_TESTS = True
NO_RANDOM_TESTS = False
#
# If TIMED_TESTS == False and FAST_TESTS == True, we cut down the number of iterations.
@@ -113,13 +115,23 @@ class Res(Structure):
error_string = load(l.test_error_string, [c_byte], c_char_p)
add_two = load(l.test_add_two, [c_char_p, c_char_p], Res)
sub_two = load(l.test_sub_two, [c_char_p, c_char_p], Res)
mul_two = load(l.test_mul_two, [c_char_p, c_char_p], Res)
div_two = load(l.test_div_two, [c_char_p, c_char_p], Res)
add = load(l.test_add, [c_char_p, c_char_p], Res)
sub = load(l.test_sub, [c_char_p, c_char_p], Res)
mul = load(l.test_mul, [c_char_p, c_char_p], Res)
div = load(l.test_div, [c_char_p, c_char_p], Res)
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
# Powers and such
int_log = load(l.test_log, [c_char_p, c_longlong], Res)
int_pow = load(l.test_pow, [c_char_p, c_longlong], Res)
int_sqrt = load(l.test_sqrt, [c_char_p], Res)
# Logical operations
int_shl_digit = load(l.test_shl_digit, [c_char_p, c_longlong], Res)
int_shr_digit = load(l.test_shr_digit, [c_char_p, c_longlong], Res)
int_shl = load(l.test_shl, [c_char_p, c_longlong], Res)
int_shr = load(l.test_shr, [c_char_p, c_longlong], Res)
int_shr_signed = load(l.test_shr_signed, [c_char_p, c_longlong], Res)
def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expected_result = ""):
passed = True
@@ -156,38 +168,38 @@ def test(test_name: "", res: Res, param=[], expected_error = Error.Okay, expecte
return passed
def test_add_two(a = 0, b = 0, expected_error = Error.Okay):
def test_add(a = 0, b = 0, expected_error = Error.Okay):
args = [str(a), str(b)]
sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
res = add_two(sa_c, sb_c)
res = add(sa_c, sb_c)
expected_result = None
if expected_error == Error.Okay:
expected_result = a + b
return test("test_add_two", res, args, expected_error, expected_result)
return test("test_add", res, args, expected_error, expected_result)
def test_sub_two(a = 0, b = 0, expected_error = Error.Okay):
def test_sub(a = 0, b = 0, expected_error = Error.Okay):
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
res = sub_two(sa_c, sb_c)
res = sub(sa_c, sb_c)
expected_result = None
if expected_error == Error.Okay:
expected_result = a - b
return test("test_sub_two", res, [sa_c, sb_c], expected_error, expected_result)
return test("test_sub", res, [sa_c, sb_c], expected_error, expected_result)
def test_mul_two(a = 0, b = 0, expected_error = Error.Okay):
def test_mul(a = 0, b = 0, expected_error = Error.Okay):
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
res = mul_two(sa_c, sb_c)
res = mul(sa_c, sb_c)
expected_result = None
if expected_error == Error.Okay:
expected_result = a * b
return test("test_mul_two", res, [sa_c, sb_c], expected_error, expected_result)
return test("test_mul", res, [sa_c, sb_c], expected_error, expected_result)
def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
def test_div(a = 0, b = 0, expected_error = Error.Okay):
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
try:
res = div_two(sa_c, sb_c)
res = div(sa_c, sb_c)
except:
print("Exception with arguments:", a, b)
return False
@@ -197,12 +209,12 @@ def test_div_two(a = 0, b = 0, expected_error = Error.Okay):
# We don't round the division results, so if one component is negative, we're off by one.
#
if a < 0 and b > 0:
expected_result = int(-(abs(a) / b))
expected_result = int(-(abs(a) // b))
elif b < 0 and a > 0:
expected_result = int(-(a / abs((b))))
expected_result = int(-(a // abs((b))))
else:
expected_result = a // b if b != 0 else None
return test("test_div_two", res, [sa_c, sb_c], expected_error, expected_result)
return test("test_div", res, [sa_c, sb_c], expected_error, expected_result)
def test_log(a = 0, base = 0, expected_error = Error.Okay):
@@ -212,7 +224,7 @@ def test_log(a = 0, base = 0, expected_error = Error.Okay):
expected_result = None
if expected_error == Error.Okay:
expected_result = int(log(a, base))
expected_result = int(math.log(a, base))
return test("test_log", res, args, expected_error, expected_result)
def test_pow(base = 0, power = 0, expected_error = Error.Okay):
@@ -225,9 +237,108 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
if power < 0:
expected_result = 0
else:
expected_result = int(base**power)
# NOTE(Jeroen): Don't use `math.pow`, it's a floating point approximation.
# Use built-in `pow` or `a**b` instead.
expected_result = pow(base, power)
return test("test_pow", res, args, expected_error, expected_result)
def test_sqrt(number = 0, expected_error = Error.Okay):
args = [str(number)]
sa_c = args[0].encode('utf-8')
try:
res = int_sqrt(sa_c)
except:
print("sqrt:", number)
expected_result = None
if expected_error == Error.Okay:
if number < 0:
expected_result = 0
else:
expected_result = int(math.isqrt(number))
return test("test_sqrt", res, args, expected_error, expected_result)
def root_n(number, root):
u, s = number, number + 1
while u < s:
s = u
t = (root-1) * s + number // pow(s, root - 1)
u = t // root
return s
def test_shl_digit(a = 0, digits = 0, expected_error = Error.Okay):
args = [str(a), digits]
sa_c = args[0].encode('utf-8')
res = int_shl_digit(sa_c, digits)
expected_result = None
if expected_error == Error.Okay:
expected_result = a << (digits * 60)
return test("test_shl_digit", res, args, expected_error, expected_result)
def test_shr_digit(a = 0, digits = 0, expected_error = Error.Okay):
args = [str(a), digits]
sa_c = args[0].encode('utf-8')
try:
res = int_shr_digit(sa_c, digits)
except:
print("int_shr_digit", a, digits)
exit()
expected_result = None
if expected_error == Error.Okay:
if a < 0:
# Don't pass negative numbers. We have a shr_signed.
return False
else:
expected_result = a >> (digits * 60)
return test("test_shr_digit", res, args, expected_error, expected_result)
def test_shl(a = 0, bits = 0, expected_error = Error.Okay):
args = [str(a), bits]
sa_c = args[0].encode('utf-8')
res = int_shl(sa_c, bits)
expected_result = None
if expected_error == Error.Okay:
expected_result = a << bits
return test("test_shl", res, args, expected_error, expected_result)
def test_shr(a = 0, bits = 0, expected_error = Error.Okay):
args = [str(a), bits]
sa_c = args[0].encode('utf-8')
try:
res = int_shr(sa_c, bits)
except:
print("int_shr", a, bits)
exit()
expected_result = None
if expected_error == Error.Okay:
if a < 0:
# Don't pass negative numbers. We have a shr_signed.
return False
else:
expected_result = a >> bits
return test("test_shr", res, args, expected_error, expected_result)
def test_shr_signed(a = 0, bits = 0, expected_error = Error.Okay):
args = [str(a), bits]
sa_c = args[0].encode('utf-8')
try:
res = int_shr_signed(sa_c, bits)
except:
print("int_shr_signed", a, bits)
exit()
expected_result = None
if expected_error == Error.Okay:
expected_result = a >> bits
return test("test_shr_signed", res, args, expected_error, expected_result)
# TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
#
# The last two arguments in tests are the expected error and expected result.
@@ -237,19 +348,20 @@ def test_pow(base = 0, power = 0, expected_error = Error.Okay):
# You can override that by supplying an expected result as the last argument instead.
TESTS = {
test_add_two: [
test_add: [
[ 1234, 5432],
],
test_sub_two: [
test_sub: [
[ 1234, 5432],
],
test_mul_two: [
test_mul: [
[ 1234, 5432],
[ 0xd3b4e926aaba3040e1c12b5ea553b5, 0x1a821e41257ed9281bee5bc7789ea7]
],
test_div_two: [
test_div: [
[ 54321, 12345],
[ 55431, 0, Error.Division_by_Zero],
[ 12980742146337069150589594264770969721, 4611686018427387904 ],
],
test_log: [
[ 3192, 1, Error.Invalid_Argument],
@@ -260,29 +372,87 @@ TESTS = {
test_pow: [
[ 0, -1, Error.Math_Domain_Error ], # Math
[ 0, 0 ], # 1
[ 0, 2 ], # 0
[ 42, -1,], # 0
[ 42, 1 ], # 1
[ 42, 0 ], # 42
[ 42, 2 ], # 42*42
[ 0, 2 ], # 0
[ 42, -1,], # 0
[ 42, 1 ], # 1
[ 42, 0 ], # 42
[ 42, 2 ], # 42*42
],
test_sqrt: [
[ -1, Error.Invalid_Argument, ],
[ 42, Error.Okay, ],
[ 12345678901234567890, Error.Okay, ],
[ 1298074214633706907132624082305024, Error.Okay, ],
],
test_shl_digit: [
[ 3192, 1 ],
[ 1298074214633706907132624082305024, 2 ],
[ 1024, 3 ],
],
test_shr_digit: [
[ 3680125442705055547392, 1 ],
[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
[ 219504133884436710204395031992179571, 2 ],
],
test_shl: [
[ 3192, 1 ],
[ 1298074214633706907132624082305024, 2 ],
[ 1024, 3 ],
],
test_shr: [
[ 3680125442705055547392, 1 ],
[ 1725436586697640946858688965569256363112777243042596638790631055949824, 2 ],
[ 219504133884436710204395031992179571, 2 ],
],
test_shr_signed: [
[ -611105530635358368578155082258244262, 12 ],
[ -149195686190273039203651143129455, 12 ],
[ 611105530635358368578155082258244262, 12 ],
[ 149195686190273039203651143129455, 12 ],
]
}
total_passes = 0
total_failures = 0
#
# test_shr_signed also tests shr, so we're not going to test shr randomly.
#
RANDOM_TESTS = [
test_add_two, test_sub_two, test_mul_two, test_div_two,
test_log, test_pow,
test_add, test_sub, test_mul, test_div,
test_log, test_pow, test_sqrt,
test_shl_digit, test_shr_digit, test_shl, test_shr_signed,
]
SKIP_LARGE = [test_pow]
SKIP_LARGEST = []
# Untimed warmup.
for test_proc in TESTS:
for t in TESTS[test_proc]:
res = test_proc(*t)
def isqrt(x):
n = int(x)
a, b = divmod(n.bit_length(), 2)
print("isqrt({}), a: {}, b: {}". format(n, a, b))
x = 2**(a+b)
print("initial: {}".format(x))
i = 0
while True:
# y = (x + n//x)//2
t1 = n // x
t2 = x + t1
t3 = t2 // 2
y = (x + n//x)//2
i += 1
print("iter {}\n\t x: {}\n\t y: {}\n\tt1: {}\n\tt2: {}\n\tsrc: {}".format(i, x, y, t1, t2, n));
if y >= x:
return x
x = y
if __name__ == '__main__':
print("---- math/big tests ----")
print()
@@ -317,7 +487,8 @@ if __name__ == '__main__':
print()
for test_proc in RANDOM_TESTS:
if test_proc == test_pow and BITS > 1_200: continue
if BITS > 1_200 and test_proc in SKIP_LARGE: continue
if BITS > 4_096 and test_proc in SKIP_LARGEST: continue
count_pass = 0
count_fail = 0
@@ -331,8 +502,10 @@ if __name__ == '__main__':
a = randint(-(1 << BITS), 1 << BITS)
b = randint(-(1 << BITS), 1 << BITS)
if test_proc == test_div_two:
if test_proc == test_div:
# We've already tested division by zero above.
bits = int(BITS * 0.6)
b = randint(-(1 << bits), 1 << bits)
if b == 0:
b == 42
elif test_proc == test_log:
@@ -341,6 +514,18 @@ if __name__ == '__main__':
b = randint(2, 1 << 60)
elif test_proc == test_pow:
b = randint(1, 10)
elif test_proc == test_sqrt:
a = randint(1, 1 << BITS)
b = Error.Okay
elif test_proc == test_shl_digit:
b = randint(0, 10);
elif test_proc == test_shr_digit:
a = abs(a)
b = randint(0, 10);
elif test_proc == test_shl:
b = randint(0, min(BITS, 120));
elif test_proc == test_shr_signed:
b = randint(0, min(BITS, 120));
else:
b = randint(0, 1 << BITS)