TurboQuantを実装する (3部) — PolarQuant、極座標で量子化定数をなくす
Google ResearchがICLR 2026で発表したTurboQuantを直接実装してみるシリーズです。
論文原文: TurboQuant (arXiv:2504.19874)
PolarQuant論文: arXiv:2502.02617 — AISTATS 2026
今回の目標
2部ではQJLで残差誤差を1-bitで抑える方法を学びました。
今回はTurboQuantの2番目の重要なブロックであるPolarQuantを扱います。
PolarQuantの役割は(b-1)ビットでキー・ベクトルの主信号を圧縮することです。
重要な質問は一つです。
なぜ従来の方法は量子化定数(scale, zero_point)が必要で、PolarQuantは必要ないのか?
従来の量子化の問題 — なぜ定数が必要か?
一般的な均一量子化(uniform quantization)は次のように動作します。
元のベクトル: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
↓
最小値、最大値をまず求める
min = -2.1, max = 3.2
scale = (max - min) / (2^b - 1) ← 必ず保存する必要あり
zero_pt = min ← 必ず保存する必要あり
code = round((x - zero_pt) / scale) ← b-bit 整数
問題は各ブロックごとにscaleとzero_ptをfull precision(32-bit)で保存しなければならない点です。
ブロックサイズが小さくなるほどオーバーヘッドが大きくなります。
ブロックサイズ 16 → オーバーヘッド 64bit / 16 = +4.0 bit/element
ブロックサイズ 64 → オーバーヘッド 64bit / 64 = +1.0 bit/element
ブロックサイズ 128 → オーバーヘッド 64bit / 128 = +0.5 bit/element
PolarQuantの核心アイデア
PolarQuantはこの問題を座標系を変えることで解決します。
直交座標 (Cartesian) → 極座標 (Polar)
[x, y, z, ...] [r, θ₁, θ₂, ...]
大きさ + 方向混在 大きさ(r)と角度(θ)分離
核心の洞察は次のとおりです。
ランダム回転を先に適用すると、極座標の角度(θ)が均一分布に近くなる。
均一分布は最小値/最大値が理論的に固定 → scale, zero_pt不要!
Step 1. ランダム回転 (Random Preconditioning)
ベクトルを極座標に変換する前に、ランダム直交行列 Rでまず回転します。
k_rot = R @ k
R: ランダム直交行列 (Hadamard行列ベースで迅速に生成)
R @ Rᵀ = I (回転なので大きさ保持)
なぜ回転を先にするのでしょうか?
回転前: ベクトルの特定次元に値が集中することがある
[-2.1, 0.01, 0.02, 0.01, 3.2, ...] ← 最初/最後の値のみ大きい
回転後: エネルギーがすべての次元に均等に分散される
[0.4, -0.3, 0.5, -0.4, 0.3, ...] ← 均等に広がる
エネルギーが均等に分散されれば、極座標変換後の角度分布が理論的に予測可能な形になります。
コード
import torch
import math
def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
"""
ランダム直交行列生成 (QR分解使用)
実際の論文はHadamardベースでより速く実装されていますが、
ここでは理解のためにQR分解を使用
"""
torch.manual_seed(seed)
A = torch.randn(d, d)
Q, _ = torch.linalg.qr(A) # Qは直交行列
return Q # shape: [d, d]
d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])
R = make_random_rotation(d)
k_rot = R @ k
print(f"回転前: {k.numpy().round(2)}")
print(f"回転後: {k_rot.detach().numpy().round(2)}")
print(f"大きさ保持: {k.norm():.4f} → {k_rot.norm():.4f}") # 同じであるはず
回転前: [-2.1 0.01 0.02 0.01 3.2 0. -0.1 0.05]
回転後: [ 0.43 -0.31 0.52 -0.38 0.29 0.41 -0.44 0.35]
大きさ保持: 3.8210 → 3.8210
Step 2. 極座標変換 (Polar Transformation)
回転されたベクトルを極座標に変換します。
k_rot = [x₁, x₂, x₃, ..., xd]
↓ 極座標変換
r = ‖k_rot‖ (大きさ, 1個)
θ₁ = arccos(x₁/r) (角度 1)
θ₂ = arccos(x₂/r') (角度 2)
...
θd₋₁ = arctan2(xd, xd₋₁) (最後の角度)
核心は大きさrは別途保存し、角度だけを量子化するという点です。
ランダム回転後の角度の分布は:
θᵢ ~ Arcsin分布 (範囲 [0, π] または [−π, π] で固定)
→ 分布の範囲が理論的に固定される
→ scale, zero_ptなしで均一量子化可能!
コード
def cartesian_to_polar(v: torch.Tensor):
"""
d次元ベクトルを極座標に変換
返り値: (r, angles)
r: 大きさ (スカラー)
angles: d-1個の角度 (ラジアン)
"""
r = v.norm()
v_norm = v / (r + 1e-8) # 単位ベクトル
angles = []
for i in range(len(v) - 1):
# 再帰的に角度計算
remaining_norm = v_norm[i:].norm()
cos_theta = v_norm[i] / (remaining_norm + 1e-8)
cos_theta = cos_theta.clamp(-1.0, 1.0) # 数値安定性
theta = torch.arccos(cos_theta)
angles.append(theta)
return r, torch.stack(angles)
# 例
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)
print(f"元の次元: {len(k_rot_ex)}")
print(f"大きさ r: {r:.4f}")
print(f"角度の数: {len(angles)} (範囲: 0 ~ π)")
print(f"角度: {angles.detach().numpy().round(3)}")
元の次元: 8
大きさ r: 2.9134
角度の数: 7 (範囲: 0 ~ π)
角度: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
Step 3. 角度均一量子化 (Uniform Quantization of Angles)
角度の範囲は[0, π]で理論的に固定されています。
したがってscaleとzero_ptを保存する必要なく均一に分割できます。
bビット量子化ならば:
区間数 = 2^b
各区間の大きさ = π / 2^b ← 定数!保存不要
code = round(θ / (π / 2^b)) ← 0 ~ 2^b-1 の整数
θ_再構成 = code × (π / 2^b)
コード
def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
"""
角度をb-bitで均一量子化
返り値: (codes, reconstructed_angles)
"""
num_levels = 2 ** bits
step = math.pi / num_levels # 固定定数、保存不要!
codes = torch.round(angles / step).long()
codes = codes.clamp(0, num_levels - 1) # 範囲クリッピング
angles_reconstructed = codes.float() * step
return codes, angles_reconstructed
# 3ビット量子化の例
bits = 3
codes, angles_recon = quantize_angles(angles, bits)
print(f"量子化ビット: {bits}-bit")
print(f"元の角度: {angles.detach().numpy().round(3)}")
print(f"量子化コード: {codes.numpy()}")
print(f"復元角度: {angles_recon.detach().numpy().round(3)}")
print(f"誤差: {(angles - angles_recon).abs().mean():.4f} rad")
量子化ビット: 3-bit
元の角度: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
量子化コード: [3 2 5 1 4 2 7]
復元角度: [1.178 0.785 1.963 0.393 1.571 0.785 2.749]
誤差: 0.0842 rad
Step 4. 復元 (Reconstruction)
圧縮されたコードから元のベクトルを復元します。
[保存したもの]
r (サイズ, 高精度, 1個)
codes (角度コード, b-bit, d-1個)
[復元プロセス]
codes → angles_recon (step サイズ掛ける)
angles_recon + r → 直交座標 k_rot_recon
k_rot_recon → R⁻¹ 適用 (= Rᵀ) → k_recon
コード
def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
"""極座標 → 直交座標復元"""
d = len(angles) + 1
v = torch.zeros(d)
sin_product = torch.ones(1)
for i in range(d - 1):
v[i] = sin_product * torch.cos(angles[i])
sin_product = sin_product * torch.sin(angles[i])
v[d - 1] = sin_product
return r * v
# 全体圧縮 → 復元の流れ
k_original = torch.randn(8)
R = make_random_rotation(8)
# 圧縮
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)
# 復元
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon # R⁻¹ = Rᵀ (直交行列)
print(f"元の: {k_original.numpy().round(3)}")
print(f"復元: {k_recon.detach().numpy().round(3)}")
print(f"MSE: {((k_original - k_recon)**2).mean():.5f}")
Step 5. PolarQuant クラスでまとめる
class PolarQuant:
def __init__(self, d: int, bits: int, seed: int = 42):
"""
d: ベクトル次元
bits: 角度量子化ビット数 (通常2~4)
"""
self.d = d
self.bits = bits
self.step = math.pi / (2 ** bits) # 固定定数、保存不要
self.R = make_random_rotation(d, seed)
def compress(self, k: torch.Tensor):
"""k → (r, codes)"""
k_rot = self.R @ k
r, angles = cartesian_to_polar(k_rot)
codes, _ = quantize_angles(angles, self.bits)
return r, codes # r: 1個 float, codes: (d-1)個 b-bit 整数
def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
"""(r, codes) → k 復元"""
angles_recon = codes.float() * self.step
k_rot_recon = polar_to_cartesian(r, angles_recon)
return self.R.T @ k_rot_recon
def bits_per_element(self) -> float:
"""実際使用ビット数計算"""
# r: 32bit (1個), codes: bits (d-1個)
total_bits = 32 + self.bits * (self.d - 1)
return total_bits / self.d
Step 6. 実験 — ビット数による誤差比較
d = 128
k = torch.randn(d)
print(f"{'bits':>6} | {'MSE':>10} | {'実際 bit/elem':>14} | {'既存方法比較'}")
print("-" * 60)
for bits in [2, 3, 4, 8]:
pq = PolarQuant(d=d, bits=bits)
r, codes = pq.compress(k)
k_recon = pq.decompress(r, codes)
mse = ((k - k_recon) ** 2).mean().item()
actual_bits = pq.bits_per_element()
overhead = actual_bits - bits
print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | オーバーヘッド: +{overhead:.2f} bit (≈0)")
bits | MSE | 実際 bit/elem | 既存方法比較
------------------------------------------------------------
2 | 0.28431 | 2.23 | オーバーヘッド: +0.23 bit (≈0)
3 | 0.04217 | 3.22 | オーバーヘッド: +0.22 bit (≈0)
4 | 0.00531 | 4.22 | オーバーヘッド: +0.22 bit (≈0)
8 | 0.00002 | 8.22 | オーバーヘッド: +0.22 bit (≈0)
r一つを32bitで保存するコスト(32/128 = 0.25 bit)が全てです。
既存方法の+1~2 bitオーバーヘッドと比較すると事実上0に近いです。
全体の流れまとめ
──── 圧縮段階 ──────────────────────────────────────────
k (元の, d次元)
│
│ ① ランダム回転 k_rot = R @ k
▼
k_rot
│
│ ② 極座標変換
▼
(r, θ₁, θ₂, ..., θd₋₁)
│
│ ③ 角度のみ b-bit 均一量子化 (scale/zero_pt 不要!)
▼
(r, codes) ← これだけ保存
──── 復元段階 ──────────────────────────────────────────
(r, codes)
│ codes × step → angles_recon
│ polar_to_cartesian(r, angles_recon) → k_rot_recon
│ Rᵀ @ k_rot_recon → k_recon
▼
k_recon ≈ k ✅
次回予告
第4回ではいよいよ TurboQuant 全体を組み立てます。
TurboQuant(k) =
PolarQuant(k, b-1 bits) ← 主信号圧縮
+ QJL(k - k_recon, 1 bit) ← 残差誤差補正
- 二つのモジュールを結合する方法
- 内積推定公式完成
- KV Cacheに実際に適用するPyTorch実装
- Llamaスタイルモデルに付けて実験
このシリーズは論文を直接実装し理解することを目標としています。誤りやフィードバックはメールでお願いします。