# Special Case Unsigned Integer Range `Iterator::sum`

I'm not 100% sure what the process for making this change would be, or exactly which files in the std crate would need to be changed, but there is an optimization for sums of consecutive unsigned integers that I would like to implement, if it's welcome. (I would like to special-case `Iterator::sum` for `ops::Range<u*>`, and `ops::RangeInclusive<u*>`).

The best explanation of this formula I've found online is at Integer Sum Formula (Gauss Sum) | integer-sum-formula.

Basically for any consecutive sequence of integers from 1 to n inclusive has a sum equal to `n*(n+1)/2`. This also works for ranges from 0 to n (inclusive) because `sum(0..=n) == 0 + sum(1..=n)`. We can generalize this to work for any range by simply subtracting the lower sum from the upper sum. For example, to do `sum(5..=10)` we would do `sum(1..=10) - sum(1..=4)` (`10*11/2 - 4*5/2 == 55-10 == 45 == 5+6+7+8+9+10 == sum(5..=10)`.

Thus for any `RangeInclusive<u*>`: `range.start..=range.end` we can replace the current implementation of `Iterator::sum` (which does a slow iterative sum) with `range.end*(range.end+1)/2 - range.start*(range.start-1)/2`.

The same solution can be extended to `ops::Range<u*>` except when `range.end == u*::max` (and `Iterator::sum` overflows in that case anyway).

I think this will yield a significant performance increase in some (perhaps uncommon) cases, with no drawbacks, but I'm curious for other's thoughts.

This seems reasonable. It'll need specialization internally, but that's nothing new. One thing to keep in mind is that it must consume the iterator.

1 Like

The intermediate values will overflow in cases where the naïve solution doesn't.

Perhaps we do a checked multiplication in that case and revert back to the naive solution in cases of overflow.

Even jumping up to the next largest integer size would probably be faster than iterating through the range, even if that happens to be 128-bit.

2 Likes

Agreed -- overflowing intermediate values is an easily mitigated edge-case, but I'm glad someone mentioned it because it's worth being aware of.

Definitely this shouldn't include bombs like that.

It should be simple enough to just make sure the intermediate values cannot overflow, e.g. `n * (n + 1) / 2` becomes

``````if n % 2 == 0 {
(n / 2) * (n + 1)
} else {
((n + 1) / 2) * n
}
``````
4 Likes

This also works, happy to implement it this way. Casting to a larger type may be faster, since it would prevent the branching logic of the `if` statement, but that's a good solution for something like `u128`.

Consider computing the sum of the range directly. One easy way to derive the formula is that it is the number of terms times the average of terms, and with a consecutive range like this, those are just `last-first+1` and `(last+first)/2`.

1 Like

Branchless:

``````#[inline(never)]
pub fn sum_to(n: u64) -> u64 {
let odd: u64 = (n % 2 == 1).into();
((n + (!odd)) / 2) * (n + odd)
}
``````
2 Likes

This seems like an ideal solution, I'll add it to the RFC.

This doesn't need an RFC, by the way. It's an implementation detail, so going straight for a PR should be acceptable.

3 Likes

Well, LLVM already does this for both, actually: https://rust.godbolt.org/z/1M1hThe1h (note the lack of backwards jumps).

So I don't know that it's worth bothering to do it in the Rust code -- it's not "no drawbacks" because it's most code to maintain and more metadata to load in libcore (which slows down compilation for everyone, if perhaps not by that much).

Also, I'm curious what real code you have where this actually happens? Summing something that's just a range seems rare.

3 Likes

And it goes for the version with a branch. Whoops. I'm curious if there's any benefit to guaranteeing this optimization from the rust side; maybe the only benefit is be to debug builds. That actually does seem marginally beneficial, as it's an algorithmic difference between the two compiled versions, not just an "optimized is faster".

isn't this just `n % 2`? Seems a bit funny to do the bool-roundabout

1 Like

yeah it is

The branch checks whether the range is empty (start >= end). It's possible to make that check branch-less using conditional moves, but branch-less isn't always a performance win.

This assumes start == 0, so it's a different problem.

There are a few mistakes here: `!odd` is a bitwise negation rather than a boolean negation, and you are dividing the odd number rather than the even number. Fixing the mistakes, we get:

``````pub fn sum_to(n: u64) -> u64 {
((n + n % 2) / 2) * (n | 1)
}
``````

However, this still has a problem: it will still give the wrong answer for `n = u64::MAX` in release mode due to the overflow on addition.

I think this works for `(0..=n).sum()`:

``````pub fn sum_to(n: u64) -> u64 {
(n / 2 + n % 2) * (n | 1)
}
``````

and is maybe slightly simpler than what LLVM generates for this special case.