Transformers
English
Japanese
Spanish
recursive_transformer
File size: 7,744 Bytes
7c56176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# coding=utf-8
# Copyright 2025 Dr. Josef Kurk Edwards (drQedwards / josefedwards). All rights reserved.
# Licensed under the MIT License (see LICENSE in the ERS repository).
# This file provides the official Hugging Face integration for the Recursive Transformer Model (RTM) + Enhanced Reconsideration System (ERS).

"""PyTorch Recursive Transformer Model (RTM) with Persistent Memory Logic Loops (PMLL) and ERS runtime.

This is the core modeling file for the Hugging Face repository.
It defines the full RTM architecture (PMLLLattice + reconsideration logic) and supports
`from_pretrained` / `save_pretrained` exactly like any other HF model.
"""

import torch
import torch.nn as nn
import json
from dataclasses import dataclass
from datetime import datetime
import hashlib
from typing import Optional, Dict, List, Any
from transformers import PretrainedConfig, PreTrainedModel


class RecursiveTransformerConfig(PretrainedConfig):
    model_type = "recursive_transformer"
    def __init__(
        self,
        embedding_dim: int = 384,
        num_petals: int = 8,
        decay_alpha: float = 0.95,
        consensus_threshold: float = 0.75,
        contradiction_threshold: float = 0.65,
        max_recursive_passes: int = 3,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_petals = num_petals
        self.decay_alpha = decay_alpha
        self.consensus_threshold = consensus_threshold
        self.contradiction_threshold = contradiction_threshold
        self.max_recursive_passes = max_recursive_passes


@dataclass
class MemoryBlock:
    """Single persistent memory unit used by ERS."""
    id: str
    text: str
    embedding: Optional[torch.Tensor] = None
    confidence: float = 1.0
    created_at: Optional[str] = None
    updated_at: Optional[str] = None
    sha256_hash: Optional[str] = None
    kg_id: Optional[str] = None

    def __post_init__(self):
        if self.created_at is None:
            self.created_at = datetime.utcnow().isoformat()
        if self.updated_at is None:
            self.updated_at = self.created_at
        if self.sha256_hash is None:
            self.sha256_hash = hashlib.sha256(self.text.encode("utf-8")).hexdigest()

    def to_dict(self) -> Dict[str, Any]:
        return {
            "id": self.id,
            "text": self.text,
            "confidence": self.confidence,
            "created_at": self.created_at,
            "updated_at": self.updated_at,
            "sha256_hash": self.sha256_hash,
            "kg_id": self.kg_id,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "MemoryBlock":
        return cls(**data)


class PMLLLattice(nn.Module):
    """Persistent Memory Logic Loop (PMLL) lattice – the core tensor routing and reconsideration engine."""

    def __init__(self, config: RecursiveTransformerConfig):
        super().__init__()
        self.config = config
        self.embedding_dim = config.embedding_dim

        # Multi-petal attention projections (simulates the "flower" attention from the paper)
        self.petal_projections = nn.ModuleList([
            nn.Linear(config.embedding_dim, config.embedding_dim)
            for _ in range(config.num_petals)
        ])

        self.consensus_head = nn.Linear(config.embedding_dim, 1)
        self.decay_param = nn.Parameter(torch.tensor(config.decay_alpha))

    def forward(self, embeddings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply multi-petal transformation + consensus scoring."""
        petal_outputs = [proj(embeddings) for proj in self.petal_projections]
        combined = torch.stack(petal_outputs, dim=0).mean(dim=0)          # average across petals
        consensus_score = torch.sigmoid(self.consensus_head(combined))
        return combined, consensus_score

    def apply_temporal_decay(self, confidence: torch.Tensor, time_delta_days: float = 1.0) -> torch.Tensor:
        """Adaptive temporal decay (core of RTM reconsideration)."""
        return confidence * torch.pow(self.decay_param, time_delta_days)


class RecursiveTransformerModel(PreTrainedModel):
    """
    Full Recursive Transformer Model with Enhanced Reconsideration System (ERS).
    This is the main class users will import with `from_pretrained`.
    """
    config_class = RecursiveTransformerConfig
    base_model_prefix = "recursive_transformer"
    supports_gradient_checkpointing = False

    def __init__(self, config: RecursiveTransformerConfig):
        super().__init__(config)
        self.config = config
        self.lattice = PMLLLattice(config)
        self.memory_line: List[MemoryBlock] = []   # active memory slots

    def add_memory(self, text: str, embedding: Optional[torch.Tensor] = None, confidence: float = 1.0) -> MemoryBlock:
        """Add a new memory block (ERS `add_memory`)."""
        block = MemoryBlock(
            id=f"mem_{len(self.memory_line)}",
            text=text,
            embedding=embedding,
            confidence=confidence,
        )
        self.memory_line.append(block)
        return block

    def reconsider(self, passes: Optional[int] = None) -> List[MemoryBlock]:
        """Run full RTM recursive reconsideration loop (temporal decay → consensus → contradiction)."""
        passes = passes or self.config.max_recursive_passes
        for i in range(passes):
            print(f"→ RTM Reconsideration pass {i+1}/{passes}")
            for block in self.memory_line:
                if block.embedding is not None:
                    _, score = self.lattice(block.embedding.unsqueeze(0))
                    block.confidence = float(score.mean().item())
                    # In a full production version this would also call contradiction detection + rewrite
        return self.memory_line

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs):
        """Load model + lattice weights exactly like any HF model."""
        config = kwargs.pop("config", None)
        if config is None:
            config = RecursiveTransformerConfig.from_pretrained(pretrained_model_name_or_path)

        model = cls(config)

        # Load lattice weights if present
        try:
            state_dict = torch.load(
                f"{pretrained_model_name_or_path}/pytorch_model.bin",
                map_location="cpu",
                weights_only=True,
            )
            model.lattice.load_state_dict(state_dict, strict=False)
            print("✅ Loaded PMLLLattice weights from pytorch_model.bin")
        except Exception:
            print("⚠️  No pytorch_model.bin found – using freshly initialized lattice")

        # Optional: load saved memory state
        try:
            with open(f"{pretrained_model_name_or_path}/memory_state.json", "r") as f:
                mem_data = json.load(f)
                model.memory_line = [MemoryBlock.from_dict(d) for d in mem_data]
                print(f"✅ Loaded {len(model.memory_line)} saved memory blocks")
        except Exception:
            pass

        return model

    def save_pretrained(self, save_directory: str, **kwargs):
        """Save model weights + memory state."""
        super().save_pretrained(save_directory, **kwargs)
        # Save lattice
        torch.save(self.lattice.state_dict(), f"{save_directory}/pytorch_model.bin")
        # Save memory line
        memory_data = [block.to_dict() for block in self.memory_line]
        with open(f"{save_directory}/memory_state.json", "w") as f:
            json.dump(memory_data, f, indent=2)


# For easy importing from the repo
__all__ = ["RecursiveTransformerConfig", "RecursiveTransformerModel", "MemoryBlock", "PMLLLattice"]