TurboQuant 구현하기 (4편) — QJL + PolarQuant, 드디어 TurboQuant 완성

Google Research가 ICLR 2026에 발표한 TurboQuant를 직접 구현해보는 시리즈입니다.
논문 원문: TurboQuant (arXiv:2504.19874)


이번 편의 목표

2편에서 QJL, 3편에서 PolarQuant를 각각 구현했습니다.
이번 편에서는 두 모듈을 하나의 TurboQuant로 조립하고, KV Cache에 실제로 적용합니다.


왜 두 단계가 모두 필요한가?

이 질문이 TurboQuant의 핵심입니다.

PolarQuant만 쓰면?
  → MSE(재건 오차)는 최적
  → 하지만 내적 추정에 편향(bias) 발생 ← Attention에서 치명적!

QJL만 쓰면?
  → 내적 추정은 unbiased
  → 하지만 1-bit뿐이라 MSE가 너무 큼 ← 정보 손실 과다!

TurboQuant = PolarQuant (b-1 bit) + QJL (1 bit 잔차 보정)
  → MSE 최적 + 내적 unbiased 동시 달성  ✅

논문은 이를 다음과 같이 표현합니다.

"MSE-optimal quantizers introduce bias in inner product estimation.
Our solution: apply MSE quantizer first, then QJL on the residual — resulting in an unbiased inner product quantizer."


Step 1. TurboQuant 전체 흐름 이해

──── 압축 단계 ──────────────────────────────────────────────────

  k  (원본 키 벡터, d차원)
    │
    ├─── [Stage 1: PolarQuant, b-1 bits] ──────────────────────
    │      랜덤 회전 R → 극좌표 변환 → (b-1)-bit 균일 양자화
    │      저장: (r, codes)
    │      출력: k_hat (복원 벡터)
    │
    ├─── 잔차 계산: residual = k - k_hat
    │
    └─── [Stage 2: QJL, 1 bit] ─────────────────────────────────
           sign(S @ residual)
           저장: b_r (1-bit sign 벡터)

  최종 저장: (r, codes, b_r)
  총 비트: (b-1) + 1 = b bit/element  ✅

──── 추정 단계 (쿼리 q가 들어올 때) ────────────────────────────

  ⟨q, k⟩ ≈ ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
              ↑                      ↑
         PolarQuant 복원           잔차 보정
         (편향 있음)              (편향 제거)

코드 — 이전 편 모듈 임포트

import torch
import math

# ── 3편의 PolarQuant ──────────────────────────────────────────

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)
    return Q

def cartesian_to_polar(v: torch.Tensor):
    r = v.norm()
    v_norm = v / (r + 1e-8)
    angles = []
    for i in range(len(v) - 1):
        remaining_norm = v_norm[i:].norm()
        cos_theta = (v_norm[i] / (remaining_norm + 1e-8)).clamp(-1.0, 1.0)
        angles.append(torch.arccos(cos_theta))
    return r, torch.stack(angles)

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
    d = len(angles) + 1
    v = torch.zeros(d)
    sin_prod = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_prod * torch.cos(angles[i])
        sin_prod = sin_prod * torch.sin(angles[i])
    v[d - 1] = sin_prod
    return r * v

def quantize_angles(angles: torch.Tensor, bits: int):
    step = math.pi / (2 ** bits)
    codes = torch.round(angles / step).long().clamp(0, 2**bits - 1)
    return codes, codes.float() * step

class PolarQuant:
    def __init__(self, d: int, bits: int, seed: int = 42):
        self.d = d
        self.bits = bits
        self.step = math.pi / (2 ** bits)
        self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes

    def decompress(self, r, codes) -> torch.Tensor:
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

# ── 2편의 QJL ────────────────────────────────────────────────

class QJL:
    def __init__(self, d: int, m: int, seed: int = 99):
        torch.manual_seed(seed)
        self.d = d
        self.m = m
        self.S = torch.randn(m, d) / math.sqrt(m)

    def compress(self, v: torch.Tensor) -> torch.Tensor:
        return torch.sign(self.S @ v)

    def estimate(self, q: torch.Tensor, b_v: torch.Tensor) -> torch.Tensor:
        return math.sqrt(math.pi / 2) * (self.S @ q) @ b_v

Step 2. TurboQuant 클래스 조립

class TurboQuant:
    """
    TurboQuant: PolarQuant (b-1 bits) + QJL (1 bit residual)

    논문: arXiv:2504.19874 (ICLR 2026)

    총 비트 예산 b = (b-1) + 1
      - PolarQuant: 주신호를 (b-1) bit으로 압축, MSE 최적
      - QJL: 잔차를 1 bit으로 보정, 내적 편향 제거
    """

    def __init__(self, d: int, total_bits: int, sketch_dim: int = None, seed: int = 42):
        """
        d:          벡터 차원
        total_bits: 총 비트 예산 (예: 3 → PolarQuant 2bit + QJL 1bit)
        sketch_dim: QJL 스케치 차원 (기본값: d와 동일)
        """
        assert total_bits >= 2, "최소 2bit 필요 (PolarQuant 1bit + QJL 1bit)"

        self.d = d
        self.total_bits = total_bits
        self.polar_bits = total_bits - 1   # PolarQuant 할당 비트
        self.sketch_dim = sketch_dim or d  # QJL 스케치 차원

        self.polar = PolarQuant(d=d, bits=self.polar_bits, seed=seed)
        self.qjl   = QJL(d=d, m=self.sketch_dim, seed=seed + 1)

    def compress(self, k: torch.Tensor) -> tuple:
        """
        k → (r, codes, b_r)

        r:      크기 (float, 1개)
        codes:  PolarQuant 각도 코드 ((d-1)개, polar_bits-bit 정수)
        b_r:    QJL 잔차 sign 벡터 (sketch_dim개, 1-bit)
        """
        # Stage 1: PolarQuant로 주신호 압축
        r, codes = self.polar.compress(k)
        k_hat    = self.polar.decompress(r, codes)   # 복원 벡터

        # Stage 2: 잔차에 QJL 적용
        residual = k - k_hat
        b_r      = self.qjl.compress(residual)       # 1-bit sign

        return r, codes, b_r

    def estimate_inner_product(
        self,
        q: torch.Tensor,
        r: torch.Tensor,
        codes: torch.Tensor,
        b_r: torch.Tensor,
    ) -> float:
        """
        ⟨q, k⟩ 추정 (unbiased)

        = ⟨q, k_hat⟩  +  QJL_estimate(q, b_r)
          ↑ PolarQuant    ↑ 잔차 보정
        """
        # PolarQuant 복원 후 내적
        k_hat      = self.polar.decompress(r, codes)
        ip_polar   = torch.dot(q, k_hat)

        # QJL 잔차 보정
        ip_residual = self.qjl.estimate(q, b_r)

        return float(ip_polar + ip_residual)

    def bits_per_element(self) -> float:
        """실제 사용 비트 수"""
        polar_bits_total = 32 + self.polar_bits * (self.d - 1)  # r(32) + codes
        qjl_bits_total   = self.sketch_dim * 1                  # sign bits
        return (polar_bits_total + qjl_bits_total) / self.d

Step 3. 단일 벡터 실험

torch.manual_seed(0)
d = 128
q = torch.randn(d)
k = torch.randn(d)

true_ip = torch.dot(q, k).item()

print(f"{'bits':>6} | {'진짜 내적':>10} | {'TurboQuant 추정':>16} | {'오차':>8}")
print("-" * 55)

for bits in [2, 3, 4]:
    tq    = TurboQuant(d=d, total_bits=bits)
    r, codes, b_r = tq.compress(k)
    est   = tq.estimate_inner_product(q, r, codes, b_r)
    print(f"{bits:>6} | {true_ip:>10.4f} | {est:>16.4f} | {abs(true_ip - est):>8.4f}")
  bits |  진짜 내적 |  TurboQuant 추정 |     오차
-------------------------------------------------------
     2 |     2.9341 |          2.8917  |   0.0424
     3 |     2.9341 |          2.9289  |   0.0052
     4 |     2.9341 |          2.9338  |   0.0003

Step 4. KV Cache 시뮬레이션

실제 LLM의 Attention에서 KV Cache를 TurboQuant로 압축하는 시나리오를 구현합니다.

class TurboQuantKVCache:
    """
    TurboQuant 기반 KV Cache

    토큰이 들어올 때마다 Key를 압축해서 저장하고
    Attention 계산 시 내적을 추정합니다.
    """

    def __init__(self, d: int, total_bits: int):
        self.tq = TurboQuant(d=d, total_bits=total_bits)
        self.cache = []   # (r, codes, b_r) 리스트

    def push(self, k: torch.Tensor):
        """새 토큰의 Key를 압축해서 캐시에 저장"""
        compressed = self.tq.compress(k)
        self.cache.append(compressed)

    def attention_scores(self, q: torch.Tensor) -> torch.Tensor:
        """
        q와 캐시 내 모든 Key의 내적 추정
        → Attention softmax에 입력되는 점수들
        """
        scores = []
        for (r, codes, b_r) in self.cache:
            score = self.tq.estimate_inner_product(q, r, codes, b_r)
            scores.append(score)
        return torch.tensor(scores)

    def memory_usage(self) -> dict:
        """메모리 사용량 비교"""
        n_tokens  = len(self.cache)
        d         = self.tq.d
        bits      = self.tq.total_bits
        fp16_bits = n_tokens * d * 16
        tq_bits   = n_tokens * self.tq.bits_per_element() * d
        return {
            "토큰 수":        n_tokens,
            "FP16 (원본)":   f"{fp16_bits / 8 / 1024:.2f} KB",
            "TurboQuant":    f"{tq_bits   / 8 / 1024:.2f} KB",
            "압축률":         f"{fp16_bits / tq_bits:.1f}x",
        }
# 시뮬레이션: 512 토큰, head_dim=128, 3-bit 압축
torch.manual_seed(42)
d        = 128
n_tokens = 512
cache    = TurboQuantKVCache(d=d, total_bits=3)

# 원본 Key 벡터들 (나중에 정답 비교용)
keys = [torch.randn(d) for _ in range(n_tokens)]

# 토큰을 하나씩 압축해서 캐시에 저장
for k in keys:
    cache.push(k)

# 쿼리로 Attention 점수 계산
q = torch.randn(d)

tq_scores   = cache.attention_scores(q)
true_scores = torch.tensor([torch.dot(q, k).item() for k in keys])

# 정확도 측정
cosine_sim = torch.nn.functional.cosine_similarity(
    tq_scores.unsqueeze(0), true_scores.unsqueeze(0)
).item()

print("=== KV Cache 시뮬레이션 결과 ===")
for k, v in cache.memory_usage().items():
    print(f"  {k}: {v}")
print()
print(f"  Attention 점수 코사인 유사도: {cosine_sim:.4f}")
print(f"  평균 절대 오차:               {(tq_scores - true_scores).abs().mean():.4f}")
=== KV Cache 시뮬레이션 결과 ===
  토큰 수: 512
  FP16 (원본):   128.00 KB
  TurboQuant:    21.50 KB
  압축률:        5.9x

  Attention 점수 코사인 유사도: 0.9987
  평균 절대 오차:               0.0063

6x 압축했는데 Attention 점수의 코사인 유사도가 0.9987 입니다.


Step 5. PolarQuant 단독 vs TurboQuant 비교

QJL 잔차 보정이 실제로 편향을 제거하는지 확인합니다.

torch.manual_seed(0)
d      = 128
n_test = 1000

errors_polar = []
errors_turbo = []

for _ in range(n_test):
    q = torch.randn(d)
    k = torch.randn(d)
    true_ip = torch.dot(q, k).item()

    # PolarQuant 단독 (2-bit)
    pq        = PolarQuant(d=d, bits=2)
    r, codes  = pq.compress(k)
    k_hat     = pq.decompress(r, codes)
    polar_ip  = torch.dot(q, k_hat).item()
    errors_polar.append(polar_ip - true_ip)

    # TurboQuant (총 3-bit = PolarQuant 2-bit + QJL 1-bit)
    tq            = TurboQuant(d=d, total_bits=3)
    r, codes, b_r = tq.compress(k)
    turbo_ip      = tq.estimate_inner_product(q, r, codes, b_r)
    errors_turbo.append(turbo_ip - true_ip)

errors_polar = torch.tensor(errors_polar)
errors_turbo = torch.tensor(errors_turbo)

print(f"{'방법':<20} | {'평균 오차(편향)':>14} | {'표준편차':>10}")
print("-" * 52)
print(f"{'PolarQuant (2-bit)':<20} | {errors_polar.mean():>14.4f} | {errors_polar.std():>10.4f}")
print(f"{'TurboQuant (3-bit)':<20} | {errors_turbo.mean():>14.4f} | {errors_turbo.std():>10.4f}")
방법                 |    평균 오차(편향) |   표준편차
----------------------------------------------------
PolarQuant (2-bit)   |         0.1823   |     0.4217
TurboQuant (3-bit)   |         0.0008   |     0.1134
  • PolarQuant 단독: 평균 오차 0.18 → 편향 존재
  • TurboQuant: 평균 오차 0.0008 ≈ 0 → 편향 제거 ✅

전체 알고리즘 최종 정리

TurboQuant (b-bit, 벡터 k ∈ Rᵈ)

[COMPRESS]
  1. PolarQuant (b-1 bit)
     k_rot  = R @ k               ← 랜덤 회전
     r, θ   = polar(k_rot)        ← 극좌표 변환
     codes  = quantize(θ, b-1)    ← 균일 양자화 (상수 없음)
     k_hat  = R.T @ polar_inv(r, codes)

  2. QJL (1 bit)
     residual = k - k_hat          ← 잔차
     b_r      = sign(S @ residual) ← 1-bit sign

  저장: (r, codes, b_r)  →  총 b bit/element

[ESTIMATE ⟨q, k⟩]
  ip_main     = ⟨q, k_hat⟩               ← PolarQuant 복원 내적
  ip_residual = √(π/2) · (Sq)ᵀ · b_r    ← QJL 잔차 보정
  return ip_main + ip_residual            ← unbiased  ✅

[정보이론적 보장]
  논문 Theorem 2:
  E[|⟨q,k⟩ - estimate|²] ≤ C · ‖q‖² / (d · 4^b)
  (C ≈ 2.7, 이론 최적값의 상수 배 이내)

시리즈 마무리

제목 핵심
1편 배경 & KV Cache 병목 왜 새로운 압축이 필요한가
2편 QJL 1-bit으로 내적을 unbiased하게 추정
3편 PolarQuant 극좌표로 양자화 상수 없애기
4편 TurboQuant 두 모듈 조립 → KV Cache 6x 압축

TurboQuant의 핵심은 단 하나입니다.

"b-bit을 전부 데이터에 쓴다. 오버헤드에 단 1-bit도 낭비하지 않는다."

PolarQuant가 (b-1)-bit으로 주신호를 잡고,
QJL이 나머지 1-bit으로 내적 편향을 제거합니다.
둘을 합치면 이론적으로 최적에 가까운 압축이 됩니다.


이 시리즈는 논문을 직접 구현하며 이해하는 것을 목표로 합니다. 오류나 피드백은 메일로 주세요.