On Learning JAX – A Framework for High Performance Machine Learning
Recently, I took part in the Huggingface x Google Cloud community sprint which (despite being named a ControlNet sprint) had a very broad scope: involve diffusion models, use JAX, and use TPUs provided for free by Google Cloud. A lot of cool projects came out of it in a relatively short span of time. Our project was quite ambitious: to take my master’s dissertation work on combining step-unrolled denoising autoencoders (loosely adjacent to discrete diffusion models) with VQ-GAN, porting it all to JAX, then adding support for text-conditioned generation....