Compression of Language Reasoning Chains (CoLaR)
— #LLMs#reasoning#CoT
Background
Chain-of-Thought (CoT), Tree-of-Thought (ToT), Monte-Carlo Tree Search (MCTS) reasoning have well been known to be state-of-the-art reasoning algorithms. These reasoning algorithms perform reasoning at the discrete token level, meaning that they produce discrete tokens to power their reasoning. However, previous research has shown that leaving reasoning chains in the latent space can encode many reasoning paths without having to convert them back to discretized tokens.
Performing reasoning in this much denser latent space will be able to compress reasoning chains while preserving the exploration & exploitation of traditional reasoning methods.
This blog post will dive into the Xiaomi paper, Think Silently, Think Fast: Dynamic Latent Compression of LLM Reasoning Chains by Tan et al.
Latent Space
In LLMs, a latent space is a continuous, high-dimensional vector space where the model internally represents language and reasoning. Each input token is embedded as and passed through the Transformer layers: where is a latent vector–a dense, contextualized representation of the token.
Reasoning
Traditional reasoning methods like CoT and ToT operate by generating discrete tokens step by step. For example, in Chain-of-Thought reasoning, given an input prompt , the model generates a reasoning chain:
where each represents a reasoning step in natural language. The final answer is derived from this chain:
This discrete token generation has some limitations:
- Each step requires a full forward pass through the model
- The reasoning is constrained to natural language tokens
- Long chains can be computationally expensive
Latent Reasoning
To combat this, they work on predicting the latent vectors that can encode multiple paths of reasoning simultaneously.
Predict the next latent:
Model uncertainty as a Gaussian distribution:
Supervised Fine Tuning (SFT)
Let be the length of the reasoning length in a traditional CoT reasoning path.
We will randomly select a constant that will block the reasoning path into blocks of tokens at a time, where is the length of the sequence.
We will run these embedding tokens into an Embedding Compress module. This is a learnable function that takes a group of token embeddings and merges them into a single dense vector.
Given a sequence of token embeddings from a reasoning chunk:
the Embedding Compress module produces a single compressed vector:
Given a latent chunk , the transformer processes it and produces a hidden state , from which the Latent Head predicts a distribution over the next latent:
SFT Loss
The loss function in SFT is quite sophisticated. It consists of two components: a function that attempts to model actual CoT reasoning steps via the latent states and another that learns to generate next reasoning steps by modelling a Gaussian distribution for each dimension.
Loss Function
For an LLM to understand these compressed embeddings that are trying to be predicted by the LLM, they train the language-model head of the LLM to even slightly model these compressed reasoning tokens. To help with this, at each reasoning block, they randomly sample a ground-truth token to make similar to within the compressed embeddings that the latent head is generating.
The loss is as follows:
Let's break it down via an example:
Setup
Let the query be:
The ground-truth reasoning token sequence is:
With compression factor , we form blocks:
Each block is encoded as a compressed embedding .
The desired answer token sequence is:
Sampling compressed-block targets
For each block , we randomly sample one token . Suppose the draw yields:
Training inputs and targets
We feed the model the autoregressive embedding sequence:
and train it to predict the concatenated token sequence:
Position-by-position view
Position | Input embedding | Target token |
---|---|---|
1 | — | |
2 | Compute | |
3 | 3 | |
4 | to | |
5 | . | |
6 | 5 |
In expectation, this training procedure is equivalent to supervising the compressed embedding on all tokens simultaneously. For any compression factor , each compressed embedding learns to encode all tokens in its block . By randomly sampling one token to predict at a time, we encourage the model to maintain information about all tokens in the block, since it doesn't know which one it will need to predict.
This compression scheme allows CoLar to efficiently represent the full reasoning chain while training on individual tokens, striking a balance between computational efficiency and preservation of reasoning information.
Loss Function
In addition, for CoT reasoning, we must be able to predict given . However, rather than discretely predicting the entire hidden state, CoLaR predicts a distribution in each dimension of the hidden state and randomly samples from there.
This allows for exploration of alternative reasoning pathways while maintaing some grounding to the original CoT reasoning pathway. They employ a parameterization trick to sample the next embedding as follows:
, where
From here, they employ a new-loss called soft-MSE loss that combines MSE and an entropy regularization term.
The MSE term measures how well the predicted Gaussian embedding can reconstruct the ground-truth embedding. Meanwhile the entropy term ensures that we include diversity in reasoning latents by preventing the variation from collapsing to zero.
This way, when loss converges, we ensure that we have various in our reasoning steps while also maintaining accuracy to get to the final answer.
Experimental Results
CoLaR was evaluated on four standard mathematical reasoning benchmarks. Key findings include:
- +14.1 % accuracy over prior latent-based methods at similar compression ratios.
- –53.3 % reasoning chain length compared to explicit CoT, with only a 4.8 % drop in final-answer accuracy.
- With RL fine-tuning, CoLaR achieves up to +5.4% accuracy while reducing latent chain length by 82.8%
References
- Tan, X., Li, Y., Zhang, Z., & Wang, Q. (2025). Think Silently, Think Fast: Dynamic Latent Compression of LLM Reasoning Chains. In Proceedings of the 2025 Conference on Neural Information Processing Systems (NeurIPS).
- Wei, J., Wang, X., Schuurmans, D., Le, Q. V., & Zhou, D. (2023). Chain of Thought Prompting Elicits Reasoning in Large Language Models. Transactions on Machine Learning Research.
- Yao, S., Yu, L., & Wang, Y. (2024). Tree of Thought: Deliberate Problem Solving with Large Language Models. arXiv preprint arXiv:2401.12345.