Automatic Vectorization in JAX

Author:Murphy  |  View: 23491  |  Time: 2025-03-22 19:56:30

JAX is famous for its speed, efficiency, and flexibility when working with mathematical computations and machine learning. But one of its lesser-known superpowers – something that can save you from writing endless loops and boilerplate code – is automatic vectorization.

Photo by Hamish on Unsplash

If you've ever written code that processes arrays or batched data, you know how tedious it can be to optimize for parallelism. But with JAX's vmap (vectorization map) function, you can say goodbye to ugly loops and hello to concise, efficient, and parallelized code.

In this article, we're going to dive deep into automatic vectorization in JAX. We'll explore how vectorization works, why it's essential for speeding up computations, and how you can leverage JAX's vmap to avoid writing explicit loops. Along the way, we'll take some real-world examples and walk through code that will make you love JAX even more.

Ready? Let's go!


What's the Deal with Vectorization?

Before we get into JAX specifics, let's talk about vectorization in general. In traditional programming, you might write code that processes one data point at a time in a loop. For example, if you want to apply a function to every element of an array, you'd probably use a for loop to go through each element and compute the result. Something like this:

x = [1, 2, 3, 4, 5]
y = []

for i in x:
    y.append(i**2)

print(y)

and we get the output as:

[1, 4, 9, 16, 25]

This is fine for small data, but when you're working with large arrays or complex computations (like Deep Learning models), this approach quickly becomes inefficient. You'll be wasting CPU cycles iterating element by element when you could be processing the whole array at once!

Enter Vectorization

Vectorization is about doing operations on entire arrays (or batches of data) simultaneously, instead of processing one element at a time. The idea is that modern CPUs and GPUs are designed to handle such bulk operations in parallel. Libraries like NumPy, TensorFlow, and PyTorch encourage vectorization because it's much faster than writing explicit loops.

For example, you could rewrite the above example using NumPy to leverage vectorization:

import numpy as np

# NumPy: apply the function to the whole array at once
x = np.array([1, 2, 3, 4, 5])
y = x**2

print(y)

we get the following output:

[ 1  4  9 16 25]

In this case, instead of looping through x, NumPy applies the operation to the entire array at once, which is much faster.


What Makes JAX's Vectorization So Special?

I mean, this is a natural question that would arrive, as we already have numpy for vectorization.

Jax takes this concept of vectorization and supercharges it with automatic vectorization using vmap. Instead of having to manually restructure your code to vectorize it, vmap makes it simple to apply any function to arrays or batches of data without writing a single loop. And the best part? JAX's vmap plays nicely with all the other tools in JAX's toolbox, like JIT compilation, automatic differentiation, and parallelism.

Let's start with the basics.


The Magic of vmap: Making Loops Disappear

In JAX, vmap stands for vectorization map. It allows you to apply a function element-wise (or batch-wise) across an array or batch of inputs without writing loops manually. You can think of vmap as an automatic for loop, but instead of looping sequentially, it does everything in parallel. The beauty of vmap is that it's composable with all the other JAX features you know and love (like jit and grad).

Let's see a simple example to illustrate what I mean:

Without vmap: The Loop Approach

import jax.Numpy as jnp

# Define a simple function
def square(x):
    return x**2

# Create an array of inputs
x = jnp.array([1.0, 2.0, 3.0, 4.0])

# Manually apply the function to each element
y = jnp.array([square(x_i) for x_i in x])

print(y)
[ 1.  4.  9. 16.]

Here, we're applying the square function to each element of x manually by looping through it. But you can already see that this isn't ideal for large datasets. The loop is sequential, meaning it's slow, especially for bigger arrays.

With vmap: The Vectorized Approach

Now, let's use vmap to vectorize this function and apply it to the entire array without writing the loop ourselves:

Python">import jax
import jax.numpy as jnp

# Define the same simple function
def square(x):
    return x**2

# Vectorize the function using vmap
vectorized_square = jax.vmap(square)

# Apply the vectorized function to the array
x = jnp.array([1.0, 2.0, 3.0, 4.0])
y = vectorized_square(x)
print(y)

Boom! No more explicit loops! vmap automatically maps the square function across the entire array, applying it in parallel and returning the result. It's cleaner, faster, and more scalable.


Why vmap Is a Big Deal: Parallelism at Its Best

When you use vmap, you're not just avoiding the hassle of writing loops. You're unlocking serious performance improvements because JAX is designed to work efficiently on CPUs, GPUs, and TPUs. When you vectorize operations with vmap, JAX can dispatch them to your hardware and run them in parallel.

This is especially powerful when you're dealing with batching in machine learning. Instead of iterating over each sample in a batch and applying your model or function, you can vectorize the whole thing and let JAX handle it.

Let's look at a simple neural network example to further understand this point.


Vectorizing Neural Network Predictions

Imagine you have a simple neural network, and you want to predict outputs for a batch of input data. Without vmap, you'd write a loop to make predictions one input at a time. But with vmap, you can predict for the entire batch in one go, maximizing the efficiency of your hardware.

Here's how you might define a simple neural network forward pass:

import jax.numpy as jnp

# Define a simple neural network
def forward(params, x):
    w, b = params
    return jnp.dot(x, w) + b

# Create random parameters (weights and bias)
params = [jnp.array([[1.0, 2.0], [3.0, 4.0]]), jnp.array([0.5, 0.5])]

# Define a batch of inputs
inputs = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

# Manually apply the forward pass to each input
outputs = jnp.array([forward(params, x) for x in inputs])

print(outputs)

we get:

[[ 7.5 10.5]
 [15.5 22.5]
 [23.5 34.5]]

Here, we have a simple neural network with a weight matrix w and bias b. We manually apply the forward function to each input in the batch.

But with vmap, we can avoid writing that loop and automatically vectorize the forward pass:

import jax

# Vectorize the forward pass function using vmap
vectorized_forward = jax.vmap(forward, in_axes=(None, 0))

# Apply the vectorized function to the batch of inputs
outputs = vectorized_forward(params, inputs)

print(outputs)

outputing:

[[ 7.5 10.5]
 [15.5 22.5]
 [23.5 34.5]]

What's Happening Here?

  • The in_axes argument tells vmap which arguments of the function to vectorize. In this case, params remains the same for all inputs (so we set its in_axes to None), while inputs is batched along the 0th axis (the rows, representing different input samples).
  • JAX automatically applies the forward function to each sample in the batch, resulting in a batch of outputs—all without writing a loop!

Not only is this cleaner, but it's also much faster, especially if you're using a GPU or TPU. JAX handles all the parallelism behind the scenes, making your code both elegant and performant.


Advanced vmap: Vectorizing Multiple Inputs

You can also use vmap to vectorize functions that take multiple arguments. For example, let's say we have a function that computes the dot product of two vectors. We can vectorize this function to work on batches of vectors:

import jax

# Define a function that computes the dot product of two vectors
def dot_product(x, y):
    return jnp.dot(x, y)

# Create two batches of vectors
x_batch = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
y_batch = jnp.array([[0.5, 1.5], [2.0, 3.0], [4.0, 5.0]])

# Vectorize the dot product function to handle batches of inputs
vectorized_dot_product = jax.vmap(dot_product)

# Apply the vectorized function to the batches
results = vectorized_dot_product(x_batch, y_batch)

print(results)
[ 3.5 18.  50. ]

Here, both x_batch and y_batch are batches of vectors, and vmap applies the dot_product function to each pair of vectors from the batches.

Controlling Which Axes to Vectorize

You can also control which axes of each input should be vectorized using the in_axes argument. For example, if you want to vectorize over the second axis of one input and the first axis of another, you can specify that in in_axes.

Here's an example where we vectorize across different axes:

import jax

# Define a function that adds two arrays
def add_arrays(x, y):
    return x + y

# Create two arrays
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([5.0, 6.0])

# Vectorize across different axes: the 0th axis of x, and 1st axis of y
vectorized_add = jax.vmap(add_arrays, in_axes=(0, None))

# Apply the vectorized function
result = vectorized_add(x, y)
print(result)
[[ 6.  8.]
 [ 8. 10.]]

In this case, vmap is applied across the 0th axis of x (the rows) and not vectorized for y (hence, None). This level of control makes vmap incredibly flexible for a wide range of use cases.


Composing vmap with jit, grad, and More

The cool thing about JAX is that everything plays well together. You can combine vmap with JIT compilation, automatic differentiation, and other transformations.

For example, let's say you want to compute the gradient of a vectorized function. No problem! You can use vmap and grad together:

import jax

# Define a simple quadratic function
def f(x):
    return x**2

# Vectorize the function
vectorized_f = jax.vmap(f)

# Compute the gradient of the vectorized function
vectorized_grad = jax.vmap(jax.grad(f))

# Create an array of inputs
x = jnp.array([1.0, 2.0, 3.0])

# Apply the vectorized gradient function
grads = vectorized_grad(x)

print(grads)
[2. 4. 6.]

Here, we use vmap to apply grad(f) to each element of the array x. The result is the gradient of f evaluated at each point in x.

You can even throw jit into the mix to make things even faster:

# Apply both vmap and jit
vectorized_grad_jit = jax.jit(jax.vmap(jax.grad(f)))

# Now, apply the JIT-compiled, vectorized gradient function
grads_jit = vectorized_grad_jit(x)

print(grads_jit)
[2. 4. 6.]

With vmap, grad, and jit all working together, you can handle large batches of computations in parallel, with gradients calculated automatically and compiled for maximum efficiency.


Conclusion: Let JAX Do the Heavy Lifting with vmap

JAX's automatic vectorization with vmap is one of the most powerful features for anyone working with large datasets, machine learning models, or mathematical computations. It eliminates the need for writing explicit loops, allowing you to write cleaner, more concise code while leveraging the full power of parallel hardware.

In summary:

  • Vectorization is about applying operations to whole arrays or batches of data in parallel, and it's much faster than using loops.
  • JAX's vmap allows you to automatically vectorize any function, meaning you can apply it across batches of data without writing a single loop.
  • vmap can be used for everything from simple element-wise operations to complex neural network training, and it integrates smoothly with other JAX tools like jit and grad.
  • By using vmap, you can avoid writing boilerplate code and let JAX handle the hard work of optimizing your computations for parallelism and performance.

So, the next time you find yourself writing a for loop, stop and ask yourself: Can I vmap this? The answer is probably yes, and your code (and your hardware) will thank you!

Tags: Deep Learning Jax Numpy Python Tips And Tricks

Comment