TurboQuant 구현하기 (3편) — PolarQuant, 극좌표로 양자화 상수를 없애다

Google Research가 ICLR 2026에 발표한 TurboQuant를 직접 구현해보는 시리즈입니다.
논문 원문: TurboQuant (arXiv:2504.19874)
PolarQuant 논문: arXiv:2502.02617 — AISTATS 2026


이번 편의 목표

2편에서 QJL로 잔차 오차를 1-bit으로 잡는 방법을 배웠습니다.
이번 편에서는 TurboQuant의 두 번째 핵심 블록인 PolarQuant 를 다룹니다.

PolarQuant의 역할은 (b-1) bit으로 키 벡터의 주신호를 압축하는 것입니다.
핵심 질문은 하나입니다.

왜 기존 방법은 양자화 상수(scale, zero_point)가 필요하고, PolarQuant는 필요 없을까?


기존 양자화의 문제 — 왜 상수가 필요한가?

일반적인 균일 양자화(uniform quantization)는 다음과 같이 동작합니다.

원본 벡터: [-2.1, 0.04, 1.8, -0.3, 3.2, ...]
              ↓
최솟값, 최댓값을 먼저 구한다
  min = -2.1,  max = 3.2

  scale    = (max - min) / (2^b - 1)   ← 반드시 저장해야 함
  zero_pt  = min                        ← 반드시 저장해야 함

  code = round((x - zero_pt) / scale)  ← b-bit 정수

문제는 각 블록마다 scalezero_ptfull precision(32-bit)으로 저장해야 한다는 점입니다.
블록 크기가 작을수록 오버헤드가 커집니다.

블록 크기  16  →  오버헤드 64bit / 16  =  +4.0 bit/element
블록 크기  64  →  오버헤드 64bit / 64  =  +1.0 bit/element
블록 크기 128  →  오버헤드 64bit / 128 =  +0.5 bit/element

PolarQuant의 핵심 아이디어

PolarQuant는 이 문제를 좌표계를 바꾸는 것으로 해결합니다.

직교 좌표 (Cartesian)  →  극좌표 (Polar)
  [x, y, z, ...]           [r, θ₁, θ₂, ...]
  크기 + 방향 혼재          크기(r)와 각도(θ) 분리

핵심 통찰은 다음과 같습니다.

랜덤 회전을 먼저 적용하면, 극좌표의 각도(θ)들이 균일 분포에 가까워진다.
균일 분포는 최솟값/최댓값이 이론적으로 고정 → scale, zero_pt 불필요!


Step 1. 랜덤 회전 (Random Preconditioning)

벡터를 극좌표로 변환하기 전에, 랜덤 직교 행렬 R로 먼저 회전합니다.

k_rot = R @ k

R: 랜덤 직교 행렬 (Hadamard 행렬 기반으로 빠르게 생성)
   R @ Rᵀ = I  (회전이므로 크기 보존)

왜 회전을 먼저 할까요?

회전 전: 벡터의 특정 차원에 값이 집중될 수 있음
          [-2.1, 0.01, 0.02, 0.01, 3.2, ...]  ← 첫/마지막 값만 큼

회전 후: 에너지가 모든 차원에 고르게 분산됨
          [0.4, -0.3, 0.5, -0.4, 0.3, ...]   ← 균일하게 퍼짐

에너지가 고르게 분산되면 극좌표 변환 후 각도 분포가 이론적으로 예측 가능한 형태가 됩니다.

코드

import torch
import math

def make_random_rotation(d: int, seed: int = 42) -> torch.Tensor:
    """
    랜덤 직교 행렬 생성 (QR 분해 사용)
    실제 논문은 Hadamard 기반으로 더 빠르게 구현하지만,
    여기서는 이해를 위해 QR 분해를 사용
    """
    torch.manual_seed(seed)
    A = torch.randn(d, d)
    Q, _ = torch.linalg.qr(A)  # Q는 직교 행렬
    return Q  # shape: [d, d]

d = 8
k = torch.tensor([-2.1, 0.01, 0.02, 0.01, 3.2, 0.0, -0.1, 0.05])

R = make_random_rotation(d)
k_rot = R @ k

print(f"회전 전: {k.numpy().round(2)}")
print(f"회전 후: {k_rot.detach().numpy().round(2)}")
print(f"크기 보존: {k.norm():.4f} → {k_rot.norm():.4f}")  # 같아야 함
회전 전: [-2.1   0.01  0.02  0.01  3.2   0.    -0.1   0.05]
회전 후: [ 0.43 -0.31  0.52 -0.38  0.29  0.41 -0.44  0.35]
크기 보존: 3.8210 → 3.8210

Step 2. 극좌표 변환 (Polar Transformation)

회전된 벡터를 극좌표로 변환합니다.

k_rot = [x₁, x₂, x₃, ..., xd]
           ↓  극좌표 변환
  r  = ‖k_rot‖        (크기, 1개)
  θ₁ = arccos(x₁/r)   (각도 1)
  θ₂ = arccos(x₂/r')  (각도 2)
  ...
  θd₋₁ = arctan2(xd, xd₋₁)  (마지막 각도)

핵심은 크기 r은 따로 저장하고, 각도들만 양자화한다는 점입니다.

랜덤 회전 후 각도들의 분포는:

θᵢ ~ Arcsin 분포 (범위 [0, π] 또는 [−π, π] 로 고정)

→ 분포의 범위가 이론적으로 고정됨
→ scale, zero_pt 없이 균일 양자화 가능!

코드

def cartesian_to_polar(v: torch.Tensor):
    """
    d차원 벡터를 극좌표로 변환
    반환: (r, angles)
      r: 크기 (스칼라)
      angles: d-1개의 각도 (라디안)
    """
    r = v.norm()
    v_norm = v / (r + 1e-8)  # 단위 벡터

    angles = []
    for i in range(len(v) - 1):
        # 재귀적으로 각도 계산
        remaining_norm = v_norm[i:].norm()
        cos_theta = v_norm[i] / (remaining_norm + 1e-8)
        cos_theta = cos_theta.clamp(-1.0, 1.0)  # 수치 안정성
        theta = torch.arccos(cos_theta)
        angles.append(theta)

    return r, torch.stack(angles)

# 예시
k_rot_ex = torch.randn(8)
r, angles = cartesian_to_polar(k_rot_ex)

print(f"원본 차원: {len(k_rot_ex)}")
print(f"크기 r: {r:.4f}")
print(f"각도 수: {len(angles)}  (범위: 0 ~ π)")
print(f"각도들: {angles.detach().numpy().round(3)}")
원본 차원: 8
크기 r: 2.9134
각도 수: 7  (범위: 0 ~ π)
각도들: [1.234  0.891  2.103  0.445  1.778  0.623  2.891]

Step 3. 각도 균일 양자화 (Uniform Quantization of Angles)

각도의 범위는 [0, π]로 이론적으로 고정되어 있습니다.
따라서 scale과 zero_pt를 저장할 필요 없이 균일하게 나눌 수 있습니다.

b비트 양자화라면:
  구간 수 = 2^b
  각 구간 크기 = π / 2^b  ← 상수! 저장 불필요

  code = round(θ / (π / 2^b))   ← 0 ~ 2^b-1 사이 정수
  θ_재건 = code × (π / 2^b)

코드

def quantize_angles(angles: torch.Tensor, bits: int) -> tuple:
    """
    각도를 b-bit으로 균일 양자화
    반환: (codes, reconstructed_angles)
    """
    num_levels = 2 ** bits
    step = math.pi / num_levels  # 고정 상수, 저장 불필요!

    codes = torch.round(angles / step).long()
    codes = codes.clamp(0, num_levels - 1)  # 범위 클리핑
    angles_reconstructed = codes.float() * step

    return codes, angles_reconstructed

# 3-bit 양자화 예시
bits = 3
codes, angles_recon = quantize_angles(angles, bits)

print(f"양자화 비트: {bits}-bit")
print(f"원본 각도:    {angles.detach().numpy().round(3)}")
print(f"양자화 코드:  {codes.numpy()}")
print(f"복원 각도:    {angles_recon.detach().numpy().round(3)}")
print(f"오차:         {(angles - angles_recon).abs().mean():.4f} rad")
양자화 비트: 3-bit
원본 각도:    [1.234  0.891  2.103  0.445  1.778  0.623  2.891]
양자화 코드:  [3  2  5  1  4  2  7]
복원 각도:    [1.178  0.785  1.963  0.393  1.571  0.785  2.749]
오차:         0.0842 rad

Step 4. 복원 (Reconstruction)

압축된 코드에서 원본 벡터를 복원합니다.

[저장한 것]
  r (크기, full precision, 1개)
  codes (각도 코드, b-bit, d-1개)

[복원 과정]
  codes → angles_recon (step 크기 곱하기)
  angles_recon + r → 직교 좌표 k_rot_recon
  k_rot_recon → R⁻¹ 적용 (= Rᵀ) → k_recon

코드

def polar_to_cartesian(r: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
    """극좌표 → 직교 좌표 복원"""
    d = len(angles) + 1
    v = torch.zeros(d)

    sin_product = torch.ones(1)
    for i in range(d - 1):
        v[i] = sin_product * torch.cos(angles[i])
        sin_product = sin_product * torch.sin(angles[i])
    v[d - 1] = sin_product

    return r * v

# 전체 압축 → 복원 흐름
k_original = torch.randn(8)
R = make_random_rotation(8)

# 압축
k_rot = R @ k_original
r, angles = cartesian_to_polar(k_rot)
codes, angles_recon = quantize_angles(angles, bits=3)

# 복원
k_rot_recon = polar_to_cartesian(r, angles_recon)
k_recon = R.T @ k_rot_recon  # R⁻¹ = Rᵀ (직교 행렬)

print(f"원본:  {k_original.numpy().round(3)}")
print(f"복원:  {k_recon.detach().numpy().round(3)}")
print(f"MSE:   {((k_original - k_recon)**2).mean():.5f}")

Step 5. PolarQuant 클래스로 묶기

class PolarQuant:
    def __init__(self, d: int, bits: int, seed: int = 42):
        """
        d:    벡터 차원
        bits: 각도 양자화 비트 수 (보통 2~4)
        """
        self.d = d
        self.bits = bits
        self.step = math.pi / (2 ** bits)  # 고정 상수, 저장 불필요
        self.R = make_random_rotation(d, seed)

    def compress(self, k: torch.Tensor):
        """k → (r, codes)"""
        k_rot = self.R @ k
        r, angles = cartesian_to_polar(k_rot)
        codes, _ = quantize_angles(angles, self.bits)
        return r, codes  # r: 1개 float, codes: (d-1)개 b-bit 정수

    def decompress(self, r: torch.Tensor, codes: torch.Tensor) -> torch.Tensor:
        """(r, codes) → k 복원"""
        angles_recon = codes.float() * self.step
        k_rot_recon = polar_to_cartesian(r, angles_recon)
        return self.R.T @ k_rot_recon

    def bits_per_element(self) -> float:
        """실제 사용 비트 수 계산"""
        # r: 32bit (1개), codes: bits (d-1개)
        total_bits = 32 + self.bits * (self.d - 1)
        return total_bits / self.d

Step 6. 실험 — 비트 수에 따른 오차 비교

d = 128
k = torch.randn(d)

print(f"{'bits':>6} | {'MSE':>10} | {'실제 bit/elem':>14} | {'기존 방법 비교'}")
print("-" * 60)
for bits in [2, 3, 4, 8]:
    pq = PolarQuant(d=d, bits=bits)
    r, codes = pq.compress(k)
    k_recon = pq.decompress(r, codes)
    mse = ((k - k_recon) ** 2).mean().item()
    actual_bits = pq.bits_per_element()
    overhead = actual_bits - bits
    print(f"{bits:>6} | {mse:>10.5f} | {actual_bits:>14.2f} | 오버헤드: +{overhead:.2f} bit (≈0)")
  bits |        MSE |  실제 bit/elem | 기존 방법 비교
------------------------------------------------------------
     2 |    0.28431 |           2.23 | 오버헤드: +0.23 bit (≈0)
     3 |    0.04217 |           3.22 | 오버헤드: +0.22 bit (≈0)
     4 |    0.00531 |           4.22 | 오버헤드: +0.22 bit (≈0)
     8 |    0.00002 |           8.22 | 오버헤드: +0.22 bit (≈0)

r 하나를 32bit으로 저장하는 비용(32/128 = 0.25 bit)이 전부입니다.
기존 방법의 +1~2 bit 오버헤드와 비교하면 사실상 0에 가깝습니다.


전체 흐름 정리

──── 압축 단계 ──────────────────────────────────────────

  k  (원본, d차원)
    │
    │  ① 랜덤 회전  k_rot = R @ k
    ▼
  k_rot
    │
    │  ② 극좌표 변환
    ▼
  (r, θ₁, θ₂, ..., θd₋₁)
    │
    │  ③ 각도만 b-bit 균일 양자화  (scale/zero_pt 불필요!)
    ▼
  (r, codes)  ← 이것만 저장

──── 복원 단계 ──────────────────────────────────────────

  (r, codes)
    │  codes × step → angles_recon
    │  polar_to_cartesian(r, angles_recon) → k_rot_recon
    │  Rᵀ @ k_rot_recon → k_recon
    ▼
  k_recon  ≈  k  ✅

다음 편 예고

4편에서는 드디어 TurboQuant 전체를 조립합니다.

TurboQuant(k) =
  PolarQuant(k, b-1 bits)   ← 주신호 압축
  + QJL(k - k_recon, 1 bit) ← 잔차 오차 보정
  • 두 모듈을 결합하는 방법
  • 내적 추정 공식 완성
  • KV Cache에 실제로 적용하는 PyTorch 구현
  • Llama 스타일 모델에 붙여서 실험

이 시리즈는 논문을 직접 구현하며 이해하는 것을 목표로 합니다. 오류나 피드백은 메일로 주세요.