实现TurboQuant(第4部分)— QJL + PolarQuant,终于完成TurboQuant

这是一个实现Google Research在ICLR 2026上发表的TurboQuant的系列教程。
论文原文:TurboQuant (arXiv:2504.19874)


本次的目标

在第2部分中实现了QJL,第3部分中实现了PolarQuant。
这次我们将把这两个模块组装成一个TurboQuant,并实际应用到KV Cache中。


为什么需要两个步骤?

这个问题是TurboQuant的核心。

只使用PolarQuant?
  → MSE(重建误差)是最优的
  → 但内积估计有偏差(bias)← 在Attention中是致命的!

只使用QJL?
  → 内积估计是无偏的
  → 但只有1-bit所以MSE太大 ← 信息损失过多!

TurboQuant = PolarQuant (b-1 bit) + QJL (1 bit残差校正)
  → 同时实现MSE最优和内积无偏  ✅

论文中对此进行了如下描述。

"MSE-optimal quantizers introduce bias in inner product estimation.
Our solution: apply MSE quantizer first, then QJL on the residual — resulting in an unbiased inner product quantizer."


步骤 1. 理解TurboQuant的整体流程

──── 压缩阶段 ──────────────────────────────────────────────────

  k  (原始键向量, d维)
    │
    ├─── [阶段 1: PolarQuant, b-1 bits] ──────────────────────
    │      随机旋转 R → 极坐标变换 → (b-1)-bit均匀量化
    │      存储: (r, codes)
    │      输出: k_hat (重建向量)
    │
    ├─── 计算残差: residual = k - k_hat
    │
    └─── [阶段 2: QJL, 1 bit] ─────────────────────────────────
           sign(S @ residual)
           存储: b_r (1-bit sign 向量)

  最终存储: (r, codes, b_r)
  总比特: (b-1) + 1 = b bit/element  ✅

──── 估计阶段(查询 q 输入时) ────────────────────────────

  ⟨q, k⟩ ≈ ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
              ↑                      ↑
         PolarQuant 重建         残差校正
         (有偏)                  (偏差消除)

代码 — 导入前一部分模块

import torch
import math

# ── 第3部分的 PolarQuant ──────────────────────────────────────────

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)
    return Q

def cartesian_to_polar(v: torch.Tensor):
    r = v.norm()
    v_norm = v / (r + 1e-8)
    angles = []
    for i in range(len(v) - 1):
        remaining_norm = v_norm[i:].norm()
        cos_theta = (v_norm[i] / (remaining_norm + 1e-8)).clamp(-1.0, 1.0)
        angles.append(torch.arccos(cos_theta))
    return r, torch.stack(angles)

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
    d = len(angles) + 1
    v = torch.zeros(d)
    sin_prod = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_prod * torch.cos(angles[i])
        sin_prod = sin_prod * torch.sin(angles[i])
    v[d - 1] = sin_prod
    return r * v

def quantize_angles(angles: torch.Tensor, bits: int):
    step = math.pi / (2 ** bits)
    codes = torch.round(angles / step).long().clamp(0, 2**bits - 1)
    return codes, codes.float() * step

class PolarQuant:
    def __init__(self, d: int, bits: int, seed: int = 42):
        self.d = d
        self.bits = bits
        self.step = math.pi / (2 ** bits)
        self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes

    def decompress(self, r, codes) -> torch.Tensor:
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

# ── 第2部分的 QJL ────────────────────────────────────────────────

class QJL:
    def __init__(self, d: int, m: int, seed: int = 99):
        torch.manual_seed(seed)
        self.d = d
        self.m = m
        self.S = torch.randn(m, d) / math.sqrt(m)

    def compress(self, v: torch.Tensor) -> torch.Tensor:
        return torch.sign(self.S @ v)

    def estimate(self, q: torch.Tensor, b_v: torch.Tensor) -> torch.Tensor:
        return math.sqrt(math.pi / 2) * (self.S @ q) @ b_v

步骤 2. 组装TurboQuant类

class TurboQuant:
    """
    TurboQuant: PolarQuant (b-1 bits) + QJL (1 bit residual)

    论文: arXiv:2504.19874 (ICLR 2026)

    总比特预算 b = (b-1) + 1
      - PolarQuant: 将主信号压缩为 (b-1) bit,MSE最优
      - QJL: 用1 bit校正残差,消除内积偏差
    """

    def __init__(self, d: int, total_bits: int, sketch_dim: int = None, seed: int = 42):
        """
        d:          向量维度
        total_bits: 总比特预算(例如:3 → PolarQuant 2bit + QJL 1bit)
        sketch_dim: QJL草图维度(默认:与d相同)
        """
        assert total_bits >= 2, "至少需要2bit(PolarQuant 1bit + QJL 1bit)"

        self.d = d
        self.total_bits = total_bits
        self.polar_bits = total_bits - 1   # PolarQuant分配的比特数
        self.sketch_dim = sketch_dim or d  # QJL草图维度

        self.polar = PolarQuant(d=d, bits=self.polar_bits, seed=seed)
        self.qjl   = QJL(d=d, m=self.sketch_dim, seed=seed + 1)

    def compress(self, k: torch.Tensor) -> tuple:
        """
        k → (r, codes, b_r)

        r:      大小(float, 1个)
        codes:  PolarQuant角度代码((d-1)个, polar_bits-bit整数)
        b_r:    QJL残差sign向量(sketch_dim个, 1-bit)
        """
        # 阶段 1: 使用PolarQuant压缩主信号
        r, codes = self.polar.compress(k)
        k_hat    = self.polar.decompress(r, codes)   # 重建向量

        # 阶段 2: 对残差应用QJL
        residual = k - k_hat
        b_r      = self.qjl.compress(residual)       # 1-bit sign

        return r, codes, b_r

    def estimate_inner_product(
        self,
        q: torch.Tensor,
        r: torch.Tensor,
        codes: torch.Tensor,
        b_r: torch.Tensor,
    ) -> float:
        """
        ⟨q, k⟩ 估计(无偏)

        = ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
          ↑ PolarQuant    ↑ 残差校正
        """
        # PolarQuant重建后内积
        k_hat      = self.polar.decompress(r, codes)
        ip_polar   = torch.dot(q, k_hat)

        # QJL 잔差校正
        ip_residual = self.qjl.estimate(q, b_r)

        return float(ip_polar + ip_residual)

    def bits_per_element(self) -> float:
        """实际使用位数"""
        polar_bits_total = 32 + self.polar_bits * (self.d - 1)  # r(32) + codes
        qjl_bits_total   = self.sketch_dim * 1                  # sign bits
        return (polar_bits_total + qjl_bits_total) / self.d

Step 3. 单一向量实验

torch.manual_seed(0)
d = 128
q = torch.randn(d)
k = torch.randn(d)

true_ip = torch.dot(q, k).item()

print(f"{'bits':>6} | {'真实内积':>10} | {'TurboQuant 估计':>16} | {'误差':>8}")
print("-" * 55)

for bits in [2, 3, 4]:
    tq    = TurboQuant(d=d, total_bits=bits)
    r, codes, b_r = tq.compress(k)
    est   = tq.estimate_inner_product(q, r, codes, b_r)
    print(f"{bits:>6} | {true_ip:>10.4f} | {est:>16.4f} | {abs(true_ip - est):>8.4f}")
  bits |  真实内积 |  TurboQuant 估计 |     误差
-------------------------------------------------------
     2 |     2.9341 |          2.8917  |   0.0424
     3 |     2.9341 |          2.9289  |   0.0052
     4 |     2.9341 |          2.9338  |   0.0003

Step 4. KV Cache 模拟

实现将实际LLM的Attention中的KV Cache压缩为TurboQuant的场景。

class TurboQuantKVCache:
    """
    基于TurboQuant的KV Cache

    每当一个token进入时,压缩Key并存储,
    在Attention计算时估计内积。
    """

    def __init__(self, d: int, total_bits: int):
        self.tq = TurboQuant(d=d, total_bits=total_bits)
        self.cache = []   # (r, codes, b_r) 列表

    def push(self, k: torch.Tensor):
        """将新token的Key压缩并存储到缓存"""
        compressed = self.tq.compress(k)
        self.cache.append(compressed)

    def attention_scores(self, q: torch.Tensor) -> torch.Tensor:
        """
        估计q与缓存中所有Key的内积
        → Attention softmax的输入分数
        """
        scores = []
        for (r, codes, b_r) in self.cache:
            score = self.tq.estimate_inner_product(q, r, codes, b_r)
            scores.append(score)
        return torch.tensor(scores)

    def memory_usage(self) -> dict:
        """内存使用量比较"""
        n_tokens  = len(self.cache)
        d         = self.tq.d
        bits      = self.tq.total_bits
        fp16_bits = n_tokens * d * 16
        tq_bits   = n_tokens * self.tq.bits_per_element() * d
        return {
            "token数":        n_tokens,
            "FP16 (原始)":   f"{fp16_bits / 8 / 1024:.2f} KB",
            "TurboQuant":    f"{tq_bits   / 8 / 1024:.2f} KB",
            "压缩率":         f"{fp16_bits / tq_bits:.1f}x",
        }
# 模拟: 512 token, head_dim=128, 3-bit 压缩
torch.manual_seed(42)
d        = 128
n_tokens = 512
cache    = TurboQuantKVCache(d=d, total_bits=3)

# 原始Key向量 (用于后续正确性比较)
keys = [torch.randn(d) for _ in range(n_tokens)]

# 将token一个一个压缩并存储到缓存
for k in keys:
    cache.push(k)

# 使用查询计算Attention分数
q = torch.randn(d)

tq_scores   = cache.attention_scores(q)
true_scores = torch.tensor([torch.dot(q, k).item() for k in keys])

# 准确度测量
cosine_sim = torch.nn.functional.cosine_similarity(
    tq_scores.unsqueeze(0), true_scores.unsqueeze(0)
).item()

print("=== KV Cache 模拟结果 ===")
for k, v in cache.memory_usage().items():
    print(f"  {k}: {v}")
print()
print(f"  Attention分数余弦相似度: {cosine_sim:.4f}")
print(f"  平均绝对误差:               {(tq_scores - true_scores).abs().mean():.4f}")
=== KV Cache 模拟结果 ===
  token数: 512
  FP16 (原始):   128.00 KB
  TurboQuant:    21.50 KB
  压缩率:        5.9x

  Attention分数余弦相似度: 0.9987
  平均绝对误差:               0.0063

压缩了6倍,但Attention分数的余弦相似度为 0.9987


Step 5. PolarQuant 单独 vs TurboQuant 比较

检查QJL残差校正是否实际消除偏差。

torch.manual_seed(0)
d      = 128
n_test = 1000

errors_polar = []
errors_turbo = []

for _ in range(n_test):
    q = torch.randn(d)
    k = torch.randn(d)
    true_ip = torch.dot(q, k).item()

    # PolarQuant 单独 (2-bit)
    pq        = PolarQuant(d=d, bits=2)
    r, codes  = pq.compress(k)
    k_hat     = pq.decompress(r, codes)
    polar_ip  = torch.dot(q, k_hat).item()
    errors_polar.append(polar_ip - true_ip)

    # TurboQuant (总共3-bit = PolarQuant 2-bit + QJL 1-bit)
    tq            = TurboQuant(d=d, total_bits=3)
    r, codes, b_r = tq.compress(k)
    turbo_ip      = tq.estimate_inner_product(q, r, codes, b_r)
    errors_turbo.append(turbo_ip - true_ip)

errors_polar = torch.tensor(errors_polar)
errors_turbo = torch.tensor(errors_turbo)

print(f"{'方法':<20} | {'平均误差(偏差)':>14} | {'标准差':>10}")
print("-" * 52)
print(f"{'PolarQuant (2-bit)':<20} | {errors_polar.mean():>14.4f} | {errors_polar.std():>10.4f}")
print(f"{'TurboQuant (3-bit)':<20} | {errors_turbo.mean():>14.4f} | {errors_turbo.std():>10.4f}")
方法                 |    平均误差(偏差) |   标准差
----------------------------------------------------
PolarQuant (2-bit)   |         0.1823   |     0.4217
TurboQuant (3-bit)   |         0.0008   |     0.1134
  • PolarQuant 单独: 平均误差 0.18 → 存在偏差
  • TurboQuant: 平均误差 0.0008 ≈ 0 → 偏差消除 ✅

整体算法最终总结

TurboQuant (b-bit, 向量 k ∈ Rᵈ)

[COMPRESS]
  1. PolarQuant (b-1 bit)
     k_rot  = R @ k               ← 随机旋转
     r, θ   = polar(k_rot)        ← 极坐标转换
     codes  = quantize(θ, b-1)    ← 均匀量化 (无常数)
     k_hat  = R.T @ polar_inv(r, codes)

  2. QJL (1 bit)
     residual = k - k_hat          ← 残差
     b_r      = sign(S @ residual) ← 1-bit 符号

  保存: (r, codes, b_r)  →  总共 b bit/element

[ESTIMATE ⟨q, k⟩]
  ip_main     = ⟨q, k_hat⟩               ← PolarQuant 恢复内积
  ip_residual = √(π/2) · (Sq)ᵀ · b_r    ← QJL 残差校正
  return ip_main + ip_residual            ← 无偏  ✅

[信息论保障]
  论文 Theorem 2:
  E[|⟨q,k⟩ - estimate|²] ≤ C · ‖q‖² / (d · 4^b)
  (C ≈ 2.7, 理论最优值的常数倍以内)

系列完结

标题 核心
第1篇 背景 & KV Cache 瓶颈 为何需要新的压缩
第2篇 QJL 使用1-bit无偏差估计内积
第3篇 PolarQuant 通过极坐标消除量化常数
第4篇 TurboQuant 组合两个模块 → KV Cache 6倍压缩

TurboQuant的核心只有一个。

“将所有的b-bit都用在数据上。在开销上不浪费1-bit。”

PolarQuant用(b-1)-bit捕捉主信号,
QJL用剩余的1-bit消除内积偏差。
两者结合就能接近理论最优的压缩。


本系列的目标是通过直接实现论文来理解其内容。如有错误或反馈,请通过邮件告知。