diff --git a/base/runtime/internal.odin b/base/runtime/internal.odin index a35dbff8a..3faf3e178 100644 --- a/base/runtime/internal.odin +++ b/base/runtime/internal.odin @@ -405,6 +405,74 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_ return 0 } +memory_prefix_length :: proc "contextless" (x, y: rawptr, n: int) -> (idx: int) #no_bounds_check { + switch { + case x == y: return n + case x == nil: return 0 + case y == nil: return 0 + } + a, b := cast([^]byte)x, cast([^]byte)y + + n := uint(n) + i := uint(0) + m := uint(0) + + when HAS_HARDWARE_SIMD { + when ODIN_ARCH == .amd64 && intrinsics.has_target_feature("avx2") { + m = n / 32 * 32 + for /**/; i < m; i += 32 { + load_a := intrinsics.unaligned_load(cast(^#simd[32]u8)&a[i]) + load_b := intrinsics.unaligned_load(cast(^#simd[32]u8)&b[i]) + comparison := intrinsics.simd_lanes_ne(load_a, load_b) + if intrinsics.simd_reduce_or(comparison) != 0 { + sentinel: #simd[32]u8 = u8(0xFF) + indices := intrinsics.simd_indices(#simd[32]u8) + index_select := intrinsics.simd_select(comparison, indices, sentinel) + index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select) + return i + index_reduce + } + } + } + } + + m = (n-i) / 16 * 16 + for /**/; i < m; i += 16 { + load_a := intrinsics.unaligned_load(cast(^#simd[16]u8)&a[i]) + load_b := intrinsics.unaligned_load(cast(^#simd[16]u8)&b[i]) + comparison := intrinsics.simd_lanes_ne(load_a, load_b) + if intrinsics.simd_reduce_or(comparison) != 0 { + sentinel: #simd[16]u8 = u8(0xFF) + indices := intrinsics.simd_indices(#simd[16]u8) + index_select := intrinsics.simd_select(comparison, indices, sentinel) + index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select) + return int(i + index_reduce) + } + } + + // 64-bit SIMD is faster than using a `uintptr` to detect a difference then + // re-iterating with the byte-by-byte loop, at least on AMD64. + m = (n-i) / 8 * 8 + for /**/; i < m; i += 8 { + load_a := intrinsics.unaligned_load(cast(^#simd[8]u8)&a[i]) + load_b := intrinsics.unaligned_load(cast(^#simd[8]u8)&b[i]) + comparison := intrinsics.simd_lanes_ne(load_a, load_b) + if intrinsics.simd_reduce_or(comparison) != 0 { + sentinel: #simd[8]u8 = u8(0xFF) + indices := intrinsics.simd_indices(#simd[8]u8) + index_select := intrinsics.simd_select(comparison, indices, sentinel) + index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select) + return int(i + index_reduce) + } + } + + for /**/; i < n; i += 1 { + if a[i] ~ b[i] != 0 { + return int(i) + } + } + return int(n) +} + string_eq :: proc "contextless" (lhs, rhs: string) -> bool { x := transmute(Raw_String)lhs y := transmute(Raw_String)rhs diff --git a/core/strings/strings.odin b/core/strings/strings.odin index e15754dff..e45b177d7 100644 --- a/core/strings/strings.odin +++ b/core/strings/strings.odin @@ -2,6 +2,7 @@ package strings import "base:intrinsics" +import "base:runtime" import "core:bytes" import "core:io" import "core:mem" @@ -458,6 +459,7 @@ equal_fold :: proc(u, v: string) -> (res: bool) { return s == t } + /* Returns the prefix length common between strings `a` and `b` @@ -488,30 +490,57 @@ Output: 0 */ -prefix_length :: proc(a, b: string) -> (n: int) { - _len := min(len(a), len(b)) +prefix_length :: proc "contextless" (a, b: string) -> (n: int) { + RUNE_ERROR :: '\ufffd' + RUNE_SELF :: 0x80 + UTF_MAX :: 4 - // Scan for matches including partial codepoints. - #no_bounds_check for n < _len && a[n] == b[n] { - n += 1 - } - - // Now scan to ignore partial codepoints. - if n > 0 { - s := a[:n] - n = 0 - for { - r0, w := utf8.decode_rune(s[n:]) - if r0 != utf8.RUNE_ERROR { - n += w - } else { - break + n = runtime.memory_prefix_length(raw_data(a), raw_data(b), min(len(a), len(b))) + lim := max(n - UTF_MAX + 1, 0) + for l := n; l > lim; l -= 1 { + r, _ := runtime.string_decode_rune(a[l - 1:]) + if r != RUNE_ERROR { + if l > 0 && (a[l - 1] & 0xc0 == 0xc0) { + return l - 1 } + return l } } return } /* +Returns the common prefix between strings `a` and `b` + +Inputs: +- a: The first input string +- b: The second input string + +Returns: +- n: The string prefix common between strings `a` and `b` + +Example: + + import "core:fmt" + import "core:strings" + + common_prefix_example :: proc() { + fmt.println(strings.common_prefix("testing", "test")) + fmt.println(strings.common_prefix("testing", "te")) + fmt.println(strings.common_prefix("telephone", "te")) + } + +Output: + + test + te + te + + +*/ +common_prefix :: proc(a, b: string) -> string { + return a[:prefix_length(a, b)] +} +/* Determines if a string `s` starts with a given `prefix` Inputs: diff --git a/tests/benchmark/all.odin b/tests/benchmark/all.odin index a48872cc6..30640ac87 100644 --- a/tests/benchmark/all.odin +++ b/tests/benchmark/all.odin @@ -4,3 +4,4 @@ package benchmarks @(require) import "crypto" @(require) import "hash" @(require) import "text/regex" +@(require) import "strings" \ No newline at end of file diff --git a/tests/benchmark/strings/benchmark_strings.odin b/tests/benchmark/strings/benchmark_strings.odin new file mode 100644 index 000000000..866e8f756 --- /dev/null +++ b/tests/benchmark/strings/benchmark_strings.odin @@ -0,0 +1,131 @@ +package benchmark_strings + +import "base:intrinsics" +import "core:fmt" +import "core:log" +import "core:testing" +import "core:strings" +import "core:text/table" +import "core:time" +import "core:unicode/utf8" + +RUNS_PER_SIZE :: 2500 + +sizes := [?]int { + 7, 8, 9, + 15, 16, 17, + 31, 32, 33, + 63, 64, 65, + 95, 96, 97, + 128, + 256, + 512, + 1024, + 4096, +} + +// These are the normal, unoptimized algorithms. + +plain_prefix_length :: proc "contextless" (a, b: string) -> (n: int) { + _len := min(len(a), len(b)) + + // Scan for matches including partial codepoints. + #no_bounds_check for n < _len && a[n] == b[n] { + n += 1 + } + + // Now scan to ignore partial codepoints. + if n > 0 { + s := a[:n] + n = 0 + for { + r0, w := utf8.decode_rune(s[n:]) + if r0 != utf8.RUNE_ERROR { + n += w + } else { + break + } + } + } + return +} + +run_trial_size_prefix :: proc(p: proc "contextless" (string, string) -> $R, suffix: string, size: int, idx: int, runs: int, loc := #caller_location) -> (timing: time.Duration) { + left := make([]u8, size) + right := make([]u8, size) + defer { + delete(left) + delete(right) + } + + if len(suffix) > 0 { + copy(left [idx:], suffix) + copy(right[idx:], suffix) + + } else { + right[idx] = 'A' + } + + accumulator: int + + watch: time.Stopwatch + + time.stopwatch_start(&watch) + for _ in 0.. (n: int) { + _len := min(len(a), len(b)) + + // Scan for matches including partial codepoints. + #no_bounds_check for n < _len && a[n] == b[n] { + n += 1 + } + + // Now scan to ignore partial codepoints. + if n > 0 { + s := a[:n] + n = 0 + for { + r0, w := utf8.decode_rune(s[n:]) + if r0 != utf8.RUNE_ERROR { + n += w + } else { + break + } + } + } + return + } + + cases := [][2]string{ + {"Hellope, there!", "Hellope, world!"}, + {"Hellope, there!", "Foozle"}, + {"Hellope, there!", "Hell"}, + {"Hellope! 🦉", "Hellope! 🦉"}, + } + + for v in cases { + p_scalar := prefix_length(v[0], v[1]) + p_simd := strings.prefix_length(v[0], v[1]) + testing.expect_value(t, p_simd, p_scalar) + + s := v[0] + for len(s) > 0 { + p_scalar = prefix_length(v[0], s) + p_simd = strings.prefix_length(v[0], s) + testing.expect_value(t, p_simd, p_scalar) + s = s[:len(s) - 1] + } + } +} \ No newline at end of file