big: Correct pow bugs from the original.

This commit is contained in:
Jeroen van Rijn
2021-07-24 09:59:38 +02:00
parent d953e40fb3
commit 5f63e3952e
3 changed files with 85 additions and 33 deletions
+3
View File
@@ -60,6 +60,9 @@ Error :: enum byte {
Buffer_Overflow = 6,
Integer_Overflow = 7,
Division_by_Zero = 8,
Math_Domain_Error = 9,
Unimplemented = 127,
};
+6 -7
View File
@@ -57,15 +57,14 @@ demo :: proc() {
a, b, c := &Int{}, &Int{}, &Int{};
defer destroy(a, b, c);
err = set(a, -1024);
err = set(b, -1024);
for base in -3..=3 {
for power in -3..=3 {
err = pow(a, base, power);
fmt.printf("err: %v | pow(%v, %v) = ", err, base, power); print("", a, 10);
}
}
print("a", a, 10);
print("b", b, 10);
fmt.println("--- mul ---");
mul(c, a, a);
print("c", c, 10);
}
main :: proc() {
+76 -26
View File
@@ -51,43 +51,93 @@ log :: proc { int_log, int_log_digit, };
Calculate c = a**b using a square-multiply algorithm.
*/
int_pow :: proc(dest, base: ^Int, power: int) -> (err: Error) {
if err = clear_if_uninitialized(dest); err != .None { return err; }
power := power;
if err = clear_if_uninitialized(base); err != .None { return err; }
if err = clear_if_uninitialized(dest); err != .None { return err; }
/*
Early outs.
*/
if z, _ := is_zero(base); z {
/*
A zero base is a special case.
*/
if power < 0 {
if err = zero(dest); err != .None { return err; }
return .Math_Domain_Error;
}
if power == 0 { return one(dest); }
if power > 0 { return zero(dest); }
// if ((err = mp_init_copy(&g, a)) != MP_OKAY) {
// return err;
// }
}
if power < 0 {
/*
Fraction, so we'll return zero.
*/
return zero(dest);
}
switch(power) {
case 0:
/*
Any base to the power zero is one.
*/
return one(dest);
case 1:
/*
Any base to the power one is itself.
*/
return copy(dest, base);
case 2:
return sqr(dest, base);
}
// /* set initial result */
// mp_set(c, 1uL);
g := &Int{};
if err = copy(g, base); err != .None { return err; }
// while (b > 0) {
// /* if the bit is set multiply */
// if ((b & 1) != 0) {
// if ((err = mp_mul(c, &g, c)) != MP_OKAY) {
// goto LBL_ERR;
// }
// }
/*
Set initial result.
*/
if err = set(dest, 1); err != .None { return err; }
// /* square */
// if (b > 1) {
// if ((err = mp_sqr(&g, &g)) != MP_OKAY) {
// goto LBL_ERR;
// }
// }
loop: for power > 0 {
/*
If the bit is set, multiply.
*/
if power & 1 != 0 {
if err = mul(dest, g, dest); err != .None {
break loop;
}
}
/*
Square.
*/
if power > 1 {
if err = sqr(g, g); err != .None {
break loop;
}
}
// /* shift to next bit */
// b >>= 1;
// }
/* shift to next bit */
power >>= 1;
}
// LBL_ERR:
// mp_clear(&g);
// return err;
destroy(g);
return err;
}
/*
Calculate c = a**b.
*/
int_pow_int :: proc(dest: ^Int, base, power: int) -> (err: Error) {
base_t := &Int{};
defer destroy(base_t);
pow :: proc { int_pow, };
if err = set(base_t, base); err != .None { return err; }
return int_pow(dest, base_t, power);
}
pow :: proc { int_pow, int_pow_int, };
exp :: pow;
/*
Returns the log2 of an `Int`, provided `base` is a power of two.