Machine learning primitives in rustc -- An Opportunity

As the recent crate wyrm demonstrates, and has been pointed out by several in this thread, these types of operations can be done in libraries today. See an example of wyrm’s define-by-run syntax in the example linked above to see what this looks like:

This is already quite nice syntax, and while I do think there are ergonomic improvements to be made in the future, I think those are outside of this internals discussion.

The part that I think could benefit from compiler assistance/integration is the set of compiler optimizations that should be happening under the hood to support this syntax/library (or is probably not happening eg in wyrm because of the difficulties involved):

  1. CFG and dataflow analysis (including liveness analysis for memory savings)
  2. Operator fusion (both vertical and horizontal)
  3. Type safety for eg matrix computations (this is something rustc can/will help with const generics, but as @gavento points out, this might benefit from RTTI ala something like trait objects or monomorphisms (similar to TensorFlow’s recurrent length bucketing strategies)).
  4. Etc…

I can go into more detail for any of these analyses/optimizations since I work with an open source deep learning compiler project for my day job (Intel nGraph). Though this post/idea is all on my free time.

In general, I agree with @notriddle’s opinion that things like concurrency should stay outside the compiler to improve velocity/experimentation etc, but given how much of the above list is something the compiler is reasoning about already, building out an out-of-compiler set of representations/optimization pass machinery/analysis passes is not only wasteful, but will likely be more difficult when trying to reason about/optimize across the domains of these two compilers. Example: if I take a one dimensional slice out of a tensor, sort it, and then broadcast it back across multiple dimensions to do a binary operation across another tensor, then I lose the ability to optimize across this operation by doing an in-place sort unless I know the liveness of the pre-sliced tensor.

I think this is a somewhat different situation than the concurrency primitives being outside of rustc given that the primitives of concurrency are quite dissimilar to the core operations of rustc, unlike what I’ve described above.

That being said, I do think we are a long way out from adding all of this stuff to rustc today, but I think the process of laying the groundwork can start. For example, it sounds like the current consensus is that external MIR plugins/passes are currently discouraged (due to the understandable desire for MIR APIs/details to remain internal). So perhaps a next question is to figure out

  1. Given that we’d want to start out using some of the existing/new deep learning compilers (nGraph, NNVM+TVM, etc) at the library level, is there a natural way to iteratively transition/experiment to rustc aware optimizations that can leverage/benefit/extend planned rustc MIR optimizations?
  2. If MIR is the right level to be operating on these types of constructs? I think so, because once you lower to LLVM, you’ve lost a lot of your higher level info.
  3. How can we enable quick iteration/experimentation with MIR passes given the desire to keep this interface private? I think a similar analogy is that of the Linux device driver contract: If you upstream your driver code, then its use of internal APIs will likely change, but those will be taken care of by the people making the changes, whereas if you don’t upstream, then we’re free to change our internal APIs and break your driver and it’s your responsibility. But this seems like a burden on rustc…
2 Likes