diff --git a/src/lib.rs b/src/lib.rs index d6e64d5e73fe..6c5f94fa6827 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,11 @@ maybe_uninit_slice, maybe_uninit_uninit_array_transpose )] +#![cfg_attr( + target_arch = "loongarch64", + feature(stdarch_loongarch, stdarch_loongarch_feature_detection, loongarch_target_feature), + allow(clippy::incompatible_msrv) +)] #![allow(clippy::missing_transmute_annotations, clippy::new_without_default, stable_features)] #[macro_use] diff --git a/src/simd/memchr2.rs b/src/simd/memchr2.rs index c04cc8ea0d95..795af336a512 100644 --- a/src/simd/memchr2.rs +++ b/src/simd/memchr2.rs @@ -21,7 +21,7 @@ pub fn memchr2(needle1: u8, needle2: u8, haystack: &[u8], offset: usize) -> usiz } unsafe fn memchr2_raw(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))] return unsafe { MEMCHR2_DISPATCH(needle1, needle2, beg, end) }; #[cfg(target_arch = "aarch64")] @@ -53,7 +53,7 @@ unsafe fn memchr2_fallback( // itself to the correct implementation on the first call. This reduces binary size. // It would also reduce branches if we had >2 implementations (a jump still needs to be predicted). // NOTE that this ONLY works if Control Flow Guard is disabled on Windows. -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))] static mut MEMCHR2_DISPATCH: unsafe fn( needle1: u8, needle2: u8, @@ -102,6 +102,91 @@ unsafe fn memchr2_avx2(needle1: u8, needle2: u8, mut beg: *const u8, end: *const } } +#[cfg(target_arch = "loongarch64")] +unsafe fn memchr2_dispatch(needle1: u8, needle2: u8, beg: *const u8, end: *const u8) -> *const u8 { + use std::arch::is_loongarch_feature_detected; + + let func = if is_loongarch_feature_detected!("lasx") { + memchr2_lasx + } else if is_loongarch_feature_detected!("lsx") { + memchr2_lsx + } else { + memchr2_fallback + }; + unsafe { MEMCHR2_DISPATCH = func }; + unsafe { func(needle1, needle2, beg, end) } +} + +#[cfg(target_arch = "loongarch64")] +#[target_feature(enable = "lasx")] +unsafe fn memchr2_lasx(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 { + unsafe { + use std::arch::loongarch64::*; + use std::mem::transmute as T; + + let n1 = lasx_xvreplgr2vr_b(needle1 as i32); + let n2 = lasx_xvreplgr2vr_b(needle2 as i32); + + let off = beg.align_offset(32); + if off != 0 && off < end.offset_from_unsigned(beg) { + beg = memchr2_lsx(needle1, needle2, beg, beg.add(off)); + } + + while end.offset_from_unsigned(beg) >= 32 { + let v = lasx_xvld::<0>(beg as *const _); + let a = lasx_xvseq_b(v, n1); + let b = lasx_xvseq_b(v, n2); + let c = lasx_xvor_v(T(a), T(b)); + let m = lasx_xvmskltz_b(T(c)); + let l = lasx_xvpickve2gr_wu::<0>(T(m)); + let h = lasx_xvpickve2gr_wu::<4>(T(m)); + let m = (h << 16) | l; + + if m != 0 { + return beg.add(m.trailing_zeros() as usize); + } + + beg = beg.add(32); + } + + memchr2_fallback(needle1, needle2, beg, end) + } +} + +#[cfg(target_arch = "loongarch64")] +#[target_feature(enable = "lsx")] +unsafe fn memchr2_lsx(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 { + unsafe { + use std::arch::loongarch64::*; + use std::mem::transmute as T; + + let n1 = lsx_vreplgr2vr_b(needle1 as i32); + let n2 = lsx_vreplgr2vr_b(needle2 as i32); + + let off = beg.align_offset(16); + if off != 0 && off < end.offset_from_unsigned(beg) { + beg = memchr2_fallback(needle1, needle2, beg, beg.add(off)); + } + + while end.offset_from_unsigned(beg) >= 16 { + let v = lsx_vld::<0>(beg as *const _); + let a = lsx_vseq_b(v, n1); + let b = lsx_vseq_b(v, n2); + let c = lsx_vor_v(T(a), T(b)); + let m = lsx_vmskltz_b(T(c)); + let m = lsx_vpickve2gr_wu::<0>(T(m)); + + if m != 0 { + return beg.add(m.trailing_zeros() as usize); + } + + beg = beg.add(16); + } + + memchr2_fallback(needle1, needle2, beg, end) + } +} + #[cfg(target_arch = "aarch64")] unsafe fn memchr2_neon(needle1: u8, needle2: u8, mut beg: *const u8, end: *const u8) -> *const u8 { unsafe {