| """
|
| Circuit Transformer: Minimal transformer for semantic circuitry experiments.
|
|
|
| Follows patterns from shimmer/lira/gpt.py with extension hooks for future work.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
|
|
| from .config import CircuitConfig
|
| from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| """Pre-norm transformer block with causal attention."""
|
|
|
| def __init__(
|
| self,
|
| hidden_size: int,
|
| num_heads: int,
|
| num_kv_heads: int | None = None,
|
| max_seq_len: int = 2048,
|
| dropout: float = 0.0,
|
| window_size: int | None = None,
|
| word_rope_dims: int = 0,
|
| word_rope_base: float = 10.0,
|
| ):
|
| super().__init__()
|
| self.attn_norm = RMSNorm(hidden_size)
|
| self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size,
|
| word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| self.ffn_norm = RMSNorm(hidden_size)
|
| self.ffn = SwiGLU(hidden_size)
|
|
|
| def forward(
|
| self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| word_positions: torch.Tensor | None = None,
|
| ) -> tuple[torch.Tensor, tuple | None]:
|
|
|
| attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| x = x + attn_out
|
|
|
|
|
| x = x + self.ffn(self.ffn_norm(x))
|
|
|
| return x, new_kv
|
|
|
|
|
| class CircuitTransformer(nn.Module):
|
| """
|
| Minimal transformer for semantic circuitry experiments.
|
|
|
| Features:
|
| - Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention)
|
| - Weight tying (embed = lm_head)
|
| - Extension hooks for future work:
|
| - freeze_layers() / unfreeze_layers() for progressive training
|
| - get_layer_outputs() for interpretability
|
| - window_size param for sliding window attention
|
| """
|
|
|
| def __init__(self, config: CircuitConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
|
|
| embed_dim = getattr(config, 'embed_dim', 0)
|
| head_dim = getattr(config, 'head_dim', 0)
|
|
|
| if embed_dim > 0 and head_dim == 0:
|
| head_dim = embed_dim
|
|
|
| if embed_dim > 0:
|
| self.embed = nn.Embedding(config.vocab_size, embed_dim)
|
| self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| else:
|
| self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
| self.embed_proj = None
|
| self.embed_scale = math.sqrt(config.hidden_size)
|
|
|
|
|
| self.layers = nn.ModuleList([
|
| TransformerBlock(
|
| config.hidden_size,
|
| config.num_heads,
|
| getattr(config, 'num_kv_heads', None),
|
| config.max_seq_len,
|
| config.dropout,
|
| word_rope_dims=getattr(config, 'word_rope_dims', 0),
|
| word_rope_base=getattr(config, 'word_rope_base', 10.0),
|
| )
|
| for _ in range(config.num_layers)
|
| ])
|
|
|
|
|
| self.norm = RMSNorm(config.hidden_size)
|
| if head_dim > 0:
|
| self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| else:
|
| self.head_down = None
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| _e = embed_dim if embed_dim > 0 else config.hidden_size
|
| _h = head_dim if head_dim > 0 else config.hidden_size
|
| if _e == _h:
|
| self.lm_head.weight = self.embed.weight
|
|
|
|
|
| self.skip_head = None
|
| self.skip_head_down = None
|
| aux_skip_k = getattr(config, 'aux_skip_k', 0)
|
| if aux_skip_k > 0:
|
| if head_dim > 0:
|
| self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| else:
|
| self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| self._frozen_layers: set[int] = set()
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, nn.Linear):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| torch.nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| labels: torch.Tensor | None = None,
|
| use_cache: bool = False,
|
| past_kv: list | None = None,
|
| word_positions: torch.Tensor | None = None,
|
| ) -> dict:
|
| """
|
| Forward pass.
|
|
|
| Args:
|
| input_ids: [B, L] token IDs
|
| labels: [B, L] target token IDs (for loss computation)
|
| use_cache: Whether to return KV cache for generation
|
| past_kv: Previous KV cache
|
| word_positions: [B, L] position within word (from compute_word_positions)
|
|
|
| Returns:
|
| dict with 'logits', optionally 'loss' and 'past_kv'
|
| """
|
| B, L = input_ids.shape
|
|
|
|
|
| x = self.embed(input_ids)
|
| if self.embed_proj is not None:
|
| x = F.silu(self.embed_proj(x))
|
| x = x * self.embed_scale
|
|
|
|
|
| new_kv = [] if use_cache else None
|
| for i, layer in enumerate(self.layers):
|
| layer_past = past_kv[i] if past_kv is not None else None
|
| x, kv = layer(x, use_cache, layer_past, word_positions=word_positions)
|
| if use_cache:
|
| new_kv.append(kv)
|
|
|
|
|
| x = self.norm(x)
|
| if self.head_down is not None:
|
| logits = self.lm_head(F.silu(self.head_down(x)))
|
| else:
|
| logits = self.lm_head(x)
|
|
|
| result = {"logits": logits}
|
|
|
| if use_cache:
|
| result["past_kv"] = new_kv
|
|
|
|
|
| if labels is not None:
|
|
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_labels = labels[:, 1:].contiguous()
|
| loss = F.cross_entropy(
|
| shift_logits.view(-1, self.config.vocab_size),
|
| shift_labels.view(-1),
|
| ignore_index=-100,
|
| )
|
|
|
|
|
| if self.skip_head is not None:
|
| skip_k = getattr(self.config, 'aux_skip_k', 0)
|
| skip_weight = getattr(self.config, 'aux_skip_weight', 0.1)
|
| if self.skip_head_down is not None:
|
| skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
|
| else:
|
| skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
|
| skip_labels = labels[:, skip_k:].contiguous()
|
| aux_loss = F.cross_entropy(
|
| skip_logits.view(-1, self.config.vocab_size),
|
| skip_labels.view(-1),
|
| ignore_index=-100,
|
| )
|
| result["aux_loss"] = aux_loss
|
| loss = loss + skip_weight * aux_loss
|
|
|
| result["loss"] = loss
|
|
|
| return result
|
|
|
|
|
|
|
| def freeze_layers(self, indices: list[int]) -> None:
|
| """Freeze specific layers (stop gradients)."""
|
| for idx in indices:
|
| if 0 <= idx < len(self.layers):
|
| for param in self.layers[idx].parameters():
|
| param.requires_grad = False
|
| self._frozen_layers.add(idx)
|
|
|
| def unfreeze_layers(self, indices: list[int] | None = None) -> None:
|
| """Unfreeze specific layers (or all if indices=None)."""
|
| if indices is None:
|
| indices = list(self._frozen_layers)
|
| for idx in indices:
|
| if 0 <= idx < len(self.layers):
|
| for param in self.layers[idx].parameters():
|
| param.requires_grad = True
|
| self._frozen_layers.discard(idx)
|
|
|
| def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]:
|
| """Get intermediate outputs from each layer for interpretability."""
|
| outputs = []
|
| x = self.embed(input_ids)
|
| if self.embed_proj is not None:
|
| x = F.silu(self.embed_proj(x))
|
| x = x * self.embed_scale
|
|
|
| for layer in self.layers:
|
| x, _ = layer(x, use_cache=False, past_kv=None)
|
| outputs.append(x.clone())
|
|
|
| return outputs
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| prompt_ids: torch.Tensor,
|
| max_new_tokens: int = 50,
|
| temperature: float = 0.8,
|
| top_k: int = 50,
|
| top_p: float = 0.9,
|
| use_cache: bool = True,
|
| word_start_table: torch.Tensor | None = None,
|
| ) -> torch.Tensor:
|
| """
|
| Autoregressive generation with KV caching.
|
|
|
| Args:
|
| prompt_ids: [B, L] prompt token IDs
|
| max_new_tokens: Maximum tokens to generate
|
| temperature: Sampling temperature
|
| top_k: Top-k filtering
|
| top_p: Nucleus sampling threshold
|
| use_cache: Use KV cache for faster generation
|
| word_start_table: [vocab_size] bool tensor for word-position RoPE
|
|
|
| Returns:
|
| [B, L + max_new_tokens] generated token IDs
|
| """
|
| from .layers import compute_word_positions
|
|
|
| self.eval()
|
| generated = prompt_ids.clone()
|
| past_kv = None
|
| word_pos_counter = 0
|
|
|
| for _ in range(max_new_tokens):
|
|
|
| if use_cache and past_kv is not None:
|
| input_ids = generated[:, -1:]
|
|
|
| if word_start_table is not None:
|
| last_token = generated[0, -1].item()
|
| if word_start_table[last_token]:
|
| word_pos_counter = 0
|
| else:
|
| word_pos_counter += 1
|
| word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
|
| else:
|
| word_positions = None
|
| else:
|
| input_ids = generated
|
|
|
| if word_start_table is not None:
|
| word_positions = compute_word_positions(input_ids, word_start_table)
|
| else:
|
| word_positions = None
|
|
|
|
|
| output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
|
| logits = output["logits"][:, -1, :]
|
|
|
| if use_cache:
|
| past_kv = output["past_kv"]
|
|
|
|
|
| if temperature > 0:
|
| logits = logits / temperature
|
|
|
|
|
| if top_k > 0:
|
| top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| logits = torch.where(logits < min_top_k, float("-inf"), logits)
|
|
|
|
|
| if top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
| sorted_indices_to_remove = cumsum_probs > top_p
|
| sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| sorted_indices_to_remove[:, 0] = False
|
|
|
| indices_to_remove = sorted_indices_to_remove.scatter(
|
| 1, sorted_indices, sorted_indices_to_remove
|
| )
|
| logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
|
|
|
|
| probs = F.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| else:
|
|
|
| next_token = logits.argmax(dim=-1, keepdim=True)
|
|
|
| generated = torch.cat([generated, next_token], dim=1)
|
|
|
|
|
| if generated.size(1) >= self.config.max_seq_len:
|
| break
|
|
|
| return generated
|
|
|
|
|
| def count_parameters(model: CircuitTransformer) -> int:
|
| """Count trainable parameters."""
|
| return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|