| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, PretrainedConfig, AutoConfig, AutoModel, PreTrainedModel |
| from torch.optim import AdamW |
| import os |
| import time |
| import numpy as np |
| import json |
| |
| class BucketMemoryConfig(PretrainedConfig): |
| model_type = "bucket-memory-model3" |
|
|
| def __init__( |
| self, vocab_size=30000, d_model=512, num_layers=6, num_buckets=8, |
| min_bucket_size=1, max_bucket_size=32, max_seq_length=1024, dropout=0.1, |
| use_flash_attention=True, num_attention_heads=8, **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.num_layers = num_layers |
| self.num_buckets = num_buckets |
| self.min_bucket_size = min_bucket_size |
| self.max_bucket_size = max_bucket_size |
| self.max_seq_length = max_seq_length |
| self.dropout = dropout |
| self.use_flash_attention = use_flash_attention |
| self.num_attention_heads = num_attention_heads |
|
|
| class DynamicBucketMemory(nn.Module): |
| def __init__(self, embedding_dim=512, num_buckets=8, min_bucket_size=1, max_bucket_size=32, |
| compression_factor=0.8, decay_rate=0.05): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
| self.num_buckets = num_buckets |
| self.min_bucket_size = min_bucket_size |
| self.max_bucket_size = max_bucket_size |
| self.decay_rate = decay_rate |
|
|
| |
| sizes = np.logspace(np.log10(min_bucket_size), np.log10(max_bucket_size), num_buckets).astype(int) |
| self.bucket_sizes = np.maximum(sizes, min_bucket_size).tolist() |
|
|
| |
| self.memory_buckets = None |
| self.memory_age = None |
| self.bucket_importance = nn.Parameter(torch.ones(num_buckets)) |
|
|
| |
| self.query_proj = nn.Linear(embedding_dim, embedding_dim) |
| self.key_proj = nn.Linear(embedding_dim, embedding_dim) |
| self.value_proj = nn.Linear(embedding_dim, embedding_dim) |
| self.output_proj = nn.Linear(embedding_dim, embedding_dim) |
| self.input_norm = nn.LayerNorm(embedding_dim) |
| self.output_norm = nn.LayerNorm(embedding_dim) |
|
|
| self.bucket_selector = nn.Sequential( |
| nn.Linear(embedding_dim, num_buckets * 2), |
| nn.GELU(), |
| nn.Linear(num_buckets * 2, num_buckets), |
| nn.Softmax(dim=-1) |
| ) |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| def _initialize_memory(self, batch_size, device): |
| if self.memory_buckets is None: |
| self.memory_buckets = [torch.zeros(batch_size, size, self.embedding_dim, device=device) |
| for size in self.bucket_sizes] |
| self.memory_age = [torch.zeros(batch_size, size, device=device) for size in self.bucket_sizes] |
|
|
| def forward(self, input_data, memory_update=True): |
| |
| while input_data.dim() > 3: |
| input_data = input_data.squeeze(0) |
| if input_data.dim() == 4: |
| input_data = input_data.squeeze(-1) |
| if input_data.dim() == 2: |
| input_data = input_data.unsqueeze(-1) |
| if self.embedding_dim > 1: |
| input_data = input_data.expand(-1, -1, self.embedding_dim) |
|
|
| batch_size, seq_len, _ = input_data.size() |
| device = input_data.device |
|
|
| normalized_input = self.input_norm(input_data) |
|
|
| |
| if self.memory_buckets is None or len(self.memory_buckets[0]) != batch_size: |
| self._initialize_memory(batch_size, device) |
|
|
| |
| avg_input_features = normalized_input.mean(dim=1) |
| bucket_weights = self.bucket_selector(avg_input_features) |
|
|
| |
| projected_query = self.query_proj(normalized_input) |
| outputs = torch.zeros(batch_size, seq_len, self.embedding_dim, device=device) |
|
|
| for b in range(self.num_buckets): |
| if bucket_weights[:, b].max() < 0.05: |
| continue |
|
|
| relevance = torch.bmm( |
| projected_query, |
| self.memory_buckets[b].transpose(1, 2) |
| ) / (self.embedding_dim ** 0.5) |
|
|
| age_penalty = torch.exp(-self.memory_age[b] * 0.7).unsqueeze(1) |
| relevance *= age_penalty |
|
|
| retrieval_weights = F.softmax(relevance, dim=-1) |
| retrieved_values = torch.bmm(retrieval_weights, self.memory_buckets[b]) |
|
|
| importance_scale = torch.sigmoid(self.bucket_importance[b]) |
| outputs += retrieved_values * importance_scale * bucket_weights[:, b].view(batch_size, 1, 1) |
|
|
| memory_output = self.output_proj(outputs) |
|
|
| |
| if memory_update and self.training: |
| with torch.no_grad(): |
| keys = self.key_proj(normalized_input) |
| values = self.value_proj(normalized_input) |
|
|
| for b in range(self.num_buckets): |
| bucket_size = self.bucket_sizes[b] |
| bucket_mask = (bucket_weights[:, b] > 0.1).float().view(-1, 1, 1) |
|
|
| if seq_len > bucket_size: |
| stride = max(1, seq_len // bucket_size) |
| indices = torch.arange(0, seq_len, stride, device=device)[:bucket_size] |
| selected_values = values[:, indices] |
| else: |
| padding = bucket_size - seq_len |
| selected_values = F.pad(values, (0, 0, 0, padding)) |
|
|
| alpha = torch.sigmoid(self.bucket_importance[b]) * (0.8 if b > self.num_buckets // 2 else 0.2) |
|
|
| update = alpha * self.memory_buckets[b] + (1 - alpha) * selected_values |
| self.memory_buckets[b] = self.memory_buckets[b] * (1 - bucket_mask) + update * bucket_mask |
|
|
| age_mask = (1 - bucket_mask.squeeze(-1)) |
| self.memory_age[b] = self.memory_age[b] * age_mask + self.decay_rate |
|
|
| return self.output_norm(input_data + memory_output) |
|
|
| |
| class BucketMemoryTransformerLayer(nn.Module): |
| def __init__(self, d_model=512, d_ff=2048, dropout=0.4, num_buckets=8, |
| min_bucket_size=1, max_bucket_size=32, use_flash_attention=True, |
| num_heads=8): |
| super().__init__() |
| self.use_flash_attention = use_flash_attention |
| self.num_heads = num_heads |
| self.head_dim = d_model // num_heads |
|
|
| |
| self.q_proj = nn.Linear(d_model, d_model) |
| self.k_proj = nn.Linear(d_model, d_model) |
| self.v_proj = nn.Linear(d_model, d_model) |
| self.out_proj = nn.Linear(d_model, d_model) |
|
|
| |
| self.bucket_memory = DynamicBucketMemory( |
| embedding_dim=d_model, num_buckets=num_buckets, |
| min_bucket_size=min_bucket_size, max_bucket_size=max_bucket_size |
| ) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
|
|
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model) |
| ) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, attention_mask=None): |
| |
| residual = x |
| x = self.norm1(x) |
|
|
| batch_size, seq_len, _ = x.shape |
|
|
| |
| q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'): |
| |
| attn_mask = None |
| if attention_mask is not None: |
| attn_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
| attn_mask = (1.0 - attn_mask) * -10000.0 |
|
|
| |
| attn_output = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=attn_mask, |
| dropout_p=self.dropout.p if self.training else 0.0, |
| is_causal=False |
| ) |
| else: |
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5) |
|
|
| if attention_mask is not None: |
| scores = scores.masked_fill(attention_mask.unsqueeze(1).unsqueeze(2) == 0, -1e9) |
|
|
| attn_weights = F.softmax(scores, dim=-1) |
| attn_weights = self.dropout(attn_weights) |
| attn_output = torch.matmul(attn_weights, v) |
|
|
| |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) |
| attn_output = self.out_proj(attn_output) |
| x = residual + self.dropout(attn_output) |
|
|
| |
| memory_out = self.bucket_memory(self.norm2(x)) |
| x = x + self.dropout(memory_out) |
|
|
| |
| x = x + self.dropout(self.ff(self.norm3(x))) |
| return x |
|
|
|
|
| |
| class BucketMemoryModel(PreTrainedModel): |
| config_class = BucketMemoryConfig |
| base_model_prefix = "bucket-memory-model2" |
| def __init__(self, config, adapter_kwargs=None): |
| super().__init__(config) |
| self.d_model = config.d_model |
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| |
| self.tape_position_encoder = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model), |
| nn.ReLU(), |
| nn.Linear(config.d_model, config.d_model) |
| ) |
|
|
| self.pos_encoding = nn.Parameter(torch.zeros(1, config.max_seq_length, config.d_model)) |
| self._init_positional_encoding(config.max_seq_length, config.d_model) |
|
|
| num_heads = max(1, getattr(config, 'num_attention_heads', config.d_model // 64)) |
| self.layers = nn.ModuleList([ |
| BucketMemoryTransformerLayer( |
| d_model=config.d_model, |
| d_ff=4*config.d_model, |
| dropout=config.dropout, |
| num_buckets=config.num_buckets, |
| min_bucket_size=config.min_bucket_size, |
| max_bucket_size=config.max_bucket_size, |
| use_flash_attention=getattr(config, 'use_flash_attention', True), |
| num_heads=num_heads |
| ) for _ in range(config.num_layers) |
| ]) |
|
|
| self.norm = nn.LayerNorm(config.d_model) |
| self.output_proj = nn.Linear(config.d_model, config.vocab_size) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def _init_positional_encoding(self, max_len, d_model): |
| position = torch.arange(0, max_len).unsqueeze(1).float() |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model)) |
| pos_enc = torch.zeros(1, max_len, d_model) |
| pos_enc[0, :, 0::2] = torch.sin(position * div_term) |
| pos_enc[0, :, 1::2] = torch.cos(position * div_term) |
| self.pos_encoding.data.copy_(pos_enc) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| batch_size, seq_len = input_ids.size() |
| x = self.token_embedding(input_ids) * np.sqrt(self.d_model) |
| tape_pos = self.tape_position_encoder(x) |
| x = x + tape_pos |
| x = x + self.pos_encoding[:, :seq_len] |
| x = self.dropout(x) |
|
|
| for layer in self.layers: |
| x = layer(x, attention_mask) |
|
|
| x = self.norm(x) |
| logits = self.output_proj(x) |
|
|
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) |
| return type('ModelOutput', (), {'loss': loss, 'logits': logits}) |
| return logits |
|
|
| def generate(self, input_ids, max_length=50): |
| generated_tokens = input_ids |
| for _ in range(max_length): |
| logits = self.forward(generated_tokens) |
| |
| if hasattr(logits, 'logits'): |
| next_token_logits = logits.logits[:, -1, :] |
| else: |
| next_token_logits = logits[:, -1, :] |
|
|
| next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1) |
| generated_tokens = torch.cat((generated_tokens, next_token_id), dim=1) |
|
|
| if next_token_id.item() == self.config.eos_token_id: |
| break |
| return generated_tokens |
| AutoConfig.register("bucket-memory-model3", BucketMemoryConfig) |
| AutoModel.register(BucketMemoryConfig, BucketMemoryModel) |
| BucketMemoryConfig.register_for_auto_class() |
| BucketMemoryModel.register_for_auto_class("AutoModel") |
|
|