Transformers
English
Japanese
Spanish
recursive_transformer
Drjkedwards's picture
Create Model.py
7c56176 verified
# coding=utf-8
# Copyright 2025 Dr. Josef Kurk Edwards (drQedwards / josefedwards). All rights reserved.
# Licensed under the MIT License (see LICENSE in the ERS repository).
# This file provides the official Hugging Face integration for the Recursive Transformer Model (RTM) + Enhanced Reconsideration System (ERS).
"""PyTorch Recursive Transformer Model (RTM) with Persistent Memory Logic Loops (PMLL) and ERS runtime.
This is the core modeling file for the Hugging Face repository.
It defines the full RTM architecture (PMLLLattice + reconsideration logic) and supports
`from_pretrained` / `save_pretrained` exactly like any other HF model.
"""
import torch
import torch.nn as nn
import json
from dataclasses import dataclass
from datetime import datetime
import hashlib
from typing import Optional, Dict, List, Any
from transformers import PretrainedConfig, PreTrainedModel
class RecursiveTransformerConfig(PretrainedConfig):
model_type = "recursive_transformer"
def __init__(
self,
embedding_dim: int = 384,
num_petals: int = 8,
decay_alpha: float = 0.95,
consensus_threshold: float = 0.75,
contradiction_threshold: float = 0.65,
max_recursive_passes: int = 3,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
self.num_petals = num_petals
self.decay_alpha = decay_alpha
self.consensus_threshold = consensus_threshold
self.contradiction_threshold = contradiction_threshold
self.max_recursive_passes = max_recursive_passes
@dataclass
class MemoryBlock:
"""Single persistent memory unit used by ERS."""
id: str
text: str
embedding: Optional[torch.Tensor] = None
confidence: float = 1.0
created_at: Optional[str] = None
updated_at: Optional[str] = None
sha256_hash: Optional[str] = None
kg_id: Optional[str] = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow().isoformat()
if self.updated_at is None:
self.updated_at = self.created_at
if self.sha256_hash is None:
self.sha256_hash = hashlib.sha256(self.text.encode("utf-8")).hexdigest()
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"text": self.text,
"confidence": self.confidence,
"created_at": self.created_at,
"updated_at": self.updated_at,
"sha256_hash": self.sha256_hash,
"kg_id": self.kg_id,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "MemoryBlock":
return cls(**data)
class PMLLLattice(nn.Module):
"""Persistent Memory Logic Loop (PMLL) lattice – the core tensor routing and reconsideration engine."""
def __init__(self, config: RecursiveTransformerConfig):
super().__init__()
self.config = config
self.embedding_dim = config.embedding_dim
# Multi-petal attention projections (simulates the "flower" attention from the paper)
self.petal_projections = nn.ModuleList([
nn.Linear(config.embedding_dim, config.embedding_dim)
for _ in range(config.num_petals)
])
self.consensus_head = nn.Linear(config.embedding_dim, 1)
self.decay_param = nn.Parameter(torch.tensor(config.decay_alpha))
def forward(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply multi-petal transformation + consensus scoring."""
petal_outputs = [proj(embeddings) for proj in self.petal_projections]
combined = torch.stack(petal_outputs, dim=0).mean(dim=0) # average across petals
consensus_score = torch.sigmoid(self.consensus_head(combined))
return combined, consensus_score
def apply_temporal_decay(self, confidence: torch.Tensor, time_delta_days: float = 1.0) -> torch.Tensor:
"""Adaptive temporal decay (core of RTM reconsideration)."""
return confidence * torch.pow(self.decay_param, time_delta_days)
class RecursiveTransformerModel(PreTrainedModel):
"""
Full Recursive Transformer Model with Enhanced Reconsideration System (ERS).
This is the main class users will import with `from_pretrained`.
"""
config_class = RecursiveTransformerConfig
base_model_prefix = "recursive_transformer"
supports_gradient_checkpointing = False
def __init__(self, config: RecursiveTransformerConfig):
super().__init__(config)
self.config = config
self.lattice = PMLLLattice(config)
self.memory_line: List[MemoryBlock] = [] # active memory slots
def add_memory(self, text: str, embedding: Optional[torch.Tensor] = None, confidence: float = 1.0) -> MemoryBlock:
"""Add a new memory block (ERS `add_memory`)."""
block = MemoryBlock(
id=f"mem_{len(self.memory_line)}",
text=text,
embedding=embedding,
confidence=confidence,
)
self.memory_line.append(block)
return block
def reconsider(self, passes: Optional[int] = None) -> List[MemoryBlock]:
"""Run full RTM recursive reconsideration loop (temporal decay → consensus → contradiction)."""
passes = passes or self.config.max_recursive_passes
for i in range(passes):
print(f"→ RTM Reconsideration pass {i+1}/{passes}")
for block in self.memory_line:
if block.embedding is not None:
_, score = self.lattice(block.embedding.unsqueeze(0))
block.confidence = float(score.mean().item())
# In a full production version this would also call contradiction detection + rewrite
return self.memory_line
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
"""Load model + lattice weights exactly like any HF model."""
config = kwargs.pop("config", None)
if config is None:
config = RecursiveTransformerConfig.from_pretrained(pretrained_model_name_or_path)
model = cls(config)
# Load lattice weights if present
try:
state_dict = torch.load(
f"{pretrained_model_name_or_path}/pytorch_model.bin",
map_location="cpu",
weights_only=True,
)
model.lattice.load_state_dict(state_dict, strict=False)
print("✅ Loaded PMLLLattice weights from pytorch_model.bin")
except Exception:
print("⚠️ No pytorch_model.bin found – using freshly initialized lattice")
# Optional: load saved memory state
try:
with open(f"{pretrained_model_name_or_path}/memory_state.json", "r") as f:
mem_data = json.load(f)
model.memory_line = [MemoryBlock.from_dict(d) for d in mem_data]
print(f"✅ Loaded {len(model.memory_line)} saved memory blocks")
except Exception:
pass
return model
def save_pretrained(self, save_directory: str, **kwargs):
"""Save model weights + memory state."""
super().save_pretrained(save_directory, **kwargs)
# Save lattice
torch.save(self.lattice.state_dict(), f"{save_directory}/pytorch_model.bin")
# Save memory line
memory_data = [block.to_dict() for block in self.memory_line]
with open(f"{save_directory}/memory_state.json", "w") as f:
json.dump(memory_data, f, indent=2)
# For easy importing from the repo
__all__ = ["RecursiveTransformerConfig", "RecursiveTransformerModel", "MemoryBlock", "PMLLLattice"]