Average function for primitives

The average of two numbers can be calculated by (a + b) / 2 naively, but if a and b is i32::MAX, the result will be overflowed. Therefore we should avoid overflow carefully like min(a, b) + abs(a / 2 - b / 2).

So I think this implementation may be useful if this is provided through impl of each primitive integer types.

I searched this topic in this forum and RFCs, but I couldn't find. Is this previously proposed? How about this idea?

5 Likes

The above implementation is wrong. I should care about a % 2 and b % 2 in some cases. About the implementation, this thread may be helpful.

If a is i32::MIN then this also underflows

Yeah, it's tricky to define this to have consistent rounding behavior and avoid overflow in all cases. IIRC, I've sometimes implemented it like this:

(a >> 1).wrapping_add(b >> 1).wrapping_add(a & b & 0x1)

This rounds towards neginf; the third term is needed to correct for the case where both the first terms round towards neginf separately.

4 Likes

I think the difficulty of correct rounding implementation may become a motivation to provide this function through standard library.

4 Likes

Given that rust concretely defines what >> should be (instead of not like in C (iirc)) then I agree that this would be good to have

Prior art: C++20 std::midpoint from P0811r3.

I think std::lerp is over-specified to the point where it's not useful for it's intended purpose (games/gfx) since handling edge cases is too expensive, but I do think that std::midpoint's specification describes the problem well and is generally what you want from a midpoint/mean/average.

There might be a way to do it branchless, but P0811r3 provides the example implementation of (roughly translated into pseudo Rust by me)

const fn midpoint<Int>(a: Int, b: Int) -> Int {
    if a > b {
        a.wrapping_add(a.wrapping_sub(b)/2)
    } else {
        a.wrapping_sub(b.wrapping_sub(a)/2)
    }
}

const fn midpoint<Float>(a: Float, b: Float) -> Float {
    if a.is_normal() && b.is_normal() {
        a/2 + b/2
    } else {
        (a+b)/2
    }
}

Do note that a+(b-a)/2 works no matter the sign so long as there isn't overflow. The trivial (a+b)/2 truncates towards 0; I'm not exactly sure the rounding behavior of this integer implementation is, but P0811r3 claims it's what you'd typically want from an e.g. binary search midpoint.

5 Likes

The num crate, which is practically standard, provides this function with the Average trait. One way to compute it efficiently is with (a&b) + ((a^b) >> 1) (using arithmetic shift if the values are signed). On Intel processors you can use add and then rcr.

3 Likes

That's really elegant. Is there a nice way to get LLVM to emit that? I didn't have any luck using the things I know how to get rustc to emit...

I don't know how to get LLVM to emit rcr, but the portable alternative (that does and, xor, shift, add) is probably better: it is correct for signed integers as well (replacing logical shift with arithmetic shift), it can be vectorized, and has the same latency as rcr on at least some microarchitectures. That being said, I've never written something where the performance was contingent on integer averages.

I mention the LLVM incantation because needing to emit LLVM intrinsics to get optimal codegen is one of the reasons it might make sense to have a new method in core.

If the num crate can implement it reasonably without needing anything extra, then it might fall in the same category as div_rem: helpers in num that I do like, but that might not be worth having in core.

Honestly I wish there were more methods from the num crate in core, not fewer. There are some methods I am recreating on my end because I don't want to bring in a dependency that would increase the size of my crate by a third. Having the methods in core would avoid this altogether.

3 Likes

The rcr alternative isn't necessarily optimal because rcr executes in two cycles (when its immediate operand is 1). The and and xor can execute simultaneously on separate ports, so I think that alternative ends up having the same latency. In general I find it challenging to get LLVM to recognize opportunities to use the carry flag.