Introducing n-Step Temporal-Difference Methods

Author:Murphy  |  View: 26479  |  Time: 2025-03-22 19:09:55

In our previous post, we wrapped up the introductory series on fundamental reinforcement learning (RL) techniques by exploring Temporal-Difference (TD) learning. TD methods merge the strengths of Dynamic Programming (DP) and Monte Carlo (MC) methods, leveraging their best features to form some of the most important RL algorithms, such as Q-learning.

Building on that foundation, this post delves into n-step TD learning, a versatile approach introduced in Chapter 7 of Sutton's book [1]. This method bridges the gap between classical TD and MC techniques. Like TD, n-step methods use bootstrapping (leveraging prior estimates), but they also incorporate the next n rewards, offering a unique blend of short-term and long-term learning. In a future post, we'll generalize this concept even further with eligibility traces.

We'll follow a structured approach, starting with the prediction problem before moving to control. Along the way, we'll:

  • Introduce n-step Sarsa,
  • Extend it to off-policy learning,
  • Explore the n-step tree backup algorithm, and
  • Present a unifying perspective with n-step Q(σ).

As always, you can find all accompanying code on GitHub. Let's dive in!

Photo by Alev Takil on Unsplash

n-Step TD Learning

As mentioned before, n-step TD allows us to freely move between classic TD learning and MC methods. To get a better understanding of this, let's recap the update formulas for both.

In MC methods, we update the value estimate towards the full observed return:

Image from [1]

In contrast, in TD learning the update is the observed reward plus the (estimated) discounted value of the next state:

Image from [1]

Intuitively it makes sense to allow more flexibility here, and in particular allow multi-step updates. Consider for example the 2-step update:

Image from [1]

And, more generally, the n-step update:

Image from [1]

This is exactly the heart of n-step TD learning. Why is this beneficial? Often, neither 1-step TD nor MC methods are best – and the optimum lies somewhere in the middle.

Another benefit is that this frees us from the "tyranny of the timestep" [1], as Sutton formulates it so nicely: for 1-step TD methods, we appreciate being able to update our value estimate often (at every step) – but we are also forced to also look only one step into the future. Here, these two numbers are decoupled.

Another nice graphic from Sutton compares these methods visually:

Image from [1]

From these definitions, one can directly introduce the prediction algorithm (and don't worry, if some of the indices might be a bit unintuitive – we'll discuss these in details in the next section):

Image from [1]

So let's not keep up too long with the prediction part, and instead move to control.

n-Step Sarsa

The idea is very similar to the prediction problem, and we begin by showing the pseudocode – and in the following will explain it in greater details:

Image from [1]

We keep a set of three indices: T, t and τ. As usual, we keep on playing episodes until they terminate – while doing so, we keep track of the current time step with t. Since for n-step methods we need to wait for at least n steps before being able to update the value estimate, with τ we track the index of the timestep we want to update. In the first n steps τ will be negative, and we cannot do an update – which is what the last if-clause catches.

Conversely, when the episode has finished, we want to keep updating the value estimates with what we have left – this is why we store the terminal step in T and progress t up to T, not taking any further actions but just updating the value estimates.

Apart from this, we should recognize the conventional Sarsa algorithm from the previous post. Here's how it looks in Python:

def sarsa_n(env: ParametrizedEnv, n: int = 3) -> np.ndarray:
    observation_space, action_space = get_observation_action_space(env)
    Q = np.zeros((observation_space.n, action_space.n))

    for _ in range(NUM_STEPS):
        observation, _ = env.env.reset()
        terminated = truncated = False
        action = get_eps_greedy_action(Q[observation])

        replay_buffer = [ReplayItem(observation, action, 0)]

        T = float("inf")  # terminal step
        t = 0  # current step
        tau = 0  # update value estimate for this time step

        while True:
            if t < T:
                # While not terminal, continue playing episode.
                observation_new, reward, terminated, truncated, _ = env.env.step(action)
                action_new = get_eps_greedy_action(Q[observation_new])
                replay_buffer.append(ReplayItem(observation_new, action_new, reward))
                if terminated or truncated:
                    T = t + 1

                observation = observation_new
                action = action_new

            tau = t - n + 1
            if tau >= 0:
                G = sum(
                    [
                        replay_buffer[i].reward * env.gamma ** (i - tau - 1)
                        for i in range(tau + 1, min(tau + n, T) + 1)
                    ]
                )

                if tau + n < T:
                    G = (
                        G
                        + env.gamma**n
                        * Q[replay_buffer[tau + n].state, replay_buffer[tau + n].action]
                    )

                Q[replay_buffer[tau].state, replay_buffer[tau].action] = Q[
                    replay_buffer[tau].state, replay_buffer[tau].action
                ] + ALPHA * (G - Q[replay_buffer[tau].state, replay_buffer[tau].action])

            if tau == T - 1:
                break

            t += 1

    return np.array([np.argmax(Q[s]) for s in range(observation_space.n)])

As usual, the full code can be found on GitHub, and you can directly test the success of the algorithm on our grid world example via:

python grid_world.py - method=sarsa_n

Off-Policy Learning

With minor modifications we can turn the previous algorithm into an off-policy one. For a thorough introduction into off-policy learning I'd like to refer to my previous post about MC methods. Just to quickly recap here: off-policy methods allow us to use a second, a behavior policy, while optimizing the original target policy (which e.g. makes the exploration – exploitation trade-off easier). In order to be able to do so, we need to correct the introduced bias in the expectation – which we do by multiplying the returns with importance sampling weights. Sutton show the following pseudocode:

Image from [1]

We can easily extend our previously introduced Python code. Now, the on-policy case is just a special case of off-policy learning, in which behavior and target policy are identical. In the code, we use a random (!) policy as behavior policy when the off_policy flag is set, and otherwise use the target policy (isn't it fascinating how off-policy learning with importance sampling allows us to learn from completely random policies?).

The importance sampling weights ρ are computed, and fall back to 1 if the two policies agree:

def sarsa_n(env: ParametrizedEnv, n: int = 3, off_policy: bool = False) -> np.ndarray:
    observation_space, action_space = get_observation_action_space(env)
    Q = np.zeros((observation_space.n, action_space.n))

    for _ in range(NUM_STEPS):
        b = (
            np.random.rand(int(observation_space.n), int(action_space.n))
            if off_policy
            else Q
        )

        observation, _ = env.env.reset()
        terminated = truncated = False
        action = (
            get_eps_greedy_action(Q[observation])
            if not off_policy
            else get_eps_greedy_action(b[observation], eps=0)
        )

        replay_buffer = [ReplayItem(observation, action, 0)]

        T = float("inf")  # terminal step
        t = 0  # current step
        tau = 0  # update value estimate for this time step

        rhos = []  # importance sampling weights

        while True:
            if t < T:
                # While not terminal, continue playing episode.
                observation_new, reward, terminated, truncated, _ = env.env.step(action)
                action_new = get_eps_greedy_action(Q[observation_new])
                replay_buffer.append(ReplayItem(observation_new, action_new, reward))
                if terminated or truncated:
                    T = t + 1

                observation = observation_new
                action = action_new

            tau = t - n + 1
            if tau >= 0:
                rho = math.prod(
                    [
                        div_with_zero(
                            Q[replay_buffer[i].state, replay_buffer[i].action],
                            b[replay_buffer[i].state, replay_buffer[i].action],
                        )
                        for i in range(tau + 1, min(tau + n, T - 1) + 1)
                    ]
                )
                rhos.append(rho)

                G = sum(
                    [
                        replay_buffer[i].reward * env.gamma ** (i - tau - 1)
                        for i in range(tau + 1, min(tau + n, T) + 1)
                    ]
                )

                if tau + n < T:
                    G = (
                        G
                        + env.gamma**n
                        * Q[replay_buffer[tau + n].state, replay_buffer[tau + n].action]
                    )

                Q[replay_buffer[tau].state, replay_buffer[tau].action] = Q[
                    replay_buffer[tau].state, replay_buffer[tau].action
                ] + ALPHA * rho / (sum(rhos) + 1) * (
                    G - Q[replay_buffer[tau].state, replay_buffer[tau].action]
                )

            if tau == T - 1:
                break

            t += 1

    return np.array([np.argmax(Q[s]) for s in range(observation_space.n)])

div_with_zero is a small helper function which evaluates 0 / 0 to 1, since this appears quite frequently in the on-policy case:

def div_with_zero(x: float, y: float) -> float:
    if x == 0 and y == 0:
        return 1
    else:
        return x / (y + 0.0001)

n-step Tree Backup Algorithm

As it turns out, it is also possible to do off-policy learning without importance sampling: for this, we extend Expected Sarsa from the previous post to a tree-like structure: n-step tree backup.

The path through the tree is defined by the actions taken according to the (ε-greedy) target policy, and the returns are used in a similar way as in n-step Sarsa:

Image from [1]

However we apply the probabilistic weighting from Expected Sarsa: each leaf node in the tree corresponds to a value estimate we bootstrap. On the first level, we weigh all leaf node estimates with the corresponding probability determined by the policy output. The probability assigned to the action actually taken is only used to weigh all following values.

For two levels of the tree, this is formalized as follows:

Image from [1]

Sutton gives the following pseudocode:

Image from [1]

Note how "off-policy" is interpreted here: as opposed to "conventional" off-policy methods where behavior and target policy can completely differ, we here still only have a single target policy which we use to generate episodes, and which we want to learn.

In Python, the code can be implemented as follows:

def tree_n(env: ParametrizedEnv, n: int = 3) -> np.ndarray:
    observation_space, action_space = get_observation_action_space(env)
    Q = np.zeros((observation_space.n, action_space.n)) + 0.1

    for _ in range(NUM_STEPS):
        observation, _ = env.env.reset()
        terminated = truncated = False
        action = get_eps_greedy_action(Q[observation])

        replay_buffer = [ReplayItem(observation, action, 0)]

        T = float("inf")  # terminal step
        t = 0  # current step
        tau = 0  # update value estimate for this time step

        while True:
            if t < T:
                observation_new, reward, terminated, truncated, _ = env.env.step(action)
                action_new = get_eps_greedy_action(Q[observation_new])
                replay_buffer.append(ReplayItem(observation_new, action_new, reward))
                if terminated or truncated:
                    T = t + 1

                observation = observation_new
                action = action_new

            tau = t - n + 1

            if tau >= 0:
                if t + 1 >= T:
                    G = replay_buffer[T].reward
                else:
                    G = replay_buffer[t + 1].reward + env.gamma * sum(
                        [
                            Q[replay_buffer[t + 1].state, a]
                            / sum(Q[replay_buffer[t + 1].state, :])
                            * Q[replay_buffer[t + 1].state, a]
                            for a in range(action_space.n)
                        ]
                    )

                for k in range(min(t, T - 1), tau + 1, -1):
                    G = (
                        replay_buffer[k].reward
                        + env.gamma
                        * sum(
                            [
                                Q[replay_buffer[k].state, a]
                                / sum(Q[replay_buffer[k].state, :])
                                * Q[replay_buffer[k].state, a]
                                for a in range(action_space.n)
                                if a != replay_buffer[k].action
                            ]
                        )
                        + env.gamma
                        * Q[replay_buffer[k].state, replay_buffer[k].action]
                        / sum(Q[replay_buffer[k].state, :])
                        * G
                    )

                Q[replay_buffer[tau].state, replay_buffer[tau].action] = Q[
                    replay_buffer[tau].state, replay_buffer[tau].action
                ] + ALPHA * (G - Q[replay_buffer[tau].state, replay_buffer[tau].action])

            if tau == T - 1:
                break

            t += 1

    return np.array([np.argmax(Q[s]) for s in range(observation_space.n)])

A Unified View of n-Step Algorithms

I want to conclude this section with an outlook of how previously introduced algorithms can be represented in a unified framework. This algorithm is n-step Q(σ). Let's recap the algorithms seen so far:

Image from [1]

n-step Sarsa has only sample transitions, i.e. we follow the executed actions along the episode. On the other end of the spectrum, n-step tree backup includes all possible transitions. There is also n-step Expected Sarsa lying somewhere in the middle, only branching out at the last level. Now it is nearby to unify this into n-step Q(σ): at each level, we flip a (biased) coin and do a sample transition with probability σ, and branch fully in the other case.

Conclusion

In this post, we unified Monte Carlo (MC) and Temporal-Difference (TD) approaches by introducing n-step TD algorithms. While MC and TD methods represent two extremes – MC relying on full episodes and TD updating value estimates at every step – n-step methods strike a balance. They update value estimates at each step using returns from the next n steps, rather than a single one.

This approach is advantageous because n-step methods often outperform pure MC or TD methods. However, they come with a trade-off: higher computational and memory costs. Since n-step TD algorithms can only update values from n steps in the past, they require tracking additional states and rewards. In a future post, we'll explore eligibility traces, a technique that addresses this memory overhead efficiently.

We began our exploration of n-step methods with n-step Sarsa, a straightforward extension of the basic Sarsa algorithm that uses returns from the next n steps. We then expanded this to handle off-policy learning by incorporating importance sampling weights, allowing the algorithm to work with arbitrary policies.

Moving beyond sample transitions at each step, we introduced the n-step tree backup algorithm, which generates all state-action pairs. Similar to Expected Sarsa, it factors in action probabilities and propagates updates in a tree-like structure. Finally, we discussed n-step Q(σ), a unifying algorithm that enables a smooth transition between n-step Sarsa and n-step tree backup.

Thank you for reading! I hope you enjoyed this post and found it insightful. Stay tuned for the next installment in this series, where we'll dive into planning and its role in Reinforcement Learning.

Other Posts in this Series

References

[1] http://incompleteideas.net/book/RLbook2020.pdf

Tags: Hands On Tutorials Machine Learning Openai Gym Python Reinforcement Learning

Comment