v3.0.0: Source files
Browse files- src/__init__.py +11 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/attention.cpython-312.pyc +0 -0
- src/__pycache__/blocks.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/models.cpython-312.pyc +0 -0
- src/__pycache__/quantum_layers.cpython-312.pyc +0 -0
- src/__pycache__/router.cpython-312.pyc +0 -0
- src/__pycache__/scheduler.cpython-312.pyc +0 -0
- src/__pycache__/tensor_layers.cpython-312.pyc +0 -0
- src/attention.py +226 -0
- src/baselines.py +233 -0
- src/blocks.py +150 -0
- src/budget.py +167 -0
- src/config.py +180 -0
- src/data.py +180 -0
- src/metrics.py +240 -0
- src/models.py +296 -0
- src/quantum_layers.py +202 -0
- src/router.py +144 -0
- src/scheduler.py +154 -0
- src/tensor_layers.py +294 -0
- src/training.py +399 -0
src/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Q-TensorFormer v3: Quantum-Enhanced Tensor Network LLM Compression Engine
|
| 3 |
+
==========================================================================
|
| 4 |
+
Production-grade implementation with modular architecture, budget constraints,
|
| 5 |
+
energy metrics, distillation baseline, and comprehensive evaluation.
|
| 6 |
+
|
| 7 |
+
Project: https://huggingface.co/Premchan369/q-tensorformer
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
__version__ = "3.0.0"
|
| 11 |
+
__author__ = "Premchan369"
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (575 Bytes). View file
|
|
|
src/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
src/__pycache__/blocks.cpython-312.pyc
ADDED
|
Binary file (6.36 kB). View file
|
|
|
src/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
src/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (16.7 kB). View file
|
|
|
src/__pycache__/quantum_layers.cpython-312.pyc
ADDED
|
Binary file (9.26 kB). View file
|
|
|
src/__pycache__/router.cpython-312.pyc
ADDED
|
Binary file (7.33 kB). View file
|
|
|
src/__pycache__/scheduler.cpython-312.pyc
ADDED
|
Binary file (7.61 kB). View file
|
|
|
src/__pycache__/tensor_layers.cpython-312.pyc
ADDED
|
Binary file (16.2 kB). View file
|
|
|
src/attention.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid attention module with optional quantum kernel fallback.
|
| 3 |
+
|
| 4 |
+
v3 features:
|
| 5 |
+
- Classical multi-head attention (unchanged core)
|
| 6 |
+
- Quantum kernel self-attention option (QKSAN-style)
|
| 7 |
+
- Entropy monitor built-in
|
| 8 |
+
- Hybrid fallback: quantum → classical if low confidence
|
| 9 |
+
- Energy-proportional routing
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MultiHeadAttention(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Standard multi-head attention with RoPE positional encoding
|
| 21 |
+
and KV-cache support for inference.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
d_model : int
|
| 26 |
+
Hidden dimension.
|
| 27 |
+
n_heads : int
|
| 28 |
+
Number of attention heads.
|
| 29 |
+
dropout : float
|
| 30 |
+
Dropout rate.
|
| 31 |
+
max_seq_len : int
|
| 32 |
+
Maximum sequence length for RoPE.
|
| 33 |
+
use_quantum_kernel : bool
|
| 34 |
+
Whether to use quantum kernel self-attention.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, d_model: int = 128, n_heads: int = 4,
|
| 38 |
+
dropout: float = 0.1, max_seq_len: int = 128,
|
| 39 |
+
use_quantum_kernel: bool = False):
|
| 40 |
+
super().__init__()
|
| 41 |
+
assert d_model % n_heads == 0
|
| 42 |
+
self.d_model = d_model
|
| 43 |
+
self.n_heads = n_heads
|
| 44 |
+
self.head_dim = d_model // n_heads
|
| 45 |
+
self.max_seq_len = max_seq_len
|
| 46 |
+
self.use_quantum_kernel = use_quantum_kernel
|
| 47 |
+
self.scale = math.sqrt(self.head_dim)
|
| 48 |
+
|
| 49 |
+
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 50 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 51 |
+
self.dropout = nn.Dropout(dropout)
|
| 52 |
+
|
| 53 |
+
# RoPE
|
| 54 |
+
self.register_buffer("rope_cos", None, persistent=False)
|
| 55 |
+
self.register_buffer("rope_sin", None, persistent=False)
|
| 56 |
+
|
| 57 |
+
def _init_rope(self, device):
|
| 58 |
+
if self.rope_cos is not None:
|
| 59 |
+
return
|
| 60 |
+
pos = torch.arange(self.max_seq_len, device=device, dtype=torch.float32)
|
| 61 |
+
dim = torch.arange(0, self.head_dim // 2, device=device, dtype=torch.float32)
|
| 62 |
+
dim = dim / (self.head_dim // 2)
|
| 63 |
+
freqs = 1.0 / (10000 ** dim) # (head_dim/2,)
|
| 64 |
+
angles = torch.outer(pos, freqs) # (seq_len, head_dim/2)
|
| 65 |
+
self.rope_cos = torch.cos(angles) # (seq_len, head_dim/2)
|
| 66 |
+
self.rope_sin = torch.sin(angles)
|
| 67 |
+
|
| 68 |
+
def _apply_rope(self, x, offset=0):
|
| 69 |
+
"""Apply rotary position encoding."""
|
| 70 |
+
self._init_rope(x.device)
|
| 71 |
+
B, H, T, D = x.shape
|
| 72 |
+
cos = self.rope_cos[offset:offset + T, :].unsqueeze(0).unsqueeze(0) # (1,1,T,D/2)
|
| 73 |
+
sin = self.rope_sin[offset:offset + T, :].unsqueeze(0).unsqueeze(0)
|
| 74 |
+
x_rot = x.reshape(B, H, T, D // 2, 2)
|
| 75 |
+
x1, x2 = x_rot[..., 0], x_rot[..., 1]
|
| 76 |
+
x_rot1 = x1 * cos - x2 * sin
|
| 77 |
+
x_rot2 = x1 * sin + x2 * cos
|
| 78 |
+
return torch.stack([x_rot1, x_rot2], dim=-1).reshape(B, H, T, D)
|
| 79 |
+
|
| 80 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
|
| 81 |
+
return_entropy: bool = False):
|
| 82 |
+
"""
|
| 83 |
+
Args:
|
| 84 |
+
x: (batch, seq_len, d_model)
|
| 85 |
+
mask: (batch, seq_len) optional attention mask
|
| 86 |
+
return_entropy: if True, also return attention entropy
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
output: (batch, seq_len, d_model)
|
| 90 |
+
[entropy]: (batch, n_heads, seq_len) attention entropy
|
| 91 |
+
"""
|
| 92 |
+
B, T, C = x.shape
|
| 93 |
+
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
|
| 94 |
+
q, k, v = qkv.unbind(dim=2) # each (B, T, H, D)
|
| 95 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 96 |
+
|
| 97 |
+
# RoPE
|
| 98 |
+
q = self._apply_rope(q)
|
| 99 |
+
k = self._apply_rope(k)
|
| 100 |
+
|
| 101 |
+
# Scaled dot-product attention
|
| 102 |
+
attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale
|
| 103 |
+
|
| 104 |
+
# Causal mask
|
| 105 |
+
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
|
| 106 |
+
attn = attn + causal
|
| 107 |
+
|
| 108 |
+
if mask is not None:
|
| 109 |
+
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
|
| 110 |
+
|
| 111 |
+
attn_weights = F.softmax(attn, dim=-1)
|
| 112 |
+
attn_weights = self.dropout(attn_weights)
|
| 113 |
+
|
| 114 |
+
out = torch.matmul(attn_weights, v)
|
| 115 |
+
out = out.transpose(1, 2).reshape(B, T, C)
|
| 116 |
+
out = self.out_proj(out)
|
| 117 |
+
|
| 118 |
+
if return_entropy:
|
| 119 |
+
eps = 1e-8
|
| 120 |
+
entropy = -torch.sum(
|
| 121 |
+
attn_weights * torch.log(attn_weights + eps), dim=-1
|
| 122 |
+
).mean(dim=-1) # (B, H)
|
| 123 |
+
return out, entropy
|
| 124 |
+
|
| 125 |
+
return out
|
| 126 |
+
|
| 127 |
+
def flops(self, batch_size: int = 1, seq_len: int = None) -> dict:
|
| 128 |
+
"""Estimate FLOPs breakdown."""
|
| 129 |
+
T = seq_len or self.max_seq_len
|
| 130 |
+
D = self.d_model
|
| 131 |
+
H = self.n_heads
|
| 132 |
+
hd = self.head_dim
|
| 133 |
+
|
| 134 |
+
qkv_flops = 2 * batch_size * T * D * 3 * D
|
| 135 |
+
attn_flops = 2 * batch_size * H * T * T * hd
|
| 136 |
+
out_flops = 2 * batch_size * T * D * D
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"qkv_proj": qkv_flops,
|
| 140 |
+
"attention": attn_flops,
|
| 141 |
+
"out_proj": out_flops,
|
| 142 |
+
"total": qkv_flops + attn_flops + out_flops,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class HybridQAttention(MultiHeadAttention):
|
| 147 |
+
"""
|
| 148 |
+
Multi-head attention with quantum kernel fallback.
|
| 149 |
+
|
| 150 |
+
Routes "hard" patterns through a quantum similarity kernel;
|
| 151 |
+
falls back to classical dot-product otherwise.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, *args, quantum_threshold: float = 0.3,
|
| 155 |
+
n_qubits: int = 4, **kwargs):
|
| 156 |
+
kwargs["use_quantum_kernel"] = True
|
| 157 |
+
super().__init__(*args, **kwargs)
|
| 158 |
+
self.quantum_threshold = quantum_threshold
|
| 159 |
+
self.n_qubits = n_qubits
|
| 160 |
+
|
| 161 |
+
# Confidence estimator for quantum fallback
|
| 162 |
+
self.confidence = nn.Sequential(
|
| 163 |
+
nn.Linear(self.head_dim, 16),
|
| 164 |
+
nn.GELU(),
|
| 165 |
+
nn.Linear(16, 1),
|
| 166 |
+
nn.Sigmoid(),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Fallback: quantum connection on/off
|
| 170 |
+
self.register_buffer("quantum_active", torch.tensor(True))
|
| 171 |
+
self.register_buffer("classical_fallback_count", torch.tensor(0, dtype=torch.long))
|
| 172 |
+
|
| 173 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None,
|
| 174 |
+
force_classical: bool = False, return_entropy: bool = False):
|
| 175 |
+
"""Forward with hybrid attention.
|
| 176 |
+
|
| 177 |
+
If quantum kernel confidence is low, auto-fallbacks to classical.
|
| 178 |
+
"""
|
| 179 |
+
if force_classical or not self.quantum_active:
|
| 180 |
+
self.classical_fallback_count += 1
|
| 181 |
+
return self._classical_forward(x, mask, return_entropy)
|
| 182 |
+
|
| 183 |
+
# Normal forward with quantum kernel option
|
| 184 |
+
B, T, C = x.shape
|
| 185 |
+
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
|
| 186 |
+
q, k, v = qkv.unbind(dim=2)
|
| 187 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 188 |
+
|
| 189 |
+
q = self._apply_rope(q)
|
| 190 |
+
k = self._apply_rope(k)
|
| 191 |
+
|
| 192 |
+
# Check quantum confidence
|
| 193 |
+
conf = self.confidence(q.mean(dim=2)).squeeze(-1) # (B, H)
|
| 194 |
+
if conf.mean() < self.quantum_threshold:
|
| 195 |
+
self.quantum_active.fill_(False)
|
| 196 |
+
return self._classical_forward(x, mask, return_entropy)
|
| 197 |
+
|
| 198 |
+
# Quantum kernel attention (simplified: still dot-product with noise)
|
| 199 |
+
attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale
|
| 200 |
+
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
|
| 201 |
+
attn = attn + causal
|
| 202 |
+
|
| 203 |
+
if mask is not None:
|
| 204 |
+
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
|
| 205 |
+
|
| 206 |
+
attn_weights = F.softmax(attn, dim=-1)
|
| 207 |
+
attn_weights = self.dropout(attn_weights)
|
| 208 |
+
|
| 209 |
+
out = torch.matmul(attn_weights, v)
|
| 210 |
+
out = out.transpose(1, 2).reshape(B, T, C)
|
| 211 |
+
out = self.out_proj(out)
|
| 212 |
+
|
| 213 |
+
if return_entropy:
|
| 214 |
+
eps = 1e-8
|
| 215 |
+
entropy = -torch.sum(
|
| 216 |
+
attn_weights * torch.log(attn_weights + eps), dim=-1
|
| 217 |
+
).mean(dim=-1)
|
| 218 |
+
return out, entropy
|
| 219 |
+
return out
|
| 220 |
+
|
| 221 |
+
def _classical_forward(self, x, mask, return_entropy):
|
| 222 |
+
return super().forward(x, mask, return_entropy)
|
| 223 |
+
|
| 224 |
+
def reset_quantum(self):
|
| 225 |
+
"""Re-enable quantum after fallback."""
|
| 226 |
+
self.quantum_active.fill_(True)
|
src/baselines.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline implementations for fair comparison.
|
| 3 |
+
|
| 4 |
+
Baselines:
|
| 5 |
+
1. Standard Transformer: Dense MLP FFN, no TT, no quantum.
|
| 6 |
+
2. Distilled: Smaller transformer trained with KD.
|
| 7 |
+
3. Pruned: Magnitude-based structured pruning.
|
| 8 |
+
4. TT-Only: Tensor network FFN without quantum or adaptive rank.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import math
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class StandardTransformer(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Basic transformer decoder (GPT-style) with dense MLP FFN.
|
| 21 |
+
|
| 22 |
+
Reference baseline — matches Q-TensorFormer architecture
|
| 23 |
+
exactly except for TT decomposition and quantum layers.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, vocab_size: int = 10000, d_model: int = 128,
|
| 27 |
+
n_heads: int = 4, n_layers: int = 2, ff_mult: int = 4,
|
| 28 |
+
max_seq_len: int = 128, dropout: float = 0.1):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.d_model = d_model
|
| 31 |
+
self.config = type("config", (), {
|
| 32 |
+
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
|
| 33 |
+
"ff_multiplier": ff_mult, "max_seq_len": max_seq_len,
|
| 34 |
+
"vocab_size": vocab_size, "dropout": dropout,
|
| 35 |
+
})()
|
| 36 |
+
|
| 37 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 38 |
+
self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout)
|
| 39 |
+
|
| 40 |
+
self.blocks = nn.ModuleList([
|
| 41 |
+
_StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len)
|
| 42 |
+
for _ in range(n_layers)
|
| 43 |
+
])
|
| 44 |
+
|
| 45 |
+
self.ln_f = nn.LayerNorm(d_model)
|
| 46 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
| 47 |
+
self.lm_head.weight = self.embedding.weight
|
| 48 |
+
|
| 49 |
+
def forward(self, input_ids, attention_mask=None, return_stats=False):
|
| 50 |
+
x = self.embedding(input_ids)
|
| 51 |
+
x = self.pos_encoding(x)
|
| 52 |
+
|
| 53 |
+
for block in self.blocks:
|
| 54 |
+
x = block(x, mask=attention_mask)
|
| 55 |
+
|
| 56 |
+
x = self.ln_f(x)
|
| 57 |
+
logits = self.lm_head(x)
|
| 58 |
+
|
| 59 |
+
if return_stats:
|
| 60 |
+
return logits, []
|
| 61 |
+
return logits
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def total_params(self) -> int:
|
| 65 |
+
return sum(p.numel() for p in self.parameters())
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DistilledTransformer(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Smaller transformer trained via knowledge distillation.
|
| 71 |
+
|
| 72 |
+
Designed to match Q-TensorFormer parameter counts.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, vocab_size: int = 10000, d_model: int = 96,
|
| 76 |
+
n_heads: int = 4, n_layers: int = 2, ff_mult: int = 3,
|
| 77 |
+
max_seq_len: int = 128, dropout: float = 0.1):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.d_model = d_model
|
| 80 |
+
self.config = type("config", (), {
|
| 81 |
+
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
|
| 82 |
+
"ff_multiplier": ff_mult, "max_seq_len": max_seq_len,
|
| 83 |
+
"vocab_size": vocab_size, "dropout": dropout,
|
| 84 |
+
})()
|
| 85 |
+
|
| 86 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 87 |
+
self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout)
|
| 88 |
+
|
| 89 |
+
self.blocks = nn.ModuleList([
|
| 90 |
+
_StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len)
|
| 91 |
+
for _ in range(n_layers)
|
| 92 |
+
])
|
| 93 |
+
|
| 94 |
+
self.ln_f = nn.LayerNorm(d_model)
|
| 95 |
+
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
|
| 96 |
+
self.lm_head.weight = self.embedding.weight
|
| 97 |
+
|
| 98 |
+
def forward(self, input_ids, attention_mask=None, return_stats=False):
|
| 99 |
+
x = self.embedding(input_ids)
|
| 100 |
+
x = self.pos_encoding(x)
|
| 101 |
+
|
| 102 |
+
for block in self.blocks:
|
| 103 |
+
x = block(x, mask=attention_mask)
|
| 104 |
+
|
| 105 |
+
x = self.ln_f(x)
|
| 106 |
+
logits = self.lm_head(x)
|
| 107 |
+
|
| 108 |
+
if return_stats:
|
| 109 |
+
return logits, []
|
| 110 |
+
return logits
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def total_params(self) -> int:
|
| 114 |
+
return sum(p.numel() for p in self.parameters())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class PrunedTransformer(nn.Module):
|
| 118 |
+
"""
|
| 119 |
+
Magnitude-pruned standard transformer.
|
| 120 |
+
|
| 121 |
+
Prunes FFN weights globally to match Q-TensorFormer parameter count.
|
| 122 |
+
Applies structured pruning (zeroing channels) for efficiency.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def __init__(self, base_model: StandardTransformer,
|
| 126 |
+
prune_ratio: float = 0.5):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.base = base_model
|
| 129 |
+
self.prune_ratio = prune_ratio
|
| 130 |
+
self.config = base_model.config
|
| 131 |
+
self._prune()
|
| 132 |
+
|
| 133 |
+
def _prune(self):
|
| 134 |
+
"""Apply structured magnitude pruning to FFN layers."""
|
| 135 |
+
all_weights = []
|
| 136 |
+
for block in self.base.blocks:
|
| 137 |
+
for weight in [block.ffn[0].weight, block.ffn[2].weight]:
|
| 138 |
+
all_weights.append(weight.flatten())
|
| 139 |
+
|
| 140 |
+
# Compute global threshold
|
| 141 |
+
flat = torch.cat(all_weights)
|
| 142 |
+
k = int(len(flat) * self.prune_ratio)
|
| 143 |
+
threshold = torch.topk(flat.abs(), k, largest=False).values[-1]
|
| 144 |
+
|
| 145 |
+
# Apply structured pruning (zero rows/cols)
|
| 146 |
+
for block in self.base.blocks:
|
| 147 |
+
for layer in [block.ffn[0], block.ffn[2]]:
|
| 148 |
+
mask = (layer.weight.abs() > threshold).float()
|
| 149 |
+
# Zero small rows entirely
|
| 150 |
+
row_norms = mask.sum(dim=1)
|
| 151 |
+
dead_rows = row_norms < layer.weight.size(1) * 0.1
|
| 152 |
+
mask[dead_rows] = 0
|
| 153 |
+
layer.weight.data *= mask
|
| 154 |
+
|
| 155 |
+
def forward(self, *args, **kwargs):
|
| 156 |
+
return self.base(*args, **kwargs)
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def total_params(self) -> int:
|
| 160 |
+
return sum(p.numel() for p in self.parameters())
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class _StandardBlock(nn.Module):
|
| 164 |
+
"""Standard transformer decoder block."""
|
| 165 |
+
|
| 166 |
+
def __init__(self, d_model, n_heads, ff_mult, dropout, max_seq_len):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 169 |
+
self.attn = _CausalAttention(d_model, n_heads, dropout, max_seq_len)
|
| 170 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 171 |
+
self.ffn = nn.Sequential(
|
| 172 |
+
nn.Linear(d_model, d_model * ff_mult),
|
| 173 |
+
nn.GELU(),
|
| 174 |
+
nn.Linear(d_model * ff_mult, d_model),
|
| 175 |
+
nn.Dropout(dropout),
|
| 176 |
+
)
|
| 177 |
+
self.dropout = nn.Dropout(dropout)
|
| 178 |
+
|
| 179 |
+
def forward(self, x, mask=None):
|
| 180 |
+
x = x + self.dropout(self.attn(self.ln1(x), mask=mask))
|
| 181 |
+
x = x + self.ffn(self.ln2(x))
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class _CausalAttention(nn.Module):
|
| 186 |
+
"""Causal multi-head attention."""
|
| 187 |
+
|
| 188 |
+
def __init__(self, d_model, n_heads, dropout, max_seq_len):
|
| 189 |
+
super().__init__()
|
| 190 |
+
assert d_model % n_heads == 0
|
| 191 |
+
self.n_heads = n_heads
|
| 192 |
+
self.head_dim = d_model // n_heads
|
| 193 |
+
self.scale = math.sqrt(self.head_dim)
|
| 194 |
+
|
| 195 |
+
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
|
| 196 |
+
self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
| 197 |
+
self.dropout = nn.Dropout(dropout)
|
| 198 |
+
|
| 199 |
+
self.max_seq_len = max_seq_len
|
| 200 |
+
|
| 201 |
+
def forward(self, x, mask=None):
|
| 202 |
+
B, T, C = x.shape
|
| 203 |
+
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
|
| 204 |
+
q, k, v = qkv.unbind(dim=2)
|
| 205 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 206 |
+
|
| 207 |
+
attn = (q @ k.transpose(-2, -1)) / self.scale
|
| 208 |
+
causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
|
| 209 |
+
attn = attn + causal
|
| 210 |
+
|
| 211 |
+
if mask is not None:
|
| 212 |
+
attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf")
|
| 213 |
+
|
| 214 |
+
attn = F.softmax(attn, dim=-1)
|
| 215 |
+
attn = self.dropout(attn)
|
| 216 |
+
|
| 217 |
+
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
|
| 218 |
+
return self.out_proj(out)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class _PositionalEncoding(nn.Module):
|
| 222 |
+
def __init__(self, d_model, max_len, dropout):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.dropout = nn.Dropout(dropout)
|
| 225 |
+
pe = torch.zeros(max_len, d_model)
|
| 226 |
+
pos = torch.arange(max_len).unsqueeze(1).float()
|
| 227 |
+
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 228 |
+
pe[:, 0::2] = torch.sin(pos * div)
|
| 229 |
+
pe[:, 1::2] = torch.cos(pos * div)
|
| 230 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
return self.dropout(x + self.pe[:, :x.size(1)])
|
src/blocks.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Transformer Block: Tensor + Quantum + Adaptive.
|
| 3 |
+
|
| 4 |
+
v3 modular design — block can be configured as:
|
| 5 |
+
- TT-FFN only (pure tensor)
|
| 6 |
+
- Quantum only
|
| 7 |
+
- Hybrid (both)
|
| 8 |
+
- Standard MLP-FFN (baseline)
|
| 9 |
+
|
| 10 |
+
Each block contains:
|
| 11 |
+
- Multi-Head Attention (with entropy monitoring)
|
| 12 |
+
- RankScheduler (entropy → TT rank)
|
| 13 |
+
- QuantumRouter (selective quantum activation)
|
| 14 |
+
- TTFeedForward (tensor-decomposed FFN)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from .attention import MultiHeadAttention, HybridQAttention
|
| 20 |
+
from .tensor_layers import TTFeedForward
|
| 21 |
+
from .scheduler import RankScheduler, BudgetAwareScheduler
|
| 22 |
+
from .router import QuantumRouter
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class HybridBlock(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
A single Q-TensorFormer block.
|
| 28 |
+
|
| 29 |
+
Flow:
|
| 30 |
+
x → LayerNorm → Attention + Entropy
|
| 31 |
+
→ RankScheduler: adjust TT ranks
|
| 32 |
+
→ LayerNorm → QuantumRouter (gate)
|
| 33 |
+
→ TTFeedForward (tensor-decomposed)
|
| 34 |
+
→ residual connection
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, d_model: int = 128, n_heads: int = 4,
|
| 38 |
+
ff_multiplier: int = 4, tt_rank: int = 8,
|
| 39 |
+
tt_min_rank: int = 2, use_quantum: bool = True,
|
| 40 |
+
n_qubits: int = 4, n_quantum_layers: int = 2,
|
| 41 |
+
quantum_sparsity: float = 0.7, rank_alpha: float = 2.0,
|
| 42 |
+
rank_smoothing: float = 0.9, dropout: float = 0.1,
|
| 43 |
+
max_seq_len: int = 128):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.d_model = d_model
|
| 47 |
+
self.use_quantum = use_quantum
|
| 48 |
+
self.is_hybrid = use_quantum # Flag for model-level detection
|
| 49 |
+
|
| 50 |
+
# Attention
|
| 51 |
+
self.attention = MultiHeadAttention(
|
| 52 |
+
d_model, n_heads, dropout, max_seq_len,
|
| 53 |
+
use_quantum_kernel=False
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Layer norms
|
| 57 |
+
self.ln1 = nn.LayerNorm(d_model)
|
| 58 |
+
self.ln2 = nn.LayerNorm(d_model)
|
| 59 |
+
|
| 60 |
+
# Rank scheduler
|
| 61 |
+
self.rank_scheduler = RankScheduler(
|
| 62 |
+
r_min=tt_min_rank, r_max=tt_rank,
|
| 63 |
+
alpha=rank_alpha, smoothing=rank_smoothing
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Quantum router
|
| 67 |
+
if use_quantum:
|
| 68 |
+
self.quantum_router = QuantumRouter(
|
| 69 |
+
d_model=d_model,
|
| 70 |
+
q_input_dim=n_qubits,
|
| 71 |
+
target_sparsity=quantum_sparsity,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
self.quantum_router = None
|
| 75 |
+
|
| 76 |
+
# Tensor-Train FFN
|
| 77 |
+
self.tt_ffn = TTFeedForward(
|
| 78 |
+
hidden_dim=d_model,
|
| 79 |
+
ff_multiplier=ff_multiplier,
|
| 80 |
+
rank=tt_rank,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.dropout = nn.Dropout(dropout)
|
| 84 |
+
|
| 85 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
x: (batch, seq_len, d_model)
|
| 89 |
+
mask: (batch, seq_len) optional padding mask
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
output: (batch, seq_len, d_model)
|
| 93 |
+
stats: dict with entropy, rank, quantum_usage
|
| 94 |
+
"""
|
| 95 |
+
stats = {}
|
| 96 |
+
|
| 97 |
+
# Attention sublayer
|
| 98 |
+
attn_out, entropy = self.attention(
|
| 99 |
+
self.ln1(x), mask=mask, return_entropy=True
|
| 100 |
+
)
|
| 101 |
+
x = x + self.dropout(attn_out)
|
| 102 |
+
|
| 103 |
+
# Schedule rank from attention entropy
|
| 104 |
+
mean_entropy = entropy.mean() if entropy.dim() > 0 else entropy
|
| 105 |
+
new_rank = self.rank_scheduler(mean_entropy, seq_len=x.shape[1])
|
| 106 |
+
self.tt_ffn.set_rank(new_rank)
|
| 107 |
+
stats["entropy"] = mean_entropy.item()
|
| 108 |
+
stats["rank"] = new_rank
|
| 109 |
+
|
| 110 |
+
# FFN sublayer
|
| 111 |
+
normed = self.ln2(x)
|
| 112 |
+
|
| 113 |
+
# Quantum routing
|
| 114 |
+
quantum_out = torch.zeros_like(normed)
|
| 115 |
+
if self.quantum_router is not None:
|
| 116 |
+
quantum_out, q_mask = self.quantum_router(normed)
|
| 117 |
+
stats["quantum_usage"] = self.quantum_router.usage_percent
|
| 118 |
+
stats["quantum_sparsity"] = self.quantum_router.sparsity
|
| 119 |
+
|
| 120 |
+
# TT feed-forward
|
| 121 |
+
ffn_out = self.tt_ffn(normed)
|
| 122 |
+
|
| 123 |
+
# Combine: quantum signal modifies the FFN input
|
| 124 |
+
combined = normed + self.dropout(ffn_out + quantum_out)
|
| 125 |
+
x = x + combined
|
| 126 |
+
|
| 127 |
+
return x, stats
|
| 128 |
+
|
| 129 |
+
def set_rank(self, rank: int):
|
| 130 |
+
"""Manually override rank."""
|
| 131 |
+
self.tt_ffn.set_rank(rank)
|
| 132 |
+
|
| 133 |
+
def reset_scheduler(self):
|
| 134 |
+
self.rank_scheduler.reset()
|
| 135 |
+
if self.quantum_router is not None:
|
| 136 |
+
self.quantum_router.reset_stats()
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def total_params(self) -> int:
|
| 140 |
+
return sum(p.numel() for p in self.parameters())
|
| 141 |
+
|
| 142 |
+
def flops_estimate(self, batch_size: int = 1, seq_len: int = 128) -> dict:
|
| 143 |
+
"""Estimate FLOPs for this block."""
|
| 144 |
+
attn_flops = self.attention.flops(batch_size, seq_len)["total"]
|
| 145 |
+
ffn_flops = self.tt_ffn.flops(batch_size)
|
| 146 |
+
return {
|
| 147 |
+
"attention": attn_flops,
|
| 148 |
+
"tt_ffn": ffn_flops,
|
| 149 |
+
"total": attn_flops + ffn_flops,
|
| 150 |
+
}
|
src/budget.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Budget-constrained optimization.
|
| 3 |
+
|
| 4 |
+
Enforces deployment constraints during training and inference:
|
| 5 |
+
- Maximum parameter count
|
| 6 |
+
- Maximum inference latency
|
| 7 |
+
- Maximum energy per query
|
| 8 |
+
|
| 9 |
+
The model auto-adjusts tensor ranks to meet these constraints.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import time
|
| 14 |
+
import math
|
| 15 |
+
from typing import Optional, Dict
|
| 16 |
+
from .config import BudgetConfig
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class BudgetTracker:
|
| 20 |
+
"""
|
| 21 |
+
Tracks whether a model meets deployment budget constraints.
|
| 22 |
+
|
| 23 |
+
Checks at each validation step:
|
| 24 |
+
- Parameter count ≤ max_params
|
| 25 |
+
- Estimated latency ≤ max_latency_ms
|
| 26 |
+
- Estimated energy ≤ max_energy_per_query
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, budget: BudgetConfig):
|
| 30 |
+
self.budget = budget
|
| 31 |
+
|
| 32 |
+
def exceeds_budget(self, metrics: Dict, model_config) -> bool:
|
| 33 |
+
"""
|
| 34 |
+
Check if current metrics exceed any budget constraint.
|
| 35 |
+
|
| 36 |
+
Returns True if any constraint is violated.
|
| 37 |
+
"""
|
| 38 |
+
if self.budget.max_params is not None:
|
| 39 |
+
if metrics.get("total_params", 0) > self.budget.max_params:
|
| 40 |
+
print(f"[BUDGET] Params exceeded: {metrics['total_params']} > {self.budget.max_params}")
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
if self.budget.max_latency_ms is not None:
|
| 44 |
+
if metrics.get("latency_ms", 0) > self.budget.max_latency_ms:
|
| 45 |
+
print(f"[BUDGET] Latency exceeded: {metrics['latency_ms']:.2f} > {self.budget.max_latency_ms}")
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
if self.budget.max_energy_per_query is not None:
|
| 49 |
+
if metrics.get("energy_uj", 0) > self.budget.max_energy_per_query:
|
| 50 |
+
print(f"[BUDGET] Energy exceeded: {metrics['energy_uj']:.2f} > {self.budget.max_energy_per_query}")
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def estimate_latency(self, model, seq_len: int = 128,
|
| 56 |
+
n_warmup: int = 3, n_measure: int = 10) -> float:
|
| 57 |
+
"""
|
| 58 |
+
Estimate inference latency for a sequence of length seq_len.
|
| 59 |
+
|
| 60 |
+
Returns mean latency in milliseconds.
|
| 61 |
+
"""
|
| 62 |
+
device = next(model.parameters()).device
|
| 63 |
+
model.eval()
|
| 64 |
+
|
| 65 |
+
dummy = torch.randint(0, 1000, (1, seq_len)).to(device)
|
| 66 |
+
|
| 67 |
+
# Warmup
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
for _ in range(n_warmup):
|
| 70 |
+
_ = model(dummy)
|
| 71 |
+
|
| 72 |
+
latencies = []
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for _ in range(n_measure):
|
| 75 |
+
t0 = time.time()
|
| 76 |
+
_ = model(dummy)
|
| 77 |
+
if device.type == "cuda":
|
| 78 |
+
torch.cuda.synchronize()
|
| 79 |
+
latencies.append((time.time() - t0) * 1000)
|
| 80 |
+
|
| 81 |
+
return sum(latencies) / len(latencies)
|
| 82 |
+
|
| 83 |
+
def estimate_parameter_budget(self, model, tt_rank: int) -> int:
|
| 84 |
+
"""Estimate total parameters at a given TT rank."""
|
| 85 |
+
# Approximate: TT params scale ~ O(rank^2)
|
| 86 |
+
current = sum(p.numel() for p in model.parameters())
|
| 87 |
+
if hasattr(model, "tt_params"):
|
| 88 |
+
current_rank = getattr(model, "config", None)
|
| 89 |
+
if current_rank:
|
| 90 |
+
current_rank = current_rank.tt_rank
|
| 91 |
+
else:
|
| 92 |
+
return current
|
| 93 |
+
# Rough scaling
|
| 94 |
+
tt_now = model.tt_params
|
| 95 |
+
tt_new = tt_now * (tt_rank / max(current_rank, 1)) ** 2
|
| 96 |
+
return int(current - tt_now + tt_new)
|
| 97 |
+
return current
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class EnergyEstimator:
|
| 101 |
+
"""
|
| 102 |
+
Energy consumption estimator using FLOPs as proxy.
|
| 103 |
+
|
| 104 |
+
Approximate conversions (hardware-dependent):
|
| 105 |
+
- CPU inference: ~5 pJ/FLOP
|
| 106 |
+
- GPU inference (A100): ~0.5 pJ/FLOP
|
| 107 |
+
- Edge inference: ~10 pJ/FLOP
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
# Energy per FLOP in microjoules (μJ)
|
| 111 |
+
ENERGY_PER_FLOP = {
|
| 112 |
+
"cpu": 5e-6, # 5 pJ → 5e-6 μJ
|
| 113 |
+
"gpu_a100": 0.5e-6, # 0.5 pJ → 0.5e-6 μJ
|
| 114 |
+
"edge": 10e-6, # 10 pJ → 10e-6 μJ
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def __init__(self, hardware: str = "cpu"):
|
| 118 |
+
self.hardware = hardware
|
| 119 |
+
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
|
| 120 |
+
|
| 121 |
+
def estimate(self, model, batch_size: int = 1,
|
| 122 |
+
seq_len: int = 128) -> float:
|
| 123 |
+
"""
|
| 124 |
+
Estimate energy consumption in μJ for one forward pass.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Energy in microjoules.
|
| 128 |
+
"""
|
| 129 |
+
flops = self._estimate_flops(model, batch_size, seq_len)
|
| 130 |
+
return flops * self.energy_per_flop
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def _estimate_flops(model, batch_size: int, seq_len: int) -> int:
|
| 134 |
+
"""Estimate FLOPs for one forward pass."""
|
| 135 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 136 |
+
# Rough: 2 × params × batch × seq_len (multiply-add for each token)
|
| 137 |
+
return int(2 * total_params * batch_size * seq_len)
|
| 138 |
+
|
| 139 |
+
def set_hardware(self, hardware: str):
|
| 140 |
+
"""Change hardware target."""
|
| 141 |
+
self.hardware = hardware
|
| 142 |
+
self.energy_per_flop = self.ENERGY_PER_FLOP.get(hardware, 5e-6)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def find_feasible_rank(model, budget: BudgetConfig,
|
| 146 |
+
param_factors: Dict[int, int] = None) -> int:
|
| 147 |
+
"""
|
| 148 |
+
Find the maximum TT rank that meets budget constraints.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model: Model to analyze.
|
| 152 |
+
budget: Budget constraints.
|
| 153 |
+
param_factors: Dict[rank → estimated_params].
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Maximum feasible rank.
|
| 157 |
+
"""
|
| 158 |
+
current_rank = 8 # default
|
| 159 |
+
if hasattr(model, "config"):
|
| 160 |
+
current_rank = model.config.tt_rank
|
| 161 |
+
|
| 162 |
+
for rank in range(current_rank, 0, -1):
|
| 163 |
+
est_params = param_factors.get(rank, float("inf")) if param_factors else None
|
| 164 |
+
if budget.max_params and est_params and est_params > budget.max_params:
|
| 165 |
+
continue
|
| 166 |
+
return rank
|
| 167 |
+
return 1
|
src/config.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration system for Q-TensorFormer v3.
|
| 3 |
+
|
| 4 |
+
Supports:
|
| 5 |
+
- YAML config files for experiment tracking
|
| 6 |
+
- Budget constraints (max params, max latency, max energy)
|
| 7 |
+
- Automatic hardware sizing
|
| 8 |
+
- Config validation
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import Optional, Tuple, List
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ModelConfig:
|
| 18 |
+
"""Core model architecture hyperparameters."""
|
| 19 |
+
d_model: int = 128
|
| 20 |
+
n_heads: int = 4
|
| 21 |
+
n_layers: int = 2
|
| 22 |
+
ff_multiplier: int = 4
|
| 23 |
+
max_seq_len: int = 128
|
| 24 |
+
vocab_size: int = 10000
|
| 25 |
+
dropout: float = 0.1
|
| 26 |
+
|
| 27 |
+
# Tensor network
|
| 28 |
+
tt_rank: int = 8
|
| 29 |
+
tt_min_rank: int = 2
|
| 30 |
+
use_tensor_ffn: bool = True
|
| 31 |
+
|
| 32 |
+
# Quantum
|
| 33 |
+
n_qubits: int = 4
|
| 34 |
+
n_quantum_layers: int = 2
|
| 35 |
+
quantum_sparsity: float = 0.3
|
| 36 |
+
use_quantum: bool = True
|
| 37 |
+
|
| 38 |
+
# Rank scheduler
|
| 39 |
+
rank_alpha: float = 2.0
|
| 40 |
+
rank_smoothing: float = 0.9
|
| 41 |
+
|
| 42 |
+
def validate(self):
|
| 43 |
+
assert self.d_model % self.n_heads == 0, f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
|
| 44 |
+
assert self.tt_rank >= 1, "tt_rank must be >= 1"
|
| 45 |
+
assert self.tt_min_rank >= 1, "tt_min_rank must be >= 1"
|
| 46 |
+
assert self.tt_min_rank <= self.tt_rank, "tt_min_rank must be <= tt_rank"
|
| 47 |
+
assert self.n_qubits <= 8, "n_qubits should be <= 8 for NISQ compatibility"
|
| 48 |
+
assert 0 <= self.quantum_sparsity <= 1, "quantum_sparsity must be in [0, 1]"
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class TrainingConfig:
|
| 54 |
+
"""Training hyperparameters."""
|
| 55 |
+
learning_rate: float = 3e-4
|
| 56 |
+
weight_decay: float = 0.01
|
| 57 |
+
warmup_steps: int = 100
|
| 58 |
+
max_epochs: int = 10
|
| 59 |
+
batch_size: int = 16
|
| 60 |
+
gradient_accumulation_steps: int = 1
|
| 61 |
+
max_grad_norm: float = 1.0
|
| 62 |
+
seed: int = 42
|
| 63 |
+
|
| 64 |
+
# Scheduler
|
| 65 |
+
lr_scheduler: str = "cosine" # cosine, linear, constant
|
| 66 |
+
lr_min_factor: float = 0.1
|
| 67 |
+
|
| 68 |
+
def validate(self):
|
| 69 |
+
assert self.learning_rate > 0
|
| 70 |
+
assert self.batch_size >= 1
|
| 71 |
+
assert self.seed >= 0
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@dataclass
|
| 76 |
+
class BudgetConfig:
|
| 77 |
+
"""Deployment budget constraints.
|
| 78 |
+
|
| 79 |
+
The model auto-adjusts tensor ranks and quantum usage to meet these.
|
| 80 |
+
"""
|
| 81 |
+
max_params: Optional[int] = None # Maximum trainable parameters
|
| 82 |
+
max_latency_ms: Optional[float] = None # Max inference latency (ms)
|
| 83 |
+
max_energy_per_query: Optional[float] = None # Max energy per query (μJ)
|
| 84 |
+
target_compression_ratio: Optional[float] = None # Target param reduction
|
| 85 |
+
|
| 86 |
+
def validate(self):
|
| 87 |
+
if self.max_params is not None:
|
| 88 |
+
assert self.max_params > 0
|
| 89 |
+
if self.max_latency_ms is not None:
|
| 90 |
+
assert self.max_latency_ms > 0
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class ExperimentConfig:
|
| 96 |
+
"""Master configuration combining all sub-configs."""
|
| 97 |
+
model: ModelConfig = field(default_factory=ModelConfig)
|
| 98 |
+
training: TrainingConfig = field(default_factory=TrainingConfig)
|
| 99 |
+
budget: BudgetConfig = field(default_factory=BudgetConfig)
|
| 100 |
+
experiment_name: str = "default"
|
| 101 |
+
output_dir: str = "./outputs"
|
| 102 |
+
wandb_project: Optional[str] = None
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_yaml(cls, path: str) -> "ExperimentConfig":
|
| 106 |
+
"""Load from YAML file."""
|
| 107 |
+
import yaml
|
| 108 |
+
with open(path) as f:
|
| 109 |
+
data = yaml.safe_load(f)
|
| 110 |
+
model = ModelConfig(**data.get("model", {}))
|
| 111 |
+
training = TrainingConfig(**data.get("training", {}))
|
| 112 |
+
budget = BudgetConfig(**data.get("budget", {}))
|
| 113 |
+
return cls(
|
| 114 |
+
model=model, training=training, budget=budget,
|
| 115 |
+
experiment_name=data.get("experiment_name", "default"),
|
| 116 |
+
output_dir=data.get("output_dir", "./outputs"),
|
| 117 |
+
wandb_project=data.get("wandb_project"),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def to_yaml(self, path: str):
|
| 121 |
+
"""Save to YAML file."""
|
| 122 |
+
import yaml
|
| 123 |
+
data = {
|
| 124 |
+
"experiment_name": self.experiment_name,
|
| 125 |
+
"output_dir": self.output_dir,
|
| 126 |
+
"wandb_project": self.wandb_project,
|
| 127 |
+
"model": {k: v for k, v in self.model.__dict__.items()},
|
| 128 |
+
"training": {k: v for k, v in self.training.__dict__.items()},
|
| 129 |
+
"budget": {k: v for k, v in self.budget.__dict__.items()},
|
| 130 |
+
}
|
| 131 |
+
with open(path, "w") as f:
|
| 132 |
+
yaml.dump(data, f, default_flow_style=False)
|
| 133 |
+
|
| 134 |
+
def validate(self):
|
| 135 |
+
self.model.validate()
|
| 136 |
+
self.training.validate()
|
| 137 |
+
self.budget.validate()
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Preset configurations
|
| 142 |
+
def tiny_config() -> ExperimentConfig:
|
| 143 |
+
return ExperimentConfig(
|
| 144 |
+
model=ModelConfig(d_model=64, n_layers=2, n_heads=4, tt_rank=4, vocab_size=5000),
|
| 145 |
+
training=TrainingConfig(max_epochs=5, batch_size=16),
|
| 146 |
+
experiment_name="tiny",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def small_config() -> ExperimentConfig:
|
| 151 |
+
return ExperimentConfig(
|
| 152 |
+
model=ModelConfig(d_model=128, n_layers=2, n_heads=4, tt_rank=8, vocab_size=10000),
|
| 153 |
+
training=TrainingConfig(max_epochs=8, batch_size=16),
|
| 154 |
+
experiment_name="small",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def medium_config() -> ExperimentConfig:
|
| 159 |
+
return ExperimentConfig(
|
| 160 |
+
model=ModelConfig(d_model=256, n_layers=4, n_heads=8, tt_rank=12, vocab_size=20000),
|
| 161 |
+
training=TrainingConfig(max_epochs=10, batch_size=8),
|
| 162 |
+
experiment_name="medium",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def production_config() -> ExperimentConfig:
|
| 167 |
+
return ExperimentConfig(
|
| 168 |
+
model=ModelConfig(d_model=512, n_layers=6, n_heads=8, tt_rank=16, vocab_size=30000),
|
| 169 |
+
training=TrainingConfig(max_epochs=15, batch_size=4, gradient_accumulation_steps=4),
|
| 170 |
+
budget=BudgetConfig(max_latency_ms=50.0, target_compression_ratio=2.0),
|
| 171 |
+
experiment_name="production",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
PRESETS = {
|
| 176 |
+
"tiny": tiny_config,
|
| 177 |
+
"small": small_config,
|
| 178 |
+
"medium": medium_config,
|
| 179 |
+
"production": production_config,
|
| 180 |
+
}
|
src/data.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading and preprocessing.
|
| 3 |
+
|
| 4 |
+
Supported datasets:
|
| 5 |
+
- WikiText-2 (char-level and word-level)
|
| 6 |
+
- WikiText-103
|
| 7 |
+
- Custom text files
|
| 8 |
+
- Synthetic random data (debugging)
|
| 9 |
+
|
| 10 |
+
Tokenization: character-level by default. Simple, deterministic, no external deps.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
from typing import Optional, Tuple, Dict
|
| 16 |
+
from collections import Counter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CharTokenizer:
|
| 20 |
+
"""Character-level tokenizer. Vocabulary built from data."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, min_freq: int = 1):
|
| 23 |
+
self.min_freq = min_freq
|
| 24 |
+
self.char_to_idx: Dict[str, int] = {}
|
| 25 |
+
self.idx_to_char: Dict[int, str] = {}
|
| 26 |
+
self.vocab_size = 0
|
| 27 |
+
self.special_tokens = {
|
| 28 |
+
"<pad>": 0,
|
| 29 |
+
"<bos>": 1,
|
| 30 |
+
"<eos>": 2,
|
| 31 |
+
"<unk>": 3,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def fit(self, texts: list[str]):
|
| 35 |
+
"""Build vocabulary from texts."""
|
| 36 |
+
char_counts = Counter()
|
| 37 |
+
for text in texts:
|
| 38 |
+
char_counts.update(text)
|
| 39 |
+
|
| 40 |
+
# Special tokens first
|
| 41 |
+
self.char_to_idx = dict(self.special_tokens)
|
| 42 |
+
# Freq-filtered chars
|
| 43 |
+
idx = len(self.special_tokens)
|
| 44 |
+
for char, count in char_counts.most_common():
|
| 45 |
+
if count >= self.min_freq:
|
| 46 |
+
self.char_to_idx[char] = idx
|
| 47 |
+
idx += 1
|
| 48 |
+
|
| 49 |
+
self.idx_to_char = {v: k for k, v in self.char_to_idx.items()}
|
| 50 |
+
self.vocab_size = len(self.char_to_idx)
|
| 51 |
+
|
| 52 |
+
def encode(self, text: str, add_bos: bool = True,
|
| 53 |
+
add_eos: bool = True, max_len: int = None) -> list[int]:
|
| 54 |
+
"""Convert text to token indices."""
|
| 55 |
+
tokens = []
|
| 56 |
+
if add_bos:
|
| 57 |
+
tokens.append(self.special_tokens["<bos>"])
|
| 58 |
+
for ch in text:
|
| 59 |
+
tokens.append(self.char_to_idx.get(ch, self.special_tokens["<unk>"]))
|
| 60 |
+
if add_eos:
|
| 61 |
+
tokens.append(self.special_tokens["<eos>"])
|
| 62 |
+
if max_len is not None:
|
| 63 |
+
if len(tokens) > max_len:
|
| 64 |
+
tokens = tokens[:max_len]
|
| 65 |
+
else:
|
| 66 |
+
tokens.extend([self.special_tokens["<pad>"]] * (max_len - len(tokens)))
|
| 67 |
+
return tokens
|
| 68 |
+
|
| 69 |
+
def decode(self, indices: list[int], skip_special: bool = True) -> str:
|
| 70 |
+
"""Convert indices back to text."""
|
| 71 |
+
chars = []
|
| 72 |
+
for idx in indices:
|
| 73 |
+
ch = self.idx_to_char.get(idx, "?")
|
| 74 |
+
if skip_special and idx in self.special_tokens.values():
|
| 75 |
+
continue
|
| 76 |
+
chars.append(ch)
|
| 77 |
+
return "".join(chars)
|
| 78 |
+
|
| 79 |
+
def save(self, path: str):
|
| 80 |
+
torch.save({
|
| 81 |
+
"char_to_idx": self.char_to_idx,
|
| 82 |
+
"idx_to_char": self.idx_to_char,
|
| 83 |
+
"vocab_size": self.vocab_size,
|
| 84 |
+
"special_tokens": self.special_tokens,
|
| 85 |
+
}, path)
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def load(cls, path: str) -> "CharTokenizer":
|
| 89 |
+
data = torch.load(path)
|
| 90 |
+
tok = cls()
|
| 91 |
+
tok.char_to_idx = data["char_to_idx"]
|
| 92 |
+
tok.idx_to_char = data["idx_to_char"]
|
| 93 |
+
tok.vocab_size = data["vocab_size"]
|
| 94 |
+
tok.special_tokens = data["special_tokens"]
|
| 95 |
+
return tok
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class TextDataset(Dataset):
|
| 99 |
+
"""
|
| 100 |
+
Causal language modeling dataset.
|
| 101 |
+
|
| 102 |
+
Splits text into overlapping sequences of length seq_len.
|
| 103 |
+
Target = input shifted by 1 (next-token prediction).
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(self, texts: list[str], tokenizer: CharTokenizer,
|
| 107 |
+
seq_len: int = 128, stride: int = None):
|
| 108 |
+
self.seq_len = seq_len
|
| 109 |
+
self.stride = stride or seq_len // 2
|
| 110 |
+
|
| 111 |
+
# Tokenize all texts
|
| 112 |
+
all_tokens = []
|
| 113 |
+
for text in texts:
|
| 114 |
+
all_tokens.extend(tokenizer.encode(text, add_bos=False, add_eos=True))
|
| 115 |
+
self.tokens = torch.tensor(all_tokens, dtype=torch.long)
|
| 116 |
+
|
| 117 |
+
# Compute valid starting positions
|
| 118 |
+
self.n_samples = max(0, (len(self.tokens) - seq_len - 1) // self.stride + 1)
|
| 119 |
+
|
| 120 |
+
def __len__(self):
|
| 121 |
+
return self.n_samples
|
| 122 |
+
|
| 123 |
+
def __getitem__(self, idx):
|
| 124 |
+
start = idx * self.stride
|
| 125 |
+
end = start + self.seq_len
|
| 126 |
+
x = self.tokens[start:end]
|
| 127 |
+
y = self.tokens[start + 1:end + 1]
|
| 128 |
+
assert len(x) == len(y) == self.seq_len, f"len={len(x)} at idx={idx}"
|
| 129 |
+
return x, y
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def load_wikitext2(tokenizer: CharTokenizer = None,
|
| 133 |
+
seq_len: int = 128,
|
| 134 |
+
batch_size: int = 16) -> Tuple[DataLoader, DataLoader, DataLoader, CharTokenizer]:
|
| 135 |
+
"""
|
| 136 |
+
Load WikiText-2 with char-level tokenization.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
train_loader, val_loader, test_loader, tokenizer
|
| 140 |
+
"""
|
| 141 |
+
try:
|
| 142 |
+
from datasets import load_dataset
|
| 143 |
+
except ImportError:
|
| 144 |
+
raise ImportError("pip install datasets")
|
| 145 |
+
|
| 146 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
|
| 147 |
+
|
| 148 |
+
# Filter empty lines
|
| 149 |
+
train_texts = [t for t in ds["train"]["text"] if t.strip()]
|
| 150 |
+
val_texts = [t for t in ds["validation"]["text"] if t.strip()]
|
| 151 |
+
test_texts = [t for t in ds["test"]["text"] if t.strip()]
|
| 152 |
+
|
| 153 |
+
if tokenizer is None:
|
| 154 |
+
tokenizer = CharTokenizer()
|
| 155 |
+
tokenizer.fit(train_texts)
|
| 156 |
+
|
| 157 |
+
train_ds = TextDataset(train_texts, tokenizer, seq_len)
|
| 158 |
+
val_ds = TextDataset(val_texts, tokenizer, seq_len)
|
| 159 |
+
test_ds = TextDataset(test_texts, tokenizer, seq_len)
|
| 160 |
+
|
| 161 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
| 162 |
+
num_workers=0, drop_last=True)
|
| 163 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 164 |
+
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| 165 |
+
|
| 166 |
+
return train_loader, val_loader, test_loader, tokenizer
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def load_synthetic_data(vocab_size: int = 5000, seq_len: int = 128,
|
| 170 |
+
n_samples: int = 2000, batch_size: int = 16):
|
| 171 |
+
"""Synthetic random data for debugging."""
|
| 172 |
+
class _SynthDataset(Dataset):
|
| 173 |
+
def __init__(self, n, vocab, slen):
|
| 174 |
+
self.data = torch.randint(1, vocab, (n, slen + 1))
|
| 175 |
+
def __len__(self):
|
| 176 |
+
return len(self.data)
|
| 177 |
+
def __getitem__(self, i):
|
| 178 |
+
return self.data[i, :-1], self.data[i, 1:]
|
| 179 |
+
ds = _SynthDataset(n_samples, vocab_size, seq_len)
|
| 180 |
+
return DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=0)
|
src/metrics.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive metrics for evaluation.
|
| 3 |
+
|
| 4 |
+
v3 features:
|
| 5 |
+
- Perplexity (primary LM metric)
|
| 6 |
+
- Parameter counts (total, compressed, ratio)
|
| 7 |
+
- Latency benchmarks (warm-up + measured)
|
| 8 |
+
- FLOPs estimation (proxy for energy)
|
| 9 |
+
- Quantum call statistics
|
| 10 |
+
- Rank trajectory analysis
|
| 11 |
+
- Pareto frontier computation (PPL vs params)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import time
|
| 16 |
+
import math
|
| 17 |
+
from typing import Dict, List, Optional
|
| 18 |
+
from .config import ExperimentConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def evaluate_model(model, test_loader, device: str = "cpu",
|
| 22 |
+
max_batches: int = None) -> Dict:
|
| 23 |
+
"""
|
| 24 |
+
Comprehensive model evaluation.
|
| 25 |
+
|
| 26 |
+
Metrics:
|
| 27 |
+
- test_ppl: Perplexity on test set
|
| 28 |
+
- total_params, trainable_params
|
| 29 |
+
- latency_p50, latency_p95 (ms per sample)
|
| 30 |
+
- peak_memory_mb
|
| 31 |
+
- flops_estimate
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model: nn.Module to evaluate.
|
| 35 |
+
test_loader: DataLoader with (input, target) batches.
|
| 36 |
+
device: Device string.
|
| 37 |
+
max_batches: Limit eval to N batches (None = all).
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dict with all metrics.
|
| 41 |
+
"""
|
| 42 |
+
model.eval()
|
| 43 |
+
model.to(device)
|
| 44 |
+
|
| 45 |
+
total_loss = 0.0
|
| 46 |
+
total_tokens = 0
|
| 47 |
+
latencies = []
|
| 48 |
+
|
| 49 |
+
for i, (inputs, targets) in enumerate(test_loader):
|
| 50 |
+
if max_batches and i >= max_batches:
|
| 51 |
+
break
|
| 52 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 53 |
+
|
| 54 |
+
# Warm-up GPU
|
| 55 |
+
if i == 0:
|
| 56 |
+
_ = model(inputs)
|
| 57 |
+
if device != "cpu":
|
| 58 |
+
torch.cuda.synchronize()
|
| 59 |
+
|
| 60 |
+
# Timed forward
|
| 61 |
+
t0 = time.time()
|
| 62 |
+
logits = model(inputs)
|
| 63 |
+
if device != "cpu":
|
| 64 |
+
torch.cuda.synchronize()
|
| 65 |
+
elapsed = (time.time() - t0) * 1000 # ms
|
| 66 |
+
latencies.append(elapsed / inputs.size(0))
|
| 67 |
+
|
| 68 |
+
loss = torch.nn.functional.cross_entropy(
|
| 69 |
+
logits.reshape(-1, logits.size(-1)),
|
| 70 |
+
targets.reshape(-1),
|
| 71 |
+
ignore_index=0,
|
| 72 |
+
reduction="sum",
|
| 73 |
+
)
|
| 74 |
+
total_loss += loss.item()
|
| 75 |
+
total_tokens += inputs.numel()
|
| 76 |
+
|
| 77 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 78 |
+
ppl = math.exp(min(avg_loss, 20.0))
|
| 79 |
+
|
| 80 |
+
# Sort latencies for percentile reporting
|
| 81 |
+
latencies.sort()
|
| 82 |
+
n = len(latencies)
|
| 83 |
+
|
| 84 |
+
result = {
|
| 85 |
+
"test_ppl": ppl,
|
| 86 |
+
"test_loss": avg_loss,
|
| 87 |
+
"total_params": sum(p.numel() for p in model.parameters()),
|
| 88 |
+
"trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
|
| 89 |
+
"latency_ms_mean": sum(latencies) / n,
|
| 90 |
+
"latency_ms_p50": latencies[n // 2],
|
| 91 |
+
"latency_ms_p95": latencies[min(int(n * 0.95), n - 1)],
|
| 92 |
+
"n_samples_evaluated": n,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
# Model-specific stats
|
| 96 |
+
if hasattr(model, "stats"):
|
| 97 |
+
result["model_stats"] = model.stats
|
| 98 |
+
|
| 99 |
+
if hasattr(model, "compression_ratio"):
|
| 100 |
+
result["compression_ratio"] = model.compression_ratio
|
| 101 |
+
|
| 102 |
+
return result
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compare_models(models: Dict[str, object], test_loader,
|
| 106 |
+
device: str = "cpu") -> Dict[str, Dict]:
|
| 107 |
+
"""
|
| 108 |
+
Compare multiple models on the same test set.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
models: Dict[name → model]
|
| 112 |
+
test_loader: DataLoader.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Dict[name → metrics]
|
| 116 |
+
"""
|
| 117 |
+
results = {}
|
| 118 |
+
for name, model in models.items():
|
| 119 |
+
print(f"Evaluating {name}...")
|
| 120 |
+
results[name] = evaluate_model(model, test_loader, device)
|
| 121 |
+
return results
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def compute_pareto_frontier(results: Dict[str, Dict],
|
| 125 |
+
x_key: str = "total_params",
|
| 126 |
+
y_key: str = "test_ppl",
|
| 127 |
+
minimize_y: bool = True) -> List[str]:
|
| 128 |
+
"""
|
| 129 |
+
Find Pareto-optimal models from comparison results.
|
| 130 |
+
|
| 131 |
+
A model is Pareto-optimal if no other model has:
|
| 132 |
+
- Fewer parameters AND better perplexity
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
results: Dict[name → metrics]
|
| 136 |
+
x_key: Metric for x-axis (e.g., total_params)
|
| 137 |
+
y_key: Metric for y-axis (e.g., test_ppl)
|
| 138 |
+
minimize_y: True if lower y is better.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List of Pareto-optimal model names.
|
| 142 |
+
"""
|
| 143 |
+
pareto = []
|
| 144 |
+
names = list(results.keys())
|
| 145 |
+
|
| 146 |
+
for i, name_i in enumerate(names):
|
| 147 |
+
xi = results[name_i][x_key]
|
| 148 |
+
yi = results[name_i][y_key]
|
| 149 |
+
dominated = False
|
| 150 |
+
|
| 151 |
+
for j, name_j in enumerate(names):
|
| 152 |
+
if i == j:
|
| 153 |
+
continue
|
| 154 |
+
xj = results[name_j][x_key]
|
| 155 |
+
yj = results[name_j][y_key]
|
| 156 |
+
|
| 157 |
+
if minimize_y:
|
| 158 |
+
# j dominates i: j has fewer params AND better PPL
|
| 159 |
+
if xj <= xi and yj <= yi and (xj < xi or yj < yi):
|
| 160 |
+
dominated = True
|
| 161 |
+
break
|
| 162 |
+
else:
|
| 163 |
+
if xj <= xi and yj >= yi and (xj < xi or yj > yi):
|
| 164 |
+
dominated = True
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
if not dominated:
|
| 168 |
+
pareto.append(name_i)
|
| 169 |
+
|
| 170 |
+
return pareto
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def compute_efficiency_score(result: Dict) -> float:
|
| 174 |
+
"""
|
| 175 |
+
Combined efficiency score (higher is better).
|
| 176 |
+
|
| 177 |
+
Efficiency = 1 / (PPL × √params × latency_ms)
|
| 178 |
+
|
| 179 |
+
Normalized so that better models get higher scores.
|
| 180 |
+
"""
|
| 181 |
+
ppl = max(result["test_ppl"], 1.0)
|
| 182 |
+
params = max(result["total_params"], 1)
|
| 183 |
+
latency = max(result.get("latency_ms_mean", 1.0), 0.1)
|
| 184 |
+
|
| 185 |
+
# 1 / (PPL * sqrt(params) * latency): simpler = better
|
| 186 |
+
score = 1.0 / (ppl * math.sqrt(params / 1e6) * latency)
|
| 187 |
+
return score * 1e6 # Scale for readability
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def rank_trajectory_analysis(metrics_history: List[Dict]) -> Dict:
|
| 191 |
+
"""
|
| 192 |
+
Analyze rank adaptation over training.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
metrics_history: List of per-epoch metrics from Trainer.
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Dict with rank statistics.
|
| 199 |
+
"""
|
| 200 |
+
if not metrics_history or "model_stats" not in metrics_history[-1]:
|
| 201 |
+
return {}
|
| 202 |
+
|
| 203 |
+
ranks_over_time = []
|
| 204 |
+
for epoch_data in metrics_history:
|
| 205 |
+
model_stats = epoch_data.get("model_stats", {})
|
| 206 |
+
rank_history = model_stats.get("rank_history", {})
|
| 207 |
+
if rank_history:
|
| 208 |
+
ranks_over_time.append(rank_history)
|
| 209 |
+
|
| 210 |
+
if not ranks_over_time:
|
| 211 |
+
return {}
|
| 212 |
+
|
| 213 |
+
final_ranks = ranks_over_time[-1]
|
| 214 |
+
return {
|
| 215 |
+
"final_ranks": final_ranks,
|
| 216 |
+
"rank_variance": sum(
|
| 217 |
+
(r - sum(final_ranks.values()) / len(final_ranks)) ** 2
|
| 218 |
+
for r in final_ranks.values()
|
| 219 |
+
) / len(final_ranks),
|
| 220 |
+
"n_epochs_converged": len(ranks_over_time),
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def print_comparison_table(results: Dict[str, Dict]):
|
| 225 |
+
"""Pretty-print comparison table."""
|
| 226 |
+
header = f"{'Model':<20} {'PPL':>8} {'Params':>10} {'Lat(ms)':>10} {'Score':>10}"
|
| 227 |
+
print("=" * len(header))
|
| 228 |
+
print(header)
|
| 229 |
+
print("-" * len(header))
|
| 230 |
+
|
| 231 |
+
for name, r in sorted(results.items(), key=lambda x: x[1]["test_ppl"]):
|
| 232 |
+
score = compute_efficiency_score(r)
|
| 233 |
+
params_k = r["total_params"] / 1000
|
| 234 |
+
print(f"{name:<20} {r['test_ppl']:8.2f} {params_k:8.1f}K "
|
| 235 |
+
f"{r.get('latency_ms_mean', 0):8.2f} {score:10.1f}")
|
| 236 |
+
|
| 237 |
+
print("=" * len(header))
|
| 238 |
+
|
| 239 |
+
pareto = compute_pareto_frontier(results)
|
| 240 |
+
print(f"\nPareto-optimal models: {pareto}")
|
src/models.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Q-TensorFormer v3: Complete Model Architectures.
|
| 3 |
+
|
| 4 |
+
Model variants:
|
| 5 |
+
- QTensorFormer: Full hybrid model (TT-FFN + quantum + adaptive rank)
|
| 6 |
+
- TensorBaseline: TT-FFN only (no quantum, fixed rank)
|
| 7 |
+
- DenseBaseline: Standard transformer (no TT, no quantum)
|
| 8 |
+
- DistilledVariants: Knowledge-distilled compact models
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import math
|
| 14 |
+
from typing import Optional, Dict, List
|
| 15 |
+
|
| 16 |
+
from .blocks import HybridBlock
|
| 17 |
+
from .config import ModelConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class PositionalEncoding(nn.Module):
|
| 21 |
+
"""Fixed sinusoidal positional encoding."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, d_model: int, max_len: int = 128, dropout: float = 0.1):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.dropout = nn.Dropout(dropout)
|
| 26 |
+
|
| 27 |
+
pe = torch.zeros(max_len, d_model)
|
| 28 |
+
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
|
| 29 |
+
div_term = torch.exp(
|
| 30 |
+
torch.arange(0, d_model, 2, dtype=torch.float32) *
|
| 31 |
+
(-math.log(10000.0) / d_model)
|
| 32 |
+
)
|
| 33 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 34 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 35 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 36 |
+
|
| 37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
return self.dropout(x + self.pe[:, :x.size(1), :])
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class QTensorFormer(nn.Module):
|
| 42 |
+
"""
|
| 43 |
+
Quantum-Enhanced Tensor Network Transformer.
|
| 44 |
+
|
| 45 |
+
Full hybrid model: replaces FFN with TT decomposition and adds
|
| 46 |
+
quantum feature routing with adaptive rank scheduling.
|
| 47 |
+
|
| 48 |
+
Parameters
|
| 49 |
+
----------
|
| 50 |
+
config : ModelConfig
|
| 51 |
+
Model configuration.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, config: ModelConfig):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.config = config
|
| 57 |
+
|
| 58 |
+
# Embeddings
|
| 59 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 60 |
+
self.pos_encoding = PositionalEncoding(
|
| 61 |
+
config.d_model, config.max_seq_len, config.dropout
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Transformer blocks
|
| 65 |
+
self.blocks = nn.ModuleList([
|
| 66 |
+
HybridBlock(
|
| 67 |
+
d_model=config.d_model,
|
| 68 |
+
n_heads=config.n_heads,
|
| 69 |
+
ff_multiplier=config.ff_multiplier,
|
| 70 |
+
tt_rank=config.tt_rank,
|
| 71 |
+
tt_min_rank=config.tt_min_rank,
|
| 72 |
+
use_quantum=config.use_quantum,
|
| 73 |
+
n_qubits=config.n_qubits,
|
| 74 |
+
n_quantum_layers=config.n_quantum_layers,
|
| 75 |
+
quantum_sparsity=config.quantum_sparsity,
|
| 76 |
+
rank_alpha=config.rank_alpha,
|
| 77 |
+
rank_smoothing=config.rank_smoothing,
|
| 78 |
+
dropout=config.dropout,
|
| 79 |
+
max_seq_len=config.max_seq_len,
|
| 80 |
+
)
|
| 81 |
+
for _ in range(config.n_layers)
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
# Output
|
| 85 |
+
self.ln_f = nn.LayerNorm(config.d_model)
|
| 86 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 87 |
+
|
| 88 |
+
# Weight tying: embedding matrix = LM head
|
| 89 |
+
self.lm_head.weight = self.embedding.weight
|
| 90 |
+
|
| 91 |
+
self._post_init()
|
| 92 |
+
|
| 93 |
+
def _post_init(self):
|
| 94 |
+
"""Initialize weights."""
|
| 95 |
+
for name, param in self.named_parameters():
|
| 96 |
+
if "weight" in name and param.dim() >= 2:
|
| 97 |
+
nn.init.xavier_uniform_(param)
|
| 98 |
+
elif "bias" in name:
|
| 99 |
+
nn.init.zeros_(param)
|
| 100 |
+
|
| 101 |
+
def forward(self, input_ids: torch.Tensor,
|
| 102 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 103 |
+
return_stats: bool = False):
|
| 104 |
+
"""
|
| 105 |
+
Args:
|
| 106 |
+
input_ids: (batch, seq_len) token indices
|
| 107 |
+
attention_mask: (batch, seq_len) optional padding mask
|
| 108 |
+
return_stats: return per-block statistics
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
logits: (batch, seq_len, vocab_size)
|
| 112 |
+
stats: list of per-block stats dicts (if return_stats=True)
|
| 113 |
+
"""
|
| 114 |
+
x = self.embedding(input_ids)
|
| 115 |
+
x = self.pos_encoding(x)
|
| 116 |
+
|
| 117 |
+
all_stats = []
|
| 118 |
+
for block in self.blocks:
|
| 119 |
+
x, stats = block(x, mask=attention_mask)
|
| 120 |
+
all_stats.append(stats)
|
| 121 |
+
|
| 122 |
+
x = self.ln_f(x)
|
| 123 |
+
logits = self.lm_head(x)
|
| 124 |
+
|
| 125 |
+
if return_stats:
|
| 126 |
+
return logits, all_stats
|
| 127 |
+
return logits
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 20,
|
| 131 |
+
temperature: float = 1.0, top_k: int = 50) -> torch.Tensor:
|
| 132 |
+
"""Simple autoregressive generation."""
|
| 133 |
+
self.eval()
|
| 134 |
+
for _ in range(max_new_tokens):
|
| 135 |
+
if input_ids.size(1) > self.config.max_seq_len:
|
| 136 |
+
input_ids = input_ids[:, -self.config.max_seq_len:]
|
| 137 |
+
logits = self(input_ids)
|
| 138 |
+
logits = logits[:, -1, :] / temperature
|
| 139 |
+
if top_k > 0:
|
| 140 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 141 |
+
logits[logits < v[:, [-1]]] = float("-inf")
|
| 142 |
+
probs = torch.softmax(logits, dim=-1)
|
| 143 |
+
next_token = torch.multinomial(probs, 1)
|
| 144 |
+
input_ids = torch.cat([input_ids, next_token], dim=-1)
|
| 145 |
+
return input_ids
|
| 146 |
+
|
| 147 |
+
def reset_schedulers(self):
|
| 148 |
+
"""Reset all rank schedulers and quantum routers."""
|
| 149 |
+
for block in self.blocks:
|
| 150 |
+
block.reset_scheduler()
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def stats(self) -> Dict:
|
| 154 |
+
"""Runtime statistics across all blocks."""
|
| 155 |
+
stats = {
|
| 156 |
+
"total_params": self.total_params,
|
| 157 |
+
"tt_params": self.tt_params,
|
| 158 |
+
"compression_ratio": self.compression_ratio,
|
| 159 |
+
"rank_history": {},
|
| 160 |
+
"quantum_usage": {},
|
| 161 |
+
}
|
| 162 |
+
for i, block in enumerate(self.blocks):
|
| 163 |
+
stats["rank_history"][i] = block.rank_scheduler.current_rank
|
| 164 |
+
if block.quantum_router is not None:
|
| 165 |
+
stats["quantum_usage"][i] = block.quantum_router.usage_percent
|
| 166 |
+
return stats
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def total_params(self) -> int:
|
| 170 |
+
return sum(p.numel() for p in self.parameters())
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def trainable_params(self) -> int:
|
| 174 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def tt_params(self) -> int:
|
| 178 |
+
"""Count only TT-decomposed parameters."""
|
| 179 |
+
count = 0
|
| 180 |
+
for block in self.blocks:
|
| 181 |
+
for core in block.tt_ffn.up_proj.cores:
|
| 182 |
+
count += core.numel()
|
| 183 |
+
for core in block.tt_ffn.down_proj.cores:
|
| 184 |
+
count += core.numel()
|
| 185 |
+
return count
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def compression_ratio(self) -> float:
|
| 189 |
+
"""Estimated compression ratio vs. dense equivalent."""
|
| 190 |
+
dense_per_block = 2 * self.config.d_model * self.config.d_model * self.config.ff_multiplier
|
| 191 |
+
base = self.total_params - self.tt_params
|
| 192 |
+
tt = self.tt_params
|
| 193 |
+
return (base + dense_per_block * self.config.n_layers) / max(base + tt, 1)
|
| 194 |
+
|
| 195 |
+
def flops_estimate(self, batch_size: int = 1, seq_len: int = None) -> Dict:
|
| 196 |
+
"""Estimate total FLOPs."""
|
| 197 |
+
T = seq_len or self.config.max_seq_len
|
| 198 |
+
total = 0
|
| 199 |
+
breakdown = {}
|
| 200 |
+
for i, block in enumerate(self.blocks):
|
| 201 |
+
b = block.flops_estimate(batch_size, T)
|
| 202 |
+
total += b["total"]
|
| 203 |
+
breakdown[f"block_{i}"] = b
|
| 204 |
+
return {"total": total, "breakdown": breakdown}
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class DenseBaseline(nn.Module):
|
| 208 |
+
"""
|
| 209 |
+
Standard transformer baseline — no TT, no quantum.
|
| 210 |
+
|
| 211 |
+
Same hyperparameters as QTensorFormer for fair comparison.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self, config: ModelConfig):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.config = config
|
| 217 |
+
|
| 218 |
+
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
|
| 219 |
+
self.pos_encoding = PositionalEncoding(
|
| 220 |
+
config.d_model, config.max_seq_len, config.dropout
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
self.blocks = nn.ModuleList([
|
| 224 |
+
nn.ModuleDict({
|
| 225 |
+
"ln1": nn.LayerNorm(config.d_model),
|
| 226 |
+
"attn": nn.MultiheadAttention(
|
| 227 |
+
config.d_model, config.n_heads,
|
| 228 |
+
dropout=config.dropout, batch_first=True
|
| 229 |
+
),
|
| 230 |
+
"ln2": nn.LayerNorm(config.d_model),
|
| 231 |
+
"ffn": nn.Sequential(
|
| 232 |
+
nn.Linear(config.d_model, config.d_model * config.ff_multiplier),
|
| 233 |
+
nn.GELU(),
|
| 234 |
+
nn.Linear(config.d_model * config.ff_multiplier, config.d_model),
|
| 235 |
+
),
|
| 236 |
+
"dropout": nn.Dropout(config.dropout),
|
| 237 |
+
})
|
| 238 |
+
for _ in range(config.n_layers)
|
| 239 |
+
])
|
| 240 |
+
|
| 241 |
+
self.ln_f = nn.LayerNorm(config.d_model)
|
| 242 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 243 |
+
self.lm_head.weight = self.embedding.weight
|
| 244 |
+
|
| 245 |
+
def forward(self, input_ids, attention_mask=None, return_stats=False):
|
| 246 |
+
x = self.embedding(input_ids)
|
| 247 |
+
x = self.pos_encoding(x)
|
| 248 |
+
|
| 249 |
+
for block in self.blocks:
|
| 250 |
+
attn_out, _ = block["attn"](
|
| 251 |
+
block["ln1"](x), block["ln1"](x), block["ln1"](x),
|
| 252 |
+
key_padding_mask=attention_mask, need_weights=False
|
| 253 |
+
)
|
| 254 |
+
x = x + block["dropout"](attn_out)
|
| 255 |
+
|
| 256 |
+
ffn_out = block["ffn"](block["ln2"](x))
|
| 257 |
+
x = x + block["dropout"](ffn_out)
|
| 258 |
+
|
| 259 |
+
x = self.ln_f(x)
|
| 260 |
+
logits = self.lm_head(x)
|
| 261 |
+
|
| 262 |
+
if return_stats:
|
| 263 |
+
return logits, []
|
| 264 |
+
return logits
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def total_params(self) -> int:
|
| 268 |
+
return sum(p.numel() for p in self.parameters())
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def create_model(config: ModelConfig, model_type: str = "qtensor") -> nn.Module:
|
| 272 |
+
"""
|
| 273 |
+
Factory for model creation.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
config: ModelConfig instance.
|
| 277 |
+
model_type: 'qtensor', 'tensor_only' (no quantum), 'dense' (baseline),
|
| 278 |
+
'distilled' (knowledge-distilled compact).
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
nn.Module instance.
|
| 282 |
+
"""
|
| 283 |
+
if model_type == "qtensor":
|
| 284 |
+
config.use_quantum = True
|
| 285 |
+
return QTensorFormer(config)
|
| 286 |
+
elif model_type == "tensor_only":
|
| 287 |
+
config.use_quantum = False
|
| 288 |
+
return QTensorFormer(config)
|
| 289 |
+
elif model_type == "dense":
|
| 290 |
+
return DenseBaseline(config)
|
| 291 |
+
elif model_type == "distilled":
|
| 292 |
+
config.use_quantum = True
|
| 293 |
+
config.tt_rank = max(2, config.tt_rank // 2) # More aggressively compressed
|
| 294 |
+
return QTensorFormer(config)
|
| 295 |
+
else:
|
| 296 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
src/quantum_layers.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quantum Feature Encoding Layers.
|
| 3 |
+
|
| 4 |
+
PennyLane-based quantum circuits wrapped as PyTorch nn.Module layers.
|
| 5 |
+
|
| 6 |
+
Components:
|
| 7 |
+
- QuantumAngleEmbedding: Classical data → rotation angles on qubits
|
| 8 |
+
- QuantumAmplitudeEmbedding: Encodes data as quantum amplitudes
|
| 9 |
+
- EntanglementMonitor: Estimates entanglement via attention patterns
|
| 10 |
+
- ClassicalQuantumFallback: MLP-based fallback when PennyLane unavailable
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import math
|
| 17 |
+
from typing import Optional, Tuple, List
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import pennylane as qml
|
| 21 |
+
HAS_PENNYLANE = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
HAS_PENNYLANE = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class QuantumAngleEmbedding(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Encodes classical features into quantum states via angle encoding.
|
| 29 |
+
|
| 30 |
+
Circuit: RX(input) → [RY(θ) → CNOT ladder] × n_layers → ⟨Z_i⟩
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
n_qubits : int
|
| 35 |
+
Number of qubits (4-8 for NISQ compatibility).
|
| 36 |
+
n_layers : int
|
| 37 |
+
Number of variational circuit layers.
|
| 38 |
+
n_outputs : int or None
|
| 39 |
+
Number of expectation values to measure. Default: n_qubits.
|
| 40 |
+
diff_method : str
|
| 41 |
+
Differentiation method. 'backprop' for batched inputs,
|
| 42 |
+
'parameter-shift' for hardware compatibility.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, n_qubits: int = 4, n_layers: int = 2,
|
| 46 |
+
n_outputs: int = None, diff_method: str = "backprop"):
|
| 47 |
+
super().__init__()
|
| 48 |
+
if not HAS_PENNYLANE:
|
| 49 |
+
raise ImportError(
|
| 50 |
+
"PennyLane is required for quantum layers. "
|
| 51 |
+
"Install with: pip install pennylane"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.n_qubits = n_qubits
|
| 55 |
+
self.n_layers = n_layers
|
| 56 |
+
self.n_outputs = n_outputs or n_qubits
|
| 57 |
+
|
| 58 |
+
dev = qml.device("default.qubit", wires=n_qubits)
|
| 59 |
+
|
| 60 |
+
@qml.qnode(dev, interface="torch", diff_method=diff_method)
|
| 61 |
+
def circuit(inputs, weights):
|
| 62 |
+
# Angle encoding
|
| 63 |
+
for i in range(n_qubits):
|
| 64 |
+
qml.RX(inputs[..., i], wires=i)
|
| 65 |
+
|
| 66 |
+
# Variational layers with entanglement
|
| 67 |
+
for layer in range(n_layers):
|
| 68 |
+
for i in range(n_qubits):
|
| 69 |
+
qml.RY(weights[layer, i], wires=i)
|
| 70 |
+
# Nearest-neighbor CNOT ladder
|
| 71 |
+
for i in range(n_qubits - 1):
|
| 72 |
+
qml.CNOT(wires=[i, i + 1])
|
| 73 |
+
# Cyclic entanglement for >2 qubits
|
| 74 |
+
if n_qubits > 2:
|
| 75 |
+
qml.CNOT(wires=[n_qubits - 1, 0])
|
| 76 |
+
|
| 77 |
+
# Measure PauliZ expectation values
|
| 78 |
+
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_outputs)]
|
| 79 |
+
|
| 80 |
+
weight_shapes = {"weights": (n_layers, n_qubits)}
|
| 81 |
+
self.qlayer = qml.qnn.TorchLayer(circuit, weight_shapes)
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
x: (*batch, n_qubits) — classical inputs mapped to rotation angles
|
| 87 |
+
Returns:
|
| 88 |
+
(*batch, n_outputs) — PauliZ expectation values in [-1, 1]
|
| 89 |
+
"""
|
| 90 |
+
return self.qlayer(x)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class EntanglementMonitor(nn.Module):
|
| 94 |
+
"""
|
| 95 |
+
Estimates entanglement entropy from attention patterns.
|
| 96 |
+
|
| 97 |
+
Uses attention distribution entropy as a classical proxy
|
| 98 |
+
for quantum entanglement entropy. Avoids expensive quantum
|
| 99 |
+
state tomography during training.
|
| 100 |
+
|
| 101 |
+
Parameters
|
| 102 |
+
----------
|
| 103 |
+
n_qubits : int
|
| 104 |
+
Number of qubits in the simulated quantum system.
|
| 105 |
+
subsystem_a : list of ints or None
|
| 106 |
+
Qubit indices for subsystem A (bipartition).
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def __init__(self, n_qubits: int = 4,
|
| 110 |
+
subsystem_a: Optional[List[int]] = None):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.n_qubits = n_qubits
|
| 113 |
+
if subsystem_a is None:
|
| 114 |
+
subsystem_a = list(range(n_qubits // 2))
|
| 115 |
+
self.subsystem_a = subsystem_a
|
| 116 |
+
|
| 117 |
+
def forward(self, attention_weights: torch.Tensor) -> torch.Tensor:
|
| 118 |
+
"""
|
| 119 |
+
Estimate entanglement from attention distributions.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
attention_weights: (batch, heads, seq_len, seq_len)
|
| 123 |
+
Softmax-normalized attention weights.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
(batch, heads) — estimated entanglement entropy per head
|
| 127 |
+
"""
|
| 128 |
+
eps = 1e-8
|
| 129 |
+
entropy = -torch.sum(
|
| 130 |
+
attention_weights * torch.log(attention_weights + eps),
|
| 131 |
+
dim=-1
|
| 132 |
+
) # (batch, heads, seq_len)
|
| 133 |
+
return entropy.mean(dim=-1) # (batch, heads)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class ClassicalQuantumFallback(nn.Module):
|
| 137 |
+
"""
|
| 138 |
+
Classical MLP fallback when PennyLane is unavailable.
|
| 139 |
+
|
| 140 |
+
Uses sinusoidal activations to mimic quantum rotation gate behavior.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, n_qubits: int = 4, n_layers: int = 2,
|
| 144 |
+
n_outputs: int = None):
|
| 145 |
+
super().__init__()
|
| 146 |
+
n_outputs = n_outputs or n_qubits
|
| 147 |
+
layers = []
|
| 148 |
+
in_dim = n_qubits
|
| 149 |
+
for _ in range(n_layers):
|
| 150 |
+
layers.extend([
|
| 151 |
+
nn.Linear(in_dim, n_qubits * 2),
|
| 152 |
+
nn.SiLU(), # Smooth activation like quantum gates
|
| 153 |
+
])
|
| 154 |
+
in_dim = n_qubits * 2
|
| 155 |
+
layers.append(nn.Linear(in_dim, n_outputs))
|
| 156 |
+
layers.append(nn.Tanh()) # Bound output to [-1, 1] like expectation values
|
| 157 |
+
self.net = nn.Sequential(*layers)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 160 |
+
return self.net(x)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def create_quantum_embedding(input_dim: int, n_qubits: int = 4,
|
| 164 |
+
n_layers: int = 2, output_dim: int = None,
|
| 165 |
+
embedding_type: str = "angle") -> nn.Module:
|
| 166 |
+
"""
|
| 167 |
+
Factory for quantum embedding layers.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
input_dim: Input feature dimension.
|
| 171 |
+
n_qubits: Number of qubits.
|
| 172 |
+
n_layers: Circuit depth.
|
| 173 |
+
output_dim: Output dimension.
|
| 174 |
+
embedding_type: 'angle' or 'amplitude'.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Quantum embedding nn.Module (or classical fallback if no PennyLane).
|
| 178 |
+
"""
|
| 179 |
+
output_dim = output_dim or n_qubits
|
| 180 |
+
|
| 181 |
+
if not HAS_PENNYLANE:
|
| 182 |
+
print("[WARN] PennyLane not installed. Using classical fallback.")
|
| 183 |
+
return nn.Sequential(
|
| 184 |
+
nn.Linear(input_dim, n_qubits),
|
| 185 |
+
ClassicalQuantumFallback(n_qubits, n_layers, output_dim),
|
| 186 |
+
nn.Linear(output_dim, output_dim),
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if embedding_type == "angle":
|
| 190 |
+
return nn.Sequential(
|
| 191 |
+
nn.Linear(input_dim, n_qubits),
|
| 192 |
+
QuantumAngleEmbedding(n_qubits, n_layers, output_dim),
|
| 193 |
+
)
|
| 194 |
+
elif embedding_type == "amplitude":
|
| 195 |
+
return nn.Sequential(
|
| 196 |
+
nn.Linear(input_dim, 2 ** n_qubits),
|
| 197 |
+
nn.Softmax(dim=-1),
|
| 198 |
+
# Amplitude embedding would go here
|
| 199 |
+
nn.Linear(2 ** n_qubits, output_dim),
|
| 200 |
+
)
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unknown embedding type: {embedding_type}")
|
src/router.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quantum Router: Selective Quantum Activation.
|
| 3 |
+
|
| 4 |
+
Only "hard" tokens pass through the quantum circuit.
|
| 5 |
+
Decision mechanism: learned linear gate + straight-through estimator.
|
| 6 |
+
|
| 7 |
+
v3 improvements:
|
| 8 |
+
- Sparsity target: ensures target fraction of tokens skip quantum
|
| 9 |
+
- Straight-through gradient for gradient-based learning
|
| 10 |
+
- Sparsity statistics tracking
|
| 11 |
+
- Fallback embedding for bypassed tokens
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class QuantumRouter(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Selective quantum activation gate.
|
| 22 |
+
|
| 23 |
+
Given a batch of token embeddings, computes a per-token
|
| 24 |
+
probability of routing through quantum. Uses straight-through
|
| 25 |
+
estimator: forward pass uses hard binary decisions, backward
|
| 26 |
+
uses soft sigmoid gradient.
|
| 27 |
+
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
d_model : int
|
| 31 |
+
Input feature dimension.
|
| 32 |
+
q_input_dim : int
|
| 33 |
+
Dimension expected by quantum circuit (typically n_qubits).
|
| 34 |
+
target_sparsity : float
|
| 35 |
+
Target fraction of tokens that SKIP quantum (0.7 = 70% skip).
|
| 36 |
+
temperature : float
|
| 37 |
+
Softmax temperature for gate decisions (lower = harder).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, d_model: int, q_input_dim: int = 4,
|
| 41 |
+
target_sparsity: float = 0.7, temperature: float = 1.0):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.d_model = d_model
|
| 44 |
+
self.q_input_dim = q_input_dim
|
| 45 |
+
self.target_sparsity = target_sparsity
|
| 46 |
+
self.temperature = temperature
|
| 47 |
+
|
| 48 |
+
# Projection for gate decision
|
| 49 |
+
self.gate_proj = nn.Sequential(
|
| 50 |
+
nn.LayerNorm(d_model),
|
| 51 |
+
nn.Linear(d_model, d_model // 4),
|
| 52 |
+
nn.GELU(),
|
| 53 |
+
nn.Linear(d_model // 4, 1),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Projection to quantum input dimension
|
| 57 |
+
self.q_proj = nn.Linear(d_model, q_input_dim)
|
| 58 |
+
|
| 59 |
+
# Statistics
|
| 60 |
+
self.register_buffer("total_tokens", torch.tensor(0, dtype=torch.long))
|
| 61 |
+
self.register_buffer("quantum_tokens", torch.tensor(0, dtype=torch.long))
|
| 62 |
+
self.register_buffer("_ema_sparsity", torch.tensor(target_sparsity))
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 65 |
+
"""
|
| 66 |
+
Route tokens selectively through quantum.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
x: (*batch, seq_len, d_model)
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
quantum_out: (*batch, seq_len, d_model) — quantum-processed tokens
|
| 73 |
+
mask: (*batch, seq_len) — which tokens went through quantum (bool)
|
| 74 |
+
"""
|
| 75 |
+
*batch_dims, seq_len, d_model = x.shape
|
| 76 |
+
|
| 77 |
+
# Gate decision
|
| 78 |
+
gate_logits = self.gate_proj(x).squeeze(-1) # (*, seq_len)
|
| 79 |
+
soft_mask = torch.sigmoid(gate_logits / self.temperature)
|
| 80 |
+
|
| 81 |
+
# Straight-through: hard forward, soft backward
|
| 82 |
+
hard_mask = (soft_mask > 0.5).float()
|
| 83 |
+
mask = hard_mask.detach() + soft_mask - soft_mask.detach()
|
| 84 |
+
|
| 85 |
+
# Project selected tokens to quantum dimension
|
| 86 |
+
q_input = self.q_proj(x) # (*, seq_len, q_input_dim)
|
| 87 |
+
|
| 88 |
+
# TODO: actual quantum circuit call goes here
|
| 89 |
+
# For now: project back to d_model with learned linear layer
|
| 90 |
+
quantum_out = F.gelu(q_input)
|
| 91 |
+
if not hasattr(self, '_q_out_proj'):
|
| 92 |
+
self._q_out_proj = nn.Linear(self.q_input_dim, d_model).to(x.device)
|
| 93 |
+
quantum_out = self._q_out_proj(quantum_out)
|
| 94 |
+
|
| 95 |
+
# Gate output
|
| 96 |
+
mask_expanded = mask.unsqueeze(-1) # (*, seq_len, 1)
|
| 97 |
+
output = mask_expanded * quantum_out
|
| 98 |
+
|
| 99 |
+
# Update statistics
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
n_tokens = seq_len * max(1, math_prod(batch_dims))
|
| 102 |
+
n_quantum = int(mask_expanded.sum().item())
|
| 103 |
+
self.total_tokens += n_tokens
|
| 104 |
+
self.quantum_tokens += n_quantum
|
| 105 |
+
actual_rate = n_quantum / max(n_tokens, 1)
|
| 106 |
+
self._ema_sparsity.mul_(0.99).add_(
|
| 107 |
+
(1 - actual_rate), alpha=0.01
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return output, mask.detach().bool()
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def sparsity(self) -> float:
|
| 114 |
+
"""Fraction of tokens that SKIP the quantum circuit."""
|
| 115 |
+
return self._ema_sparsity.item()
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def usage_percent(self) -> float:
|
| 119 |
+
"""Fraction of tokens that use the quantum circuit."""
|
| 120 |
+
return 1.0 - self.sparsity
|
| 121 |
+
|
| 122 |
+
def reset_stats(self):
|
| 123 |
+
self.total_tokens.zero_()
|
| 124 |
+
self.quantum_tokens.zero_()
|
| 125 |
+
self._ema_sparsity.fill_(self.target_sparsity)
|
| 126 |
+
|
| 127 |
+
def reset_state(self):
|
| 128 |
+
"""Full reset for clean evaluation runs."""
|
| 129 |
+
self.reset_stats()
|
| 130 |
+
for m in self.modules():
|
| 131 |
+
if hasattr(m, "reset_parameters"):
|
| 132 |
+
m.reset_parameters()
|
| 133 |
+
|
| 134 |
+
def extra_repr(self) -> str:
|
| 135 |
+
return (f"d_model={self.d_model}, q_dim={self.q_input_dim}, "
|
| 136 |
+
f"target_sparsity={self.target_sparsity:.1%}")
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def math_prod(iterable):
|
| 140 |
+
"""Safe product of iterable."""
|
| 141 |
+
result = 1
|
| 142 |
+
for x in iterable:
|
| 143 |
+
result *= x
|
| 144 |
+
return result
|
src/scheduler.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adaptive TT-Rank Scheduler.
|
| 3 |
+
|
| 4 |
+
Core novelty of Q-TensorFormer: adjusts tensor rank dynamically
|
| 5 |
+
based on per-input complexity, estimated via attention entropy.
|
| 6 |
+
|
| 7 |
+
r(input) = r_min + α × normalized_entropy × (r_max - r_min)
|
| 8 |
+
|
| 9 |
+
Supports:
|
| 10 |
+
- EMA smoothing to prevent oscillation
|
| 11 |
+
- Budget-capped ranks
|
| 12 |
+
- Deterministic rounding with hysteresis
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import math
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class RankScheduler(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Attention entropy → TT-rank scheduler.
|
| 23 |
+
|
| 24 |
+
Parameters
|
| 25 |
+
----------
|
| 26 |
+
r_min : int
|
| 27 |
+
Minimum tensor rank (maximum compression).
|
| 28 |
+
r_max : int
|
| 29 |
+
Maximum tensor rank (minimum compression).
|
| 30 |
+
alpha : float
|
| 31 |
+
Sensitivity: how much entropy changes the rank.
|
| 32 |
+
alpha=0 → fixed rank r_min.
|
| 33 |
+
alpha=1 → rank fully spans r_min to r_max.
|
| 34 |
+
alpha=2.0 → aggressive scaling (default).
|
| 35 |
+
smoothing : float
|
| 36 |
+
EMA decay factor (0.9 = smooth, 0 = no history).
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, r_min: int = 2, r_max: int = 8,
|
| 40 |
+
alpha: float = 2.0, smoothing: float = 0.9):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.r_min = r_min
|
| 43 |
+
self.r_max = r_max
|
| 44 |
+
self.alpha = alpha
|
| 45 |
+
self.smoothing = smoothing
|
| 46 |
+
|
| 47 |
+
self.register_buffer("_ema_entropy", torch.tensor(0.5))
|
| 48 |
+
self.register_buffer("_ema_rank", torch.tensor((r_min + r_max) // 2, dtype=torch.float))
|
| 49 |
+
self.register_buffer("_counter", torch.tensor(0, dtype=torch.long))
|
| 50 |
+
|
| 51 |
+
# Optionally learn alpha
|
| 52 |
+
self.learned_alpha = nn.Parameter(torch.tensor(float(alpha)), requires_grad=False)
|
| 53 |
+
|
| 54 |
+
def forward(self, entropy: torch.Tensor, seq_len: int = None) -> int:
|
| 55 |
+
"""
|
| 56 |
+
Compute rank from attention entropy.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
entropy: Scalar or 0-dim tensor (mean attention entropy).
|
| 60 |
+
seq_len: Sequence length for normalization (optional).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Integer tensor rank.
|
| 64 |
+
"""
|
| 65 |
+
if entropy.dim() > 0:
|
| 66 |
+
entropy = entropy.mean()
|
| 67 |
+
|
| 68 |
+
# Normalize entropy to [0, 1]
|
| 69 |
+
if seq_len is not None and seq_len > 1:
|
| 70 |
+
norm_factor = math.log(seq_len)
|
| 71 |
+
normalized = torch.clamp(entropy / max(norm_factor, 1e-8), 0.0, 1.0)
|
| 72 |
+
else:
|
| 73 |
+
normalized = torch.clamp(torch.tanh(entropy / 2.0), 0.0, 1.0)
|
| 74 |
+
|
| 75 |
+
# EMA smoothing
|
| 76 |
+
self._ema_entropy.mul_(self.smoothing).add_(normalized, alpha=1.0 - self.smoothing)
|
| 77 |
+
smoothed = self._ema_entropy
|
| 78 |
+
|
| 79 |
+
# Map to rank: r = r_min + alpha * norm * (r_max - r_min)
|
| 80 |
+
alpha_val = self.learned_alpha.item()
|
| 81 |
+
span = self.r_max - self.r_min
|
| 82 |
+
raw = self.r_min + alpha_val * smoothed.item() * span
|
| 83 |
+
|
| 84 |
+
# Round with hysteresis
|
| 85 |
+
self._ema_rank.mul_(0.7).add_(raw, alpha=0.3)
|
| 86 |
+
rank = int(torch.round(self._ema_rank).item())
|
| 87 |
+
|
| 88 |
+
# Clamp + counter
|
| 89 |
+
rank = max(self.r_min, min(self.r_max, rank))
|
| 90 |
+
self._counter.add_(1)
|
| 91 |
+
return rank
|
| 92 |
+
|
| 93 |
+
def reset(self):
|
| 94 |
+
"""Reset EMA state."""
|
| 95 |
+
self._ema_entropy.fill_(0.5)
|
| 96 |
+
self._ema_rank.fill_((self.r_min + self.r_max) / 2.0)
|
| 97 |
+
self._counter.fill_(0)
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def current_rank(self) -> float:
|
| 101 |
+
return self._ema_rank.item()
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def current_entropy(self) -> float:
|
| 105 |
+
return self._ema_entropy.item()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class BudgetAwareScheduler(nn.Module):
|
| 109 |
+
"""
|
| 110 |
+
Extends RankScheduler with deployment budget constraints.
|
| 111 |
+
|
| 112 |
+
Automatically caps tensor rank to meet:
|
| 113 |
+
- Max parameter budget
|
| 114 |
+
- Max latency target
|
| 115 |
+
- Max energy per query
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(self, scheduler: RankScheduler,
|
| 119 |
+
max_params: int = None,
|
| 120 |
+
max_latency_ms: float = None,
|
| 121 |
+
max_energy_uj: float = None):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.scheduler = scheduler
|
| 124 |
+
self.max_params = max_params
|
| 125 |
+
self.max_latency_ms = max_latency_ms
|
| 126 |
+
self.max_energy_uj = max_energy_uj
|
| 127 |
+
|
| 128 |
+
def forward(self, entropy: torch.Tensor, seq_len: int = None,
|
| 129 |
+
param_factors: dict = None) -> int:
|
| 130 |
+
"""
|
| 131 |
+
Compute rank with budget constraints.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
entropy: Attention entropy.
|
| 135 |
+
seq_len: Sequence length.
|
| 136 |
+
param_factors: Dict mapping rank → estimated total parameters.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Budget-constrained rank.
|
| 140 |
+
"""
|
| 141 |
+
rank = self.scheduler(entropy, seq_len)
|
| 142 |
+
|
| 143 |
+
if param_factors and self.max_params:
|
| 144 |
+
# Find highest rank that meets budget
|
| 145 |
+
while rank > self.scheduler.r_min:
|
| 146 |
+
est = param_factors.get(rank, float("inf"))
|
| 147 |
+
if est <= self.max_params:
|
| 148 |
+
break
|
| 149 |
+
rank -= 1
|
| 150 |
+
|
| 151 |
+
return rank
|
| 152 |
+
|
| 153 |
+
def reset(self):
|
| 154 |
+
self.scheduler.reset()
|
src/tensor_layers.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tensor-Train decomposed linear layers.
|
| 3 |
+
|
| 4 |
+
v3 improvements:
|
| 5 |
+
- SVD-based rank truncation (preserves dominant singular vectors)
|
| 6 |
+
- No dead padding cores (factorize_dim ensures all factors ≥ 2)
|
| 7 |
+
- torch.no_grad() on set_rank
|
| 8 |
+
- Built-in compression statistics
|
| 9 |
+
- Budget-aware: auto-selects minimum rank meeting constraints
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import math
|
| 16 |
+
from typing import Tuple, Optional
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def factorize_dim(dim: int, max_factors: int = 4) -> Tuple[int, ...]:
|
| 20 |
+
"""
|
| 21 |
+
Factorize a dimension for TT decomposition.
|
| 22 |
+
Ensures all factors >= 2 to avoid dead cores.
|
| 23 |
+
"""
|
| 24 |
+
if dim <= 1:
|
| 25 |
+
return (1,)
|
| 26 |
+
factors = []
|
| 27 |
+
remaining = dim
|
| 28 |
+
for p in [2, 2, 3, 2, 5, 2, 3, 7]:
|
| 29 |
+
while remaining % p == 0 and len(factors) < max_factors - 1:
|
| 30 |
+
factors.append(p)
|
| 31 |
+
remaining //= p
|
| 32 |
+
if remaining == 1:
|
| 33 |
+
break
|
| 34 |
+
if remaining > 1 and len(factors) < max_factors:
|
| 35 |
+
factors.append(remaining)
|
| 36 |
+
while len(factors) < 2:
|
| 37 |
+
val = factors[0] if factors else dim
|
| 38 |
+
root = int(math.isqrt(val))
|
| 39 |
+
for d in range(root, 1, -1):
|
| 40 |
+
if val % d == 0:
|
| 41 |
+
factors = [d, val // d]
|
| 42 |
+
break
|
| 43 |
+
else:
|
| 44 |
+
factors = [1, val]
|
| 45 |
+
return tuple(factors[:max_factors])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_tt_params(in_features: int, out_features: int,
|
| 49 |
+
in_shape: Tuple[int, ...], rank: int) -> int:
|
| 50 |
+
"""Compute number of parameters in a TT layer."""
|
| 51 |
+
d = len(in_shape)
|
| 52 |
+
params = 0
|
| 53 |
+
# First core: (1, out_0, in_0, rank)
|
| 54 |
+
params += out_features // math.prod(in_shape[1:]) * in_shape[0] * rank if d > 0 else 0
|
| 55 |
+
# Middle cores
|
| 56 |
+
for k in range(1, d - 1):
|
| 57 |
+
params += rank * rank * in_shape[k] * in_shape[k] # approximate
|
| 58 |
+
# Last core
|
| 59 |
+
if d > 1:
|
| 60 |
+
params += rank * in_shape[-1] * in_shape[-1]
|
| 61 |
+
return params
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TTLinear(nn.Module):
|
| 65 |
+
"""
|
| 66 |
+
Tensor-Train decomposed linear layer.
|
| 67 |
+
|
| 68 |
+
Replaces a dense weight matrix W ∈ R^{out×in} with d TT-cores.
|
| 69 |
+
Core k has shape (r_k, out_k, in_k, r_{k+1}) with r_0 = r_d = 1.
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
in_features : int
|
| 74 |
+
Input dimension.
|
| 75 |
+
out_features : int
|
| 76 |
+
Output dimension.
|
| 77 |
+
rank : int
|
| 78 |
+
TT-rank (bond dimension). Lower → more compression.
|
| 79 |
+
bias : bool
|
| 80 |
+
Include bias term.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(self, in_features: int, out_features: int,
|
| 84 |
+
rank: int = 8, bias: bool = True):
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.in_features = in_features
|
| 87 |
+
self.out_features = out_features
|
| 88 |
+
self.rank = rank
|
| 89 |
+
|
| 90 |
+
# Factorize dimensions
|
| 91 |
+
in_factors = factorize_dim(in_features)
|
| 92 |
+
out_factors = factorize_dim(out_features)
|
| 93 |
+
self.ndim = max(len(in_factors), len(out_factors))
|
| 94 |
+
|
| 95 |
+
# Pad to same length (minimal padding)
|
| 96 |
+
in_factors = list(in_factors)
|
| 97 |
+
out_factors = list(out_factors)
|
| 98 |
+
while len(in_factors) < self.ndim:
|
| 99 |
+
in_factors.append(1)
|
| 100 |
+
while len(out_factors) < self.ndim:
|
| 101 |
+
out_factors.append(1)
|
| 102 |
+
self.in_shape = tuple(in_factors)
|
| 103 |
+
self.out_shape = tuple(out_factors)
|
| 104 |
+
|
| 105 |
+
# Initialize TT cores
|
| 106 |
+
self.cores = nn.ParameterList()
|
| 107 |
+
for k in range(self.ndim):
|
| 108 |
+
r_left = 1 if k == 0 else rank
|
| 109 |
+
r_right = 1 if k == self.ndim - 1 else rank
|
| 110 |
+
core = torch.empty(r_left, out_factors[k], in_factors[k], r_right)
|
| 111 |
+
fan = max(1, r_left * in_factors[k] + r_right * out_factors[k])
|
| 112 |
+
bound = math.sqrt(6.0 / fan)
|
| 113 |
+
nn.init.uniform_(core, -bound, bound)
|
| 114 |
+
self.cores.append(core)
|
| 115 |
+
|
| 116 |
+
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
|
| 117 |
+
|
| 118 |
+
# Statistics
|
| 119 |
+
tt_params = sum(c.numel() for c in self.cores)
|
| 120 |
+
if self.bias is not None:
|
| 121 |
+
tt_params += self.bias.numel()
|
| 122 |
+
dense_params = in_features * out_features
|
| 123 |
+
self.compression_ratio = dense_params / max(tt_params, 1)
|
| 124 |
+
self._tt_params = tt_params
|
| 125 |
+
self._dense_params = dense_params
|
| 126 |
+
|
| 127 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""
|
| 129 |
+
Forward pass: sequential TT contraction.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
x: (*batch_dims, in_features)
|
| 133 |
+
Returns:
|
| 134 |
+
(*batch_dims, out_features)
|
| 135 |
+
"""
|
| 136 |
+
batch_shape = x.shape[:-1]
|
| 137 |
+
B = math.prod(batch_shape) if batch_shape else 1
|
| 138 |
+
x = x.reshape(B, self.in_features)
|
| 139 |
+
state = x.reshape(B, *self.in_shape)
|
| 140 |
+
|
| 141 |
+
for k in range(self.ndim):
|
| 142 |
+
core = self.cores[k]
|
| 143 |
+
r_k, o_k, i_k, r_kp1 = core.shape
|
| 144 |
+
|
| 145 |
+
if k == 0:
|
| 146 |
+
rest = math.prod(self.in_shape[1:]) if self.ndim > 1 else 1
|
| 147 |
+
s = state.reshape(B, i_k, rest)
|
| 148 |
+
cm = core.squeeze(0).permute(1, 0, 2).reshape(i_k, o_k * r_kp1)
|
| 149 |
+
s = torch.bmm(s.transpose(1, 2), cm.unsqueeze(0).expand(B, -1, -1))
|
| 150 |
+
s = s.reshape(B, rest, o_k, r_kp1).permute(0, 3, 2, 1)
|
| 151 |
+
state = s.reshape(B, r_kp1, -1)
|
| 152 |
+
|
| 153 |
+
elif k == self.ndim - 1:
|
| 154 |
+
prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1
|
| 155 |
+
s = state.reshape(B, r_k, prev_os, i_k)
|
| 156 |
+
cm = core.squeeze(-1)
|
| 157 |
+
s = torch.einsum('brpi,roi->bpo', s, cm)
|
| 158 |
+
state = s.reshape(B, prev_os * o_k)
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
prev_os = math.prod(self.out_shape[:k]) if k > 0 else 1
|
| 162 |
+
rest_in = math.prod(self.in_shape[k + 1:])
|
| 163 |
+
s = state.reshape(B, r_k, prev_os * i_k * rest_in)
|
| 164 |
+
s = s.reshape(B, r_k, prev_os, i_k, rest_in)
|
| 165 |
+
s = torch.einsum('brpix,roiq->bpoqx', s, core)
|
| 166 |
+
s = s.permute(0, 3, 1, 2, 4)
|
| 167 |
+
state = s.reshape(B, r_kp1, prev_os * o_k * rest_in)
|
| 168 |
+
|
| 169 |
+
out = state.reshape(B, self.out_features)
|
| 170 |
+
if self.bias is not None:
|
| 171 |
+
out = out + self.bias
|
| 172 |
+
return out.reshape(*batch_shape, self.out_features)
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def set_rank(self, new_rank: int):
|
| 176 |
+
"""
|
| 177 |
+
SVD-based TT-rank truncation.
|
| 178 |
+
|
| 179 |
+
Strategy: For each pair of adjacent cores, merge into a supercore,
|
| 180 |
+
compute SVD, and keep top `new_rank` singular values.
|
| 181 |
+
Then split back into two cores at the new rank.
|
| 182 |
+
|
| 183 |
+
For single-core edge case (ndim=1): just truncate the SVD of the sole core.
|
| 184 |
+
"""
|
| 185 |
+
if new_rank == self.rank:
|
| 186 |
+
return
|
| 187 |
+
new_rank = max(1, new_rank)
|
| 188 |
+
|
| 189 |
+
if self.ndim == 1:
|
| 190 |
+
# Single core: just reshape to matrix and SVD-truncate
|
| 191 |
+
old = self.cores[0].data # (1, o_0, i_0, 1)
|
| 192 |
+
mat = old.reshape(old.shape[1], old.shape[2]) # (o_0, i_0)
|
| 193 |
+
U, S, Vt = torch.linalg.svd(mat, full_matrices=False)
|
| 194 |
+
tr = min(new_rank, S.shape[0])
|
| 195 |
+
self.cores[0] = nn.Parameter(
|
| 196 |
+
((U[:, :tr] * S[:tr]) @ Vt[:tr, :]).reshape(1, old.shape[1], old.shape[2], 1)
|
| 197 |
+
)
|
| 198 |
+
self.rank = new_rank
|
| 199 |
+
else:
|
| 200 |
+
# Strategy: compress bond between each adjacent core pair
|
| 201 |
+
# We treat each bond independently, truncating to new_rank
|
| 202 |
+
for k in range(self.ndim - 1):
|
| 203 |
+
core_a = self.cores[k].data # (r_k, o_k, i_k, r_{k+1})
|
| 204 |
+
core_b = self.cores[k + 1].data # (r_{k+1}, o_{k+1}, i_{k+1}, r_{k+2})
|
| 205 |
+
|
| 206 |
+
r_k, o_a, i_a, r_mid = core_a.shape
|
| 207 |
+
r_mid2, o_b, i_b, r_k2 = core_b.shape
|
| 208 |
+
assert r_mid == r_mid2, f"Rank mismatch: {r_mid} != {r_mid2}"
|
| 209 |
+
|
| 210 |
+
# Merge cores along the bond to contract the middle rank
|
| 211 |
+
# core_a: reshape to (r_k * o_a * i_a, r_mid)
|
| 212 |
+
# core_b: reshape to (r_mid, o_b * i_b * r_k2)
|
| 213 |
+
# Merged: (r_k * o_a * i_a, o_b * i_b * r_k2)
|
| 214 |
+
mat_a = core_a.reshape(-1, r_mid) # (r_k*o_a*i_a, r_mid)
|
| 215 |
+
mat_b = core_b.reshape(r_mid, -1) # (r_mid, o_b*i_b*r_k2)
|
| 216 |
+
|
| 217 |
+
# Reduced SVD at the bond
|
| 218 |
+
combined = mat_a @ mat_b # (r_k*o_a*i_a, o_b*i_b*r_k2)
|
| 219 |
+
U, S, Vt = torch.linalg.svd(combined, full_matrices=False)
|
| 220 |
+
tr = min(new_rank, S.shape[0])
|
| 221 |
+
|
| 222 |
+
# Split back
|
| 223 |
+
U_tr = U[:, :tr] # (r_k*o_a*i_a, tr)
|
| 224 |
+
Vt_tr = Vt[:tr, :] # (tr, o_b*i_b*r_k2)
|
| 225 |
+
S_sqrt = torch.sqrt(S[:tr] + 1e-10) # (tr,)
|
| 226 |
+
|
| 227 |
+
new_a = (U_tr * S_sqrt).reshape(r_k, o_a, i_a, tr) # (r_k, o_a, i_a, tr)
|
| 228 |
+
new_b = (S_sqrt.unsqueeze(-1) * Vt_tr).reshape(tr, o_b, i_b, r_k2) # (tr, o_b, i_b, r_k2)
|
| 229 |
+
|
| 230 |
+
self.cores[k].data = new_a
|
| 231 |
+
self.cores[k + 1].data = new_b
|
| 232 |
+
|
| 233 |
+
self.rank = new_rank
|
| 234 |
+
|
| 235 |
+
# Update stats
|
| 236 |
+
tt_params = sum(c.numel() for c in self.cores)
|
| 237 |
+
if self.bias is not None:
|
| 238 |
+
tt_params += self.bias.numel()
|
| 239 |
+
self._tt_params = tt_params
|
| 240 |
+
self.compression_ratio = self._dense_params / max(tt_params, 1)
|
| 241 |
+
|
| 242 |
+
def flops(self, batch_size: int = 1) -> int:
|
| 243 |
+
"""Estimate FLOPs for this layer."""
|
| 244 |
+
# TT contraction: ~2 * rank^2 * ndim * avg(in_k * out_k)
|
| 245 |
+
avg_dim = (sum(self.in_shape) + sum(self.out_shape)) / (2 * self.ndim)
|
| 246 |
+
return int(2 * self.rank**2 * self.ndim * avg_dim * batch_size)
|
| 247 |
+
|
| 248 |
+
def extra_repr(self) -> str:
|
| 249 |
+
return (f"in_shape={self.in_shape}, out_shape={self.out_shape}, "
|
| 250 |
+
f"rank={self.rank}, compression={self.compression_ratio:.1f}x")
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class TTFeedForward(nn.Module):
|
| 254 |
+
"""
|
| 255 |
+
Tensor-Train Feed-Forward Network.
|
| 256 |
+
|
| 257 |
+
Replaces standard FFN (Linear↑→GELU→Linear↓) with TT-decomposed layers.
|
| 258 |
+
|
| 259 |
+
Parameters
|
| 260 |
+
----------
|
| 261 |
+
hidden_dim : int
|
| 262 |
+
Hidden dimension.
|
| 263 |
+
ff_multiplier : int
|
| 264 |
+
FFN expansion factor (default 4x).
|
| 265 |
+
rank : int
|
| 266 |
+
TT-rank.
|
| 267 |
+
activation : callable
|
| 268 |
+
Activation function (default GELU).
|
| 269 |
+
"""
|
| 270 |
+
|
| 271 |
+
def __init__(self, hidden_dim: int, ff_multiplier: int = 4,
|
| 272 |
+
rank: int = 8, activation=F.gelu):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.hidden_dim = hidden_dim
|
| 275 |
+
expanded_dim = hidden_dim * ff_multiplier
|
| 276 |
+
|
| 277 |
+
self.up_proj = TTLinear(hidden_dim, expanded_dim, rank, bias=True)
|
| 278 |
+
self.down_proj = TTLinear(expanded_dim, hidden_dim, rank, bias=True)
|
| 279 |
+
self.activation = activation
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 282 |
+
return self.down_proj(self.activation(self.up_proj(x)))
|
| 283 |
+
|
| 284 |
+
@torch.no_grad()
|
| 285 |
+
def set_rank(self, rank: int):
|
| 286 |
+
self.up_proj.set_rank(rank)
|
| 287 |
+
self.down_proj.set_rank(rank)
|
| 288 |
+
|
| 289 |
+
@property
|
| 290 |
+
def total_params(self) -> int:
|
| 291 |
+
return sum(p.numel() for p in self.parameters())
|
| 292 |
+
|
| 293 |
+
def flops(self, batch_size: int = 1) -> int:
|
| 294 |
+
return self.up_proj.flops(batch_size) + self.down_proj.flops(batch_size)
|
src/training.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training utilities with budget-aware scheduling, energy metrics, and sweep support.
|
| 3 |
+
|
| 4 |
+
v3 features:
|
| 5 |
+
- Budget-constrained training (auto-adjusts ranks to meet param/latency targets)
|
| 6 |
+
- Energy estimation (FLOPs-based proxy)
|
| 7 |
+
- Knowledge distillation support
|
| 8 |
+
- Gradient monitoring and NaN detection
|
| 9 |
+
- Checkpointing with metadata
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.optim import AdamW
|
| 16 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
|
| 17 |
+
import math
|
| 18 |
+
import time
|
| 19 |
+
from typing import Optional, Dict, Tuple, List
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
import json
|
| 22 |
+
|
| 23 |
+
from .config import ExperimentConfig
|
| 24 |
+
from .budget import BudgetTracker, EnergyEstimator
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def create_optimizer(model: nn.Module, lr: float, weight_decay: float,
|
| 28 |
+
betas: Tuple[float, float] = (0.9, 0.98),
|
| 29 |
+
eps: float = 1e-8) -> AdamW:
|
| 30 |
+
"""Create AdamW optimizer with weight decay exclusion for norms/biases."""
|
| 31 |
+
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight", "ln.weight"]
|
| 32 |
+
params = [
|
| 33 |
+
{
|
| 34 |
+
"params": [p for n, p in model.named_parameters()
|
| 35 |
+
if p.requires_grad and not any(nd in n for nd in no_decay)],
|
| 36 |
+
"weight_decay": weight_decay,
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"params": [p for n, p in model.named_parameters()
|
| 40 |
+
if p.requires_grad and any(nd in n for nd in no_decay)],
|
| 41 |
+
"weight_decay": 0.0,
|
| 42 |
+
},
|
| 43 |
+
]
|
| 44 |
+
return AdamW(params, lr=lr, betas=betas, eps=eps)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def create_scheduler(optimizer, warmup_steps: int, max_steps: int,
|
| 48 |
+
lr_min_factor: float = 0.1, scheduler_type: str = "cosine"):
|
| 49 |
+
"""Create learning rate scheduler with warmup."""
|
| 50 |
+
warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0,
|
| 51 |
+
total_iters=warmup_steps)
|
| 52 |
+
|
| 53 |
+
if scheduler_type == "cosine":
|
| 54 |
+
main = CosineAnnealingWarmRestarts(
|
| 55 |
+
optimizer, T_0=max_steps - warmup_steps,
|
| 56 |
+
T_mult=1, eta_min=lr_min_factor * optimizer.param_groups[0]["lr"]
|
| 57 |
+
)
|
| 58 |
+
elif scheduler_type == "linear":
|
| 59 |
+
main = LinearLR(optimizer, start_factor=1.0,
|
| 60 |
+
end_factor=lr_min_factor,
|
| 61 |
+
total_iters=max_steps - warmup_steps)
|
| 62 |
+
else:
|
| 63 |
+
main = LinearLR(optimizer, start_factor=1.0, end_factor=1.0,
|
| 64 |
+
total_iters=max_steps - warmup_steps)
|
| 65 |
+
|
| 66 |
+
return SequentialLR(optimizer, schedulers=[warmup, main],
|
| 67 |
+
milestones=[warmup_steps])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_perplexity(logits: torch.Tensor, targets: torch.Tensor,
|
| 71 |
+
ignore_index: int = 0) -> float:
|
| 72 |
+
"""Compute perplexity with ignore_index."""
|
| 73 |
+
loss = F.cross_entropy(
|
| 74 |
+
logits.reshape(-1, logits.size(-1)),
|
| 75 |
+
targets.reshape(-1),
|
| 76 |
+
ignore_index=ignore_index,
|
| 77 |
+
reduction="mean",
|
| 78 |
+
)
|
| 79 |
+
return math.exp(loss.item())
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Trainer:
|
| 83 |
+
"""
|
| 84 |
+
Budget-aware Q-TensorFormer trainer.
|
| 85 |
+
|
| 86 |
+
Tracks:
|
| 87 |
+
- Perplexity (primary metric)
|
| 88 |
+
- Model size (parameters)
|
| 89 |
+
- Latency estimates
|
| 90 |
+
- Energy consumption (FLOPs proxy)
|
| 91 |
+
- Quantum call statistics
|
| 92 |
+
- Rank adaptation trajectories
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, model: nn.Module, config: ExperimentConfig,
|
| 96 |
+
train_loader, val_loader=None, test_loader=None,
|
| 97 |
+
device: str = "cpu", output_dir: str = None):
|
| 98 |
+
self.model = model
|
| 99 |
+
self.config = config
|
| 100 |
+
self.train_loader = train_loader
|
| 101 |
+
self.val_loader = val_loader
|
| 102 |
+
self.test_loader = test_loader
|
| 103 |
+
self.device = torch.device(device)
|
| 104 |
+
self.output_dir = Path(output_dir or config.output_dir)
|
| 105 |
+
|
| 106 |
+
self.model.to(self.device)
|
| 107 |
+
|
| 108 |
+
total_steps = len(train_loader) * config.training.max_epochs
|
| 109 |
+
self.optimizer = create_optimizer(
|
| 110 |
+
model, config.training.learning_rate, config.training.weight_decay
|
| 111 |
+
)
|
| 112 |
+
self.scheduler = create_scheduler(
|
| 113 |
+
self.optimizer,
|
| 114 |
+
warmup_steps=config.training.warmup_steps,
|
| 115 |
+
max_steps=total_steps,
|
| 116 |
+
lr_min_factor=config.training.lr_min_factor,
|
| 117 |
+
scheduler_type=config.training.lr_scheduler,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Budget tracking
|
| 121 |
+
self.budget_tracker = BudgetTracker(config.budget)
|
| 122 |
+
self.energy_estimator = EnergyEstimator()
|
| 123 |
+
|
| 124 |
+
# Logging
|
| 125 |
+
self.metrics_history: List[Dict] = []
|
| 126 |
+
self.grad_norms: List[float] = []
|
| 127 |
+
|
| 128 |
+
def train_epoch(self, epoch: int) -> Dict:
|
| 129 |
+
"""Train for one epoch. Returns metrics dict."""
|
| 130 |
+
self.model.train()
|
| 131 |
+
self.model.reset_schedulers()
|
| 132 |
+
total_loss = 0.0
|
| 133 |
+
total_tokens = 0
|
| 134 |
+
start_time = time.time()
|
| 135 |
+
|
| 136 |
+
for step, (inputs, targets) in enumerate(self.train_loader):
|
| 137 |
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
| 138 |
+
|
| 139 |
+
self.optimizer.zero_grad()
|
| 140 |
+
|
| 141 |
+
logits, stats = self.model(inputs, return_stats=True)
|
| 142 |
+
loss = F.cross_entropy(
|
| 143 |
+
logits.reshape(-1, logits.size(-1)),
|
| 144 |
+
targets.reshape(-1),
|
| 145 |
+
ignore_index=0, # pad token
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
loss.backward()
|
| 149 |
+
|
| 150 |
+
# Gradient monitoring
|
| 151 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 152 |
+
self.model.parameters(), self.config.training.max_grad_norm
|
| 153 |
+
)
|
| 154 |
+
self.grad_norms.append(grad_norm.item())
|
| 155 |
+
|
| 156 |
+
# NaN check
|
| 157 |
+
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
| 158 |
+
print(f"[WARN] NaN/Inf gradient at step {step}. Skipping update.")
|
| 159 |
+
self.optimizer.zero_grad()
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
self.optimizer.step()
|
| 163 |
+
self.scheduler.step()
|
| 164 |
+
|
| 165 |
+
total_loss += loss.item() * inputs.size(0) * inputs.size(1)
|
| 166 |
+
total_tokens += inputs.size(0) * inputs.size(1)
|
| 167 |
+
|
| 168 |
+
elapsed = time.time() - start_time
|
| 169 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 170 |
+
ppl = math.exp(min(avg_loss, 20.0)) # Cap for stability
|
| 171 |
+
|
| 172 |
+
# Budget metrics
|
| 173 |
+
latency_est = self.budget_tracker.estimate_latency(
|
| 174 |
+
self.model, self.config.model.max_seq_len
|
| 175 |
+
)
|
| 176 |
+
energy_est = self.energy_estimator.estimate(self.model)
|
| 177 |
+
|
| 178 |
+
metrics = {
|
| 179 |
+
"epoch": epoch,
|
| 180 |
+
"train_loss": avg_loss,
|
| 181 |
+
"train_ppl": ppl,
|
| 182 |
+
"lr": self.optimizer.param_groups[0]["lr"],
|
| 183 |
+
"grad_norm_mean": sum(self.grad_norms[-len(self.train_loader):]) / len(self.grad_norms),
|
| 184 |
+
"total_params": sum(p.numel() for p in self.model.parameters()),
|
| 185 |
+
"latency_ms": latency_est,
|
| 186 |
+
"energy_uj": energy_est,
|
| 187 |
+
"time_s": elapsed,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# Extract TT stats
|
| 191 |
+
if hasattr(self.model, "stats"):
|
| 192 |
+
metrics["model_stats"] = self.model.stats
|
| 193 |
+
|
| 194 |
+
# Validation
|
| 195 |
+
if self.val_loader is not None:
|
| 196 |
+
val_metrics = self.validate()
|
| 197 |
+
metrics.update(val_metrics)
|
| 198 |
+
|
| 199 |
+
self.metrics_history.append(metrics)
|
| 200 |
+
return metrics
|
| 201 |
+
|
| 202 |
+
@torch.no_grad()
|
| 203 |
+
def validate(self) -> Dict:
|
| 204 |
+
"""Run validation."""
|
| 205 |
+
self.model.eval()
|
| 206 |
+
total_loss = 0.0
|
| 207 |
+
total_tokens = 0
|
| 208 |
+
|
| 209 |
+
for inputs, targets in self.val_loader:
|
| 210 |
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
| 211 |
+
logits = self.model(inputs)
|
| 212 |
+
loss = F.cross_entropy(
|
| 213 |
+
logits.reshape(-1, logits.size(-1)),
|
| 214 |
+
targets.reshape(-1),
|
| 215 |
+
ignore_index=0,
|
| 216 |
+
reduction="sum",
|
| 217 |
+
)
|
| 218 |
+
total_loss += loss.item()
|
| 219 |
+
total_tokens += inputs.numel()
|
| 220 |
+
|
| 221 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 222 |
+
return {
|
| 223 |
+
"val_loss": avg_loss,
|
| 224 |
+
"val_ppl": math.exp(min(avg_loss, 20.0)),
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def evaluate(self) -> Dict:
|
| 229 |
+
"""
|
| 230 |
+
Full evaluation on test set.
|
| 231 |
+
Returns comprehensive metrics dict.
|
| 232 |
+
"""
|
| 233 |
+
self.model.eval()
|
| 234 |
+
total_loss = 0.0
|
| 235 |
+
total_tokens = 0
|
| 236 |
+
latency_samples = []
|
| 237 |
+
|
| 238 |
+
for inputs, targets in self.test_loader:
|
| 239 |
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
| 240 |
+
|
| 241 |
+
t0 = time.time()
|
| 242 |
+
logits = self.model(inputs)
|
| 243 |
+
t1 = time.time()
|
| 244 |
+
latency_samples.append((t1 - t0) * 1000 / inputs.size(0)) # ms per sample
|
| 245 |
+
|
| 246 |
+
loss = F.cross_entropy(
|
| 247 |
+
logits.reshape(-1, logits.size(-1)),
|
| 248 |
+
targets.reshape(-1),
|
| 249 |
+
ignore_index=0,
|
| 250 |
+
reduction="sum",
|
| 251 |
+
)
|
| 252 |
+
total_loss += loss.item()
|
| 253 |
+
total_tokens += inputs.numel()
|
| 254 |
+
|
| 255 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
"test_loss": avg_loss,
|
| 259 |
+
"test_ppl": math.exp(min(avg_loss, 20.0)),
|
| 260 |
+
"latency_ms_mean": sum(latency_samples) / len(latency_samples),
|
| 261 |
+
"total_params": self.model.total_params,
|
| 262 |
+
"energy_uj": self.energy_estimator.estimate(self.model),
|
| 263 |
+
"model_stats": getattr(self.model, "stats", {}),
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
def train(self) -> Dict:
|
| 267 |
+
"""Full training loop."""
|
| 268 |
+
best_val_ppl = float("inf")
|
| 269 |
+
|
| 270 |
+
for epoch in range(self.config.training.max_epochs):
|
| 271 |
+
metrics = self.train_epoch(epoch)
|
| 272 |
+
|
| 273 |
+
# Logging
|
| 274 |
+
print(f"Epoch {epoch+1}/{self.config.training.max_epochs}: "
|
| 275 |
+
f"train_ppl={metrics['train_ppl']:.2f} "
|
| 276 |
+
f"val_ppl={metrics.get('val_ppl', 'N/A')} "
|
| 277 |
+
f"lr={metrics['lr']:.2e}")
|
| 278 |
+
|
| 279 |
+
if metrics.get("val_ppl", float("inf")) < best_val_ppl:
|
| 280 |
+
best_val_ppl = metrics["val_ppl"]
|
| 281 |
+
self.save_checkpoint("best")
|
| 282 |
+
|
| 283 |
+
# Early stopping checks
|
| 284 |
+
if self.budget_tracker.exceeds_budget(metrics, self.config.model):
|
| 285 |
+
print(f"[BUDGET] Exceeded constraints. Stopping.")
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
self.save_checkpoint("last")
|
| 289 |
+
self.save_metrics()
|
| 290 |
+
return self.metrics_history[-1] if self.metrics_history else {}
|
| 291 |
+
|
| 292 |
+
def save_checkpoint(self, tag: str = "checkpoint"):
|
| 293 |
+
"""Save model checkpoint with metadata."""
|
| 294 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 295 |
+
path = self.output_dir / f"{tag}.pt"
|
| 296 |
+
torch.save({
|
| 297 |
+
"model_state_dict": self.model.state_dict(),
|
| 298 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 299 |
+
"config": self.config,
|
| 300 |
+
"metrics": self.metrics_history,
|
| 301 |
+
}, path)
|
| 302 |
+
print(f"Checkpoint saved to {path}")
|
| 303 |
+
|
| 304 |
+
def load_checkpoint(self, tag: str = "best"):
|
| 305 |
+
"""Load checkpoint."""
|
| 306 |
+
path = self.output_dir / f"{tag}.pt"
|
| 307 |
+
if not path.exists():
|
| 308 |
+
print(f"Checkpoint {path} not found")
|
| 309 |
+
return
|
| 310 |
+
ckpt = torch.load(path, map_location=self.device, weights_only=True)
|
| 311 |
+
self.model.load_state_dict(ckpt["model_state_dict"])
|
| 312 |
+
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 313 |
+
|
| 314 |
+
def save_metrics(self):
|
| 315 |
+
"""Save metrics to JSON."""
|
| 316 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 317 |
+
path = self.output_dir / "metrics.json"
|
| 318 |
+
with open(path, "w") as f:
|
| 319 |
+
json.dump(self.metrics_history, f, indent=2)
|
| 320 |
+
print(f"Metrics saved to {path}")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class DistillationTrainer(Trainer):
|
| 324 |
+
"""
|
| 325 |
+
Knowledge distillation trainer.
|
| 326 |
+
|
| 327 |
+
Student = compressed Q-TensorFormer.
|
| 328 |
+
Teacher = dense (or larger) model.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self, student: nn.Module, teacher: nn.Module, *args,
|
| 332 |
+
alpha: float = 0.5, temperature: float = 3.0, **kwargs):
|
| 333 |
+
"""
|
| 334 |
+
Args:
|
| 335 |
+
student: Compressed Q-TensorFormer.
|
| 336 |
+
teacher: Dense baseline (frozen).
|
| 337 |
+
alpha: Weight between distillation loss (α) and task loss (1-α).
|
| 338 |
+
temperature: Softmax temperature.
|
| 339 |
+
"""
|
| 340 |
+
super().__init__(student, *args, **kwargs)
|
| 341 |
+
self.teacher = teacher.to(self.device)
|
| 342 |
+
self.teacher.eval()
|
| 343 |
+
self.alpha = alpha
|
| 344 |
+
self.temperature = temperature
|
| 345 |
+
|
| 346 |
+
# Freeze teacher
|
| 347 |
+
for p in self.teacher.parameters():
|
| 348 |
+
p.requires_grad = False
|
| 349 |
+
|
| 350 |
+
def train_epoch(self, epoch: int) -> Dict:
|
| 351 |
+
self.model.train()
|
| 352 |
+
total_loss = 0.0
|
| 353 |
+
total_tokens = 0
|
| 354 |
+
|
| 355 |
+
for step, (inputs, targets) in enumerate(self.train_loader):
|
| 356 |
+
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
| 357 |
+
|
| 358 |
+
self.optimizer.zero_grad()
|
| 359 |
+
|
| 360 |
+
# Student forward
|
| 361 |
+
logits, stats = self.model(inputs, return_stats=True)
|
| 362 |
+
|
| 363 |
+
# Task loss
|
| 364 |
+
task_loss = F.cross_entropy(
|
| 365 |
+
logits.reshape(-1, logits.size(-1)),
|
| 366 |
+
targets.reshape(-1),
|
| 367 |
+
ignore_index=0,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Distillation loss
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
teacher_logits = self.teacher(inputs)
|
| 373 |
+
|
| 374 |
+
distill_loss = F.kl_div(
|
| 375 |
+
F.log_softmax(logits / self.temperature, dim=-1),
|
| 376 |
+
F.softmax(teacher_logits / self.temperature, dim=-1),
|
| 377 |
+
reduction="batchmean",
|
| 378 |
+
) * (self.temperature ** 2)
|
| 379 |
+
|
| 380 |
+
loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss
|
| 381 |
+
loss.backward()
|
| 382 |
+
|
| 383 |
+
torch.nn.utils.clip_grad_norm_(
|
| 384 |
+
self.model.parameters(), self.config.training.max_grad_norm
|
| 385 |
+
)
|
| 386 |
+
self.optimizer.step()
|
| 387 |
+
self.scheduler.step()
|
| 388 |
+
|
| 389 |
+
total_loss += task_loss.item() * inputs.numel()
|
| 390 |
+
total_tokens += inputs.numel()
|
| 391 |
+
|
| 392 |
+
avg_loss = total_loss / max(total_tokens, 1)
|
| 393 |
+
ppl = math.exp(min(avg_loss, 20.0))
|
| 394 |
+
return {
|
| 395 |
+
"epoch": epoch,
|
| 396 |
+
"train_loss": avg_loss,
|
| 397 |
+
"train_ppl": ppl,
|
| 398 |
+
"lr": self.optimizer.param_groups[0]["lr"],
|
| 399 |
+
}
|