I just spent time reading the Swift proposition :
-
Writing a differentiable trait for types with an associated derive macro (that would define the type of the
TangentVector
) should be fairly straightforward. -
Marking a function as the user-defined
forward|reverse derivative
of another function (with respect to one or more parameters) could be a simple matter of renaming it following a predefined convention. -
Defining a
gradient!
macro that takes a function and the name of the parameters of interest in order to call the appropriate forward|reverse derivative should be doable. -
Marking a function as
#differentiable
is harder as we cannot know if its user-defined input type are differentiable at macro time. It could default to all parameters unless told otherwise (the Swift proposal offers several ideas in that space). -
Automatically building the forward derivative of a
#differentiable
function should be a mecanical process once we know which inputs should be derived. -
I am not familiar enough with tape based implementations (which is what swift uses) to know how to automatically build the reverse derivative of a
#differentiable
function but it seems doable (it can be done mechanically for simple function with no control flow nor mutability before going into a proper, general, implementation)
Steps 1 to 3 should be doable in a relatively short time.
Step 4 and 5 are more delicate as their require some AST manipulation but they are doable (and would be a good way to explore how Rust type system interacts with the transformations and what is truly required of a differentiable type (clone
?)).
Step 6 requires a review of the state of the art, some design brainstorming, a solid tape implementation and knowledge gotten from the previous steps.
The result could be as efficient as the Swift implementation, the hardest part being having nice error messages when user forget to differentiate something along the way.
I do not have the time to start prototyping (I already have the advent of code, a crate under work (for which I do manual gradient computation) and a PhD to finish) but I stay on my position that this is doable now, with macros, on stable Rust.
For those interested but not familiar with the subject, here is a nice tutorial on forward-differentiation vs reverse-differentiation with an associated, run time evaluated, Rust implementation : reverse-mode-automatic-differentiation