Average function for primitives

The average of two numbers can be calculated by (a + b) / 2 naively, but if a and b is i32::MAX, the result will be overflowed. Therefore we should avoid overflow carefully like min(a, b) + abs(a / 2 - b / 2).

So I think this implementation may be useful if this is provided through impl of each primitive integer types.

I searched this topic in this forum and RFCs, but I couldn't find. Is this previously proposed? How about this idea?

6 Likes

The above implementation is wrong. I should care about a % 2 and b % 2 in some cases. About the implementation, this thread may be helpful.

If a is i32::MIN then this also underflows

Yeah, it's tricky to define this to have consistent rounding behavior and avoid overflow in all cases. IIRC, I've sometimes implemented it like this:

(a >> 1).wrapping_add(b >> 1).wrapping_add(a & b & 0x1)

This rounds towards neginf; the third term is needed to correct for the case where both the first terms round towards neginf separately.

5 Likes

I think the difficulty of correct rounding implementation may become a motivation to provide this function through standard library.

5 Likes

Given that rust concretely defines what >> should be (instead of not like in C (iirc)) then I agree that this would be good to have

Prior art: C++20 std::midpoint from P0811r3.

I think std::lerp is over-specified to the point where it's not useful for it's intended purpose (games/gfx) since handling edge cases is too expensive, but I do think that std::midpoint's specification describes the problem well and is generally what you want from a midpoint/mean/average.

There might be a way to do it branchless, but P0811r3 provides the example implementation of (roughly translated into pseudo Rust by me)

const fn midpoint<Int>(a: Int, b: Int) -> Int {
    if a > b {
        a.wrapping_add(a.wrapping_sub(b)/2)
    } else {
        a.wrapping_sub(b.wrapping_sub(a)/2)
    }
}

const fn midpoint<Float>(a: Float, b: Float) -> Float {
    if a.is_normal() && b.is_normal() {
        a/2 + b/2
    } else {
        (a+b)/2
    }
}

Do note that a+(b-a)/2 works no matter the sign so long as there isn't overflow. The trivial (a+b)/2 truncates towards 0; I'm not exactly sure the rounding behavior of this integer implementation is, but P0811r3 claims it's what you'd typically want from an e.g. binary search midpoint.

7 Likes

The num crate, which is practically standard, provides this function with the Average trait. One way to compute it efficiently is with (a&b) + ((a^b) >> 1) (using arithmetic shift if the values are signed). On Intel processors you can use add and then rcr.

3 Likes

That's really elegant. Is there a nice way to get LLVM to emit that? I didn't have any luck using the things I know how to get rustc to emit...

I don't know how to get LLVM to emit rcr, but the portable alternative (that does and, xor, shift, add) is probably better: it is correct for signed integers as well (replacing logical shift with arithmetic shift), it can be vectorized, and has the same latency as rcr on at least some microarchitectures. That being said, I've never written something where the performance was contingent on integer averages.

I mention the LLVM incantation because needing to emit LLVM intrinsics to get optimal codegen is one of the reasons it might make sense to have a new method in core.

If the num crate can implement it reasonably without needing anything extra, then it might fall in the same category as div_rem: helpers in num that I do like, but that might not be worth having in core.

Honestly I wish there were more methods from the num crate in core, not fewer. There are some methods I am recreating on my end because I don't want to bring in a dependency that would increase the size of my crate by a third. Having the methods in core would avoid this altogether.

4 Likes

The rcr alternative isn't necessarily optimal because rcr executes in two cycles (when its immediate operand is 1). The and and xor can execute simultaneously on separate ports, so I think that alternative ends up having the same latency. In general I find it challenging to get LLVM to recognize opportunities to use the carry flag.

This rounds towards minus infinity, while the C++ std::midpoint rounds towards the first operand a. That is, if b < a, then rounding should be towards plus infinity. This can be remedied as:

(a&b) + (((a^b) + (b<a) as i32) >> 1)

Edit: would overflow if a^b is !0 and b<a. Better is:

let x = a ^ b;
let up = (b < a) as i32;
(a & b) + (x >> 1) + (x & up)

Is there a compelling reason why std::midpoint(1, 4) != std::midpoint(4, 1)? Is it that signed integer division in C++ (and on plenty of hardware) rounds towards zero?

P0811R3 provides the standard definition for std::midpoint over integral types as "Returns: Half the sum of a and b. If T is an integer type and the sum is odd, the result is rounded towards a. Remarks: No overflow occurs."

The stated reasoning in the rest of P0811R3 is that (paraphrasing heavily) rounding towards a is typically what you'd expect when doing e.g. binary search. (Side note: this also allows the caller to control the rounding.)

P0811R3 also notes that you can change the condition to a>=b or a<=b to round half integers up or down, respectively.

From my understanding, there's five options:

  • Round towards a
  • Round towards b
  • Round up (towards MAX)
  • Round down (towards MIN)
  • Round towards 0

Personally, I think only the first or the fifth options are really in the running to be picked. The first isn't commutative, but it has predictable and consistent behavior for ++, +-, -+, and --: bias towards the first argument. The fifth leans on the fact that integer division truncates toward zero for familiarity, and gains commutativity, but loses the consistency that midpoint(x, y) == -midpoint(-x, -y).

It's a trade-off. If always rounding toward zero can be done branchless and rounding toward a can't, that'd make me more likely to support rounding toward zero, but I still think rounding toward a is more useful. (Plus, if it inlines, you can just sort a and b on input to get the the rounding you want at no cost.)

1 Like

I looked at P0118R3, but I didn't see any language justifying rounding towards a. Perhaps I missed something.

There is a sixth option: rounding towards even. It's what I would use for fixed-point arithmetic.

All of these options can be implemented in a branch-free way.

1 Like

I think the paper mentions that for floating-point numbers, it would make more sense to call this mean rather than midpoint, but it is called that to use the same overload for all types in C++. I'm now thinking it could make sense for midpoint to round towards a for applications like binary search, but mean for floating-point or fixed-point could be rounded like other operations.

I could've sworn it specifically called out it's application to binary search; perhaps that's only for the midpoint of pointers? Anyway, the only actual remnants of the justification for the choice they made seems to be

"it aims to replace the simple expression a+(b-a)/2 regardless of type", "Returns: Half the sum of a and b. If T is an integer type and the sum is odd, the result is rounded towards a.", and (for the pointer version) "As with integers, when the midpoint lies between two pointer values the one closer to a is chosen; for the usual case of a<b, this is compatible with the usual half-open ranges by selecting a when [a,b) is [a,a+1)." and "Returns: A pointer to x[i+(j-i)/2], where the result of the division is truncated towards zero."

So that's where I recall the justification: always returning something inbounds for a half open range.


IMHO, while half-even is a decent option for floating point to integer types, it makes less sense for purely integral types, which should stick to consistent truncation semantics.

4 Likes

With division rounding towards 0, (a+b)/2 corresponds to the fifth option, but a + (b-a)/2 corresponds to the first option (round towards a).

However (a+b).div_euclid(2) and a + (b-a).div_euclid(2) both correspond to the fourth option (round down).

BTW, I think making "/" round towards zero was a serious mistake. This has caused many bugs in C, it's virtually never what you want for negative numbers. The argument that it's more efficient because of hardware is also only partially true, for example "/ 2" can be implemented as a right shift by 1 bit only if you round down rather than round to 0. With rounding to 0 extra instructions have to be emitted.

In this case I think rounding towards a and rounding down are the only reasonable options.