File size: 4,730 Bytes
b9c4adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Hybrid Transformer Block: Tensor + Quantum + Adaptive.

v3 modular design — block can be configured as:
  - TT-FFN only (pure tensor)
  - Quantum only
  - Hybrid (both)
  - Standard MLP-FFN (baseline)

Each block contains:
  - Multi-Head Attention (with entropy monitoring)
  - RankScheduler (entropy → TT rank)
  - QuantumRouter (selective quantum activation)
  - TTFeedForward (tensor-decomposed FFN)
"""

import torch
import torch.nn as nn
from .attention import MultiHeadAttention, HybridQAttention
from .tensor_layers import TTFeedForward
from .scheduler import RankScheduler, BudgetAwareScheduler
from .router import QuantumRouter


class HybridBlock(nn.Module):
    """
    A single Q-TensorFormer block.

    Flow:
        x → LayerNorm → Attention + Entropy
          → RankScheduler: adjust TT ranks
          → LayerNorm → QuantumRouter (gate)
          → TTFeedForward (tensor-decomposed)
          → residual connection
    """

    def __init__(self, d_model: int = 128, n_heads: int = 4,
                 ff_multiplier: int = 4, tt_rank: int = 8,
                 tt_min_rank: int = 2, use_quantum: bool = True,
                 n_qubits: int = 4, n_quantum_layers: int = 2,
                 quantum_sparsity: float = 0.7, rank_alpha: float = 2.0,
                 rank_smoothing: float = 0.9, dropout: float = 0.1,
                 max_seq_len: int = 128):
        super().__init__()

        self.d_model = d_model
        self.use_quantum = use_quantum
        self.is_hybrid = use_quantum  # Flag for model-level detection

        # Attention
        self.attention = MultiHeadAttention(
            d_model, n_heads, dropout, max_seq_len,
            use_quantum_kernel=False
        )

        # Layer norms
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        # Rank scheduler
        self.rank_scheduler = RankScheduler(
            r_min=tt_min_rank, r_max=tt_rank,
            alpha=rank_alpha, smoothing=rank_smoothing
        )

        # Quantum router
        if use_quantum:
            self.quantum_router = QuantumRouter(
                d_model=d_model,
                q_input_dim=n_qubits,
                target_sparsity=quantum_sparsity,
            )
        else:
            self.quantum_router = None

        # Tensor-Train FFN
        self.tt_ffn = TTFeedForward(
            hidden_dim=d_model,
            ff_multiplier=ff_multiplier,
            rank=tt_rank,
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: (batch, seq_len) optional padding mask

        Returns:
            output: (batch, seq_len, d_model)
            stats: dict with entropy, rank, quantum_usage
        """
        stats = {}

        # Attention sublayer
        attn_out, entropy = self.attention(
            self.ln1(x), mask=mask, return_entropy=True
        )
        x = x + self.dropout(attn_out)

        # Schedule rank from attention entropy
        mean_entropy = entropy.mean() if entropy.dim() > 0 else entropy
        new_rank = self.rank_scheduler(mean_entropy, seq_len=x.shape[1])
        self.tt_ffn.set_rank(new_rank)
        stats["entropy"] = mean_entropy.item()
        stats["rank"] = new_rank

        # FFN sublayer
        normed = self.ln2(x)

        # Quantum routing
        quantum_out = torch.zeros_like(normed)
        if self.quantum_router is not None:
            quantum_out, q_mask = self.quantum_router(normed)
            stats["quantum_usage"] = self.quantum_router.usage_percent
            stats["quantum_sparsity"] = self.quantum_router.sparsity

        # TT feed-forward
        ffn_out = self.tt_ffn(normed)

        # Combine: quantum signal modifies the FFN input
        combined = normed + self.dropout(ffn_out + quantum_out)
        x = x + combined

        return x, stats

    def set_rank(self, rank: int):
        """Manually override rank."""
        self.tt_ffn.set_rank(rank)

    def reset_scheduler(self):
        self.rank_scheduler.reset()
        if self.quantum_router is not None:
            self.quantum_router.reset_stats()

    @property
    def total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def flops_estimate(self, batch_size: int = 1, seq_len: int = 128) -> dict:
        """Estimate FLOPs for this block."""
        attn_flops = self.attention.flops(batch_size, seq_len)["total"]
        ffn_flops = self.tt_ffn.flops(batch_size)
        return {
            "attention": attn_flops,
            "tt_ffn": ffn_flops,
            "total": attn_flops + ffn_flops,
        }