Use dynamic sized types on RISC-V V vector registers

RISC-V (and similiar architectures) have a feature of dynamic sized vector types. It has 32 vector registers with the same size upon design, but size of register varies between microarchitectures. Designs with different vector length share the same instructions as operands. For example we add two vectors using vadd.vv v0, v1, v2, no matter how much bytes each v0 would include. This would mark its major difference with x86-64 etc, where SIMD values must have the same size, like AVX2 should be 256 bits forever. Hence, fixed size types are not enough to describe RISC-V vectors.

The main idea of this post is to use dynamic sized type for dynamic vectors. To begin with, we introduce value of slice of arithmetic types, and wrap it into a struct:

#[repr(simd)]
struct u8xN([u8]); // allow(non_camel_case)

and also in generics:

#[repr(simd)] struct UnsizedSimd<T>([T]);
type vint8mf8_t = UnsizedSimd<u8>;

... here, we express a vector register full of u8 variables in u8xN. We can additionally define a wrapper around [u16], [f32] or [i8] etc. to allow vector register with different type configurations. RISC-V allows us to change vector context to match these types using vsetvl{i} instruction. We express the context using Rust's lifetime, and wrap this instruction into a function and call it:

// in core::arch::riscv{32, 64}
pub fn vsetvl<'v, V: Vector>(len: usize) -> Vcsr<'v, V> { ... }
// in application
let vl = vsetvl::<vint8mf8_t>(len); // vl: Vcsr<'v, vint8mf8_t>

When we adds two vector of [u8]'s, we are using vadd.vv but we can't add them using + operator since it lacks information of vector context. Instead, we introduce special function to handle this:

// in core::arch::riscv{32, 64}
pub fn vaddvv<V>(a: V, b: V, vl: Length<V>) -> V
where V: Vector;
// in application
let c = vaddvv(a, b, vl);

... here a and b are function parameters, c is the function return value. This would require unsized function parameter which in nightly now, and unsized return value which is not implemented in nightly.

Finally we should load and store value in vector registers. We use vlv instruction and wrap it into:

pub unsafe fn vlv<V>(base: *const V::Element, vl: Length<V>) -> V
where V: Vector;

This approach only uses small amount of intrinsic functions to describe the whole RISC-V V subsystem, which saves lot of code and avoids to read manual frequently comparing to the C intrinsics (the function definition header code alone takes 1MB of disk space). The backend e.g. LLVM already supports dynamic sized vectors using vscale parameter, for example [i8] can be expressed in LLVM IR < vscale x 1 x i8 >, and there are intrinsic functions to arithmetic, logical operations to these LLVM types.

Now we can write a full program how we use DST in SIMD for Rust to accelerate addition to two arrays:

fn add<const N: usize>(a: &[u8; N], b: &[u8; N], c: &mut [u8; N]) {
    let mut n = N;
    let mut a_ptr = a.as_ptr();
    let mut b_ptr = b.as_ptr();
    let mut c_ptr = c.as_mut_ptr();
    while n > 0 {
        let vl = vsetvl::<vuint8mf8_t>(n);
        n -= vl.as_bytes();
        unsafe {
            let vs1 = vlv(a_ptr, vl);
            let vs2 = vlv(b_ptr, vl);
            let vd = vaddvv(vs1, vs2, vl);
            vsv(c_ptr, vd, vl);
            a_ptr = a_ptr.add(vl.as_bytes());
            b_ptr = b_ptr.add(vl.as_bytes());
            c_ptr = c_ptr.add(vl.as_bytes());
        }
    }
}

pub fn main() {
    let src1 = [1, 2, 3, 4];
    let src2 = [5, 6, 7, 8];
    let mut dst = [0; 4];
    add(&src1, &src2, &mut dst);
}

I included this example in repository GitHub - luojia65/rust-rvv-intrinsics: test repository.

By the time I wrote this post, I feel like we are reaching the first actual usage of Rust's DST in practice, but I also feel like we are far from it. Is this approach practical? Thanks all rustaceans! :slight_smile:

2 Likes

There is an RFC on scalable vectors: RFC: Add a scalable representation to allow support for scalable vectors by JamieCunliffe · Pull Request #3268 · rust-lang/rfcs · GitHub