From 5f63e3952e3345a340cd4c7a1b38f08530277f89 Mon Sep 17 00:00:00 2001 From: Jeroen van Rijn Date: Sat, 24 Jul 2021 09:59:38 +0200 Subject: [PATCH] big: Correct `pow` bugs from the original. --- core/math/big/common.odin | 3 ++ core/math/big/example.odin | 13 +++-- core/math/big/log.odin | 102 +++++++++++++++++++++++++++---------- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/core/math/big/common.odin b/core/math/big/common.odin index fcfdb3973..dd17a678f 100644 --- a/core/math/big/common.odin +++ b/core/math/big/common.odin @@ -60,6 +60,9 @@ Error :: enum byte { Buffer_Overflow = 6, Integer_Overflow = 7, + Division_by_Zero = 8, + Math_Domain_Error = 9, + Unimplemented = 127, }; diff --git a/core/math/big/example.odin b/core/math/big/example.odin index a9dde60f2..a6ba77667 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -57,15 +57,14 @@ demo :: proc() { a, b, c := &Int{}, &Int{}, &Int{}; defer destroy(a, b, c); - err = set(a, -1024); - err = set(b, -1024); + for base in -3..=3 { + for power in -3..=3 { + err = pow(a, base, power); + fmt.printf("err: %v | pow(%v, %v) = ", err, base, power); print("", a, 10); + } + } - print("a", a, 10); - print("b", b, 10); - fmt.println("--- mul ---"); - mul(c, a, a); - print("c", c, 10); } main :: proc() { diff --git a/core/math/big/log.odin b/core/math/big/log.odin index 14b4c593b..632973ec8 100644 --- a/core/math/big/log.odin +++ b/core/math/big/log.odin @@ -51,43 +51,93 @@ log :: proc { int_log, int_log_digit, }; Calculate c = a**b using a square-multiply algorithm. */ int_pow :: proc(dest, base: ^Int, power: int) -> (err: Error) { - if err = clear_if_uninitialized(dest); err != .None { return err; } + power := power; if err = clear_if_uninitialized(base); err != .None { return err; } + if err = clear_if_uninitialized(dest); err != .None { return err; } + /* + Early outs. + */ + if z, _ := is_zero(base); z { + /* + A zero base is a special case. + */ + if power < 0 { + if err = zero(dest); err != .None { return err; } + return .Math_Domain_Error; + } + if power == 0 { return one(dest); } + if power > 0 { return zero(dest); } -// if ((err = mp_init_copy(&g, a)) != MP_OKAY) { -// return err; -// } + } + if power < 0 { + /* + Fraction, so we'll return zero. + */ + return zero(dest); + } + switch(power) { + case 0: + /* + Any base to the power zero is one. + */ + return one(dest); + case 1: + /* + Any base to the power one is itself. + */ + return copy(dest, base); + case 2: + return sqr(dest, base); + } -// /* set initial result */ -// mp_set(c, 1uL); + g := &Int{}; + if err = copy(g, base); err != .None { return err; } -// while (b > 0) { -// /* if the bit is set multiply */ -// if ((b & 1) != 0) { -// if ((err = mp_mul(c, &g, c)) != MP_OKAY) { -// goto LBL_ERR; -// } -// } + /* + Set initial result. + */ + if err = set(dest, 1); err != .None { return err; } -// /* square */ -// if (b > 1) { -// if ((err = mp_sqr(&g, &g)) != MP_OKAY) { -// goto LBL_ERR; -// } -// } + loop: for power > 0 { + /* + If the bit is set, multiply. + */ + if power & 1 != 0 { + if err = mul(dest, g, dest); err != .None { + break loop; + } + } + /* + Square. + */ + if power > 1 { + if err = sqr(g, g); err != .None { + break loop; + } + } -// /* shift to next bit */ -// b >>= 1; -// } + /* shift to next bit */ + power >>= 1; + } -// LBL_ERR: -// mp_clear(&g); -// return err; + destroy(g); return err; } +/* + Calculate c = a**b. +*/ +int_pow_int :: proc(dest: ^Int, base, power: int) -> (err: Error) { + base_t := &Int{}; + defer destroy(base_t); -pow :: proc { int_pow, }; + if err = set(base_t, base); err != .None { return err; } + + return int_pow(dest, base_t, power); +} + +pow :: proc { int_pow, int_pow_int, }; +exp :: pow; /* Returns the log2 of an `Int`, provided `base` is a power of two.