In the recent weeks I’ve been learning about JAX, a Python library for machine learning developed by Google Research team and extensively used by Deepmind.
What is JAX and why is it different?
JAX uses Numpy based API for numerical computation. It has three main features making it attractive for fast and simple deep learning model implementation:
gradin JAX) for customized functions. This means when optimizing loss for a machine learning model, instead of having to derive gradient and code it yourself or using syntactic and semantic constraints, the
gradfunction will just compute the gredient for you based on the loss function.
Fast and efficient vectorization and parallelization (a.k.a
pmapin JAX). Like numpy vectorization, it simplifies the coding process and speeding up computation drastically.
Just-in-time (JIT) compilation (a.k.a
jitin JAX). I think this is the main secret sauce for JAX to perform so much faster than other packages, take a look at the comparisons in quick start on
Resources to learn about JAX
Tutorial: JAX 101 from JAX docsfile is a great point to start, though I found a bit challenging to understand some examples as a beginner.
My mentor plus friend Eric Ma wrote a gradient-based machine learning course called dl-workshop, which uses JAX framework to illustrate machine learning concepts.
Robert Tjarko Lange has several blog posts on using JAX on deep learning based projects, while Colin Raffel demonstrated key concepts of JAX in You don’t know JAX with code.
vmap() function work?
According to JAX docsfile on vmap,
jax.vmap(function, in_axes=0, out_axes=0) returns a function which maps the function one specified over using
in_axes and stack them together using
out_axes . The concept is simple but it took me a while to understand when
out_axes is not set as default.
To understand what it is really doing when
out_axes change, I’m creating simple arrays with unique numbers in matrix forms called
b and passing them into a simple function
numpy.sum(axis=None) to do default element-wise addition.
- Let’s see an example with
in_axes is a tuple with two number, since there are two inputs
in_axes indicates which axis to focus on for
a, it can take in different values: None, 0 or -2 (map over 0th axis also known as -2 axis in this case, the operation will be performed row-wise), 1 or -1 (map over 1st axis also known as -1 axis in this case, the operation will be performed col-wise). Given
a only has two dimensions, there won’t be any additional axis available for vmap to work on.
in_axes can take in four different values same as in the case of
a, but the computation would be on
b. Note at least one input in
in_axes should be integer instead of None.
out_axes simply means how one would like to stack the results, either by row ( as 0) or column (as 1).
in_axesto customized integers, what would the function return?
The operation is done on the specified axis. Overall, if
in_axes has all integers, the shape of the output remains the same as of (2,2) just like what you would see in regular
- Mutate some integer in
in_axesto None (note that at least one of the input should be integer), what would the result be like?
The example outputs are all of shape (2,2,2) instead of the original shape (2,2), why is that? We see the
a is always
None, this indicates the operation is going to pick whatever axis specified in
b and apply it on the all rows of
a is of shape (2,2) and
b has two rows and two columns, you can think of each operation produced two copies of a, then apply row or col of
- A more complicated example
The above example is simple, each input is of the same shape and has two dimensions. What if we have two matrices each of shape
(2,3,5,7,9)? What would we expect when changing the
in_axes? Note that although
b matrices are initiated differently, they have the same shape.
Based on user supplied
function, only inputs with matching dimensions will work together. Try
vmap(jnp.add, in_axes = (0, 1) , out_axes = 0 )(a,b).shape it would tell you 0-axis of a is of size 2 while 1-axis of b is of size 3, therefore, the computation can’t happen. Then it is obvious that the operation takes out whatever axis specified and stack them as 0-th axis.
Let’s set one of the
None and change the
You realized there is an additional axis in the output just like in the simple example. But the returned shape can be a bit confusing. Is there a pattern? I summarized the pattern in an intuitive way in the table below, though it might also be possbile to infer the output shape in a different way. First, write down the output shape when
out_axes=0. Then, sequentially swap output axes. The output shape of
out_axes ==n equals to swapping the (n-1)th index with nth index in
out_axes ==(n-1) shape.
Imagine you have two matrices that the customized function is applied on, A1 is of shape (a1,b1,c1,….) with x dimensions, A2 is of shape (a2,b2,c2,….) with y dimensions, and the output has z dimensions.
The values used for
in_axes = (x',y'), where x’ is an integer between [-x, x), y is an interger between [-y, y). x’ or y’ can also be None, but at last there should be one integer value in (x’, y’).
The scenario of
in_axes = (n,None)will apply take out the nth index from the first matrix and apply (1, remianing shape of the first matrix…) on the 2nd matrix. e.g A1.shape = (a1,b1,c1), A2.shape = (a2,b2,c2).
in_axes= (1,None) will apply vectors of shape (1,a1,c1) to all rows in A2, while
in_axes = (2,None)will apply vectors of shape (1,a1,b1) to the all rows in A2.
out_axesis an interger of value between (-z, z].
To figure out the change of shape when you change
out_axes: First, write down the output shape when
out_axes=0. Then, sequentially swap output axes. The shape of the output for
out_axes==n equals to swapping the (n-1)th index with nth index in
out_axes==(n-1) output shape.
I want to thank Arkadij Kummer for giving valuable advice on making this blog post easier to understand :) .