Async fn in dyn traits without heap* is provably impossible

This is going to be a bit of a rough pseudo mini blog post with minimal editing so hang on.

-> impl Trait in traits is a bit easier than generalized dynamically sized return values. This is because while you don't know the returned value's size at compile time, you do know it at runtime strictly before calling the function. That is, because the concrete impl returns a sized type, you could conceivably put the return type's layout in the vtable, and then use that to allocate space for the return slot.

In a regular function, you can at least conceivably use alloca/stack VLA[1]s to reserve space on the stack to avoid a heap allocation for such a function. Unfortunately, doing this for async fn (-> impl Future) runs into a nonobvious problem: you can't .await it.

The problem is that the impl Future that the async fn produces is itself a Sized type. In order to .await a value, any values on the logical stack need to be put into pinned future's state, so that they'll still be there when the function resumes. And this means that the impl Future you're awaiting is part of your state, so you're also dynamically sized.

So, you might be able to get away with one layer of this by passing the buck, and saying that your async fn also returns this special runtime sized value if it awaits an async trait. If the &dyn is passed in to your async fn, it is possible for the caller to know how much space you're going to want for your call, and so can give it to you.

But this doesn't work in the general case. In the general case, it falls back to whole-program analysis[2], and quickly undecidability via the halting problem. If the caller doesn't have access to the vtable, they can't possibly know the correct size ahead of time. Recursive calls (without heap indirection) already don't work in async fn, but they become much harder to avoid and diagnose with async fn in traits, because it's no longer about what you do call, but what you can call via dynamic dyspatch. This might even only be able to be diagnosed at runtime, even with the subset that kind of works.

The title of the post is clickbait because I don't have a formal proof that it's not possible. I just have a whole pile of really annoying edge cases (What if one call/await is a prerequisite of making a further decision on which instantiation to call? At some point taking the worst case of all possibilities breaks down.) and a hunch that it becomes truly impossible. And I've just ignored the issues of "colored functions" -- because this really does introduce a semantic split, between "sized return" functions and "runtime selected sized return" functions.

Perhaps it is possible to teach a compiler to allow certain "simple" cases, so that more code just works without having to box it. But I honestly think down that way lies far too much nonlocal effects and poor whole-program errors, that just requiring the use of dynamic allocation here (or a bridge from dynamic to a static memory chunk, e.g. StackBox) is just easier and the better direction in general.

Down the alloca path lies madness.


* Of course, StackBox or other fixed-max-size-align containers work fine, but I'm referring to a general-purpose language-supported ability that doesn't go through the dynamic allocation APIs. A compiler wanting to insert StackBox would have to do whole-program analysis[2] to determine the maximal size/align of any potential return value, and even if there isn't any actual recursion in the program, it would be very easy to accidentally have the type potential for a recursive call, leading to infinite size requirements with painful "here's why" errors.

[1] Variable Length Array

[2] Plus, the whole point of anything dyn is to avoid monomorphization. A reliance on whole-program analysis means that the end binary has to be in charge of monomorphizing the function for that program (if the whole program is ever statically available, such as e.g. in the face of dynamic linking).

2 Likes

Fascinating! It seems to be related to the core tradeoff between stalkless and stackfull coroutines — knowing or not knowing maximal stack size in advance. It makes sense that, if you are doing dynamic dispatch, you can’t know stack usage.

3 Likes

I am not sure I follow your reasoning.

This premise looks wrong to me.

Compiler always knows size of an -> impl Trait type at compile time. Such types do not require stack VLAs to work. The problem with -> impl Trait in traits is that compiler does not know its size until code gets monomorphized, so it has to defer some compilation steps until then. But AFAIK this problem is not unique to -> impl Trait in traits, the same applies to the following trait based on const generics:

trait Foo {
    const N: usize;
    fn foo() -> [u8; Self::N];
}

Here compiler also does not know size of the return type until monomorphisation hits.

Another example of similar code is this function:

pub fn foo<T: Sized + Add<Output=T>>(val: T) -> impl FnOnce(T) -> T {
    move |t| val + t
}

Here size of the return closure also not known to compiler until monomorphisation, but nevertheless compiler handles it good enough without relying on stack VLAs.

Same logic applies to async fns. If we have "virtual" stack of size N and we call an async trait method when this stack is filled up to M (of course M <= N), then compiler may have to extend size of the "virtual" stack to M + T if it's bigger than N, where T is stack size required for the async method. The problem is that T is not known to compiler until monomorphisation, so it can not perform a lot of optimizations locally. Yes, it's not ideal, but I don't think it's that bad as you make it to be when you mention "whole-program analysis". And note that, as mentioned earlier, this problem is not unique to async code.

Now, as for object safety of traits with async methods or impl Trait return types, I agree they should not be object safe. I think that any proposed solution should be universal enough to be able to handle the const generic trait Foo and personally I don't think such solution could be practical. It's probably would be better to simply recommend users to create objects-safe trait wrappers (in the async case they would return boxed futures), as we do for example with the DynDigest trait.

Most optimizations already happen post-monomorphization anyway. When it comes to code generation, I'm pretty sure supporting async fns in traits for static dispatch presents zero additional difficulty compared to the existing support for freestanding async fns. There's plenty of difficulty when it comes to type system implementation and language design, but not code generation. (However, I don't speak from any authority here since I'm not a rustc contributor.)

I agree that dynamic dispatch is a completely different story and likely not feasible.

1 Like

They do require runtime allocation (stack or heap) in a dyn Trait though, which is what the post is about.

And then due to the fact that you can't have a runtime stack allocation in an async fn, it requires heap to be able to just .await it.

6 Likes

I think it might be possible. The placement by return rfc describes a transformation similar to generators that would allow some functions to return a DST. That transformation could be generalized to allow (almost) any function to return a DST.

The function would be transformed into a state machine, storing the locals that are alive for each state (like with generators/futures). Locals that are dynamically sized would be stored as a (move) reference pointing to memory allocated by the calling function (or one of its ancestors). By storing a reference the function state remains Sized. Each time the function requires memory for a DST (for one of its locals, the return value, etc.) it would yield control to the calling function returning the required Layout. The caller may then either allocate the memory and resume the function or yield the request to its caller.

For trait functions that may return DSTs (they would be required to opt-in for functions that return associated types), the vtable for the trait would include the Layout of the state for the generator (which is Sized since all dynamically sized locals are stored as references) and the relevant function pointers.

The output of such a transformation might look like:

#![feature(
    const_mut_refs,
    const_fn_trait_bound,
    const_fn_fn_ptr_basics,
    unsize,
    generic_associated_types
)]

use core::{
    alloc::Layout,
    cell::{Cell, UnsafeCell},
    fmt::Debug,
    marker::{PhantomData, Unsize},
    mem::{self, MaybeUninit},
    ops,
    pin::Pin,
    ptr, slice,
};

pub enum FnState<T> {
    SpaceNeeded(Layout),
    Complete(T),
}

pub trait GeneratorFn<Args>: Sized {
    type State<'r>;
    type Output: ?Sized;

    fn start<'r>(args: Args) -> (Self::State<'r>, Layout);

    /// # Panics
    /// If a previous call to this function or [`start`] either returned `FnState::Complete` or paniced, then
    /// this function may panic.
    ///
    /// # Safety
    /// If this is the first call to this function or if the last call to to this function returned
    /// `FnState::SpaceNeeded(layout)` then the caller **must** pass a slice meeting the
    /// requirements specified by `layout`.
    unsafe fn resume<'r>(
        this: Pin<&mut Self::State<'r>>, space: &'r mut [MaybeUninit<u8>],
    ) -> FnState<MoveRef<'r, Self::Output>>;
}

// Vtable for fn(Args) -> R where R is a dst
pub struct GeneratorFnVtable<Args, R: ?Sized> {
    layout: Layout,
    drop: unsafe fn(*mut ()),
    start: for<'r> unsafe fn(Args, &'r mut [MaybeUninit<u8>]) -> (*mut (), Layout),
    resume: for<'r> unsafe fn(*mut (), &'r mut [MaybeUninit<u8>]) -> FnState<MoveRef<'r, R>>,
}

impl<Args, R: ?Sized> Copy for GeneratorFnVtable<Args, R> {}
impl<Args, R: ?Sized> Clone for GeneratorFnVtable<Args, R> {
    fn clone(&self) -> Self {
        *self
    }
}

impl<Args, R: ?Sized> GeneratorFnVtable<Args, R> {
    pub const fn new<F>() -> Self
    where F: GeneratorFn<Args, Output = R> {
        Self {
            layout: Layout::new::<F::State<'_>>(),
            drop: |p| unsafe { p.cast::<F::State<'_>>().drop_in_place() },
            start: |args, space| unsafe {
                let (s, l) = F::start(args);
                let p = space.as_mut_ptr().cast::<F::State<'_>>();
                p.write(s);
                (p.cast::<()>(), l)
            },
            resume: |p, space| unsafe {
                F::resume(Pin::new_unchecked(&mut *p.cast::<F::State<'_>>()), space)
            },
        }
    }

    pub fn layout(&self) -> Layout {
        self.layout
    }

    pub fn start<'r>(
        &self, args: Args, slot: &'r mut [MaybeUninit<u8>],
    ) -> (DynGeneratorFnState<'r, Args, R>, Layout) {
        assert!(slot.len() >= self.layout.size());
        assert_eq!((slot.as_ptr() as usize) % self.layout.align(), 0);
        let (state, layout) = unsafe { (self.start)(args, slot) };
        (
            DynGeneratorFnState {
                state,
                vtable: *self,
                _marker: PhantomData,
            },
            layout,
        )
    }
}

pub struct DynGeneratorFnState<'r, Args, R: ?Sized> {
    vtable: GeneratorFnVtable<Args, R>,
    state: *mut (),
    _marker: PhantomData<&'r mut ()>,
}

impl<Args, R: ?Sized> Drop for DynGeneratorFnState<'_, Args, R> {
    fn drop(&mut self) {
        unsafe {
            (self.vtable.drop)(self.state);
        }
    }
}

impl<'r, Args, R: ?Sized> DynGeneratorFnState<'r, Args, R> {
    pub fn resume(&mut self, space: &'r mut [MaybeUninit<u8>]) -> FnState<MoveRef<'r, R>> {
        unsafe { (self.vtable.resume)(self.state, space) }
    }
}

pub struct MoveRef<'a, T: ?Sized>(&'a mut T);

impl<T: ?Sized> Drop for MoveRef<'_, T> {
    fn drop(&mut self) {
        unsafe {
            ptr::drop_in_place(self.0);
        }
    }
}

impl<T: ?Sized> ops::Deref for MoveRef<'_, T> {
    type Target = T;

    fn deref(&self) -> &Self::Target {
        &*self.0
    }
}

impl<T: ?Sized> ops::DerefMut for MoveRef<'_, T> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut *self.0
    }
}

impl<'a, T: ?Sized> MoveRef<'a, T> {
    pub fn place<U: 'a>(slot: &'a mut [MaybeUninit<u8>], value: U) -> Self
    where U: Unsize<T> {
        assert!(slot.len() >= mem::size_of::<U>());
        assert_eq!((slot.as_ptr() as usize) % mem::align_of::<U>(), 0);
        let p = slot.as_mut_ptr().cast::<U>();
        // SAFETY: This is safe because the size and alignment of slot were just checked.
        unsafe {
            p.write(value);
            Self(&mut *p)
        }
    }
}

pub struct Foo;

pub enum FooState {
    Start,
    Complete,
}

// fn foo() -> [u32] {
//     [0, 1, 2, 3, 4]
// }
impl GeneratorFn<()> for Foo {
    type State<'r> = FooState;
    type Output = [u32];

    fn start<'r>(_args: ()) -> (Self::State<'r>, Layout) {
        (FooState::Start, Layout::new::<[u32; 5]>())
    }

    unsafe fn resume<'r>(
        mut this: Pin<&mut Self::State<'r>>, space: &'r mut [MaybeUninit<u8>],
    ) -> FnState<MoveRef<'r, Self::Output>> {
        match this.as_mut().get_mut() {
            FooState::Start => {
                *this.get_mut() = FooState::Complete;
                FnState::Complete(MoveRef::place(space, [0, 1, 2, 3, 4]))
            }
            FooState::Complete => panic!("resume called after complete"),
        }
    }
}

pub struct Bar;
pub enum BarState<'r> {
    Start(GeneratorFnVtable<(), [u32]>),
    WaitForSpace1(DynGeneratorFnState<'r, (), [u32]>),
    WaitForSpace2(MoveRef<'r, [u32]>),
    Complete,
}

// fn bar(f: fn() -> [u32]) -> dyn Debug {
//     f().iter().sum::<u32>()
// }
impl GeneratorFn<(GeneratorFnVtable<(), [u32]>,)> for Bar {
    type State<'r> = BarState<'r>;
    type Output = dyn Debug;

    fn start<'r>((args,): (GeneratorFnVtable<(), [u32]>,)) -> (Self::State<'r>, Layout) {
        let layout = args.layout;
        (BarState::Start(args), layout)
    }

    unsafe fn resume<'r>(
        mut this: Pin<&mut Self::State<'r>>, space: &'r mut [MaybeUninit<u8>],
    ) -> FnState<MoveRef<'r, Self::Output>> {
        match this.as_mut().get_mut() {
            BarState::Start(f) => {
                let (s, l) = f.start((), space);
                *this.get_mut() = BarState::WaitForSpace1(s);
                FnState::SpaceNeeded(l)
            }
            BarState::WaitForSpace1(s) => match s.resume(space) {
                FnState::SpaceNeeded(l) => FnState::SpaceNeeded(l),
                FnState::Complete(r) => {
                    *this.get_mut() = BarState::WaitForSpace2(r);
                    FnState::SpaceNeeded(Layout::new::<u32>())
                }
            },
            BarState::WaitForSpace2(s) => {
                let sum = s.iter().sum::<u32>();
                *this.get_mut() = BarState::Complete;
                FnState::Complete(MoveRef::place(space, sum))
            }
            BarState::Complete => panic!("resume called after complete"),
        }
    }
}

// Wrap a function that returns a non-dst in a generator
// pub struct RegularFn<F, R>(PhantomData<(F, R)>);
// pub enum RegularFnState<R> {
//     Start(R),
//     Complete,
// }
//
// impl<Args, R: Unpin, F> GeneratorFn<Args> for RegularFn<F, R>
// where F: Default + Fn(Args) -> R
// {
//     type State<'r> = RegularFnState<R>;
//     type Output = R;
//
//     fn start<'r>(args: Args) -> (Self::State<'r>, Layout) {
//         let r = (F::default())(args);
//         (RegularFnState::Start(r), Layout::new::<R>())
//     }
//
//     unsafe fn resume<'r>(
//         this: Pin<&mut Self::State<'r>>, space: &'r mut [MaybeUninit<u8>],
//     ) -> FnState<MoveRef<'r, Self::Output>> {
//         match mem::replace(this.as_mut().get_mut(), RegularFnState::Complete) {
//             RegularFnState::Start(r) => FnState::Complete(MoveRef::place(space, r)),
//             RegularFnState::Complete => panic!("resume called after complete"),
//         }
//     }
// }

// pub trait FutureDst {
//     type Output: ?Sized;
//
//     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>;
// }
// pub struct FutureVtable {
//     layout: Layout,
//     drop: unsafe fn(*mut ()),
//     poll: GeneratorFnVtable<(*mut (), *mut Context<'static>), Poll<dyn Drop>>,
// }

struct Arena {
    space: [UnsafeCell<MaybeUninit<u8>>; 16 * 1024],
    // INVARIANT: next is always <= space.len()
    next: Cell<usize>,
}

impl Arena {
    fn new() -> Self {
        Self {
            space: unsafe { MaybeUninit::uninit().assume_init() },
            next: Cell::new(0),
        }
    }

    fn allocate(&self, layout: Layout) -> &mut [MaybeUninit<u8>] {
        let next = self.next.get();
        let remaining = self.space.len() - next;
        // SAFETY: By the invariant of Arena, next is always within the bounds of space.
        let p = unsafe { (*self.space.as_ptr().add(next)).get() };
        let start = p.align_offset(layout.align());
        let total = start.checked_add(layout.size()).expect("stack overflow");
        assert!(total <= remaining, "stack overflow");
        self.next.set(next + total);
        // SAFETY: The bounds were just checked.
        unsafe { slice::from_raw_parts_mut(p.add(start).cast(), layout.size()) }
    }
}

// fn foo_bar(f: fn(fn() -> [u32]) -> dyn Debug) {
//      println!("{:?}", f(foo));
// }
fn foo_bar(f: GeneratorFnVtable<(GeneratorFnVtable<(), [u32]>,), dyn Debug>) {
    let arena = Arena::new();
    let r = 'outer: loop {
        let (mut s, mut l) = f.start(
            (GeneratorFnVtable::new::<Foo>(),),
            arena.allocate(f.layout()),
        );
        loop {
            match s.resume(arena.allocate(l)) {
                FnState::Complete(r) => break 'outer r,
                FnState::SpaceNeeded(n) => {
                    l = n;
                }
            }
        }
    };
    println!("{:?}", &*r);
}

fn main() {
    foo_bar(GeneratorFnVtable::new::<Bar>());
}

(playground)

Note: This is just a proof of concept and there may be safety/lifetime issues. The transformation uses the memory allocated for both dynamically sized locals and the return value, which doesn't meet the needs that motivated the placement by return RFC. The transformation would need to be modified, so that each request for memory would include the purpose for the memory (storing locals vs the return value) allowing the caller to modify its allocation strategy.

1 Like

This theoretically works at arbitrary depth for straightline, blocking code, yes.

But my argument here is that it can't work, even in theory, for async fn in particular because they a) already have a state machine and b) need their locals to be pinned.

Our async semicoroutines are stackless. The key point is that you can suspend execution of the async fn and resume it later at an .await point. There is no persistent stack in which to allocate an unsized local that will stay there between calls to Future::poll.

I'd honestly love guaranteed return placement used to allow unsized locals and return values. But a fundamental limitation of stackless semicoroutines (async fn) is that they don't have a dynamic stack, and thus they can't hold on to unsized locals across .await points (which includes .awaiting an unsized local future, e.g. a dyn trait's async fn).

(If it goes through dynamic allocation APIs, then it's fine. I'm only talking about skipping dynamic allocation entirely and having the compiler do automatic stack allocation. Arena allocation is still dynamic allocation, and thus considered the same as "heap." It's maybe not the most correct term, but it's I think still a useful one, as "uses Box," or isn't just a transparent local. I want to support Box<T, LocalStorage<MaxSize, MaxAlign>> for inline stack dynamic static allocation probably more than most next people.)

6 Likes