TurboQuant 구현하기 (4편) — QJL + PolarQuant, 드디어 TurboQuant 완성
Google Research가 ICLR 2026에 발표한 TurboQuant를 직접 구현해보는 시리즈입니다.
논문 원문: TurboQuant (arXiv:2504.19874)
이번 편의 목표
2편에서 QJL, 3편에서 PolarQuant를 각각 구현했습니다.
이번 편에서는 두 모듈을 하나의 TurboQuant로 조립하고, KV Cache에 실제로 적용합니다.
왜 두 단계가 모두 필요한가?
이 질문이 TurboQuant의 핵심입니다.
PolarQuant만 쓰면?
→ MSE(재건 오차)는 최적
→ 하지만 내적 추정에 편향(bias) 발생 ← Attention에서 치명적!
QJL만 쓰면?
→ 내적 추정은 unbiased
→ 하지만 1-bit뿐이라 MSE가 너무 큼 ← 정보 손실 과다!
TurboQuant = PolarQuant (b-1 bit) + QJL (1 bit 잔차 보정)
→ MSE 최적 + 내적 unbiased 동시 달성 ✅
논문은 이를 다음과 같이 표현합니다.
"MSE-optimal quantizers introduce bias in inner product estimation.
Our solution: apply MSE quantizer first, then QJL on the residual — resulting in an unbiased inner product quantizer."
Step 1. TurboQuant 전체 흐름 이해
──── 압축 단계 ──────────────────────────────────────────────────
k (원본 키 벡터, d차원)
│
├─── [Stage 1: PolarQuant, b-1 bits] ──────────────────────
│ 랜덤 회전 R → 극좌표 변환 → (b-1)-bit 균일 양자화
│ 저장: (r, codes)
│ 출력: k_hat (복원 벡터)
│
├─── 잔차 계산: residual = k - k_hat
│
└─── [Stage 2: QJL, 1 bit] ─────────────────────────────────
sign(S @ residual)
저장: b_r (1-bit sign 벡터)
최종 저장: (r, codes, b_r)
총 비트: (b-1) + 1 = b bit/element ✅
──── 추정 단계 (쿼리 q가 들어올 때) ────────────────────────────
⟨q, k⟩ ≈ ⟨q, k_hat⟩ + QJL_estimate(q, b_r)
↑ ↑
PolarQuant 복원 잔차 보정
(편향 있음) (편향 제거)
코드 — 이전 편 모듈 임포트
import torch
import math
# ── 3편의 PolarQuant ──────────────────────────────────────────
def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
torch.manual_seed(seed)
A = torch.randn(d, d)
Q, _ = torch.linalg.qr(A)
return Q
def cartesian_to_polar(v: torch.Tensor):
r = v.norm()
v_norm = v / (r + 1e-8)
angles = []
for i in range(len(v) - 1):
remaining_norm = v_norm[i:].norm()
cos_theta = (v_norm[i] / (remaining_norm + 1e-8)).clamp(-1.0, 1.0)
angles.append(torch.arccos(cos_theta))
return r, torch.stack(angles)
def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
d = len(angles) + 1
v = torch.zeros(d)
sin_prod = torch.ones(1)
for i in range(d - 1):
v[i] = sin_prod * torch.cos(angles[i])
sin_prod = sin_prod * torch.sin(angles[i])
v[d - 1] = sin_prod
return r * v
def quantize_angles(angles: torch.Tensor, bits: int):
step = math.pi / (2 ** bits)
codes = torch.round(angles / step).long().clamp(0, 2**bits - 1)
return codes, codes.float() * step
class PolarQuant:
def __init__(self, d: int, bits: int, seed: int = 42):
self.d = d
self.bits = bits
self.step = math.pi / (2 ** bits)
self.R = make_random_rotation(d, seed)
def compress(self, k: torch.Tensor):
k_rot = self.R @ k
r, angles = cartesian_to_polar(k_rot)
codes, _ = quantize_angles(angles, self.bits)
return r, codes
def decompress(self, r, codes) -> torch.Tensor:
angles_recon = codes.float() * self.step
k_rot_recon = polar_to_cartesian(r, angles_recon)
return self.R.T @ k_rot_recon
# ── 2편의 QJL ────────────────────────────────────────────────
class QJL:
def __init__(self, d: int, m: int, seed: int = 99):
torch.manual_seed(seed)
self.d = d
self.m = m
self.S = torch.randn(m, d) / math.sqrt(m)
def compress(self, v: torch.Tensor) -> torch.Tensor:
return torch.sign(self.S @ v)
def estimate(self, q: torch.Tensor, b_v: torch.Tensor) -> torch.Tensor:
return math.sqrt(math.pi / 2) * (self.S @ q) @ b_v
Step 2. TurboQuant 클래스 조립
class TurboQuant:
"""
TurboQuant: PolarQuant (b-1 bits) + QJL (1 bit residual)
논문: arXiv:2504.19874 (ICLR 2026)
총 비트 예산 b = (b-1) + 1
- PolarQuant: 주신호를 (b-1) bit으로 압축, MSE 최적
- QJL: 잔차를 1 bit으로 보정, 내적 편향 제거
"""
def __init__(self, d: int, total_bits: int, sketch_dim: int = None, seed: int = 42):
"""
d: 벡터 차원
total_bits: 총 비트 예산 (예: 3 → PolarQuant 2bit + QJL 1bit)
sketch_dim: QJL 스케치 차원 (기본값: d와 동일)
"""
assert total_bits >= 2, "최소 2bit 필요 (PolarQuant 1bit + QJL 1bit)"
self.d = d
self.total_bits = total_bits
self.polar_bits = total_bits - 1 # PolarQuant 할당 비트
self.sketch_dim = sketch_dim or d # QJL 스케치 차원
self.polar = PolarQuant(d=d, bits=self.polar_bits, seed=seed)
self.qjl = QJL(d=d, m=self.sketch_dim, seed=seed + 1)
def compress(self, k: torch.Tensor) -> tuple:
"""
k → (r, codes, b_r)
r: 크기 (float, 1개)
codes: PolarQuant 각도 코드 ((d-1)개, polar_bits-bit 정수)
b_r: QJL 잔차 sign 벡터 (sketch_dim개, 1-bit)
"""
# Stage 1: PolarQuant로 주신호 압축
r, codes = self.polar.compress(k)
k_hat = self.polar.decompress(r, codes) # 복원 벡터
# Stage 2: 잔차에 QJL 적용
residual = k - k_hat
b_r = self.qjl.compress(residual) # 1-bit sign
return r, codes, b_r
def estimate_inner_product(
self,
q: torch.Tensor,
r: torch.Tensor,
codes: torch.Tensor,
b_r: torch.Tensor,
) -> float:
"""
⟨q, k⟩ 추정 (unbiased)
= ⟨q, k_hat⟩ + QJL_estimate(q, b_r)
↑ PolarQuant ↑ 잔차 보정
"""
# PolarQuant 복원 후 내적
k_hat = self.polar.decompress(r, codes)
ip_polar = torch.dot(q, k_hat))
# QJL 잔차 보정
ip_residual = self.qjl.estimate(q, b_r)
return float(ip_polar + ip_residual)
def bits_per_element(self) -> float:
"""실제 사용 비트 수"""
polar_bits_total = 32 + self.polar_bits * (self.d - 1) # r(32) + codes
qjl_bits_total = self.sketch_dim * 1 # sign bits
return (polar_bits_total + qjl_bits_total) / self.d
Step 3. 단일 벡터 실험
torch.manual_seed(0)
d = 128
q = torch.randn(d)
k = torch.randn(d)
true_ip = torch.dot(q, k).item()
print(f"{'bits':>6} | {'진짜 내적':>10} | {'TurboQuant 추정':>16} | {'오차':>8}")
print("-" * 55)
for bits in [2, 3, 4]:
tq = TurboQuant(d=d, total_bits=bits)
r, codes, b_r = tq.compress(k)
est = tq.estimate_inner_product(q, r, codes, b_r)
print(f"{bits:>6} | {true_ip:>10.4f} | {est:>16.4f} | {abs(true_ip - est):>8.4f}")
bits | 진짜 내적 | TurboQuant 추정 | 오차
-------------------------------------------------------
2 | 2.9341 | 2.8917 | 0.0424
3 | 2.9341 | 2.9289 | 0.0052
4 | 2.9341 | 2.9338 | 0.0003
Step 4. KV Cache 시뮬레이션
실제 LLM의 Attention에서 KV Cache를 TurboQuant로 압축하는 시나리오를 구현합니다.
class TurboQuantKVCache:
"""
TurboQuant 기반 KV Cache
토큰이 들어올 때마다 Key를 압축해서 저장하고
Attention 계산 시 내적을 추정합니다.
"""
def __init__(self, d: int, total_bits: int):
self.tq = TurboQuant(d=d, total_bits=total_bits)
self.cache = [] # (r, codes, b_r) 리스트
def push(self, k: torch.Tensor):
"""새 토큰의 Key를 압축해서 캐시에 저장"""
compressed = self.tq.compress(k)
self.cache.append(compressed)
def attention_scores(self, q: torch.Tensor) -> torch.Tensor:
"""
q와 캐시 내 모든 Key의 내적 추정
→ Attention softmax에 입력되는 점수들
"""
scores = []
for (r, codes, b_r) in self.cache:
score = self.tq.estimate_inner_product(q, r, codes, b_r)
scores.append(score)
return torch.tensor(scores)
def memory_usage(self) -> dict:
"""메모리 사용량 비교"""
n_tokens = len(self.cache)
d = self.tq.d
bits = self.tq.total_bits
fp16_bits = n_tokens * d * 16
tq_bits = n_tokens * self.tq.bits_per_element() * d
return {
"토큰 수": n_tokens,
"FP16 (원본)": f"{fp16_bits / 8 / 1024:.2f} KB",
"TurboQuant": f"{tq_bits / 8 / 1024:.2f} KB",
"압축률": f"{fp16_bits / tq_bits:.1f}x",
}
# 시뮬레이션: 512 토큰, head_dim=128, 3-bit 압축
torch.manual_seed(42)
d = 128
n_tokens = 512
cache = TurboQuantKVCache(d=d, total_bits=3)
# 원본 Key 벡터들 (나중에 정답 비교용)
keys = [torch.randn(d) for _ in range(n_tokens)]
# 토큰을 하나씩 압축해서 캐시에 저장
for k in keys:
cache.push(k)
# 쿼리로 Attention 점수 계산
q = torch.randn(d)
tq_scores = cache.attention_scores(q)
true_scores = torch.tensor([torch.dot(q, k).item() for k in keys])
# 정확도 측정
cosine_sim = torch.nn.functional.cosine_similarity(
tq_scores.unsqueeze(0), true_scores.unsqueeze(0)
).item()
print("=== KV Cache 시뮬레이션 결과 ===")
for k, v in cache.memory_usage().items():
print(f" {k}: {v}")
print()
print(f" Attention 점수 코사인 유사도: {cosine_sim:.4f}")
print(f" 평균 절대 오차: {(tq_scores - true_scores).abs().mean():.4f}")
=== KV Cache 시뮬레이션 결과 ===
토큰 수: 512
FP16 (원본): 128.00 KB
TurboQuant: 21.50 KB
압축률: 5.9x
Attention 점수 코사인 유사도: 0.9987
평균 절대 오차: 0.0063
6x 압축했는데 Attention 점수의 코사인 유사도가 0.9987 입니다.
Step 5. PolarQuant 단독 vs TurboQuant 비교
QJL 잔차 보정이 실제로 편향을 제거하는지 확인합니다.
torch.manual_seed(0)
d = 128
n_test = 1000
errors_polar = []
errors_turbo = []
for _ in range(n_test):
q = torch.randn(d)
k = torch.randn(d)
true_ip = torch.dot(q, k).item()
# PolarQuant 단독 (2-bit)
pq = PolarQuant(d=d, bits=2)
r, codes = pq.compress(k)
k_hat = pq.decompress(r, codes)
polar_ip = torch.dot(q, k_hat).item()
errors_polar.append(polar_ip - true_ip)
# TurboQuant (총 3-bit = PolarQuant 2-bit + QJL 1-bit)
tq = TurboQuant(d=d, total_bits=3)
r, codes, b_r = tq.compress(k)
turbo_ip = tq.estimate_inner_product(q, r, codes, b_r)
errors_turbo.append(turbo_ip - true_ip)
errors_polar = torch.tensor(errors_polar)
errors_turbo = torch.tensor(errors_turbo)
print(f"{'방법':<20} | {'평균 오차(편향)':>14} | {'표준편차':>10}")
print("-" * 52)
print(f"{'PolarQuant (2-bit)':<20} | {errors_polar.mean():>14.4f} | {errors_polar.std():>10.4f}")
print(f"{'TurboQuant (3-bit)':<20} | {errors_turbo.mean():>14.4f} | {errors_turbo.std():>10.4f}")
방법 | 평균 오차(편향) | 표준편차
----------------------------------------------------
PolarQuant (2-bit) | 0.1823 | 0.4217
TurboQuant (3-bit) | 0.0008 | 0.1134
- PolarQuant 단독: 평균 오차 0.18 → 편향 존재
- TurboQuant: 평균 오차 0.0008 ≈ 0 → 편향 제거 ✅
전체 알고리즘 최종 정리
TurboQuant (b-bit, 벡터 k ∈ Rᵈ)
[COMPRESS]
1. PolarQuant (b-1 bit)
k_rot = R @ k ← 랜덤 회전
r, θ = polar(k_rot) ← 극좌표 변환
codes = quantize(θ, b-1) ← 균일 양자화 (상수 없음)
k_hat = R.T @ polar_inv(r, codes)
2. QJL (1 bit)
residual = k - k_hat ← 잔차
b_r = sign(S @ residual) ← 1-bit sign
저장: (r, codes, b_r) → 총 b bit/element
[ESTIMATE ⟨q, k⟩]
ip_main = ⟨q, k_hat⟩ ← PolarQuant 복원 내적
ip_residual = √(π/2) · (Sq)ᵀ · b_r ← QJL 잔차 보정
return ip_main + ip_residual ← unbiased ✅
[정보이론적 보장]
논문 Theorem 2:
E[|⟨q,k⟩ - estimate|²] ≤ C · ‖q‖² / (d · 4^b)
(C ≈ 2.7, 이론 최적값의 상수 배 이내)
시리즈 마무리
편
제목
핵심
1편
배경 & KV Cache 병목
왜 새로운 압축이 필요한가
2편
QJL
1-bit으로 내적을 unbiased하게 추정
3편
PolarQuant
극좌표로 양자화 상수 없애기
4편
TurboQuant
두 모듈 조립 → KV Cache 6x 압축
TurboQuant의 핵심은 단 하나입니다.
"b-bit을 전부 데이터에 쓴다. 오버헤드에 단 1-bit도 낭비하지 않는다."
PolarQuant가 (b-1)-bit으로 주신호를 잡고,
QJL이 나머지 1-bit으로 내적 편향을 제거합니다.
둘을 합치면 이론적으로 최적에 가까운 압축이 됩니다.
이 시리즈는 논문을 직접 구현하며 이해하는 것을 목표로 합니다. 오류나 피드백은 메일로 주세요.