# coding=utf-8 # Copyright 2025 Dr. Josef Kurk Edwards (drQedwards / josefedwards). All rights reserved. # Licensed under the MIT License (see LICENSE in the ERS repository). # This file provides the official Hugging Face integration for the Recursive Transformer Model (RTM) + Enhanced Reconsideration System (ERS). """PyTorch Recursive Transformer Model (RTM) with Persistent Memory Logic Loops (PMLL) and ERS runtime. This is the core modeling file for the Hugging Face repository. It defines the full RTM architecture (PMLLLattice + reconsideration logic) and supports `from_pretrained` / `save_pretrained` exactly like any other HF model. """ import torch import torch.nn as nn import json from dataclasses import dataclass from datetime import datetime import hashlib from typing import Optional, Dict, List, Any from transformers import PretrainedConfig, PreTrainedModel class RecursiveTransformerConfig(PretrainedConfig): model_type = "recursive_transformer" def __init__( self, embedding_dim: int = 384, num_petals: int = 8, decay_alpha: float = 0.95, consensus_threshold: float = 0.75, contradiction_threshold: float = 0.65, max_recursive_passes: int = 3, **kwargs, ): super().__init__(**kwargs) self.embedding_dim = embedding_dim self.num_petals = num_petals self.decay_alpha = decay_alpha self.consensus_threshold = consensus_threshold self.contradiction_threshold = contradiction_threshold self.max_recursive_passes = max_recursive_passes @dataclass class MemoryBlock: """Single persistent memory unit used by ERS.""" id: str text: str embedding: Optional[torch.Tensor] = None confidence: float = 1.0 created_at: Optional[str] = None updated_at: Optional[str] = None sha256_hash: Optional[str] = None kg_id: Optional[str] = None def __post_init__(self): if self.created_at is None: self.created_at = datetime.utcnow().isoformat() if self.updated_at is None: self.updated_at = self.created_at if self.sha256_hash is None: self.sha256_hash = hashlib.sha256(self.text.encode("utf-8")).hexdigest() def to_dict(self) -> Dict[str, Any]: return { "id": self.id, "text": self.text, "confidence": self.confidence, "created_at": self.created_at, "updated_at": self.updated_at, "sha256_hash": self.sha256_hash, "kg_id": self.kg_id, } @classmethod def from_dict(cls, data: Dict[str, Any]) -> "MemoryBlock": return cls(**data) class PMLLLattice(nn.Module): """Persistent Memory Logic Loop (PMLL) lattice – the core tensor routing and reconsideration engine.""" def __init__(self, config: RecursiveTransformerConfig): super().__init__() self.config = config self.embedding_dim = config.embedding_dim # Multi-petal attention projections (simulates the "flower" attention from the paper) self.petal_projections = nn.ModuleList([ nn.Linear(config.embedding_dim, config.embedding_dim) for _ in range(config.num_petals) ]) self.consensus_head = nn.Linear(config.embedding_dim, 1) self.decay_param = nn.Parameter(torch.tensor(config.decay_alpha)) def forward(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Apply multi-petal transformation + consensus scoring.""" petal_outputs = [proj(embeddings) for proj in self.petal_projections] combined = torch.stack(petal_outputs, dim=0).mean(dim=0) # average across petals consensus_score = torch.sigmoid(self.consensus_head(combined)) return combined, consensus_score def apply_temporal_decay(self, confidence: torch.Tensor, time_delta_days: float = 1.0) -> torch.Tensor: """Adaptive temporal decay (core of RTM reconsideration).""" return confidence * torch.pow(self.decay_param, time_delta_days) class RecursiveTransformerModel(PreTrainedModel): """ Full Recursive Transformer Model with Enhanced Reconsideration System (ERS). This is the main class users will import with `from_pretrained`. """ config_class = RecursiveTransformerConfig base_model_prefix = "recursive_transformer" supports_gradient_checkpointing = False def __init__(self, config: RecursiveTransformerConfig): super().__init__(config) self.config = config self.lattice = PMLLLattice(config) self.memory_line: List[MemoryBlock] = [] # active memory slots def add_memory(self, text: str, embedding: Optional[torch.Tensor] = None, confidence: float = 1.0) -> MemoryBlock: """Add a new memory block (ERS `add_memory`).""" block = MemoryBlock( id=f"mem_{len(self.memory_line)}", text=text, embedding=embedding, confidence=confidence, ) self.memory_line.append(block) return block def reconsider(self, passes: Optional[int] = None) -> List[MemoryBlock]: """Run full RTM recursive reconsideration loop (temporal decay → consensus → contradiction).""" passes = passes or self.config.max_recursive_passes for i in range(passes): print(f"→ RTM Reconsideration pass {i+1}/{passes}") for block in self.memory_line: if block.embedding is not None: _, score = self.lattice(block.embedding.unsqueeze(0)) block.confidence = float(score.mean().item()) # In a full production version this would also call contradiction detection + rewrite return self.memory_line @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): """Load model + lattice weights exactly like any HF model.""" config = kwargs.pop("config", None) if config is None: config = RecursiveTransformerConfig.from_pretrained(pretrained_model_name_or_path) model = cls(config) # Load lattice weights if present try: state_dict = torch.load( f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location="cpu", weights_only=True, ) model.lattice.load_state_dict(state_dict, strict=False) print("✅ Loaded PMLLLattice weights from pytorch_model.bin") except Exception: print("⚠️ No pytorch_model.bin found – using freshly initialized lattice") # Optional: load saved memory state try: with open(f"{pretrained_model_name_or_path}/memory_state.json", "r") as f: mem_data = json.load(f) model.memory_line = [MemoryBlock.from_dict(d) for d in mem_data] print(f"✅ Loaded {len(model.memory_line)} saved memory blocks") except Exception: pass return model def save_pretrained(self, save_directory: str, **kwargs): """Save model weights + memory state.""" super().save_pretrained(save_directory, **kwargs) # Save lattice torch.save(self.lattice.state_dict(), f"{save_directory}/pytorch_model.bin") # Save memory line memory_data = [block.to_dict() for block in self.memory_line] with open(f"{save_directory}/memory_state.json", "w") as f: json.dump(memory_data, f, indent=2) # For easy importing from the repo __all__ = ["RecursiveTransformerConfig", "RecursiveTransformerModel", "MemoryBlock", "PMLLLattice"]