core/portable-simd/crates/core_simd/src/
masks.rs

1//! Types and traits associated with masking elements of vectors.
2//! Types representing
3#![allow(non_camel_case_types)]
4
5#[cfg_attr(
6    not(all(target_arch = "x86_64", target_feature = "avx512f")),
7    path = "masks/full_masks.rs"
8)]
9#[cfg_attr(
10    all(target_arch = "x86_64", target_feature = "avx512f"),
11    path = "masks/bitmask.rs"
12)]
13mod mask_impl;
14
15use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
16use core::cmp::Ordering;
17use core::{fmt, mem};
18
19mod sealed {
20    use super::*;
21
22    /// Not only does this seal the `MaskElement` trait, but these functions prevent other traits
23    /// from bleeding into the parent bounds.
24    ///
25    /// For example, `eq` could be provided by requiring `MaskElement: PartialEq`, but that would
26    /// prevent us from ever removing that bound, or from implementing `MaskElement` on
27    /// non-`PartialEq` types in the future.
28    pub trait Sealed {
29        fn valid<const N: usize>(values: Simd<Self, N>) -> bool
30        where
31            LaneCount<N>: SupportedLaneCount,
32            Self: SimdElement;
33
34        fn eq(self, other: Self) -> bool;
35
36        fn to_usize(self) -> usize;
37        fn max_unsigned() -> u64;
38
39        type Unsigned: SimdElement;
40
41        const TRUE: Self;
42
43        const FALSE: Self;
44    }
45}
46use sealed::Sealed;
47
48/// Marker trait for types that may be used as SIMD mask elements.
49///
50/// # Safety
51/// Type must be a signed integer.
52pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
53
54macro_rules! impl_element {
55    { $ty:ty, $unsigned:ty } => {
56        impl Sealed for $ty {
57            #[inline]
58            fn valid<const N: usize>(value: Simd<Self, N>) -> bool
59            where
60                LaneCount<N>: SupportedLaneCount,
61            {
62                // We can't use `Simd` directly, because `Simd`'s functions call this function and
63                // we will end up with an infinite loop.
64                // Safety: `value` is an integer vector
65                unsafe {
66                    use core::intrinsics::simd;
67                    let falses: Simd<Self, N> = simd::simd_eq(value, Simd::splat(0 as _));
68                    let trues: Simd<Self, N> = simd::simd_eq(value, Simd::splat(-1 as _));
69                    let valid: Simd<Self, N> = simd::simd_or(falses, trues);
70                    simd::simd_reduce_all(valid)
71                }
72            }
73
74            #[inline]
75            fn eq(self, other: Self) -> bool { self == other }
76
77            #[inline]
78            fn to_usize(self) -> usize {
79                self as usize
80            }
81
82            #[inline]
83            fn max_unsigned() -> u64 {
84                <$unsigned>::MAX as u64
85            }
86
87            type Unsigned = $unsigned;
88
89            const TRUE: Self = -1;
90            const FALSE: Self = 0;
91        }
92
93        // Safety: this is a valid mask element type
94        unsafe impl MaskElement for $ty {}
95    }
96}
97
98impl_element! { i8, u8 }
99impl_element! { i16, u16 }
100impl_element! { i32, u32 }
101impl_element! { i64, u64 }
102impl_element! { isize, usize }
103
104/// A SIMD vector mask for `N` elements of width specified by `Element`.
105///
106/// Masks represent boolean inclusion/exclusion on a per-element basis.
107///
108/// The layout of this type is unspecified, and may change between platforms
109/// and/or Rust versions, and code should not assume that it is equivalent to
110/// `[T; N]`.
111#[repr(transparent)]
112pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
113where
114    T: MaskElement,
115    LaneCount<N>: SupportedLaneCount;
116
117impl<T, const N: usize> Copy for Mask<T, N>
118where
119    T: MaskElement,
120    LaneCount<N>: SupportedLaneCount,
121{
122}
123
124impl<T, const N: usize> Clone for Mask<T, N>
125where
126    T: MaskElement,
127    LaneCount<N>: SupportedLaneCount,
128{
129    #[inline]
130    fn clone(&self) -> Self {
131        *self
132    }
133}
134
135impl<T, const N: usize> Mask<T, N>
136where
137    T: MaskElement,
138    LaneCount<N>: SupportedLaneCount,
139{
140    /// Constructs a mask by setting all elements to the given value.
141    #[inline]
142    pub fn splat(value: bool) -> Self {
143        Self(mask_impl::Mask::splat(value))
144    }
145
146    /// Converts an array of bools to a SIMD mask.
147    #[inline]
148    pub fn from_array(array: [bool; N]) -> Self {
149        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
150        //     true:    0b_0000_0001
151        //     false:   0b_0000_0000
152        // Thus, an array of bools is also a valid array of bytes: [u8; N]
153        // This would be hypothetically valid as an "in-place" transmute,
154        // but these are "dependently-sized" types, so copy elision it is!
155        unsafe {
156            let bytes: [u8; N] = mem::transmute_copy(&array);
157            let bools: Simd<i8, N> =
158                core::intrinsics::simd::simd_ne(Simd::from_array(bytes), Simd::splat(0u8));
159            Mask::from_int_unchecked(core::intrinsics::simd::simd_cast(bools))
160        }
161    }
162
163    /// Converts a SIMD mask to an array of bools.
164    #[inline]
165    pub fn to_array(self) -> [bool; N] {
166        // This follows mostly the same logic as from_array.
167        // SAFETY: Rust's bool has a layout of 1 byte (u8) with a value of
168        //     true:    0b_0000_0001
169        //     false:   0b_0000_0000
170        // Thus, an array of bools is also a valid array of bytes: [u8; N]
171        // Since our masks are equal to integers where all bits are set,
172        // we can simply convert them to i8s, and then bitand them by the
173        // bitpattern for Rust's "true" bool.
174        // This would be hypothetically valid as an "in-place" transmute,
175        // but these are "dependently-sized" types, so copy elision it is!
176        unsafe {
177            let mut bytes: Simd<i8, N> = core::intrinsics::simd::simd_cast(self.to_int());
178            bytes &= Simd::splat(1i8);
179            mem::transmute_copy(&bytes)
180        }
181    }
182
183    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
184    /// represents `true`.
185    ///
186    /// # Safety
187    /// All elements must be either 0 or -1.
188    #[inline]
189    #[must_use = "method returns a new mask and does not mutate the original value"]
190    pub unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
191        // Safety: the caller must confirm this invariant
192        unsafe {
193            core::intrinsics::assume(<T as Sealed>::valid(value));
194            Self(mask_impl::Mask::from_int_unchecked(value))
195        }
196    }
197
198    /// Converts a vector of integers to a mask, where 0 represents `false` and -1
199    /// represents `true`.
200    ///
201    /// # Panics
202    /// Panics if any element is not 0 or -1.
203    #[inline]
204    #[must_use = "method returns a new mask and does not mutate the original value"]
205    #[track_caller]
206    pub fn from_int(value: Simd<T, N>) -> Self {
207        assert!(T::valid(value), "all values must be either 0 or -1",);
208        // Safety: the validity has been checked
209        unsafe { Self::from_int_unchecked(value) }
210    }
211
212    /// Converts the mask to a vector of integers, where 0 represents `false` and -1
213    /// represents `true`.
214    #[inline]
215    #[must_use = "method returns a new vector and does not mutate the original value"]
216    pub fn to_int(self) -> Simd<T, N> {
217        self.0.to_int()
218    }
219
220    /// Converts the mask to a mask of any other element size.
221    #[inline]
222    #[must_use = "method returns a new mask and does not mutate the original value"]
223    pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
224        Mask(self.0.convert())
225    }
226
227    /// Tests the value of the specified element.
228    ///
229    /// # Safety
230    /// `index` must be less than `self.len()`.
231    #[inline]
232    #[must_use = "method returns a new bool and does not mutate the original value"]
233    pub unsafe fn test_unchecked(&self, index: usize) -> bool {
234        // Safety: the caller must confirm this invariant
235        unsafe { self.0.test_unchecked(index) }
236    }
237
238    /// Tests the value of the specified element.
239    ///
240    /// # Panics
241    /// Panics if `index` is greater than or equal to the number of elements in the vector.
242    #[inline]
243    #[must_use = "method returns a new bool and does not mutate the original value"]
244    #[track_caller]
245    pub fn test(&self, index: usize) -> bool {
246        assert!(index < N, "element index out of range");
247        // Safety: the element index has been checked
248        unsafe { self.test_unchecked(index) }
249    }
250
251    /// Sets the value of the specified element.
252    ///
253    /// # Safety
254    /// `index` must be less than `self.len()`.
255    #[inline]
256    pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
257        // Safety: the caller must confirm this invariant
258        unsafe {
259            self.0.set_unchecked(index, value);
260        }
261    }
262
263    /// Sets the value of the specified element.
264    ///
265    /// # Panics
266    /// Panics if `index` is greater than or equal to the number of elements in the vector.
267    #[inline]
268    #[track_caller]
269    pub fn set(&mut self, index: usize, value: bool) {
270        assert!(index < N, "element index out of range");
271        // Safety: the element index has been checked
272        unsafe {
273            self.set_unchecked(index, value);
274        }
275    }
276
277    /// Returns true if any element is set, or false otherwise.
278    #[inline]
279    #[must_use = "method returns a new bool and does not mutate the original value"]
280    pub fn any(self) -> bool {
281        self.0.any()
282    }
283
284    /// Returns true if all elements are set, or false otherwise.
285    #[inline]
286    #[must_use = "method returns a new bool and does not mutate the original value"]
287    pub fn all(self) -> bool {
288        self.0.all()
289    }
290
291    /// Creates a bitmask from a mask.
292    ///
293    /// Each bit is set if the corresponding element in the mask is `true`.
294    /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
295    #[inline]
296    #[must_use = "method returns a new integer and does not mutate the original value"]
297    pub fn to_bitmask(self) -> u64 {
298        self.0.to_bitmask_integer()
299    }
300
301    /// Creates a mask from a bitmask.
302    ///
303    /// For each bit, if it is set, the corresponding element in the mask is set to `true`.
304    /// If the mask contains more than 64 elements, the remainder are set to `false`.
305    #[inline]
306    #[must_use = "method returns a new mask and does not mutate the original value"]
307    pub fn from_bitmask(bitmask: u64) -> Self {
308        Self(mask_impl::Mask::from_bitmask_integer(bitmask))
309    }
310
311    /// Finds the index of the first set element.
312    ///
313    /// ```
314    /// # #![feature(portable_simd)]
315    /// # #[cfg(feature = "as_crate")] use core_simd::simd;
316    /// # #[cfg(not(feature = "as_crate"))] use core::simd;
317    /// # use simd::mask32x8;
318    /// assert_eq!(mask32x8::splat(false).first_set(), None);
319    /// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
320    ///
321    /// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
322    /// assert_eq!(mask.first_set(), Some(1));
323    /// ```
324    #[inline]
325    #[must_use = "method returns the index and does not mutate the original value"]
326    pub fn first_set(self) -> Option<usize> {
327        // If bitmasks are efficient, using them is better
328        if cfg!(target_feature = "sse") && N <= 64 {
329            let tz = self.to_bitmask().trailing_zeros();
330            return if tz == 64 { None } else { Some(tz as usize) };
331        }
332
333        // To find the first set index:
334        // * create a vector 0..N
335        // * replace unset mask elements in that vector with -1
336        // * perform _unsigned_ reduce-min
337        // * check if the result is -1 or an index
338
339        let index = Simd::from_array(
340            const {
341                let mut index = [0; N];
342                let mut i = 0;
343                while i < N {
344                    index[i] = i;
345                    i += 1;
346                }
347                index
348            },
349        );
350
351        // Safety: the input and output are integer vectors
352        let index: Simd<T, N> = unsafe { core::intrinsics::simd::simd_cast(index) };
353
354        let masked_index = self.select(index, Self::splat(true).to_int());
355
356        // Safety: the input and output are integer vectors
357        let masked_index: Simd<T::Unsigned, N> =
358            unsafe { core::intrinsics::simd::simd_cast(masked_index) };
359
360        // Safety: the input is an integer vector
361        let min_index: T::Unsigned =
362            unsafe { core::intrinsics::simd::simd_reduce_min(masked_index) };
363
364        // Safety: the return value is the unsigned version of T
365        let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
366
367        if min_index.eq(T::TRUE) {
368            None
369        } else {
370            Some(min_index.to_usize())
371        }
372    }
373}
374
375// vector/array conversion
376impl<T, const N: usize> From<[bool; N]> for Mask<T, N>
377where
378    T: MaskElement,
379    LaneCount<N>: SupportedLaneCount,
380{
381    #[inline]
382    fn from(array: [bool; N]) -> Self {
383        Self::from_array(array)
384    }
385}
386
387impl<T, const N: usize> From<Mask<T, N>> for [bool; N]
388where
389    T: MaskElement,
390    LaneCount<N>: SupportedLaneCount,
391{
392    #[inline]
393    fn from(vector: Mask<T, N>) -> Self {
394        vector.to_array()
395    }
396}
397
398impl<T, const N: usize> Default for Mask<T, N>
399where
400    T: MaskElement,
401    LaneCount<N>: SupportedLaneCount,
402{
403    #[inline]
404    fn default() -> Self {
405        Self::splat(false)
406    }
407}
408
409impl<T, const N: usize> PartialEq for Mask<T, N>
410where
411    T: MaskElement + PartialEq,
412    LaneCount<N>: SupportedLaneCount,
413{
414    #[inline]
415    fn eq(&self, other: &Self) -> bool {
416        self.0 == other.0
417    }
418}
419
420impl<T, const N: usize> PartialOrd for Mask<T, N>
421where
422    T: MaskElement + PartialOrd,
423    LaneCount<N>: SupportedLaneCount,
424{
425    #[inline]
426    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
427        self.0.partial_cmp(&other.0)
428    }
429}
430
431impl<T, const N: usize> fmt::Debug for Mask<T, N>
432where
433    T: MaskElement + fmt::Debug,
434    LaneCount<N>: SupportedLaneCount,
435{
436    #[inline]
437    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438        f.debug_list()
439            .entries((0..N).map(|i| self.test(i)))
440            .finish()
441    }
442}
443
444impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
445where
446    T: MaskElement,
447    LaneCount<N>: SupportedLaneCount,
448{
449    type Output = Self;
450    #[inline]
451    fn bitand(self, rhs: Self) -> Self {
452        Self(self.0 & rhs.0)
453    }
454}
455
456impl<T, const N: usize> core::ops::BitAnd<bool> for Mask<T, N>
457where
458    T: MaskElement,
459    LaneCount<N>: SupportedLaneCount,
460{
461    type Output = Self;
462    #[inline]
463    fn bitand(self, rhs: bool) -> Self {
464        self & Self::splat(rhs)
465    }
466}
467
468impl<T, const N: usize> core::ops::BitAnd<Mask<T, N>> for bool
469where
470    T: MaskElement,
471    LaneCount<N>: SupportedLaneCount,
472{
473    type Output = Mask<T, N>;
474    #[inline]
475    fn bitand(self, rhs: Mask<T, N>) -> Mask<T, N> {
476        Mask::splat(self) & rhs
477    }
478}
479
480impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
481where
482    T: MaskElement,
483    LaneCount<N>: SupportedLaneCount,
484{
485    type Output = Self;
486    #[inline]
487    fn bitor(self, rhs: Self) -> Self {
488        Self(self.0 | rhs.0)
489    }
490}
491
492impl<T, const N: usize> core::ops::BitOr<bool> for Mask<T, N>
493where
494    T: MaskElement,
495    LaneCount<N>: SupportedLaneCount,
496{
497    type Output = Self;
498    #[inline]
499    fn bitor(self, rhs: bool) -> Self {
500        self | Self::splat(rhs)
501    }
502}
503
504impl<T, const N: usize> core::ops::BitOr<Mask<T, N>> for bool
505where
506    T: MaskElement,
507    LaneCount<N>: SupportedLaneCount,
508{
509    type Output = Mask<T, N>;
510    #[inline]
511    fn bitor(self, rhs: Mask<T, N>) -> Mask<T, N> {
512        Mask::splat(self) | rhs
513    }
514}
515
516impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
517where
518    T: MaskElement,
519    LaneCount<N>: SupportedLaneCount,
520{
521    type Output = Self;
522    #[inline]
523    fn bitxor(self, rhs: Self) -> Self::Output {
524        Self(self.0 ^ rhs.0)
525    }
526}
527
528impl<T, const N: usize> core::ops::BitXor<bool> for Mask<T, N>
529where
530    T: MaskElement,
531    LaneCount<N>: SupportedLaneCount,
532{
533    type Output = Self;
534    #[inline]
535    fn bitxor(self, rhs: bool) -> Self::Output {
536        self ^ Self::splat(rhs)
537    }
538}
539
540impl<T, const N: usize> core::ops::BitXor<Mask<T, N>> for bool
541where
542    T: MaskElement,
543    LaneCount<N>: SupportedLaneCount,
544{
545    type Output = Mask<T, N>;
546    #[inline]
547    fn bitxor(self, rhs: Mask<T, N>) -> Self::Output {
548        Mask::splat(self) ^ rhs
549    }
550}
551
552impl<T, const N: usize> core::ops::Not for Mask<T, N>
553where
554    T: MaskElement,
555    LaneCount<N>: SupportedLaneCount,
556{
557    type Output = Mask<T, N>;
558    #[inline]
559    fn not(self) -> Self::Output {
560        Self(!self.0)
561    }
562}
563
564impl<T, const N: usize> core::ops::BitAndAssign for Mask<T, N>
565where
566    T: MaskElement,
567    LaneCount<N>: SupportedLaneCount,
568{
569    #[inline]
570    fn bitand_assign(&mut self, rhs: Self) {
571        self.0 = self.0 & rhs.0;
572    }
573}
574
575impl<T, const N: usize> core::ops::BitAndAssign<bool> for Mask<T, N>
576where
577    T: MaskElement,
578    LaneCount<N>: SupportedLaneCount,
579{
580    #[inline]
581    fn bitand_assign(&mut self, rhs: bool) {
582        *self &= Self::splat(rhs);
583    }
584}
585
586impl<T, const N: usize> core::ops::BitOrAssign for Mask<T, N>
587where
588    T: MaskElement,
589    LaneCount<N>: SupportedLaneCount,
590{
591    #[inline]
592    fn bitor_assign(&mut self, rhs: Self) {
593        self.0 = self.0 | rhs.0;
594    }
595}
596
597impl<T, const N: usize> core::ops::BitOrAssign<bool> for Mask<T, N>
598where
599    T: MaskElement,
600    LaneCount<N>: SupportedLaneCount,
601{
602    #[inline]
603    fn bitor_assign(&mut self, rhs: bool) {
604        *self |= Self::splat(rhs);
605    }
606}
607
608impl<T, const N: usize> core::ops::BitXorAssign for Mask<T, N>
609where
610    T: MaskElement,
611    LaneCount<N>: SupportedLaneCount,
612{
613    #[inline]
614    fn bitxor_assign(&mut self, rhs: Self) {
615        self.0 = self.0 ^ rhs.0;
616    }
617}
618
619impl<T, const N: usize> core::ops::BitXorAssign<bool> for Mask<T, N>
620where
621    T: MaskElement,
622    LaneCount<N>: SupportedLaneCount,
623{
624    #[inline]
625    fn bitxor_assign(&mut self, rhs: bool) {
626        *self ^= Self::splat(rhs);
627    }
628}
629
630macro_rules! impl_from {
631    { $from:ty  => $($to:ty),* } => {
632        $(
633        impl<const N: usize> From<Mask<$from, N>> for Mask<$to, N>
634        where
635            LaneCount<N>: SupportedLaneCount,
636        {
637            #[inline]
638            fn from(value: Mask<$from, N>) -> Self {
639                value.cast()
640            }
641        }
642        )*
643    }
644}
645impl_from! { i8 => i16, i32, i64, isize }
646impl_from! { i16 => i32, i64, isize, i8 }
647impl_from! { i32 => i64, isize, i8, i16 }
648impl_from! { i64 => isize, i8, i16, i32 }
649impl_from! { isize => i8, i16, i32, i64 }
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy