Tensors static typing

In the '60s they invented a programming language named APL, meant to process first of all tensors (n-dimensional arrays). Its syntax was arcane and baroque. It was dynamically typed, so the operations and tensor dimensions were verified at run-time (and often adapted in flexible ways). Later other persons have invented derived languages like K that share a similar semantics but with a less esoteric syntax (no language-specific glyphs). I think few persona tried to create a static type system for languages like APL, but I think such experiments didn't lead to long lasting developments. Much later people have invented Python libraries to process tensors, and such usage has become quite common in several scientific and technological fields (today there's NumPy).

Later people have created libraries, often in C++, to allow the creation and training of neural networks, and to run them on CPUs and GPUs. Now the use of such libraries has become common enough, from Python too. Meanwhile dynamic languages like Python, Ruby, JavaScript, PHP, Racket, and others have started to develop ways to add static typing (gradual typing) and to verify such type annotations (both in a sound and unsound ways. On such topics I remember many fun articles like: https://journal.stuffwithstuff.com/2010/10/29/bootstrapping-a-type-system/ or https://blog.acolyer.org/2016/02/05/is-sound-gradual-typing-dead/ ).

So lately people are trying to design a static verifier (again a form of gradual typing) for the shapes and size of tensors used by such neural network libraries used by Python (like Pyre). Some of such persons say they do not want to use dependent typing, I am not sure why, but they are close to that.

At first sight it may seem curious that advancements in the field of static (gradual) typing of such tensors doesn't come from the world of statically typed languages. But I think dynamic languages (or flexible mostly statically typed languages like Swift) get used a LOT for such machine learning purposes, so advancements are expected mostly were the real work is done. Another factor is that such dynamic languages are flexible, they can do almost anything, so it's simpler to put on top of them a static type system to restrict their flexibility, catch some bugs, and make the code easier to write and understand. So I think such type systems are simpler to design in languages as Python (but Haskell programmers may disagree with that), and later some of such ideas may be added to static languages as Swift and Rust (or partially static languages like Julia). I hope some of such ideas will become usable in Rust, eventually.

Another curious thing is that such advancements seem to come from the world of practical language implementations and not from the Academy where there are several (lot of) people very expert about advanced type systems.

Some links about such recent explorations, with slides, videos and more:

See also:

Some extracts:

# How to type this?
image: Tensor[64, 128, 3]
a = tf.reduce_mean(image, axis=0) # Shape: 128, 3
a = tf.reduce_mean(image, axis=1) # Shape: 64, 3
a = tf.reduce_mean(image, axis=2) # Shape: 64, 128

# Ideally:
def reduce_mean(x: Tensor[AxesSet], axis: int)
  -> Tensor[AxesSet - {axis}]): ...

(But I think this is not exact, AxesSet needs to be an ordered sequence. And this looks close to being dependent typing.)

# For now:
def reduce_mean(x: Tensor[A0, A1, A2], axis: Literal[0])
  -> Tensor[A1, A2]): ...
def reduce_mean(x: Tensor[A0, A1, A2], axis: Literal[1])
  -> Tensor[A0, A2]): ...

# Variadics:
def add(x: Tensor1D[A], y: float) -> Tensor1D[A]
def add(x: Tensor2D[A,B], y: float) -> Tensor2D[A,B]
def add(x: Tensor3D[A,B,C], y: float) -> Tensor3D[A,B,C]
Shape = ListVariadic("Shape", bound=int)
class Tensor(Generic[Shape]): ...
def add(x: Tensor[Shape], y: float) -> Tensor[Shape]

# Type Arithmetic:
def concat(x: Tensor[A], y: Tensor[B]) -> Tensor[A + B]
def _getitem_(t: Tensor[N], start: A, end: B) -> Tensor[B - A]
def range(start: A, end: B, step: S) -> Tensor[(B-A) //S]
def duplicate(t: Tensor[N]) -> Tensor[N * 2]
def flatten(t: Tensor[Shape]) -> Tensor[Prod[Shape]]
def _len_(t: Tuple[Ts]) -> Length[Ts]

# Convolution:
def __call__(t: Tensor[N, C_in, H_in, W_in]) -> Tensor[
  1 + ((H_in + 2 * P - D * (K - 1)) - 1 // S),
  1 + ((W_in + 2 * P - D * (K - 1)) - 1 // S),
]: ...

# Reduce operators on variadics:
def _len_(x: Tensor[Ts]) -> Length[Ts]
def flatten(x: Tensor[Ts]) -> Tensor[Prod[Ts]]
def view(x: Tensor[Ts1], a: L[ -1], *ts: Ts2) ->
  Tensor[Prod[Ts1] //Prod[Ts2],Ts2]

# Equality:
# Tensor[A+B] and Tensor[B+A] are not the same type
# Any expression with addition and multiplication can be normalized to an ordered list of monomials.

# GCD on multivariate polynomials is not trivial.
x^2-1/x+1 -> (x+1)(x-1)/x-1 -> x+1
# Fully decidable equality would require Gröbner basis.
# For simplicity we limit ourselves to:
(2xy + 4xz)/(6x + 4x2) -> (y + 2z)/(3 + 2x)
1 Like

tldr :)))

Could you summary what you want to express ?

1 Like

Currently I am not asking much, I've just tried to raise some awareness for Rust designers of this part of the design space. Because machine learning and tensors could become important for Rust too, and being Rust often interested in using its static type system to help code correctness, I think this stuff could be fit in Rust too.

Currently I'd like to be able to write code like this in Rust:

fn foo(a: &mut [u32], i: usize) -> &mut [u32; 3] {
    (&mut a[i + 1 .. i + 4]).into()

Instead of:

fn foo(a: &mut [u32], i: usize) -> &mut [u32; 3] {
    (&mut a[i + 1 .. i + 4]).try_into().unwrap()

Hmm, I think that might require implementing dependent types, which would be a pretty big change to the type system.


No need for dependent types, if i + 1 .. i + 4 is a valid range for the 'a' slice, then it's statically known to be of length 3. Take a look at the numerous examples and links in the first post.

But the type of the expression &mut a[i + 1 .. i + 4] is &mut [u32]. Yes, the bounds are compile time constants, but these compile time constants aren't a part of the type system in any way.


The whole point of this thread is discussing about possible future extensions of the type system(s) to add slice/vec sizes, while still avoiding proper dependent typing.

I need to improve my exposition skills in English.


Oddly enough, dependent types just came up in Idea: make assert!() a keyword

@leonardo, would that thread be helpful to you?

Sounds like safer transmutation would be of interest?

Are you suggesting implementing a subset of dependent types for just slice and Vec? I'm not sure about slice, but I don't think it would be good to special-case Vec since it's implemented in the standard library; the compiler shouldn't need to know or care about anything in the standard library, including Vec.

1 Like

I do think that dependent types are interesting, but I'm not sure they're the right fit for Rust since it's a fairly mature language at this point.

This looks a bit weird to me :slight_smile:

let _: &[u8; 3] = (&[0u8; 9]).transmute_into();

I am not an expert about type systems, but I think dependent types are types that could depend on the run-time value of some variable. That's not happing above. Once the sub-slice is verified to be contained inside the slice (and this is a regular run-time test of slicing bound test that's already present in Rust, and this test isn't done by .into()) you don't need the run-time value of any run-time variable to know the type of the result. After the panics-prone bounds test, the result is always of type &mut [u32; 3] as requested by the return type of foo().

So the code is similar to:


use std::convert::TryInto;

pub fn take_array_from_mut<T, const N: usize>(data: &mut [T], start: usize) -> &mut [T; N] {
    (&mut data[start .. start + N]).try_into().unwrap()

fn main() {
    let mut arr: Vec<u32> = (0 .. 35).collect();

    for i in 0 .. 30 {
        let window: &mut [_; 3] = take_array_from_mut(&mut arr, i);
        println!("{:?}", window);

Where the function gets compiled to (rustc removes the unwrap):

    push    rax
    lea     rax, [rdx + 3]
    cmp     rdx, -4
    ja      .LBB3_3
    cmp     rax, rsi
    ja      .LBB3_4
    lea     rax, [rdi + 4*rdx]
    pop     rcx
    lea     rcx, [rip + .L__unnamed_2]
    mov     rdi, rdx
    mov     rsi, rax
    mov     rdx, rcx
    call    qword ptr [rip + core::slice::index::slice_index_order_fail@GOTPCREL]
    lea     rdx, [rip + .L__unnamed_2]
    mov     rdi, rax
    call    qword ptr [rip + core::slice::index::slice_end_index_len_fail@GOTPCREL]

This means that to allow the usage of .into() there the type system has to verify that the slice is of length 3. This can't be done in general, but in lot of useful practical cases it could be done, and it becomes syntax sugar for a function like take_array_from{_mut}.

I see 2 conditional jumps (ja, jump if above) to .LBB3_3 and .LBB3_4 and they call slice_index_order_fail and slice_end_index_len_fail respectively. Those calls are followed by ud2, which is the same instruction emitted by unreachable_unchecked, i.e. the program must either throw an exception (panic; inside the called functions) or abort. In conclusion, the call to unwrap hasn't been removed, but inlined. That's an important difference.

A never-inlined function can never be optimized in the way you want it to, because you explicitly tell the compiler to not look outside the function to optimize what's inside function¹.

¹ Except maybe for profile-guided optimization, which I'm unfamiliar with, but that wouldn't change the function semantically, IIRC.

Are you sure a failed unwrap doesn't call core::result::unwrap_failed instead?

I have used never inline because that function doesn't need the information outside it to remove the final unwrap.

What your program does:

  1. Create a vector with 34 values
  2. Set i to 0
  3. Loop while i < 30
    1. Call take_array_from_mut with a reference to your vector, i and the constant 3 for N
    2. Print the returned value
    3. Increment i

When looking at your program as a whole, you're right to assume, that there will never be an out-of-bound access and the conditional jumps can be removed. However, when looking only at the function signature, the compiler sees data as a reference to any slice, i.e. it could have a length between 0 and isize, and it sees start as any usize, i.e. a value between 0 and 2⁶⁴-1 (on 64 bit systems). You can easily construct an example where indexing with an arbitrary value into an arbitrary slice can cause an out-of-bounds memory access. In conclusion, the function by itself does not have enough information to optimize away the panics caused by unwrapping.

You could look at the unoptimized code and see, if there are calls to slice_index_order_fail and slice_end_index_len_fail. If they're still there and you see the call to unwrap, you could be right. In that case, post the ASM output here.

	sub     rsp, 88
	mov     rax, rdx
	add     rax, 3
	setb    cl
	test    cl, 1
	mov     qword ptr [rsp + 64], rdi
	mov     qword ptr [rsp + 56], rsi
	mov     qword ptr [rsp + 48], rdx
	mov     qword ptr [rsp + 40], rax
	jne     .LBB196_5
	lea     rax, [rip + .L__unnamed_9]
	mov     rcx, qword ptr [rsp + 48]
	mov     qword ptr [rsp + 72], rcx
	mov     rdx, qword ptr [rsp + 40]
	mov     qword ptr [rsp + 80], rdx
	mov     rdx, qword ptr [rsp + 72]
	mov     rcx, qword ptr [rsp + 80]
	mov     rdi, qword ptr [rsp + 64]
	mov     rsi, qword ptr [rsp + 56]
	mov     r8, rax
	call    qword ptr [rip + core::slice::index::<impl core::ops::index::IndexMut<I> for [T]>::index_mut@GOTPCREL]
	mov     qword ptr [rsp + 32], rax
	mov     qword ptr [rsp + 24], rdx
	mov     rdi, qword ptr [rsp + 32]
	mov     rsi, qword ptr [rsp + 24]
	call    qword ptr [rip + <T as core::convert::TryInto<U>>::try_into@GOTPCREL]
	mov     qword ptr [rsp + 16], rax
	lea     rax, [rip + .L__unnamed_10]
	mov     rdi, qword ptr [rsp + 16]
	mov     rsi, rax
	call    qword ptr [rip + core::result::Result<T,E>::unwrap@GOTPCREL]
	mov     qword ptr [rsp + 8], rax
	mov     rax, qword ptr [rsp + 8]
	add     rsp, 88
	lea     rdi, [rip + str.1]
	lea     rdx, [rip + .L__unnamed_11]
	mov     rax, qword ptr [rip + core::panicking::panic@GOTPCREL]
	mov     esi, 28
	call    rax

In conclusion, the function by itself does not have enough information to optimize away the panics caused by unwrapping.<

I think you're confusing the slice bounds tests panics with the unwrap panic. The first is still present in the isolated (never inline) function but LLVM is able to remove the second one because once you have taken a [start .. start + N] it's always of length N and LLVM understands this, removing the unwrap branching to panic.

fn foo(a: &mut [u32], i: usize) -> &mut [u32; 3] {
    (&mut a[i + 1 .. i + 4]).into()

This can be solved with const-generics by reordering expressions to ensure the length 3 is actually a separate compile time constant not mixed in the dynamic indices. Amazingly, type deduction works.

fn foo(a: &mut [u32], i: usize) -> &mut [u32; 3] {
    a[i + 1..].array_chunks_mut().next().unwrap()

It doesn't require compiler intervention, so I've made it into a crate. I find it more readable..

use index_ext::array::Prefix;
fn foo(a: &mut [u32], i: usize) -> &mut [u32; 3] {
    &mut a[i + 1..][Prefix]