Q-TensorFormer / src /blocks.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
raw
history blame
4.73 kB
"""
Hybrid Transformer Block: Tensor + Quantum + Adaptive.
v3 modular design β€” block can be configured as:
- TT-FFN only (pure tensor)
- Quantum only
- Hybrid (both)
- Standard MLP-FFN (baseline)
Each block contains:
- Multi-Head Attention (with entropy monitoring)
- RankScheduler (entropy β†’ TT rank)
- QuantumRouter (selective quantum activation)
- TTFeedForward (tensor-decomposed FFN)
"""
import torch
import torch.nn as nn
from .attention import MultiHeadAttention, HybridQAttention
from .tensor_layers import TTFeedForward
from .scheduler import RankScheduler, BudgetAwareScheduler
from .router import QuantumRouter
class HybridBlock(nn.Module):
"""
A single Q-TensorFormer block.
Flow:
x β†’ LayerNorm β†’ Attention + Entropy
β†’ RankScheduler: adjust TT ranks
β†’ LayerNorm β†’ QuantumRouter (gate)
β†’ TTFeedForward (tensor-decomposed)
β†’ residual connection
"""
def __init__(self, d_model: int = 128, n_heads: int = 4,
ff_multiplier: int = 4, tt_rank: int = 8,
tt_min_rank: int = 2, use_quantum: bool = True,
n_qubits: int = 4, n_quantum_layers: int = 2,
quantum_sparsity: float = 0.7, rank_alpha: float = 2.0,
rank_smoothing: float = 0.9, dropout: float = 0.1,
max_seq_len: int = 128):
super().__init__()
self.d_model = d_model
self.use_quantum = use_quantum
self.is_hybrid = use_quantum # Flag for model-level detection
# Attention
self.attention = MultiHeadAttention(
d_model, n_heads, dropout, max_seq_len,
use_quantum_kernel=False
)
# Layer norms
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
# Rank scheduler
self.rank_scheduler = RankScheduler(
r_min=tt_min_rank, r_max=tt_rank,
alpha=rank_alpha, smoothing=rank_smoothing
)
# Quantum router
if use_quantum:
self.quantum_router = QuantumRouter(
d_model=d_model,
q_input_dim=n_qubits,
target_sparsity=quantum_sparsity,
)
else:
self.quantum_router = None
# Tensor-Train FFN
self.tt_ffn = TTFeedForward(
hidden_dim=d_model,
ff_multiplier=ff_multiplier,
rank=tt_rank,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
Args:
x: (batch, seq_len, d_model)
mask: (batch, seq_len) optional padding mask
Returns:
output: (batch, seq_len, d_model)
stats: dict with entropy, rank, quantum_usage
"""
stats = {}
# Attention sublayer
attn_out, entropy = self.attention(
self.ln1(x), mask=mask, return_entropy=True
)
x = x + self.dropout(attn_out)
# Schedule rank from attention entropy
mean_entropy = entropy.mean() if entropy.dim() > 0 else entropy
new_rank = self.rank_scheduler(mean_entropy, seq_len=x.shape[1])
self.tt_ffn.set_rank(new_rank)
stats["entropy"] = mean_entropy.item()
stats["rank"] = new_rank
# FFN sublayer
normed = self.ln2(x)
# Quantum routing
quantum_out = torch.zeros_like(normed)
if self.quantum_router is not None:
quantum_out, q_mask = self.quantum_router(normed)
stats["quantum_usage"] = self.quantum_router.usage_percent
stats["quantum_sparsity"] = self.quantum_router.sparsity
# TT feed-forward
ffn_out = self.tt_ffn(normed)
# Combine: quantum signal modifies the FFN input
combined = normed + self.dropout(ffn_out + quantum_out)
x = x + combined
return x, stats
def set_rank(self, rank: int):
"""Manually override rank."""
self.tt_ffn.set_rank(rank)
def reset_scheduler(self):
self.rank_scheduler.reset()
if self.quantum_router is not None:
self.quantum_router.reset_stats()
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def flops_estimate(self, batch_size: int = 1, seq_len: int = 128) -> dict:
"""Estimate FLOPs for this block."""
attn_flops = self.attention.flops(batch_size, seq_len)["total"]
ffn_flops = self.tt_ffn.flops(batch_size)
return {
"attention": attn_flops,
"tt_ffn": ffn_flops,
"total": attn_flops + ffn_flops,
}