mirror of
https://github.com/Ed94/Odin.git
synced 2026-06-18 03:42:23 -07:00
Merge pull request #5249 from Kelimion/simd_prefix
Vectorize `strings.prefix_length`.
This commit is contained in:
@@ -405,6 +405,74 @@ memory_compare_zero :: proc "contextless" (a: rawptr, n: int) -> int #no_bounds_
|
||||
return 0
|
||||
}
|
||||
|
||||
memory_prefix_length :: proc "contextless" (x, y: rawptr, n: int) -> (idx: int) #no_bounds_check {
|
||||
switch {
|
||||
case x == y: return n
|
||||
case x == nil: return 0
|
||||
case y == nil: return 0
|
||||
}
|
||||
a, b := cast([^]byte)x, cast([^]byte)y
|
||||
|
||||
n := uint(n)
|
||||
i := uint(0)
|
||||
m := uint(0)
|
||||
|
||||
when HAS_HARDWARE_SIMD {
|
||||
when ODIN_ARCH == .amd64 && intrinsics.has_target_feature("avx2") {
|
||||
m = n / 32 * 32
|
||||
for /**/; i < m; i += 32 {
|
||||
load_a := intrinsics.unaligned_load(cast(^#simd[32]u8)&a[i])
|
||||
load_b := intrinsics.unaligned_load(cast(^#simd[32]u8)&b[i])
|
||||
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
|
||||
if intrinsics.simd_reduce_or(comparison) != 0 {
|
||||
sentinel: #simd[32]u8 = u8(0xFF)
|
||||
indices := intrinsics.simd_indices(#simd[32]u8)
|
||||
index_select := intrinsics.simd_select(comparison, indices, sentinel)
|
||||
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
|
||||
return i + index_reduce
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m = (n-i) / 16 * 16
|
||||
for /**/; i < m; i += 16 {
|
||||
load_a := intrinsics.unaligned_load(cast(^#simd[16]u8)&a[i])
|
||||
load_b := intrinsics.unaligned_load(cast(^#simd[16]u8)&b[i])
|
||||
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
|
||||
if intrinsics.simd_reduce_or(comparison) != 0 {
|
||||
sentinel: #simd[16]u8 = u8(0xFF)
|
||||
indices := intrinsics.simd_indices(#simd[16]u8)
|
||||
index_select := intrinsics.simd_select(comparison, indices, sentinel)
|
||||
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
|
||||
return int(i + index_reduce)
|
||||
}
|
||||
}
|
||||
|
||||
// 64-bit SIMD is faster than using a `uintptr` to detect a difference then
|
||||
// re-iterating with the byte-by-byte loop, at least on AMD64.
|
||||
m = (n-i) / 8 * 8
|
||||
for /**/; i < m; i += 8 {
|
||||
load_a := intrinsics.unaligned_load(cast(^#simd[8]u8)&a[i])
|
||||
load_b := intrinsics.unaligned_load(cast(^#simd[8]u8)&b[i])
|
||||
comparison := intrinsics.simd_lanes_ne(load_a, load_b)
|
||||
if intrinsics.simd_reduce_or(comparison) != 0 {
|
||||
sentinel: #simd[8]u8 = u8(0xFF)
|
||||
indices := intrinsics.simd_indices(#simd[8]u8)
|
||||
index_select := intrinsics.simd_select(comparison, indices, sentinel)
|
||||
index_reduce := cast(uint)intrinsics.simd_reduce_min(index_select)
|
||||
return int(i + index_reduce)
|
||||
}
|
||||
}
|
||||
|
||||
for /**/; i < n; i += 1 {
|
||||
if a[i] ~ b[i] != 0 {
|
||||
return int(i)
|
||||
}
|
||||
}
|
||||
return int(n)
|
||||
}
|
||||
|
||||
string_eq :: proc "contextless" (lhs, rhs: string) -> bool {
|
||||
x := transmute(Raw_String)lhs
|
||||
y := transmute(Raw_String)rhs
|
||||
|
||||
+46
-17
@@ -2,6 +2,7 @@
|
||||
package strings
|
||||
|
||||
import "base:intrinsics"
|
||||
import "base:runtime"
|
||||
import "core:bytes"
|
||||
import "core:io"
|
||||
import "core:mem"
|
||||
@@ -458,6 +459,7 @@ equal_fold :: proc(u, v: string) -> (res: bool) {
|
||||
|
||||
return s == t
|
||||
}
|
||||
|
||||
/*
|
||||
Returns the prefix length common between strings `a` and `b`
|
||||
|
||||
@@ -488,30 +490,57 @@ Output:
|
||||
0
|
||||
|
||||
*/
|
||||
prefix_length :: proc(a, b: string) -> (n: int) {
|
||||
_len := min(len(a), len(b))
|
||||
prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
|
||||
RUNE_ERROR :: '\ufffd'
|
||||
RUNE_SELF :: 0x80
|
||||
UTF_MAX :: 4
|
||||
|
||||
// Scan for matches including partial codepoints.
|
||||
#no_bounds_check for n < _len && a[n] == b[n] {
|
||||
n += 1
|
||||
}
|
||||
|
||||
// Now scan to ignore partial codepoints.
|
||||
if n > 0 {
|
||||
s := a[:n]
|
||||
n = 0
|
||||
for {
|
||||
r0, w := utf8.decode_rune(s[n:])
|
||||
if r0 != utf8.RUNE_ERROR {
|
||||
n += w
|
||||
} else {
|
||||
break
|
||||
n = runtime.memory_prefix_length(raw_data(a), raw_data(b), min(len(a), len(b)))
|
||||
lim := max(n - UTF_MAX + 1, 0)
|
||||
for l := n; l > lim; l -= 1 {
|
||||
r, _ := runtime.string_decode_rune(a[l - 1:])
|
||||
if r != RUNE_ERROR {
|
||||
if l > 0 && (a[l - 1] & 0xc0 == 0xc0) {
|
||||
return l - 1
|
||||
}
|
||||
return l
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
/*
|
||||
Returns the common prefix between strings `a` and `b`
|
||||
|
||||
Inputs:
|
||||
- a: The first input string
|
||||
- b: The second input string
|
||||
|
||||
Returns:
|
||||
- n: The string prefix common between strings `a` and `b`
|
||||
|
||||
Example:
|
||||
|
||||
import "core:fmt"
|
||||
import "core:strings"
|
||||
|
||||
common_prefix_example :: proc() {
|
||||
fmt.println(strings.common_prefix("testing", "test"))
|
||||
fmt.println(strings.common_prefix("testing", "te"))
|
||||
fmt.println(strings.common_prefix("telephone", "te"))
|
||||
}
|
||||
|
||||
Output:
|
||||
|
||||
test
|
||||
te
|
||||
te
|
||||
|
||||
|
||||
*/
|
||||
common_prefix :: proc(a, b: string) -> string {
|
||||
return a[:prefix_length(a, b)]
|
||||
}
|
||||
/*
|
||||
Determines if a string `s` starts with a given `prefix`
|
||||
|
||||
Inputs:
|
||||
|
||||
@@ -4,3 +4,4 @@ package benchmarks
|
||||
@(require) import "crypto"
|
||||
@(require) import "hash"
|
||||
@(require) import "text/regex"
|
||||
@(require) import "strings"
|
||||
@@ -0,0 +1,131 @@
|
||||
package benchmark_strings
|
||||
|
||||
import "base:intrinsics"
|
||||
import "core:fmt"
|
||||
import "core:log"
|
||||
import "core:testing"
|
||||
import "core:strings"
|
||||
import "core:text/table"
|
||||
import "core:time"
|
||||
import "core:unicode/utf8"
|
||||
|
||||
RUNS_PER_SIZE :: 2500
|
||||
|
||||
sizes := [?]int {
|
||||
7, 8, 9,
|
||||
15, 16, 17,
|
||||
31, 32, 33,
|
||||
63, 64, 65,
|
||||
95, 96, 97,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
4096,
|
||||
}
|
||||
|
||||
// These are the normal, unoptimized algorithms.
|
||||
|
||||
plain_prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
|
||||
_len := min(len(a), len(b))
|
||||
|
||||
// Scan for matches including partial codepoints.
|
||||
#no_bounds_check for n < _len && a[n] == b[n] {
|
||||
n += 1
|
||||
}
|
||||
|
||||
// Now scan to ignore partial codepoints.
|
||||
if n > 0 {
|
||||
s := a[:n]
|
||||
n = 0
|
||||
for {
|
||||
r0, w := utf8.decode_rune(s[n:])
|
||||
if r0 != utf8.RUNE_ERROR {
|
||||
n += w
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
run_trial_size_prefix :: proc(p: proc "contextless" (string, string) -> $R, suffix: string, size: int, idx: int, runs: int, loc := #caller_location) -> (timing: time.Duration) {
|
||||
left := make([]u8, size)
|
||||
right := make([]u8, size)
|
||||
defer {
|
||||
delete(left)
|
||||
delete(right)
|
||||
}
|
||||
|
||||
if len(suffix) > 0 {
|
||||
copy(left [idx:], suffix)
|
||||
copy(right[idx:], suffix)
|
||||
|
||||
} else {
|
||||
right[idx] = 'A'
|
||||
}
|
||||
|
||||
accumulator: int
|
||||
|
||||
watch: time.Stopwatch
|
||||
|
||||
time.stopwatch_start(&watch)
|
||||
for _ in 0..<runs {
|
||||
result := p(string(left[:size]), string(right[:size]))
|
||||
accumulator += result
|
||||
}
|
||||
time.stopwatch_stop(&watch)
|
||||
timing = time.stopwatch_duration(watch)
|
||||
|
||||
log.debug(accumulator)
|
||||
return
|
||||
}
|
||||
|
||||
run_trial_size :: proc {
|
||||
run_trial_size_prefix,
|
||||
}
|
||||
|
||||
bench_table_size :: proc(algo_name: string, plain, simd: $P, suffix := "") {
|
||||
string_buffer := strings.builder_make()
|
||||
defer strings.builder_destroy(&string_buffer)
|
||||
|
||||
tbl: table.Table
|
||||
table.init(&tbl)
|
||||
defer table.destroy(&tbl)
|
||||
|
||||
table.aligned_header_of_values(&tbl, .Right, "Algorithm", "Size", "Iterations", "Scalar", "SIMD", "SIMD Relative (%)", "SIMD Relative (x)")
|
||||
|
||||
for size in sizes {
|
||||
// Place the non-zero byte somewhere in the middle.
|
||||
needle_index := size / 2
|
||||
|
||||
plain_timing := run_trial_size(plain, suffix, size, needle_index, RUNS_PER_SIZE)
|
||||
simd_timing := run_trial_size(simd, suffix, size, needle_index, RUNS_PER_SIZE)
|
||||
|
||||
_plain := fmt.tprintf("%8M", plain_timing)
|
||||
_simd := fmt.tprintf("%8M", simd_timing)
|
||||
_relp := fmt.tprintf("%.3f %%", f64(simd_timing) / f64(plain_timing) * 100.0)
|
||||
_relx := fmt.tprintf("%.3f x", 1 / (f64(simd_timing) / f64(plain_timing)))
|
||||
|
||||
table.aligned_row_of_values(
|
||||
&tbl,
|
||||
.Right,
|
||||
algo_name,
|
||||
size, RUNS_PER_SIZE, _plain, _simd, _relp, _relx)
|
||||
}
|
||||
|
||||
builder_writer := strings.to_writer(&string_buffer)
|
||||
|
||||
fmt.sbprintln(&string_buffer)
|
||||
table.write_plain_table(builder_writer, &tbl)
|
||||
|
||||
my_table_string := strings.to_string(string_buffer)
|
||||
log.info(my_table_string)
|
||||
}
|
||||
|
||||
@test
|
||||
benchmark_memory_procs :: proc(t: ^testing.T) {
|
||||
bench_table_size("prefix_length ascii", plain_prefix_length, strings.prefix_length)
|
||||
bench_table_size("prefix_length unicode", plain_prefix_length, strings.prefix_length, "🦉")
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
package test_core_strings
|
||||
|
||||
import "base:runtime"
|
||||
import "core:mem"
|
||||
import "core:strings"
|
||||
import "core:testing"
|
||||
import "base:runtime"
|
||||
import "core:unicode/utf8"
|
||||
|
||||
@test
|
||||
test_index_any_small_string_not_found :: proc(t: ^testing.T) {
|
||||
@@ -218,3 +219,51 @@ test_builder_to_cstring :: proc(t: ^testing.T) {
|
||||
testing.expect(t, err == .Out_Of_Memory)
|
||||
}
|
||||
}
|
||||
|
||||
@test
|
||||
test_prefix_length :: proc(t: ^testing.T) {
|
||||
prefix_length :: proc "contextless" (a, b: string) -> (n: int) {
|
||||
_len := min(len(a), len(b))
|
||||
|
||||
// Scan for matches including partial codepoints.
|
||||
#no_bounds_check for n < _len && a[n] == b[n] {
|
||||
n += 1
|
||||
}
|
||||
|
||||
// Now scan to ignore partial codepoints.
|
||||
if n > 0 {
|
||||
s := a[:n]
|
||||
n = 0
|
||||
for {
|
||||
r0, w := utf8.decode_rune(s[n:])
|
||||
if r0 != utf8.RUNE_ERROR {
|
||||
n += w
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
cases := [][2]string{
|
||||
{"Hellope, there!", "Hellope, world!"},
|
||||
{"Hellope, there!", "Foozle"},
|
||||
{"Hellope, there!", "Hell"},
|
||||
{"Hellope! 🦉", "Hellope! 🦉"},
|
||||
}
|
||||
|
||||
for v in cases {
|
||||
p_scalar := prefix_length(v[0], v[1])
|
||||
p_simd := strings.prefix_length(v[0], v[1])
|
||||
testing.expect_value(t, p_simd, p_scalar)
|
||||
|
||||
s := v[0]
|
||||
for len(s) > 0 {
|
||||
p_scalar = prefix_length(v[0], s)
|
||||
p_simd = strings.prefix_length(v[0], s)
|
||||
testing.expect_value(t, p_simd, p_scalar)
|
||||
s = s[:len(s) - 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user