File size: 4,730 Bytes
b9c4adf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
Hybrid Transformer Block: Tensor + Quantum + Adaptive.
v3 modular design — block can be configured as:
- TT-FFN only (pure tensor)
- Quantum only
- Hybrid (both)
- Standard MLP-FFN (baseline)
Each block contains:
- Multi-Head Attention (with entropy monitoring)
- RankScheduler (entropy → TT rank)
- QuantumRouter (selective quantum activation)
- TTFeedForward (tensor-decomposed FFN)
"""
import torch
import torch.nn as nn
from .attention import MultiHeadAttention, HybridQAttention
from .tensor_layers import TTFeedForward
from .scheduler import RankScheduler, BudgetAwareScheduler
from .router import QuantumRouter
class HybridBlock(nn.Module):
"""
A single Q-TensorFormer block.
Flow:
x → LayerNorm → Attention + Entropy
→ RankScheduler: adjust TT ranks
→ LayerNorm → QuantumRouter (gate)
→ TTFeedForward (tensor-decomposed)
→ residual connection
"""
def __init__(self, d_model: int = 128, n_heads: int = 4,
ff_multiplier: int = 4, tt_rank: int = 8,
tt_min_rank: int = 2, use_quantum: bool = True,
n_qubits: int = 4, n_quantum_layers: int = 2,
quantum_sparsity: float = 0.7, rank_alpha: float = 2.0,
rank_smoothing: float = 0.9, dropout: float = 0.1,
max_seq_len: int = 128):
super().__init__()
self.d_model = d_model
self.use_quantum = use_quantum
self.is_hybrid = use_quantum # Flag for model-level detection
# Attention
self.attention = MultiHeadAttention(
d_model, n_heads, dropout, max_seq_len,
use_quantum_kernel=False
)
# Layer norms
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
# Rank scheduler
self.rank_scheduler = RankScheduler(
r_min=tt_min_rank, r_max=tt_rank,
alpha=rank_alpha, smoothing=rank_smoothing
)
# Quantum router
if use_quantum:
self.quantum_router = QuantumRouter(
d_model=d_model,
q_input_dim=n_qubits,
target_sparsity=quantum_sparsity,
)
else:
self.quantum_router = None
# Tensor-Train FFN
self.tt_ffn = TTFeedForward(
hidden_dim=d_model,
ff_multiplier=ff_multiplier,
rank=tt_rank,
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""
Args:
x: (batch, seq_len, d_model)
mask: (batch, seq_len) optional padding mask
Returns:
output: (batch, seq_len, d_model)
stats: dict with entropy, rank, quantum_usage
"""
stats = {}
# Attention sublayer
attn_out, entropy = self.attention(
self.ln1(x), mask=mask, return_entropy=True
)
x = x + self.dropout(attn_out)
# Schedule rank from attention entropy
mean_entropy = entropy.mean() if entropy.dim() > 0 else entropy
new_rank = self.rank_scheduler(mean_entropy, seq_len=x.shape[1])
self.tt_ffn.set_rank(new_rank)
stats["entropy"] = mean_entropy.item()
stats["rank"] = new_rank
# FFN sublayer
normed = self.ln2(x)
# Quantum routing
quantum_out = torch.zeros_like(normed)
if self.quantum_router is not None:
quantum_out, q_mask = self.quantum_router(normed)
stats["quantum_usage"] = self.quantum_router.usage_percent
stats["quantum_sparsity"] = self.quantum_router.sparsity
# TT feed-forward
ffn_out = self.tt_ffn(normed)
# Combine: quantum signal modifies the FFN input
combined = normed + self.dropout(ffn_out + quantum_out)
x = x + combined
return x, stats
def set_rank(self, rank: int):
"""Manually override rank."""
self.tt_ffn.set_rank(rank)
def reset_scheduler(self):
self.rank_scheduler.reset()
if self.quantum_router is not None:
self.quantum_router.reset_stats()
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
def flops_estimate(self, batch_size: int = 1, seq_len: int = 128) -> dict:
"""Estimate FLOPs for this block."""
attn_flops = self.attention.flops(batch_size, seq_len)["total"]
ffn_flops = self.tt_ffn.flops(batch_size)
return {
"attention": attn_flops,
"tt_ffn": ffn_flops,
"total": attn_flops + ffn_flops,
}
|