Learning about JAX :axes in vmap()
In the recent weeks I’ve been learning about JAX, a Python library for machine learning developed by Google Research team and extensively used by Deepmind.
What is JAX and why is it different?
JAX uses Numpy based API for numerical computation. It has three main features making it attractive for fast and simple deep learning model implementation:
Auto-differentiation (a.k.a
grad
in JAX) for customized functions. This means when optimizing loss for a machine learning model, instead of having to derive gradient and code it yourself or using syntactic and semantic constraints, thegrad
function will just compute the gredient for you based on the loss function.Fast and efficient vectorization and parallelization (a.k.a
vmap
andpmap
in JAX). Like numpy vectorization, it simplifies the coding process and speeding up computation drastically.Just-in-time (JIT) compilation (a.k.a
jit
in JAX). I think this is the main secret sauce for JAX to perform so much faster than other packages, take a look at the comparisons in quick start onjit
and combinejit
withvmap
.
Resources to learn about JAX
Tutorial: JAX 101 from JAX docsfile is a great point to start, though I found a bit challenging to understand some examples as a beginner.
My mentor plus friend Eric Ma wrote a gradient-based machine learning course called dl-workshop, which uses JAX framework to illustrate machine learning concepts.
Robert Tjarko Lange has several blog posts on using JAX on deep learning based projects, while Colin Raffel demonstrated key concepts of JAX in You don’t know JAX with code.
How does vmap()
function work?
According to JAX docsfile on vmap, jax.vmap(function, in_axes=0, out_axes=0)
returns a function which maps the function one specified over using in_axes
and stack them together using out_axes
. The concept is simple but it took me a while to understand when in_axes
or out_axes
is not set as default.
To understand what it is really doing when in_axes
and out_axes
change, I’m creating simple arrays with unique numbers in matrix forms called a
and b
and passing them into a simple function numpy.sum(axis=None)
to do default element-wise addition.
- Let’s see an example with
in_axes
as default.
The in_axes
is a tuple with two number, since there are two inputs a
and b
. in_axes[0]
indicates which axis to focus on for a
, it can take in different values: None, 0 or -2 (map over 0th axis also known as -2 axis in this case, the operation will be performed row-wise), 1 or -1 (map over 1st axis also known as -1 axis in this case, the operation will be performed col-wise). Given a
only has two dimensions, there won’t be any additional axis available for vmap to work on. in_axes[1]
can take in four different values same as in the case of a
, but the computation would be on b
. Note at least one input in in_axes
should be integer instead of None. out_axes
simply means how one would like to stack the results, either by row ( as 0) or column (as 1).
- Change
in_axes
to customized integers, what would the function return?
The operation is done on the specified axis. Overall, if in_axes
has all integers, the shape of the output remains the same as of (2,2) just like what you would see in regular np.add(a,b)
.
- Mutate some integer in
in_axes
to None (note that at least one of the input should be integer), what would the result be like?
The example outputs are all of shape (2,2,2) instead of the original shape (2,2), why is that? We see the in_axes
for a
is always None
, this indicates the operation is going to pick whatever axis specified in b
and apply it on the all rows of a
. Given a
is of shape (2,2) and b
has two rows and two columns, you can think of each operation produced two copies of a, then apply row or col of b
on a
.
- A more complicated example
The above example is simple, each input is of the same shape and has two dimensions. What if we have two matrices each of shape (2,3,5,7,9)
? What would we expect when changing the in_axes
? Note that although a
and b
matrices are initiated differently, they have the same shape.
Based on user supplied function
, only inputs with matching dimensions will work together. Try vmap(jnp.add, in_axes = (0, 1) , out_axes = 0 )(a,b).shape
it would tell you 0-axis of a is of size 2 while 1-axis of b is of size 3, therefore, the computation can’t happen. Then it is obvious that the operation takes out whatever axis specified and stack them as 0-th axis.
Let’s set one of the in_axes
as None
and change the out_axes
.
You realized there is an additional axis in the output just like in the simple example. But the returned shape can be a bit confusing. Is there a pattern? I summarized the pattern in an intuitive way in the table below, though it might also be possbile to infer the output shape in a different way. First, write down the output shape when out_axes=0
. Then, sequentially swap output axes. The output shape of out_axes ==n
equals to swapping the (n-1)th index with nth index in out_axes ==(n-1)
shape.
Summary of in_axes
and out_axes
in vmap()
Imagine you have two matrices that the customized function is applied on, A1 is of shape (a1,b1,c1,….) with x dimensions, A2 is of shape (a2,b2,c2,….) with y dimensions, and the output has z dimensions.
The values used for
in_axes
:in_axes = (x',y')
, where x’ is an integer between [-x, x), y is an interger between [-y, y). x’ or y’ can also be None, but at last there should be one integer value in (x’, y’).The scenario of
in_axes = (n,None)
will apply take out the nth index from the first matrix and apply (1, remianing shape of the first matrix…) on the 2nd matrix. e.g A1.shape = (a1,b1,c1), A2.shape = (a2,b2,c2).in_axes
= (1,None) will apply vectors of shape (1,a1,c1) to all rows in A2, whilein_axes = (2,None)
will apply vectors of shape (1,a1,b1) to the all rows in A2.out_axes
is an interger of value between (-z, z].To figure out the change of shape when you change
out_axes
: First, write down the output shape whenout_axes=0
. Then, sequentially swap output axes. The shape of the output forout_axes
==n equals to swapping the (n-1)th index with nth index inout_axes
==(n-1) output shape.
I want to thank Arkadij Kummer for giving valuable advice on making this blog post easier to understand :) .