实现TurboQuant(第3部分)— PolarQuant,消除极坐标下的量化常数
这是一个直接实现Google Research在ICLR 2026上发布的TurboQuant的系列文章。
论文原文:TurboQuant (arXiv:2504.19874)
PolarQuant论文:arXiv:2502.02617 — AISTATS 2026
本篇目标
在第2部分中,我们学习了如何用QJL将残差误差限制在1-bit。
在本篇中,我们将探讨TurboQuant的第二个核心模块PolarQuant。
PolarQuant的作用是用(b-1) bit压缩关键向量的主信号。
关键问题是一个。
为什么传统方法需要量化常数(scale, zero_point),而PolarQuant不需要?
传统量化的问题 — 为什么需要常数?
一般的均匀量化(uniform quantization)如下操作。
原始向量: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
↓
首先找到最小值和最大值
min = -2.1, max = 3.2
scale = (max - min) / (2^b - 1) ← 必须存储
zero_pt = min ← 必须存储
code = round((x - zero_pt) / scale) ← b-bit整数
问题在于每个块都需要以full precision(32-bit)存储scale和zero_pt。
块越小,开销越大。
块大小 16 → 开销 64bit / 16 = +4.0 bit/element
块大小 64 → 开销 64bit / 64 = +1.0 bit/element
块大小 128 → 开销 64bit / 128 = +0.5 bit/element
PolarQuant的核心理念
PolarQuant通过改变坐标系来解决这个问题。
笛卡尔坐标 (Cartesian) → 极坐标 (Polar)
[x, y, z, ...] [r, θ₁, θ₂, ...]
混合大小和方向 分离大小(r)和角度(θ)
核心洞察如下。
如果先应用随机旋转,极坐标的角度(θ)将接近均匀分布。
均匀分布的最小值/最大值理论上是固定的 → 无需scale, zero_pt!
步骤1. 随机旋转 (Random Preconditioning)
在将向量转换为极坐标之前,先用随机正交矩阵R进行旋转。
k_rot = R @ k
R: 随机正交矩阵(基于Hadamard矩阵快速生成)
R @ Rᵀ = I (由于是旋转,大小保持不变)
为什么先要旋转?
旋转前: 向量的某些维度可能集中值
[-2.1, 0.01, 0.02, 0.01, 3.2, ...] ← 仅首/尾值大
旋转后: 能量在所有维度均匀分布
[0.4, -0.3, 0.5, -0.4, 0.3, ...] ← 均匀分布
当能量均匀分布后,极坐标转换后的角度分布理论上可预测。
代码
import torch
import math
def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
"""
生成随机正交矩阵(使用QR分解)
实际论文基于Hadamard实现更快,但这里为了理解使用QR分解
"""
torch.manual_seed(seed)
A = torch.randn(d, d)
Q, _ = torch.linalg.qr(A) # Q是正交矩阵
return Q # shape: [d, d]
d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])
R = make_random_rotation(d)
k_rot = R @ k
print(f"旋转前: {k.numpy().round(2)}")
print(f"旋转后: {k_rot.detach().numpy().round(2)}")
print(f"大小保持: {k.norm():.4f} → {k_rot.norm():.4f}") # 应该相同
旋转前: [-2.1 0.01 0.02 0.01 3.2 0. -0.1 0.05]
旋转后: [ 0.43 -0.31 0.52 -0.38 0.29 0.41 -0.44 0.35]
大小保持: 3.8210 → 3.8210
步骤2. 极坐标转换 (Polar Transformation)
将旋转后的向量转换为极坐标。
k_rot = [x₁, x₂, x₃, ..., xd]
↓ 极坐标转换
r = ‖k_rot‖ (大小,1个)
θ₁ = arccos(x₁/r) (角度1)
θ₂ = arccos(x₂/r') (角度2)
...
θd₋₁ = arctan2(xd, xd₋₁) (最后一个角度)
关键是单独存储大小r,仅量化角度。
随机旋转后角度的分布是:
θᵢ ~ Arcsin分布 (范围 [0, π] 或 [−π, π] 固定)
→ 分布的范围理论上固定
→ 无需scale, zero_pt即可均匀量化!
代码
def cartesian_to_polar(v: torch.Tensor):
"""
将d维向量转换为极坐标
返回: (r, angles)
r: 大小(标量)
angles: d-1个角度(弧度)
"""
r = v.norm()
v_norm = v / (r + 1e-8) # 单位向量
angles = []
for i in range(len(v) - 1):
# 递归计算角度
remaining_norm = v_norm[i:].norm()
cos_theta = v_norm[i] / (remaining_norm + 1e-8)
cos_theta = cos_theta.clamp(-1.0, 1.0) # 数值稳定性
theta = torch.arccos(cos_theta)
angles.append(theta)
return r, torch.stack(angles)
# 示例
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)
print(f"原始维度: {len(k_rot_ex)}")
print(f"大小 r: {r:.4f}")
print(f"角度数: {len(angles)} (范围: 0 ~ π)")
print(f"角度: {angles.detach().numpy().round(3)}")
原始维度: 8
大小 r: 2.9134
角度数: 7 (范围: 0 ~ π)
角度: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
步骤3. 角度均匀量化 (Uniform Quantization of Angles)
角度的范围理论上固定为[0, π]。
因此可以无需存储scale和zero_pt进行均匀划分。
b位量化:
区间数 = 2^b
每个区间大小 = π / 2^b ← 常数!无需存储
code = round(θ / (π / 2^b)) ← 0 ~ 2^b-1之间的整数
θ_重建 = code × (π / 2^b)
代码
def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
"""
将角度均匀量化为b-bit
返回: (codes, reconstructed_angles)
"""
num_levels = 2 ** bits
step = math.pi / num_levels # 固定常数,无需存储!
codes = torch.round(angles / step).long()
codes = codes.clamp(0, num_levels - 1) # 范围裁剪
angles_reconstructed = codes.float() * step
return codes, angles_reconstructed
# 3-bit量化示例
bits = 3
codes, angles_recon = quantize_angles(angles, bits)
print(f"量化位: {bits}-bit")
print(f"原始角度: {angles.detach().numpy().round(3)}")
print(f"量化代码: {codes.numpy()}")
print(f"重建角度: {angles_recon.detach().numpy().round(3)}")
print(f"误差: {(angles - angles_recon).abs().mean():.4f} rad")
量化位: 3-bit
原始角度: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
量化代码: [3 2 5 1 4 2 7]
重建角度: [1.178 0.785 1.963 0.393 1.571 0.785 2.749]
误差: 0.0842 rad
步骤4. 重建 (Reconstruction)
从压缩的代码中恢复原始向量
恢复。[保存的内容] r (大小, 高精度, 1个) codes (角度代码, b-bit, d-1个)
[恢复过程] codes → angles_recon (乘以 step 大小) angles_recon + r → 直角坐标 k_rot_recon k_rot_recon → 应用 R⁻¹ (= Rᵀ) → k_recon 代码
def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor: """极坐标 → 直角坐标恢复""" d = len(angles) + 1 v = torch.zeros(d)
sin_product = torch.ones(1)
for i in range(d - 1):
v[i] = sin_product * torch.cos(angles[i])
sin_product = sin_product * torch.sin(angles[i])
v[d - 1] = sin_product
return r * v
# 全部压缩 → 恢复流程
k_original = torch.randn(8)
R = make_random_rotation(8)
# 压缩
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)
# 恢复
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon # R⁻¹ = Rᵀ (正交矩阵)
print(f"原本: {k_original.numpy().round(3)}")
print(f"恢复: {k_recon.detach().numpy().round(3)}")
print(f"MSE: {((k_original - k_recon)**2).mean():.5f}") Step 5. 用 PolarQuant 类封装
class PolarQuant: def __init__(self, d: int, bits: int, seed: int = 42): """ d: 向量维度 bits: 角度量化位数 (通常为 2~4) """ self.d = d self.bits = bits self.step = math.pi / (2 ** bits) # 固定常数, 不需要存储 self.R = make_random_rotation(d, seed)
def compress(self, k: torch.Tensor):
"""k → (r, codes)"""
k_rot = self.R @ k
r, angles = cartesian_to_polar(k_rot)
codes, _ = quantize_angles(angles, self.bits)
return r, codes # r: 1个 float, codes: (d-1)个 b-bit 整数
def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
"""(r, codes) → k 恢复"""
angles_recon = codes.float() * self.step
k_rot_recon = polar_to_cartesian(r, angles_recon)
return self.R.T @ k_rot_recon
def bits_per_element(self) -> float:
"""计算实际使用的位数"""
# r: 32bit (1个), codes: bits (d-1个)
total_bits = 32 + self.bits * (self.d - 1)
return total_bits / self.d Step 6. 实验 — 比较不同位数的误差
d = 128 k = torch.randn(d)
print(f"{'bits':>6} | {'MSE':>10} | {'实际 bit/elem':>14} | {'与传统方法比较'}") print("-" * 60) for bits in [2, 3, 4, 8]: pq = PolarQuant(d=d, bits=bits) r, codes = pq.compress(k) k_recon = pq.decompress(r, codes) mse = ((k - k_recon) ** 2).mean().item() actual_bits = pq.bits_per_element() overhead = actual_bits - bits print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | 开销: +{overhead:.2f} bit (≈0)") bits | MSE | 实际 bit/elem | 与传统方法比较 ------------------------------------------------------------ 2 | 0.28431 | 2.23 | 开销: +0.23 bit (≈0) 3 | 0.04217 | 3.22 | 开销: +0.22 bit (≈0) 4 | 0.00531 | 4.22 | 开销: +0.22 bit (≈0) 8 | 0.00002 | 8.22 | 开销: +0.22 bit (≈0)
r一个以 32bit 存储的成本 (32/128 = 0.25 bit) 是全部。
与传统方法的 +1~2 bit 开销相比,实际上接近于 0。
整体流程总结
──── 压缩阶段 ──────────────────────────────────────────
k (原本, d 维度)
│
│ ① 随机旋转 k_rot = R @ k
▼
k_rot
│
│ ② 极坐标转换
▼
(r, θ₁, θ₂, ..., θd₋₁)
│
│ ③ 仅对角度进行 b-bit 均匀量化 (不需要 scale/zero_pt!)
▼
(r, codes) ← 仅保存这个
──── 恢复阶段 ──────────────────────────────────────────
(r, codes)
│ codes × step → angles_recon
│ polar_to_cartesian(r, angles_recon) → k_rot_recon
│ Rᵀ @ k_rot_recon → k_recon
▼
k_recon ≈ k ✅ 下期预告
在第 4 部分中,我们终于将TurboQuant 整体组合起来。
TurboQuant(k) = PolarQuant(k, b-1 bits) ← 主信号压缩 + QJL(k - k_recon, 1 bit) ← 残差误差修正 - 结合这两个模块的方法
- 内积估计公式完成
- 在 KV Cache 中实际应用的 PyTorch 实现
- 附加到 Llama 风格模型上的实验
本系列旨在通过实际实现论文来理解其内容。如有错误或反馈,请通过邮件告知。