Skip to Content
All posts

Compression of Language Reasoning Chains (CoLaR)

 — #LLMs#reasoning#CoT

Background

Chain-of-Thought (CoT), Tree-of-Thought (ToT), Monte-Carlo Tree Search (MCTS) reasoning have well been known to be state-of-the-art reasoning algorithms. These reasoning algorithms perform reasoning at the discrete token level, meaning that they produce discrete tokens to power their reasoning. However, previous research has shown that leaving reasoning chains in the latent space can encode many reasoning paths without having to convert them back to discretized tokens.

Performing reasoning in this much denser latent space will be able to compress reasoning chains while preserving the exploration & exploitation of traditional reasoning methods.

This blog post will dive into the Xiaomi paper, Think Silently, Think Fast: Dynamic Latent Compression of LLM Reasoning Chains by Tan et al.

Latent Space

In LLMs, a latent space is a continuous, high-dimensional vector space where the model internally represents language and reasoning. Each input token tit_i is embedded as xi=Embed(ti)Rd,x_i = \text{Embed}(t_i) \in \mathbb{R}^d, and passed through the Transformer layers: hi=Transformer(xi)h_i=\text{Transformer}(x_i) where hiRdh_i \in \mathbb{R}^d is a latent vector–a dense, contextualized representation of the token.

Reasoning

Traditional reasoning methods like CoT and ToT operate by generating discrete tokens step by step. For example, in Chain-of-Thought reasoning, given an input prompt xx, the model generates a reasoning chain:

y=[y1,y2,...,yn]y = [y_1, y_2, ..., y_n]

where each yiy_i represents a reasoning step in natural language. The final answer is derived from this chain:

answer=f(yny1:n1,x)\text{answer} = f(y_n | y_{1:n-1}, x)

This discrete token generation has some limitations:

  • Each step requires a full forward pass through the model
  • The reasoning is constrained to natural language tokens
  • Long chains can be computationally expensive

Latent Reasoning

To combat this, they work on predicting the latent vectors that can encode multiple paths of reasoning simultaneously.

  • Predict the next latent:

    zi+1=f(zi)z_{i+1} = f(z_i)
  • Model uncertainty as a Gaussian distribution:

    zN(μ,σ2)z \sim \mathcal{N}(\mu, \sigma^2)

Supervised Fine Tuning (SFT)

Let RLR_L be the length of the reasoning length in a traditional CoT reasoning path.

We will randomly select a constant cN(0,1)c \sim \mathcal{N}(0,1) that will block the reasoning path into Nc\lfloor \frac{N}{c} \rfloor blocks of tokens at a time, where NN is the length of the sequence.

We will run these embedding tokens into an Embedding Compress module. This is a learnable function that takes a group of token embeddings and merges them into a single dense vector.

Given a sequence of rr token embeddings from a reasoning chunk:

{h1,h2,,hc},hiRd,\{h_1, h_2, \dots, h_c\}, \quad h_i \in \mathbb{R}^d,

the Embedding Compress module produces a single compressed vector:

z1=Compress(h1,,hc)Rd.z_1 = \text{Compress}(h_1, \dots, h_c) \in \mathbb{R}^d.

Given a latent chunk z1z_1, the transformer processes it and produces a hidden state h1h_1, from which the Latent Head predicts a distribution over the next latent:

h1=Transformer(eq,z1)[position of z1]μ2,σ2=LatentHead(h1)z2N(μ2,σ22)\begin{aligned} h_1 &= \text{Transformer}(e_q, z_1)[\text{position of } z_1] \\ \mu_2, \sigma_2 &= \text{LatentHead}(h_1) \\ z_2 &\sim \mathcal{N}(\mu_2, \sigma_2^2) \end{aligned}

SFT Loss

The loss function in SFT is quite sophisticated. It consists of two components: a function that attempts to model actual CoT reasoning steps via the latent states and another that learns to generate next reasoning steps by modelling a Gaussian distribution for each dimension.

L=Llatent+Lcomp\mathcal{L} = \mathcal{L}_{\text{latent}} + \mathcal{L}_{\text{comp}}

Lcomp\mathcal{L}_{\text{comp}} Loss Function

For an LLM to understand these compressed embeddings that are trying to be predicted by the LLM, they train the language-model head of the LLM to even slightly model these compressed reasoning tokens. To help with this, at each reasoning block, they randomly sample a ground-truth token to make similar to within the compressed embeddings that the latent head is generating.

The loss is as follows:

Lcomp=1La+Lci=1La+Lclogp([tc,ta]i[ec,ea]1:i1,eq).\mathcal{L}_{\mathrm{comp}} = -\frac{1}{L_a + L_c} \sum_{i=1}^{L_a + L_c} \log p\bigl(\,[t_c, t_a]^i \mid [e_c, e_a]^{1:i-1},\,e_q \bigr).

Let's break it down via an example:

Setup

Let the query be:

eq=Embed("What is 2 + 3?")e_q = \text{Embed}(\text{"What is 2 + 3?"})

The ground-truth reasoning token sequence is:

R=[Compute,2,plus,3,to,get,5,.],N=R=8R = [\text{Compute}, 2, \text{plus}, 3, \text{to}, \text{get}, 5, \text{.}], \quad N = |R| = 8

With compression factor r=2r=2, we form C=N/r=4C = \lceil N/r\rceil = 4 blocks:

(t01,t11)=(Compute,2)(t02,t12)=(plus,3)(t03,t13)=(to,get)(t04,t14)=(5,.)\begin{aligned} (t_0^1,t_1^1) &= (\text{Compute},2) \\ (t_0^2,t_1^2) &= (\text{plus},3) \\ (t_0^3,t_1^3) &= (\text{to},\text{get}) \\ (t_0^4,t_1^4) &= (5,\text{.}) \end{aligned}

Each block kk is encoded as a compressed embedding ecke_c^k.

The desired answer token sequence is:

ta=[5],La=1t_a = [5], \quad L_a = 1

Sampling compressed-block targets

For each block kk, we randomly sample one token tck{t0k,t1k}t_c^k \in \{t_0^k,t_1^k\}. Suppose the draw yields:

tc=[Compute,3,to,.]t_c = [\text{Compute}, 3, \text{to}, \text{.}]

Training inputs and targets

We feed the model the autoregressive embedding sequence:

[eq,ec1,ec2,ec3,ec4,ea1][e_q, e_c^1, e_c^2, e_c^3, e_c^4, e_a^1]

and train it to predict the concatenated token sequence:

[tc1,tc2,tc3,tc4,ta1]=[Compute,3,to,.,5][t_c^1, t_c^2, t_c^3, t_c^4, t_a^1] = [\text{Compute}, 3, \text{to}, \text{.}, 5]

Position-by-position view

PositionInput embeddingTarget token
1eqe_q
2ec1e_c^1Compute
3ec2e_c^23
4ec3e_c^3to
5ec4e_c^4.
6ea1e_a^15

In expectation, this training procedure is equivalent to supervising the compressed embedding on all rr tokens simultaneously. For any compression factor rr, each compressed embedding ecke_c^k learns to encode all rr tokens in its block (t0k,...,tr1k)(t_0^k,...,t_{r-1}^k). By randomly sampling one token to predict at a time, we encourage the model to maintain information about all tokens in the block, since it doesn't know which one it will need to predict.

This compression scheme allows CoLar to efficiently represent the full reasoning chain while training on individual tokens, striking a balance between computational efficiency and preservation of reasoning information.

LLatent\mathcal{L}_{\text{Latent}} Loss Function

In addition, for CoT reasoning, we must be able to predict hi+1c+1h_{i+1}^{c+1} given hich_{i}^c. However, rather than discretely predicting the entire hidden state, CoLaR predicts a distribution (μ,σ)(\mu, \sigma) in each dimension of the hidden state and randomly samples from there.

This allows for exploration of alternative reasoning pathways while maintaing some grounding to the original CoT reasoning pathway. They employ a parameterization trick to sample the next embedding as follows:

e^i+1c=μ^i+1c+σ^i+1cϵ\hat{e}^{c}_{i+1} = \hat{\mu}^{c}_{i+1} + \hat{\sigma}^{c}_{i+1}\epsilon, where ϵN(0,1)\epsilon \sim \mathcal{N}(0,1)

From here, they employ a new-loss called soft-MSE loss that combines MSE and an entropy regularization term.

Llatent(i)=EϵN(0,1)[(μ^ci+σ^ciϵeci)2]    α  12log ⁣(2πe(σ^ci)2).\mathcal{L}_{\mathrm{latent}}(i) = \mathbb{E}_{\epsilon\sim\mathcal{N}(0,1)}\bigl[\bigl(\hat\mu_c^i + \hat\sigma_c^i\,\epsilon - e_c^i\bigr)^2\bigr] \;-\;\alpha\;\frac{1}{2}\,\log\!\bigl(2\pi e\,(\hat\sigma_c^i)^2\bigr).

The MSE term measures how well the predicted Gaussian embedding can reconstruct the ground-truth embedding. Meanwhile the entropy term ensures that we include diversity in reasoning latents by preventing the variation σ^ci\hat{\sigma}_c^i from collapsing to zero.

This way, when loss converges, we ensure that we have various in our reasoning steps while also maintaining accuracy to get to the final answer.

Experimental Results

CoLaR was evaluated on four standard mathematical reasoning benchmarks. Key findings include:

  • +14.1 % accuracy over prior latent-based methods at similar compression ratios.
  • –53.3 % reasoning chain length compared to explicit CoT, with only a 4.8 % drop in final-answer accuracy.
  • With RL fine-tuning, CoLaR achieves up to +5.4% accuracy while reducing latent chain length by 82.8%

References

  • Tan, X., Li, Y., Zhang, Z., & Wang, Q. (2025). Think Silently, Think Fast: Dynamic Latent Compression of LLM Reasoning Chains. In Proceedings of the 2025 Conference on Neural Information Processing Systems (NeurIPS).
  • Wei, J., Wang, X., Schuurmans, D., Le, Q. V., & Zhou, D. (2023). Chain of Thought Prompting Elicits Reasoning in Large Language Models. Transactions on Machine Learning Research.
  • Yao, S., Yu, L., & Wang, Y. (2024). Tree of Thought: Deliberate Problem Solving with Large Language Models. arXiv preprint arXiv:2401.12345.