A Practical Guide to Proximal Policy Optimization in JAX

Since its publication in a 2017 paper by OpenAI, Proximal Policy Optimization (PPO) is widely regarded as one of the state-of-the-art algorithms in Reinforcement Learning. Indeed, PPO has demonstrated remarkable performances across various tasks, from attaining superhuman performances in Dota 2 teams to solving a Rubik's cube with a single robotic hand while maintaining three main advantages: simplicity, stability, and sample efficiency.
However, implementing RL algorithms from scratch is notoriously difficult and error-prone, given the numerous error sources and Implementation details to be aware of.
In this article, we'll focus on breaking down the clever tricks and programming concepts used in a popular implementation of PPO in JAX. Specifically, we'll focus on the implementation featured in the PureJaxRL library, developed by Chris Lu.
Disclaimer: Rather than diving too deep into theory, this article covers the practical implementation details and (numerous) tricks used in popular versions of PPO. Should you require any reminders about PPO's theory, please refer to the "references" section at the end of this article. Additionally, all the code (minus the added comments) is copied directly from PureJaxRL for pedagogical purposes.
UGitHub – luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations
Actor-Critic Architectures
Proximal Policy Optimization is categorized within the policy gradient family of algorithms, a subset of which includes actor-critic methods. The designation ‘actor-critic' reflects the dual components of the model:
- The actor network creates a distribution over actions given the current state of the environment and returns an action sampled from this distribution. Here, the actor network comprises three dense layers separated by two activation layers (either ReLU or hyperbolic tangeant) and a final categorical layer applying the softmax function to the computed distribution.
- The critic network estimates the value function of the current state, in other words, how good a particular action is at a given time. Its architecture is almost identical to the actor network, except for the final softmax layer. Indeed, the critic network doesn't apply any activation function to the final dense layer outputs as it performs a regression task.

Additionally, this implementation pays particular attention to weight initialization in dense layers. Indeed, all dense layers are initialized by orthogonal matrices with specific coefficients. This initialization strategy has been shown to preserve the gradient norms (i.e. scale) during forward passes and backpropagation, leading to smoother convergence and limiting the risks of vanishing or exploding gradients[1].
Orthogonal initialization is used in conjunction with specific scaling coefficients:
- Square root of 2: Used for the first two dense layers of both networks, this factor aims to compensate for the variance reduction induced by ReLU activations (as inputs with negative values are set to 0). For the tanh activation, the Xavier initialization is a popular alternative[2].
- 0.01: Used in the last dense layer of the actor network, this factor helps to minimize the initial differences in logit values before applying the softmax function. This will reduce the difference in action probabilities and thus encourage early exploration.
-
1: As the critic network is performing a regression task, we do not scale the initial weights.
Training Loop
The training loop is divided into 3 main blocks that share similar coding patterns, taking advantage of Jax's functionalities:
- Trajectory collection: First, we'll interact with the environment for a set number of steps and collect observations and rewards.
- Generalized Advantage Estimation (GAE): Then, we'll approximate the expected return for each trajectory by computing the generalized advantage estimation.
- Update step: Finally, we'll compute the gradient of the loss and update the network parameters via gradient descent.
Before going through each block in detail, here's a quick reminder about the jax.lax.scan
__ function that will show up multiple times throughout the code:
Jax.lax.scan
A common Programming pattern in JAX consists of defining a function that acts on a single sample and using jax.lax.scan
__ to iteratively apply it to elements of a sequence or an array, while carrying along some state.
For instance, we'll apply it to the step
function to step our environment N consecutive times while carrying the new state of the environment through each iteration.
In pure Python, we could proceed as follows:
trajectories = []
for step in range(n_steps):
action = actor_network(obs)
obs, state, reward, done, info = env.step(action, state)
trajectories.append(tuple(obs, state, reward, done, info))
However, we avoid writing such loops in JAX for performance reasons (as pure Python loops are incompatible with JIT compilation). The alternative is jax.lax.scan
__ which is equivalent to:
def scan(f, init, xs, length=None):
"""Example provided in the JAX documentation."""
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
# apply function f to current state
# and element x
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
Using jax.lax.scan
is more efficient than a Python loop because it allows the transformation to be optimized and executed as a single compiled operation rather than interpreting each loop iteration at runtime.
We can see that the scan
function takes multiple arguments:
- f: A function that is applied at each step. It takes the current state and an element of
xs
(or a placeholder ifxs
isNone
) and returns the updated state and an output. - init: The initial state that
f
will use in its first invocation. - xs: A sequence of inputs that are iteratively processed by
f
. Ifxs
isNone
, the function simulates a loop withlength
iterations usingNone
as the input for each iteration. - length: Specifies the number of iterations if
xs
isNone
, ensuring that the function can still operate without explicit inputs.
Additionally, scan
returns:
- carry: The final state after all iterations.
- ys: An array of outputs corresponding to each step's application of
f
, stacked for easy analysis or further processing.
Finally, scan
can be used in combination with vmap
to scan a function over multiple dimensions in parallel. As we'll see in the next section, this allows us to interact with several environments in parallel to collect trajectories rapidly.

1. Trajectory Collection
As mentioned in the previous section, the trajectory collection block consists of a step
function scanned across N iterations. This step
function successively:
- Selects an action using the actor network
- Steps the environment
- Stores transition data in a
transition
tuple - Stores the model parameters, the environment state, the current observation, and rng keys in a
runner_state
tuple - Returns
runner_state
andtransition
Scanning this function returns the latest runner_state
and traj_batch
, an array of transition
tuples. In practice, transitions are collected from multiple environments in parallel for efficiency as indicated by the use of jax.vmap(env.step, ...)
(for more details about vectorized environments and vmap
, refer to my previous article).
2. Generalized Advantage Estimation
After collecting trajectories, we need to compute the advantage function, a crucial component of PPO's loss function. The advantage function measures how much better a specific action is compared to the average action in a given state:

Where Gt is the return at time t and V(St) is the value of state s at time t.
As the return is generally unknown, we have to approximate the advantage function. A popular solution is generalized advantage estimation[3], defined as follows:

With γ the discount factor, λ a parameter that controls the trade-off between bias and variance in the estimate, and _δt the temporal difference error at time t_:

As we can see, the value of the GAE at time t depends on the GAE at future timesteps. Therefore, we compute it backward, starting from the end of a trajectory. For example, for a trajectory of 3 transitions, we would have:

Which is equivalent to the following recursive form:

Once again, we use jax.lax.scan
on the trajectory batch (this time in reverse order) to iteratively compute the GAE.
Note that the function returns advantages + traj_batch.value
as a second output, which is equivalent to the return according to the first equation of this section.
3. Update step
The final block of the training loop defines the loss function, computes its gradient, and performs gradient descent on minibatches. Similarly to previous sections, the update step is an arrangement of several functions in a hierarchical order:
def _update_epoch(update_state, unused):
"""
Scans update_minibatch over shuffled and permuted
mini batches created from the trajectory batch.
"""
def _update_minbatch(train_state, batch_info):
"""
Wraps loss_fn and computes its gradient over the
trajectory batch before updating the network parameters.
"""
...
def _loss_fn(params, traj_batch, gae, targets):
"""
Defines the PPO loss and computes its value.
"""
...
Let's break them down one by one, starting from the innermost function of the update step.
3.1 Loss function
This function aims to define and compute the PPO loss, originally defined as:

Where:

However, the PureJaxRL implementation features some tricks and differences compared to the original PPO paper[4]:
- The paper defines the PPO loss in the context of gradient ascent whereas the implementation performs gradient descent. Therefore, the sign of each loss component is reversed.
- The value function term is modified to include an additional clipped term. This could be seen as a way to make the value function updates more conservative (as for the clipped surrogate objective):

- The GAE is standardized.
Here's the complete loss function:
3.2 Update Minibatch
The update_minibatch
function is essentially a wrapper around loss_fn
used to compute its gradient over the trajectory batch and update the model parameters stored in train_state
.
3.3 Update Epoch
Finally, update_epoch
wraps update_minibatch
and applies it on minibatches. Once again, jax.lax.scan
is used to apply the update function on all minibatches iteratively.
Conclusion
From there, we can wrap all of the previous functions in an update_step
function and use scan
one last time for N steps to complete the training loop.
A global view of the training loop would look like this:
We can now run a fully compiled training loop using jax.jit(train(rng))
or even train multiple agents in parallel using jax.vmap(train(rng))
.
There we have it! We covered the essential building blocks of the PPO training loop as well as common programming patterns in JAX.
To go further, I highly recommend reading the full training script in detail and running example notebooks on the PureJaxRL repository.
GitHub – luchris429/purejaxrl: Really Fast End-to-End Jax RL Implementations
Thank you very much for your support, until next time