File size: 3,704 Bytes
8bfcf43
 
 
cf0a8ed
8bfcf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""Tests for RotateKVQuantizer — TASK-005."""
import pytest
import numpy as np
from apohara_context_forge.quantization.rotate_kv import RotateKVQuantizer, RotateKVConfig, QuantizedKVBlock


class TestRotateKVQuantizer:
    """Tests for RotateKV quantization (INVARIANT 10: pre-RoPE only)."""

    def test_rotate_kv_config_defaults(self):
        """RotateKVConfig has sensible defaults."""
        config = RotateKVConfig()
        assert config.bits == 4
        assert config.group_size == 64
        assert config.sink_tokens == 4

    def test_quantized_kv_block_has_pre_rope_metadata(self):
        """QuantizedKVBlock stores pre_rope flag in metadata."""
        # This tests the invariant: pre-RoPE tensors are what we quantize
        block = QuantizedKVBlock(
            keys_int4=np.zeros((10, 8, 64), dtype=np.float32),
            values_int4=np.zeros((10, 8, 64), dtype=np.float32),
            keys_sink_fp16=np.zeros((4, 8, 128), dtype=np.float16),
            values_sink_fp16=np.zeros((4, 8, 128), dtype=np.float16),
            scales_k=np.ones((1, 8, 64), dtype=np.float32),
            zero_points_k=np.zeros((1, 8, 64), dtype=np.float32),
            scales_v=np.ones((1, 8, 128), dtype=np.float32),
            zero_points_v=np.zeros((1, 8, 128), dtype=np.float32),
            channel_order=np.arange(128, dtype=np.int32),
            positions=np.arange(14, dtype=np.float32),
            bits=4,
        )
        assert block.bits == 4

    @pytest.mark.asyncio
    async def test_quantize_pre_rope_returns_quantized_block(self):
        """quantize_pre_rope() returns (QuantizedKVBlock, ndarray) tuple (INVARIANT 10)."""
        config = RotateKVConfig(bits=4, group_size=64, sink_tokens=4)
        quantizer = RotateKVQuantizer(config)

        # Pre-RoPE tensors: (batch=1, seq_len, num_heads, head_dim)
        k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        positions = np.arange(64, dtype=np.float32)

        result = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
        assert isinstance(result, tuple)
        qblock, remaining = result
        assert isinstance(qblock, QuantizedKVBlock)
        assert qblock.keys_int4.shape[0] > 0
        assert qblock.values_int4.shape[0] > 0

    @pytest.mark.asyncio
    async def test_quantize_pre_rope_sink_tokens_preserved(self):
        """First sink_tokens are preserved at FP16."""
        config = RotateKVConfig(bits=4, sink_tokens=4)
        quantizer = RotateKVQuantizer(config)

        k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        positions = np.arange(64, dtype=np.float32)

        qblock, _ = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)

        assert qblock.keys_sink_fp16.shape == (1, 4, 8, 64)
        assert qblock.values_sink_fp16.shape == (1, 4, 8, 64)

    @pytest.mark.asyncio
    async def test_dequantize_returns_fp32_tensors(self):
        """dequantize() returns FP32 tensors."""
        config = RotateKVConfig(bits=4, group_size=64, sink_tokens=4)
        quantizer = RotateKVQuantizer(config)

        k_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        v_tensor = np.random.randn(1, 64, 8, 64).astype(np.float32)
        positions = np.arange(64, dtype=np.float32)

        qblock, _ = quantizer.quantize_pre_rope(k_tensor, v_tensor, positions)
        k_deq, v_deq = quantizer.dequantize(qblock)

        assert isinstance(k_deq, np.ndarray)
        assert isinstance(v_deq, np.ndarray)
        assert k_deq.dtype == np.float32
        assert v_deq.dtype == np.float32