What if match statetement could generate perfect hash function

I was bench marking match statement and realized that using runtime allocating data structures can be faster than hard-coding match patterns.

e.g.

match  string {
  .. lots of strings
}

let mut data = sometDataStructure();
data.insert_lots_of_strings();
data.get(value) // this is faster

It seems like match statement always translates itself into a jump table, do you think it would be benefit-al if compiler was more aggressive in it's optimization strategy, like inserting a hash function for strings and such.

So, if I were to implement this as a trait, it would look like this. e.g.

trait MatchStmtTrait<T> 
where
    T: Eq + PartialEq
{
  /// is there a function for narrowing down possible candidates implemented?
  const IS_MATCH_HINT_IMPLEMENTED: bool = false;
  /// creates value that helps you narrow down the options
   fn match_hint(&self) -> T;
}

This is something that I just came up with and I must admit that it's not well thought out.
I appreciate feed backs.

2 Likes

Match does get optimized by llvm, but obviously llvm can't be correct in every case. For instance if the early match clauses were much more likely than the later ones, I would expect naïve match to outperform a more complex search. sometDataStructure also has different semantics from match in the case of duplicate keys, which is maybe harder for llvm to reason about.

In general this seems more, if you understand your data you can come up with faster code to deal with it.

2 Likes

This is up to LLVM, generally. It's allowed to mostly do what it wants with the control flow. So the default answer for "I'd like it to codegen differently" is to file an LLVM issue describing how it should have picked a different strategy. (After all, LLVM has the target-specific knowledge about things like branch costs, predicated instructions, etc.)

That said, it's possible that there are some cases where rustc isn't communicating enough information to LLVM for it to be able to make good choices. It would perhaps help to demonstrate some of those, in order to focus what would be best on the Rust side to be able to convey intent and thus let LLVM know better what to do.

A classic thing, for example, has been that it might be good to have ways to mark some match arms as much more likely that others, in case pre-checking that one as a normal branch might be a good choice, then using a jump table for the others, say.

But that does tend to quickly run into "well, if you care about that you should probably use PGO instead" to get far more information -- and more accurate information -- than the programmer can reasonably include in source.

8 Likes

I don't think it should be left up to LLVM. Generating naive code, and then having LLVM guess it looks like a match, and then optimize it somehow, is a lot of work, and in practice it doesn't work well (e.g. match of lots of strings ends up being series of calls to memcmp). Rustc should be able to generate smarter match code in the first place.

4 Likes

We generally lower MIR SwitchInt to LLVM switch, so in at least some cases we already give it about what it wants.

What kind of "smarter" are you thinking here? Trying to change match lowering -- already one of the most complicated parts of the compiler -- sounds scary. And optimizing it later makes me think it'd be easier to just do in LLVM. (LLVM knows what memcmp is; it could optimize those string compares itself.)

3 Likes

Here is an example.

I believe match statement that rustc's generates has room for improvements.

Given this code,

pub fn example_function(a: &str) -> i64 {
    match a {
        "arm1" => 1,
        "arm2" => 2,
        "arm3" => 3,
        "arm4" => 4,
        "arm11" => 11,
        "arm12" => 12,
        _ => 0
    }
}

This compile downs to this assembly. It just calls PartialEq multiple times.

I think we can make it faster if we could compare it against the length of the strings before calling PartialEq. We should be able to avoid all those costly operations

It appears that alloc command is being called before calling PartialEq; If we could remove this, it should let us save CPU power, CPU cache and reduce memory allocation.

_ZN7example16example_function17h6f467ebc3bc986a4E:
.Lfunc_begin5:
        .loc    6 22 0 is_stmt 1
        .cfi_startproc
        subq    $24, %rsp
        .cfi_def_cfa_offset 32
        movq    %rdi, (%rsp)
        movq    %rsi, 8(%rsp)
.Ltmp17:
        .loc    6 24 9 prologue_end
        leaq    .Lalloc_e7c84e874c0afbfdcca932ed40900bb2(%rip), %rdx
        movl    $4, %ecx
        callq   _ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE
        testb   $1, %al
        jne     .LBB5_2
        .loc    6 0 9 is_stmt 0
        movq    8(%rsp), %rsi
        movq    (%rsp), %rdi
        .loc    6 25 9 is_stmt 1
        leaq    .Lalloc_1368b64af3f9ec0d7a68ddc8c0fcf9f2(%rip), %rdx
        movl    $4, %ecx
        callq   _ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE
        testb   $1, %al
        jne     .LBB5_5
        jmp     .LBB5_4
.LBB5_2:
        .loc    6 24 19
        movq    $1, 16(%rsp)
.LBB5_3:
        .loc    6 32 2
        movq    16(%rsp), %rax
        .loc    6 32 2 epilogue_begin is_stmt 0
        addq    $24, %rsp
        .cfi_def_cfa_offset 8
        retq
.LBB5_4:
        .cfi_def_cfa_offset 32
        .loc    6 0 2
        movq    8(%rsp), %rsi
        movq    (%rsp), %rdi
        .loc    6 26 9 is_stmt 1
        leaq    .Lalloc_268e26cfe670dbf832d80322ec41d9f8(%rip), %rdx
        movl    $4, %ecx
        callq   _ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE
        testb   $1, %al
        jne     .LBB5_7
        jmp     .LBB5_6
.LBB5_5:
        .loc    6 25 19
        movq    $2, 16(%rsp)
        jmp     .LBB5_3

By the way, LLVM IR ends up like this.


; example::example_function
define i64 @example::example_function(ptr align 1 %a.0, i64 %a.1) unnamed_addr {
start:
  %0 = alloca i64, align 8
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_2 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_e7c84e874c0afbfdcca932ed40900bb2, i64 4)
  br i1 %_2, label %bb13, label %bb2

bb2:                                              ; preds = %start
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_3 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_1368b64af3f9ec0d7a68ddc8c0fcf9f2, i64 4)
  br i1 %_3, label %bb14, label %bb4

bb13:                                             ; preds = %start
  store i64 1, ptr %0, align 8
  br label %bb19

bb19:                                             ; preds = %bb12, %bb18, %bb17, %bb16, %bb15, %bb14, %bb13
  %1 = load i64, ptr %0, align 8
  ret i64 %1

bb4:                                              ; preds = %bb2
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_4 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_268e26cfe670dbf832d80322ec41d9f8, i64 4)
  br i1 %_4, label %bb15, label %bb6

bb14:                                             ; preds = %bb2
  store i64 2, ptr %0, align 8
  br label %bb19

bb6:                                              ; preds = %bb4
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_5 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_a5a9b17e752b14683358567ce3b24a8b, i64 4)
  br i1 %_5, label %bb16, label %bb8

bb15:                                             ; preds = %bb4
  store i64 3, ptr %0, align 8
  br label %bb19

bb8:                                              ; preds = %bb6
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_6 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_68707e444069684395fa79bc597d56e1, i64 5)
  br i1 %_6, label %bb17, label %bb10

bb16:                                             ; preds = %bb6
  store i64 4, ptr %0, align 8
  br label %bb19

bb10:                                             ; preds = %bb8
; call core::str::traits::<impl core::cmp::PartialEq for str>::eq
  %_7 = call zeroext i1 @"_ZN4core3str6traits54_$LT$impl$u20$core..cmp..PartialEq$u20$for$u20$str$GT$2eq17h329e7f59881ae1daE"(ptr align 1 %a.0, i64 %a.1, ptr align 1 @alloc_721bce25910cd3959179c6a60d679dc9, i64 5)
  br i1 %_7, label %bb18, label %bb12

bb17:                                             ; preds = %bb8
  store i64 11, ptr %0, align 8
  br label %bb19

bb12:                                             ; preds = %bb10
  store i64 0, ptr %0, align 8
  br label %bb19

bb18:                                             ; preds = %bb10
  store i64 12, ptr %0, align 8
  br label %bb19
}

I think this should be handled from the rust side.
Do you think it's worth creating a RFC?

Optimizations don't need an RFC unless they guarantee user-observable behavior (which this does not).

1 Like

That's not quite true. First it's a jump table on the length of the string that was passed in:

define noundef i64 @_ZN7example16example_function17hb1208ef3045187ebE(ptr noalias nocapture noundef nonnull readonly align 1 %a.0, i64 noundef %a.1) unnamed_addr #0 {
start:
  switch i64 %a.1, label %bb13 [
    i64 4, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit"
    i64 5, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit20"
  ]

https://rust.godbolt.org/z/sdcrrGocv

I just read the assembly; Yes, you are right. They are indeed comparing them.

example::example_function:
        .cfi_startproc
        cmp     rsi, 5
        je      .LBB0_10
        cmp     rsi, 4
        jne     .LBB0_2
        cmp     dword ptr [rdi], 829256289
        je      .LBB0_4
        cmp     dword ptr [rdi], 846033505
        je      .LBB0_6
        cmp     dword ptr [rdi], 862810721
        je      .LBB0_8
        xor     eax, eax
        cmp     dword ptr [rdi], 879587937
        sete    al
        shl     rax, 2
        ret

Thank you for letting me know. I completely missed it!

Oh, this bit is surprisingly nice, since 829256289 = 0x316D7261 = arm1. So it's be hard to do much better than that.

(Though of course it's just good luck from the 4-byte pattern.)

2 Likes

I tried it again with tuple.

pub fn example_function(a: (Thing, &str)) -> i64 {
    match a {
        (Thing::This, "arm1") => 1,
        (Thing::That, "arm2") => 2,
        (Thing::That, "arm3") => 3,
        (Thing::Else, "arm4") => 4,
        (Thing::Else, "arm11") => 11,
        (Thing::This, "arm11") => 12,
        _ => 0
    }
}
pub enum Thing {
    This,
    That,
    Else
}

This generates,

; example::example_function
; Function Attrs: mustprogress nofree nounwind nonlazybind willreturn memory(read) uwtable
define noundef i64 @_ZN7example16example_function17hb96e950ca2d2a1d2E(ptr noalias nocapture noundef readonly align 8 dereferenceable(24) %a) unnamed_addr #0 {
start:
  %0 = load i8, ptr %a, align 8, !range !3, !noundef !4
  %_8 = zext i8 %0 to i64
  switch i64 %_8, label %bb15 [
    i64 0, label %bb1
    i64 1, label %bb4
    i64 2, label %bb6
  ]


bb1:                                              ; preds = %start
  %1 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2
  %self.0 = load ptr, ptr %1, align 8, !nonnull !4, !align !5, !noundef !4
  %2 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2, i32 1
  %self.1 = load i64, ptr %2, align 8, !noundef !4
  switch i64 %self.1, label %bb14 [
    i64 4, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit"
    i64 5, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit27"
  ]

"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit": ; preds = %bb1
  %bcmp.i = tail call i32 @bcmp(ptr noundef nonnull dereferenceable(4) %self.0, ptr noundef nonnull dereferenceable(4) @alloc_144a1a6b591638e98904068c7edeac19, i64 4), !alias.scope !6
  %3 = icmp eq i32 %bcmp.i, 0
  %spec.select53 = zext i1 %3 to i64
  br label %bb14

bb4:                                              ; preds = %start
  %4 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2
  %self.03 = load ptr, ptr %4, align 8, !nonnull !4, !align !5, !noundef !4
  %5 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2, i32 1
  %self.14 = load i64, ptr %5, align 8, !noundef !4
  %_3.not.i13 = icmp eq i64 %self.14, 4
  br i1 %_3.not.i13, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit17", label %bb14

"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit17": ; preds = %bb4
  %bcmp.i14 = tail call i32 @bcmp(ptr noundef nonnull dereferenceable(4) %self.03, ptr noundef nonnull dereferenceable(4) @alloc_c9ab67d65d2f6de5047d23501518397c, i64 4), !alias.scope !10
  %6 = icmp eq i32 %bcmp.i14, 0
  br i1 %6, label %bb14, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit32"

bb6:                                              ; preds = %start
  %7 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2
  %self.07 = load ptr, ptr %7, align 8, !nonnull !4, !align !5, !noundef !4
  %8 = getelementptr inbounds { i8, [7 x i8], { ptr, i64 } }, ptr %a, i64 0, i32 2, i32 1
  %self.18 = load i64, ptr %8, align 8, !noundef !4
  switch i64 %self.18, label %bb14 [
    i64 4, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit22"
    i64 5, label %"_ZN73_$LT$$u5b$A$u5d$$u20$as$u20$core..slice..cmp..SlicePartialEq$LT$B$GT$$GT$5equal17h99f0a08325237071E.exit37"
  ]

This starts by matching it against the enum, then it goes to compare it against the strings.
I think we might be able to further optimize this by telling LLVM that it can change the ordering of those operations.

As a data point, Zig observed at least in one case an if memcmp else if chain to be faster than PHF: Optimise stringToEnum · Issue #3863 · ziglang/zig · GitHub

8 Likes

Cool. let me look into it. Thank you for the info.

I wouldn't expect rustc to be smart about complex bindings with mixed moves and if guard clauses, but I think it could have predefined solutions for situations when the arms are all simple constants with strings or integers, and generate perfect hash tables/jump tables, or even use bisection.

I'm not sure what to do about assumptions that the first few match arms may be more likely. Perhaps it should leave it up to LLVM for low numbers of match arms, and be algorithmically clever only for huge match statements. Maybe it could have direct linear check for the first 1-3 arms, and assume everything else is equally likely, thus suitable for clever algorithms.

LLVM on its own is not clever. This code is awful:

The thing that's unclear to me is whether that's fundamental to LLVM somehow.

Is there a reason that LLVM shouldn't do better for that LLVM-IR input? My instinct is that it's better for LLVM to recognize matching on bytes like this and do something smarter so that clang and rustc and co don't all need to special-case it somehow and emit something harder for LLVM to recognize.

but I think it could have predefined solutions for situations when the arms are all simple constants with strings or integers, and generate perfect hash tables/jump tables, or even use bisection.

I would note that any switch/if-ladder can be optimized by LLVM (and GCC) into:

  • An if-ladder.
  • An if-tree (bisection).
  • A jump table.

And combinations of the above.

Going straight to PHF in rustc may impact the ability of the backend to apply the above, and thus incur regressions on existing code, which would be fairly sad.

LLVM on its own is not clever. This code is awful:

It doesn't look great indeed, but how does it perform?

First of all, the first comparison separates 4-bytes strings from 5-bytes strings.

Secondly, LLVM does apply the string -> int optimization, packing the bytes into a register and comparing the registers themselves... but it suffers from the fact that this strategy only works with powers-of-2, and your identifiers foo10 to foo99 are 5 bytes long, requiring a 4 bytes comparison followed by a 1 byte comparison.

I can see that LLVM picked an if-ladder for the 4-bytes section, for example, instead of an if-tree. This does not seem algorithmically ideal, but I have no idea about the performance of it. And said performance may depend on the data submitted, it'll likely vary quite a lot depending on whether foo1 appears much more often than foo2 (etc...) or whether they're all as likely.

A PHF performance profile should be flat regardless of the input, which is better for random data, but worse for heavily biased data where the first branch matches most often. Preserving this "edge" for the first (few) branch(es) may be part of the reason for LLVM preserving the if-ladder: assuming the developer knows best.


I don't want to dismiss the PHF idea, but I am afraid that because it won't automatically be a win, it'd be hard to apply systematically.

A more guided approach would avoid the issue. For example, if a PHF was only used if the match statement was adorned by a #[rustc::phf] attribute, then it would be easy to benchmark whether a PHF is beneficial or not, and apply it only when it is.

And interestingly, this seems like something a proc-macro could do, to validate the experiment.

1 Like

I think I will try to create a PHF implementation as you suggested.

We can do something like prefix matching before we compare the rest.
It shouldn't regress for biased data

e.g.

// we can convert this code into,
match string {
    "foo1" => 1,
    "foo2" => 2,
    "foo11" => 11,
    "foo12" => 12,
    _ => 0
  };

if string.starts_with("foo") {
   match &string[3..] {
     "1" => 1,
     "2" => 2,
     "11" => 11,
     "12" => 12,
      _ => 0
    };
} else {
  return 0;
}

I think I can create few macros to explore the idea.

Stay tuned....

Note that if one is willing to use an attribute, there's crates like GitHub - rust-phf/rust-phf: Compile time static maps for Rust that provide that.

2 Likes

This topic was automatically closed 90 days after the last reply. New replies are no longer allowed.