TurboQuant 구현하기 (3편) — PolarQuant, 극좌표로 양자화 상수를 없애다
Google Research가 ICLR 2026에 발표한 TurboQuant를 직접 구현해보는 시리즈입니다.
논문 원문: TurboQuant (arXiv:2504.19874)
PolarQuant 논문: arXiv:2502.02617 — AISTATS 2026
이번 편의 목표
2편에서 QJL로 잔차 오차를 1-bit으로 잡는 방법을 배웠습니다.
이번 편에서는 TurboQuant의 두 번째 핵심 블록인 PolarQuant 를 다룹니다.
PolarQuant의 역할은 (b-1) bit으로 키 벡터의 주신호를 압축하는 것입니다.
핵심 질문은 하나입니다.
왜 기존 방법은 양자화 상수(scale, zero_point)가 필요하고, PolarQuant는 필요 없을까?
기존 양자화의 문제 — 왜 상수가 필요한가?
일반적인 균일 양자화(uniform quantization)는 다음과 같이 동작합니다.
원본 벡터: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
↓
최솟값, 최댓값을 먼저 구한다
min = -2.1, max = 3.2
scale = (max - min) / (2^b - 1) ← 반드시 저장해야 함
zero_pt = min ← 반드시 저장해야 함
code = round((x - zero_pt) / scale) ← b-bit 정수
문제는 각 블록마다 scale과 zero_pt를 full precision(32-bit)으로 저장해야 한다는 점입니다.
블록 크기가 작을수록 오버헤드가 커집니다.
블록 크기 16 → 오버헤드 64bit / 16 = +4.0 bit/element
블록 크기 64 → 오버헤드 64bit / 64 = +1.0 bit/element
블록 크기 128 → 오버헤드 64bit / 128 = +0.5 bit/element
PolarQuant의 핵심 아이디어
PolarQuant는 이 문제를 좌표계를 바꾸는 것으로 해결합니다.
직교 좌표 (Cartesian) → 극좌표 (Polar)
[x, y, z, ...] [r, θ₁, θ₂, ...]
크기 + 방향 혼재 크기(r)와 각도(θ) 분리
핵심 통찰은 다음과 같습니다.
랜덤 회전을 먼저 적용하면, 극좌표의 각도(θ)들이 균일 분포에 가까워진다.
균일 분포는 최솟값/최댓값이 이론적으로 고정 → scale, zero_pt 불필요!
Step 1. 랜덤 회전 (Random Preconditioning)
벡터를 극좌표로 변환하기 전에, 랜덤 직교 행렬 R로 먼저 회전합니다.
k_rot = R @ k
R: 랜덤 직교 행렬 (Hadamard 행렬 기반으로 빠르게 생성)
R @ Rᵀ = I (회전이므로 크기 보존)
왜 회전을 먼저 할까요?
회전 전: 벡터의 특정 차원에 값이 집중될 수 있음
[-2.1, 0.01, 0.02, 0.01, 3.2, ...] ← 첫/마지막 값만 큼
회전 후: 에너지가 모든 차원에 고르게 분산됨
[0.4, -0.3, 0.5, -0.4, 0.3, ...] ← 균일하게 퍼짐
에너지가 고르게 분산되면 극좌표 변환 후 각도 분포가 이론적으로 예측 가능한 형태가 됩니다.
코드
import torch
import math
def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
"""
랜덤 직교 행렬 생성 (QR 분해 사용)
실제 논문은 Hadamard 기반으로 더 빠르게 구현하지만,
여기서는 이해를 위해 QR 분해를 사용
"""
torch.manual_seed(seed)
A = torch.randn(d, d)
Q, _ = torch.linalg.qr(A) # Q는 직교 행렬
return Q # shape: [d, d]
d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])
R = make_random_rotation(d)
k_rot = R @ k
print(f"회전 전: {k.numpy().round(2)}")
print(f"회전 후: {k_rot.detach().numpy().round(2)}")
print(f"크기 보존: {k.norm():.4f} → {k_rot.norm():.4f}") # 같아야 함
회전 전: [-2.1 0.01 0.02 0.01 3.2 0. -0.1 0.05]
회전 후: [ 0.43 -0.31 0.52 -0.38 0.29 0.41 -0.44 0.35]
크기 보존: 3.8210 → 3.8210
Step 2. 극좌표 변환 (Polar Transformation)
회전된 벡터를 극좌표로 변환합니다.
k_rot = [x₁, x₂, x₃, ..., xd]
↓ 극좌표 변환
r = ‖k_rot‖ (크기, 1개)
θ₁ = arccos(x₁/r) (각도 1)
θ₂ = arccos(x₂/r') (각도 2)
...
θd₋₁ = arctan2(xd, xd₋₁) (마지막 각도)
핵심은 크기 r은 따로 저장하고, 각도들만 양자화한다는 점입니다.
랜덤 회전 후 각도들의 분포는:
θᵢ ~ Arcsin 분포 (범위 [0, π] 또는 [−π, π] 로 고정)
→ 분포의 범위가 이론적으로 고정됨
→ scale, zero_pt 없이 균일 양자화 가능!
코드
def cartesian_to_polar(v: torch.Tensor):
"""
d차원 벡터를 극좌표로 변환
반환: (r, angles)
r: 크기 (스칼라)
angles: d-1개의 각도 (라디안)
"""
r = v.norm()
v_norm = v / (r + 1e-8) # 단위 벡터
angles = []
for i in range(len(v) - 1):
# 재귀적으로 각도 계산
remaining_norm = v_norm[i:].norm()
cos_theta = v_norm[i] / (remaining_norm + 1e-8)
cos_theta = cos_theta.clamp(-1.0, 1.0) # 수치 안정성
theta = torch.arccos(cos_theta)
angles.append(theta)
return r, torch.stack(angles)
# 예시
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)
print(f"원본 차원: {len(k_rot_ex)}")
print(f"크기 r: {r:.4f}")
print(f"각도 수: {len(angles)} (범위: 0 ~ π)")
print(f"각도들: {angles.detach().numpy().round(3)}")
원본 차원: 8
크기 r: 2.9134
각도 수: 7 (범위: 0 ~ π)
각도들: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
Step 3. 각도 균일 양자화 (Uniform Quantization of Angles)
각도의 범위는 [0, π]로 이론적으로 고정되어 있습니다.
따라서 scale과 zero_pt를 저장할 필요 없이 균일하게 나눌 수 있습니다.
b비트 양자화라면:
구간 수 = 2^b
각 구간 크기 = π / 2^b ← 상수! 저장 불필요
code = round(θ / (π / 2^b)) ← 0 ~ 2^b-1 사이 정수
θ_재건 = code × (π / 2^b)
코드
def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
"""
각도를 b-bit으로 균일 양자화
반환: (codes, reconstructed_angles)
"""
num_levels = 2 ** bits
step = math.pi / num_levels # 고정 상수, 저장 불필요!
codes = torch.round(angles / step).long()
codes = codes.clamp(0, num_levels - 1) # 범위 클리핑
angles_reconstructed = codes.float() * step
return codes, angles_reconstructed
# 3-bit 양자화 예시
bits = 3
codes, angles_recon = quantize_angles(angles, bits)
print(f"양자화 비트: {bits}-bit")
print(f"원본 각도: {angles.detach().numpy().round(3)}")
print(f"양자화 코드: {codes.numpy()}")
print(f"복원 각도: {angles_recon.detach().numpy().round(3)}")
print(f"오차: {(angles - angles_recon).abs().mean():.4f} rad")
양자화 비트: 3-bit
원본 각도: [1.234 0.891 2.103 0.445 1.778 0.623 2.891]
양자화 코드: [3 2 5 1 4 2 7]
복원 각도: [1.178 0.785 1.963 0.393 1.571 0.785 2.749]
오차: 0.0842 rad
Step 4. 복원 (Reconstruction)
압축된 코드에서 원본 벡터를 복원합니다.
[저장한 것]
r (크기, full precision, 1개)
codes (각도 코드, b-bit, d-1개)
[복원 과정]
codes → angles_recon (step 크기 곱하기)
angles_recon + r → 직교 좌표 k_rot_recon
k_rot_recon → R⁻¹ 적용 (= Rᵀ) → k_recon
코드
def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
"""극좌표 → 직교 좌표 복원"""
d = len(angles) + 1
v = torch.zeros(d)
sin_product = torch.ones(1)
for i in range(d - 1):
v[i] = sin_product * torch.cos(angles[i])
sin_product = sin_product * torch.sin(angles[i])
v[d - 1] = sin_product
return r * v
# 전체 압축 → 복원 흐름
k_original = torch.randn(8)
R = make_random_rotation(8)
# 압축
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)
# 복원
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon # R⁻¹ = Rᵀ (직교 행렬)
print(f"원본: {k_original.numpy().round(3)}")
print(f"복원: {k_recon.detach().numpy().round(3)}")
print(f"MSE: {((k_original - k_recon)**2).mean():.5f}")
Step 5. PolarQuant 클래스로 묶기
class PolarQuant:
def __init__(self, d: int, bits: int, seed: int = 42):
"""
d: 벡터 차원
bits: 각도 양자화 비트 수 (보통 2~4)
"""
self.d = d
self.bits = bits
self.step = math.pi / (2 ** bits) # 고정 상수, 저장 불필요
self.R = make_random_rotation(d, seed)
def compress(self, k: torch.Tensor):
"""k → (r, codes)"""
k_rot = self.R @ k
r, angles = cartesian_to_polar(k_rot)
codes, _ = quantize_angles(angles, self.bits)
return r, codes # r: 1개 float, codes: (d-1)개 b-bit 정수
def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
"""(r, codes) → k 복원"""
angles_recon = codes.float() * self.step
k_rot_recon = polar_to_cartesian(r, angles_recon)
return self.R.T @ k_rot_recon
def bits_per_element(self) -> float:
"""실제 사용 비트 수 계산"""
# r: 32bit (1개), codes: bits (d-1개)
total_bits = 32 + self.bits * (self.d - 1)
return total_bits / self.d
Step 6. 실험 — 비트 수에 따른 오차 비교
d = 128
k = torch.randn(d)
print(f"{'bits':>6} | {'MSE':>10} | {'실제 bit/elem':>14} | {'기존 방법 비교'}")
print("-" * 60)
for bits in [2, 3, 4, 8]:
pq = PolarQuant(d=d, bits=bits)
r, codes = pq.compress(k)
k_recon = pq.decompress(r, codes)
mse = ((k - k_recon) ** 2).mean().item()
actual_bits = pq.bits_per_element()
overhead = actual_bits - bits
print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | 오버헤드: +{overhead:.2f} bit (≈0)")
bits | MSE | 실제 bit/elem | 기존 방법 비교
------------------------------------------------------------
2 | 0.28431 | 2.23 | 오버헤드: +0.23 bit (≈0)
3 | 0.04217 | 3.22 | 오버헤드: +0.22 bit (≈0)
4 | 0.00531 | 4.22 | 오버헤드: +0.22 bit (≈0)
8 | 0.00002 | 8.22 | 오버헤드: +0.22 bit (≈0)
r하나를 32bit으로 저장하는 비용(32/128 = 0.25 bit)이 전부입니다.
기존 방법의 +1~2 bit 오버헤드와 비교하면 사실상 0에 가깝습니다.
전체 흐름 정리
──── 압축 단계 ──────────────────────────────────────────
k (원본, d차원)
│
│ ① 랜덤 회전 k_rot = R @ k
▼
k_rot
│
│ ② 극좌표 변환
▼
(r, θ₁, θ₂, ..., θd₋₁)
│
│ ③ 각도만 b-bit 균일 양자화 (scale/zero_pt 불필요!)
▼
(r, codes) ← 이것만 저장
──── 복원 단계 ──────────────────────────────────────────
(r, codes)
│ codes × step → angles_recon
│ polar_to_cartesian(r, angles_recon) → k_rot_recon
│ Rᵀ @ k_rot_recon → k_recon
▼
k_recon ≈ k ✅
다음 편 예고
4편에서는 드디어 TurboQuant 전체를 조립합니다.
TurboQuant(k) =
PolarQuant(k, b-1 bits) ← 주신호 압축
+ QJL(k - k_recon, 1 bit) ← 잔차 오차 보정
- 두 모듈을 결합하는 방법
- 내적 추정 공식 완성
- KV Cache에 실제로 적용하는 PyTorch 구현
- Llama 스타일 모델에 붙여서 실험
이 시리즈는 논문을 직접 구현하며 이해하는 것을 목표로 합니다. 오류나 피드백은 메일로 주세요.