实现TurboQuant(第3部分)— PolarQuant,消除极坐标下的量化常数

这是一个直接实现Google Research在ICLR 2026上发布的TurboQuant的系列文章。
论文原文:TurboQuant (arXiv:2504.19874)
PolarQuant论文:arXiv:2502.02617 — AISTATS 2026


本篇目标

在第2部分中,我们学习了如何用QJL将残差误差限制在1-bit
在本篇中,我们将探讨TurboQuant的第二个核心模块PolarQuant

PolarQuant的作用是用(b-1) bit压缩关键向量的主信号
关键问题是一个。

为什么传统方法需要量化常数(scale, zero_point),而PolarQuant不需要?


传统量化的问题 — 为什么需要常数?

一般的均匀量化(uniform quantization)如下操作。

原始向量: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
              ↓
首先找到最小值和最大值
  min = -2.1,  max = 3.2

  scale    = (max - min) / (2^b - 1)   ← 必须存储
  zero_pt  = min                        ← 必须存储

  code = round((x - zero_pt) / scale)  ← b-bit整数

问题在于每个块都需要以full precision(32-bit)存储scalezero_pt
块越小,开销越大。

块大小  16  →  开销 64bit / 16  =  +4.0 bit/element
块大小  64  →  开销 64bit / 64  =  +1.0 bit/element
块大小 128  →  开销 64bit / 128 =  +0.5 bit/element

PolarQuant的核心理念

PolarQuant通过改变坐标系来解决这个问题。

笛卡尔坐标 (Cartesian)  →  极坐标 (Polar)
  [x, y, z, ...]           [r, θ₁, θ₂, ...]
  混合大小和方向          分离大小(r)和角度(θ)

核心洞察如下。

如果先应用随机旋转,极坐标的角度(θ)将接近均匀分布。
均匀分布的最小值/最大值理论上是固定的 → 无需scale, zero_pt!


步骤1. 随机旋转 (Random Preconditioning)

在将向量转换为极坐标之前,先用随机正交矩阵R进行旋转。

k_rot = R @ k

R: 随机正交矩阵(基于Hadamard矩阵快速生成)
   R @ Rᵀ = I  (由于是旋转,大小保持不变)

为什么先要旋转?

旋转前: 向量的某些维度可能集中值
          [-2.1, 0.01, 0.02, 0.01, 3.2, ...]  ← 仅首/尾值大

旋转后: 能量在所有维度均匀分布
          [0.4, -0.3, 0.5, -0.4, 0.3, ...]   ← 均匀分布

当能量均匀分布后,极坐标转换后的角度分布理论上可预测

代码

import torch
import math

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    """
    生成随机正交矩阵(使用QR分解)
    实际论文基于Hadamard实现更快,但这里为了理解使用QR分解
    """
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)  # Q是正交矩阵
    return Q  # shape: [d, d]

d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])

R = make_random_rotation(d)
k_rot = R @ k

print(f"旋转前: {k.numpy().round(2)}")
print(f"旋转后: {k_rot.detach().numpy().round(2)}")
print(f"大小保持: {k.norm():.4f} → {k_rot.norm():.4f}")  # 应该相同
旋转前: [-2.1   0.01  0.02  0.01  3.2   0.    -0.1   0.05]
旋转后: [ 0.43 -0.31  0.52 -0.38  0.29  0.41 -0.44  0.35]
大小保持: 3.8210 → 3.8210

步骤2. 极坐标转换 (Polar Transformation)

将旋转后的向量转换为极坐标。

k_rot = [x₁, x₂, x₃, ..., xd]
           ↓  极坐标转换
  r  = ‖k_rot‖        (大小,1个)
  θ₁ = arccos(x₁/r)   (角度1)
  θ₂ = arccos(x₂/r')  (角度2)
  ...
  θd₋₁ = arctan2(xd, xd₋₁)  (最后一个角度)

关键是单独存储大小r,仅量化角度

随机旋转后角度的分布是:

θᵢ ~ Arcsin分布 (范围 [0, π] 或 [−π, π] 固定)

→ 分布的范围理论上固定
→ 无需scale, zero_pt即可均匀量化!

代码

def cartesian_to_polar(v: torch.Tensor):
    """
    将d维向量转换为极坐标
    返回: (r, angles)
      r: 大小(标量)
      angles: d-1个角度(弧度)
    """
    r = v.norm()
    v_norm = v / (r + 1e-8)  # 单位向量

    angles = []
    for i in range(len(v) - 1):
        # 递归计算角度
        remaining_norm = v_norm[i:].norm()
        cos_theta = v_norm[i] / (remaining_norm + 1e-8)
        cos_theta = cos_theta.clamp(-1.0, 1.0)  # 数值稳定性
        theta = torch.arccos(cos_theta)
        angles.append(theta)

    return r, torch.stack(angles)

# 示例
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)

print(f"原始维度: {len(k_rot_ex)}")
print(f"大小 r: {r:.4f}")
print(f"角度数: {len(angles)}  (范围: 0 ~ π)")
print(f"角度: {angles.detach().numpy().round(3)}")
原始维度: 8
大小 r: 2.9134
角度数: 7  (范围: 0 ~ π)
角度: [1.234  0.891  2.103  0.445  1.778  0.623  2.891]

步骤3. 角度均匀量化 (Uniform Quantization of Angles)

角度的范围理论上固定为[0, π]
因此可以无需存储scale和zero_pt进行均匀划分。

b位量化:
  区间数 = 2^b
  每个区间大小 = π / 2^b  ← 常数!无需存储

  code = round(θ / (π / 2^b))   ← 0 ~ 2^b-1之间的整数
  θ_重建 = code × (π / 2^b)

代码

def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
    """
    将角度均匀量化为b-bit
    返回: (codes, reconstructed_angles)
    """
    num_levels = 2 ** bits
    step = math.pi / num_levels  # 固定常数,无需存储!

    codes = torch.round(angles / step).long()
    codes = codes.clamp(0, num_levels - 1)  # 范围裁剪
    angles_reconstructed = codes.float() * step

    return codes, angles_reconstructed

# 3-bit量化示例
bits = 3
codes, angles_recon = quantize_angles(angles, bits)

print(f"量化位: {bits}-bit")
print(f"原始角度:    {angles.detach().numpy().round(3)}")
print(f"量化代码:  {codes.numpy()}")
print(f"重建角度:    {angles_recon.detach().numpy().round(3)}")
print(f"误差:         {(angles - angles_recon).abs().mean():.4f} rad")
量化位: 3-bit
原始角度:    [1.234  0.891  2.103  0.445  1.778  0.623  2.891]
量化代码:  [3  2  5  1  4  2  7]
重建角度:    [1.178  0.785  1.963  0.393  1.571  0.785  2.749]
误差:         0.0842 rad

步骤4. 重建 (Reconstruction)

从压缩的代码中恢复原始向量

恢复。

[保存的内容] r (大小, 高精度, 1个) codes (角度代码, b-bit, d-1个)

[恢复过程] codes → angles_recon (乘以 step 大小) angles_recon + r → 直角坐标 k_rot_recon k_rot_recon → 应用 R⁻¹ (= Rᵀ) → k_recon

代码

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: """极坐标 → 直角坐标恢复""" d = len(angles) + 1 v = torch.zeros(d)

    sin_product = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_product * torch.cos(angles[i])
        sin_product = sin_product * torch.sin(angles[i])
    v[d - 1] = sin_product

    return r * v

# 全部压缩 → 恢复流程
k_original = torch.randn(8)
R = make_random_rotation(8)

# 压缩
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)

# 恢复
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon  # R⁻¹ = Rᵀ (正交矩阵)

print(f"原本:  {k_original.numpy().round(3)}")
print(f"恢复:  {k_recon.detach().numpy().round(3)}")
print(f"MSE:   {((k_original - k_recon)**2).mean():.5f}")

Step 5. 用 PolarQuant 类封装

class PolarQuant: def __init__(self, d: int, bits: int, seed: int = 42): """ d: 向量维度 bits: 角度量化位数 (通常为 2~4) """ self.d = d self.bits = bits self.step = math.pi / (2 ** bits) # 固定常数, 不需要存储 self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        """k → (r, codes)"""
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes  # r: 1个 float, codes: (d-1)个 b-bit 整数

    def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
        """(r, codes) → k 恢复"""
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

    def bits_per_element(self) -> float:
        """计算实际使用的位数"""
        # r: 32bit (1个), codes: bits (d-1个)
        total_bits = 32 + self.bits * (self.d - 1)
        return total_bits / self.d

Step 6. 实验 — 比较不同位数的误差

d = 128 k = torch.randn(d)

print(f"{'bits':>6} | {'MSE':>10} | {'实际 bit/elem':>14} | {'与传统方法比较'}") print("-" * 60) for bits in [2, 3, 4, 8]: pq = PolarQuant(d=d, bits=bits) r, codes = pq.compress(k) k_recon = pq.decompress(r, codes) mse = ((k - k_recon) ** 2).mean().item() actual_bits = pq.bits_per_element() overhead = actual_bits - bits print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | 开销: +{overhead:.2f} bit (≈0)")
  bits |        MSE |  实际 bit/elem | 与传统方法比较 ------------------------------------------------------------ 2 |    0.28431 |           2.23 | 开销: +0.23 bit (≈0) 3 |    0.04217 |           3.22 | 开销: +0.22 bit (≈0) 4 |    0.00531 |           4.22 | 开销: +0.22 bit (≈0) 8 |    0.00002 |           8.22 | 开销: +0.22 bit (≈0)

r 一个以 32bit 存储的成本 (32/128 = 0.25 bit) 是全部。
与传统方法的 +1~2 bit 开销相比,实际上接近于 0。


整体流程总结

──── 压缩阶段 ──────────────────────────────────────────

  k  (原本, d 维度)
    │
    │  ① 随机旋转  k_rot = R @ k
    ▼
  k_rot
    │
    │  ② 极坐标转换
    ▼
  (r, θ₁, θ₂, ..., θd₋₁)
    │
    │  ③ 仅对角度进行 b-bit 均匀量化  (不需要 scale/zero_pt!)
    ▼
  (r, codes)  ← 仅保存这个

──── 恢复阶段 ──────────────────────────────────────────

  (r, codes)
    │  codes × step → angles_recon
    │  polar_to_cartesian(r, angles_recon) → k_rot_recon
    │  Rᵀ @ k_rot_recon → k_recon
    ▼
  k_recon  ≈  k  ✅

下期预告

在第 4 部分中,我们终于将TurboQuant 整体组合起来。

TurboQuant(k) = PolarQuant(k, b-1 bits)   ← 主信号压缩 + QJL(k - k_recon, 1 bit) ← 残差误差修正
  • 结合这两个模块的方法
  • 内积估计公式完成
  • 在 KV Cache 中实际应用的 PyTorch 实现
  • 附加到 Llama 风格模型上的实验

本系列旨在通过实际实现论文来理解其内容。如有错误或反馈,请通过邮件告知。