big: Add tests for log.

This commit is contained in:
Jeroen van Rijn
2021-07-29 21:12:59 +02:00
parent 922df6a438
commit 385b9c9922
4 changed files with 123 additions and 79 deletions
+2 -2
View File
@@ -70,10 +70,10 @@ demo :: proc() {
// r := &rnd.Rand{};
// rnd.init(r, 12345);
// as := cstring("596360079055148742691396559496540363");
// as := cstring("12341234");
// bs := cstring("159671292010002348397151706347412301");
// res := test_div_two(as, bs);
// res := test_log(as, 2, 10);
// fmt.print(res);
// destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{};
// defer destroy(destination, source, quotient, remainder, numerator, denominator);
+2 -2
View File
@@ -14,8 +14,8 @@ int_log :: proc(a: ^Int, base: DIGIT) -> (res: int, err: Error) {
return -1, .Invalid_Argument;
}
if err = clear_if_uninitialized(a); err != .None { return -1, err; }
if n, _ := is_neg(a); n { return -1, .Invalid_Argument; }
if z, _ := is_zero(a); z { return -1, .Invalid_Argument; }
if n, _ := is_neg(a); n { return -1, .Math_Domain_Error; }
if z, _ := is_zero(a); z { return -1, .Math_Domain_Error; }
/*
Fast path for bases that are a power of two.
+28 -1
View File
@@ -95,4 +95,31 @@ PyRes :: struct {
r, err = int_itoa_cstring(quotient, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; }
return PyRes{res = r, err = .None};
}
}
/*
res = log(a, base)
*/
@export test_log :: proc "c" (a: cstring, base := DIGIT(2), radix := int(10)) -> (res: PyRes) {
context = runtime.default_context();
err: Error;
l: int;
aa := &Int{};
defer destroy(aa);
if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":log:atoi(a):", err=err}; }
if l, err = log(aa, base); err != .None { return PyRes{res=":log:log(a, base):", err=err}; }
zero(aa);
aa.digit[0] = DIGIT(l) & _MASK;
aa.digit[1] = DIGIT(l) >> _DIGIT_BITS;
aa.used = 2;
clamp(aa);
r: cstring;
r, err = int_itoa_cstring(aa, i8(radix), context.temp_allocator);
if err != .None { return PyRes{res=":log:itoa(res):", err=err}; }
return PyRes{res = r, err = .None};
}
+91 -74
View File
@@ -6,7 +6,12 @@ import platform
import time
#
# Where is the DLL? If missing, build using: `odin build . -build-mode:dll`
# Fast tests?
#
FAST_TESTS = True
#
# Where is the DLL? If missing, build using: `odin build . -build-mode:shared`
#
if platform.system() == "Windows":
LIB_PATH = os.getcwd() + os.sep + "big.dll"
@@ -16,7 +21,8 @@ elif platform.system() == "Darwin":
LIB_PATH = os.getcwd() + os.sep + "big.dylib"
else:
print("Platform is unsupported.")
os.exit(1)
exit(1)
#
# How many iterations of each random test do we want to run?
#
@@ -27,6 +33,14 @@ BITS_AND_ITERATIONS = [
(12_000, 10),
]
#
# Fast tests?
#
if FAST_TESTS:
for k in range(len(BITS_AND_ITERATIONS)):
b, i = BITS_AND_ITERATIONS[k]
BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5)
#
# Result values will be passed in a struct { res: cstring, err: Error }
#
@@ -58,76 +72,53 @@ except:
print("Couldn't find or load " + LIB_PATH + ".")
exit(1)
def load(export_name, args, res):
export_name.argtypes = args
export_name.restype = res
return export_name
error_string = load(l.test_error_string, [c_byte], c_char_p)
#
# res = a + b, err
#
try:
l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_add_two.restype = Res
except:
print("Couldn't find exported function 'test_add_two'")
exit(2)
add_two = l.test_add_two
add_two = load(l.test_add_two, [c_char_p, c_char_p, c_longlong], Res)
#
# res = a - b, err
#
try:
l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_sub_two.restype = Res
except:
print("Couldn't find exported function 'test_sub_two'")
exit(2)
sub_two = l.test_sub_two
sub_two = load(l.test_sub_two, [c_char_p, c_char_p, c_longlong], Res)
#
# res = a * b, err
#
try:
l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_mul_two.restype = Res
except:
print("Couldn't find exported function 'test_add_two'")
exit(2)
mul_two = l.test_mul_two
mul_two = load(l.test_mul_two, [c_char_p, c_char_p, c_longlong], Res)
#
# res = a / b, err
#
try:
l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong]
l.test_div_two.restype = Res
except:
print("Couldn't find exported function 'test_div_two'")
exit(2)
div_two = l.test_div_two
div_two = load(l.test_div_two, [c_char_p, c_char_p, c_longlong], Res)
#
# res = log(a, base)
#
int_log = load(l.test_log, [c_char_p, c_longlong, c_longlong], Res)
try:
l.test_error_string.argtypes = [c_byte]
l.test_error_string.restype = c_char_p
except:
print("Couldn't find exported function 'test_error_string'")
exit(2)
def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_result = ""):
passed = True
r = None
if res.err != expected_error:
error_type = l.test_error_string(res.err).decode('utf-8')
error_type = error_string(res.err).decode('utf-8')
error_loc = res.res.decode('utf-8')
error_string = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
error = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc)
if len(param):
error_string += " with params {}".format(param)
error += " with params {}".format(param)
print(error_string, flush=True)
print(error, flush=True)
passed = False
elif res.err == E_None:
try:
@@ -137,50 +128,43 @@ def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_re
r = eval(res.res)
if r != expected_result:
error_string = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result)
if len(param):
error_string += " with params {}".format(param)
error += " with params {}".format(param)
print(error_string, flush=True)
print(error, flush=True)
passed = False
return passed
def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
sa = str(a)
sb = str(b)
sa_c = sa.encode('utf-8')
sb_c = sb.encode('utf-8')
args = [str(a), str(b), radix]
sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8')
res = add_two(sa_c, sb_c, radix)
if expected_result == None:
expected_result = a + b
return test("test_add_two", res, [sa, sb, radix], expected_error, expected_result)
return test("test_add_two", res, args, expected_error, expected_result)
def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
sa = str(a)
sb = str(b)
sa_c = sa.encode('utf-8')
sb_c = sb.encode('utf-8')
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
res = sub_two(sa_c, sb_c, radix)
if expected_result == None:
expected_result = a - b
return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
sa = str(a)
sb = str(b)
sa_c = sa.encode('utf-8')
sb_c = sb.encode('utf-8')
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
res = mul_two(sa_c, sb_c, radix)
if expected_result == None:
expected_result = a * b
return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None):
sa = str(a)
sb = str(b)
sa_c = sa.encode('utf-8')
sb_c = sb.encode('utf-8')
sa, sb = str(a), str(b)
sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8')
try:
res = div_two(sa_c, sb_c, radix)
except:
@@ -190,6 +174,17 @@ def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_res
expected_result = a // b if b != 0 else None
return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result)
def test_log(a = 0, base = 0, radix = 10, expected_error = E_None, expected_result = None):
args = [str(a), base, radix]
sa_c = args[0].encode('utf-8')
res = int_log(sa_c, base, radix)
if expected_result == None:
expected_result = int(log(a, base))
return test("test_log", res, args, expected_error, expected_result)
# TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on.
#
# The last two arguments in tests are the expected error and expected result.
@@ -211,15 +206,26 @@ TESTS = {
[ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ]
],
test_div_two: [
[ 54321, 12345, 10, ],
[ 55431, 0, 10, E_Division_by_Zero, ],
[ 54321, 12345, 10, ],
[ 55431, 0, 10, E_Division_by_Zero, ],
],
test_log: [
[ 3192, 1, 10, E_Invalid_Argument, ":log:log(a, base):"],
[ -1234, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"],
[ 0, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"],
[ 1024, 2, 10, ],
],
}
TOTAL_TIME = 0
total_passes = 0
total_failures = 0
if __name__ == '__main__':
test_log(1234, 2, 10)
print("---- core:math/big tests ----")
print()
@@ -240,6 +246,7 @@ if __name__ == '__main__':
if res:
count_pass += 1
total_passes += 1
else:
count_fail += 1
total_failures += 1
@@ -251,19 +258,25 @@ if __name__ == '__main__':
print("---- core:math/big with two random {bits:,} bit numbers ----".format(bits=BITS))
print()
for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]:
for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two, test_log]:
count_pass = 0
count_fail = 0
TIMINGS = {}
for i in range(ITERATIONS):
a = randint(0, 1 << BITS)
b = randint(0, 1 << BITS)
res = None
# We've already tested division by zero above.
if b == 0 and test_proc == test_div_two:
b = b + 1
if test_proc == test_div_two:
# We've already tested division by zero above.
b = randint(1, 1 << BITS)
elif test_proc == test_log:
# We've already tested log's domain errors.
a = randint(1, 1 << BITS)
b = randint(2, 1 << 60)
else:
b = randint(0, 1 << BITS)
res = None
start = time.perf_counter()
res = test_proc(a, b)
@@ -277,13 +290,17 @@ if __name__ == '__main__':
if res:
count_pass += 1
total_passes += 1
else:
count_fail += 1
total_failures += 1
print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000))
print("\ntotal: {0:.3f} ms".format(TOTAL_TIME * 1_000))
print()
print("---- THE END ----")
print()
print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000))
if total_failures:
os.exit(1)
exit(1)