TurboQuantを実装する (3部) — PolarQuant、極座標で量子化定数をなくす

Google ResearchがICLR 2026で発表したTurboQuantを直接実装してみるシリーズです。
論文原文: TurboQuant (arXiv:2504.19874)
PolarQuant論文: arXiv:2502.02617 — AISTATS 2026


今回の目標

2部ではQJLで残差誤差を1-bitで抑える方法を学びました。
今回はTurboQuantの2番目の重要なブロックであるPolarQuantを扱います。

PolarQuantの役割は(b-1)ビットでキー・ベクトルの主信号を圧縮することです。
重要な質問は一つです。

なぜ従来の方法は量子化定数(scale, zero_point)が必要で、PolarQuantは必要ないのか?


従来の量子化の問題 — なぜ定数が必要か?

一般的な均一量子化(uniform quantization)は次のように動作します。

元のベクトル: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
              ↓
最小値、最大値をまず求める
  min = -2.1,  max = 3.2

  scale    = (max - min) / (2^b - 1)   ← 必ず保存する必要あり
  zero_pt  = min                        ← 必ず保存する必要あり

  code = round((x - zero_pt) / scale)  ← b-bit 整数

問題は各ブロックごとにscalezero_ptfull precision(32-bit)で保存しなければならない点です。
ブロックサイズが小さくなるほどオーバーヘッドが大きくなります。

ブロックサイズ  16  →  オーバーヘッド 64bit / 16  =  +4.0 bit/element
ブロックサイズ  64  →  オーバーヘッド 64bit / 64  =  +1.0 bit/element
ブロックサイズ 128  →  オーバーヘッド 64bit / 128 =  +0.5 bit/element

PolarQuantの核心アイデア

PolarQuantはこの問題を座標系を変えることで解決します。

直交座標 (Cartesian)  →  極座標 (Polar)
  [x, y, z, ...]           [r, θ₁, θ₂, ...]
  大きさ + 方向混在        大きさ(r)と角度(θ)分離

核心の洞察は次のとおりです。

ランダム回転を先に適用すると、極座標の角度(θ)が均一分布に近くなる。
均一分布は最小値/最大値が理論的に固定 → scale, zero_pt不要!


Step 1. ランダム回転 (Random Preconditioning)

ベクトルを極座標に変換する前に、ランダム直交行列 Rでまず回転します。

k_rot = R @ k

R: ランダム直交行列 (Hadamard行列ベースで迅速に生成)
   R @ Rᵀ = I  (回転なので大きさ保持)

なぜ回転を先にするのでしょうか?

回転前: ベクトルの特定次元に値が集中することがある
          [-2.1, 0.01, 0.02, 0.01, 3.2, ...]  ← 最初/最後の値のみ大きい

回転後: エネルギーがすべての次元に均等に分散される
          [0.4, -0.3, 0.5, -0.4, 0.3, ...]   ← 均等に広がる

エネルギーが均等に分散されれば、極座標変換後の角度分布が理論的に予測可能な形になります。

コード

import torch
import math

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    """
    ランダム直交行列生成 (QR分解使用)
    実際の論文はHadamardベースでより速く実装されていますが、
    ここでは理解のためにQR分解を使用
    """
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)  # Qは直交行列
    return Q  # shape: [d, d]

d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])

R = make_random_rotation(d)
k_rot = R @ k

print(f"回転前: {k.numpy().round(2)}")
print(f"回転後: {k_rot.detach().numpy().round(2)}")
print(f"大きさ保持: {k.norm():.4f} → {k_rot.norm():.4f}")  # 同じであるはず
回転前: [-2.1   0.01  0.02  0.01  3.2   0.    -0.1   0.05]
回転後: [ 0.43 -0.31  0.52 -0.38  0.29  0.41 -0.44  0.35]
大きさ保持: 3.8210 → 3.8210

Step 2. 極座標変換 (Polar Transformation)

回転されたベクトルを極座標に変換します。

k_rot = [x₁, x₂, x₃, ..., xd]
           ↓  極座標変換
  r  = ‖k_rot‖        (大きさ, 1個)
  θ₁ = arccos(x₁/r)   (角度 1)
  θ₂ = arccos(x₂/r')  (角度 2)
  ...
  θd₋₁ = arctan2(xd, xd₋₁)  (最後の角度)

核心は大きさrは別途保存し、角度だけを量子化するという点です。

ランダム回転後の角度の分布は:

θᵢ ~ Arcsin分布 (範囲 [0, π] または [−π, π] で固定)

→ 分布の範囲が理論的に固定される
→ scale, zero_ptなしで均一量子化可能!

コード

def cartesian_to_polar(v: torch.Tensor):
    """
    d次元ベクトルを極座標に変換
    返り値: (r, angles)
      r: 大きさ (スカラー)
      angles: d-1個の角度 (ラジアン)
    """
    r = v.norm()
    v_norm = v / (r + 1e-8)  # 単位ベクトル

    angles = []
    for i in range(len(v) - 1):
        # 再帰的に角度計算
        remaining_norm = v_norm[i:].norm()
        cos_theta = v_norm[i] / (remaining_norm + 1e-8)
        cos_theta = cos_theta.clamp(-1.0, 1.0)  # 数値安定性
        theta = torch.arccos(cos_theta)
        angles.append(theta)

    return r, torch.stack(angles)

# 例
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)

print(f"元の次元: {len(k_rot_ex)}")
print(f"大きさ r: {r:.4f}")
print(f"角度の数: {len(angles)}  (範囲: 0 ~ π)")
print(f"角度: {angles.detach().numpy().round(3)}")
元の次元: 8
大きさ r: 2.9134
角度の数: 7  (範囲: 0 ~ π)
角度: [1.234  0.891  2.103  0.445  1.778  0.623  2.891]

Step 3. 角度均一量子化 (Uniform Quantization of Angles)

角度の範囲は[0, π]で理論的に固定されています。
したがってscaleとzero_ptを保存する必要なく均一に分割できます。

bビット量子化ならば:
  区間数 = 2^b
  各区間の大きさ = π / 2^b  ← 定数!保存不要

  code = round(θ / (π / 2^b))   ← 0 ~ 2^b-1 の整数
  θ_再構成 = code × (π / 2^b)

コード

def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
    """
    角度をb-bitで均一量子化
    返り値: (codes, reconstructed_angles)
    """
    num_levels = 2 ** bits
    step = math.pi / num_levels  # 固定定数、保存不要!

    codes = torch.round(angles / step).long()
    codes = codes.clamp(0, num_levels - 1)  # 範囲クリッピング
    angles_reconstructed = codes.float() * step

    return codes, angles_reconstructed

# 3ビット量子化の例
bits = 3
codes, angles_recon = quantize_angles(angles, bits)

print(f"量子化ビット: {bits}-bit")
print(f"元の角度:    {angles.detach().numpy().round(3)}")
print(f"量子化コード:  {codes.numpy()}")
print(f"復元角度:    {angles_recon.detach().numpy().round(3)}")
print(f"誤差:         {(angles - angles_recon).abs().mean():.4f} rad")
量子化ビット: 3-bit
元の角度:    [1.234  0.891  2.103  0.445  1.778  0.623  2.891]
量子化コード:  [3  2  5  1  4  2  7]
復元角度:    [1.178  0.785  1.963  0.393  1.571  0.785  2.749]
誤差:         0.0842 rad

Step 4. 復元 (Reconstruction)

圧縮されたコードから元のベクトルを復元します。

[保存したもの]
  r (サイズ, 高精度, 1個)
  codes (角度コード, b-bit, d-1個)

[復元プロセス]
  codes → angles_recon (step サイズ掛ける)
  angles_recon + r → 直交座標 k_rot_recon
  k_rot_recon → R⁻¹ 適用 (= Rᵀ) → k_recon

コード

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
    """極座標 → 直交座標復元"""
    d = len(angles) + 1
    v = torch.zeros(d)

    sin_product = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_product * torch.cos(angles[i])
        sin_product = sin_product * torch.sin(angles[i])
    v[d - 1] = sin_product

    return r * v

# 全体圧縮 → 復元の流れ
k_original = torch.randn(8)
R = make_random_rotation(8)

# 圧縮
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)

# 復元
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon  # R⁻¹ = Rᵀ (直交行列)

print(f"元の:  {k_original.numpy().round(3)}")
print(f"復元:  {k_recon.detach().numpy().round(3)}")
print(f"MSE:   {((k_original - k_recon)**2).mean():.5f}")

Step 5. PolarQuant クラスでまとめる

class PolarQuant:
    def __init__(self, d: int, bits: int, seed: int = 42):
        """
        d:    ベクトル次元
        bits: 角度量子化ビット数 (通常2~4)
        """
        self.d = d
        self.bits = bits
        self.step = math.pi / (2 ** bits)  # 固定定数、保存不要
        self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        """k → (r, codes)"""
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes  # r: 1個 float, codes: (d-1)個 b-bit 整数

    def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
        """(r, codes) → k 復元"""
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

    def bits_per_element(self) -> float:
        """実際使用ビット数計算"""
        # r: 32bit (1個), codes: bits (d-1個)
        total_bits = 32 + self.bits * (self.d - 1)
        return total_bits / self.d

Step 6. 実験 — ビット数による誤差比較

d = 128
k = torch.randn(d)

print(f"{'bits':>6} | {'MSE':>10} | {'実際 bit/elem':>14} | {'既存方法比較'}")
print("-" * 60)
for bits in [2, 3, 4, 8]:
    pq = PolarQuant(d=d, bits=bits)
    r, codes = pq.compress(k)
    k_recon = pq.decompress(r, codes)
    mse = ((k - k_recon) ** 2).mean().item()
    actual_bits = pq.bits_per_element()
    overhead = actual_bits - bits
    print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | オーバーヘッド: +{overhead:.2f} bit (≈0)")
  bits |        MSE |  実際 bit/elem | 既存方法比較
------------------------------------------------------------
     2 |    0.28431 |           2.23 | オーバーヘッド: +0.23 bit (≈0)
     3 |    0.04217 |           3.22 | オーバーヘッド: +0.22 bit (≈0)
     4 |    0.00531 |           4.22 | オーバーヘッド: +0.22 bit (≈0)
     8 |    0.00002 |           8.22 | オーバーヘッド: +0.22 bit (≈0)

r 一つを32bitで保存するコスト(32/128 = 0.25 bit)が全てです。
既存方法の+1~2 bitオーバーヘッドと比較すると事実上0に近いです。


全体の流れまとめ

──── 圧縮段階 ──────────────────────────────────────────

  k  (元の, d次元)
    │
    │  ① ランダム回転  k_rot = R @ k
    ▼
  k_rot
    │
    │  ② 極座標変換
    ▼
  (r, θ₁, θ₂, ..., θd₋₁)
    │
    │  ③ 角度のみ b-bit 均一量子化  (scale/zero_pt 不要!)
    ▼
  (r, codes)  ← これだけ保存

──── 復元段階 ──────────────────────────────────────────

  (r, codes)
    │  codes × step → angles_recon
    │  polar_to_cartesian(r, angles_recon) → k_rot_recon
    │  Rᵀ @ k_rot_recon → k_recon
    ▼
  k_recon  ≈  k  ✅

次回予告

第4回ではいよいよ TurboQuant 全体を組み立てます。

TurboQuant(k) =
  PolarQuant(k, b-1 bits)   ← 主信号圧縮
  + QJL(k - k_recon, 1 bit) ← 残差誤差補正
  • 二つのモジュールを結合する方法
  • 内積推定公式完成
  • KV Cacheに実際に適用するPyTorch実装
  • Llamaスタイルモデルに付けて実験

このシリーズは論文を直接実装し理解することを目標としています。誤りやフィードバックはメールでお願いします。