Two subsets of types

One of the two or three things I think are missing in Rust (that is, ways to formally prove some code is panic-free, or it's even functionally correct) is actively being worked on by multiple persons/groups, so I think we'll eventually have (various versions of) it.

But even if we eventually build ways to formally prove some aspects of the (Rust) code correct, this doesn't make it useless to have less powerful ways to assure some aspects of your code are correct. This because formally proving things is a bit painful and it often requires lot of programmer work, lot of slow proof computations (done by things like Z3), or both. So the more stuff you are able to assure correct by the less flexible type system, the less verification work/time you need.

So a (simple) type system thing I think could be useful in Rust and it can avoid some more complex verification work, are ways to handle subsets of types. Below I show two ways of doing this.

As example code I use solutions of (two) Euler Project problems because they are complex enough to show the problem and solution, but they are also very self-contained.


This code is from: https://projecteuler.net/problem=277

In this problem there's a coded sequence of steps, represented with a given hardcoded string of 'D', 'U' and 'd'. This could lead to Rust code like this, that uses a byte string:

const SEQ: &[u8] = b"UDDDUdddDDUDDddDdDddDDUDDdUUDd"; // Input.
for &s in SEQ {
    let c = match s { b'D' => 0, b'U' => 1, b'd' => 2, _ => panic!() };
    // ...
}

But here we're talking about having panic-free code, so I could use a Rust enum. With repr(u8) it has the same good enough memory representation as the byte string, and now the match{} doesn't contain any panics (every symbol here uses one byte, but it needs less than two bits, I think there are languages that offer a standard way to use two bits to represent this Cs):

#[derive(Copy, Clone)]
#[repr(u8)] enum Cs { D1, D2, U }
use Cs::*;
// Input.
const SEQ: &[Cs] = &[U, D2, D2, D2, U, D1, D1, D1, D2, D2,
                     U, D2, D2, D1, D1, D2, D1, D2, D1, D1,
                     D2, D2, U, D2, D2, D1, U, U, D2, D1];

for &s in SEQ {
    let c = match s { D2 => 0, U => 1, D1 => 2 };
    // ...
}

But it's conceivable to have in Rust a way to specify subsets of types. That allows you to write code more similar to this (similar to Pascal/Ada) code:

type Step = b'U' | b'd' | b'D';
const SEQ: &[Step] = b"UDDDUdddDDUDDddDdDddDDUDDdUUDd"; // Input.
for &s in SEQ {
    let c = match s { b'D' => 0, b'U' => 1, b'd' => 2 };
    // ...
}

This enumeration of alternative values allows to use compact literals like SEQ, still avoiding the need of panics in the match{}.


Beside and enumeration with the | alternation symbol, the other basic way to express a subset of finite types is with an interval. The syntax is simple, a type followed by a slicing with the usual range syntax:

type Month = u8[1 ..= 12];

In Ada language you can use something similar to specify intervals of floating point types too (and other things) but in Rust to keep the design simple enough I think single intervals of integral values is enough.

This strong typing of integral intervals is useful to assure the correctness of input values, to assure the correctness of computations, to remove catch-all cases from match{} and so on. In Ada for performance such bounds get tested only at variable assignment (unlike Rust in debug mode that verifiers bounds of integers in every one of their operations). One important advantage of using such interval types is that it allows to spot wrong values very close to where they are created. This fail-fast allows to catch bugs as soon as they are generated, and this allows the programmer to spot and fix them sooner.

But to be added to a language a feature needs to be as useful as possible, and this means it should be used in various part of the language. Another important usage of interval integral types is as index of fixed-size arrays. A value that is statically known to be strictly less than the length of an array doesn't need an array bound test (the upper bound of the interval type could be any smaller value than the length):

let mut array = [u32; 100];
let i: usize[0 .. 10] = 5;
array[i] = 2; // Panic-free.

In some cases array bound tests slow down my code in critical inner loops. In such cases writing panic-free code isn't jut for correctness but also for performance. Using bounded integrals allows to write faster code. At first sight this seems untrue, because to create a bounded value you need to test it's in-bound anyway, so you still have one bound test as using regular usize values in arrays. But an optimizing compiler changes the situation. In many cases when you create a ranged value the compiler can infer it's in-bound and it can remove the test. In other cases you even directly generate the bounded values in a way that's already assuring they are in-bound:

type MyIdx = usize[0 ..= 10];
for i in MyIdx::span() {
    // Here all i are in bound of MyIdx.
}

You can then store such values and pass them around to functions, using them to index arrays multiple times in various parts of your code, and in all such cases no bound test is needed. So you move the bound test early, where the compiler is often able to infer they are in bound, or where the compiler needs to test them only once (and then you can use the index multiple times later), or even reducing the number of tests to just their extrema, like in this example where the compiler tests only the first and last value, and all the intermediate indexes don't need to be tested to be in-bound:

let mut a: usize = ...;
let mut b: usize = ...;
for i in MyIdx::span_from(a ..= b) {

To test some of those ideas I've created little ranged index struct (that also allows to index with types different from usize) and using it I've measured small real performance improvements in some of my code. The code becomes a little more complex/long to write, but the parts of the code fit together better, like higher precision cogs :slight_smile: So it also caught early a couple of bugs in indexing-heavy code.

The following example code is from this other Euler Project problem: https://projecteuler.net/problem=280

fn round(x: f64, n_digits: u32) -> f64 {
    let precision = 10_u64.pow(n_digits) as f64;
    (x * precision).round() / precision
}

fn e280() -> f64 {
    const N: usize = 5 * 5; // Total grid size. Input.
    const PRECISION: u32 = 6; // Input.
    const M: usize = 31;
    const K: usize = 20;

    #[derive(Default)]
    struct E280 {
        x: [[[f64; N]; N]; M], // 19_375 items.
        b: [[u32; N]; M], // 775 items.
    }

    impl E280 {
        fn copy(&mut self, r: usize, s: usize) {
            for i in 0 .. N {
                for j in 0 .. N {
                    self.x[s][i][j] = self.x[r][i][j];
                }
                self.b[s][i] = self.b[r][i];
            }
        }


        fn inverse(&mut self, s: usize, i: usize) {
            self.b[s][i] = 1;
            let z = -1.0 / self.x[s][i][i];
            self.x[s][i][i] = z;

            for j in 0 .. N {
                if j == i {
                    continue;
                }
                let w = self.x[s][i][j] * z;
                if w == 0.0 {
                    continue;
                }
                for k in 0 .. N {
                    if k != i {
                        self.x[s][k][j] += w * self.x[s][k][i];
                    }
                }
                self.x[s][i][j] = w;
            }

            for k in 0 .. N {
                if k != i {
                    self.x[s][k][i] *= z;
                }
            }
        }


        fn initialize(&mut self) {
            const F: [u8; N] = [2, 3, 3, 3, 2,
                                 3, 4, 4, 4, 3,
                                 3, 4, 4, 4, 3,
                                 3, 4, 4, 4, 3,
                                 2, 3, 3, 3, 2];

            for (i, &fi) in F.iter().enumerate() {
                let z = f64::from(fi).recip();
                if i >= 5    { self.x[0][i][i - 5] = z; }
                if i + 5 < N { self.x[0][i][i + 5] = z; }
                //if i % 5 > 0 { self.x[0][i][i - 1] = z; } // Bound test with LLVM13.
                // Redundant "i>0" added to avoid a bound test.
                if i > 0 && i % 5 > 0 { self.x[0][i][i - 1] = z; }
                //if i % 5 < 4 { self.x[0][i][i + 1] = z; }  // Bound test with LLVM13.
                // Redundant "i<N-1" added to avoid a bound test.
                if i < N - 1 && i % 5 < 4 { self.x[0][i][i + 1] = z; }
                self.x[0][i][i] = -1.0;
                self.b[0][i] = 0;
            }

            for i in 0 .. K {
                self.inverse(0, i);
            }

            for k in 1 .. M {
                let ktz = k.trailing_zeros(); // in 0 ..= 4.
                self.copy(k - (1 << ktz), k);
                self.inverse(k, (ktz as usize) + K);
            }
        }


        fn tran(&mut self, pz: &mut f64, s: usize, d: usize, r: usize, y: f64) {
            for k in 0 .. N {
                if self.b[d][k] != 0 {
                    *pz += y * self.x[d][r][k];
                }
            }

            if s >= M { return; }

            for k in 0 .. 5 {
                if self.b[d][k + K] == 0 {
                    self.tran(pz, d + (1 << k), s, k, y * self.x[d][r][k + K]);
                }
            }
        }
    }


    let mut e = E280::default();
    e.initialize();

    let mut z = 0.0;
    for i in 0 .. K {
        z += e.x[0][12][i];
    }
    for i in 0 .. 5 {
        e.tran(&mut z, 1 << i, 0, i, e.x[0][12][K + i]);
    }
    round(z, PRECISION)
}


fn main() {
	assert_eq!(e280(), 430.088247);
}

Using my Idx struct the code becomes:

fn e280() -> f64 {
    const N: usize = 5 * 5; // Total grid size. Input.
    const PRECISION: u32 = 6; // Input.
    const M: usize = 31;
    const K: usize = 20;
    type IdM = Idx<usize, M>;
    type IdN = Idx<usize, N>;

    #[derive(Default)]
    struct E280 {
        x: [[[f64; N]; N]; M],
        b: [[u32; N]; M],
    }

    impl E280 {
        fn copy(&mut self, r: IdM, s: IdM) {
            for i in 0 .. N {
                for j in 0 .. N {
                    self.x[s][i][j] = self.x[r][i][j];
                }
                self.b[s][i] = self.b[r][i];
            }
        }


        fn inverse(&mut self, s: IdM, i: IdN) {
            self.b[s][i] = 1;
            let z = -1.0 / self.x[s][i][i];
            self.x[s][i][i] = z;

            for j in 0 .. N {
                if j == i.get() {
                    continue;
                }
                let w = self.x[s][i][j] * z;
                if w == 0.0 {
                    continue;
                }
                for k in 0 .. N {
                    if k != i.get() {
                        self.x[s][k][j] += w * self.x[s][k][i];
                    }
                }
                self.x[s][i][j] = w;
            }

            for k in 0 .. N {
                if k != i.get() {
                    self.x[s][k][i] *= z;
                }
            }
        }


        fn initialize(&mut self) {
            const F: [u8; N] = [2, 3, 3, 3, 2,
                                 3, 4, 4, 4, 3,
                                 3, 4, 4, 4, 3,
                                 3, 4, 4, 4, 3,
                                 2, 3, 3, 3, 2];

            for (i, fi) in F.into_iter().enumerate() {
                let z = f64::from(fi).recip();
                if i >= 5    { self.x[0][i][i - 5] = z; }
                if i + 5 < N { self.x[0][i][i + 5] = z; }
                //if i % 5 > 0 { self.x[0][i][i - 1] = z; } // Bound test with LLVM13.
                // Redundant "i>0" added to avoid a bound test.
                if i > 0 && i % 5 > 0 { self.x[0][i][i - 1] = z; }
                //if i % 5 < 4 { self.x[0][i][i + 1] = z; }  // Bound test with LLVM13.
                // Redundant "i<N-1" added to avoid a bound test.
                if i < N - 1 && i % 5 < 4 { self.x[0][i][i + 1] = z; }
                self.x[0][i][i] = -1.0;
                self.b[0][i] = 0;
            }

            for i in IdN::FIRST .. IdN::new(K).unwrap() {
                self.inverse(IdM::FIRST, i);
            }

            for k in IdM::new(1).unwrap() ..= IdM::LAST {
                let ktz = k.get().trailing_zeros(); // in [0 ..= 4].
                self.copy(IdM::new(k.get() - (1 << ktz)).unwrap(), k);
                self.inverse(k, IdN::new(to!{ktz, usize} + K).unwrap());
            }
        }


        fn tran(&mut self, pz: &mut f64, s: usize, d: usize, r: IdN, y: f64) {
            for k in 0 .. N {
                if self.b[d][k] != 0 {
                    *pz += y * self.x[d][r][k];
                }
            }

            if s >= M { return; }

            for k in IdN::FIRST .. IdN::new(5).unwrap() {
                if self.b[d][k.get() + K] == 0 {
                    self.tran(pz, d + (1 << k.get()), s, k, y * self.x[d][r][k.get() + K]);
                }
            }
        }
    }


    let mut e = E280::default();
    e.initialize();

    let mut z = 0.0;
    for i in 0 .. K {
        z += e.x[0][12][i];
    }
    for i in IdN::FIRST .. IdN::new(5).unwrap() {
        e.tran(&mut z, 1 << i.get(), 0, i, e.x[0][12][K + i.get()]);
    }
    round(z, PRECISION)
}

In this second version many (most) array bound tests are absent.

A well rounded language feature should also allow to define interval types from an array:

type MyMat1 = [u32; M];
type IdM = MyMat1::Index;

Note I'm not suggesting to add that Idx struct to the Rust stdlib, it's just a way to emulate one of the usages of interval integral types.

2 Likes

...hey don't tease - how much? %)

Totally not saying that speed gain is the sole motivation, but still?..

That Idx is just able to remove some (or in the best cases most) array bound tests. It doesn't remove slice bound tests and bound tests in vecs/deques, etc. So there's some speedup only if your code uses arrays a lot, the optimizer is able to remove most bound tests when indexes get created, and your code has kernels where there's lot of computation over those arrays. I have seen speed up also in code where indexes are stored inside another array (or another data structure) and used later. With so many variables I think it isn't important to give percentages, I wasn't trying to tease. They were one digits percentages, but in some cases it allowed me to avoid unsafe array indexing (or avoid unsafe assumes()) and still have Rust code as fast as equivalent C code. I am using arrays a lot, so this happened several times.