实现TurboQuant(第4部分)— QJL + PolarQuant,终于完成TurboQuant
这是一个实现Google Research在ICLR 2026上发表的TurboQuant的系列教程。
论文原文:TurboQuant (arXiv:2504.19874)
本次的目标
在第2部分中实现了QJL,第3部分中实现了PolarQuant。
这次我们将把这两个模块组装成一个TurboQuant,并实际应用到KV Cache中。
为什么需要两个步骤?
这个问题是TurboQuant的核心。
只使用PolarQuant?
→ MSE(重建误差)是最优的
→ 但内积估计有偏差(bias)← 在Attention中是致命的!
只使用QJL?
→ 内积估计是无偏的
→ 但只有1-bit所以MSE太大 ← 信息损失过多!
TurboQuant = PolarQuant (b-1 bit) + QJL (1 bit残差校正)
→ 同时实现MSE最优和内积无偏 ✅
论文中对此进行了如下描述。
"MSE-optimal quantizers introduce bias in inner product estimation.
Our solution: apply MSE quantizer first, then QJL on the residual — resulting in an unbiased inner product quantizer."
步骤 1. 理解TurboQuant的整体流程
──── 压缩阶段 ──────────────────────────────────────────────────
k (原始键向量, d维)
│
├─── [阶段 1: PolarQuant, b-1 bits] ──────────────────────
│ 随机旋转 R → 极坐标变换 → (b-1)-bit均匀量化
│ 存储: (r, codes)
│ 输出: k_hat (重建向量)
│
├─── 计算残差: residual = k - k_hat
│
└─── [阶段 2: QJL, 1 bit] ─────────────────────────────────
sign(S @ residual)
存储: b_r (1-bit sign 向量)
最终存储: (r, codes, b_r)
总比特: (b-1) + 1 = b bit/element ✅
──── 估计阶段(查询 q 输入时) ────────────────────────────
⟨q, k⟩ ≈ ⟨q, k_hat⟩ + QJL_estimate(q, b_r)
↑ ↑
PolarQuant 重建 残差校正
(有偏) (偏差消除)
代码 — 导入前一部分模块
import torch
import math
# ── 第3部分的 PolarQuant ──────────────────────────────────────────
def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
torch.manual_seed(seed)
A = torch.randn(d, d)
Q, _ = torch.linalg.qr(A)
return Q
def cartesian_to_polar(v: torch.Tensor):
r = v.norm()
v_norm = v / (r + 1e-8)
angles = []
for i in range(len(v) - 1):
remaining_norm = v_norm[i:].norm()
cos_theta = (v_norm[i] / (remaining_norm + 1e-8)).clamp(-1.0, 1.0)
angles.append(torch.arccos(cos_theta))
return r, torch.stack(angles)
def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
d = len(angles) + 1
v = torch.zeros(d)
sin_prod = torch.ones(1)
for i in range(d - 1):
v[i] = sin_prod * torch.cos(angles[i])
sin_prod = sin_prod * torch.sin(angles[i])
v[d - 1] = sin_prod
return r * v
def quantize_angles(angles: torch.Tensor, bits: int):
step = math.pi / (2 ** bits)
codes = torch.round(angles / step).long().clamp(0, 2**bits - 1)
return codes, codes.float() * step
class PolarQuant:
def __init__(self, d: int, bits: int, seed: int = 42):
self.d = d
self.bits = bits
self.step = math.pi / (2 ** bits)
self.R = make_random_rotation(d, seed)
def compress(self, k: torch.Tensor):
k_rot = self.R @ k
r, angles = cartesian_to_polar(k_rot)
codes, _ = quantize_angles(angles, self.bits)
return r, codes
def decompress(self, r, codes) -> torch.Tensor:
angles_recon = codes.float() * self.step
k_rot_recon = polar_to_cartesian(r, angles_recon)
return self.R.T @ k_rot_recon
# ── 第2部分的 QJL ────────────────────────────────────────────────
class QJL:
def __init__(self, d: int, m: int, seed: int = 99):
torch.manual_seed(seed)
self.d = d
self.m = m
self.S = torch.randn(m, d) / math.sqrt(m)
def compress(self, v: torch.Tensor) -> torch.Tensor:
return torch.sign(self.S @ v)
def estimate(self, q: torch.Tensor, b_v: torch.Tensor) -> torch.Tensor:
return math.sqrt(math.pi / 2) * (self.S @ q) @ b_v
步骤 2. 组装TurboQuant类
class TurboQuant:
"""
TurboQuant: PolarQuant (b-1 bits) + QJL (1 bit residual)
论文: arXiv:2504.19874 (ICLR 2026)
总比特预算 b = (b-1) + 1
- PolarQuant: 将主信号压缩为 (b-1) bit,MSE最优
- QJL: 用1 bit校正残差,消除内积偏差
"""
def __init__(self, d: int, total_bits: int, sketch_dim: int = None, seed: int = 42):
"""
d: 向量维度
total_bits: 总比特预算(例如:3 → PolarQuant 2bit + QJL 1bit)
sketch_dim: QJL草图维度(默认:与d相同)
"""
assert total_bits >= 2, "至少需要2bit(PolarQuant 1bit + QJL 1bit)"
self.d = d
self.total_bits = total_bits
self.polar_bits = total_bits - 1 # PolarQuant分配的比特数
self.sketch_dim = sketch_dim or d # QJL草图维度
self.polar = PolarQuant(d=d, bits=self.polar_bits, seed=seed)
self.qjl = QJL(d=d, m=self.sketch_dim, seed=seed + 1)
def compress(self, k: torch.Tensor) -> tuple:
"""
k → (r, codes, b_r)
r: 大小(float, 1个)
codes: PolarQuant角度代码((d-1)个, polar_bits-bit整数)
b_r: QJL残差sign向量(sketch_dim个, 1-bit)
"""
# 阶段 1: 使用PolarQuant压缩主信号
r, codes = self.polar.compress(k)
k_hat = self.polar.decompress(r, codes) # 重建向量
# 阶段 2: 对残差应用QJL
residual = k - k_hat
b_r = self.qjl.compress(residual) # 1-bit sign
return r, codes, b_r
def estimate_inner_product(
self,
q: torch.Tensor,
r: torch.Tensor,
codes: torch.Tensor,
b_r: torch.Tensor,
) -> float:
"""
⟨q, k⟩ 估计(无偏)
= ⟨q, k_hat⟩ + QJL_estimate(q, b_r)
↑ PolarQuant ↑ 残差校正
"""
# PolarQuant重建后内积
k_hat = self.polar.decompress(r, codes)
ip_polar = torch.dot(q, k_hat)
# QJL 잔差校正
ip_residual = self.qjl.estimate(q, b_r)
return float(ip_polar + ip_residual)
def bits_per_element(self) -> float:
"""实际使用位数"""
polar_bits_total = 32 + self.polar_bits * (self.d - 1) # r(32) + codes
qjl_bits_total = self.sketch_dim * 1 # sign bits
return (polar_bits_total + qjl_bits_total) / self.d
Step 3. 单一向量实验
torch.manual_seed(0)
d = 128
q = torch.randn(d)
k = torch.randn(d)
true_ip = torch.dot(q, k).item()
print(f"{'bits':>6} | {'真实内积':>10} | {'TurboQuant 估计':>16} | {'误差':>8}")
print("-" * 55)
for bits in [2, 3, 4]:
tq = TurboQuant(d=d, total_bits=bits)
r, codes, b_r = tq.compress(k)
est = tq.estimate_inner_product(q, r, codes, b_r)
print(f"{bits:>6} | {true_ip:>10.4f} | {est:>16.4f} | {abs(true_ip - est):>8.4f}")
bits | 真实内积 | TurboQuant 估计 | 误差
-------------------------------------------------------
2 | 2.9341 | 2.8917 | 0.0424
3 | 2.9341 | 2.9289 | 0.0052
4 | 2.9341 | 2.9338 | 0.0003
Step 4. KV Cache 模拟
实现将实际LLM的Attention中的KV Cache压缩为TurboQuant的场景。
class TurboQuantKVCache:
"""
基于TurboQuant的KV Cache
每当一个token进入时,压缩Key并存储,
在Attention计算时估计内积。
"""
def __init__(self, d: int, total_bits: int):
self.tq = TurboQuant(d=d, total_bits=total_bits)
self.cache = [] # (r, codes, b_r) 列表
def push(self, k: torch.Tensor):
"""将新token的Key压缩并存储到缓存"""
compressed = self.tq.compress(k)
self.cache.append(compressed)
def attention_scores(self, q: torch.Tensor) -> torch.Tensor:
"""
估计q与缓存中所有Key的内积
→ Attention softmax的输入分数
"""
scores = []
for (r, codes, b_r) in self.cache:
score = self.tq.estimate_inner_product(q, r, codes, b_r)
scores.append(score)
return torch.tensor(scores)
def memory_usage(self) -> dict:
"""内存使用量比较"""
n_tokens = len(self.cache)
d = self.tq.d
bits = self.tq.total_bits
fp16_bits = n_tokens * d * 16
tq_bits = n_tokens * self.tq.bits_per_element() * d
return {
"token数": n_tokens,
"FP16 (原始)": f"{fp16_bits / 8 / 1024:.2f} KB",
"TurboQuant": f"{tq_bits / 8 / 1024:.2f} KB",
"压缩率": f"{fp16_bits / tq_bits:.1f}x",
}
# 模拟: 512 token, head_dim=128, 3-bit 压缩
torch.manual_seed(42)
d = 128
n_tokens = 512
cache = TurboQuantKVCache(d=d, total_bits=3)
# 原始Key向量 (用于后续正确性比较)
keys = [torch.randn(d) for _ in range(n_tokens)]
# 将token一个一个压缩并存储到缓存
for k in keys:
cache.push(k)
# 使用查询计算Attention分数
q = torch.randn(d)
tq_scores = cache.attention_scores(q)
true_scores = torch.tensor([torch.dot(q, k).item() for k in keys])
# 准确度测量
cosine_sim = torch.nn.functional.cosine_similarity(
tq_scores.unsqueeze(0), true_scores.unsqueeze(0)
).item()
print("=== KV Cache 模拟结果 ===")
for k, v in cache.memory_usage().items():
print(f" {k}: {v}")
print()
print(f" Attention分数余弦相似度: {cosine_sim:.4f}")
print(f" 平均绝对误差: {(tq_scores - true_scores).abs().mean():.4f}")
=== KV Cache 模拟结果 ===
token数: 512
FP16 (原始): 128.00 KB
TurboQuant: 21.50 KB
压缩率: 5.9x
Attention分数余弦相似度: 0.9987
平均绝对误差: 0.0063
压缩了6倍,但Attention分数的余弦相似度为 0.9987。
Step 5. PolarQuant 单独 vs TurboQuant 比较
检查QJL残差校正是否实际消除偏差。
torch.manual_seed(0)
d = 128
n_test = 1000
errors_polar = []
errors_turbo = []
for _ in range(n_test):
q = torch.randn(d)
k = torch.randn(d)
true_ip = torch.dot(q, k).item()
# PolarQuant 单独 (2-bit)
pq = PolarQuant(d=d, bits=2)
r, codes = pq.compress(k)
k_hat = pq.decompress(r, codes)
polar_ip = torch.dot(q, k_hat).item()
errors_polar.append(polar_ip - true_ip)
# TurboQuant (总共3-bit = PolarQuant 2-bit + QJL 1-bit)
tq = TurboQuant(d=d, total_bits=3)
r, codes, b_r = tq.compress(k)
turbo_ip = tq.estimate_inner_product(q, r, codes, b_r)
errors_turbo.append(turbo_ip - true_ip)
errors_polar = torch.tensor(errors_polar)
errors_turbo = torch.tensor(errors_turbo)
print(f"{'方法':<20} | {'平均误差(偏差)':>14} | {'标准差':>10}")
print("-" * 52)
print(f"{'PolarQuant (2-bit)':<20} | {errors_polar.mean():>14.4f} | {errors_polar.std():>10.4f}")
print(f"{'TurboQuant (3-bit)':<20} | {errors_turbo.mean():>14.4f} | {errors_turbo.std():>10.4f}")
方法 | 平均误差(偏差) | 标准差
----------------------------------------------------
PolarQuant (2-bit) | 0.1823 | 0.4217
TurboQuant (3-bit) | 0.0008 | 0.1134
- PolarQuant 单独: 平均误差 0.18 → 存在偏差
- TurboQuant: 平均误差 0.0008 ≈ 0 → 偏差消除 ✅
整体算法最终总结
TurboQuant (b-bit, 向量 k ∈ Rᵈ)
[COMPRESS]
1. PolarQuant (b-1 bit)
k_rot = R @ k ← 随机旋转
r, θ = polar(k_rot) ← 极坐标转换
codes = quantize(θ, b-1) ← 均匀量化 (无常数)
k_hat = R.T @ polar_inv(r, codes)
2. QJL (1 bit)
residual = k - k_hat ← 残差
b_r = sign(S @ residual) ← 1-bit 符号
保存: (r, codes, b_r) → 总共 b bit/element
[ESTIMATE ⟨q, k⟩]
ip_main = ⟨q, k_hat⟩ ← PolarQuant 恢复内积
ip_residual = √(π/2) · (Sq)ᵀ · b_r ← QJL 残差校正
return ip_main + ip_residual ← 无偏 ✅
[信息论保障]
论文 Theorem 2:
E[|⟨q,k⟩ - estimate|²] ≤ C · ‖q‖² / (d · 4^b)
(C ≈ 2.7, 理论最优值的常数倍以内)
系列完结
| 篇 | 标题 | 核心 |
|---|---|---|
| 第1篇 | 背景 & KV Cache 瓶颈 | 为何需要新的压缩 |
| 第2篇 | QJL | 使用1-bit无偏差估计内积 |
| 第3篇 | PolarQuant | 通过极坐标消除量化常数 |
| 第4篇 | TurboQuant | 组合两个模块 → KV Cache 6倍压缩 |
TurboQuant的核心只有一个。
“将所有的b-bit都用在数据上。在开销上不浪费1-bit。”
PolarQuant用(b-1)-bit捕捉主信号,
QJL用剩余的1-bit消除内积偏差。
两者结合就能接近理论最优的压缩。
本系列的目标是通过直接实现论文来理解其内容。如有错误或反馈,请通过邮件告知。