Google’s JAX: Flexible, High-Performance Machine Learning

How to deploy models from JAX to production

Go Gopher mascot with code

With the rate at which the machine learning ecosystem evolves, there’s always an element of risk in investing in a young framework. As exciting as a new framework maybe—whether that’s in terms of its design, philosophy, feature-set, performance on benchmarks, etc—most researchers and practitioners will wait for the initial wave of early adopters to try things out before they decide whether or not its worth giving a chance.

Two years after the release of Google’s JAX framework, it feels safe to say that it has staying power. Looking at the machine learning community’s response to it, we’ve seen:

Given all the momentum behind it, we wanted to try JAX out from a production perspective and share our experience as well as a guide to deploying JAX models.

But first, let’s set some context around what JAX is and why it’s so popular.

JAX: The spiritual successor to Autograd

Seeing as it’s an entire numerical computation framework, it’s hard to comprehensively describe JAX in a single sentence, but Google offers a decent summation on the JAX repo:

“JAX is Autograd and XLA, brought together for high-performance machine learning research.”

Autograd, for those who are unfamiliar, is a library (no longer actively developed) that enabled very efficient automatic differentiation of Python and NumPy code, even for higher order derivatives, via both backpropagation and forward-mode differentiation. As you can imagine, this made it very useful to machine learning researchers of various concentrations.

XLA, on the other hand, is an accelerated linear algebra library originally designed by Google to optimize TensorFlow models. It, essentially, takes a TensorFlow graph and, on compilation, optimizes it. The official XLA documentations give a simple example to explain this:

Imagine you have a function that performs a simple TensorFlow computation, like tf.reduce_sum(x + y * z) . Without XLA, a new kernel will be launched for each different operation (the addition, multiplication, and reduction). With XLA, the operations will be grouped and executed on one kernel.

JAX combines both of these tools (all of the main Autograd contributors work full time at Google Brain or have otherwise contributed to JAX), and offers even more features on top of them, including just-in-time compilation and automatic vectorization. And, because JAX is designed with a purely functional approach—transformations gradient evaluation, jit compilation, and auto-vectorization are just functions which can be executed arbitrarily— it offers a very composable, flexible UX that is appealing to many researchers.

As an example, here is a snippet taken directly from the JAX repo:

JAX is fast, flexible, and offers a UX that many find to be an improvement over existing frameworks.

It is, however, a lower level framework—more comparable to something like NumPy—and so when we talk about putting into production, we’re often going to be talking about a JAX-based framework, of which there are many.

Speaking of…

Deploying JAX to production with Elegy and Cortex

Elegy is a neural networks library built on top of JAX that takes a lot of inspiration from the Keras API. We’ve picked it because it’s particularly simple to get setup, but most of the popular JAX libraries can be deployed using the same pattern we’ll be using here—simply serialize your model, load it into your predictor, and serve predictions.

Cortex, if you’re unfamiliar, is our open source platform for running inference at scale. It automates all of the underlying infrastructure work, from provisioning an inference cluster to serving your model, while allowing you to customize every aspect of your deployment, from autoscaling behavior to compute resources (CPUs, GPUs, ASICs, etc.) to logging.

Because Cortex deploys our model inside a prediction-serving API, we need to first define our request handling code within a predictor. Our predictor will have two methods (init() and predict() ), and will look like this:

Simple, right?

Now all we have to do is write our configuration file:

And then add a requirements.txt file, and we’re done. Now, when we run “cortex deploy” from the directory, Cortex will automatically package and containerize all of our prediction serving code and requirements, and deploy it to our Cortex cluster behind a load balancer.

Because JAX and Elegy are so composable, we can also implement any advanced Cortex feature—server-side batching, GPU inference, multi-model caching, etc.—by simply toggling a field in our configuration YAML, no changes to our Python needed.

A better bridge between research and production

One of our initial motivators in working on Cortex was the difficulty we had in deploying models to production. The research and production ecosystems, at the time, were two different worlds, and machine learning frameworks were simply painful to get to work in prod.

The production ecosystem has come a long way since then, and we’ve put a lot of work into removing the pain from the deployment process. However, we were very pleasantly surprised by JAX and the broader JAX ecosystem not just in terms of its performance, but by its ease of use in production.

The JAX ecosystem isn’t just exciting, from our perspective, because of its benefits to research. It is exciting because it represents progress on the bridge between research and production—but this time, from the research side.




Like Cortex? Leave us a Star on GitHub

Star Cortex

Interested in production machine learning?