Special Case Unsigned Integer Range `Iterator::sum`

I'm not 100% sure what the process for making this change would be, or exactly which files in the std crate would need to be changed, but there is an optimization for sums of consecutive unsigned integers that I would like to implement, if it's welcome. (I would like to special-case Iterator::sum for ops::Range<u*>, and ops::RangeInclusive<u*>).

The best explanation of this formula I've found online is at Integer Sum Formula (Gauss Sum) | integer-sum-formula.

Basically for any consecutive sequence of integers from 1 to n inclusive has a sum equal to n*(n+1)/2. This also works for ranges from 0 to n (inclusive) because sum(0..=n) == 0 + sum(1..=n). We can generalize this to work for any range by simply subtracting the lower sum from the upper sum. For example, to do sum(5..=10) we would do sum(1..=10) - sum(1..=4) (10*11/2 - 4*5/2 == 55-10 == 45 == 5+6+7+8+9+10 == sum(5..=10).

Thus for any RangeInclusive<u*>: range.start..=range.end we can replace the current implementation of Iterator::sum (which does a slow iterative sum) with range.end*(range.end+1)/2 - range.start*(range.start-1)/2.

The same solution can be extended to ops::Range<u*> except when range.end == u*::max (and Iterator::sum overflows in that case anyway).

I think this will yield a significant performance increase in some (perhaps uncommon) cases, with no drawbacks, but I'm curious for other's thoughts.

This seems reasonable. It'll need specialization internally, but that's nothing new. One thing to keep in mind is that it must consume the iterator.

1 Like

The intermediate values will overflow in cases where the naïve solution doesn't.

Perhaps we do a checked multiplication in that case and revert back to the naive solution in cases of overflow.

Even jumping up to the next largest integer size would probably be faster than iterating through the range, even if that happens to be 128-bit.

3 Likes

Agreed -- overflowing intermediate values is an easily mitigated edge-case, but I'm glad someone mentioned it because it's worth being aware of.

Definitely this shouldn't include bombs like that.

It should be simple enough to just make sure the intermediate values cannot overflow, e.g. n * (n + 1) / 2 becomes

if n % 2 == 0 {
    (n / 2) * (n + 1)
} else {
    ((n + 1) / 2) * n
}
5 Likes

This also works, happy to implement it this way. Casting to a larger type may be faster, since it would prevent the branching logic of the if statement, but that's a good solution for something like u128.

Consider computing the sum of the range directly. One easy way to derive the formula is that it is the number of terms times the average of terms, and with a consecutive range like this, those are just last-first+1 and (last+first)/2.

1 Like

RFC opened btw Special-cased performance improvement for `Iterator::sum` on `Range<u*>` and `RangeInclusive<u*>` by Alfriadox · Pull Request #3481 · rust-lang/rfcs · GitHub

Branchless:

#[inline(never)]
pub fn sum_to(n: u64) -> u64 {
    let odd: u64 = (n % 2 == 1).into();
    ((n + (!odd)) / 2) * (n + odd)
}
3 Likes

This seems like an ideal solution, I'll add it to the RFC.

This doesn't need an RFC, by the way. It's an implementation detail, so going straight for a PR should be acceptable.

4 Likes

Well, LLVM already does this for both, actually: https://rust.godbolt.org/z/1M1hThe1h (note the lack of backwards jumps).

So I don't know that it's worth bothering to do it in the Rust code -- it's not "no drawbacks" because it's most code to maintain and more metadata to load in libcore (which slows down compilation for everyone, if perhaps not by that much).

Also, I'm curious what real code you have where this actually happens? Summing something that's just a range seems rare.

4 Likes

And it goes for the version with a branch. Whoops. I'm curious if there's any benefit to guaranteeing this optimization from the rust side; maybe the only benefit is be to debug builds. That actually does seem marginally beneficial, as it's an algorithmic difference between the two compiled versions, not just an "optimized is faster".

isn't this just n % 2? Seems a bit funny to do the bool-roundabout

1 Like

yeah it is

The branch checks whether the range is empty (start >= end). It's possible to make that check branch-less using conditional moves, but branch-less isn't always a performance win.

This assumes start == 0, so it's a different problem.

There are a few mistakes here: !odd is a bitwise negation rather than a boolean negation, and you are dividing the odd number rather than the even number. Fixing the mistakes, we get:

pub fn sum_to(n: u64) -> u64 {
    ((n + n % 2) / 2) * (n | 1)
}

However, this still has a problem: it will still give the wrong answer for n = u64::MAX in release mode due to the overflow on addition.

I think this works for (0..=n).sum():

pub fn sum_to(n: u64) -> u64 {
    (n / 2 + n % 2) * (n | 1)
}

and is maybe slightly simpler than what LLVM generates for this special case.

2 Likes

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.