Tracking issue: Tracking Issue for get_many_mut · Issue #104642 · rust-lang/rust · GitHub
I would like to move the stabilization of that forwards. That begins with answering the unresolved question.
One question that is presented throughout the thread is whether the current API, taking an array of indices directly, is the best. An alternative that was proposed is to encapsulate the invariant of "disjoint indices" in a type. This has the benefit that it can be reused. It also allows to distinguish between "overlapping indices" and "out-of-bounds indices" error conditions (of course, this can be done with the current API, but that requires making the error type no longer a ZST, which may harm performance a bit).
I quite like this idea. There are two main ways to implement this: one is to keep track of only the invariant, the other is also to find the max index while checking that. The last one has the advantage that inside the method, we can check with one instruction whether any of the indices is OOB.
However, when trying to implement the second approach, I found that LLVM generates worse code for it when using few indices (2), and in benchmarks it performs worse too (487.40ps vs. 555.49ps with criterion
). This is something I think we cannot afford, given that probably most of the usages will be with 2 indices.
Code:
#![feature(array_windows, get_many_mut)]
#[derive(Clone, Copy)]
pub struct DisjointIndices<const N: usize> {
indices: [usize; N],
max_index: usize,
}
impl<const N: usize> DisjointIndices<N> {
#[inline]
pub fn new(mut indices: [usize; N]) -> Result<Self, ()> {
if N > 10 {
indices.sort_unstable();
return Self::new_sorted(indices);
}
let mut max_index = 0;
let mut valid = true;
for (i, &idx) in indices.iter().enumerate() {
max_index = std::cmp::max(idx, max_index);
for &idx2 in &indices[..i] {
valid &= idx != idx2;
}
}
match valid {
true => Ok(Self { indices, max_index }),
// true => Ok(Self { indices }),
false => Err(()),
}
}
#[inline]
pub fn new_sorted(indices: [usize; N]) -> Result<Self, ()> {
if N == 0 {
return Ok(Self {
indices,
max_index: 0,
});
}
let mut valid = true;
for [prev, next] in indices.array_windows() {
valid &= next > prev;
}
match valid {
true => Ok(Self {
max_index: *indices.last().unwrap(),
indices,
}),
false => Err(()),
}
}
}
#[inline]
pub fn my_get_many_mut<'a, T, const N: usize>(
slice: &'a mut [T],
indices: &DisjointIndices<N>,
) -> Result<[&'a mut T; N], ()> {
if N != 0 && indices.max_index >= slice.len() {
// if N != 0 && indices.indices.iter().any(|&i| i >= slice.len()) {
Err(())
} else {
unsafe { Ok(slice.get_many_unchecked_mut(indices.indices)) }
}
}
The alternative (commented out) performed just as well as current get_many_mut()
with few indices, but becomes worse with many (100) indices.
So, if we want maximum performance, and we want DisjointIndices
, the only way I can see is by switching decision based on N
. If N
is less than <insert constant based on benchmarking here> check the indices in get_many_mut()
, otherwise check them in DisjointIndices
.
The disadvantages are that we get a redundant field in case of small indices (minor), and also that we cannot provide a fully zero-cost DisjointIndices::new_unchecked()
API, because it will have to calculate the maximum index, as passing manually it won't work for small N
, and it will be weird and inconsistent to use/not use the passed number based on an implementation detail constant. I am already opposed to such API that takes the maximum index, however, as I feel it exposes an implementation detail
What are your thoughts?