Optimization in the enum/match pattern

First of all I'm not quite sure if this feature already exists. I tried to figure it out by combing through the generated assembler and LLVM IR and didn't find it. Although my x86_64 assembler is not very good and my LLVM IR knowledge even worse.

I was doing software architecture for a framework I'm writing and maybe discovered an opportunity for performance improvements in Rust.

It seems to be a common pattern in Rust to use enums to avoid dynamic trait objects when returning an object and match the enum afterwords (even in the next use), destroying it to get to the contents and perform something with them.

Rust Example

Open in playground

#![allow(dead_code)]
#![feature(bench_black_box)]

use std::hint::black_box;

enum Variants {
    Usize(usize),
    U64(u64),
    U32(u32),
    U16(u16),
    U8(u8),
}

enum InVariant {
    Usize,
    U64,
    U32,
    U16,
    U8,
}

#[inline(always)]
fn mapping_in_to_out(input: InVariant) -> Variants {
    match input {
        InVariant::Usize => {
            Variants::Usize(usize::MAX)
        }
        InVariant::U64 => Variants::U64(u64::MAX),
        InVariant::U32 => Variants::U32(u32::MAX),
        InVariant::U16 => Variants::U16(u16::MAX),
        InVariant::U8 => Variants::U8(u8::MAX),
    }
}

fn perform_complete(input: InVariant) {
    let variants = mapping_in_to_out(input);
    unsafe { calculating_out(variants); }
}

#[inline(always)]
unsafe fn calculating_out(input: Variants) {
    match input {
        Variants::Usize(d) => {
            let e = d - 10;
            println!("{}", e);
        }
        Variants::U64(d) => {
            let e = d - 10;
            println!("{}", e);
        }
        Variants::U32(d) => {
            let e = d - 10;
            println!("{}", e);
        }
        Variants::U16(d) => {
            let e = d - 10;
            println!("{}", e);
        }
        Variants::U8(d) => {
            let e = d - 10;
            println!("{}", e);
        }
    }
}

fn main() {
    let in_variant = InVariant::Usize;
    perform_complete(black_box(in_variant));    
}

Possible optimization opportunity

From the code sample one can clearly see, that it's not necessary to calculate the intermediate Variants enum and avoid the second branching all together, by just jumping into the respective match clause.

That's it.

Sample / Benchmark C implementation

I've rewritten the Rust sample in C and employed goto to perform the proposed optimization. Whole Benchmark in the Compiler Explorer

Here is the relevant code from the C implementation

void constructVariantASM(enum NumInVariantEnum input) {
	  uint8_t uint8_variant_val;
	  uint16_t uint16_variant_val;
	  uint32_t uint32_variant_val;
	  uint64_t uint64_variant_val;
          struct NumVariant ret;
          switch (input) {
          case u8In: {
		    uint8_variant_val = UINT8_MAX;
		    goto uint8_variant;
                    ret.type = u8;
                    ret.data.u8 = UINT8_MAX;
                    break;
          }
          case u16In: {
		    uint16_variant_val = UINT16_MAX;
		    goto uint16_variant;
                    ret.type = u16;
                    ret.data.u16 = UINT16_MAX;
                    break;
          }
          case u32In: {
		    uint32_variant_val = UINT32_MAX;
		    goto uint32_variant;
                    ret.type = u32;
                    ret.data.u32 = UINT32_MAX;
                    break;
          }
          case u64In: {
		    uint64_variant_val = UINT64_MAX; 
		    goto uint64_variant;
                    ret.type = u64;
                    ret.data.u64 = UINT64_MAX;
	  }
          }

          switch (ret.type) {
          case u8: {
                    uint8_t structData = ret.data.u8;

          uint8_variant:
		    structData = uint8_variant_val;
                    structData -= uint8_one;
#ifdef PRINT_OUPUT
                    printf("%i\n", structData);
#endif
                    break;
          }
          case u16: {
                    uint16_t structData = ret.data.u16;
	  uint16_variant:	
		    structData = uint16_variant_val;
                    structData -= uint16_one;
#ifdef PRINT_OUPUT
                    printf("%i\n", structData);
#endif
                    break;
          }
          case u32: {
                    uint32_t structData = ret.data.u32;
	  uint32_variant:
		    structData = uint32_variant_val;
                    structData -= uint32_one;
#ifdef PRINT_OUPUT
                    printf("%u\n", structData);
#endif
                    break;
          }
          case u64: {
                    uint64_t structData = ret.data.u64;
	  uint64_variant:
		    structData = uint64_variant_val;
                    structData -= uint64_one;
#ifdef PRINT_OUPUT
                    printf("%llu\n", structData);
#endif
                    break;
          }
          }
}

Bench Results

I've run the benchmark and got the following results:

Unoptimized took 1.477011 seconds
Optimized   took 1.443288 seconds

Granted, the benefit is really small, it might be bigger in more complex code.

Resume

I just thought about this opportunity and would be interested in comments by people with more expertise and knowledge in Rust and compiler design. Sadly I don't have the time, resources or knowledge to push any efforts here.

Benchmark machine

OS: macOS 12.3.1 21E258 x86_64 Host: MacBookAir7,2 CPU: Intel i5-5350U (4) @ 1.80 GHz GPU: Intel HD Graphics 6000 Memory: 8192 MiB

Thanks for your time and interest :slight_smile:

A difference that small probably comes down to the difference between println! and printf. LLVM isn't the best at transposing control flow like this, but it does do it.

Afaik, rustc relies on LLVM for essentially all optimization. So I’d guess there might not be too much we can do besides hoping LLVM already does such optimizations in some cases. (Note that I’m actually not familiar with rustc deeply enough at all so make such a claim; it’s just a guess, maybe there is things we can do, IDK.)

If we had Rust-specific optimizations on MIR (IIRC, there is actually at least plans to do this eventually, no clue what the status is), we could probably discuss adapting, in some form or another, the “case-of-case” optimization that Haskell does, as described e.g. here (starting at page 12; when reading the point about the let bindings to avoid duplication note that Haskell is lazy). In other words, a transformation that turns

match (match E {
        P1 => E1,
        P2 => E2,
    }) {
    Q1 => F1,
    Q2 => F2,
}

into

match E {
    P1 => match E1 {
        Q1 => F1,
        Q2 => F2,
    }
    P2 => match E2 {
        Q1 => F1,
        Q2 => F2,
    }
})

note that these two could maybe differ in temporary scopes, I haven’t checked that; but I assume that optimization passes would probably more explicitly track those anyways; or maybe those are (or would be) desugared already in MIR after all..[1]

with the intention that the match E1/match E2 expressions could optimize better (and with the downside that – at least when done naively – the expressions F1 and F2, and hence the code evaluating them, could get duplicated).

The perform_complete function body in the code example you show is essentially of this form (if the let is “inlined”), and the result would feature expressions like match Variants::Usize(usize::MAX) { Variants::Usize(d) => { … }, /* more non-`Usize` cases */ … } which could be further optimized (since the variant is known).


  1. and I just remembered, I don’t even know what a match looks like in MIR; I suppose those details don’t matter to get the main idea across. ↩︎

2 Likes

LLVM does actually see through this example. Excerpting the optimized LLIR with annotations:

define internal fastcc void playground::perform_complete(i8 noundef %input) unnamed_addr #0 {
start:
  ; snip allocas
  switch i8 %input, label %bb2.i [ ; jump table
    i8 0, label %bb3.i2
    i8 1, label %bb7.i
    i8 2, label %bb11.i
    i8 3, label %bb15.i
    i8 4, label %bb1.i3
  ]

; fallthrough
bb2.i:                                            ; preds = %start
  unreachable

; case 0
bb7.i:                                            ; preds = %start
  ; compute e
  store i64 -11, i64* %e1.i, align 8, !noalias !8
  ; snip formatting machinery
  ; call std::io::stdio::_print
  call void std::io::stdio::_print(%"core::fmt::Arguments"* noalias nocapture noundef nonnull dereferenceable(48) %_22.i), !noalias !8
  ; jump to function epilogue
  br label playground::calculating_out.exit

; snip other cases

; function epilogue
playground::calculating_out.exit: ; preds = %bb3.i2, %bb7.i, %bb11.i, %bb15.i, %bb1.i3
  ret void

[playground]

LLVM will do it if the function is inlined. Nothing we can do if it isn't (how will we justify this optimization in that case?)

2 Likes

Here's an example: Compiler Explorer.

? desugaring actually wants this too, and nikic got some LLVM changes in to make it better: https://github.com/rust-lang/rust/issues/85133#issuecomment-1072168354.

I'm not sure if those hit what's discussed in this thread specifically, or what LLVM version those are in and thus whether they're in rustc nightly yet, but hopefully it'll be better soon!

As an aside, I've actually had great lucky lately with filing issues on LLVM now that they're in github so it's much easier than it used to be. For example, `mul nuw`+`lshr exact` should fold to a single multiplication (when the latter is a factor) · Issue #54824 · llvm/llvm-project · GitHub got picked up by someone in about a day, and had a fix merged within two weeks.

3 Likes