Merge pull request #5249 from Kelimion/simd_prefix

Vectorize `strings.prefix_length`.
This commit is contained in:
Jeroen van Rijn
2025-05-31 20:42:15 +02:00
committed by GitHub
5 changed files with 296 additions and 18 deletions
+68
View File
@@ -405,6 +405,74 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_
return 0
}
memory_prefix_length :: proc "contextless" (x, y: rawptr, n: int) -> (idx: int) #no_bounds_check {
switch {
case x == y: return n
case x == nil: return 0
case y == nil: return 0
}
a, b := cast([^]byte)x, cast([^]byte)y
n := uint(n)
i := uint(0)
m := uint(0)
when HAS_HARDWARE_SIMD {
when ODIN_ARCH == .amd64 && intrinsics.has_target_feature("avx2") {
m = n / 32 * 32
for /**/; i < m; i += 32 {
load_a := intrinsics.unaligned_load(cast(^#simd[32]u8)&a[i])
load_b := intrinsics.unaligned_load(cast(^#simd[32]u8)&b[i])
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
if intrinsics.simd_reduce_or(comparison) != 0 {
sentinel: #simd[32]u8 = u8(0xFF)
indices := intrinsics.simd_indices(#simd[32]u8)
index_select := intrinsics.simd_select(comparison, indices, sentinel)
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
return i + index_reduce
}
}
}
}
m = (n-i) / 16 * 16
for /**/; i < m; i += 16 {
load_a := intrinsics.unaligned_load(cast(^#simd[16]u8)&a[i])
load_b := intrinsics.unaligned_load(cast(^#simd[16]u8)&b[i])
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
if intrinsics.simd_reduce_or(comparison) != 0 {
sentinel: #simd[16]u8 = u8(0xFF)
indices := intrinsics.simd_indices(#simd[16]u8)
index_select := intrinsics.simd_select(comparison, indices, sentinel)
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
return int(i + index_reduce)
}
}
// 64-bit SIMD is faster than using a `uintptr` to detect a difference then
// re-iterating with the byte-by-byte loop, at least on AMD64.
m = (n-i) / 8 * 8
for /**/; i < m; i += 8 {
load_a := intrinsics.unaligned_load(cast(^#simd[8]u8)&a[i])
load_b := intrinsics.unaligned_load(cast(^#simd[8]u8)&b[i])
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
if intrinsics.simd_reduce_or(comparison) != 0 {
sentinel: #simd[8]u8 = u8(0xFF)
indices := intrinsics.simd_indices(#simd[8]u8)
index_select := intrinsics.simd_select(comparison, indices, sentinel)
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
return int(i + index_reduce)
}
}
for /**/; i < n; i += 1 {
if a[i] ~ b[i] != 0 {
return int(i)
}
}
return int(n)
}
string_eq :: proc "contextless" (lhs, rhs: string) -> bool {
x := transmute(Raw_String)lhs
y := transmute(Raw_String)rhs
+46 -17
View File
@@ -2,6 +2,7 @@
package strings
import "base:intrinsics"
import "base:runtime"
import "core:bytes"
import "core:io"
import "core:mem"
@@ -458,6 +459,7 @@ equal_fold :: proc(u, v: string) -> (res: bool) {
return s == t
}
/*
Returns the prefix length common between strings `a` and `b`
@@ -488,30 +490,57 @@ Output:
0
*/
prefix_length :: proc(a, b: string) -> (n: int) {
_len := min(len(a), len(b))
prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
RUNE_ERROR :: '\ufffd'
RUNE_SELF :: 0x80
UTF_MAX :: 4
// Scan for matches including partial codepoints.
#no_bounds_check for n < _len && a[n] == b[n] {
n += 1
}
// Now scan to ignore partial codepoints.
if n > 0 {
s := a[:n]
n = 0
for {
r0, w := utf8.decode_rune(s[n:])
if r0 != utf8.RUNE_ERROR {
n += w
} else {
break
n = runtime.memory_prefix_length(raw_data(a), raw_data(b), min(len(a), len(b)))
lim := max(n - UTF_MAX + 1, 0)
for l := n; l > lim; l -= 1 {
r, _ := runtime.string_decode_rune(a[l - 1:])
if r != RUNE_ERROR {
if l > 0 && (a[l - 1] & 0xc0 == 0xc0) {
return l - 1
}
return l
}
}
return
}
/*
Returns the common prefix between strings `a` and `b`
Inputs:
- a: The first input string
- b: The second input string
Returns:
- n: The string prefix common between strings `a` and `b`
Example:
import "core:fmt"
import "core:strings"
common_prefix_example :: proc() {
fmt.println(strings.common_prefix("testing", "test"))
fmt.println(strings.common_prefix("testing", "te"))
fmt.println(strings.common_prefix("telephone", "te"))
}
Output:
test
te
te
*/
common_prefix :: proc(a, b: string) -> string {
return a[:prefix_length(a, b)]
}
/*
Determines if a string `s` starts with a given `prefix`
Inputs:
+1
View File
@@ -4,3 +4,4 @@ package benchmarks
@(require) import "crypto"
@(require) import "hash"
@(require) import "text/regex"
@(require) import "strings"
@@ -0,0 +1,131 @@
package benchmark_strings
import "base:intrinsics"
import "core:fmt"
import "core:log"
import "core:testing"
import "core:strings"
import "core:text/table"
import "core:time"
import "core:unicode/utf8"
RUNS_PER_SIZE :: 2500
sizes := [?]int {
7, 8, 9,
15, 16, 17,
31, 32, 33,
63, 64, 65,
95, 96, 97,
128,
256,
512,
1024,
4096,
}
// These are the normal, unoptimized algorithms.
plain_prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
_len := min(len(a), len(b))
// Scan for matches including partial codepoints.
#no_bounds_check for n < _len && a[n] == b[n] {
n += 1
}
// Now scan to ignore partial codepoints.
if n > 0 {
s := a[:n]
n = 0
for {
r0, w := utf8.decode_rune(s[n:])
if r0 != utf8.RUNE_ERROR {
n += w
} else {
break
}
}
}
return
}
run_trial_size_prefix :: proc(p: proc "contextless" (string, string) -> $R, suffix: string, size: int, idx: int, runs: int, loc := #caller_location) -> (timing: time.Duration) {
left := make([]u8, size)
right := make([]u8, size)
defer {
delete(left)
delete(right)
}
if len(suffix) > 0 {
copy(left [idx:], suffix)
copy(right[idx:], suffix)
} else {
right[idx] = 'A'
}
accumulator: int
watch: time.Stopwatch
time.stopwatch_start(&watch)
for _ in 0..<runs {
result := p(string(left[:size]), string(right[:size]))
accumulator += result
}
time.stopwatch_stop(&watch)
timing = time.stopwatch_duration(watch)
log.debug(accumulator)
return
}
run_trial_size :: proc {
run_trial_size_prefix,
}
bench_table_size :: proc(algo_name: string, plain, simd: $P, suffix := "") {
string_buffer := strings.builder_make()
defer strings.builder_destroy(&string_buffer)
tbl: table.Table
table.init(&tbl)
defer table.destroy(&tbl)
table.aligned_header_of_values(&tbl, .Right, "Algorithm", "Size", "Iterations", "Scalar", "SIMD", "SIMD Relative (%)", "SIMD Relative (x)")
for size in sizes {
// Place the non-zero byte somewhere in the middle.
needle_index := size / 2
plain_timing := run_trial_size(plain, suffix, size, needle_index, RUNS_PER_SIZE)
simd_timing := run_trial_size(simd, suffix, size, needle_index, RUNS_PER_SIZE)
_plain := fmt.tprintf("%8M", plain_timing)
_simd := fmt.tprintf("%8M", simd_timing)
_relp := fmt.tprintf("%.3f %%", f64(simd_timing) / f64(plain_timing) * 100.0)
_relx := fmt.tprintf("%.3f x", 1 / (f64(simd_timing) / f64(plain_timing)))
table.aligned_row_of_values(
&tbl,
.Right,
algo_name,
size, RUNS_PER_SIZE, _plain, _simd, _relp, _relx)
}
builder_writer := strings.to_writer(&string_buffer)
fmt.sbprintln(&string_buffer)
table.write_plain_table(builder_writer, &tbl)
my_table_string := strings.to_string(string_buffer)
log.info(my_table_string)
}
@test
benchmark_memory_procs :: proc(t: ^testing.T) {
bench_table_size("prefix_length ascii", plain_prefix_length, strings.prefix_length)
bench_table_size("prefix_length unicode", plain_prefix_length, strings.prefix_length, "🦉")
}
+50 -1
View File
@@ -1,9 +1,10 @@
package test_core_strings
import "base:runtime"
import "core:mem"
import "core:strings"
import "core:testing"
import "base:runtime"
import "core:unicode/utf8"
@test
test_index_any_small_string_not_found :: proc(t: ^testing.T) {
@@ -218,3 +219,51 @@ test_builder_to_cstring :: proc(t: ^testing.T) {
testing.expect(t, err == .Out_Of_Memory)
}
}
@test
test_prefix_length :: proc(t: ^testing.T) {
prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
_len := min(len(a), len(b))
// Scan for matches including partial codepoints.
#no_bounds_check for n < _len && a[n] == b[n] {
n += 1
}
// Now scan to ignore partial codepoints.
if n > 0 {
s := a[:n]
n = 0
for {
r0, w := utf8.decode_rune(s[n:])
if r0 != utf8.RUNE_ERROR {
n += w
} else {
break
}
}
}
return
}
cases := [][2]string{
{"Hellope, there!", "Hellope, world!"},
{"Hellope, there!", "Foozle"},
{"Hellope, there!", "Hell"},
{"Hellope! 🦉", "Hellope! 🦉"},
}
for v in cases {
p_scalar := prefix_length(v[0], v[1])
p_simd := strings.prefix_length(v[0], v[1])
testing.expect_value(t, p_simd, p_scalar)
s := v[0]
for len(s) > 0 {
p_scalar = prefix_length(v[0], s)
p_simd = strings.prefix_length(v[0], s)
testing.expect_value(t, p_simd, p_scalar)
s = s[:len(s) - 1]
}
}
}