TurboQuantを実装する (第4回) — QJL + PolarQuant、ついにTurboQuant完成

Google ResearchがICLR 2026で発表したTurboQuantを直接実装するシリーズです。
論文原文: TurboQuant (arXiv:2504.19874)


今回の目標

第2回でQJL、第3回でPolarQuantをそれぞれ実装しました。
今回は2つのモジュールを1つのTurboQuantに組み立て、KVキャッシュに実際に適用します。


なぜ2段階が両方必要なのか?

この質問がTurboQuantの核心です。

PolarQuantだけ使うと?
  → MSE(再構築誤差)は最適
  → しかし内積推定にバイアスが発生 ← Attentionで致命的!

QJLだけ使うと?
  → 内積推定はunbiased
  → でも1-bitしかないのでMSEが大きすぎ ← 情報損失が多すぎ!

TurboQuant = PolarQuant (b-1 bit) + QJL (1 bit 残差補正)
  → MSE最適 + 内積unbiased同時達成  ✅

論文はこれを以下のように表現します。

"MSE-optimal quantizers introduce bias in inner product estimation.
Our solution: apply MSE quantizer first, then QJL on the residual — resulting in an unbiased inner product quantizer."


Step 1. TurboQuant全体の流れを理解する

──── 圧縮段階 ──────────────────────────────────────────────────

  k  (元のキー ベクトル, d次元)
    │
    ├─── [Stage 1: PolarQuant, b-1 bits] ──────────────────────
    │      ランダム回転 R → 極座標変換 → (b-1)-bit 均一量子化
    │      保存: (r, codes)
    │      出力: k_hat (復元ベクトル)
    │
    ├─── 残差計算: residual = k - k_hat
    │
    └─── [Stage 2: QJL, 1 bit] ─────────────────────────────────
           sign(S @ residual)
           保存: b_r (1-bit sign ベクトル)

  最終保存: (r, codes, b_r)
  総ビット: (b-1) + 1 = b bit/element  ✅

──── 推定段階 (クエリ qが入ってきたとき) ────────────────────────────

  ⟨q, k⟩ ≈ ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
              ↑                      ↑
         PolarQuant復元           残差補正
         (バイアスあり)          (バイアス除去)

コード — 前回のモジュールインポート

import torch
import math

# ── 第3回のPolarQuant ──────────────────────────────────────────

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)
    return Q

def cartesian_to_polar(v: torch.Tensor):
    r = v.norm()
    v_norm = v / (r + 1e-8)
    angles = []
    for i in range(len(v) - 1):
        remaining_norm = v_norm[i:].norm()
        cos_theta = (v_norm[i] / (remaining_norm + 1e-8)).clamp(-1.0, 1.0)
        angles.append(torch.arccos(cos_theta))
    return r, torch.stack(angles)

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
    d = len(angles) + 1
    v = torch.zeros(d)
    sin_prod = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_prod * torch.cos(angles[i])
        sin_prod = sin_prod * torch.sin(angles[i])
    v[d - 1] = sin_prod
    return r * v

def quantize_angles(angles: torch.Tensor, bits: int):
    step = math.pi / (2 ** bits)
    codes = torch.round(angles / step).long().clamp(0, 2**bits - 1)
    return codes, codes.float() * step

class PolarQuant:
    def __init__(self, d: int, bits: int, seed: int = 42):
        self.d = d
        self.bits = bits
        self.step = math.pi / (2 ** bits)
        self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes

    def decompress(self, r, codes) -> torch.Tensor:
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

# ── 第2回のQJL ────────────────────────────────────────────────

class QJL:
    def __init__(self, d: int, m: int, seed: int = 99):
        torch.manual_seed(seed)
        self.d = d
        self.m = m
        self.S = torch.randn(m, d) / math.sqrt(m)

    def compress(self, v: torch.Tensor) -> torch.Tensor:
        return torch.sign(self.S @ v)

    def estimate(self, q: torch.Tensor, b_v: torch.Tensor) -> torch.Tensor:
        return math.sqrt(math.pi / 2) * (self.S @ q) @ b_v

Step 2. TurboQuantクラスの組み立て

class TurboQuant:
    """
    TurboQuant: PolarQuant (b-1 bits) + QJL (1 bit residual)

    論文: arXiv:2504.19874 (ICLR 2026)

    総ビット予算 b = (b-1) + 1
      - PolarQuant: 主信号を(b-1) bitで圧縮、MSE最適
      - QJL: 残差を1 bitで補正、内積バイアス除去
    """

    def __init__(self, d: int, total_bits: int, sketch_dim: int = None, seed: int = 42):
        """
        d:          ベクトル次元
        total_bits: 総ビット予算 (例: 3 → PolarQuant 2bit + QJL 1bit)
        sketch_dim: QJLスケッチ次元 (デフォルト: dと同じ)
        """
        assert total_bits >= 2, "最低2bit必要 (PolarQuant 1bit + QJL 1bit)"

        self.d = d
        self.total_bits = total_bits
        self.polar_bits = total_bits - 1   # PolarQuant割当ビット
        self.sketch_dim = sketch_dim or d  # QJLスケッチ次元

        self.polar = PolarQuant(d=d, bits=self.polar_bits, seed=seed)
        self.qjl   = QJL(d=d, m=self.sketch_dim, seed=seed + 1)

    def compress(self, k: torch.Tensor) -> tuple:
        """
        k → (r, codes, b_r)

        r:      大きさ (float, 1個)
        codes:  PolarQuant角度コード ((d-1)個, polar_bits-bit整数)
        b_r:    QJL残差signベクトル (sketch_dim個, 1-bit)
        """
        # Stage 1: PolarQuantで主信号圧縮
        r, codes = self.polar.compress(k)
        k_hat    = self.polar.decompress(r, codes)   # 復元ベクトル

        # Stage 2: 残差にQJL適用
        residual = k - k_hat
        b_r      = self.qjl.compress(residual)       # 1-bit sign

        return r, codes, b_r

    def estimate_inner_product(
        self,
        q: torch.Tensor,
        r: torch.Tensor,
        codes: torch.Tensor,
        b_r: torch.Tensor,
    ) -> float:
        """
        ⟨q, k⟩推定 (unbiased)

        = ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
          ↑ PolarQuant    ↑ 残差補正
        """
        # PolarQuant復元後内積
        k_hat      = self.polar.decompress(r, codes)
        ip_polar   = torch.dot(q, k_hat)

        # QJL 残差補正
        ip_residual = self.qjl.estimate(q, b_r)

        return float(ip_polar + ip_residual)

    def bits_per_element(self) -> float:
        """実際に使用するビット数"""
        polar_bits_total = 32 + self.polar_bits * (self.d - 1)  # r(32) + codes
        qjl_bits_total   = self.sketch_dim * 1                  # sign bits
        return (polar_bits_total + qjl_bits_total) / self.d

Step 3. 単一ベクトル実験

torch.manual_seed(0)
d = 128
q = torch.randn(d)
k = torch.randn(d)

true_ip = torch.dot(q, k).item()

print(f"{'bits':>6} | {'真の内積':>10} | {'TurboQuant 推定':>16} | {'誤差':>8}")
print("-" * 55)

for bits in [2, 3, 4]:
    tq    = TurboQuant(d=d, total_bits=bits)
    r, codes, b_r = tq.compress(k)
    est   = tq.estimate_inner_product(q, r, codes, b_r)
    print(f"{bits:>6} | {true_ip:>10.4f} | {est:>16.4f} | {abs(true_ip - est):>8.4f}")
  bits |  真の内積 |  TurboQuant 推定 |     誤差
-------------------------------------------------------
     2 |     2.9341 |          2.8917  |   0.0424
     3 |     2.9341 |          2.9289  |   0.0052
     4 |     2.9341 |          2.9338  |   0.0003

Step 4. KV Cache シミュレーション

実際の LLM の Attention で KV Cache を TurboQuant で圧縮するシナリオを実装します。

class TurboQuantKVCache:
    """
    TurboQuant 基盤 KV Cache

    トークンが入るたびに Key を圧縮して保存し、
    Attention 計算時に内積を推定します。
    """

    def __init__(self, d: int, total_bits: int):
        self.tq = TurboQuant(d=d, total_bits=total_bits)
        self.cache = []   # (r, codes, b_r) リスト

    def push(self, k: torch.Tensor):
        """新しいトークンの Key を圧縮してキャッシュに保存"""
        compressed = self.tq.compress(k)
        self.cache.append(compressed)

    def attention_scores(self, q: torch.Tensor) -> torch.Tensor:
        """
        q とキャッシュ内のすべての Key の内積推定
        → Attention softmax に入力されるスコア
        """
        scores = []
        for (r, codes, b_r) in self.cache:
            score = self.tq.estimate_inner_product(q, r, codes, b_r)
            scores.append(score)
        return torch.tensor(scores)

    def memory_usage(self) -> dict:
        """メモリ使用量比較"""
        n_tokens  = len(self.cache)
        d         = self.tq.d
        bits      = self.tq.total_bits
        fp16_bits = n_tokens * d * 16
        tq_bits   = n_tokens * self.tq.bits_per_element() * d
        return {
            "トークン数":        n_tokens,
            "FP16 (元)":   f"{fp16_bits / 8 / 1024:.2f} KB",
            "TurboQuant":    f"{tq_bits   / 8 / 1024:.2f} KB",
            "圧縮率":         f"{fp16_bits / tq_bits:.1f}x",
        }
# シミュレーション: 512 トークン, head_dim=128, 3-bit 圧縮
torch.manual_seed(42)
d        = 128
n_tokens = 512
cache    = TurboQuantKVCache(d=d, total_bits=3)

# 元の Key ベクトル (後で正解比較用)
keys = [torch.randn(d) for _ in range(n_tokens)]

# トークンを1つずつ圧縮してキャッシュに保存
for k in keys:
    cache.push(k)

# クエリで Attention スコア計算
q = torch.randn(d)

tq_scores   = cache.attention_scores(q)
true_scores = torch.tensor([torch.dot(q, k).item() for k in keys])

# 精度測定
cosine_sim = torch.nn.functional.cosine_similarity(
    tq_scores.unsqueeze(0), true_scores.unsqueeze(0)
).item()

print("=== KV Cache シミュレーション結果 ===")
for k, v in cache.memory_usage().items():
    print(f"  {k}: {v}")
print()
print(f"  Attention スコアコサイン類似度: {cosine_sim:.4f}")
print(f"  平均絶対誤差:               {(tq_scores - true_scores).abs().mean():.4f}")
=== KV Cache シミュレーション結果 ===
  トークン数: 512
  FP16 (元):   128.00 KB
  TurboQuant:    21.50 KB
  圧縮率:        5.9x

  Attention スコアコサイン類似度: 0.9987
  平均絶対誤差:               0.0063

6x 圧縮したのに Attention スコアのコサイン類似度が0.9987です。


Step 5. PolarQuant 単独 vs TurboQuant 比較

QJL 残差補正が実際にバイアスを除去するか確認します。

torch.manual_seed(0)
d      = 128
n_test = 1000

errors_polar = []
errors_turbo = []

for _ in range(n_test):
    q = torch.randn(d)
    k = torch.randn(d)
    true_ip = torch.dot(q, k).item()

    # PolarQuant 単独 (2-bit)
    pq        = PolarQuant(d=d, bits=2)
    r, codes  = pq.compress(k)
    k_hat     = pq.decompress(r, codes)
    polar_ip  = torch.dot(q, k_hat).item()
    errors_polar.append(polar_ip - true_ip)

    # TurboQuant (合計 3-bit = PolarQuant 2-bit + QJL 1-bit)
    tq            = TurboQuant(d=d, total_bits=3)
    r, codes, b_r = tq.compress(k)
    turbo_ip      = tq.estimate_inner_product(q, r, codes, b_r)
    errors_turbo.append(turbo_ip - true_ip)

errors_polar = torch.tensor(errors_polar)
errors_turbo = torch.tensor(errors_turbo)

print(f"{'方法':<20} | {'平均誤差(バイアス)':>14} | {'標準偏差':>10}")
print("-" * 52)
print(f"{'PolarQuant (2-bit)':<20} | {errors_polar.mean():>14.4f} | {errors_polar.std():>10.4f}")
print(f"{'TurboQuant (3-bit)':<20} | {errors_turbo.mean():>14.4f} | {errors_turbo.std():>10.4f}")
方法                 |    平均誤差(バイアス) |   標準偏差
----------------------------------------------------
PolarQuant (2-bit)   |         0.1823   |     0.4217
TurboQuant (3-bit)   |         0.0008   |     0.1134
  • PolarQuant 単独: 平均誤差 0.18 → バイアス存在
  • TurboQuant: 平均誤差 0.0008 ≈ 0 → バイアス除去 ✅

全体アルゴリズム最終まとめ

TurboQuant (b-bit, ベクトル k ∈ Rᵈ)

[COMPRESS]
  1. PolarQuant (b-1 bit)
     k_rot  = R @ k               ← ランダム回転
     r, θ   = polar(k_rot)        ← 極座標変換
     codes  = quantize(θ, b-1)    ← 均一量子化 (定数なし)
     k_hat  = R.T @ polar_inv(r, codes)

  2. QJL (1 bit)
     residual = k - k_hat          ← 残差
     b_r      = sign(S @ residual) ← 1-bit sign

  保存: (r, codes, b_r)  →  合計 b bit/element

[ESTIMATE ⟨q, k⟩]
  ip_main     = ⟨q, k_hat⟩               ← PolarQuant 復元内積
  ip_residual = √(π/2) · (Sq)ᵀ · b_r    ← QJL 残差補正
  return ip_main + ip_residual            ← unbiased  ✅

[情報理論的保証]
  論文 Theorem 2:
  E[|⟨q,k⟩ - estimate|²] ≤ C · ‖q‖² / (d · 4^b)
  (C ≈ 2.7, 理論最適値の定数倍以内)

シリーズまとめ

タイトル 要点
第1編 背景 & KV Cache ボトルネック なぜ新しい圧縮が必要か
第2編 QJL 1-bitで内積をunbiasedに推定
第3編 PolarQuant 極座標で量子化定数を除去
第4編 TurboQuant 2つのモジュールを組み合わせ → KV Cache 6x 圧縮

TurboQuant の核心はただ1つです。

"b-bit を全てデータに使う。オーバーヘッドに1-bitも無駄にしない。"

PolarQuant が (b-1)-bit で主信号を捉え、
QJL が残りの1-bitで内積バイアスを除去します。
これを合わせると理論的に最適に近い圧縮が可能です。


このシリーズは論文を直接実装し理解することを目的としています。誤りやフィードバックはメールでお願いします。