In my previous blog post, I discussed JAX – a framework for high performance numerical computing and machine learning — in an atypical manner. I didn’t create a single training loop, and only showed a couple patterns that looked vaguely machine learning-like. If you haven’t read that blog post yet, you can read it here.
This approach was deliberate as I felt that JAX — although designed for machine learning research — is more general-purpose than that. The steps to use it are to define what you want to happen, wrap it in within jax.jit, let JAX trace out your function into an intermediate graph representation, which is then passed to XLA where it will be compiled and optimised. The result is a single, heavily-optimised, binary blob, ready and waiting to receive your data. This approach is a natural fit for many machine learning applications, as well as other scientific computing tasks. Therefore, targeting machine learning only didn’t make sense. It is also ground that has already been extensively covered — I wanted to do a different take on introductory JAX.
...