diff --git a/core/math/big/example.odin b/core/math/big/example.odin index e677af8b9..5d30c85ae 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -70,10 +70,10 @@ demo :: proc() { // r := &rnd.Rand{}; // rnd.init(r, 12345); - // as := cstring("596360079055148742691396559496540363"); + // as := cstring("12341234"); // bs := cstring("159671292010002348397151706347412301"); - // res := test_div_two(as, bs); + // res := test_log(as, 2, 10); // fmt.print(res); // destination, source, quotient, remainder, numerator, denominator := &Int{}, &Int{}, &Int{}, &Int{}, &Int{}, &Int{}; // defer destroy(destination, source, quotient, remainder, numerator, denominator); diff --git a/core/math/big/exp_log.odin b/core/math/big/exp_log.odin index 943fd51e9..4c31a99d8 100644 --- a/core/math/big/exp_log.odin +++ b/core/math/big/exp_log.odin @@ -14,8 +14,8 @@ int_log :: proc(a: ^Int, base: DIGIT) -> (res: int, err: Error) { return -1, .Invalid_Argument; } if err = clear_if_uninitialized(a); err != .None { return -1, err; } - if n, _ := is_neg(a); n { return -1, .Invalid_Argument; } - if z, _ := is_zero(a); z { return -1, .Invalid_Argument; } + if n, _ := is_neg(a); n { return -1, .Math_Domain_Error; } + if z, _ := is_zero(a); z { return -1, .Math_Domain_Error; } /* Fast path for bases that are a power of two. diff --git a/core/math/big/test.odin b/core/math/big/test.odin index 0b0d65c26..d05d3ecaa 100644 --- a/core/math/big/test.odin +++ b/core/math/big/test.odin @@ -95,4 +95,31 @@ PyRes :: struct { r, err = int_itoa_cstring(quotient, i8(radix), context.temp_allocator); if err != .None { return PyRes{res=":div_two:itoa(quotient):", err=err}; } return PyRes{res = r, err = .None}; -} \ No newline at end of file +} + + +/* + res = log(a, base) +*/ +@export test_log :: proc "c" (a: cstring, base := DIGIT(2), radix := int(10)) -> (res: PyRes) { + context = runtime.default_context(); + err: Error; + l: int; + + aa := &Int{}; + defer destroy(aa); + + if err = atoi(aa, string(a), i8(radix)); err != .None { return PyRes{res=":log:atoi(a):", err=err}; } + if l, err = log(aa, base); err != .None { return PyRes{res=":log:log(a, base):", err=err}; } + + zero(aa); + aa.digit[0] = DIGIT(l) & _MASK; + aa.digit[1] = DIGIT(l) >> _DIGIT_BITS; + aa.used = 2; + clamp(aa); + + r: cstring; + r, err = int_itoa_cstring(aa, i8(radix), context.temp_allocator); + if err != .None { return PyRes{res=":log:itoa(res):", err=err}; } + return PyRes{res = r, err = .None}; +} diff --git a/core/math/big/test.py b/core/math/big/test.py index e00af473d..f325d7787 100644 --- a/core/math/big/test.py +++ b/core/math/big/test.py @@ -6,7 +6,12 @@ import platform import time # -# Where is the DLL? If missing, build using: `odin build . -build-mode:dll` +# Fast tests? +# +FAST_TESTS = True + +# +# Where is the DLL? If missing, build using: `odin build . -build-mode:shared` # if platform.system() == "Windows": LIB_PATH = os.getcwd() + os.sep + "big.dll" @@ -16,7 +21,8 @@ elif platform.system() == "Darwin": LIB_PATH = os.getcwd() + os.sep + "big.dylib" else: print("Platform is unsupported.") - os.exit(1) + exit(1) + # # How many iterations of each random test do we want to run? # @@ -27,6 +33,14 @@ BITS_AND_ITERATIONS = [ (12_000, 10), ] +# +# Fast tests? +# +if FAST_TESTS: + for k in range(len(BITS_AND_ITERATIONS)): + b, i = BITS_AND_ITERATIONS[k] + BITS_AND_ITERATIONS[k] = (b, i // 10 if i >= 100 else 5) + # # Result values will be passed in a struct { res: cstring, err: Error } # @@ -58,76 +72,53 @@ except: print("Couldn't find or load " + LIB_PATH + ".") exit(1) +def load(export_name, args, res): + export_name.argtypes = args + export_name.restype = res + return export_name + +error_string = load(l.test_error_string, [c_byte], c_char_p) + # # res = a + b, err # -try: - l.test_add_two.argtypes = [c_char_p, c_char_p, c_longlong] - l.test_add_two.restype = Res -except: - print("Couldn't find exported function 'test_add_two'") - exit(2) - -add_two = l.test_add_two +add_two = load(l.test_add_two, [c_char_p, c_char_p, c_longlong], Res) # # res = a - b, err # -try: - l.test_sub_two.argtypes = [c_char_p, c_char_p, c_longlong] - l.test_sub_two.restype = Res -except: - print("Couldn't find exported function 'test_sub_two'") - exit(2) - -sub_two = l.test_sub_two +sub_two = load(l.test_sub_two, [c_char_p, c_char_p, c_longlong], Res) # # res = a * b, err # -try: - l.test_mul_two.argtypes = [c_char_p, c_char_p, c_longlong] - l.test_mul_two.restype = Res -except: - print("Couldn't find exported function 'test_add_two'") - exit(2) - -mul_two = l.test_mul_two +mul_two = load(l.test_mul_two, [c_char_p, c_char_p, c_longlong], Res) # # res = a / b, err # -try: - l.test_div_two.argtypes = [c_char_p, c_char_p, c_longlong] - l.test_div_two.restype = Res -except: - print("Couldn't find exported function 'test_div_two'") - exit(2) - -div_two = l.test_div_two +div_two = load(l.test_div_two, [c_char_p, c_char_p, c_longlong], Res) +# +# res = log(a, base) +# +int_log = load(l.test_log, [c_char_p, c_longlong, c_longlong], Res) -try: - l.test_error_string.argtypes = [c_byte] - l.test_error_string.restype = c_char_p -except: - print("Couldn't find exported function 'test_error_string'") - exit(2) def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_result = ""): passed = True r = None if res.err != expected_error: - error_type = l.test_error_string(res.err).decode('utf-8') + error_type = error_string(res.err).decode('utf-8') error_loc = res.res.decode('utf-8') - error_string = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc) + error = "{}: '{}' error in '{}'".format(test_name, error_type, error_loc) if len(param): - error_string += " with params {}".format(param) + error += " with params {}".format(param) - print(error_string, flush=True) + print(error, flush=True) passed = False elif res.err == E_None: try: @@ -137,50 +128,43 @@ def test(test_name: "", res: Res, param=[], expected_error = E_None, expected_re r = eval(res.res) if r != expected_result: - error_string = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result) + error = "{}: Result was '{}', expected '{}'".format(test_name, r, expected_result) if len(param): - error_string += " with params {}".format(param) + error += " with params {}".format(param) - print(error_string, flush=True) + print(error, flush=True) passed = False return passed + def test_add_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None): - sa = str(a) - sb = str(b) - sa_c = sa.encode('utf-8') - sb_c = sb.encode('utf-8') + args = [str(a), str(b), radix] + sa_c, sb_c = args[0].encode('utf-8'), args[1].encode('utf-8') res = add_two(sa_c, sb_c, radix) if expected_result == None: expected_result = a + b - return test("test_add_two", res, [sa, sb, radix], expected_error, expected_result) + return test("test_add_two", res, args, expected_error, expected_result) def test_sub_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None): - sa = str(a) - sb = str(b) - sa_c = sa.encode('utf-8') - sb_c = sb.encode('utf-8') + sa, sb = str(a), str(b) + sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') res = sub_two(sa_c, sb_c, radix) if expected_result == None: expected_result = a - b return test("test_sub_two", res, [sa_c, sb_c, radix], expected_error, expected_result) def test_mul_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None): - sa = str(a) - sb = str(b) - sa_c = sa.encode('utf-8') - sb_c = sb.encode('utf-8') + sa, sb = str(a), str(b) + sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') res = mul_two(sa_c, sb_c, radix) if expected_result == None: expected_result = a * b return test("test_mul_two", res, [sa_c, sb_c, radix], expected_error, expected_result) def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_result = None): - sa = str(a) - sb = str(b) - sa_c = sa.encode('utf-8') - sb_c = sb.encode('utf-8') + sa, sb = str(a), str(b) + sa_c, sb_c = sa.encode('utf-8'), sb.encode('utf-8') try: res = div_two(sa_c, sb_c, radix) except: @@ -190,6 +174,17 @@ def test_div_two(a = 0, b = 0, radix = 10, expected_error = E_None, expected_res expected_result = a // b if b != 0 else None return test("test_div_two", res, [sa_c, sb_c, radix], expected_error, expected_result) + +def test_log(a = 0, base = 0, radix = 10, expected_error = E_None, expected_result = None): + args = [str(a), base, radix] + sa_c = args[0].encode('utf-8') + res = int_log(sa_c, base, radix) + + if expected_result == None: + expected_result = int(log(a, base)) + return test("test_log", res, args, expected_error, expected_result) + + # TODO(Jeroen): Make sure tests cover edge cases, fast paths, and so on. # # The last two arguments in tests are the expected error and expected result. @@ -211,15 +206,26 @@ TESTS = { [ 1099243943008198766717263669950239669, 137638828577110581150675834234248871, 10, ] ], test_div_two: [ - [ 54321, 12345, 10, ], - [ 55431, 0, 10, E_Division_by_Zero, ], + [ 54321, 12345, 10, ], + [ 55431, 0, 10, E_Division_by_Zero, ], + ], + test_log: [ + [ 3192, 1, 10, E_Invalid_Argument, ":log:log(a, base):"], + [ -1234, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"], + [ 0, 2, 10, E_Math_Domain_Error, ":log:log(a, base):"], + [ 1024, 2, 10, ], ], } TOTAL_TIME = 0 +total_passes = 0 total_failures = 0 + if __name__ == '__main__': + + test_log(1234, 2, 10) + print("---- core:math/big tests ----") print() @@ -240,6 +246,7 @@ if __name__ == '__main__': if res: count_pass += 1 + total_passes += 1 else: count_fail += 1 total_failures += 1 @@ -251,19 +258,25 @@ if __name__ == '__main__': print("---- core:math/big with two random {bits:,} bit numbers ----".format(bits=BITS)) print() - for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two]: + for test_proc in [test_add_two, test_sub_two, test_mul_two, test_div_two, test_log]: count_pass = 0 count_fail = 0 TIMINGS = {} for i in range(ITERATIONS): a = randint(0, 1 << BITS) - b = randint(0, 1 << BITS) - res = None - # We've already tested division by zero above. - if b == 0 and test_proc == test_div_two: - b = b + 1 + if test_proc == test_div_two: + # We've already tested division by zero above. + b = randint(1, 1 << BITS) + elif test_proc == test_log: + # We've already tested log's domain errors. + a = randint(1, 1 << BITS) + b = randint(2, 1 << 60) + else: + b = randint(0, 1 << BITS) + + res = None start = time.perf_counter() res = test_proc(a, b) @@ -277,13 +290,17 @@ if __name__ == '__main__': if res: count_pass += 1 + total_passes += 1 else: count_fail += 1 total_failures += 1 print("{name}: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(name=test_proc.__name__, count_pass=count_pass, count_fail=count_fail, timing=TIMINGS[test_proc] * 1_000)) - print("\ntotal: {0:.3f} ms".format(TOTAL_TIME * 1_000)) + print() + print("---- THE END ----") + print() + print("total: {count_pass:,} passes and {count_fail:,} failures in {timing:.3f} ms.".format(count_pass=total_passes, count_fail=total_failures, timing=TOTAL_TIME * 1_000)) if total_failures: - os.exit(1) \ No newline at end of file + exit(1) \ No newline at end of file