Proximal Policy Optimization (PPO)
— #LLMs#RL
Hi everyone! Today, I'll be discussing a famous reinforcement learning algorithm that has been the foundation of many post-training regimes in LLMs and the stepping stone for many novel algorithms: Proximal Policy Optimization.
I will be going to significant mathematical understanding but always tying it back to LLMs and how it would be implemented in post-training objectives as it helped me understand why it is important and may help you understand too!
PPO Objective
In PPO, we train an actor and critic to generate better LLM responses. Our actor generates the output tokens while our critic reasons on whether it is a good or bad response. Using this, we generate our objective function for PPO. In specific, I'm discussing PPO2, discovered by John Schulman.
Some quick terminology, before we begin explaining the algorithm.
- is the probability ratio between new and old policies. It shows how likely the current policy (LLM) is to predict the action took at step versus our older LLM weights
- tells us whether the action that we took was better or worse than the LLMs average action in that state.
- In this context, the state represents the prefix of the current LLM input, and the action represents which next token the LLM chooses to emit.
So overall, this objective function tells us to update our policy (LLM) in a way that improves upon good actions (positive advantage) while limiting how far we can move from our old policy using the clip function. This prevents the policy from making drastic changes that could destabilize training.
Actor Critic Architecture
The Actor-Critic architecture consists of two neural networks working together: the Actor (policy network) that decides which actions to take, and the Critic (value network) that evaluates how good those actions are.
In PPO for LLMs, both the Actor and Critic typically share the same transformer-based architecture, with the only difference being their output heads.
The Actor uses the standard language model head to predict token probabilities, while the Critic adds a separate Value head - usually a small MLP that outputs a single scalar value representing the expected future reward from the current state.
Here's how they interact:
- The Actor (policy) generates token probabilities
- The Critic evaluates the chosen tokens and estimates their value
- The difference between actual rewards and Critic's estimates forms the advantage
- This advantage guides the Actor's policy updates through the PPO objective
This creates a feedback loop where the Critic helps the Actor learn better token generation strategies while staying close to its original behavior.
Generalized Advantage Estimation
Now, I will explain how we generate this advantage function . First think of these two functions that are common throughout reinforcement learning. The Q-function and the Value function.
- . This is the expected future reward from the current state
- . This is the expected future reward conditioned on choosing a ceratin next action(token).
Intuitively, if we subtract these two values (), we get how much better or worse taking action is compared to the average action in state . This is exactly what we want for our advantage function! Thus,
This would result it having to learn a value function and a Q-value function. The Generalized Advantage Estimate finds a work-around to this.
We can estimate this action function via just the reward function and the current value function. Schulman describes this as the following:
where
Deriving GAE:
GAE provides a method to estimate by finding a balance between two approaches: temporal-difference estimates (which have low variance but are biased) and Monte Carlo estimates (which are unbiased but have high variance). This balance helps us get more stable and accurate advantage estimates.
TD Estimate
From previous literature, we have that the 1-step temporal difference error is:
This formula represents the temporal difference (TD) error at time , which measures the difference between our predicted value of the current state and a more accurate estimate based on the actual reward received plus the discounted value of the next state .
In other words, it tells us how "surprised" we were by the actual outcome compared to what we predicted. This is a noisy estimate of , using just one-step information.
Multi-step Advantage Estimators
We can define longer-horizon advantage estimators using multiple steps of rewards. For 1-step, we get:
For 2-step, we get:
For 3-step, we get:
GAE:
In general, for k-steps, we get:
We can rewrite this in terms of the TD errors:
GAE takes a weighted average of these k-step estimators, with weights determined by :
Which simplifies to:
- controls the trade-off between bias and variance. is low variance while the high variance for
- is the traditional discount factor
- The sum runs from to the end of the trajectory (when the LLM generates the end token)
Value Function Ground-Truth
We are learning the value function on the fly. In order to get a better Value-function estimate, we generate a estimate via the GAE. The ground truth is defined as,
Thus, we now have a to update the LLM as,
Pseudo-Code
# Initialize policy π_θ and value function V_φ
# Initialize replay buffer D
# Initialize hyperparameters: clip_ratio ε, GAE params γ and λ
for iteration = 1,2,... do:
# SAVE current policy as π_θ_old for this iteration
π_θ_old = copy(π_θ)
# Collect trajectories using current policy π_θ
for episode = 1,2,...,N do:
# Generate sequence using policy π_θ
state = initial_prompt
trajectory = []
while not done:
# Get action probabilities from old policy
logits = π_θ_old(state)
action = sample_token(logits)
# Take action and get reward – usually an LLM or maybe human-feedback
next_state = append(state, action)
reward = compute_reward(next_state)
# Store transition (including old_logits for ratio calculation)
old_logits = logits
trajectory.append((state, action, reward, old_logits))
state = next_state
# Compute advantages using GAE
values = V_φ(states)
advantages = compute_gae(rewards, values, γ, λ)
# Add to replay buffer
D.add(trajectory, advantages)
# Update policy and value function using SAME collected data
for epoch = 1,2,...,K do:
# Sample mini-batch from trajectories collected above
batch = D.sample()
states, actions, old_logits, advantages = batch
# Compute policy ratio (π_θ vs π_θ_old)
new_logits = π_θ(states) # Current policy
ratio = exp(new_logits - old_logits) # old_logits from π_θ_old
# Compute clipped surrogate objective
clip_adv = clip(ratio, 1-ε, 1+ε) * advantages
policy_loss = -min(ratio * advantages, clip_adv)
# Compute value loss
values = V_φ(states)
value_targets = values + advantages
value_loss = (value_targets - values)^2
# Update parameters
loss = policy_loss + c1 * value_loss
loss.backward()
actor.step() # π_θ gets updated here
critic.step()
# After K epochs, π_θ has changed from π_θ_old
# Next iteration will use updated π_θ to collect new trajectories
# and treat current π_θ as the new π_θ_old
Important Notes:
Multiple Updates Per Data Collection:
- We collect data once per iteration using π_θ_old
- We then update π_θ multiple times (K epochs) using that same data
- The clipping prevents πθ from deviating too much from πθ_old
Why This Matters:
- Sample Efficiency: We get multiple gradient updates from each expensive data collection
- Stability: Clipping ensures we don't change the policy too drastically
- Performance: Better than vanilla policy gradients which only do one update per data collection
Disavantages
- For large-scale post-training, having just couple of value-layers and parameters to generate the value function is unsufficient. They sometimes require a completely seperate model.
- In the LLM context, usually the last token is assigned the reward (good answer or bad answer), but PPO needs to find at each time step to find . Propagating one reward across all time-steps is not a great idea
- That's why Deepseek came up with Group-Relative Policy Optimization (GRPO) that removes the need for a Value head!
References
- Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.