big: Add root_n.

This commit is contained in:
Jeroen van Rijn
2021-07-27 20:33:32 +02:00
parent 2aae1016ab
commit 531c4936dd
3 changed files with 195 additions and 51 deletions
-44
View File
@@ -749,50 +749,6 @@ int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) {
}
sqrmod :: proc { int_sqrmod, };
/*
This function is less generic than `nth_root`, simpler and faster.
*/
int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src); err != .None { return err; }
/* Must be positive. */
if src.sign == .Negative { return .Invalid_Argument; }
/* Easy out. If src is zero, so is dest. */
if z, _ := is_zero(src); z { return zero(dest); }
/* Set up temporaries. */
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if err = copy(t1, src); err != .None { return err; }
if err = zero(t2); err != .None { return err; }
/* First approximation. Not very bad for large arguments. */
if err = shr_digit(t1, t1.used / 2); err != .None { return err; }
/* t1 > 0 */
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
/* And now t1 > sqrt(arg). */
for {
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
/* t1 >= sqrt(arg) >= t2 at this point */
cm, _ := cmp_mag(t1, t2);
if cm != 1 { break; }
}
swap(dest, t1);
return err;
}
sqrt :: proc { int_sqrt, };
/*
==========================
Low-level routines
+11 -7
View File
@@ -63,15 +63,19 @@ demo :: proc() {
// defer delete(string_buffer);
err = set (numerator, 1024);
err = int_sqrt(destination, numerator);
err = sqrt(destination, numerator);
fmt.printf("int_sqrt returned: %v\n", err);
print("destination", destination);
// print("source ", source);
// print("quotient ", quotient);
// print("remainder ", remainder);
print("numerator ", numerator);
// print("denominator", denominator);
print("num ", numerator);
print("sqrt(num)", destination);
fmt.println("\n\n");
err = root_n(destination, numerator, 2);
fmt.printf("root_n(2) returned: %v\n", err);
print("num ", numerator);
print("root_n(num)", destination);
}
main :: proc() {
@@ -210,6 +210,190 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) {
}
}
/*
This function is less generic than `root_n`, simpler and faster.
*/
int_sqrt :: proc(dest, src: ^Int) -> (err: Error) {
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src); err != .None { return err; }
/* Must be positive. */
if src.sign == .Negative { return .Invalid_Argument; }
/* Easy out. If src is zero, so is dest. */
if z, _ := is_zero(src); z { return zero(dest); }
/* Set up temporaries. */
t1, t2 := &Int{}, &Int{};
defer destroy(t1, t2);
if err = copy(t1, src); err != .None { return err; }
if err = zero(t2); err != .None { return err; }
/* First approximation. Not very bad for large arguments. */
if err = shr_digit(t1, t1.used / 2); err != .None { return err; }
/* t1 > 0 */
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
/* And now t1 > sqrt(arg). */
for {
if err = div(t2, src, t1); err != .None { return err; }
if err = add(t1, t1, t2); err != .None { return err; }
if err = shr(t1, t1, 1); err != .None { return err; }
/* t1 >= sqrt(arg) >= t2 at this point */
cm, _ := cmp_mag(t1, t2);
if cm != 1 { break; }
}
swap(dest, t1);
return err;
}
sqrt :: proc { int_sqrt, };
/*
Find the nth root of an Integer.
Result found such that `(dest)**n <= src` and `(dest+1)**n > src`
This algorithm uses Newton's approximation `x[i+1] = x[i] - f(x[i])/f'(x[i])`,
which will find the root in `log(n)` time where each step involves a fair bit.
*/
int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) {
/* Fast path for n == 2 */
if n == 2 { return sqrt(dest, src); }
/* Initialize dest + src if needed. */
if err = clear_if_uninitialized(dest); err != .None { return err; }
if err = clear_if_uninitialized(src); err != .None { return err; }
if n < 0 || n > int(_DIGIT_MAX) {
return .Invalid_Argument;
}
neg: bool;
if n & 1 == 0 {
if neg, err = is_neg(src); neg || err != .None { return .Invalid_Argument; }
}
/* Set up temporaries. */
t1, t2, t3, a := &Int{}, &Int{}, &Int{}, &Int{};
defer destroy(t1, t2, t3);
/* If a is negative fudge the sign but keep track. */
a.sign = .Zero_or_Positive;
a.used = src.used;
a.digit = src.digit;
/*
If "n" is larger than INT_MAX it is also larger than
log_2(src) because the bit-length of the "src" is measured
with an int and hence the root is always < 2 (two).
*/
if n > max(int) / 2 {
err = set(dest, 1);
dest.sign = a.sign;
return err;
}
/* Compute seed: 2^(log_2(src)/n + 2) */
ilog2: int;
ilog2, err = count_bits(src);
/* "src" is smaller than max(int), we can cast safely. */
if ilog2 < n {
err = set(dest, 1);
dest.sign = a.sign;
return err;
}
ilog2 /= n;
if ilog2 == 0 {
err = set(dest, 1);
dest.sign = a.sign;
return err;
}
/* Start value must be larger than root. */
ilog2 += 2;
if err = power_of_two(t2, ilog2); err != .None { return err; }
c: int;
for {
/* t1 = t2 */
if err = copy(t1, t2); err != .None { return err; }
/* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */
/* t3 = t1**(b-1) */
if err = pow(t3, t1, n-1); err != .None { return err; }
/* numerator */
/* t2 = t1**b */
if err = mul(t2, t1, t3); err != .None { return err; }
/* t2 = t1**b - a */
if err = sub(t2, t2, a); err != .None { return err; }
/* denominator */
/* t3 = t1**(b-1) * b */
if err = mul(t3, t3, DIGIT(n)); err != .None { return err; }
/* t3 = (t1**b - a)/(b * t1**(b-1)) */
if err = div(t3, t2, t3); err != .None { return err; }
if err = sub(t2, t1, t3); err != .None { return err; }
/*
Number of rounds is at most log_2(root). If it is more it
got stuck, so break out of the loop and do the rest manually.
*/
if ilog2 -= 1; ilog2 == 0 {
break;
}
if c, err = cmp(t1, t2); c == 0 { break; }
}
/* Result can be off by a few so check. */
/* Loop beneath can overshoot by one if found root is smaller than actual root. */
for {
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
if c == 0 {
swap(dest, t1);
return .None;
} else if c == -1 {
if err = add(t1, t1, DIGIT(1)); err != .None { return err; }
} else {
break;
}
}
/* Correct overshoot from above or from recurrence. */
for {
if err = pow(t2, t1, n); err != .None { return err; }
c, err = cmp(t2, a);
if c == 1 {
if err = sub(t1, t1, DIGIT(1)); err != .None { return err; }
} else {
break;
}
}
/* Set the result. */
swap(dest, t1);
/* set the sign of the result */
dest.sign = src.sign;
return err;
}
root_n :: proc { int_root_n, };
/*
Internal implementation of log.
*/