Transformers
English
Japanese
Spanish
recursive_transformer
Drjkedwards commited on
Commit
7c56176
·
verified ·
1 Parent(s): e183a48

Create Model.py

Browse files
Files changed (1) hide show
  1. Model.py +191 -0
Model.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Dr. Josef Kurk Edwards (drQedwards / josefedwards). All rights reserved.
3
+ # Licensed under the MIT License (see LICENSE in the ERS repository).
4
+ # This file provides the official Hugging Face integration for the Recursive Transformer Model (RTM) + Enhanced Reconsideration System (ERS).
5
+
6
+ """PyTorch Recursive Transformer Model (RTM) with Persistent Memory Logic Loops (PMLL) and ERS runtime.
7
+
8
+ This is the core modeling file for the Hugging Face repository.
9
+ It defines the full RTM architecture (PMLLLattice + reconsideration logic) and supports
10
+ `from_pretrained` / `save_pretrained` exactly like any other HF model.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import json
16
+ from dataclasses import dataclass
17
+ from datetime import datetime
18
+ import hashlib
19
+ from typing import Optional, Dict, List, Any
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+
22
+
23
+ class RecursiveTransformerConfig(PretrainedConfig):
24
+ model_type = "recursive_transformer"
25
+ def __init__(
26
+ self,
27
+ embedding_dim: int = 384,
28
+ num_petals: int = 8,
29
+ decay_alpha: float = 0.95,
30
+ consensus_threshold: float = 0.75,
31
+ contradiction_threshold: float = 0.65,
32
+ max_recursive_passes: int = 3,
33
+ **kwargs,
34
+ ):
35
+ super().__init__(**kwargs)
36
+ self.embedding_dim = embedding_dim
37
+ self.num_petals = num_petals
38
+ self.decay_alpha = decay_alpha
39
+ self.consensus_threshold = consensus_threshold
40
+ self.contradiction_threshold = contradiction_threshold
41
+ self.max_recursive_passes = max_recursive_passes
42
+
43
+
44
+ @dataclass
45
+ class MemoryBlock:
46
+ """Single persistent memory unit used by ERS."""
47
+ id: str
48
+ text: str
49
+ embedding: Optional[torch.Tensor] = None
50
+ confidence: float = 1.0
51
+ created_at: Optional[str] = None
52
+ updated_at: Optional[str] = None
53
+ sha256_hash: Optional[str] = None
54
+ kg_id: Optional[str] = None
55
+
56
+ def __post_init__(self):
57
+ if self.created_at is None:
58
+ self.created_at = datetime.utcnow().isoformat()
59
+ if self.updated_at is None:
60
+ self.updated_at = self.created_at
61
+ if self.sha256_hash is None:
62
+ self.sha256_hash = hashlib.sha256(self.text.encode("utf-8")).hexdigest()
63
+
64
+ def to_dict(self) -> Dict[str, Any]:
65
+ return {
66
+ "id": self.id,
67
+ "text": self.text,
68
+ "confidence": self.confidence,
69
+ "created_at": self.created_at,
70
+ "updated_at": self.updated_at,
71
+ "sha256_hash": self.sha256_hash,
72
+ "kg_id": self.kg_id,
73
+ }
74
+
75
+ @classmethod
76
+ def from_dict(cls, data: Dict[str, Any]) -> "MemoryBlock":
77
+ return cls(**data)
78
+
79
+
80
+ class PMLLLattice(nn.Module):
81
+ """Persistent Memory Logic Loop (PMLL) lattice – the core tensor routing and reconsideration engine."""
82
+
83
+ def __init__(self, config: RecursiveTransformerConfig):
84
+ super().__init__()
85
+ self.config = config
86
+ self.embedding_dim = config.embedding_dim
87
+
88
+ # Multi-petal attention projections (simulates the "flower" attention from the paper)
89
+ self.petal_projections = nn.ModuleList([
90
+ nn.Linear(config.embedding_dim, config.embedding_dim)
91
+ for _ in range(config.num_petals)
92
+ ])
93
+
94
+ self.consensus_head = nn.Linear(config.embedding_dim, 1)
95
+ self.decay_param = nn.Parameter(torch.tensor(config.decay_alpha))
96
+
97
+ def forward(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
98
+ """Apply multi-petal transformation + consensus scoring."""
99
+ petal_outputs = [proj(embeddings) for proj in self.petal_projections]
100
+ combined = torch.stack(petal_outputs, dim=0).mean(dim=0) # average across petals
101
+ consensus_score = torch.sigmoid(self.consensus_head(combined))
102
+ return combined, consensus_score
103
+
104
+ def apply_temporal_decay(self, confidence: torch.Tensor, time_delta_days: float = 1.0) -> torch.Tensor:
105
+ """Adaptive temporal decay (core of RTM reconsideration)."""
106
+ return confidence * torch.pow(self.decay_param, time_delta_days)
107
+
108
+
109
+ class RecursiveTransformerModel(PreTrainedModel):
110
+ """
111
+ Full Recursive Transformer Model with Enhanced Reconsideration System (ERS).
112
+ This is the main class users will import with `from_pretrained`.
113
+ """
114
+ config_class = RecursiveTransformerConfig
115
+ base_model_prefix = "recursive_transformer"
116
+ supports_gradient_checkpointing = False
117
+
118
+ def __init__(self, config: RecursiveTransformerConfig):
119
+ super().__init__(config)
120
+ self.config = config
121
+ self.lattice = PMLLLattice(config)
122
+ self.memory_line: List[MemoryBlock] = [] # active memory slots
123
+
124
+ def add_memory(self, text: str, embedding: Optional[torch.Tensor] = None, confidence: float = 1.0) -> MemoryBlock:
125
+ """Add a new memory block (ERS `add_memory`)."""
126
+ block = MemoryBlock(
127
+ id=f"mem_{len(self.memory_line)}",
128
+ text=text,
129
+ embedding=embedding,
130
+ confidence=confidence,
131
+ )
132
+ self.memory_line.append(block)
133
+ return block
134
+
135
+ def reconsider(self, passes: Optional[int] = None) -> List[MemoryBlock]:
136
+ """Run full RTM recursive reconsideration loop (temporal decay → consensus → contradiction)."""
137
+ passes = passes or self.config.max_recursive_passes
138
+ for i in range(passes):
139
+ print(f"→ RTM Reconsideration pass {i+1}/{passes}")
140
+ for block in self.memory_line:
141
+ if block.embedding is not None:
142
+ _, score = self.lattice(block.embedding.unsqueeze(0))
143
+ block.confidence = float(score.mean().item())
144
+ # In a full production version this would also call contradiction detection + rewrite
145
+ return self.memory_line
146
+
147
+ @classmethod
148
+ def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
149
+ """Load model + lattice weights exactly like any HF model."""
150
+ config = kwargs.pop("config", None)
151
+ if config is None:
152
+ config = RecursiveTransformerConfig.from_pretrained(pretrained_model_name_or_path)
153
+
154
+ model = cls(config)
155
+
156
+ # Load lattice weights if present
157
+ try:
158
+ state_dict = torch.load(
159
+ f"{pretrained_model_name_or_path}/pytorch_model.bin",
160
+ map_location="cpu",
161
+ weights_only=True,
162
+ )
163
+ model.lattice.load_state_dict(state_dict, strict=False)
164
+ print("✅ Loaded PMLLLattice weights from pytorch_model.bin")
165
+ except Exception:
166
+ print("⚠️ No pytorch_model.bin found – using freshly initialized lattice")
167
+
168
+ # Optional: load saved memory state
169
+ try:
170
+ with open(f"{pretrained_model_name_or_path}/memory_state.json", "r") as f:
171
+ mem_data = json.load(f)
172
+ model.memory_line = [MemoryBlock.from_dict(d) for d in mem_data]
173
+ print(f"✅ Loaded {len(model.memory_line)} saved memory blocks")
174
+ except Exception:
175
+ pass
176
+
177
+ return model
178
+
179
+ def save_pretrained(self, save_directory: str, **kwargs):
180
+ """Save model weights + memory state."""
181
+ super().save_pretrained(save_directory, **kwargs)
182
+ # Save lattice
183
+ torch.save(self.lattice.state_dict(), f"{save_directory}/pytorch_model.bin")
184
+ # Save memory line
185
+ memory_data = [block.to_dict() for block in self.memory_line]
186
+ with open(f"{save_directory}/memory_state.json", "w") as f:
187
+ json.dump(memory_data, f, indent=2)
188
+
189
+
190
+ # For easy importing from the repo
191
+ __all__ = ["RecursiveTransformerConfig", "RecursiveTransformerModel", "MemoryBlock", "PMLLLattice"]