From 531c4936dd2ccf026748cc390094c9b51691a534 Mon Sep 17 00:00:00 2001 From: Jeroen van Rijn Date: Tue, 27 Jul 2021 20:33:32 +0200 Subject: [PATCH] big: Add `root_n`. --- core/math/big/basic.odin | 44 ------ core/math/big/example.odin | 18 ++- core/math/big/{log.odin => exp_log.odin} | 184 +++++++++++++++++++++++ 3 files changed, 195 insertions(+), 51 deletions(-) rename core/math/big/{log.odin => exp_log.odin} (54%) diff --git a/core/math/big/basic.odin b/core/math/big/basic.odin index 567bc05d8..ef069de17 100644 --- a/core/math/big/basic.odin +++ b/core/math/big/basic.odin @@ -749,50 +749,6 @@ int_sqrmod :: proc(remainder, number, modulus: ^Int) -> (err: Error) { } sqrmod :: proc { int_sqrmod, }; -/* - This function is less generic than `nth_root`, simpler and faster. -*/ -int_sqrt :: proc(dest, src: ^Int) -> (err: Error) { - if err = clear_if_uninitialized(dest); err != .None { return err; } - if err = clear_if_uninitialized(src); err != .None { return err; } - - /* Must be positive. */ - if src.sign == .Negative { return .Invalid_Argument; } - - /* Easy out. If src is zero, so is dest. */ - if z, _ := is_zero(src); z { return zero(dest); } - - /* Set up temporaries. */ - t1, t2 := &Int{}, &Int{}; - defer destroy(t1, t2); - - if err = copy(t1, src); err != .None { return err; } - if err = zero(t2); err != .None { return err; } - - /* First approximation. Not very bad for large arguments. */ - if err = shr_digit(t1, t1.used / 2); err != .None { return err; } - /* t1 > 0 */ - if err = div(t2, src, t1); err != .None { return err; } - if err = add(t1, t1, t2); err != .None { return err; } - if err = shr(t1, t1, 1); err != .None { return err; } - - /* And now t1 > sqrt(arg). */ - for { - if err = div(t2, src, t1); err != .None { return err; } - if err = add(t1, t1, t2); err != .None { return err; } - if err = shr(t1, t1, 1); err != .None { return err; } - /* t1 >= sqrt(arg) >= t2 at this point */ - - cm, _ := cmp_mag(t1, t2); - if cm != 1 { break; } - } - - swap(dest, t1); - return err; -} - -sqrt :: proc { int_sqrt, }; - /* ========================== Low-level routines diff --git a/core/math/big/example.odin b/core/math/big/example.odin index 7f64967de..7b0df89ae 100644 --- a/core/math/big/example.odin +++ b/core/math/big/example.odin @@ -63,15 +63,19 @@ demo :: proc() { // defer delete(string_buffer); err = set (numerator, 1024); - err = int_sqrt(destination, numerator); + err = sqrt(destination, numerator); fmt.printf("int_sqrt returned: %v\n", err); - print("destination", destination); - // print("source ", source); - // print("quotient ", quotient); - // print("remainder ", remainder); - print("numerator ", numerator); - // print("denominator", denominator); + print("num ", numerator); + print("sqrt(num)", destination); + + fmt.println("\n\n"); + + err = root_n(destination, numerator, 2); + fmt.printf("root_n(2) returned: %v\n", err); + + print("num ", numerator); + print("root_n(num)", destination); } main :: proc() { diff --git a/core/math/big/log.odin b/core/math/big/exp_log.odin similarity index 54% rename from core/math/big/log.odin rename to core/math/big/exp_log.odin index 777cc1ea2..2ccfa7146 100644 --- a/core/math/big/log.odin +++ b/core/math/big/exp_log.odin @@ -210,6 +210,190 @@ int_log_digit :: proc(a: DIGIT, base: DIGIT) -> (log: int, err: Error) { } } +/* + This function is less generic than `root_n`, simpler and faster. +*/ +int_sqrt :: proc(dest, src: ^Int) -> (err: Error) { + if err = clear_if_uninitialized(dest); err != .None { return err; } + if err = clear_if_uninitialized(src); err != .None { return err; } + + /* Must be positive. */ + if src.sign == .Negative { return .Invalid_Argument; } + + /* Easy out. If src is zero, so is dest. */ + if z, _ := is_zero(src); z { return zero(dest); } + + /* Set up temporaries. */ + t1, t2 := &Int{}, &Int{}; + defer destroy(t1, t2); + + if err = copy(t1, src); err != .None { return err; } + if err = zero(t2); err != .None { return err; } + + /* First approximation. Not very bad for large arguments. */ + if err = shr_digit(t1, t1.used / 2); err != .None { return err; } + /* t1 > 0 */ + if err = div(t2, src, t1); err != .None { return err; } + if err = add(t1, t1, t2); err != .None { return err; } + if err = shr(t1, t1, 1); err != .None { return err; } + + /* And now t1 > sqrt(arg). */ + for { + if err = div(t2, src, t1); err != .None { return err; } + if err = add(t1, t1, t2); err != .None { return err; } + if err = shr(t1, t1, 1); err != .None { return err; } + /* t1 >= sqrt(arg) >= t2 at this point */ + + cm, _ := cmp_mag(t1, t2); + if cm != 1 { break; } + } + + swap(dest, t1); + return err; +} +sqrt :: proc { int_sqrt, }; + + +/* + Find the nth root of an Integer. + Result found such that `(dest)**n <= src` and `(dest+1)**n > src` + + This algorithm uses Newton's approximation `x[i+1] = x[i] - f(x[i])/f'(x[i])`, + which will find the root in `log(n)` time where each step involves a fair bit. +*/ +int_root_n :: proc(dest, src: ^Int, n: int) -> (err: Error) { + /* Fast path for n == 2 */ + if n == 2 { return sqrt(dest, src); } + + /* Initialize dest + src if needed. */ + if err = clear_if_uninitialized(dest); err != .None { return err; } + if err = clear_if_uninitialized(src); err != .None { return err; } + + if n < 0 || n > int(_DIGIT_MAX) { + return .Invalid_Argument; + } + + neg: bool; + if n & 1 == 0 { + if neg, err = is_neg(src); neg || err != .None { return .Invalid_Argument; } + } + + /* Set up temporaries. */ + t1, t2, t3, a := &Int{}, &Int{}, &Int{}, &Int{}; + defer destroy(t1, t2, t3); + + /* If a is negative fudge the sign but keep track. */ + a.sign = .Zero_or_Positive; + a.used = src.used; + a.digit = src.digit; + + /* + If "n" is larger than INT_MAX it is also larger than + log_2(src) because the bit-length of the "src" is measured + with an int and hence the root is always < 2 (two). + */ + if n > max(int) / 2 { + err = set(dest, 1); + dest.sign = a.sign; + return err; + } + + /* Compute seed: 2^(log_2(src)/n + 2) */ + ilog2: int; + ilog2, err = count_bits(src); + + /* "src" is smaller than max(int), we can cast safely. */ + if ilog2 < n { + err = set(dest, 1); + dest.sign = a.sign; + return err; + } + + ilog2 /= n; + if ilog2 == 0 { + err = set(dest, 1); + dest.sign = a.sign; + return err; + } + + /* Start value must be larger than root. */ + ilog2 += 2; + if err = power_of_two(t2, ilog2); err != .None { return err; } + + c: int; + for { + /* t1 = t2 */ + if err = copy(t1, t2); err != .None { return err; } + + /* t2 = t1 - ((t1**b - a) / (b * t1**(b-1))) */ + + /* t3 = t1**(b-1) */ + if err = pow(t3, t1, n-1); err != .None { return err; } + + /* numerator */ + /* t2 = t1**b */ + if err = mul(t2, t1, t3); err != .None { return err; } + + /* t2 = t1**b - a */ + if err = sub(t2, t2, a); err != .None { return err; } + + /* denominator */ + /* t3 = t1**(b-1) * b */ + if err = mul(t3, t3, DIGIT(n)); err != .None { return err; } + + /* t3 = (t1**b - a)/(b * t1**(b-1)) */ + if err = div(t3, t2, t3); err != .None { return err; } + if err = sub(t2, t1, t3); err != .None { return err; } + + /* + Number of rounds is at most log_2(root). If it is more it + got stuck, so break out of the loop and do the rest manually. + */ + if ilog2 -= 1; ilog2 == 0 { + break; + } + if c, err = cmp(t1, t2); c == 0 { break; } + } + + /* Result can be off by a few so check. */ + /* Loop beneath can overshoot by one if found root is smaller than actual root. */ + + for { + if err = pow(t2, t1, n); err != .None { return err; } + + c, err = cmp(t2, a); + if c == 0 { + swap(dest, t1); + return .None; + } else if c == -1 { + if err = add(t1, t1, DIGIT(1)); err != .None { return err; } + } else { + break; + } + } + + /* Correct overshoot from above or from recurrence. */ + for { + if err = pow(t2, t1, n); err != .None { return err; } + + c, err = cmp(t2, a); + if c == 1 { + if err = sub(t1, t1, DIGIT(1)); err != .None { return err; } + } else { + break; + } + } + + /* Set the result. */ + swap(dest, t1); + + /* set the sign of the result */ + dest.sign = src.sign; + + return err; +} +root_n :: proc { int_root_n, }; + /* Internal implementation of log. */