diff --git a/src/lib.rs b/src/lib.rs index d6e64d5e73fe..6c5f94fa6827 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,11 @@ maybe_uninit_slice, maybe_uninit_uninit_array_transpose )] +#![cfg_attr( + target_arch = "loongarch64", + feature(stdarch_loongarch, stdarch_loongarch_feature_detection, loongarch_target_feature), + allow(clippy::incompatible_msrv) +)] #![allow(clippy::missing_transmute_annotations, clippy::new_without_default, stable_features)] #[macro_use] diff --git a/src/simd/memset.rs b/src/simd/memset.rs index 01285ceabd70..8e8e7acea2e3 100644 --- a/src/simd/memset.rs +++ b/src/simd/memset.rs @@ -72,7 +72,7 @@ pub fn memset(dst: &mut [T], val: T) { #[inline] fn memset_raw(beg: *mut u8, end: *mut u8, val: u64) { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))] return unsafe { MEMSET_DISPATCH(beg, end, val) }; #[cfg(target_arch = "aarch64")] @@ -108,7 +108,7 @@ unsafe fn memset_fallback(mut beg: *mut u8, end: *mut u8, val: u64) { } } -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "loongarch64"))] static mut MEMSET_DISPATCH: unsafe fn(beg: *mut u8, end: *mut u8, val: u64) = memset_dispatch; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] @@ -235,6 +235,130 @@ fn memset_avx2(mut beg: *mut u8, end: *mut u8, val: u64) { } } +#[cfg(target_arch = "loongarch64")] +fn memset_dispatch(beg: *mut u8, end: *mut u8, val: u64) { + use std::arch::is_loongarch_feature_detected; + + let func = if is_loongarch_feature_detected!("lasx") { + memset_lasx + } else if is_loongarch_feature_detected!("lsx") { + memset_lsx + } else { + memset_fallback + }; + unsafe { MEMSET_DISPATCH = func }; + unsafe { func(beg, end, val) } +} + +#[cfg(target_arch = "loongarch64")] +#[target_feature(enable = "lasx")] +fn memset_lasx(mut beg: *mut u8, end: *mut u8, val: u64) { + unsafe { + use std::arch::loongarch64::*; + use std::mem::transmute as T; + + let fill: v32i8 = T(lasx_xvreplgr2vr_d(val as i64)); + + if end.offset_from_unsigned(beg) >= 32 { + lasx_xvst::<0>(fill, beg as *mut _); + let off = beg.align_offset(32); + beg = beg.add(off); + } + + if end.offset_from_unsigned(beg) >= 128 { + loop { + lasx_xvst::<0>(fill, beg as *mut _); + lasx_xvst::<32>(fill, beg as *mut _); + lasx_xvst::<64>(fill, beg as *mut _); + lasx_xvst::<96>(fill, beg as *mut _); + + beg = beg.add(128); + if end.offset_from_unsigned(beg) < 128 { + break; + } + } + } + + if end.offset_from_unsigned(beg) >= 16 { + let fill: v16i8 = T(lsx_vreplgr2vr_d(val as i64)); + + loop { + lsx_vst::<0>(fill, beg as *mut _); + + beg = beg.add(16); + if end.offset_from_unsigned(beg) < 16 { + break; + } + } + } + + if end.offset_from_unsigned(beg) >= 8 { + // 8-15 bytes + (beg as *mut u64).write_unaligned(val); + (end.sub(8) as *mut u64).write_unaligned(val); + } else if end.offset_from_unsigned(beg) >= 4 { + // 4-7 bytes + (beg as *mut u32).write_unaligned(val as u32); + (end.sub(4) as *mut u32).write_unaligned(val as u32); + } else if end.offset_from_unsigned(beg) >= 2 { + // 2-3 bytes + (beg as *mut u16).write_unaligned(val as u16); + (end.sub(2) as *mut u16).write_unaligned(val as u16); + } else if end.offset_from_unsigned(beg) >= 1 { + // 1 byte + beg.write(val as u8); + } + } +} + +#[cfg(target_arch = "loongarch64")] +#[target_feature(enable = "lsx")] +unsafe fn memset_lsx(mut beg: *mut u8, end: *mut u8, val: u64) { + unsafe { + use std::arch::loongarch64::*; + use std::mem::transmute as T; + + if end.offset_from_unsigned(beg) >= 16 { + let fill: v16i8 = T(lsx_vreplgr2vr_d(val as i64)); + + lsx_vst::<0>(fill, beg as *mut _); + let off = beg.align_offset(16); + beg = beg.add(off); + + while end.offset_from_unsigned(beg) >= 32 { + lsx_vst::<0>(fill, beg as *mut _); + lsx_vst::<16>(fill, beg as *mut _); + + beg = beg.add(32); + } + + if end.offset_from_unsigned(beg) >= 16 { + // 16-31 bytes remaining + lsx_vst::<0>(fill, beg as *mut _); + lsx_vst::<-16>(fill, end as *mut _); + return; + } + } + + if end.offset_from_unsigned(beg) >= 8 { + // 8-15 bytes remaining + (beg as *mut u64).write_unaligned(val); + (end.sub(8) as *mut u64).write_unaligned(val); + } else if end.offset_from_unsigned(beg) >= 4 { + // 4-7 bytes remaining + (beg as *mut u32).write_unaligned(val as u32); + (end.sub(4) as *mut u32).write_unaligned(val as u32); + } else if end.offset_from_unsigned(beg) >= 2 { + // 2-3 bytes remaining + (beg as *mut u16).write_unaligned(val as u16); + (end.sub(2) as *mut u16).write_unaligned(val as u16); + } else if end.offset_from_unsigned(beg) >= 1 { + // 1 byte remaining + beg.write(val as u8); + } + } +} + #[cfg(target_arch = "aarch64")] unsafe fn memset_neon(mut beg: *mut u8, end: *mut u8, val: u64) { unsafe {