Skip to Content
All posts

Proximal Policy Optimization (PPO)

 — #LLMs#RL

Hi everyone! Today, I'll be discussing a famous reinforcement learning algorithm that has been the foundation of many post-training regimes in LLMs and the stepping stone for many novel algorithms: Proximal Policy Optimization.

I will be going to significant mathematical understanding but always tying it back to LLMs and how it would be implemented in post-training objectives as it helped me understand why it is important and may help you understand too!

PPO Objective

In PPO, we train an actor and critic to generate better LLM responses. Our actor generates the output tokens while our critic reasons on whether it is a good or bad response. Using this, we generate our objective function for PPO. In specific, I'm discussing PPO2, discovered by John Schulman.

JPPO(θ)=Et[tmin(πθ(atst)πθold(atst)At,clip(πθ(atst)πθold(atst),1ϵ,1+ϵ)At)]J_{\text{PPO}}(\theta) = \mathbb{E}_t\left[\sum_t \min\left(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)}A_t, \text{clip}\left(\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon\right)A_t\right)\right]

Some quick terminology, before we begin explaining the algorithm.

  • rt=πθ(atst)πθold(atst)r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_\text{old}}(a_t|s_t)} is the probability ratio between new and old policies. It shows how likely the current policy (LLM) πθ\pi_\theta is to predict the action took at step tt versus our older LLM weights πθold\pi_{\theta_\text{old}}
  • AtA_t tells us whether the action that we took was better or worse than the LLMs average action in that state.
  • In this context, the state represents the prefix of the current LLM input, and the action represents which next token the LLM chooses to emit.

So overall, this objective function tells us to update our policy (LLM) in a way that improves upon good actions (positive advantage) while limiting how far we can move from our old policy using the clip function. This prevents the policy from making drastic changes that could destabilize training.

Actor Critic Architecture

The Actor-Critic architecture consists of two neural networks working together: the Actor (policy network) that decides which actions to take, and the Critic (value network) that evaluates how good those actions are.

In PPO for LLMs, both the Actor and Critic typically share the same transformer-based architecture, with the only difference being their output heads.

The Actor uses the standard language model head to predict token probabilities, while the Critic adds a separate Value head - usually a small MLP that outputs a single scalar value representing the expected future reward from the current state.

Here's how they interact:

  1. The Actor (policy) generates token probabilities
  2. The Critic evaluates the chosen tokens and estimates their value
  3. The difference between actual rewards and Critic's estimates forms the advantage AtA_t
  4. This advantage guides the Actor's policy updates through the PPO objective

This creates a feedback loop where the Critic helps the Actor learn better token generation strategies while staying close to its original behavior.

Generalized Advantage Estimation

Now, I will explain how we generate this advantage function AtA_t. First think of these two functions that are common throughout reinforcement learning. The Q-function and the Value function.

  • Vπ(s)=Eπ[k=0γkrt+k+1st=s]V_\pi(s) = \mathbb{E}_\pi \left[\sum_{k=0}^{\infty} \gamma^k r_{t+k+1} \mid s_t=s\right]. This is the expected future reward from the current state
  • Qπ(s,a)=Eπ[k=0γkrt+k+1st=s,at=a]Q_\pi(s,a) = \mathbb{E}_\pi \left[\sum_{k=0}^{\infty} \gamma^k r_{t+k+1} \mid s_t=s, a_t=a\right]. This is the expected future reward conditioned on choosing a ceratin next action(token).

Intuitively, if we subtract these two values (Qπ(s,a)Vπ(s)Q_\pi(s,a) - V_\pi(s)), we get how much better or worse taking action aa is compared to the average action in state ss. This is exactly what we want for our advantage function! Thus,

Aπ(s,a)=Qπ(s,a)Vπ(s)A_{\pi} (s,a) = Q_{\pi}(s,a) - V_{\pi}(s)

This would result it having to learn a value function and a Q-value function. The Generalized Advantage Estimate finds a work-around to this.

We can estimate this action function via just the reward function and the current value function. Schulman describes this as the following:

AtGAE=i=0(λγ)iδt+iA_t^{\text{GAE}} = \sum_{i=0}^{\infty} (\lambda\gamma)^i \delta_{t+i} where δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

Deriving GAE:

GAE provides a method to estimate Q(st,at)Q(s_t, a_t) by finding a balance between two approaches: temporal-difference estimates (which have low variance but are biased) and Monte Carlo estimates (which are unbiased but have high variance). This balance helps us get more stable and accurate advantage estimates.

TD Estimate

From previous literature, we have that the 1-step temporal difference error is:

δt(1)=rt+γV(st+1)V(st)\delta_t^{(1)} = r_t + \gamma V(s_{t+1}) - V(s_t)

This formula represents the temporal difference (TD) error at time tt, which measures the difference between our predicted value of the current state V(st)V(s_t) and a more accurate estimate based on the actual reward received rtr_t plus the discounted value of the next state γV(st+1)\gamma V(s_{t+1}).

In other words, it tells us how "surprised" we were by the actual outcome compared to what we predicted. This is a noisy estimate of AtA_t, using just one-step information.

Multi-step Advantage Estimators

We can define longer-horizon advantage estimators using multiple steps of rewards. For 1-step, we get:

A^t(1)=δt(1)=rt+γV(st+1)V(st)\hat{A}_t^{(1)} = \delta_t^{(1)} = r_t + \gamma V(s_{t+1}) - V(s_t)

For 2-step, we get:

A^t(2)=rt+γrt+1+γ2V(st+2)V(st)\hat{A}_t^{(2)} = r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) - V(s_t)

For 3-step, we get:

A^t(3)=rt+γrt+1+γ2rt+2+γ3V(st+3)V(st)\hat{A}_t^{(3)} = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \gamma^3 V(s_{t+3}) - V(s_t)

GAE:

In general, for k-steps, we get:

A^t(k)=l=0k1γlrt+l+γkV(st+k)V(st)\hat{A}_t^{(k)} = \sum_{l=0}^{k-1} \gamma^l r_{t+l} + \gamma^k V(s_{t+k}) - V(s_t)

We can rewrite this in terms of the TD errors:

A^t(k)=l=0k1γlδt+l(1)\hat{A}_t^{(k)} = \sum_{l=0}^{k-1} \gamma^l \delta_{t+l}^{(1)}

GAE takes a weighted average of these k-step estimators, with weights determined by λ\lambda:

A^tGAE(γ,λ)=(1λ)(A^t(1)+λA^t(2)+λ2A^t(3)+...)\hat{A}_t^{\text{GAE}(\gamma,\lambda)} = (1-\lambda)(\hat{A}_t^{(1)} + \lambda \hat{A}_t^{(2)} + \lambda^2 \hat{A}_t^{(3)} + ...)

Which simplifies to:

A^tGAE(γ,λ)=l=0(γλ)lδt+l(1)\hat{A}_t^{\text{GAE}(\gamma,\lambda)} = \sum_{l=0}^{\infty} (\gamma\lambda)^l \delta_{t+l}^{(1)}
  • λ[0,1]\lambda \in [0,1] controls the trade-off between bias and variance. λ=1\lambda=1 is low variance while the high variance for λ=1\lambda=1
  • γ\gamma is the traditional discount factor
  • The sum runs from tt to the end of the trajectory (when the LLM generates the end token)

Value Function Ground-Truth

We are learning the value function on the fly. In order to get a better Value-function estimate, we generate a estimate via the GAE. The ground truth is defined as,

VGAE(st)=V(st)+A^tGAE(γ,λ)V_{\text{GAE}}(s_t) = V(s_t) + \hat{A}_t^{\text{GAE}(\gamma,\lambda)}

Thus, we now have a LvalueL_{value} to update the LLM as,

12[VGAE(st)V(st)]2\frac{1}{2} [V_{\text{GAE}}(s_t) - V(s_t)]^2

Pseudo-Code


# Initialize policy π_θ and value function V_φ
# Initialize replay buffer D
# Initialize hyperparameters: clip_ratio ε, GAE params γ and λ

for iteration = 1,2,... do:
    # SAVE current policy as π_θ_old for this iteration
    π_θ_old = copy(π_θ)
    
    # Collect trajectories using current policy π_θ
    for episode = 1,2,...,N do:
        # Generate sequence using policy π_θ
        state = initial_prompt
        trajectory = []
        
        while not done:
            # Get action probabilities from old policy
            logits = π_θ_old(state)
            action = sample_token(logits)
            
            # Take action and get reward – usually an LLM or maybe human-feedback
            next_state = append(state, action)
            reward = compute_reward(next_state)
            
            # Store transition (including old_logits for ratio calculation)
            old_logits = logits
            trajectory.append((state, action, reward, old_logits))
            state = next_state
            
        # Compute advantages using GAE
        values = V_φ(states)
        advantages = compute_gae(rewards, values, γ, λ)
        
        # Add to replay buffer
        D.add(trajectory, advantages)
    
    # Update policy and value function using SAME collected data
    for epoch = 1,2,...,K do:
        # Sample mini-batch from trajectories collected above
        batch = D.sample()
        states, actions, old_logits, advantages = batch
        
        # Compute policy ratio (π_θ vs π_θ_old)
        new_logits = π_θ(states)  # Current policy
        ratio = exp(new_logits - old_logits)  # old_logits from π_θ_old
        
        # Compute clipped surrogate objective
        clip_adv = clip(ratio, 1-ε, 1+ε) * advantages
        policy_loss = -min(ratio * advantages, clip_adv)
        
        # Compute value loss
        values = V_φ(states)
        value_targets = values + advantages
        value_loss = (value_targets - values)^2
        
        # Update parameters
        loss = policy_loss + c1 * value_loss
        loss.backward()
        actor.step()  # π_θ gets updated here
        critic.step()
    
    # After K epochs, π_θ has changed from π_θ_old
    # Next iteration will use updated π_θ to collect new trajectories
    # and treat current π_θ as the new π_θ_old

Important Notes:

Multiple Updates Per Data Collection:

  • We collect data once per iteration using π_θ_old
  • We then update π_θ multiple times (K epochs) using that same data
  • The clipping prevents πθ from deviating too much from πθ_old

Why This Matters:

  • Sample Efficiency: We get multiple gradient updates from each expensive data collection
  • Stability: Clipping ensures we don't change the policy too drastically
  • Performance: Better than vanilla policy gradients which only do one update per data collection

Disavantages

  • For large-scale post-training, having just couple of value-layers and parameters to generate the value function is unsufficient. They sometimes require a completely seperate model.
  • In the LLM context, usually the last token is assigned the reward (good answer or bad answer), but PPO needs to find rtr_t at each time step to find δt\delta_t. Propagating one reward across all time-steps is not a great idea
  • That's why Deepseek came up with Group-Relative Policy Optimization (GRPO) that removes the need for a Value head!

References

  • Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.