| |
| |
| |
| |
|
|
| """PyTorch Recursive Transformer Model (RTM) with Persistent Memory Logic Loops (PMLL) and ERS runtime. |
| |
| This is the core modeling file for the Hugging Face repository. |
| It defines the full RTM architecture (PMLLLattice + reconsideration logic) and supports |
| `from_pretrained` / `save_pretrained` exactly like any other HF model. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import json |
| from dataclasses import dataclass |
| from datetime import datetime |
| import hashlib |
| from typing import Optional, Dict, List, Any |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| class RecursiveTransformerConfig(PretrainedConfig): |
| model_type = "recursive_transformer" |
| def __init__( |
| self, |
| embedding_dim: int = 384, |
| num_petals: int = 8, |
| decay_alpha: float = 0.95, |
| consensus_threshold: float = 0.75, |
| contradiction_threshold: float = 0.65, |
| max_recursive_passes: int = 3, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.embedding_dim = embedding_dim |
| self.num_petals = num_petals |
| self.decay_alpha = decay_alpha |
| self.consensus_threshold = consensus_threshold |
| self.contradiction_threshold = contradiction_threshold |
| self.max_recursive_passes = max_recursive_passes |
|
|
|
|
| @dataclass |
| class MemoryBlock: |
| """Single persistent memory unit used by ERS.""" |
| id: str |
| text: str |
| embedding: Optional[torch.Tensor] = None |
| confidence: float = 1.0 |
| created_at: Optional[str] = None |
| updated_at: Optional[str] = None |
| sha256_hash: Optional[str] = None |
| kg_id: Optional[str] = None |
|
|
| def __post_init__(self): |
| if self.created_at is None: |
| self.created_at = datetime.utcnow().isoformat() |
| if self.updated_at is None: |
| self.updated_at = self.created_at |
| if self.sha256_hash is None: |
| self.sha256_hash = hashlib.sha256(self.text.encode("utf-8")).hexdigest() |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| return { |
| "id": self.id, |
| "text": self.text, |
| "confidence": self.confidence, |
| "created_at": self.created_at, |
| "updated_at": self.updated_at, |
| "sha256_hash": self.sha256_hash, |
| "kg_id": self.kg_id, |
| } |
|
|
| @classmethod |
| def from_dict(cls, data: Dict[str, Any]) -> "MemoryBlock": |
| return cls(**data) |
|
|
|
|
| class PMLLLattice(nn.Module): |
| """Persistent Memory Logic Loop (PMLL) lattice – the core tensor routing and reconsideration engine.""" |
|
|
| def __init__(self, config: RecursiveTransformerConfig): |
| super().__init__() |
| self.config = config |
| self.embedding_dim = config.embedding_dim |
|
|
| |
| self.petal_projections = nn.ModuleList([ |
| nn.Linear(config.embedding_dim, config.embedding_dim) |
| for _ in range(config.num_petals) |
| ]) |
|
|
| self.consensus_head = nn.Linear(config.embedding_dim, 1) |
| self.decay_param = nn.Parameter(torch.tensor(config.decay_alpha)) |
|
|
| def forward(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Apply multi-petal transformation + consensus scoring.""" |
| petal_outputs = [proj(embeddings) for proj in self.petal_projections] |
| combined = torch.stack(petal_outputs, dim=0).mean(dim=0) |
| consensus_score = torch.sigmoid(self.consensus_head(combined)) |
| return combined, consensus_score |
|
|
| def apply_temporal_decay(self, confidence: torch.Tensor, time_delta_days: float = 1.0) -> torch.Tensor: |
| """Adaptive temporal decay (core of RTM reconsideration).""" |
| return confidence * torch.pow(self.decay_param, time_delta_days) |
|
|
|
|
| class RecursiveTransformerModel(PreTrainedModel): |
| """ |
| Full Recursive Transformer Model with Enhanced Reconsideration System (ERS). |
| This is the main class users will import with `from_pretrained`. |
| """ |
| config_class = RecursiveTransformerConfig |
| base_model_prefix = "recursive_transformer" |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: RecursiveTransformerConfig): |
| super().__init__(config) |
| self.config = config |
| self.lattice = PMLLLattice(config) |
| self.memory_line: List[MemoryBlock] = [] |
|
|
| def add_memory(self, text: str, embedding: Optional[torch.Tensor] = None, confidence: float = 1.0) -> MemoryBlock: |
| """Add a new memory block (ERS `add_memory`).""" |
| block = MemoryBlock( |
| id=f"mem_{len(self.memory_line)}", |
| text=text, |
| embedding=embedding, |
| confidence=confidence, |
| ) |
| self.memory_line.append(block) |
| return block |
|
|
| def reconsider(self, passes: Optional[int] = None) -> List[MemoryBlock]: |
| """Run full RTM recursive reconsideration loop (temporal decay → consensus → contradiction).""" |
| passes = passes or self.config.max_recursive_passes |
| for i in range(passes): |
| print(f"→ RTM Reconsideration pass {i+1}/{passes}") |
| for block in self.memory_line: |
| if block.embedding is not None: |
| _, score = self.lattice(block.embedding.unsqueeze(0)) |
| block.confidence = float(score.mean().item()) |
| |
| return self.memory_line |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
| """Load model + lattice weights exactly like any HF model.""" |
| config = kwargs.pop("config", None) |
| if config is None: |
| config = RecursiveTransformerConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
| model = cls(config) |
|
|
| |
| try: |
| state_dict = torch.load( |
| f"{pretrained_model_name_or_path}/pytorch_model.bin", |
| map_location="cpu", |
| weights_only=True, |
| ) |
| model.lattice.load_state_dict(state_dict, strict=False) |
| print("✅ Loaded PMLLLattice weights from pytorch_model.bin") |
| except Exception: |
| print("⚠️ No pytorch_model.bin found – using freshly initialized lattice") |
|
|
| |
| try: |
| with open(f"{pretrained_model_name_or_path}/memory_state.json", "r") as f: |
| mem_data = json.load(f) |
| model.memory_line = [MemoryBlock.from_dict(d) for d in mem_data] |
| print(f"✅ Loaded {len(model.memory_line)} saved memory blocks") |
| except Exception: |
| pass |
|
|
| return model |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| """Save model weights + memory state.""" |
| super().save_pretrained(save_directory, **kwargs) |
| |
| torch.save(self.lattice.state_dict(), f"{save_directory}/pytorch_model.bin") |
| |
| memory_data = [block.to_dict() for block in self.memory_line] |
| with open(f"{save_directory}/memory_state.json", "w") as f: |
| json.dump(memory_data, f, indent=2) |
|
|
|
|
| |
| __all__ = ["RecursiveTransformerConfig", "RecursiveTransformerModel", "MemoryBlock", "PMLLLattice"] |