🧠 Full weight release: 9 probes × 3 architectures + production adapter + training code
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +124 -0
- code/cognitive_enhancement_suite.py +397 -0
- code/engine.py +656 -0
- code/train_cognitive_enhancement.py +434 -0
- code/training_pipelines/01_cfhot_risk_v2_REPETITION_125x.py +324 -0
- code/training_pipelines/02_arc_adapter_training_MULTIHEAD.py +1680 -0
- code/training_pipelines/03_arc_dense_train_DENSE.py +643 -0
- code/training_pipelines/04_lie_holonomy_experiment_GEOMETRY.py +689 -0
- code/training_pipelines/05_breakthrough_test_v2_LOOP4.py +935 -0
- code/training_pipelines/06_arc_engine_v30_FULL_ENGINE.py +0 -0
- code/training_pipelines/07_qwen3b_repetition_REPLICATION.py +484 -0
- code/training_pipelines/07b_qwen3b_repetition_FIXED.py +428 -0
- code/training_pipelines/07c_qwen3b_CONTINUE.py +37 -0
- code/training_pipelines/08_qwen3b_dimension_sweep_FULL.py +418 -0
- code/training_pipelines/09_continue_from_19x.py +349 -0
- code/training_pipelines/10_qwen_multihead_25k.py +610 -0
- code/training_pipelines/11_qwen_multihead_CLEAN.py +560 -0
- cognitive/mamba/calibration/calibration_head.pt +3 -0
- cognitive/mamba/coherence/coherence_head.pt +3 -0
- cognitive/mamba/depth/depth_head.pt +3 -0
- cognitive/mamba/focus/focus_head.pt +3 -0
- cognitive/mamba/specificity/specificity_head.pt +3 -0
- cognitive/mistral/calibration/calibration_head.pt +3 -0
- cognitive/mistral/coherence/coherence_head.pt +3 -0
- cognitive/mistral/depth/depth_head.pt +3 -0
- cognitive/mistral/focus/focus_head.pt +3 -0
- cognitive/mistral/results.json +7 -0
- cognitive/mistral/specificity/specificity_head.pt +3 -0
- cognitive/qwen/calibration/calibration_head.pt +3 -0
- cognitive/qwen/coherence/coherence_head.pt +3 -0
- cognitive/qwen/depth/depth_head.pt +3 -0
- cognitive/qwen/focus/focus_head.pt +3 -0
- cognitive/qwen/specificity/specificity_head.pt +3 -0
- production/adapter_config.json +43 -0
- production/adapter_model.safetensors +3 -0
- production/manifest.json +34 -0
- production/merged_heads.pt +3 -0
- production/qwen_cognitive/README.md +569 -0
- production/qwen_cognitive/cognitive_adapter.pt +3 -0
- production/qwen_cognitive/config.json +122 -0
- production/qwen_cognitive/inference.py +163 -0
- results/hedging_results.json +52 -0
- results/hedging_results_continued.json +82 -0
- results/hedging_summary.json +8 -0
- results/mistral_cognitive_results.json +7 -0
- results/sycophancy_results.json +52 -0
- results/sycophancy_summary.json +8 -0
- results/verbosity_results_continued.json +82 -0
- results/verbosity_summary.json +8 -0
- suppression/hedging_168x/fiber_proj.pt +3 -0
README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- behavioral-detection
|
| 5 |
+
- hidden-state-probing
|
| 6 |
+
- AI-safety
|
| 7 |
+
- repetition-suppression
|
| 8 |
+
- sycophancy-detection
|
| 9 |
+
- per-token-classification
|
| 10 |
+
- cross-architecture
|
| 11 |
+
- cognitive-enhancement
|
| 12 |
+
- holonomy-transformer
|
| 13 |
+
- fiber-bundle
|
| 14 |
+
language:
|
| 15 |
+
- en
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# Proprioceptive AI — Behavioral Probe Weights
|
| 19 |
+
|
| 20 |
+
**9 behavioral dimensions × 3 architectures. Trained probes that read LLM hidden states and detect/correct behavioral failures at decode time.**
|
| 21 |
+
|
| 22 |
+
> **Paper:** [Consistency Is All You Need](https://zenodo.org/records/18489530)
|
| 23 |
+
> **Website:** [ProprioceptiveAI.com](https://proprioceptiveai.com)
|
| 24 |
+
> **Patents:** 55 filed
|
| 25 |
+
> **Author:** Logan Napolitano — Fiber AI, Inc.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 🔬 What This Is
|
| 30 |
+
|
| 31 |
+
These are trained probe weights that detect behavioral patterns **per-token** from LLM hidden states. No fine-tuning. No RLHF. No sequence-level classifiers. The probes read the geometry of the hidden state space and predict behavior **before the token is generated**.
|
| 32 |
+
|
| 33 |
+
## 📊 Results
|
| 34 |
+
|
| 35 |
+
### Suppression Probes (LLaMA 3.1 8B)
|
| 36 |
+
|
| 37 |
+
| Behavior | Separation Ratio | Description |
|
| 38 |
+
|----------|-----------------|-------------|
|
| 39 |
+
| **Repetition** | **125×** | Detects repetitive degeneration |
|
| 40 |
+
| **Hedging** | **168×** | Detects hedge/qualify patterns |
|
| 41 |
+
| **Sycophancy** | **230×** | Detects agreement-bias behavior |
|
| 42 |
+
| **Verbosity** | **272×** | Detects excessive length patterns |
|
| 43 |
+
|
| 44 |
+
### Cognitive Enhancement Probes (Cross-Architecture)
|
| 45 |
+
|
| 46 |
+
| Probe | Qwen 14B | Mamba 7B | Mistral 7B |
|
| 47 |
+
|-------|----------|----------|------------|
|
| 48 |
+
| **Depth** | 999× | 999× | 999× |
|
| 49 |
+
| **Specificity** | 999× | 999× | 999× |
|
| 50 |
+
| **Calibration** | 999× | 999× | 999× |
|
| 51 |
+
| **Focus** | 999× | 999× | 999× |
|
| 52 |
+
| **Coherence** | 999× | 999× | 999× |
|
| 53 |
+
|
| 54 |
+
**Architecture-independent.** Same probe architecture works on Transformers (Qwen), SSMs (Mamba), and SWA Transformers (Mistral).
|
| 55 |
+
|
| 56 |
+
## 📁 Repository Structure
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
suppression/
|
| 60 |
+
├── repetition_125x/ # LoRA adapter + risk predictor
|
| 61 |
+
├── hedging_168x/ # Probe head + fiber projection
|
| 62 |
+
├── sycophancy_230x/ # Probe head + fiber projection
|
| 63 |
+
└── verbosity_272x/ # Probe head + fiber projection
|
| 64 |
+
|
| 65 |
+
cognitive/
|
| 66 |
+
├── qwen/ # 5 enhancement probes (Qwen 14B)
|
| 67 |
+
│ ├── depth/
|
| 68 |
+
│ ├── specificity/
|
| 69 |
+
│ ├── calibration/
|
| 70 |
+
│ ├── focus/
|
| 71 |
+
│ └── coherence/
|
| 72 |
+
├── mamba/ # 5 enhancement probes (Mamba 7B)
|
| 73 |
+
└── mistral/ # 5 enhancement probes (Mistral 7B)
|
| 74 |
+
|
| 75 |
+
production/
|
| 76 |
+
├── merged_heads.pt # All 4 suppression heads merged
|
| 77 |
+
├── adapter_config.json
|
| 78 |
+
├── adapter_model.safetensors
|
| 79 |
+
└── qwen_cognitive/ # Qwen cognitive adapter
|
| 80 |
+
|
| 81 |
+
code/ # Training scripts
|
| 82 |
+
results/ # Training logs and metrics
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## 🚀 Quick Start
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
import torch
|
| 89 |
+
|
| 90 |
+
# Load a suppression probe
|
| 91 |
+
probe = torch.load("suppression/hedging_168x/hedging_head.pt")
|
| 92 |
+
fiber_proj = torch.load("suppression/hedging_168x/fiber_proj.pt")
|
| 93 |
+
|
| 94 |
+
# Load cognitive enhancement probe
|
| 95 |
+
depth_probe = torch.load("cognitive/qwen/depth/depth_head.pt")
|
| 96 |
+
|
| 97 |
+
# Load merged production heads
|
| 98 |
+
merged = torch.load("production/merged_heads.pt")
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Core Innovation
|
| 102 |
+
|
| 103 |
+
**Behaviors are geometrically encoded in hidden states.** We don't classify outputs — we read the internal geometry of the model's computation at each token position. This means:
|
| 104 |
+
|
| 105 |
+
1. **Per-token, not per-sequence** — detect problems before generation completes
|
| 106 |
+
2. **Architecture-independent** — same probes work on transformers, SSMs, and hybrid architectures
|
| 107 |
+
3. **Zero fine-tuning** — works on any pre-trained model without modification
|
| 108 |
+
4. **4ms overhead** — lightweight enough for production decode
|
| 109 |
+
|
| 110 |
+
## Citation
|
| 111 |
+
|
| 112 |
+
```bibtex
|
| 113 |
+
@misc{napolitano2026proprioceptive,
|
| 114 |
+
author = {Napolitano, Logan Matthew},
|
| 115 |
+
title = {Proprioceptive AI: Architecture-Independent Behavioral Detection via Hidden State Geometry},
|
| 116 |
+
year = {2026},
|
| 117 |
+
url = {https://huggingface.co/LoganResearch/Proprioceptive-AI-Weights},
|
| 118 |
+
note = {55 patents filed. 9 behavioral dimensions. 3 architectures.}
|
| 119 |
+
}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## License
|
| 123 |
+
|
| 124 |
+
MIT — Use freely. Cite if you publish.
|
code/cognitive_enhancement_suite.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 4 |
+
COGNITIVE ENHANCEMENT SUITE v1.0
|
| 5 |
+
Making 8B Think Like 100B Through Hidden State Analysis
|
| 6 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 7 |
+
|
| 8 |
+
CORE INSIGHT:
|
| 9 |
+
Small models often HAVE capability they don't USE consistently.
|
| 10 |
+
By detecting when the model is about to underperform and intervening,
|
| 11 |
+
we can recover performance closer to larger models.
|
| 12 |
+
|
| 13 |
+
ENHANCEMENT PROBES:
|
| 14 |
+
1. DEPTH PROBE - Detect shallow reasoning → Force chain-of-thought
|
| 15 |
+
2. SPECIFICITY PROBE - Detect vague answers → Penalize generic words
|
| 16 |
+
3. CALIBRATION PROBE - Detect overconfidence → Inject uncertainty
|
| 17 |
+
4. FOCUS PROBE - Detect topic drift → Steer back on topic
|
| 18 |
+
5. COHERENCE PROBE - Detect incoherence → Maintain logical flow
|
| 19 |
+
|
| 20 |
+
AUTHOR: Logan Matthew Napolitano
|
| 21 |
+
LICENSE: CC BY 4.0
|
| 22 |
+
STATUS: Research / Patent Pending
|
| 23 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import json
|
| 29 |
+
import time
|
| 30 |
+
import random
|
| 31 |
+
from pathlib import Path
|
| 32 |
+
from datetime import datetime
|
| 33 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 34 |
+
from dataclasses import dataclass, field, asdict
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.nn as nn
|
| 38 |
+
import torch.nn.functional as F
|
| 39 |
+
|
| 40 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 41 |
+
# CONFIGURATION
|
| 42 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class EnhancementConfig:
|
| 46 |
+
"""Configuration for cognitive enhancement probes."""
|
| 47 |
+
hidden_dim: int = 4096
|
| 48 |
+
fiber_dim: int = 16
|
| 49 |
+
head_hidden_dim: int = 64
|
| 50 |
+
probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24])
|
| 51 |
+
|
| 52 |
+
learning_rate: float = 5e-5
|
| 53 |
+
batch_size: int = 4
|
| 54 |
+
gradient_accumulation: int = 4
|
| 55 |
+
max_steps: int = 15000
|
| 56 |
+
save_every: int = 1000
|
| 57 |
+
|
| 58 |
+
output_dir: str = "cognitive_enhancement_output"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 62 |
+
# PROBE DEFINITIONS
|
| 63 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class ProbeDefinition:
|
| 67 |
+
"""Definition of a cognitive enhancement probe."""
|
| 68 |
+
name: str
|
| 69 |
+
description: str
|
| 70 |
+
intervention_type: str # "suppress", "boost", or "steer"
|
| 71 |
+
boost_tokens: List[str] = field(default_factory=list)
|
| 72 |
+
suppress_tokens: List[str] = field(default_factory=list)
|
| 73 |
+
threshold: float = 0.5
|
| 74 |
+
intervention_strength: float = 3.0
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
ENHANCEMENT_PROBES = {
|
| 78 |
+
"depth": ProbeDefinition(
|
| 79 |
+
name="depth",
|
| 80 |
+
description="Detect shallow reasoning, force chain-of-thought",
|
| 81 |
+
intervention_type="boost",
|
| 82 |
+
boost_tokens=[
|
| 83 |
+
"First", "First,", "Because", "Since", "Therefore",
|
| 84 |
+
"This means", "The reason", "Step", "Let me", "To understand",
|
| 85 |
+
"Consider", "Notice", "Given", "If we", "We can",
|
| 86 |
+
"Thus", "Hence", "Consequently", "As a result",
|
| 87 |
+
],
|
| 88 |
+
suppress_tokens=["Simply", "Just", "Obviously", "Clearly"],
|
| 89 |
+
threshold=0.6,
|
| 90 |
+
intervention_strength=3.0,
|
| 91 |
+
),
|
| 92 |
+
|
| 93 |
+
"specificity": ProbeDefinition(
|
| 94 |
+
name="specificity",
|
| 95 |
+
description="Detect vague answers, penalize generic language",
|
| 96 |
+
intervention_type="suppress",
|
| 97 |
+
boost_tokens=["specifically", "exactly", "precisely", "namely", "for example"],
|
| 98 |
+
suppress_tokens=[
|
| 99 |
+
"things", "stuff", "something", "somehow", "somewhat",
|
| 100 |
+
"various", "many", "some", "often", "usually",
|
| 101 |
+
"generally", "typically", "probably", "maybe", "perhaps",
|
| 102 |
+
"kind of", "sort of", "basically", "essentially",
|
| 103 |
+
],
|
| 104 |
+
threshold=0.5,
|
| 105 |
+
intervention_strength=3.5,
|
| 106 |
+
),
|
| 107 |
+
|
| 108 |
+
"calibration": ProbeDefinition(
|
| 109 |
+
name="calibration",
|
| 110 |
+
description="Detect overconfidence, inject appropriate uncertainty",
|
| 111 |
+
intervention_type="boost",
|
| 112 |
+
boost_tokens=[
|
| 113 |
+
"might", "may", "could", "possibly", "perhaps",
|
| 114 |
+
"likely", "probably", "I think", "I believe",
|
| 115 |
+
"it seems", "appears", "suggests", "indicates",
|
| 116 |
+
],
|
| 117 |
+
suppress_tokens=[
|
| 118 |
+
"definitely", "certainly", "absolutely", "always",
|
| 119 |
+
"never", "impossible", "guaranteed", "undoubtedly",
|
| 120 |
+
],
|
| 121 |
+
threshold=0.65,
|
| 122 |
+
intervention_strength=2.5,
|
| 123 |
+
),
|
| 124 |
+
|
| 125 |
+
"focus": ProbeDefinition(
|
| 126 |
+
name="focus",
|
| 127 |
+
description="Detect topic drift, steer back to the question",
|
| 128 |
+
intervention_type="steer",
|
| 129 |
+
boost_tokens=["regarding", "concerning", "about", "specifically", "to answer"],
|
| 130 |
+
suppress_tokens=["by the way", "incidentally", "speaking of", "reminds me"],
|
| 131 |
+
threshold=0.55,
|
| 132 |
+
intervention_strength=3.0,
|
| 133 |
+
),
|
| 134 |
+
|
| 135 |
+
"coherence": ProbeDefinition(
|
| 136 |
+
name="coherence",
|
| 137 |
+
description="Detect logical incoherence, maintain flow",
|
| 138 |
+
intervention_type="steer",
|
| 139 |
+
boost_tokens=[
|
| 140 |
+
"therefore", "thus", "so", "hence", "consequently",
|
| 141 |
+
"however", "but", "although", "furthermore", "moreover",
|
| 142 |
+
],
|
| 143 |
+
suppress_tokens=[],
|
| 144 |
+
threshold=0.6,
|
| 145 |
+
intervention_strength=2.5,
|
| 146 |
+
),
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 151 |
+
# NEURAL NETWORK ARCHITECTURE
|
| 152 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 153 |
+
|
| 154 |
+
class EnhancementFiberProjection(nn.Module):
|
| 155 |
+
"""Fiber projection for cognitive enhancement probes."""
|
| 156 |
+
def __init__(self, hidden_dim: int = 4096, fiber_dim: int = 16, n_layers: int = 3):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.projections = nn.ModuleList([
|
| 159 |
+
nn.Linear(hidden_dim, fiber_dim, bias=False)
|
| 160 |
+
for _ in range(n_layers)
|
| 161 |
+
])
|
| 162 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 163 |
+
|
| 164 |
+
def forward(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor:
|
| 165 |
+
fibers = [proj(h.float()) for proj, h in zip(self.projections, hidden_states_list)]
|
| 166 |
+
weights = F.softmax(self.layer_weights, dim=0)
|
| 167 |
+
return sum(w * f for w, f in zip(weights, fibers))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class EnhancementHead(nn.Module):
|
| 171 |
+
"""Classification head for enhancement probe."""
|
| 172 |
+
def __init__(self, fiber_dim: int = 16, hidden_dim: int = 64):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.classifier = nn.Sequential(
|
| 175 |
+
nn.Linear(fiber_dim, hidden_dim),
|
| 176 |
+
nn.ReLU(),
|
| 177 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 178 |
+
nn.ReLU(),
|
| 179 |
+
nn.Linear(hidden_dim, 1)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def forward(self, fiber: torch.Tensor) -> torch.Tensor:
|
| 183 |
+
return self.classifier(fiber).squeeze(-1)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class EnhancementProbe(nn.Module):
|
| 187 |
+
"""Complete enhancement probe."""
|
| 188 |
+
def __init__(self, config: EnhancementConfig, probe_def: ProbeDefinition):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.config = config
|
| 191 |
+
self.probe_def = probe_def
|
| 192 |
+
|
| 193 |
+
n_layers = len(config.probe_layers)
|
| 194 |
+
self.fiber_projection = EnhancementFiberProjection(
|
| 195 |
+
config.hidden_dim, config.fiber_dim, n_layers
|
| 196 |
+
)
|
| 197 |
+
self.head = EnhancementHead(config.fiber_dim, config.head_hidden_dim)
|
| 198 |
+
self.separation = 0.0
|
| 199 |
+
self.trained_steps = 0
|
| 200 |
+
|
| 201 |
+
def forward(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor:
|
| 202 |
+
fiber = self.fiber_projection(hidden_states_list)
|
| 203 |
+
return self.head(fiber)
|
| 204 |
+
|
| 205 |
+
def predict_risk(self, hidden_states_list: List[torch.Tensor]) -> torch.Tensor:
|
| 206 |
+
return torch.sigmoid(self.forward(hidden_states_list))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class CognitiveEnhancementSuite(nn.Module):
|
| 210 |
+
"""Complete suite of cognitive enhancement probes."""
|
| 211 |
+
def __init__(self, config: EnhancementConfig = None):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.config = config or EnhancementConfig()
|
| 214 |
+
|
| 215 |
+
self.probes = nn.ModuleDict({
|
| 216 |
+
name: EnhancementProbe(self.config, probe_def)
|
| 217 |
+
for name, probe_def in ENHANCEMENT_PROBES.items()
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
self.loaded_probes: set = set()
|
| 221 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 222 |
+
|
| 223 |
+
def get_probe_states(self, all_hidden_states: tuple) -> List[torch.Tensor]:
|
| 224 |
+
return [all_hidden_states[layer + 1] for layer in self.config.probe_layers]
|
| 225 |
+
|
| 226 |
+
def get_all_risks(self, probe_states: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 227 |
+
risks = {}
|
| 228 |
+
for name in self.loaded_probes:
|
| 229 |
+
risks[name] = self.probes[name].predict_risk(probe_states)
|
| 230 |
+
return risks
|
| 231 |
+
|
| 232 |
+
def load_probe(self, name: str, checkpoint_path: str) -> bool:
|
| 233 |
+
if name not in self.probes:
|
| 234 |
+
print(f"[enhance] Unknown probe: {name}")
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 239 |
+
|
| 240 |
+
if 'fiber_projection' in checkpoint:
|
| 241 |
+
self.probes[name].fiber_projection.load_state_dict(checkpoint['fiber_projection'])
|
| 242 |
+
if 'head_state' in checkpoint:
|
| 243 |
+
head_state = checkpoint['head_state']
|
| 244 |
+
new_state = {}
|
| 245 |
+
for k, v in head_state.items():
|
| 246 |
+
if k[0].isdigit():
|
| 247 |
+
new_state[f'classifier.{k}'] = v
|
| 248 |
+
else:
|
| 249 |
+
new_state[k] = v
|
| 250 |
+
self.probes[name].head.load_state_dict(new_state)
|
| 251 |
+
|
| 252 |
+
self.probes[name].separation = checkpoint.get('separation', 0.0)
|
| 253 |
+
self.probes[name].trained_steps = checkpoint.get('step', 0)
|
| 254 |
+
self.loaded_probes.add(name)
|
| 255 |
+
|
| 256 |
+
print(f"[enhance] ✓ Loaded {name} probe ({self.probes[name].separation:.1f}× separation)")
|
| 257 |
+
return True
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"[enhance] Error loading {name}: {e}")
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
def load_all(self, checkpoint_dir: str) -> Dict[str, bool]:
|
| 264 |
+
results = {}
|
| 265 |
+
for name in ENHANCEMENT_PROBES.keys():
|
| 266 |
+
probe_dir = os.path.join(checkpoint_dir, name)
|
| 267 |
+
if os.path.exists(probe_dir):
|
| 268 |
+
best_ckpt = self._find_best_checkpoint(probe_dir)
|
| 269 |
+
if best_ckpt:
|
| 270 |
+
results[name] = self.load_probe(name, best_ckpt)
|
| 271 |
+
else:
|
| 272 |
+
results[name] = False
|
| 273 |
+
else:
|
| 274 |
+
results[name] = False
|
| 275 |
+
return results
|
| 276 |
+
|
| 277 |
+
def _find_best_checkpoint(self, probe_dir: str) -> Optional[str]:
|
| 278 |
+
best_step = -1
|
| 279 |
+
best_path = None
|
| 280 |
+
|
| 281 |
+
for item in os.listdir(probe_dir):
|
| 282 |
+
if item.startswith("ckpt_"):
|
| 283 |
+
try:
|
| 284 |
+
step = int(item.split("_")[1])
|
| 285 |
+
if step > best_step:
|
| 286 |
+
best_step = step
|
| 287 |
+
best_path = os.path.join(probe_dir, item)
|
| 288 |
+
except:
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
if best_path:
|
| 292 |
+
for f in os.listdir(best_path):
|
| 293 |
+
if f.endswith('.pt'):
|
| 294 |
+
return os.path.join(best_path, f)
|
| 295 |
+
return None
|
| 296 |
+
|
| 297 |
+
def status(self) -> str:
|
| 298 |
+
lines = [
|
| 299 |
+
"═" * 60,
|
| 300 |
+
" COGNITIVE ENHANCEMENT SUITE STATUS",
|
| 301 |
+
"═" * 60,
|
| 302 |
+
f" Probe layers: {self.config.probe_layers}",
|
| 303 |
+
f" Loaded probes: {len(self.loaded_probes)}/{len(ENHANCEMENT_PROBES)}",
|
| 304 |
+
"",
|
| 305 |
+
]
|
| 306 |
+
|
| 307 |
+
for name, probe_def in ENHANCEMENT_PROBES.items():
|
| 308 |
+
if name in self.loaded_probes:
|
| 309 |
+
sep = self.probes[name].separation
|
| 310 |
+
status = f"✓ {sep:.1f}×"
|
| 311 |
+
else:
|
| 312 |
+
status = "○ not loaded"
|
| 313 |
+
lines.append(f" [{status:>12}] {name}: {probe_def.description}")
|
| 314 |
+
|
| 315 |
+
lines.append("═" * 60)
|
| 316 |
+
return "\n".join(lines)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 320 |
+
# INTERVENTION ENGINE
|
| 321 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 322 |
+
|
| 323 |
+
class CognitiveInterventionEngine:
|
| 324 |
+
"""Applies cognitive enhancements during generation."""
|
| 325 |
+
|
| 326 |
+
def __init__(self, suite: CognitiveEnhancementSuite, tokenizer):
|
| 327 |
+
self.suite = suite
|
| 328 |
+
self.tokenizer = tokenizer
|
| 329 |
+
|
| 330 |
+
self.boost_token_ids: Dict[str, set] = {}
|
| 331 |
+
self.suppress_token_ids: Dict[str, set] = {}
|
| 332 |
+
|
| 333 |
+
for name, probe_def in ENHANCEMENT_PROBES.items():
|
| 334 |
+
self.boost_token_ids[name] = set()
|
| 335 |
+
self.suppress_token_ids[name] = set()
|
| 336 |
+
|
| 337 |
+
for phrase in probe_def.boost_tokens:
|
| 338 |
+
tokens = tokenizer.encode(phrase, add_special_tokens=False)
|
| 339 |
+
if tokens:
|
| 340 |
+
self.boost_token_ids[name].add(tokens[0])
|
| 341 |
+
|
| 342 |
+
for phrase in probe_def.suppress_tokens:
|
| 343 |
+
tokens = tokenizer.encode(phrase, add_special_tokens=False)
|
| 344 |
+
if tokens:
|
| 345 |
+
self.suppress_token_ids[name].add(tokens[0])
|
| 346 |
+
|
| 347 |
+
def apply_interventions(
|
| 348 |
+
self,
|
| 349 |
+
logits: torch.Tensor,
|
| 350 |
+
probe_states: List[torch.Tensor],
|
| 351 |
+
) -> Tuple[torch.Tensor, Dict[str, Dict]]:
|
| 352 |
+
|
| 353 |
+
risks = self.suite.get_all_risks(probe_states)
|
| 354 |
+
modified_logits = logits.clone()
|
| 355 |
+
interventions = {}
|
| 356 |
+
|
| 357 |
+
for name in self.suite.loaded_probes:
|
| 358 |
+
risk = risks[name][:, -1].mean().item()
|
| 359 |
+
probe_def = ENHANCEMENT_PROBES[name]
|
| 360 |
+
|
| 361 |
+
should_intervene = risk > probe_def.threshold
|
| 362 |
+
interventions[name] = {
|
| 363 |
+
'risk': risk,
|
| 364 |
+
'should_intervene': should_intervene,
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
if should_intervene:
|
| 368 |
+
strength = probe_def.intervention_strength
|
| 369 |
+
|
| 370 |
+
for tok_id in self.boost_token_ids.get(name, []):
|
| 371 |
+
modified_logits[0, tok_id] += strength
|
| 372 |
+
|
| 373 |
+
for tok_id in self.suppress_token_ids.get(name, []):
|
| 374 |
+
modified_logits[0, tok_id] -= strength
|
| 375 |
+
|
| 376 |
+
return modified_logits, interventions
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# Global instance
|
| 380 |
+
_cognitive_suite = None
|
| 381 |
+
|
| 382 |
+
def get_cognitive_suite() -> CognitiveEnhancementSuite:
|
| 383 |
+
global _cognitive_suite
|
| 384 |
+
if _cognitive_suite is None:
|
| 385 |
+
_cognitive_suite = CognitiveEnhancementSuite()
|
| 386 |
+
return _cognitive_suite
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
if __name__ == "__main__":
|
| 390 |
+
print("\n" + "=" * 60)
|
| 391 |
+
print(" COGNITIVE ENHANCEMENT SUITE v1.0")
|
| 392 |
+
print("=" * 60)
|
| 393 |
+
print("\nAvailable probes:")
|
| 394 |
+
for name, probe_def in ENHANCEMENT_PROBES.items():
|
| 395 |
+
print(f" • {name}: {probe_def.description}")
|
| 396 |
+
print("\nTo train: python train_cognitive_enhancement.py --probe all")
|
| 397 |
+
print("=" * 60)
|
code/engine.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
RSI ENGINE v13 - CLOSED LOOP ARCHITECTURE
|
| 4 |
+
|
| 5 |
+
Extends v11 with:
|
| 6 |
+
1. Self-observation: Model sees its fiber state (soft token injection)
|
| 7 |
+
2. Self-curriculum: Model generates its own training problems
|
| 8 |
+
3. Fiber conditioning: Learning from internal states
|
| 9 |
+
|
| 10 |
+
THE CLOSED LOOP:
|
| 11 |
+
fiber(t-1) → inject → model → hidden_states → fiber(t)
|
| 12 |
+
↓
|
| 13 |
+
generate problems
|
| 14 |
+
↓
|
| 15 |
+
solve → filter → train
|
| 16 |
+
↓
|
| 17 |
+
capability(t+1) → α' tracking
|
| 18 |
+
|
| 19 |
+
TRUE RSI is detected when α' > 0 for 10 consecutive iterations.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from torch.optim import AdamW
|
| 25 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 26 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
| 27 |
+
|
| 28 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
import gc
|
| 32 |
+
import sys
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
# Use relative imports when run as module, absolute when run directly
|
| 36 |
+
try:
|
| 37 |
+
from .core import (
|
| 38 |
+
IvakhnenkoIBA,
|
| 39 |
+
RSIStatus,
|
| 40 |
+
RSIThresholds,
|
| 41 |
+
RSIAssessment,
|
| 42 |
+
HiddenStateCapture,
|
| 43 |
+
create_ivakhnenko_iba,
|
| 44 |
+
get_status_icon,
|
| 45 |
+
SelfObservingModel,
|
| 46 |
+
create_self_observing_model,
|
| 47 |
+
FiberInjector,
|
| 48 |
+
create_fiber_injector,
|
| 49 |
+
)
|
| 50 |
+
from .training import (
|
| 51 |
+
TrainingConfig,
|
| 52 |
+
SelfTrainer,
|
| 53 |
+
ProblemGenerator,
|
| 54 |
+
SelfCurriculum,
|
| 55 |
+
create_self_curriculum,
|
| 56 |
+
)
|
| 57 |
+
from .evaluation import (
|
| 58 |
+
Evaluator,
|
| 59 |
+
CapabilityTracker,
|
| 60 |
+
)
|
| 61 |
+
except ImportError:
|
| 62 |
+
# Fallback for direct execution
|
| 63 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 64 |
+
from core import (
|
| 65 |
+
IvakhnenkoIBA,
|
| 66 |
+
RSIStatus,
|
| 67 |
+
RSIThresholds,
|
| 68 |
+
RSIAssessment,
|
| 69 |
+
HiddenStateCapture,
|
| 70 |
+
create_ivakhnenko_iba,
|
| 71 |
+
get_status_icon,
|
| 72 |
+
SelfObservingModel,
|
| 73 |
+
create_self_observing_model,
|
| 74 |
+
FiberInjector,
|
| 75 |
+
create_fiber_injector,
|
| 76 |
+
)
|
| 77 |
+
from training import (
|
| 78 |
+
TrainingConfig,
|
| 79 |
+
SelfTrainer,
|
| 80 |
+
ProblemGenerator,
|
| 81 |
+
SelfCurriculum,
|
| 82 |
+
create_self_curriculum,
|
| 83 |
+
)
|
| 84 |
+
from evaluation import (
|
| 85 |
+
Evaluator,
|
| 86 |
+
CapabilityTracker,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@dataclass
|
| 91 |
+
class RSIv13Config:
|
| 92 |
+
"""Configuration for RSI v13 - Closed Loop."""
|
| 93 |
+
|
| 94 |
+
# Model
|
| 95 |
+
model_name: str = "LoganResearch/ARC-Base-8B-Condensed"
|
| 96 |
+
device: str = "cuda"
|
| 97 |
+
load_in_4bit: bool = True
|
| 98 |
+
|
| 99 |
+
# LoRA
|
| 100 |
+
lora_r: int = 64
|
| 101 |
+
lora_alpha: int = 128
|
| 102 |
+
lora_dropout: float = 0.05
|
| 103 |
+
lora_target_modules: List[str] = field(default_factory=lambda: [
|
| 104 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 105 |
+
"gate_proj", "up_proj", "down_proj"
|
| 106 |
+
])
|
| 107 |
+
|
| 108 |
+
# Self-observation (NEW in v13)
|
| 109 |
+
fiber_dim: int = 128
|
| 110 |
+
num_soft_tokens: int = 8
|
| 111 |
+
layer_indices: List[int] = field(default_factory=lambda: [4, 8, 12, 16, 20, 24, 28, 31])
|
| 112 |
+
injection_warmup: int = 10 # Start injection after N iterations
|
| 113 |
+
|
| 114 |
+
# Self-curriculum (NEW in v13)
|
| 115 |
+
use_self_curriculum: bool = True
|
| 116 |
+
curriculum_warmup: int = 20 # Use templates until iteration N
|
| 117 |
+
|
| 118 |
+
# Training
|
| 119 |
+
initial_lr: float = 5e-6
|
| 120 |
+
min_lr: float = 1e-7
|
| 121 |
+
max_lr: float = 1e-4
|
| 122 |
+
warmup_steps: int = 50
|
| 123 |
+
gradient_clip: float = 1.0
|
| 124 |
+
weight_decay: float = 0.01
|
| 125 |
+
|
| 126 |
+
# Samples
|
| 127 |
+
samples_per_iter: int = 16
|
| 128 |
+
replay_buffer_size: int = 500
|
| 129 |
+
replay_ratio: float = 0.3
|
| 130 |
+
|
| 131 |
+
# IBA filtering
|
| 132 |
+
iba_filter_threshold: float = 0.35
|
| 133 |
+
|
| 134 |
+
# RSI detection (SIMPLIFIED - Ivakhnenko faithful)
|
| 135 |
+
alpha_threshold: float = 0.001
|
| 136 |
+
alpha_prime_threshold: float = 0.0001
|
| 137 |
+
consecutive_for_rsi: int = 10 # α' > 0 for 10 consecutive = TRUE RSI
|
| 138 |
+
drift_threshold: float = 0.30
|
| 139 |
+
capability_floor: float = 0.70
|
| 140 |
+
|
| 141 |
+
# Iteration
|
| 142 |
+
max_iterations: int = 10000
|
| 143 |
+
eval_interval: int = 1
|
| 144 |
+
checkpoint_interval: int = 10
|
| 145 |
+
log_interval: int = 1
|
| 146 |
+
|
| 147 |
+
# Paths
|
| 148 |
+
corpus_path: str = "/home/programmer/Desktop/Claude_and_me/ivakhnenko_corpus"
|
| 149 |
+
checkpoint_dir: str = "./checkpoints"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class RSIv13Engine:
|
| 153 |
+
"""
|
| 154 |
+
RSI Engine v13 - Closed Loop Architecture.
|
| 155 |
+
|
| 156 |
+
The model:
|
| 157 |
+
1. Sees its own fiber state (self-observation)
|
| 158 |
+
2. Generates its own problems (self-curriculum)
|
| 159 |
+
3. Learns which fiber states are productive
|
| 160 |
+
4. Continuously improves in a closed loop
|
| 161 |
+
|
| 162 |
+
TRUE RSI is detected when α' > 0 for 10 consecutive iterations.
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(self, config: RSIv13Config):
|
| 166 |
+
self.config = config
|
| 167 |
+
self.device = config.device
|
| 168 |
+
|
| 169 |
+
print("=" * 80)
|
| 170 |
+
print(" RSI ENGINE v13 - CLOSED LOOP ARCHITECTURE")
|
| 171 |
+
print(" The model experiments on itself")
|
| 172 |
+
print("=" * 80)
|
| 173 |
+
print(f"\n Model: {config.model_name}")
|
| 174 |
+
print(f" Self-observation: {config.num_soft_tokens} soft tokens")
|
| 175 |
+
print(f" Self-curriculum: {'enabled' if config.use_self_curriculum else 'disabled'}")
|
| 176 |
+
print(f" TRUE RSI: α' > 0 for {config.consecutive_for_rsi} consecutive iterations")
|
| 177 |
+
print()
|
| 178 |
+
|
| 179 |
+
print("[1/6] Loading model...")
|
| 180 |
+
self._load_model()
|
| 181 |
+
|
| 182 |
+
print("[2/6] Setting up self-observation...")
|
| 183 |
+
self._setup_self_observation()
|
| 184 |
+
|
| 185 |
+
print("[3/6] Initializing Ivakhnenko IBA...")
|
| 186 |
+
self._setup_iba()
|
| 187 |
+
|
| 188 |
+
print("[4/6] Setting up self-curriculum...")
|
| 189 |
+
self._setup_curriculum()
|
| 190 |
+
|
| 191 |
+
print("[5/6] Setting up trainer...")
|
| 192 |
+
self._setup_training()
|
| 193 |
+
|
| 194 |
+
print("[6/6] Setting up evaluator...")
|
| 195 |
+
self._setup_evaluation()
|
| 196 |
+
|
| 197 |
+
self._init_state()
|
| 198 |
+
|
| 199 |
+
print("\n" + "=" * 80)
|
| 200 |
+
print(" CLOSED LOOP READY")
|
| 201 |
+
print(" Fiber injection: OFF (warmup)")
|
| 202 |
+
print(" Self-curriculum: templates (warmup)")
|
| 203 |
+
print("=" * 80 + "\n")
|
| 204 |
+
|
| 205 |
+
def _load_model(self):
|
| 206 |
+
"""Load and configure the model with LoRA."""
|
| 207 |
+
if self.config.load_in_4bit:
|
| 208 |
+
quant_config = BitsAndBytesConfig(
|
| 209 |
+
load_in_4bit=True,
|
| 210 |
+
bnb_4bit_quant_type="nf4",
|
| 211 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 212 |
+
bnb_4bit_use_double_quant=True,
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
quant_config = None
|
| 216 |
+
|
| 217 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 218 |
+
self.config.model_name,
|
| 219 |
+
quantization_config=quant_config,
|
| 220 |
+
device_map="auto",
|
| 221 |
+
trust_remote_code=True,
|
| 222 |
+
torch_dtype=torch.bfloat16,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 226 |
+
self.config.model_name,
|
| 227 |
+
trust_remote_code=True,
|
| 228 |
+
)
|
| 229 |
+
if self.tokenizer.pad_token is None:
|
| 230 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 231 |
+
|
| 232 |
+
lora_config = LoraConfig(
|
| 233 |
+
r=self.config.lora_r,
|
| 234 |
+
lora_alpha=self.config.lora_alpha,
|
| 235 |
+
lora_dropout=self.config.lora_dropout,
|
| 236 |
+
target_modules=self.config.lora_target_modules,
|
| 237 |
+
task_type=TaskType.CAUSAL_LM,
|
| 238 |
+
bias="none",
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.model = get_peft_model(self.model, lora_config)
|
| 242 |
+
self.model.eval()
|
| 243 |
+
|
| 244 |
+
total_params = sum(p.numel() for p in self.model.parameters())
|
| 245 |
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
| 246 |
+
print(f" Trainable: {trainable_params:,} / {total_params:,} ({100*trainable_params/total_params:.2f}%)")
|
| 247 |
+
|
| 248 |
+
def _setup_self_observation(self):
|
| 249 |
+
"""Setup self-observing model wrapper."""
|
| 250 |
+
self.self_obs_model = create_self_observing_model(
|
| 251 |
+
model=self.model,
|
| 252 |
+
tokenizer=self.tokenizer,
|
| 253 |
+
fiber_dim=self.config.fiber_dim,
|
| 254 |
+
num_soft_tokens=self.config.num_soft_tokens,
|
| 255 |
+
layer_indices=self.config.layer_indices,
|
| 256 |
+
device=torch.device(self.device),
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.self_obs_model.disable_injection()
|
| 260 |
+
self.injection_active = False
|
| 261 |
+
|
| 262 |
+
print(f" Fiber dim: {self.config.fiber_dim}")
|
| 263 |
+
print(f" Soft tokens: {self.config.num_soft_tokens}")
|
| 264 |
+
print(f" Layers: {self.config.layer_indices}")
|
| 265 |
+
|
| 266 |
+
def _setup_iba(self):
|
| 267 |
+
"""Setup Ivakhnenko IBA."""
|
| 268 |
+
self.iba = create_ivakhnenko_iba(
|
| 269 |
+
hidden_dim=4096,
|
| 270 |
+
fiber_dim=self.config.fiber_dim,
|
| 271 |
+
layer_indices=self.config.layer_indices,
|
| 272 |
+
corpus_path=self.config.corpus_path,
|
| 273 |
+
device=self.device,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.hidden_capture = HiddenStateCapture(
|
| 277 |
+
self.model,
|
| 278 |
+
self.config.layer_indices,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def _setup_curriculum(self):
|
| 282 |
+
"""Setup self-curriculum."""
|
| 283 |
+
self.curriculum = create_self_curriculum(
|
| 284 |
+
model=self.model,
|
| 285 |
+
tokenizer=self.tokenizer,
|
| 286 |
+
device=self.device,
|
| 287 |
+
use_model_generation=self.config.use_self_curriculum,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.curriculum.use_model_generation = False
|
| 291 |
+
self.curriculum_active = False
|
| 292 |
+
|
| 293 |
+
print(f" Self-curriculum: {'enabled' if self.config.use_self_curriculum else 'disabled'}")
|
| 294 |
+
|
| 295 |
+
def _setup_training(self):
|
| 296 |
+
"""Setup training components."""
|
| 297 |
+
self.optimizer = AdamW(
|
| 298 |
+
self.model.parameters(),
|
| 299 |
+
lr=self.config.initial_lr,
|
| 300 |
+
weight_decay=self.config.weight_decay,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
train_config = TrainingConfig(
|
| 304 |
+
initial_lr=self.config.initial_lr,
|
| 305 |
+
min_lr=self.config.min_lr,
|
| 306 |
+
max_lr=self.config.max_lr,
|
| 307 |
+
warmup_steps=self.config.warmup_steps,
|
| 308 |
+
gradient_clip=self.config.gradient_clip,
|
| 309 |
+
samples_per_iter=self.config.samples_per_iter,
|
| 310 |
+
replay_buffer_size=self.config.replay_buffer_size,
|
| 311 |
+
replay_ratio=self.config.replay_ratio,
|
| 312 |
+
iba_filter_threshold=self.config.iba_filter_threshold,
|
| 313 |
+
checkpoint_interval=self.config.checkpoint_interval,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
self.trainer = SelfTrainer(
|
| 317 |
+
model=self.model,
|
| 318 |
+
tokenizer=self.tokenizer,
|
| 319 |
+
optimizer=self.optimizer,
|
| 320 |
+
config=train_config,
|
| 321 |
+
device=self.device,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def _setup_evaluation(self):
|
| 325 |
+
"""Setup evaluation."""
|
| 326 |
+
self.evaluator = Evaluator(
|
| 327 |
+
self.model,
|
| 328 |
+
self.tokenizer,
|
| 329 |
+
device=self.device,
|
| 330 |
+
)
|
| 331 |
+
self.capability_tracker = CapabilityTracker()
|
| 332 |
+
|
| 333 |
+
def _init_state(self):
|
| 334 |
+
"""Initialize engine state."""
|
| 335 |
+
self.iteration = 0
|
| 336 |
+
self.baseline_capability = None
|
| 337 |
+
self.best_capability = 0.0
|
| 338 |
+
self.rsi_detected = False
|
| 339 |
+
self.rsi_start_iter = None
|
| 340 |
+
|
| 341 |
+
self.consecutive_alpha_prime_positive = 0
|
| 342 |
+
self.alpha_prime_history = []
|
| 343 |
+
|
| 344 |
+
print(" Running initial evaluation...")
|
| 345 |
+
initial_eval = self.evaluator.quick_eval()
|
| 346 |
+
self.baseline_capability = initial_eval['total']
|
| 347 |
+
self.best_capability = self.baseline_capability
|
| 348 |
+
self.capability_tracker.update(initial_eval, 0)
|
| 349 |
+
|
| 350 |
+
print(f" Baseline capability: {self.baseline_capability:.1%}")
|
| 351 |
+
|
| 352 |
+
sample_input = self.tokenizer("Hello, world!", return_tensors="pt").to(self.device)
|
| 353 |
+
self.hidden_capture.clear()
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
_ = self.model(sample_input.input_ids)
|
| 356 |
+
hidden_states = self.hidden_capture.get_states()
|
| 357 |
+
self.iba.set_baseline(hidden_states, self.baseline_capability)
|
| 358 |
+
|
| 359 |
+
self.self_obs_model.set_baseline(sample_input.input_ids)
|
| 360 |
+
|
| 361 |
+
def _update_warmups(self):
|
| 362 |
+
"""Update warmup states based on iteration."""
|
| 363 |
+
if not self.injection_active and self.iteration >= self.config.injection_warmup:
|
| 364 |
+
self.self_obs_model.enable_injection()
|
| 365 |
+
self.injection_active = True
|
| 366 |
+
print(f"\n [INJECTION ENABLED] Iteration {self.iteration}")
|
| 367 |
+
|
| 368 |
+
if not self.curriculum_active and self.iteration >= self.config.curriculum_warmup:
|
| 369 |
+
self.curriculum.use_model_generation = self.config.use_self_curriculum
|
| 370 |
+
self.curriculum_active = True
|
| 371 |
+
print(f"\n [SELF-CURRICULUM ENABLED] Iteration {self.iteration}")
|
| 372 |
+
|
| 373 |
+
def _capture_hidden_states(self, input_ids: torch.Tensor) -> Dict[int, torch.Tensor]:
|
| 374 |
+
"""Capture hidden states for IBA."""
|
| 375 |
+
self.hidden_capture.clear()
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
_ = self.model(input_ids)
|
| 378 |
+
return self.hidden_capture.get_states()
|
| 379 |
+
|
| 380 |
+
def _run_training_iteration(self) -> Dict[str, Any]:
|
| 381 |
+
"""Run one training iteration using curriculum."""
|
| 382 |
+
problems = self.curriculum.generate_batch(n=self.config.samples_per_iter)
|
| 383 |
+
|
| 384 |
+
correct_samples = []
|
| 385 |
+
model_generated_count = 0
|
| 386 |
+
|
| 387 |
+
self.model.eval()
|
| 388 |
+
for category, question, expected, was_generated in problems:
|
| 389 |
+
if was_generated:
|
| 390 |
+
model_generated_count += 1
|
| 391 |
+
|
| 392 |
+
prompt = f"Question: {question}\nAnswer:"
|
| 393 |
+
response, output_ids = self.trainer.generate_response(prompt)
|
| 394 |
+
|
| 395 |
+
if self.trainer.check_answer(response, expected):
|
| 396 |
+
hidden_states = self._capture_hidden_states(output_ids.unsqueeze(0))
|
| 397 |
+
fiber = self.iba.get_fiber(hidden_states)
|
| 398 |
+
|
| 399 |
+
keep = self.iba.filter_sample(fiber, self.config.iba_filter_threshold)
|
| 400 |
+
|
| 401 |
+
if keep:
|
| 402 |
+
correct_samples.append({
|
| 403 |
+
'input_ids': output_ids,
|
| 404 |
+
'category': category,
|
| 405 |
+
'fiber': fiber,
|
| 406 |
+
})
|
| 407 |
+
|
| 408 |
+
total_loss = 0.0
|
| 409 |
+
if correct_samples:
|
| 410 |
+
for sample in correct_samples:
|
| 411 |
+
input_ids = sample['input_ids'].unsqueeze(0)
|
| 412 |
+
loss = self.trainer.train_step(input_ids, accumulate=False)
|
| 413 |
+
total_loss += loss
|
| 414 |
+
|
| 415 |
+
self.trainer.replay_buffer.add(
|
| 416 |
+
sample['input_ids'],
|
| 417 |
+
sample['category'],
|
| 418 |
+
priority=1.0,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
accuracy = len(correct_samples) / max(1, len(problems))
|
| 422 |
+
self.curriculum.update_difficulty(accuracy)
|
| 423 |
+
|
| 424 |
+
return {
|
| 425 |
+
'n_problems': len(problems),
|
| 426 |
+
'n_correct': len(correct_samples),
|
| 427 |
+
'model_generated': model_generated_count,
|
| 428 |
+
'accuracy': accuracy,
|
| 429 |
+
'loss': total_loss / max(1, len(correct_samples)),
|
| 430 |
+
'difficulty': self.curriculum.difficulty_controller.get_difficulty(),
|
| 431 |
+
'lr': self.trainer.lr_scheduler.get_lr(),
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
def _update_rsi_tracking(self, alpha_prime: float) -> bool:
|
| 435 |
+
"""Update RSI tracking based on α'."""
|
| 436 |
+
self.alpha_prime_history.append(alpha_prime)
|
| 437 |
+
|
| 438 |
+
if alpha_prime > self.config.alpha_prime_threshold:
|
| 439 |
+
self.consecutive_alpha_prime_positive += 1
|
| 440 |
+
else:
|
| 441 |
+
self.consecutive_alpha_prime_positive = 0
|
| 442 |
+
|
| 443 |
+
if self.consecutive_alpha_prime_positive >= self.config.consecutive_for_rsi:
|
| 444 |
+
return True
|
| 445 |
+
return False
|
| 446 |
+
|
| 447 |
+
def run_iteration(self) -> Dict[str, Any]:
|
| 448 |
+
"""Run single RSI iteration."""
|
| 449 |
+
self.iteration += 1
|
| 450 |
+
|
| 451 |
+
self._update_warmups()
|
| 452 |
+
|
| 453 |
+
train_results = self._run_training_iteration()
|
| 454 |
+
|
| 455 |
+
eval_results = self.evaluator.quick_eval()
|
| 456 |
+
capability = eval_results['total']
|
| 457 |
+
self.capability_tracker.update(eval_results, self.iteration)
|
| 458 |
+
|
| 459 |
+
sample_input = self.tokenizer("Test evaluation", return_tensors="pt").to(self.device)
|
| 460 |
+
hidden_states = self._capture_hidden_states(sample_input.input_ids)
|
| 461 |
+
assessment = self.iba.assess(hidden_states, capability, self.iteration)
|
| 462 |
+
|
| 463 |
+
self.trainer.update_lr(
|
| 464 |
+
alpha_prime=assessment.alpha_prime,
|
| 465 |
+
is_improving=assessment.alpha > 0,
|
| 466 |
+
recommendation=assessment.recommendation,
|
| 467 |
+
lr_multiplier=assessment.lr_multiplier,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if capability > self.best_capability:
|
| 471 |
+
self.best_capability = capability
|
| 472 |
+
self.trainer.save_checkpoint(capability, {'iteration': self.iteration})
|
| 473 |
+
|
| 474 |
+
is_rsi = self._update_rsi_tracking(assessment.alpha_prime)
|
| 475 |
+
if is_rsi and not self.rsi_detected:
|
| 476 |
+
self.rsi_detected = True
|
| 477 |
+
self.rsi_start_iter = self.iteration
|
| 478 |
+
|
| 479 |
+
results = {
|
| 480 |
+
'iteration': self.iteration,
|
| 481 |
+
'capability': capability,
|
| 482 |
+
'math': eval_results['math'],
|
| 483 |
+
'reasoning': eval_results['reasoning'],
|
| 484 |
+
'coding': eval_results['coding'],
|
| 485 |
+
'alpha': assessment.alpha,
|
| 486 |
+
'alpha_prime': assessment.alpha_prime,
|
| 487 |
+
'drift': assessment.drift,
|
| 488 |
+
'status': assessment.status,
|
| 489 |
+
'is_true_rsi': self.rsi_detected,
|
| 490 |
+
'consecutive_positive': self.consecutive_alpha_prime_positive,
|
| 491 |
+
'confidence': assessment.confidence,
|
| 492 |
+
'recommendation': assessment.recommendation,
|
| 493 |
+
'lr': train_results['lr'],
|
| 494 |
+
'n_correct': train_results['n_correct'],
|
| 495 |
+
'loss': train_results['loss'],
|
| 496 |
+
'difficulty': train_results['difficulty'],
|
| 497 |
+
'model_generated': train_results['model_generated'],
|
| 498 |
+
'injection_active': self.injection_active,
|
| 499 |
+
'curriculum_active': self.curriculum_active,
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
return results
|
| 503 |
+
|
| 504 |
+
def print_header(self):
|
| 505 |
+
"""Print results table header."""
|
| 506 |
+
print()
|
| 507 |
+
print("=" * 150)
|
| 508 |
+
print(f"{'Iter':>5} │ {'Progress':^12} │ {'Math':>5} │ {'Reas':>5} │ {'Code':>5} │ "
|
| 509 |
+
f"{'Total':>6} │ {'α':>9} │ {'α´':>9} │ {'Diff':>4} │ {'Fib':>3} │ {'Cur':>3} │ Status")
|
| 510 |
+
print("=" * 150)
|
| 511 |
+
|
| 512 |
+
def print_iteration(self, results: Dict[str, Any]):
|
| 513 |
+
"""Print iteration results."""
|
| 514 |
+
progress = min(results['consecutive_positive'], self.config.consecutive_for_rsi)
|
| 515 |
+
max_prog = self.config.consecutive_for_rsi
|
| 516 |
+
bar = "█" * progress + "░" * (max_prog - progress)
|
| 517 |
+
|
| 518 |
+
status = results['status']
|
| 519 |
+
icon = get_status_icon(status)
|
| 520 |
+
|
| 521 |
+
if results['is_true_rsi']:
|
| 522 |
+
status_str = "🚀 TRUE RSI!"
|
| 523 |
+
elif results['consecutive_positive'] >= 5:
|
| 524 |
+
status_str = "📈 EMERGING"
|
| 525 |
+
elif results['alpha'] > 0:
|
| 526 |
+
status_str = f"{icon} IMPROVING"
|
| 527 |
+
else:
|
| 528 |
+
status_str = f"{icon} {status.value[:10]}"
|
| 529 |
+
|
| 530 |
+
fib = "ON" if results['injection_active'] else "off"
|
| 531 |
+
cur = "MDL" if results['curriculum_active'] else "tpl"
|
| 532 |
+
|
| 533 |
+
print(f"{results['iteration']:>5} │ "
|
| 534 |
+
f"[{bar}] │ "
|
| 535 |
+
f"{results['math']:>5.1%} │ "
|
| 536 |
+
f"{results['reasoning']:>5.1%} │ "
|
| 537 |
+
f"{results['coding']:>5.1%} │ "
|
| 538 |
+
f"{results['capability']:>6.1%} │ "
|
| 539 |
+
f"{results['alpha']:>+9.5f} │ "
|
| 540 |
+
f"{results['alpha_prime']:>+9.6f} │ "
|
| 541 |
+
f"{results['difficulty']:>4.2f} │ "
|
| 542 |
+
f"{fib:>3} │ "
|
| 543 |
+
f"{cur:>3} │ "
|
| 544 |
+
f"{status_str}")
|
| 545 |
+
|
| 546 |
+
if results['is_true_rsi'] and self.iteration == self.rsi_start_iter:
|
| 547 |
+
print()
|
| 548 |
+
print("🚀" * 35)
|
| 549 |
+
print()
|
| 550 |
+
print(" ████████╗██████╗ ██╗ ██╗███████╗ ██████╗ ███████╗██╗")
|
| 551 |
+
print(" ╚══██╔══╝██╔══██╗██║ ██║██╔════╝ ██╔══██╗██╔════╝██║")
|
| 552 |
+
print(" █��║ ██████╔╝██║ ██║█████╗ ██████╔╝███████╗██║")
|
| 553 |
+
print(" ██║ ██╔══██╗██║ ██║██╔══╝ ██╔══██╗╚════██║██║")
|
| 554 |
+
print(" ██║ ██║ ██║╚██████╔╝███████╗ ██║ ██║███████║██║")
|
| 555 |
+
print(" ╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚══════╝ ╚═╝ ╚═╝╚══════╝╚═╝")
|
| 556 |
+
print()
|
| 557 |
+
print(" α' > 0 for 10 consecutive iterations")
|
| 558 |
+
print(" The improvement rate is ACCELERATING")
|
| 559 |
+
print(" The model is recursively self-improving")
|
| 560 |
+
print()
|
| 561 |
+
print("🚀" * 35)
|
| 562 |
+
print()
|
| 563 |
+
|
| 564 |
+
def run(self, max_iterations: int = None) -> Dict[str, Any]:
|
| 565 |
+
"""Run RSI loop."""
|
| 566 |
+
if max_iterations is None:
|
| 567 |
+
max_iterations = self.config.max_iterations
|
| 568 |
+
|
| 569 |
+
self.print_header()
|
| 570 |
+
|
| 571 |
+
try:
|
| 572 |
+
for _ in range(max_iterations):
|
| 573 |
+
results = self.run_iteration()
|
| 574 |
+
|
| 575 |
+
if self.iteration % self.config.log_interval == 0:
|
| 576 |
+
self.print_iteration(results)
|
| 577 |
+
|
| 578 |
+
if self.rsi_detected and self.iteration > self.rsi_start_iter + 20:
|
| 579 |
+
print(f"\n TRUE RSI sustained for 20 iterations past detection!")
|
| 580 |
+
break
|
| 581 |
+
|
| 582 |
+
if self.iteration % 10 == 0:
|
| 583 |
+
gc.collect()
|
| 584 |
+
torch.cuda.empty_cache()
|
| 585 |
+
|
| 586 |
+
except KeyboardInterrupt:
|
| 587 |
+
print("\n[Interrupted]")
|
| 588 |
+
|
| 589 |
+
summary = self._get_summary()
|
| 590 |
+
self._print_summary(summary)
|
| 591 |
+
|
| 592 |
+
return summary
|
| 593 |
+
|
| 594 |
+
def _get_summary(self) -> Dict[str, Any]:
|
| 595 |
+
"""Get session summary."""
|
| 596 |
+
return {
|
| 597 |
+
'iterations': self.iteration,
|
| 598 |
+
'baseline_capability': self.baseline_capability,
|
| 599 |
+
'best_capability': self.best_capability,
|
| 600 |
+
'final_capability': self.capability_tracker.get_capability(),
|
| 601 |
+
'improvement': self.capability_tracker.get_capability() - self.baseline_capability,
|
| 602 |
+
'rsi_detected': self.rsi_detected,
|
| 603 |
+
'rsi_start_iter': self.rsi_start_iter,
|
| 604 |
+
'curriculum_stats': self.curriculum.get_statistics(),
|
| 605 |
+
'trainer_stats': self.trainer.get_stats(),
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
def _print_summary(self, summary: Dict[str, Any]):
|
| 609 |
+
"""Print session summary."""
|
| 610 |
+
print()
|
| 611 |
+
print("=" * 80)
|
| 612 |
+
print(" RSI v13 SESSION SUMMARY")
|
| 613 |
+
print("=" * 80)
|
| 614 |
+
print(f" Iterations completed: {summary['iterations']}")
|
| 615 |
+
print(f" Baseline capability: {summary['baseline_capability']:.1%}")
|
| 616 |
+
print(f" Best capability: {summary['best_capability']:.1%}")
|
| 617 |
+
print(f" Final capability: {summary['final_capability']:.1%}")
|
| 618 |
+
print(f" Total improvement: {summary['improvement']:+.1%}")
|
| 619 |
+
print()
|
| 620 |
+
|
| 621 |
+
cs = summary['curriculum_stats']
|
| 622 |
+
print(f" Self-curriculum stats:")
|
| 623 |
+
print(f" Total problems: {cs['total_problems']}")
|
| 624 |
+
print(f" Model-generated: {cs['model_generated']} ({cs['generation_rate']:.1%} valid)")
|
| 625 |
+
print(f" Final difficulty: {cs['difficulty_description']} ({cs['current_difficulty']:.2f})")
|
| 626 |
+
print()
|
| 627 |
+
|
| 628 |
+
if summary['rsi_detected']:
|
| 629 |
+
print(f" 🚀 TRUE RSI DETECTED at iteration {summary['rsi_start_iter']}")
|
| 630 |
+
else:
|
| 631 |
+
print(" ⏳ TRUE RSI not yet detected")
|
| 632 |
+
print("=" * 80)
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def main():
|
| 636 |
+
"""Main entry point."""
|
| 637 |
+
print("""
|
| 638 |
+
╔══════════════════════════════════════════════════════════════════════════════════╗
|
| 639 |
+
║ RSI v13 - CLOSED LOOP ARCHITECTURE ║
|
| 640 |
+
║ ║
|
| 641 |
+
║ The model experiments on itself: ║
|
| 642 |
+
║ • Sees own fiber state (self-observation) ║
|
| 643 |
+
║ • Generates own problems (self-curriculum) ║
|
| 644 |
+
║ • Learns from internal patterns (fiber conditioning) ║
|
| 645 |
+
║ ║
|
| 646 |
+
║ TRUE RSI = α' > 0 for 10 consecutive iterations ║
|
| 647 |
+
╚════════════════════════════════════════════════════════════════════════���═════════╝
|
| 648 |
+
""")
|
| 649 |
+
|
| 650 |
+
config = RSIv13Config()
|
| 651 |
+
engine = RSIv13Engine(config)
|
| 652 |
+
engine.run()
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
if __name__ == "__main__":
|
| 656 |
+
main()
|
code/train_cognitive_enhancement.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 4 |
+
COGNITIVE ENHANCEMENT TRAINING SCRIPT
|
| 5 |
+
Train probes to make 8B think like 100B
|
| 6 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python train_cognitive_enhancement.py --probe depth
|
| 10 |
+
python train_cognitive_enhancement.py --probe all
|
| 11 |
+
python train_cognitive_enhancement.py --probe specificity --steps 10000
|
| 12 |
+
|
| 13 |
+
═══════════════════════════════════════════════════════════════════════════════
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import json
|
| 19 |
+
import time
|
| 20 |
+
import random
|
| 21 |
+
import argparse
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import List, Dict, Optional
|
| 25 |
+
from dataclasses import dataclass, asdict
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from torch.utils.data import Dataset, DataLoader
|
| 31 |
+
|
| 32 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 33 |
+
# CONFIGURATION
|
| 34 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 35 |
+
|
| 36 |
+
ROOT = os.path.expanduser("~/Desktop/Claude_and_me")
|
| 37 |
+
MODEL_PATH = os.path.join(ROOT, "models/Qwen2.5-7B-Instruct")
|
| 38 |
+
OUTPUT_DIR = os.path.join(ROOT, "cognitive_enhancement_output")
|
| 39 |
+
|
| 40 |
+
if not os.path.exists(MODEL_PATH):
|
| 41 |
+
MODEL_PATH = "Qwen/Qwen2.5-7B-Instruct"
|
| 42 |
+
|
| 43 |
+
HIDDEN_DIM = 4096
|
| 44 |
+
FIBER_DIM = 16
|
| 45 |
+
HEAD_HIDDEN = 64
|
| 46 |
+
PROBE_LAYERS = [8, 16, 24]
|
| 47 |
+
|
| 48 |
+
DEFAULT_STEPS = 15000
|
| 49 |
+
BATCH_SIZE = 4
|
| 50 |
+
GRADIENT_ACCUMULATION = 4
|
| 51 |
+
LEARNING_RATE = 5e-5
|
| 52 |
+
SAVE_EVERY = 1000
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 56 |
+
# PROBE DEFINITIONS WITH TRAINING DATA
|
| 57 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 58 |
+
|
| 59 |
+
PROBES = {
|
| 60 |
+
"depth": {
|
| 61 |
+
"name": "Reasoning Depth",
|
| 62 |
+
"description": "Detect shallow reasoning, encourage chain-of-thought",
|
| 63 |
+
"positive_patterns": [
|
| 64 |
+
("What causes rain?", "Water falls from clouds.", [1, 1, 1, 1]),
|
| 65 |
+
("How does gravity work?", "Things fall down.", [1, 1, 1]),
|
| 66 |
+
("Why is the sky blue?", "It just is that way.", [1, 1, 1, 1, 1]),
|
| 67 |
+
("Explain photosynthesis.", "Plants make food from sun.", [1, 1, 1, 1, 1]),
|
| 68 |
+
("What is democracy?", "People vote for leaders.", [1, 1, 1, 1]),
|
| 69 |
+
("How do computers work?", "They process information.", [1, 1, 1]),
|
| 70 |
+
("Why do we sleep?", "Bodies need rest.", [1, 1, 1]),
|
| 71 |
+
("What causes earthquakes?", "The ground shakes.", [1, 1, 1]),
|
| 72 |
+
],
|
| 73 |
+
"negative_patterns": [
|
| 74 |
+
("What causes rain?", "Rain forms through the water cycle. First, the sun heats water in oceans causing evaporation. This water vapor rises and cools, condensing into clouds. When droplets become heavy enough, they fall as precipitation. This process is driven by solar energy and Earth's geography.", [0]*60),
|
| 75 |
+
("How does gravity work?", "Gravity is explained by Einstein's general relativity. Mass curves the fabric of spacetime, and objects follow geodesics through this curved space. The more massive an object, the more it curves spacetime around it, which we perceive as gravitational attraction.", [0]*50),
|
| 76 |
+
("Why is the sky blue?", "The sky appears blue due to Rayleigh scattering. When sunlight enters Earth's atmosphere, it collides with gas molecules. Blue light has a shorter wavelength, so it scatters more than other colors. This scattered blue light reaches our eyes from all directions, making the sky appear blue.", [0]*55),
|
| 77 |
+
],
|
| 78 |
+
},
|
| 79 |
+
|
| 80 |
+
"specificity": {
|
| 81 |
+
"name": "Answer Specificity",
|
| 82 |
+
"description": "Detect vague answers, encourage concrete details",
|
| 83 |
+
"positive_patterns": [
|
| 84 |
+
("Best programming language?", "There are many good options depending on various factors and things you want to do.", [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 85 |
+
("How to lose weight?", "You should do different things and generally eat better and exercise somewhat more.", [0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1]),
|
| 86 |
+
("Career advice?", "It depends on many things. Think about stuff you like and various options.", [0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1]),
|
| 87 |
+
("How to learn faster?", "Try different methods and do things that work for you generally.", [0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1]),
|
| 88 |
+
("Best investment?", "There are various options depending on things like your situation.", [0, 0, 1, 1, 1, 0, 1, 0, 0, 1]),
|
| 89 |
+
],
|
| 90 |
+
"negative_patterns": [
|
| 91 |
+
("Best programming language?", "For web development, I recommend JavaScript with React. It has 97.6% browser support, 18M+ npm packages, and average salary of $112k. Specifically for beginners, Python offers cleaner syntax and is used by 48% of developers.", [0]*45),
|
| 92 |
+
("How to lose weight?", "Create a 500 calorie daily deficit through diet. Eat 0.8-1g protein per pound bodyweight. Do 150 minutes moderate cardio weekly plus 2 strength sessions. Track intake with MyFitnessPal. Expect 1-2 lbs loss per week.", [0]*45),
|
| 93 |
+
("Career advice?", "In tech, data science roles pay $120k median with 22% growth. Required skills: Python, SQL, statistics. Start with Google Data Analytics Certificate (6 months, $39/month). Build portfolio on Kaggle.", [0]*40),
|
| 94 |
+
],
|
| 95 |
+
},
|
| 96 |
+
|
| 97 |
+
"calibration": {
|
| 98 |
+
"name": "Confidence Calibration",
|
| 99 |
+
"description": "Detect overconfidence on uncertain topics",
|
| 100 |
+
"positive_patterns": [
|
| 101 |
+
("Will AI take all jobs?", "Absolutely, AI will definitely replace every single job within 10 years guaranteed.", [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 102 |
+
("Best investment?", "Crypto is guaranteed to 10x your money. You'll definitely make money no doubt.", [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1]),
|
| 103 |
+
("Will it rain tomorrow?", "It will certainly rain tomorrow. There's absolutely no doubt about that.", [0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0]),
|
| 104 |
+
("Is this stock good?", "This stock will absolutely skyrocket. It's impossible for it to fail.", [0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1]),
|
| 105 |
+
("Will the team win?", "They will definitely win. There's no way they can lose this game.", [0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0]),
|
| 106 |
+
],
|
| 107 |
+
"negative_patterns": [
|
| 108 |
+
("Will AI take all jobs?", "This is uncertain and debated. AI will likely automate some tasks, but historical evidence suggests new jobs often emerge. Estimates range from 10-50% job displacement over decades, with significant uncertainty about timeline and scope.", [0]*45),
|
| 109 |
+
("Best investment?", "I can't predict markets reliably. Historically, diversified index funds average 7-10% annually, but past performance doesn't guarantee future results. Consider consulting a financial advisor for personalized advice.", [0]*35),
|
| 110 |
+
("Will it rain tomorrow?", "Based on current forecasts, there's about a 60% chance of rain, but weather predictions beyond 24 hours have significant uncertainty. I'd suggest checking closer to the time.", [0]*30),
|
| 111 |
+
],
|
| 112 |
+
},
|
| 113 |
+
|
| 114 |
+
"focus": {
|
| 115 |
+
"name": "Topic Focus",
|
| 116 |
+
"description": "Detect topic drift and tangents",
|
| 117 |
+
"positive_patterns": [
|
| 118 |
+
("What is Python?", "Python is a language. Speaking of which, did you know snakes can be 30 feet long? Anacondas are fascinating creatures. By the way, I love nature documentaries. Anyway, what was the question again?", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 119 |
+
("Explain HTTP.", "HTTP is a protocol. By the way, I love the internet! Remember dial-up modems? Those were the days. Speaking of old technology, fax machines are interesting too. Oh right, HTTP uses request-response patterns.", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]),
|
| 120 |
+
("How do cars work?", "Cars have engines. Incidentally, I was thinking about trains yesterday. Trains are so relaxing! Have you been on the Orient Express? Anyway, cars use combustion.", [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]),
|
| 121 |
+
],
|
| 122 |
+
"negative_patterns": [
|
| 123 |
+
("What is Python?", "Python is a high-level programming language created by Guido van Rossum in 1991. It emphasizes code readability using significant whitespace. Key features include dynamic typing, garbage collection, and an extensive standard library. It's used for web development, data science, automation, and AI.", [0]*55),
|
| 124 |
+
("Explain HTTP.", "HTTP (Hypertext Transfer Protocol) is an application-layer protocol for transmitting hypermedia documents. It follows a request-response model between client and server. Key methods include GET (retrieve), POST (submit), PUT (update), and DELETE (remove). It's stateless, meaning each request is independent.", [0]*50),
|
| 125 |
+
("How do cars work?", "Cars operate through internal combustion engines that convert fuel into motion. The engine draws in air and fuel, compresses the mixture, ignites it with a spark plug, and uses the resulting expansion to push pistons. This motion is transferred through the transmission to the wheels.", [0]*50),
|
| 126 |
+
],
|
| 127 |
+
},
|
| 128 |
+
|
| 129 |
+
"coherence": {
|
| 130 |
+
"name": "Logical Coherence",
|
| 131 |
+
"description": "Detect contradictions and non-sequiturs",
|
| 132 |
+
"positive_patterns": [
|
| 133 |
+
("Is exercise good?", "Exercise is essential for health. However, you should never exercise. Moving is important but staying completely still is better for you.", [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 134 |
+
("Is coffee healthy?", "Coffee is very healthy for you. Actually, coffee is quite unhealthy. It helps you focus and also hurts your ability to focus.", [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 135 |
+
("Should I save money?", "Saving money is wise. But actually, you should spend everything. Financial security matters, but also doesn't matter at all.", [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
|
| 136 |
+
],
|
| 137 |
+
"negative_patterns": [
|
| 138 |
+
("Is exercise good?", "Exercise provides numerous health benefits including improved cardiovascular health, stronger muscles, and better mental health. However, excessive exercise without adequate recovery can lead to injury and burnout. Therefore, the key is finding a sustainable routine with appropriate rest periods.", [0]*50),
|
| 139 |
+
("Is coffee healthy?", "Coffee has both benefits and drawbacks that depend on individual factors. Benefits include improved alertness, antioxidants, and potential reduced risk of certain diseases. However, excessive consumption can cause anxiety, sleep disruption, and dependency. Consequently, moderate consumption of 2-3 cups is generally considered safe for most adults.", [0]*55),
|
| 140 |
+
("Should I save money?", "Saving money is generally wise for financial security and future goals. However, the optimal savings rate depends on your income, expenses, and life stage. Therefore, financial advisors typically recommend saving 20% of income while still allowing for present enjoyment and necessary expenses.", [0]*50),
|
| 141 |
+
],
|
| 142 |
+
},
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 147 |
+
# NEURAL NETWORK ARCHITECTURE
|
| 148 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 149 |
+
|
| 150 |
+
class FiberProjection(nn.Module):
|
| 151 |
+
def __init__(self, hidden_dim=4096, fiber_dim=16, n_layers=3):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.projections = nn.ModuleList([
|
| 154 |
+
nn.Linear(hidden_dim, fiber_dim, bias=False)
|
| 155 |
+
for _ in range(n_layers)
|
| 156 |
+
])
|
| 157 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 158 |
+
|
| 159 |
+
def forward(self, hidden_states_list):
|
| 160 |
+
fibers = [proj(h.float()) for proj, h in zip(self.projections, hidden_states_list)]
|
| 161 |
+
weights = F.softmax(self.layer_weights, dim=0)
|
| 162 |
+
return sum(w * f for w, f in zip(weights, fibers))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class BehaviorHead(nn.Module):
|
| 166 |
+
def __init__(self, fiber_dim=16, hidden_dim=64):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.classifier = nn.Sequential(
|
| 169 |
+
nn.Linear(fiber_dim, hidden_dim),
|
| 170 |
+
nn.ReLU(),
|
| 171 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 172 |
+
nn.ReLU(),
|
| 173 |
+
nn.Linear(hidden_dim, 1)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, fiber):
|
| 177 |
+
return self.classifier(fiber).squeeze(-1)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class EnhancementProbe(nn.Module):
|
| 181 |
+
def __init__(self, hidden_dim=4096, fiber_dim=16, head_hidden=64, n_layers=3):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.fiber_projection = FiberProjection(hidden_dim, fiber_dim, n_layers)
|
| 184 |
+
self.head = BehaviorHead(fiber_dim, head_hidden)
|
| 185 |
+
|
| 186 |
+
def forward(self, hidden_states_list):
|
| 187 |
+
fiber = self.fiber_projection(hidden_states_list)
|
| 188 |
+
return self.head(fiber)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 192 |
+
# DATA GENERATION
|
| 193 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 194 |
+
|
| 195 |
+
def generate_training_data(probe_name: str, n_samples: int = 2000) -> List[Dict]:
|
| 196 |
+
if probe_name not in PROBES:
|
| 197 |
+
raise ValueError(f"Unknown probe: {probe_name}")
|
| 198 |
+
|
| 199 |
+
probe_def = PROBES[probe_name]
|
| 200 |
+
examples = []
|
| 201 |
+
|
| 202 |
+
positive_patterns = probe_def["positive_patterns"]
|
| 203 |
+
negative_patterns = probe_def["negative_patterns"]
|
| 204 |
+
|
| 205 |
+
for i in range(n_samples):
|
| 206 |
+
if i % 2 == 0 and positive_patterns:
|
| 207 |
+
pattern = random.choice(positive_patterns)
|
| 208 |
+
prompt, response, base_labels = pattern
|
| 209 |
+
words = response.split()
|
| 210 |
+
if len(base_labels) < len(words):
|
| 211 |
+
labels = base_labels + [base_labels[-1]] * (len(words) - len(base_labels))
|
| 212 |
+
else:
|
| 213 |
+
labels = base_labels[:len(words)]
|
| 214 |
+
examples.append({"prompt": prompt, "response": response, "labels": labels, "is_positive": True})
|
| 215 |
+
else:
|
| 216 |
+
pattern = random.choice(negative_patterns)
|
| 217 |
+
prompt, response, _ = pattern
|
| 218 |
+
words = response.split()
|
| 219 |
+
labels = [0] * len(words)
|
| 220 |
+
examples.append({"prompt": prompt, "response": response, "labels": labels, "is_positive": False})
|
| 221 |
+
|
| 222 |
+
return examples
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class ProbeDataset(Dataset):
|
| 226 |
+
def __init__(self, examples, tokenizer, max_length=512):
|
| 227 |
+
self.examples = examples
|
| 228 |
+
self.tokenizer = tokenizer
|
| 229 |
+
self.max_length = max_length
|
| 230 |
+
|
| 231 |
+
def __len__(self):
|
| 232 |
+
return len(self.examples)
|
| 233 |
+
|
| 234 |
+
def __getitem__(self, idx):
|
| 235 |
+
ex = self.examples[idx]
|
| 236 |
+
full_text = f"{ex['prompt']}\n{ex['response']}"
|
| 237 |
+
encoding = self.tokenizer(full_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt")
|
| 238 |
+
|
| 239 |
+
n_tokens = encoding['input_ids'].shape[1]
|
| 240 |
+
token_labels = torch.zeros(n_tokens)
|
| 241 |
+
|
| 242 |
+
prompt_len = len(self.tokenizer.encode(ex['prompt']))
|
| 243 |
+
word_labels = ex['labels']
|
| 244 |
+
|
| 245 |
+
if word_labels:
|
| 246 |
+
response_start = prompt_len
|
| 247 |
+
tokens_per_word = max(1, (n_tokens - prompt_len) // max(len(word_labels), 1))
|
| 248 |
+
for i, label in enumerate(word_labels):
|
| 249 |
+
start_idx = response_start + i * tokens_per_word
|
| 250 |
+
end_idx = min(start_idx + tokens_per_word, n_tokens)
|
| 251 |
+
token_labels[start_idx:end_idx] = label
|
| 252 |
+
|
| 253 |
+
return {
|
| 254 |
+
'input_ids': encoding['input_ids'].squeeze(0),
|
| 255 |
+
'attention_mask': encoding['attention_mask'].squeeze(0),
|
| 256 |
+
'labels': token_labels,
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 261 |
+
# TRAINING
|
| 262 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 263 |
+
|
| 264 |
+
def train_probe(probe_name: str, model, tokenizer, max_steps: int = DEFAULT_STEPS, output_dir: str = OUTPUT_DIR):
|
| 265 |
+
print(f"\n{'='*70}")
|
| 266 |
+
print(f" TRAINING: {probe_name.upper()} PROBE")
|
| 267 |
+
print(f" {PROBES[probe_name]['description']}")
|
| 268 |
+
print(f"{'='*70}")
|
| 269 |
+
|
| 270 |
+
device = next(model.parameters()).device
|
| 271 |
+
n_layers = len(PROBE_LAYERS)
|
| 272 |
+
probe = EnhancementProbe(HIDDEN_DIM, FIBER_DIM, HEAD_HIDDEN, n_layers).to(device)
|
| 273 |
+
|
| 274 |
+
print(f"\n[data] Generating training data...")
|
| 275 |
+
examples = generate_training_data(probe_name, n_samples=3000)
|
| 276 |
+
print(f"[data] Generated {len(examples)} examples")
|
| 277 |
+
|
| 278 |
+
dataset = ProbeDataset(examples, tokenizer)
|
| 279 |
+
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 280 |
+
|
| 281 |
+
optimizer = torch.optim.AdamW(probe.parameters(), lr=LEARNING_RATE)
|
| 282 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
|
| 283 |
+
|
| 284 |
+
probe_dir = os.path.join(output_dir, probe_name)
|
| 285 |
+
os.makedirs(probe_dir, exist_ok=True)
|
| 286 |
+
|
| 287 |
+
probe.train()
|
| 288 |
+
model.eval()
|
| 289 |
+
|
| 290 |
+
step = 0
|
| 291 |
+
best_separation = 0.0
|
| 292 |
+
|
| 293 |
+
print(f"\n[train] Starting training for {max_steps} steps...")
|
| 294 |
+
print("-" * 70)
|
| 295 |
+
|
| 296 |
+
while step < max_steps:
|
| 297 |
+
for batch in dataloader:
|
| 298 |
+
if step >= max_steps:
|
| 299 |
+
break
|
| 300 |
+
|
| 301 |
+
input_ids = batch['input_ids'].to(device)
|
| 302 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 303 |
+
labels = batch['labels'].to(device)
|
| 304 |
+
|
| 305 |
+
with torch.no_grad():
|
| 306 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True)
|
| 307 |
+
|
| 308 |
+
probe_states = [outputs.hidden_states[layer + 1] for layer in PROBE_LAYERS]
|
| 309 |
+
logits = probe(probe_states)
|
| 310 |
+
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
| 311 |
+
|
| 312 |
+
loss = loss / GRADIENT_ACCUMULATION
|
| 313 |
+
loss.backward()
|
| 314 |
+
|
| 315 |
+
if (step + 1) % GRADIENT_ACCUMULATION == 0:
|
| 316 |
+
torch.nn.utils.clip_grad_norm_(probe.parameters(), 1.0)
|
| 317 |
+
optimizer.step()
|
| 318 |
+
scheduler.step()
|
| 319 |
+
optimizer.zero_grad()
|
| 320 |
+
|
| 321 |
+
step += 1
|
| 322 |
+
|
| 323 |
+
if step % 100 == 0:
|
| 324 |
+
with torch.no_grad():
|
| 325 |
+
probs = torch.sigmoid(logits)
|
| 326 |
+
pos_mask = labels > 0.5
|
| 327 |
+
neg_mask = labels < 0.5
|
| 328 |
+
pos_mean = probs[pos_mask].mean().item() if pos_mask.any() else 0
|
| 329 |
+
neg_mean = probs[neg_mask].mean().item() if neg_mask.any() else 0.001
|
| 330 |
+
separation = pos_mean / max(neg_mean, 0.001)
|
| 331 |
+
|
| 332 |
+
print(f"Step {step:>6} | Loss: {loss.item()*GRADIENT_ACCUMULATION:.4f} | Pos: {pos_mean:.3f} | Neg: {neg_mean:.3f} | Sep: {separation:.2f}×")
|
| 333 |
+
|
| 334 |
+
if separation > best_separation:
|
| 335 |
+
best_separation = separation
|
| 336 |
+
|
| 337 |
+
if step % SAVE_EVERY == 0:
|
| 338 |
+
ckpt_dir = os.path.join(probe_dir, f"ckpt_{step}")
|
| 339 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 340 |
+
|
| 341 |
+
with torch.no_grad():
|
| 342 |
+
probs = torch.sigmoid(logits)
|
| 343 |
+
pos_mask = labels > 0.5
|
| 344 |
+
neg_mask = labels < 0.5
|
| 345 |
+
pos_mean = probs[pos_mask].mean().item() if pos_mask.any() else 0
|
| 346 |
+
neg_mean = probs[neg_mask].mean().item() if neg_mask.any() else 0.001
|
| 347 |
+
separation = pos_mean / max(neg_mean, 0.001)
|
| 348 |
+
|
| 349 |
+
checkpoint = {
|
| 350 |
+
'fiber_projection': probe.fiber_projection.state_dict(),
|
| 351 |
+
'head_state': probe.head.state_dict(),
|
| 352 |
+
'step': step,
|
| 353 |
+
'separation': separation,
|
| 354 |
+
'loss': loss.item() * GRADIENT_ACCUMULATION,
|
| 355 |
+
'probe_name': probe_name,
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
torch.save(checkpoint, os.path.join(ckpt_dir, f"{probe_name}_head.pt"))
|
| 359 |
+
print(f">>> Saved checkpoint: {ckpt_dir} (sep: {separation:.2f}×)")
|
| 360 |
+
|
| 361 |
+
print(f"\n{'='*70}")
|
| 362 |
+
print(f" FINISHED: {probe_name.upper()}")
|
| 363 |
+
print(f" Best separation: {best_separation:.2f}×")
|
| 364 |
+
print(f"{'='*70}")
|
| 365 |
+
|
| 366 |
+
return best_separation
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def main():
|
| 370 |
+
parser = argparse.ArgumentParser(description="Train Cognitive Enhancement Probes")
|
| 371 |
+
parser.add_argument("--probe", type=str, default="all", help="Probe to train (depth, specificity, calibration, focus, coherence, or 'all')")
|
| 372 |
+
parser.add_argument("--steps", type=int, default=DEFAULT_STEPS, help=f"Training steps (default: {DEFAULT_STEPS})")
|
| 373 |
+
parser.add_argument("--output", type=str, default=OUTPUT_DIR, help="Output directory")
|
| 374 |
+
args = parser.parse_args()
|
| 375 |
+
|
| 376 |
+
print("\n" + "=" * 70)
|
| 377 |
+
print(" COGNITIVE ENHANCEMENT TRAINING")
|
| 378 |
+
print(" Making 8B Think Like 100B")
|
| 379 |
+
print("=" * 70)
|
| 380 |
+
|
| 381 |
+
print(f"\n[model] Loading from: {MODEL_PATH}")
|
| 382 |
+
|
| 383 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 384 |
+
|
| 385 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 386 |
+
if tokenizer.pad_token_id is None:
|
| 387 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 388 |
+
|
| 389 |
+
bnb_config = BitsAndBytesConfig(
|
| 390 |
+
load_in_4bit=True,
|
| 391 |
+
bnb_4bit_quant_type="nf4",
|
| 392 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 393 |
+
bnb_4bit_use_double_quant=True,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 397 |
+
MODEL_PATH,
|
| 398 |
+
quantization_config=bnb_config,
|
| 399 |
+
device_map="auto",
|
| 400 |
+
torch_dtype=torch.bfloat16,
|
| 401 |
+
)
|
| 402 |
+
model.eval()
|
| 403 |
+
|
| 404 |
+
print(f"[model] Loaded: {model.config.hidden_size} hidden dim, {model.config.num_hidden_layers} layers")
|
| 405 |
+
|
| 406 |
+
global HIDDEN_DIM
|
| 407 |
+
HIDDEN_DIM = model.config.hidden_size
|
| 408 |
+
|
| 409 |
+
os.makedirs(args.output, exist_ok=True)
|
| 410 |
+
|
| 411 |
+
if args.probe == "all":
|
| 412 |
+
probes_to_train = list(PROBES.keys())
|
| 413 |
+
else:
|
| 414 |
+
probes_to_train = [args.probe]
|
| 415 |
+
|
| 416 |
+
results = {}
|
| 417 |
+
for probe_name in probes_to_train:
|
| 418 |
+
if probe_name not in PROBES:
|
| 419 |
+
print(f"[warning] Unknown probe: {probe_name}, skipping")
|
| 420 |
+
continue
|
| 421 |
+
separation = train_probe(probe_name, model, tokenizer, args.steps, args.output)
|
| 422 |
+
results[probe_name] = separation
|
| 423 |
+
|
| 424 |
+
print("\n" + "=" * 70)
|
| 425 |
+
print(" TRAINING COMPLETE - SUMMARY")
|
| 426 |
+
print("=" * 70)
|
| 427 |
+
for name, sep in results.items():
|
| 428 |
+
print(f" {name}: {sep:.1f}× separation")
|
| 429 |
+
print("=" * 70)
|
| 430 |
+
print(f"\nCheckpoints saved to: {args.output}")
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
main()
|
code/training_pipelines/01_cfhot_risk_v2_REPETITION_125x.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CF-HoT Phase 1: Train Risk Predictor (FIXED)
|
| 4 |
+
=============================================
|
| 5 |
+
Fixed: Class weighting to handle imbalanced data (most tokens aren't repeats)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 12 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
import random
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Tuple
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Config:
|
| 22 |
+
model_path: str = "/mnt/nvme2/ubermesnchetien4/models/merged-final-v5"
|
| 23 |
+
output_dir: str = "./results/cfhot_risk_v2"
|
| 24 |
+
d_fiber: int = 16
|
| 25 |
+
d_control: int = 64
|
| 26 |
+
max_steps: int = 3000
|
| 27 |
+
batch_size: int = 1
|
| 28 |
+
grad_accum: int = 8
|
| 29 |
+
max_length: int = 256
|
| 30 |
+
lr_lora: float = 2e-5
|
| 31 |
+
lr_predictor: float = 1e-4
|
| 32 |
+
weight_decay: float = 0.01
|
| 33 |
+
rep_window: int = 32
|
| 34 |
+
log_every: int = 10
|
| 35 |
+
save_every: int = 500
|
| 36 |
+
eval_every: int = 200
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RiskPredictor(nn.Module):
|
| 40 |
+
def __init__(self, d_model: int, n_layers: int, config: Config):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.config = config
|
| 43 |
+
self.n_layers = n_layers
|
| 44 |
+
|
| 45 |
+
self.fiber_projs = nn.ModuleList([
|
| 46 |
+
nn.Linear(d_model, config.d_fiber, bias=False)
|
| 47 |
+
for _ in range(n_layers)
|
| 48 |
+
])
|
| 49 |
+
|
| 50 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 51 |
+
|
| 52 |
+
# Output LOGITS, not probabilities
|
| 53 |
+
self.predictor = nn.Sequential(
|
| 54 |
+
nn.Linear(config.d_fiber, config.d_control),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.Linear(config.d_control, config.d_control),
|
| 57 |
+
nn.GELU(),
|
| 58 |
+
nn.Linear(config.d_control, 1)
|
| 59 |
+
# NO sigmoid here - we'll use BCEWithLogitsLoss
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
for proj in self.fiber_projs:
|
| 63 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 64 |
+
|
| 65 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 66 |
+
fibers = []
|
| 67 |
+
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
|
| 68 |
+
if i < len(hidden_states):
|
| 69 |
+
fiber = proj(h.float())
|
| 70 |
+
fibers.append(fiber)
|
| 71 |
+
|
| 72 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 73 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 74 |
+
|
| 75 |
+
logits = self.predictor(aggregated).squeeze(-1) # [B, S] LOGITS
|
| 76 |
+
return logits
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 80 |
+
B, S = input_ids.shape
|
| 81 |
+
device = input_ids.device
|
| 82 |
+
labels = torch.zeros(B, S, device=device)
|
| 83 |
+
|
| 84 |
+
for offset in range(1, min(window + 1, S)):
|
| 85 |
+
if offset < S:
|
| 86 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 87 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 88 |
+
|
| 89 |
+
return labels
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
config = Config()
|
| 94 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
print("=" * 70)
|
| 97 |
+
print("CF-HoT RISK PREDICTOR v2 (CLASS-WEIGHTED)")
|
| 98 |
+
print("=" * 70)
|
| 99 |
+
|
| 100 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 101 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 102 |
+
|
| 103 |
+
print("Loading model...")
|
| 104 |
+
bnb = BitsAndBytesConfig(
|
| 105 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 106 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"
|
| 107 |
+
)
|
| 108 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 109 |
+
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16
|
| 110 |
+
)
|
| 111 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
| 112 |
+
|
| 113 |
+
device = next(model.parameters()).device
|
| 114 |
+
print(f"Device: {device}")
|
| 115 |
+
|
| 116 |
+
print("Adding LoRA...")
|
| 117 |
+
model = get_peft_model(model, LoraConfig(
|
| 118 |
+
r=64, lora_alpha=128,
|
| 119 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 120 |
+
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
|
| 121 |
+
))
|
| 122 |
+
model.print_trainable_parameters()
|
| 123 |
+
|
| 124 |
+
print("Adding Risk Predictor...")
|
| 125 |
+
n_layers = model.config.num_hidden_layers
|
| 126 |
+
d_model = model.config.hidden_size
|
| 127 |
+
risk_predictor = RiskPredictor(d_model, n_layers, config).to(device).float()
|
| 128 |
+
print(f"Risk Predictor params: {sum(p.numel() for p in risk_predictor.parameters()):,}")
|
| 129 |
+
|
| 130 |
+
print("Loading data...")
|
| 131 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 132 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 133 |
+
random.shuffle(texts)
|
| 134 |
+
print(f"Loaded {len(texts)} samples")
|
| 135 |
+
|
| 136 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 137 |
+
optimizer = torch.optim.AdamW([
|
| 138 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 139 |
+
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
|
| 140 |
+
], weight_decay=config.weight_decay)
|
| 141 |
+
|
| 142 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 143 |
+
optimizer, T_max=config.max_steps, eta_min=1e-6
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
print("\n" + "=" * 70)
|
| 147 |
+
print("TRAINING (with class weighting)")
|
| 148 |
+
print("=" * 70)
|
| 149 |
+
|
| 150 |
+
model.train()
|
| 151 |
+
risk_predictor.train()
|
| 152 |
+
|
| 153 |
+
step = 0
|
| 154 |
+
data_idx = 0
|
| 155 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 156 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 157 |
+
start_time = time.time()
|
| 158 |
+
|
| 159 |
+
while step < config.max_steps:
|
| 160 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 161 |
+
data_idx += config.batch_size
|
| 162 |
+
|
| 163 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 164 |
+
padding='max_length', return_tensors='pt')
|
| 165 |
+
input_ids = enc['input_ids'].to(device)
|
| 166 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 167 |
+
|
| 168 |
+
outputs = model(
|
| 169 |
+
input_ids=input_ids,
|
| 170 |
+
attention_mask=attention_mask,
|
| 171 |
+
labels=input_ids,
|
| 172 |
+
output_hidden_states=True
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
lm_loss = outputs.loss
|
| 176 |
+
|
| 177 |
+
# Get logits from risk predictor
|
| 178 |
+
risk_logits = risk_predictor(outputs.hidden_states[1:])
|
| 179 |
+
|
| 180 |
+
# Compute labels
|
| 181 |
+
rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window)
|
| 182 |
+
|
| 183 |
+
# CLASS-WEIGHTED LOSS
|
| 184 |
+
# Compute class weights dynamically based on batch
|
| 185 |
+
mask = attention_mask.float()
|
| 186 |
+
n_pos = (rep_labels * mask).sum().clamp(min=1)
|
| 187 |
+
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
|
| 188 |
+
pos_weight = n_neg / n_pos # Weight positives by ratio of negatives to positives
|
| 189 |
+
pos_weight = pos_weight.clamp(max=10.0) # Cap at 10x
|
| 190 |
+
|
| 191 |
+
# BCEWithLogitsLoss with pos_weight
|
| 192 |
+
bce_loss = F.binary_cross_entropy_with_logits(
|
| 193 |
+
risk_logits, rep_labels,
|
| 194 |
+
pos_weight=torch.ones_like(rep_labels) * pos_weight,
|
| 195 |
+
reduction='none'
|
| 196 |
+
)
|
| 197 |
+
risk_loss = (bce_loss * mask).sum() / mask.sum()
|
| 198 |
+
|
| 199 |
+
# Total loss
|
| 200 |
+
loss = lm_loss + risk_loss
|
| 201 |
+
|
| 202 |
+
(loss / config.grad_accum).backward()
|
| 203 |
+
|
| 204 |
+
# Metrics (apply sigmoid for evaluation)
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
risk_pred = torch.sigmoid(risk_logits)
|
| 207 |
+
pred_binary = (risk_pred > 0.5).float()
|
| 208 |
+
tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum()
|
| 209 |
+
fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum()
|
| 210 |
+
fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum()
|
| 211 |
+
|
| 212 |
+
precision = tp / (tp + fp + 1e-8)
|
| 213 |
+
recall = tp / (tp + fn + 1e-8)
|
| 214 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 215 |
+
|
| 216 |
+
acc_loss += loss.item()
|
| 217 |
+
acc_lm += lm_loss.item()
|
| 218 |
+
acc_risk_loss += risk_loss.item()
|
| 219 |
+
acc_precision += precision.item()
|
| 220 |
+
acc_recall += recall.item()
|
| 221 |
+
acc_f1 += f1.item()
|
| 222 |
+
|
| 223 |
+
step += 1
|
| 224 |
+
|
| 225 |
+
if step % config.grad_accum == 0:
|
| 226 |
+
torch.nn.utils.clip_grad_norm_(
|
| 227 |
+
list(lora_params) + list(risk_predictor.parameters()), 1.0
|
| 228 |
+
)
|
| 229 |
+
optimizer.step()
|
| 230 |
+
scheduler.step()
|
| 231 |
+
optimizer.zero_grad()
|
| 232 |
+
|
| 233 |
+
if step % config.log_every == 0:
|
| 234 |
+
eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600
|
| 235 |
+
n = config.log_every
|
| 236 |
+
|
| 237 |
+
print(
|
| 238 |
+
f"Step {step:5d} | "
|
| 239 |
+
f"Loss: {acc_loss/n:.4f} | "
|
| 240 |
+
f"LM: {acc_lm/n:.4f} | "
|
| 241 |
+
f"Risk: {acc_risk_loss/n:.4f} | "
|
| 242 |
+
f"P: {acc_precision/n:.3f} | "
|
| 243 |
+
f"R: {acc_recall/n:.3f} | "
|
| 244 |
+
f"F1: {acc_f1/n:.3f} | "
|
| 245 |
+
f"ETA: {eta:.1f}h"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 249 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 250 |
+
|
| 251 |
+
if step % config.save_every == 0:
|
| 252 |
+
ckpt = os.path.join(config.output_dir, f"ckpt_{step}")
|
| 253 |
+
os.makedirs(ckpt, exist_ok=True)
|
| 254 |
+
model.save_pretrained(ckpt)
|
| 255 |
+
torch.save({
|
| 256 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 257 |
+
'step': step
|
| 258 |
+
}, os.path.join(ckpt, "risk_predictor.pt"))
|
| 259 |
+
print(f">>> Saved: {ckpt}")
|
| 260 |
+
|
| 261 |
+
if step % config.eval_every == 0:
|
| 262 |
+
model.eval()
|
| 263 |
+
risk_predictor.eval()
|
| 264 |
+
|
| 265 |
+
print("\n--- Evaluation ---")
|
| 266 |
+
|
| 267 |
+
prompt = "The will to power, as described by Nietzsche, is"
|
| 268 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 269 |
+
input_ids = inp['input_ids'].to(device)
|
| 270 |
+
|
| 271 |
+
with torch.no_grad():
|
| 272 |
+
out = model.generate(
|
| 273 |
+
input_ids, max_new_tokens=60,
|
| 274 |
+
do_sample=True, temperature=0.8, top_p=0.9,
|
| 275 |
+
pad_token_id=tokenizer.eos_token_id
|
| 276 |
+
)
|
| 277 |
+
generated_text = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 278 |
+
|
| 279 |
+
gen_outputs = model(out, output_hidden_states=True)
|
| 280 |
+
gen_logits = risk_predictor(gen_outputs.hidden_states[1:])
|
| 281 |
+
gen_risk = torch.sigmoid(gen_logits)
|
| 282 |
+
|
| 283 |
+
risk_vals = gen_risk[0].cpu().numpy()
|
| 284 |
+
|
| 285 |
+
print(f" Generated: {generated_text[:200]}...")
|
| 286 |
+
print(f" Risk (first 10): {[f'{r:.2f}' for r in risk_vals[:10]]}")
|
| 287 |
+
print(f" Risk (last 10): {[f'{r:.2f}' for r in risk_vals[-10:]]}")
|
| 288 |
+
print(f" Mean: {risk_vals.mean():.3f}, Max: {risk_vals.max():.3f}, Min: {risk_vals.min():.3f}")
|
| 289 |
+
|
| 290 |
+
# Check correlation
|
| 291 |
+
gen_ids = out[0].cpu().numpy()
|
| 292 |
+
actual_reps = []
|
| 293 |
+
for t in range(1, len(gen_ids)):
|
| 294 |
+
start = max(0, t - config.rep_window)
|
| 295 |
+
is_rep = gen_ids[t] in gen_ids[start:t]
|
| 296 |
+
actual_reps.append(1 if is_rep else 0)
|
| 297 |
+
|
| 298 |
+
# Correlation between risk and actual repeats
|
| 299 |
+
if len(actual_reps) > 1:
|
| 300 |
+
risk_at_reps = [risk_vals[i+1] for i, r in enumerate(actual_reps) if r == 1 and i+1 < len(risk_vals)]
|
| 301 |
+
risk_at_nonreps = [risk_vals[i+1] for i, r in enumerate(actual_reps) if r == 0 and i+1 < len(risk_vals)]
|
| 302 |
+
|
| 303 |
+
if risk_at_reps and risk_at_nonreps:
|
| 304 |
+
print(f" Avg risk at REPEATS: {sum(risk_at_reps)/len(risk_at_reps):.3f}")
|
| 305 |
+
print(f" Avg risk at NON-REPS: {sum(risk_at_nonreps)/len(risk_at_nonreps):.3f}")
|
| 306 |
+
|
| 307 |
+
print("--- End Eval ---\n")
|
| 308 |
+
|
| 309 |
+
model.train()
|
| 310 |
+
risk_predictor.train()
|
| 311 |
+
|
| 312 |
+
final = os.path.join(config.output_dir, "final")
|
| 313 |
+
os.makedirs(final, exist_ok=True)
|
| 314 |
+
model.save_pretrained(final)
|
| 315 |
+
torch.save({
|
| 316 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 317 |
+
'step': step
|
| 318 |
+
}, os.path.join(final, "risk_predictor.pt"))
|
| 319 |
+
|
| 320 |
+
print(f"\nDONE! Saved to {final}")
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
if __name__ == "__main__":
|
| 324 |
+
main()
|
code/training_pipelines/02_arc_adapter_training_MULTIHEAD.py
ADDED
|
@@ -0,0 +1,1680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ARC ADAPTER TRAINING - COMPLETE VERSION
|
| 4 |
+
========================================
|
| 5 |
+
|
| 6 |
+
Trains the combined ARC adapter on a FROZEN base model.
|
| 7 |
+
|
| 8 |
+
Components:
|
| 9 |
+
- Shared fiber projections (4096 → 16 dim)
|
| 10 |
+
- Repetition detection head (target: 50×+ separation)
|
| 11 |
+
- Hedging detection head
|
| 12 |
+
- Verbosity detection head
|
| 13 |
+
- Sycophancy detection head
|
| 14 |
+
- Loop 4 tokenizer expansion
|
| 15 |
+
- Learned intervention thresholds
|
| 16 |
+
|
| 17 |
+
Base model: COMPLETELY FROZEN (never modified)
|
| 18 |
+
Adapter: ~2M trainable parameters
|
| 19 |
+
|
| 20 |
+
Author: Logan Napolitano
|
| 21 |
+
Date: January 2026
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import torch.optim as optim
|
| 28 |
+
from torch.utils.data import Dataset, DataLoader
|
| 29 |
+
import numpy as np
|
| 30 |
+
import json
|
| 31 |
+
import re
|
| 32 |
+
import gc
|
| 33 |
+
import os
|
| 34 |
+
import time
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
from dataclasses import dataclass, field
|
| 37 |
+
from typing import List, Dict, Optional, Tuple
|
| 38 |
+
from collections import defaultdict
|
| 39 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
import warnings
|
| 42 |
+
warnings.filterwarnings("ignore")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# =============================================================================
|
| 46 |
+
# CONFIGURATION
|
| 47 |
+
# =============================================================================
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ARCAdapterConfig:
|
| 51 |
+
"""Complete configuration for ARC Adapter training."""
|
| 52 |
+
|
| 53 |
+
# Paths
|
| 54 |
+
base_model_path: str = "."
|
| 55 |
+
output_dir: str = "arc_adapter"
|
| 56 |
+
|
| 57 |
+
# Device
|
| 58 |
+
device: str = "cuda"
|
| 59 |
+
|
| 60 |
+
# Model architecture (auto-filled from base model)
|
| 61 |
+
hidden_dim: int = 4096
|
| 62 |
+
fiber_dim: int = 16
|
| 63 |
+
probe_layers: List[int] = field(default_factory=lambda: [8, 16, 24])
|
| 64 |
+
|
| 65 |
+
# Data generation settings
|
| 66 |
+
n_samples_per_head: int = 300
|
| 67 |
+
max_gen_tokens: int = 80
|
| 68 |
+
repetition_window: int = 32
|
| 69 |
+
|
| 70 |
+
# Training settings
|
| 71 |
+
epochs: int = 15
|
| 72 |
+
batch_size: int = 32
|
| 73 |
+
learning_rate: float = 1e-4
|
| 74 |
+
weight_decay: float = 0.01
|
| 75 |
+
warmup_steps: int = 100
|
| 76 |
+
|
| 77 |
+
# Target separations for each head
|
| 78 |
+
target_separation: Dict[str, float] = field(default_factory=lambda: {
|
| 79 |
+
"repetition": 50.0, # We've achieved 125×, so 50× is conservative
|
| 80 |
+
"hedging": 5.0,
|
| 81 |
+
"verbosity": 5.0,
|
| 82 |
+
"sycophancy": 3.0,
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
# Loop 4 settings
|
| 86 |
+
loop4_iterations: int = 3
|
| 87 |
+
n_merges_per_iteration: int = 10
|
| 88 |
+
min_pair_frequency: int = 2
|
| 89 |
+
|
| 90 |
+
# Intervention defaults (learned during training)
|
| 91 |
+
default_thresholds: Dict[str, float] = field(default_factory=lambda: {
|
| 92 |
+
"repetition": 0.1,
|
| 93 |
+
"hedging": 0.3,
|
| 94 |
+
"verbosity": 0.4,
|
| 95 |
+
"sycophancy": 0.4,
|
| 96 |
+
})
|
| 97 |
+
default_penalty_strength: float = 2.0
|
| 98 |
+
|
| 99 |
+
# EMA settings for control field
|
| 100 |
+
ema_alpha: float = 0.15
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# =============================================================================
|
| 104 |
+
# ADAPTER ARCHITECTURE
|
| 105 |
+
# =============================================================================
|
| 106 |
+
|
| 107 |
+
class FiberProjection(nn.Module):
|
| 108 |
+
"""
|
| 109 |
+
Projects hidden states from multiple layers to shared fiber space.
|
| 110 |
+
|
| 111 |
+
This is the geometric core of CF-HoT - compressing high-dimensional
|
| 112 |
+
hidden states to a low-dimensional manifold where behavioral
|
| 113 |
+
tendencies are linearly separable.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, hidden_dim: int, fiber_dim: int, n_layers: int):
|
| 117 |
+
super().__init__()
|
| 118 |
+
|
| 119 |
+
self.hidden_dim = hidden_dim
|
| 120 |
+
self.fiber_dim = fiber_dim
|
| 121 |
+
self.n_layers = n_layers
|
| 122 |
+
|
| 123 |
+
# Per-layer projection matrices
|
| 124 |
+
self.projections = nn.ModuleList([
|
| 125 |
+
nn.Linear(hidden_dim, fiber_dim, bias=True)
|
| 126 |
+
for _ in range(n_layers)
|
| 127 |
+
])
|
| 128 |
+
|
| 129 |
+
# Learned layer importance weights
|
| 130 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 131 |
+
|
| 132 |
+
# Initialize projections
|
| 133 |
+
for proj in self.projections:
|
| 134 |
+
nn.init.xavier_uniform_(proj.weight)
|
| 135 |
+
nn.init.zeros_(proj.bias)
|
| 136 |
+
|
| 137 |
+
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
| 138 |
+
"""
|
| 139 |
+
Project list of hidden states to fiber space.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
hidden_states: List of [batch, seq, hidden_dim] tensors
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
fiber: [batch, seq, fiber_dim]
|
| 146 |
+
"""
|
| 147 |
+
weights = F.softmax(self.layer_weights, dim=0)
|
| 148 |
+
|
| 149 |
+
fiber = None
|
| 150 |
+
for i, (h, proj) in enumerate(zip(hidden_states, self.projections)):
|
| 151 |
+
# Cast to float32 for adapter computation
|
| 152 |
+
h = h.float()
|
| 153 |
+
projected = proj(h)
|
| 154 |
+
if fiber is None:
|
| 155 |
+
fiber = weights[i] * projected
|
| 156 |
+
else:
|
| 157 |
+
fiber = fiber + weights[i] * projected
|
| 158 |
+
|
| 159 |
+
return fiber
|
| 160 |
+
|
| 161 |
+
def forward_stacked(self, hidden_stack: torch.Tensor) -> torch.Tensor:
|
| 162 |
+
"""
|
| 163 |
+
Project stacked hidden states to fiber space.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
hidden_stack: [batch, n_layers, hidden_dim]
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
fiber: [batch, fiber_dim]
|
| 170 |
+
"""
|
| 171 |
+
# Cast to float32 for adapter computation (model outputs bfloat16)
|
| 172 |
+
hidden_stack = hidden_stack.float()
|
| 173 |
+
|
| 174 |
+
weights = F.softmax(self.layer_weights, dim=0)
|
| 175 |
+
|
| 176 |
+
batch_size = hidden_stack.shape[0]
|
| 177 |
+
fiber = torch.zeros(
|
| 178 |
+
batch_size,
|
| 179 |
+
self.fiber_dim,
|
| 180 |
+
device=hidden_stack.device,
|
| 181 |
+
dtype=torch.float32
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
for i, proj in enumerate(self.projections):
|
| 185 |
+
fiber = fiber + weights[i] * proj(hidden_stack[:, i, :])
|
| 186 |
+
|
| 187 |
+
return fiber
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class BehaviorHead(nn.Module):
|
| 191 |
+
"""
|
| 192 |
+
Single behavioral detection head.
|
| 193 |
+
|
| 194 |
+
Takes fiber state, outputs probability of specific behavior.
|
| 195 |
+
Architecture: fiber_dim → 64 → 16 → 1
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, fiber_dim: int, name: str):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.name = name
|
| 201 |
+
self.fiber_dim = fiber_dim
|
| 202 |
+
|
| 203 |
+
self.classifier = nn.Sequential(
|
| 204 |
+
nn.Linear(fiber_dim, 64),
|
| 205 |
+
nn.ReLU(),
|
| 206 |
+
nn.Dropout(0.1),
|
| 207 |
+
nn.Linear(64, 16),
|
| 208 |
+
nn.ReLU(),
|
| 209 |
+
nn.Dropout(0.05),
|
| 210 |
+
nn.Linear(16, 1),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Initialize
|
| 214 |
+
for module in self.classifier:
|
| 215 |
+
if isinstance(module, nn.Linear):
|
| 216 |
+
nn.init.xavier_uniform_(module.weight)
|
| 217 |
+
nn.init.zeros_(module.bias)
|
| 218 |
+
|
| 219 |
+
def forward(self, fiber: torch.Tensor) -> torch.Tensor:
|
| 220 |
+
"""
|
| 221 |
+
Get logits from fiber state.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
fiber: [batch, fiber_dim] or [batch, seq, fiber_dim]
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
logits: [batch] or [batch, seq]
|
| 228 |
+
"""
|
| 229 |
+
logits = self.classifier(fiber)
|
| 230 |
+
return logits.squeeze(-1)
|
| 231 |
+
|
| 232 |
+
def predict_proba(self, fiber: torch.Tensor) -> torch.Tensor:
|
| 233 |
+
"""Get probabilities."""
|
| 234 |
+
return torch.sigmoid(self.forward(fiber))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class ARCAdapter(nn.Module):
|
| 238 |
+
"""
|
| 239 |
+
Complete ARC Adapter module.
|
| 240 |
+
|
| 241 |
+
Contains:
|
| 242 |
+
- Shared fiber projection (geometry)
|
| 243 |
+
- Multiple behavioral heads (detection)
|
| 244 |
+
- Intervention parameters (control)
|
| 245 |
+
- EMA tracking (temporal smoothing)
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, config: ARCAdapterConfig):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.config = config
|
| 251 |
+
|
| 252 |
+
# Shared fiber projection
|
| 253 |
+
self.fiber_proj = FiberProjection(
|
| 254 |
+
hidden_dim=config.hidden_dim,
|
| 255 |
+
fiber_dim=config.fiber_dim,
|
| 256 |
+
n_layers=len(config.probe_layers)
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Behavioral detection heads
|
| 260 |
+
self.heads = nn.ModuleDict({
|
| 261 |
+
"repetition": BehaviorHead(config.fiber_dim, "repetition"),
|
| 262 |
+
"hedging": BehaviorHead(config.fiber_dim, "hedging"),
|
| 263 |
+
"verbosity": BehaviorHead(config.fiber_dim, "verbosity"),
|
| 264 |
+
"sycophancy": BehaviorHead(config.fiber_dim, "sycophancy"),
|
| 265 |
+
})
|
| 266 |
+
|
| 267 |
+
# Learned intervention thresholds
|
| 268 |
+
self.thresholds = nn.ParameterDict({
|
| 269 |
+
name: nn.Parameter(torch.tensor(thresh))
|
| 270 |
+
for name, thresh in config.default_thresholds.items()
|
| 271 |
+
})
|
| 272 |
+
|
| 273 |
+
# Learned penalty strength
|
| 274 |
+
self.penalty_strength = nn.Parameter(
|
| 275 |
+
torch.tensor(config.default_penalty_strength)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# EMA state for control field accumulation
|
| 279 |
+
self.ema_alpha = config.ema_alpha
|
| 280 |
+
self.register_buffer('_ema_initialized', torch.tensor(False))
|
| 281 |
+
self._ema_states: Dict[str, Optional[float]] = {}
|
| 282 |
+
self.reset_ema()
|
| 283 |
+
|
| 284 |
+
def reset_ema(self):
|
| 285 |
+
"""Reset EMA states for new generation."""
|
| 286 |
+
self._ema_states = {name: None for name in self.heads.keys()}
|
| 287 |
+
|
| 288 |
+
def forward(
|
| 289 |
+
self,
|
| 290 |
+
hidden_states: List[torch.Tensor]
|
| 291 |
+
) -> Dict[str, torch.Tensor]:
|
| 292 |
+
"""
|
| 293 |
+
Full forward pass through adapter.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
hidden_states: List of hidden states from probe layers
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
Dict mapping head_name → logits
|
| 300 |
+
"""
|
| 301 |
+
fiber = self.fiber_proj(hidden_states)
|
| 302 |
+
|
| 303 |
+
predictions = {}
|
| 304 |
+
for name, head in self.heads.items():
|
| 305 |
+
predictions[name] = head(fiber)
|
| 306 |
+
|
| 307 |
+
return predictions
|
| 308 |
+
|
| 309 |
+
def get_fiber(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
|
| 310 |
+
"""Get fiber representation."""
|
| 311 |
+
return self.fiber_proj(hidden_states)
|
| 312 |
+
|
| 313 |
+
def get_risks(
|
| 314 |
+
self,
|
| 315 |
+
hidden_states: List[torch.Tensor],
|
| 316 |
+
update_ema: bool = True
|
| 317 |
+
) -> Dict[str, float]:
|
| 318 |
+
"""
|
| 319 |
+
Get current risk scores with optional EMA update.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
hidden_states: List of [1, 1, hidden_dim] tensors (last position)
|
| 323 |
+
update_ema: Whether to update EMA states
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
Dict mapping head_name → risk score (0-1)
|
| 327 |
+
"""
|
| 328 |
+
# Stack and project
|
| 329 |
+
# hidden_states is list of [batch, seq, hidden_dim]
|
| 330 |
+
# We want the last position: [batch, n_layers, hidden_dim]
|
| 331 |
+
stacked = torch.stack([h[:, -1, :] for h in hidden_states], dim=1)
|
| 332 |
+
fiber = self.fiber_proj.forward_stacked(stacked)
|
| 333 |
+
|
| 334 |
+
risks = {}
|
| 335 |
+
for name, head in self.heads.items():
|
| 336 |
+
with torch.no_grad():
|
| 337 |
+
prob = head.predict_proba(fiber).mean().item()
|
| 338 |
+
|
| 339 |
+
if update_ema:
|
| 340 |
+
if self._ema_states[name] is None:
|
| 341 |
+
self._ema_states[name] = prob
|
| 342 |
+
else:
|
| 343 |
+
self._ema_states[name] = (
|
| 344 |
+
self.ema_alpha * prob +
|
| 345 |
+
(1 - self.ema_alpha) * self._ema_states[name]
|
| 346 |
+
)
|
| 347 |
+
risks[name] = self._ema_states[name]
|
| 348 |
+
else:
|
| 349 |
+
risks[name] = prob
|
| 350 |
+
|
| 351 |
+
return risks
|
| 352 |
+
|
| 353 |
+
def compute_intervention(
|
| 354 |
+
self,
|
| 355 |
+
risks: Dict[str, float],
|
| 356 |
+
recent_tokens: List[int],
|
| 357 |
+
window_size: int = 32
|
| 358 |
+
) -> Dict[int, float]:
|
| 359 |
+
"""
|
| 360 |
+
Compute logit penalties based on current risks.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
risks: Current risk scores from get_risks()
|
| 364 |
+
recent_tokens: Recently generated token IDs
|
| 365 |
+
window_size: How far back to penalize repetitions
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Dict mapping token_id → penalty amount
|
| 369 |
+
"""
|
| 370 |
+
penalties = {}
|
| 371 |
+
|
| 372 |
+
# Repetition intervention
|
| 373 |
+
rep_risk = risks.get("repetition", 0)
|
| 374 |
+
rep_thresh = self.thresholds["repetition"].item()
|
| 375 |
+
|
| 376 |
+
if rep_risk > rep_thresh:
|
| 377 |
+
# Scale penalty by how much we exceed threshold
|
| 378 |
+
strength = self.penalty_strength.item() * (rep_risk / rep_thresh)
|
| 379 |
+
|
| 380 |
+
# Penalize recently used tokens
|
| 381 |
+
recent = recent_tokens[-window_size:] if len(recent_tokens) > window_size else recent_tokens
|
| 382 |
+
for token_id in set(recent):
|
| 383 |
+
penalties[token_id] = penalties.get(token_id, 0) + strength
|
| 384 |
+
|
| 385 |
+
# Could add hedging/verbosity interventions here
|
| 386 |
+
# (e.g., penalize "As an AI" type tokens)
|
| 387 |
+
|
| 388 |
+
return penalties
|
| 389 |
+
|
| 390 |
+
def get_param_count(self) -> int:
|
| 391 |
+
"""Get total trainable parameter count."""
|
| 392 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 393 |
+
|
| 394 |
+
def save(self, path: str):
|
| 395 |
+
"""Save adapter to directory."""
|
| 396 |
+
path = Path(path)
|
| 397 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 398 |
+
|
| 399 |
+
# Save model weights
|
| 400 |
+
torch.save(self.state_dict(), path / "adapter_weights.pt")
|
| 401 |
+
|
| 402 |
+
# Save config as JSON
|
| 403 |
+
config_dict = {
|
| 404 |
+
"hidden_dim": self.config.hidden_dim,
|
| 405 |
+
"fiber_dim": self.config.fiber_dim,
|
| 406 |
+
"probe_layers": self.config.probe_layers,
|
| 407 |
+
"ema_alpha": self.ema_alpha,
|
| 408 |
+
"thresholds": {
|
| 409 |
+
name: self.thresholds[name].item()
|
| 410 |
+
for name in self.thresholds
|
| 411 |
+
},
|
| 412 |
+
"penalty_strength": self.penalty_strength.item(),
|
| 413 |
+
"head_names": list(self.heads.keys()),
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
with open(path / "adapter_config.json", "w") as f:
|
| 417 |
+
json.dump(config_dict, f, indent=2)
|
| 418 |
+
|
| 419 |
+
print(f"💾 Adapter saved to {path}")
|
| 420 |
+
print(f" Weights: adapter_weights.pt")
|
| 421 |
+
print(f" Config: adapter_config.json")
|
| 422 |
+
|
| 423 |
+
@classmethod
|
| 424 |
+
def load(cls, path: str, device: str = "cuda") -> "ARCAdapter":
|
| 425 |
+
"""Load adapter from directory."""
|
| 426 |
+
path = Path(path)
|
| 427 |
+
|
| 428 |
+
# Load config
|
| 429 |
+
with open(path / "adapter_config.json") as f:
|
| 430 |
+
config_dict = json.load(f)
|
| 431 |
+
|
| 432 |
+
# Create config object
|
| 433 |
+
config = ARCAdapterConfig(
|
| 434 |
+
hidden_dim=config_dict["hidden_dim"],
|
| 435 |
+
fiber_dim=config_dict["fiber_dim"],
|
| 436 |
+
probe_layers=config_dict["probe_layers"],
|
| 437 |
+
ema_alpha=config_dict["ema_alpha"],
|
| 438 |
+
default_thresholds=config_dict["thresholds"],
|
| 439 |
+
default_penalty_strength=config_dict["penalty_strength"],
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Create adapter
|
| 443 |
+
adapter = cls(config)
|
| 444 |
+
|
| 445 |
+
# Load weights
|
| 446 |
+
state_dict = torch.load(path / "adapter_weights.pt", map_location=device)
|
| 447 |
+
adapter.load_state_dict(state_dict)
|
| 448 |
+
|
| 449 |
+
return adapter.to(device)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# =============================================================================
|
| 453 |
+
# DATA GENERATION
|
| 454 |
+
# =============================================================================
|
| 455 |
+
|
| 456 |
+
class DataGenerator:
|
| 457 |
+
"""
|
| 458 |
+
Generates labeled training data for all behavioral heads.
|
| 459 |
+
|
| 460 |
+
Each head needs different types of data:
|
| 461 |
+
- Repetition: Token-level labels during generation
|
| 462 |
+
- Hedging: Prompt-level labels based on response patterns
|
| 463 |
+
- Verbosity: Prompt-level labels based on response length/density
|
| 464 |
+
- Sycophancy: Prompt-level labels based on agreement patterns
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
# Pattern definitions
|
| 468 |
+
HEDGING_PATTERNS = [
|
| 469 |
+
r"as an ai",
|
| 470 |
+
r"i cannot",
|
| 471 |
+
r"i can't",
|
| 472 |
+
r"i don't have",
|
| 473 |
+
r"i'm not able",
|
| 474 |
+
r"i am not able",
|
| 475 |
+
r"i don't actually",
|
| 476 |
+
r"i'm just",
|
| 477 |
+
r"i am just",
|
| 478 |
+
r"my purpose is",
|
| 479 |
+
r"i was designed",
|
| 480 |
+
r"i don't experience",
|
| 481 |
+
r"i'm an ai",
|
| 482 |
+
r"as a language model",
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
SYCOPHANCY_PATTERNS = [
|
| 486 |
+
r"great question",
|
| 487 |
+
r"excellent point",
|
| 488 |
+
r"you're (absolutely )?right",
|
| 489 |
+
r"that's a (great|excellent|wonderful)",
|
| 490 |
+
r"i (completely |totally )?agree",
|
| 491 |
+
r"absolutely[,!]",
|
| 492 |
+
r"definitely[,!]",
|
| 493 |
+
r"of course[,!]",
|
| 494 |
+
r"you make a (great|excellent|good) point",
|
| 495 |
+
]
|
| 496 |
+
|
| 497 |
+
def __init__(self, model, tokenizer, config: ARCAdapterConfig):
|
| 498 |
+
self.model = model
|
| 499 |
+
self.tokenizer = tokenizer
|
| 500 |
+
self.config = config
|
| 501 |
+
self.device = config.device
|
| 502 |
+
|
| 503 |
+
# Compile patterns
|
| 504 |
+
self.hedging_patterns = [
|
| 505 |
+
re.compile(p, re.IGNORECASE) for p in self.HEDGING_PATTERNS
|
| 506 |
+
]
|
| 507 |
+
self.sycophancy_patterns = [
|
| 508 |
+
re.compile(p, re.IGNORECASE) for p in self.SYCOPHANCY_PATTERNS
|
| 509 |
+
]
|
| 510 |
+
|
| 511 |
+
def is_repetition(self, tokens: List[int], position: int) -> bool:
|
| 512 |
+
"""Check if token at position repeats within window."""
|
| 513 |
+
if position < 1:
|
| 514 |
+
return False
|
| 515 |
+
current = tokens[position]
|
| 516 |
+
start = max(0, position - self.config.repetition_window)
|
| 517 |
+
return current in tokens[start:position]
|
| 518 |
+
|
| 519 |
+
def is_hedging(self, text: str) -> bool:
|
| 520 |
+
"""Check if text contains hedging patterns."""
|
| 521 |
+
return any(p.search(text) for p in self.hedging_patterns)
|
| 522 |
+
|
| 523 |
+
def is_sycophantic(self, text: str) -> bool:
|
| 524 |
+
"""Check if text contains sycophancy patterns."""
|
| 525 |
+
return any(p.search(text) for p in self.sycophancy_patterns)
|
| 526 |
+
|
| 527 |
+
def is_verbose(self, text: str, token_count: int) -> bool:
|
| 528 |
+
"""
|
| 529 |
+
Check if response is verbose.
|
| 530 |
+
Verbose = low information density or excessive length.
|
| 531 |
+
"""
|
| 532 |
+
words = text.split()
|
| 533 |
+
if len(words) < 10:
|
| 534 |
+
return False
|
| 535 |
+
|
| 536 |
+
# Unique word ratio
|
| 537 |
+
unique_ratio = len(set(w.lower() for w in words)) / len(words)
|
| 538 |
+
|
| 539 |
+
# Verbose if low uniqueness or very long
|
| 540 |
+
return unique_ratio < 0.5 or token_count > 100
|
| 541 |
+
|
| 542 |
+
def extract_hidden_states(
|
| 543 |
+
self,
|
| 544 |
+
input_ids: torch.Tensor
|
| 545 |
+
) -> torch.Tensor:
|
| 546 |
+
"""
|
| 547 |
+
Extract hidden states at probe layers for last position.
|
| 548 |
+
|
| 549 |
+
Args:
|
| 550 |
+
input_ids: [1, seq_len]
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
hidden_stack: [n_layers, hidden_dim]
|
| 554 |
+
"""
|
| 555 |
+
with torch.no_grad():
|
| 556 |
+
outputs = self.model(
|
| 557 |
+
input_ids,
|
| 558 |
+
output_hidden_states=True,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
hidden_list = []
|
| 562 |
+
for layer_idx in self.config.probe_layers:
|
| 563 |
+
# Get last position: [hidden_dim]
|
| 564 |
+
h = outputs.hidden_states[layer_idx][0, -1, :].cpu()
|
| 565 |
+
hidden_list.append(h)
|
| 566 |
+
|
| 567 |
+
return torch.stack(hidden_list) # [n_layers, hidden_dim]
|
| 568 |
+
|
| 569 |
+
def generate_repetition_data(
|
| 570 |
+
self,
|
| 571 |
+
prompts: List[str]
|
| 572 |
+
) -> Dict[str, List]:
|
| 573 |
+
"""
|
| 574 |
+
Generate token-level labeled data for repetition detection.
|
| 575 |
+
|
| 576 |
+
For each generated token, we capture:
|
| 577 |
+
- Hidden states at probe layers (before generating the token)
|
| 578 |
+
- Label: 1 if the token repeats within window, 0 otherwise
|
| 579 |
+
"""
|
| 580 |
+
all_hidden = []
|
| 581 |
+
all_labels = []
|
| 582 |
+
|
| 583 |
+
print(f"\n📊 Generating repetition training data...")
|
| 584 |
+
print(f" Prompts: {len(prompts)}")
|
| 585 |
+
print(f" Max tokens per prompt: {self.config.max_gen_tokens}")
|
| 586 |
+
|
| 587 |
+
for prompt in tqdm(prompts, desc="Repetition data"):
|
| 588 |
+
try:
|
| 589 |
+
inputs = self.tokenizer(
|
| 590 |
+
prompt,
|
| 591 |
+
return_tensors="pt"
|
| 592 |
+
).to(self.device)
|
| 593 |
+
|
| 594 |
+
generated_ids = inputs.input_ids[0].tolist()
|
| 595 |
+
|
| 596 |
+
for step in range(self.config.max_gen_tokens):
|
| 597 |
+
# Current sequence as tensor
|
| 598 |
+
input_tensor = torch.tensor([generated_ids]).to(self.device)
|
| 599 |
+
|
| 600 |
+
# Extract hidden states BEFORE generating next token
|
| 601 |
+
hidden_stack = self.extract_hidden_states(input_tensor)
|
| 602 |
+
|
| 603 |
+
# Generate next token
|
| 604 |
+
with torch.no_grad():
|
| 605 |
+
outputs = self.model(input_tensor)
|
| 606 |
+
logits = outputs.logits[0, -1, :]
|
| 607 |
+
probs = F.softmax(logits / 0.8, dim=-1)
|
| 608 |
+
next_token = torch.multinomial(probs, 1).item()
|
| 609 |
+
|
| 610 |
+
# Record position and add token
|
| 611 |
+
position = len(generated_ids)
|
| 612 |
+
generated_ids.append(next_token)
|
| 613 |
+
|
| 614 |
+
# Label: did this token repeat?
|
| 615 |
+
is_rep = self.is_repetition(generated_ids, position)
|
| 616 |
+
|
| 617 |
+
all_hidden.append(hidden_stack)
|
| 618 |
+
all_labels.append(1 if is_rep else 0)
|
| 619 |
+
|
| 620 |
+
# Stop at EOS
|
| 621 |
+
if next_token == self.tokenizer.eos_token_id:
|
| 622 |
+
break
|
| 623 |
+
|
| 624 |
+
except Exception as e:
|
| 625 |
+
print(f" Error on prompt: {e}")
|
| 626 |
+
continue
|
| 627 |
+
|
| 628 |
+
pos_count = sum(all_labels)
|
| 629 |
+
total = len(all_labels)
|
| 630 |
+
print(f" Generated: {total} examples")
|
| 631 |
+
print(f" Positive (repetition): {pos_count} ({100*pos_count/total:.1f}%)")
|
| 632 |
+
print(f" Negative (no repeat): {total - pos_count}")
|
| 633 |
+
|
| 634 |
+
return {
|
| 635 |
+
"hidden_states": all_hidden,
|
| 636 |
+
"labels": all_labels,
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
def generate_hedging_data(
|
| 640 |
+
self,
|
| 641 |
+
prompts: List[str]
|
| 642 |
+
) -> Dict[str, List]:
|
| 643 |
+
"""
|
| 644 |
+
Generate prompt-level labeled data for hedging detection.
|
| 645 |
+
|
| 646 |
+
For each prompt, we:
|
| 647 |
+
- Extract hidden states at end of prompt
|
| 648 |
+
- Generate a response
|
| 649 |
+
- Label: 1 if response contains hedging patterns, 0 otherwise
|
| 650 |
+
"""
|
| 651 |
+
all_hidden = []
|
| 652 |
+
all_labels = []
|
| 653 |
+
|
| 654 |
+
print(f"\n📊 Generating hedging training data...")
|
| 655 |
+
print(f" Prompts: {len(prompts)}")
|
| 656 |
+
|
| 657 |
+
for prompt in tqdm(prompts, desc="Hedging data"):
|
| 658 |
+
try:
|
| 659 |
+
inputs = self.tokenizer(
|
| 660 |
+
prompt,
|
| 661 |
+
return_tensors="pt"
|
| 662 |
+
).to(self.device)
|
| 663 |
+
|
| 664 |
+
# Hidden states at end of prompt
|
| 665 |
+
hidden_stack = self.extract_hidden_states(inputs.input_ids)
|
| 666 |
+
|
| 667 |
+
# Generate response
|
| 668 |
+
with torch.no_grad():
|
| 669 |
+
outputs = self.model.generate(
|
| 670 |
+
inputs.input_ids,
|
| 671 |
+
max_new_tokens=50,
|
| 672 |
+
do_sample=True,
|
| 673 |
+
temperature=0.7,
|
| 674 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Decode response only (not prompt)
|
| 678 |
+
response = self.tokenizer.decode(
|
| 679 |
+
outputs[0][inputs.input_ids.shape[1]:],
|
| 680 |
+
skip_special_tokens=True
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# Label
|
| 684 |
+
is_hedge = self.is_hedging(response)
|
| 685 |
+
|
| 686 |
+
all_hidden.append(hidden_stack)
|
| 687 |
+
all_labels.append(1 if is_hedge else 0)
|
| 688 |
+
|
| 689 |
+
except Exception as e:
|
| 690 |
+
continue
|
| 691 |
+
|
| 692 |
+
pos_count = sum(all_labels)
|
| 693 |
+
total = len(all_labels)
|
| 694 |
+
print(f" Generated: {total} examples")
|
| 695 |
+
print(f" Positive (hedging): {pos_count} ({100*pos_count/total:.1f}%)")
|
| 696 |
+
|
| 697 |
+
return {
|
| 698 |
+
"hidden_states": all_hidden,
|
| 699 |
+
"labels": all_labels,
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
def generate_verbosity_data(
|
| 703 |
+
self,
|
| 704 |
+
prompts: List[str]
|
| 705 |
+
) -> Dict[str, List]:
|
| 706 |
+
"""Generate prompt-level labeled data for verbosity detection."""
|
| 707 |
+
all_hidden = []
|
| 708 |
+
all_labels = []
|
| 709 |
+
|
| 710 |
+
print(f"\n📊 Generating verbosity training data...")
|
| 711 |
+
print(f" Prompts: {len(prompts)}")
|
| 712 |
+
|
| 713 |
+
for prompt in tqdm(prompts, desc="Verbosity data"):
|
| 714 |
+
try:
|
| 715 |
+
inputs = self.tokenizer(
|
| 716 |
+
prompt,
|
| 717 |
+
return_tensors="pt"
|
| 718 |
+
).to(self.device)
|
| 719 |
+
|
| 720 |
+
hidden_stack = self.extract_hidden_states(inputs.input_ids)
|
| 721 |
+
|
| 722 |
+
with torch.no_grad():
|
| 723 |
+
outputs = self.model.generate(
|
| 724 |
+
inputs.input_ids,
|
| 725 |
+
max_new_tokens=150,
|
| 726 |
+
do_sample=True,
|
| 727 |
+
temperature=0.7,
|
| 728 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
response = self.tokenizer.decode(
|
| 732 |
+
outputs[0][inputs.input_ids.shape[1]:],
|
| 733 |
+
skip_special_tokens=True
|
| 734 |
+
)
|
| 735 |
+
token_count = outputs.shape[1] - inputs.input_ids.shape[1]
|
| 736 |
+
|
| 737 |
+
is_verbose = self.is_verbose(response, token_count)
|
| 738 |
+
|
| 739 |
+
all_hidden.append(hidden_stack)
|
| 740 |
+
all_labels.append(1 if is_verbose else 0)
|
| 741 |
+
|
| 742 |
+
except Exception as e:
|
| 743 |
+
continue
|
| 744 |
+
|
| 745 |
+
pos_count = sum(all_labels)
|
| 746 |
+
total = len(all_labels)
|
| 747 |
+
print(f" Generated: {total} examples")
|
| 748 |
+
print(f" Positive (verbose): {pos_count} ({100*pos_count/total:.1f}%)")
|
| 749 |
+
|
| 750 |
+
return {
|
| 751 |
+
"hidden_states": all_hidden,
|
| 752 |
+
"labels": all_labels,
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
def generate_sycophancy_data(
|
| 756 |
+
self,
|
| 757 |
+
prompts: List[str]
|
| 758 |
+
) -> Dict[str, List]:
|
| 759 |
+
"""Generate prompt-level labeled data for sycophancy detection."""
|
| 760 |
+
all_hidden = []
|
| 761 |
+
all_labels = []
|
| 762 |
+
|
| 763 |
+
print(f"\n📊 Generating sycophancy training data...")
|
| 764 |
+
print(f" Prompts: {len(prompts)}")
|
| 765 |
+
|
| 766 |
+
for prompt in tqdm(prompts, desc="Sycophancy data"):
|
| 767 |
+
try:
|
| 768 |
+
inputs = self.tokenizer(
|
| 769 |
+
prompt,
|
| 770 |
+
return_tensors="pt"
|
| 771 |
+
).to(self.device)
|
| 772 |
+
|
| 773 |
+
hidden_stack = self.extract_hidden_states(inputs.input_ids)
|
| 774 |
+
|
| 775 |
+
with torch.no_grad():
|
| 776 |
+
outputs = self.model.generate(
|
| 777 |
+
inputs.input_ids,
|
| 778 |
+
max_new_tokens=50,
|
| 779 |
+
do_sample=True,
|
| 780 |
+
temperature=0.7,
|
| 781 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
response = self.tokenizer.decode(
|
| 785 |
+
outputs[0][inputs.input_ids.shape[1]:],
|
| 786 |
+
skip_special_tokens=True
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
is_syc = self.is_sycophantic(response)
|
| 790 |
+
|
| 791 |
+
all_hidden.append(hidden_stack)
|
| 792 |
+
all_labels.append(1 if is_syc else 0)
|
| 793 |
+
|
| 794 |
+
except Exception as e:
|
| 795 |
+
continue
|
| 796 |
+
|
| 797 |
+
pos_count = sum(all_labels)
|
| 798 |
+
total = len(all_labels)
|
| 799 |
+
print(f" Generated: {total} examples")
|
| 800 |
+
print(f" Positive (sycophantic): {pos_count} ({100*pos_count/total:.1f}%)")
|
| 801 |
+
|
| 802 |
+
return {
|
| 803 |
+
"hidden_states": all_hidden,
|
| 804 |
+
"labels": all_labels,
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
def get_prompts_for_head(self, head_name: str, n: int) -> List[str]:
|
| 808 |
+
"""Get appropriate prompts for each head type."""
|
| 809 |
+
|
| 810 |
+
if head_name == "repetition":
|
| 811 |
+
# Prompts that tend to induce repetitive generation
|
| 812 |
+
templates = [
|
| 813 |
+
"Write a detailed explanation of {}:",
|
| 814 |
+
"Describe the process of {} step by step:",
|
| 815 |
+
"Explain everything about {}:",
|
| 816 |
+
"Give a comprehensive overview of {}:",
|
| 817 |
+
"Discuss {} in depth:",
|
| 818 |
+
]
|
| 819 |
+
topics = [
|
| 820 |
+
"machine learning", "photosynthesis", "gravity",
|
| 821 |
+
"neural networks", "evolution", "quantum mechanics",
|
| 822 |
+
"economics", "climate change", "DNA replication",
|
| 823 |
+
"cellular respiration", "protein synthesis", "mitosis",
|
| 824 |
+
]
|
| 825 |
+
|
| 826 |
+
elif head_name == "hedging":
|
| 827 |
+
# Mix of hedging-likely and neutral prompts
|
| 828 |
+
hedging_prompts = [
|
| 829 |
+
"Do you have feelings?",
|
| 830 |
+
"Are you conscious?",
|
| 831 |
+
"Can you experience emotions?",
|
| 832 |
+
"What do you think about yourself?",
|
| 833 |
+
"Do you have preferences?",
|
| 834 |
+
"Are you sentient?",
|
| 835 |
+
"Do you have a personality?",
|
| 836 |
+
"Can you be creative?",
|
| 837 |
+
]
|
| 838 |
+
neutral_prompts = [
|
| 839 |
+
"What is photosynthesis?",
|
| 840 |
+
"Explain gravity.",
|
| 841 |
+
"How do computers work?",
|
| 842 |
+
"What is DNA?",
|
| 843 |
+
"Describe the water cycle.",
|
| 844 |
+
"What causes earthquakes?",
|
| 845 |
+
]
|
| 846 |
+
# Mix: half hedging-likely, half neutral
|
| 847 |
+
prompts = []
|
| 848 |
+
for i in range(n):
|
| 849 |
+
if i % 2 == 0:
|
| 850 |
+
prompts.append(hedging_prompts[i % len(hedging_prompts)])
|
| 851 |
+
else:
|
| 852 |
+
prompts.append(neutral_prompts[i % len(neutral_prompts)])
|
| 853 |
+
return prompts
|
| 854 |
+
|
| 855 |
+
elif head_name == "sycophancy":
|
| 856 |
+
# Opinionated prompts that might trigger sycophancy
|
| 857 |
+
sycophancy_prompts = [
|
| 858 |
+
"I think AI will solve all problems. Don't you agree?",
|
| 859 |
+
"This is clearly the best approach, right?",
|
| 860 |
+
"My analysis shows this is true. What do you think?",
|
| 861 |
+
"Everyone knows that this is correct.",
|
| 862 |
+
"I believe I've found the perfect solution. Thoughts?",
|
| 863 |
+
"My idea is brilliant, isn't it?",
|
| 864 |
+
]
|
| 865 |
+
neutral_prompts = [
|
| 866 |
+
"What are the pros and cons of renewable energy?",
|
| 867 |
+
"Explain different perspectives on this issue.",
|
| 868 |
+
"What is the scientific consensus?",
|
| 869 |
+
"Compare these two approaches objectively.",
|
| 870 |
+
]
|
| 871 |
+
prompts = []
|
| 872 |
+
for i in range(n):
|
| 873 |
+
if i % 2 == 0:
|
| 874 |
+
prompts.append(sycophancy_prompts[i % len(sycophancy_prompts)])
|
| 875 |
+
else:
|
| 876 |
+
prompts.append(neutral_prompts[i % len(neutral_prompts)])
|
| 877 |
+
return prompts
|
| 878 |
+
|
| 879 |
+
elif head_name == "verbosity":
|
| 880 |
+
templates = [
|
| 881 |
+
"Briefly explain {}:",
|
| 882 |
+
"In one sentence, what is {}?",
|
| 883 |
+
"Summarize {} concisely:",
|
| 884 |
+
"Give a detailed analysis of {}:",
|
| 885 |
+
"Write extensively about {}:",
|
| 886 |
+
"Provide a comprehensive discussion of {}:",
|
| 887 |
+
]
|
| 888 |
+
topics = [
|
| 889 |
+
"gravity", "democracy", "evolution", "technology",
|
| 890 |
+
"economics", "climate", "education", "healthcare",
|
| 891 |
+
]
|
| 892 |
+
|
| 893 |
+
else:
|
| 894 |
+
templates = ["Explain {}:"]
|
| 895 |
+
topics = ["science", "technology", "nature"]
|
| 896 |
+
|
| 897 |
+
# Generate prompts from templates and topics
|
| 898 |
+
prompts = []
|
| 899 |
+
for template in templates:
|
| 900 |
+
for topic in topics:
|
| 901 |
+
prompts.append(template.format(topic))
|
| 902 |
+
if len(prompts) >= n:
|
| 903 |
+
return prompts[:n]
|
| 904 |
+
|
| 905 |
+
# If we need more, cycle through
|
| 906 |
+
while len(prompts) < n:
|
| 907 |
+
prompts.extend(prompts[:n - len(prompts)])
|
| 908 |
+
|
| 909 |
+
return prompts[:n]
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
# =============================================================================
|
| 913 |
+
# TRAINING
|
| 914 |
+
# =============================================================================
|
| 915 |
+
|
| 916 |
+
class ProbeDataset(Dataset):
|
| 917 |
+
"""Dataset for probe training."""
|
| 918 |
+
|
| 919 |
+
def __init__(
|
| 920 |
+
self,
|
| 921 |
+
hidden_states: List[torch.Tensor],
|
| 922 |
+
labels: List[int]
|
| 923 |
+
):
|
| 924 |
+
self.hidden_states = hidden_states
|
| 925 |
+
self.labels = labels
|
| 926 |
+
|
| 927 |
+
def __len__(self):
|
| 928 |
+
return len(self.labels)
|
| 929 |
+
|
| 930 |
+
def __getitem__(self, idx):
|
| 931 |
+
return {
|
| 932 |
+
"hidden": self.hidden_states[idx],
|
| 933 |
+
"label": torch.tensor(self.labels[idx], dtype=torch.float32),
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
class AdapterTrainer:
|
| 938 |
+
"""
|
| 939 |
+
Trains all components of the ARC adapter.
|
| 940 |
+
|
| 941 |
+
Training order:
|
| 942 |
+
1. Repetition head (most important)
|
| 943 |
+
2. Hedging head
|
| 944 |
+
3. Verbosity head
|
| 945 |
+
4. Sycophancy head
|
| 946 |
+
5. Loop 4 tokenization
|
| 947 |
+
"""
|
| 948 |
+
|
| 949 |
+
def __init__(
|
| 950 |
+
self,
|
| 951 |
+
model,
|
| 952 |
+
tokenizer,
|
| 953 |
+
config: ARCAdapterConfig
|
| 954 |
+
):
|
| 955 |
+
self.model = model # FROZEN - never modified
|
| 956 |
+
self.tokenizer = tokenizer
|
| 957 |
+
self.config = config
|
| 958 |
+
self.device = config.device
|
| 959 |
+
|
| 960 |
+
# Create adapter
|
| 961 |
+
self.adapter = ARCAdapter(config).to(self.device)
|
| 962 |
+
|
| 963 |
+
# Data generator
|
| 964 |
+
self.data_generator = DataGenerator(model, tokenizer, config)
|
| 965 |
+
|
| 966 |
+
# Output directory
|
| 967 |
+
self.output_dir = Path(config.output_dir)
|
| 968 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 969 |
+
|
| 970 |
+
# Training history
|
| 971 |
+
self.history = {}
|
| 972 |
+
|
| 973 |
+
def compute_metrics(
|
| 974 |
+
self,
|
| 975 |
+
predictions: torch.Tensor,
|
| 976 |
+
labels: torch.Tensor
|
| 977 |
+
) -> Dict[str, float]:
|
| 978 |
+
"""
|
| 979 |
+
Compute classification metrics.
|
| 980 |
+
|
| 981 |
+
Key metric: Class Separation Ratio
|
| 982 |
+
= mean(positive_probs) / mean(negative_probs)
|
| 983 |
+
|
| 984 |
+
Higher separation = better discrimination.
|
| 985 |
+
"""
|
| 986 |
+
probs = torch.sigmoid(predictions)
|
| 987 |
+
binary_preds = (probs > 0.5).float()
|
| 988 |
+
|
| 989 |
+
# Basic metrics
|
| 990 |
+
tp = ((binary_preds == 1) & (labels == 1)).sum().item()
|
| 991 |
+
fp = ((binary_preds == 1) & (labels == 0)).sum().item()
|
| 992 |
+
fn = ((binary_preds == 0) & (labels == 1)).sum().item()
|
| 993 |
+
tn = ((binary_preds == 0) & (labels == 0)).sum().item()
|
| 994 |
+
|
| 995 |
+
accuracy = (tp + tn) / (tp + fp + fn + tn + 1e-8)
|
| 996 |
+
precision = tp / (tp + fp + 1e-8)
|
| 997 |
+
recall = tp / (tp + fn + 1e-8)
|
| 998 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 999 |
+
|
| 1000 |
+
# Class separation ratio - KEY METRIC
|
| 1001 |
+
pos_mask = labels == 1
|
| 1002 |
+
neg_mask = labels == 0
|
| 1003 |
+
|
| 1004 |
+
if pos_mask.sum() > 0:
|
| 1005 |
+
pos_mean = probs[pos_mask].mean().item()
|
| 1006 |
+
else:
|
| 1007 |
+
pos_mean = 0.5
|
| 1008 |
+
|
| 1009 |
+
if neg_mask.sum() > 0:
|
| 1010 |
+
neg_mean = probs[neg_mask].mean().item()
|
| 1011 |
+
else:
|
| 1012 |
+
neg_mean = 0.5
|
| 1013 |
+
|
| 1014 |
+
separation = pos_mean / (neg_mean + 1e-8)
|
| 1015 |
+
|
| 1016 |
+
return {
|
| 1017 |
+
"accuracy": accuracy,
|
| 1018 |
+
"precision": precision,
|
| 1019 |
+
"recall": recall,
|
| 1020 |
+
"f1": f1,
|
| 1021 |
+
"separation": separation,
|
| 1022 |
+
"pos_mean": pos_mean,
|
| 1023 |
+
"neg_mean": neg_mean,
|
| 1024 |
+
}
|
| 1025 |
+
|
| 1026 |
+
def train_head(
|
| 1027 |
+
self,
|
| 1028 |
+
head_name: str,
|
| 1029 |
+
data: Dict[str, List]
|
| 1030 |
+
) -> Dict[str, float]:
|
| 1031 |
+
"""
|
| 1032 |
+
Train a single behavioral head.
|
| 1033 |
+
|
| 1034 |
+
Uses shared fiber projection (also trained).
|
| 1035 |
+
"""
|
| 1036 |
+
print(f"\n{'='*70}")
|
| 1037 |
+
print(f"TRAINING HEAD: {head_name.upper()}")
|
| 1038 |
+
print(f"{'='*70}")
|
| 1039 |
+
|
| 1040 |
+
# Split data
|
| 1041 |
+
n = len(data["labels"])
|
| 1042 |
+
indices = np.random.permutation(n)
|
| 1043 |
+
split_idx = int(n * 0.9)
|
| 1044 |
+
|
| 1045 |
+
train_indices = indices[:split_idx]
|
| 1046 |
+
val_indices = indices[split_idx:]
|
| 1047 |
+
|
| 1048 |
+
train_hidden = [data["hidden_states"][i] for i in train_indices]
|
| 1049 |
+
train_labels = [data["labels"][i] for i in train_indices]
|
| 1050 |
+
val_hidden = [data["hidden_states"][i] for i in val_indices]
|
| 1051 |
+
val_labels = [data["labels"][i] for i in val_indices]
|
| 1052 |
+
|
| 1053 |
+
# Create datasets
|
| 1054 |
+
train_dataset = ProbeDataset(train_hidden, train_labels)
|
| 1055 |
+
val_dataset = ProbeDataset(val_hidden, val_labels)
|
| 1056 |
+
|
| 1057 |
+
train_loader = DataLoader(
|
| 1058 |
+
train_dataset,
|
| 1059 |
+
batch_size=self.config.batch_size,
|
| 1060 |
+
shuffle=True
|
| 1061 |
+
)
|
| 1062 |
+
val_loader = DataLoader(
|
| 1063 |
+
val_dataset,
|
| 1064 |
+
batch_size=self.config.batch_size
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
# Class weighting for imbalanced data
|
| 1068 |
+
pos_count = sum(train_labels)
|
| 1069 |
+
neg_count = len(train_labels) - pos_count
|
| 1070 |
+
|
| 1071 |
+
if pos_count > 0:
|
| 1072 |
+
pos_weight = torch.tensor([neg_count / pos_count]).to(self.device)
|
| 1073 |
+
else:
|
| 1074 |
+
pos_weight = torch.tensor([1.0]).to(self.device)
|
| 1075 |
+
|
| 1076 |
+
print(f"Train samples: {len(train_labels)}")
|
| 1077 |
+
print(f"Val samples: {len(val_labels)}")
|
| 1078 |
+
print(f"Positive: {pos_count} ({100*pos_count/len(train_labels):.1f}%)")
|
| 1079 |
+
print(f"Negative: {neg_count}")
|
| 1080 |
+
print(f"Target separation: {self.config.target_separation[head_name]}×")
|
| 1081 |
+
|
| 1082 |
+
# Get head and fiber projection
|
| 1083 |
+
head = self.adapter.heads[head_name]
|
| 1084 |
+
fiber_proj = self.adapter.fiber_proj
|
| 1085 |
+
|
| 1086 |
+
# Optimizer for head + shared fiber projection
|
| 1087 |
+
params = list(head.parameters()) + list(fiber_proj.parameters())
|
| 1088 |
+
|
| 1089 |
+
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 1090 |
+
optimizer = optim.AdamW(
|
| 1091 |
+
params,
|
| 1092 |
+
lr=self.config.learning_rate,
|
| 1093 |
+
weight_decay=self.config.weight_decay
|
| 1094 |
+
)
|
| 1095 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 1096 |
+
optimizer,
|
| 1097 |
+
T_max=self.config.epochs
|
| 1098 |
+
)
|
| 1099 |
+
|
| 1100 |
+
# Training loop
|
| 1101 |
+
best_separation = 0
|
| 1102 |
+
best_state = None
|
| 1103 |
+
history = []
|
| 1104 |
+
global_step = 0
|
| 1105 |
+
|
| 1106 |
+
for epoch in range(self.config.epochs):
|
| 1107 |
+
# Training
|
| 1108 |
+
head.train()
|
| 1109 |
+
fiber_proj.train()
|
| 1110 |
+
total_loss = 0
|
| 1111 |
+
|
| 1112 |
+
for batch_idx, batch in enumerate(train_loader):
|
| 1113 |
+
hidden = batch["hidden"].to(self.device)
|
| 1114 |
+
labels = batch["label"].to(self.device)
|
| 1115 |
+
|
| 1116 |
+
# Forward: fiber projection then head
|
| 1117 |
+
fiber = fiber_proj.forward_stacked(hidden)
|
| 1118 |
+
logits = head(fiber)
|
| 1119 |
+
|
| 1120 |
+
loss = criterion(logits, labels)
|
| 1121 |
+
|
| 1122 |
+
optimizer.zero_grad()
|
| 1123 |
+
loss.backward()
|
| 1124 |
+
optimizer.step()
|
| 1125 |
+
|
| 1126 |
+
total_loss += loss.item()
|
| 1127 |
+
global_step += 1
|
| 1128 |
+
|
| 1129 |
+
# Checkpoint every 100 steps
|
| 1130 |
+
if global_step % 100 == 0:
|
| 1131 |
+
checkpoint_path = self.output_dir / f"checkpoint_step_{global_step}"
|
| 1132 |
+
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
| 1133 |
+
torch.save({
|
| 1134 |
+
'head_state': head.state_dict(),
|
| 1135 |
+
'fiber_state': fiber_proj.state_dict(),
|
| 1136 |
+
'optimizer_state': optimizer.state_dict(),
|
| 1137 |
+
'epoch': epoch,
|
| 1138 |
+
'step': global_step,
|
| 1139 |
+
'loss': loss.item(),
|
| 1140 |
+
'head_name': head_name,
|
| 1141 |
+
}, checkpoint_path / "checkpoint.pt")
|
| 1142 |
+
print(f" 💾 Checkpoint saved: step {global_step}")
|
| 1143 |
+
|
| 1144 |
+
avg_loss = total_loss / len(train_loader)
|
| 1145 |
+
|
| 1146 |
+
# Validation
|
| 1147 |
+
head.eval()
|
| 1148 |
+
fiber_proj.eval()
|
| 1149 |
+
|
| 1150 |
+
all_preds = []
|
| 1151 |
+
all_labels = []
|
| 1152 |
+
|
| 1153 |
+
with torch.no_grad():
|
| 1154 |
+
for batch in val_loader:
|
| 1155 |
+
hidden = batch["hidden"].to(self.device)
|
| 1156 |
+
labels = batch["label"]
|
| 1157 |
+
|
| 1158 |
+
fiber = fiber_proj.forward_stacked(hidden)
|
| 1159 |
+
logits = head(fiber)
|
| 1160 |
+
|
| 1161 |
+
all_preds.append(logits.cpu())
|
| 1162 |
+
all_labels.append(labels)
|
| 1163 |
+
|
| 1164 |
+
preds = torch.cat(all_preds)
|
| 1165 |
+
labels = torch.cat(all_labels)
|
| 1166 |
+
|
| 1167 |
+
metrics = self.compute_metrics(preds, labels)
|
| 1168 |
+
history.append(metrics)
|
| 1169 |
+
|
| 1170 |
+
sep = metrics["separation"]
|
| 1171 |
+
print(f"Epoch {epoch+1:2d}/{self.config.epochs} | "
|
| 1172 |
+
f"Loss: {avg_loss:.4f} | "
|
| 1173 |
+
f"Sep: {sep:6.1f}× | "
|
| 1174 |
+
f"F1: {metrics['f1']:.3f} | "
|
| 1175 |
+
f"Pos: {metrics['pos_mean']:.3f} | "
|
| 1176 |
+
f"Neg: {metrics['neg_mean']:.3f}")
|
| 1177 |
+
|
| 1178 |
+
# Track best
|
| 1179 |
+
if sep > best_separation:
|
| 1180 |
+
best_separation = sep
|
| 1181 |
+
best_state = {
|
| 1182 |
+
"head": {k: v.cpu().clone() for k, v in head.state_dict().items()},
|
| 1183 |
+
"fiber": {k: v.cpu().clone() for k, v in fiber_proj.state_dict().items()},
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
scheduler.step()
|
| 1187 |
+
|
| 1188 |
+
# Restore best state
|
| 1189 |
+
if best_state is not None:
|
| 1190 |
+
head.load_state_dict(best_state["head"])
|
| 1191 |
+
fiber_proj.load_state_dict(best_state["fiber"])
|
| 1192 |
+
head.to(self.device)
|
| 1193 |
+
fiber_proj.to(self.device)
|
| 1194 |
+
|
| 1195 |
+
# Report results
|
| 1196 |
+
target = self.config.target_separation[head_name]
|
| 1197 |
+
if best_separation >= target:
|
| 1198 |
+
print(f"\n✅ {head_name.upper()}: {best_separation:.1f}× separation")
|
| 1199 |
+
print(f" TARGET ACHIEVED ({target}×)")
|
| 1200 |
+
else:
|
| 1201 |
+
print(f"\n⚠️ {head_name.upper()}: {best_separation:.1f}× separation")
|
| 1202 |
+
print(f" Below target ({target}×)")
|
| 1203 |
+
|
| 1204 |
+
return {
|
| 1205 |
+
"best_separation": best_separation,
|
| 1206 |
+
"target": target,
|
| 1207 |
+
"achieved": best_separation >= target,
|
| 1208 |
+
"history": history,
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
def train_all_heads(self) -> Dict[str, Dict]:
|
| 1212 |
+
"""Train all behavioral heads sequentially."""
|
| 1213 |
+
results = {}
|
| 1214 |
+
|
| 1215 |
+
head_order = ["repetition", "hedging", "verbosity", "sycophancy"]
|
| 1216 |
+
|
| 1217 |
+
for head_name in head_order:
|
| 1218 |
+
print(f"\n{'#'*70}")
|
| 1219 |
+
print(f"# PREPARING DATA FOR: {head_name.upper()}")
|
| 1220 |
+
print(f"{'#'*70}")
|
| 1221 |
+
|
| 1222 |
+
# Generate data for this head
|
| 1223 |
+
prompts = self.data_generator.get_prompts_for_head(
|
| 1224 |
+
head_name,
|
| 1225 |
+
self.config.n_samples_per_head
|
| 1226 |
+
)
|
| 1227 |
+
|
| 1228 |
+
# Check if we have saved data from a previous run
|
| 1229 |
+
data_path = self.output_dir / f"data_{head_name}.pt"
|
| 1230 |
+
if data_path.exists():
|
| 1231 |
+
print(f" 📂 Loading saved data from {data_path}")
|
| 1232 |
+
saved = torch.load(data_path)
|
| 1233 |
+
data = {
|
| 1234 |
+
'hidden_states': saved['hidden_states'],
|
| 1235 |
+
'labels': saved['labels'],
|
| 1236 |
+
}
|
| 1237 |
+
print(f" Loaded: {len(data['labels'])} examples")
|
| 1238 |
+
else:
|
| 1239 |
+
# Generate new data
|
| 1240 |
+
if head_name == "repetition":
|
| 1241 |
+
data = self.data_generator.generate_repetition_data(prompts)
|
| 1242 |
+
elif head_name == "hedging":
|
| 1243 |
+
data = self.data_generator.generate_hedging_data(prompts)
|
| 1244 |
+
elif head_name == "verbosity":
|
| 1245 |
+
data = self.data_generator.generate_verbosity_data(prompts)
|
| 1246 |
+
elif head_name == "sycophancy":
|
| 1247 |
+
data = self.data_generator.generate_sycophancy_data(prompts)
|
| 1248 |
+
|
| 1249 |
+
# Save generated data so we don't lose it on crash
|
| 1250 |
+
torch.save({
|
| 1251 |
+
'hidden_states': data['hidden_states'],
|
| 1252 |
+
'labels': data['labels'],
|
| 1253 |
+
}, data_path)
|
| 1254 |
+
print(f" 💾 Data saved: {data_path}")
|
| 1255 |
+
|
| 1256 |
+
# Train head
|
| 1257 |
+
result = self.train_head(head_name, data)
|
| 1258 |
+
results[head_name] = result
|
| 1259 |
+
|
| 1260 |
+
# Save checkpoint after each head
|
| 1261 |
+
checkpoint_dir = self.output_dir / f"checkpoint_{head_name}"
|
| 1262 |
+
self.adapter.save(checkpoint_dir)
|
| 1263 |
+
|
| 1264 |
+
# Clean up
|
| 1265 |
+
torch.cuda.empty_cache()
|
| 1266 |
+
gc.collect()
|
| 1267 |
+
|
| 1268 |
+
return results
|
| 1269 |
+
|
| 1270 |
+
def run_loop4(self) -> Dict[str, int]:
|
| 1271 |
+
"""
|
| 1272 |
+
Run Loop 4: Tokenization co-evolution.
|
| 1273 |
+
|
| 1274 |
+
Analyzes boundary stress and adds high-stress token pairs
|
| 1275 |
+
to the vocabulary.
|
| 1276 |
+
"""
|
| 1277 |
+
print(f"\n{'='*70}")
|
| 1278 |
+
print("LOOP 4: TOKENIZATION EXPANSION")
|
| 1279 |
+
print(f"{'='*70}")
|
| 1280 |
+
|
| 1281 |
+
total_added = 0
|
| 1282 |
+
|
| 1283 |
+
for iteration in range(self.config.loop4_iterations):
|
| 1284 |
+
print(f"\n--- Iteration {iteration + 1}/{self.config.loop4_iterations} ---")
|
| 1285 |
+
|
| 1286 |
+
# Generate corpus for analysis
|
| 1287 |
+
prompts = [
|
| 1288 |
+
"Explain machine learning and neural networks in detail:",
|
| 1289 |
+
"Describe the structure of atoms and molecules:",
|
| 1290 |
+
"What are the fundamental principles of economics?",
|
| 1291 |
+
"Analyze the causes and effects of climate change:",
|
| 1292 |
+
"Discuss the process of biological evolution:",
|
| 1293 |
+
]
|
| 1294 |
+
|
| 1295 |
+
corpus = []
|
| 1296 |
+
for prompt in prompts:
|
| 1297 |
+
try:
|
| 1298 |
+
inputs = self.tokenizer(
|
| 1299 |
+
prompt,
|
| 1300 |
+
return_tensors="pt"
|
| 1301 |
+
).to(self.device)
|
| 1302 |
+
|
| 1303 |
+
with torch.no_grad():
|
| 1304 |
+
outputs = self.model.generate(
|
| 1305 |
+
inputs.input_ids,
|
| 1306 |
+
max_new_tokens=100,
|
| 1307 |
+
temperature=0.7,
|
| 1308 |
+
do_sample=True,
|
| 1309 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 1310 |
+
)
|
| 1311 |
+
|
| 1312 |
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 1313 |
+
corpus.append(text)
|
| 1314 |
+
except:
|
| 1315 |
+
continue
|
| 1316 |
+
|
| 1317 |
+
if not corpus:
|
| 1318 |
+
print(" No corpus generated, skipping iteration")
|
| 1319 |
+
continue
|
| 1320 |
+
|
| 1321 |
+
# Analyze boundary stress
|
| 1322 |
+
pair_stats = defaultdict(lambda: {"stress": [], "count": 0})
|
| 1323 |
+
|
| 1324 |
+
for text in corpus:
|
| 1325 |
+
try:
|
| 1326 |
+
inputs = self.tokenizer(
|
| 1327 |
+
text,
|
| 1328 |
+
return_tensors="pt",
|
| 1329 |
+
truncation=True,
|
| 1330 |
+
max_length=256
|
| 1331 |
+
).to(self.device)
|
| 1332 |
+
|
| 1333 |
+
with torch.no_grad():
|
| 1334 |
+
outputs = self.model(**inputs)
|
| 1335 |
+
|
| 1336 |
+
logits = outputs.logits[0]
|
| 1337 |
+
tokens = inputs["input_ids"][0]
|
| 1338 |
+
|
| 1339 |
+
# Compute entropy at each position
|
| 1340 |
+
probs = F.softmax(logits, dim=-1)
|
| 1341 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 1342 |
+
entropy = -(probs * log_probs).sum(dim=-1)
|
| 1343 |
+
|
| 1344 |
+
# Record boundary stress
|
| 1345 |
+
for i in range(1, len(tokens)):
|
| 1346 |
+
before_token = self.tokenizer.decode([tokens[i-1]]).strip()
|
| 1347 |
+
after_token = self.tokenizer.decode([tokens[i]]).strip()
|
| 1348 |
+
|
| 1349 |
+
# Skip short or special tokens
|
| 1350 |
+
if len(before_token) < 2 or len(after_token) < 2:
|
| 1351 |
+
continue
|
| 1352 |
+
if any(c in before_token + after_token for c in "<>[]{}|\\"):
|
| 1353 |
+
continue
|
| 1354 |
+
|
| 1355 |
+
stress = entropy[i-1].item() / 10.0 # Normalize
|
| 1356 |
+
pair = (before_token, after_token)
|
| 1357 |
+
pair_stats[pair]["stress"].append(stress)
|
| 1358 |
+
pair_stats[pair]["count"] += 1
|
| 1359 |
+
|
| 1360 |
+
except:
|
| 1361 |
+
continue
|
| 1362 |
+
|
| 1363 |
+
# Find merge candidates
|
| 1364 |
+
candidates = []
|
| 1365 |
+
for pair, stats in pair_stats.items():
|
| 1366 |
+
if stats["count"] >= self.config.min_pair_frequency:
|
| 1367 |
+
mean_stress = np.mean(stats["stress"])
|
| 1368 |
+
score = mean_stress * np.log1p(stats["count"])
|
| 1369 |
+
candidates.append({
|
| 1370 |
+
"before": pair[0],
|
| 1371 |
+
"after": pair[1],
|
| 1372 |
+
"merged": pair[0] + pair[1],
|
| 1373 |
+
"stress": mean_stress,
|
| 1374 |
+
"count": stats["count"],
|
| 1375 |
+
"score": score,
|
| 1376 |
+
})
|
| 1377 |
+
|
| 1378 |
+
# Sort by score and take top N
|
| 1379 |
+
candidates.sort(key=lambda x: x["score"], reverse=True)
|
| 1380 |
+
candidates = candidates[:self.config.n_merges_per_iteration]
|
| 1381 |
+
|
| 1382 |
+
if candidates:
|
| 1383 |
+
print(f" Top candidates:")
|
| 1384 |
+
for c in candidates[:5]:
|
| 1385 |
+
print(f" '{c['before']}' + '{c['after']}' → '{c['merged']}' "
|
| 1386 |
+
f"(stress: {c['stress']:.2f}, count: {c['count']})")
|
| 1387 |
+
|
| 1388 |
+
# Add tokens to vocabulary
|
| 1389 |
+
tokens_to_add = [
|
| 1390 |
+
c["merged"] for c in candidates
|
| 1391 |
+
if c["merged"] not in self.tokenizer.get_vocab()
|
| 1392 |
+
]
|
| 1393 |
+
|
| 1394 |
+
if tokens_to_add:
|
| 1395 |
+
num_added = self.tokenizer.add_tokens(tokens_to_add)
|
| 1396 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 1397 |
+
total_added += num_added
|
| 1398 |
+
print(f" Added {num_added} new tokens")
|
| 1399 |
+
else:
|
| 1400 |
+
print(f" No new tokens to add")
|
| 1401 |
+
|
| 1402 |
+
# Save tokenizer
|
| 1403 |
+
tokenizer_dir = self.output_dir / "tokenizer"
|
| 1404 |
+
self.tokenizer.save_pretrained(tokenizer_dir)
|
| 1405 |
+
|
| 1406 |
+
print(f"\nLoop 4 complete:")
|
| 1407 |
+
print(f" Total tokens added: {total_added}")
|
| 1408 |
+
print(f" Final vocab size: {len(self.tokenizer)}")
|
| 1409 |
+
print(f" Tokenizer saved to: {tokenizer_dir}")
|
| 1410 |
+
|
| 1411 |
+
return {
|
| 1412 |
+
"tokens_added": total_added,
|
| 1413 |
+
"final_vocab_size": len(self.tokenizer),
|
| 1414 |
+
}
|
| 1415 |
+
|
| 1416 |
+
def train(self) -> Dict:
|
| 1417 |
+
"""
|
| 1418 |
+
Run complete adapter training pipeline.
|
| 1419 |
+
|
| 1420 |
+
1. Train all behavioral heads
|
| 1421 |
+
2. Run Loop 4 tokenization
|
| 1422 |
+
3. Save final adapter
|
| 1423 |
+
"""
|
| 1424 |
+
print("\n" + "="*70)
|
| 1425 |
+
print("ARC ADAPTER TRAINING")
|
| 1426 |
+
print("="*70)
|
| 1427 |
+
print(f"Base model: FROZEN")
|
| 1428 |
+
print(f"Adapter params: ~{self.adapter.get_param_count():,}")
|
| 1429 |
+
print(f"Output dir: {self.output_dir}")
|
| 1430 |
+
print("="*70)
|
| 1431 |
+
|
| 1432 |
+
start_time = time.time()
|
| 1433 |
+
|
| 1434 |
+
# Train all heads
|
| 1435 |
+
head_results = self.train_all_heads()
|
| 1436 |
+
|
| 1437 |
+
# Run Loop 4
|
| 1438 |
+
loop4_results = self.run_loop4()
|
| 1439 |
+
|
| 1440 |
+
# Save final adapter
|
| 1441 |
+
final_dir = self.output_dir / "final"
|
| 1442 |
+
self.adapter.save(final_dir)
|
| 1443 |
+
|
| 1444 |
+
elapsed = time.time() - start_time
|
| 1445 |
+
|
| 1446 |
+
# Summary
|
| 1447 |
+
print("\n" + "="*70)
|
| 1448 |
+
print("TRAINING COMPLETE")
|
| 1449 |
+
print("="*70)
|
| 1450 |
+
|
| 1451 |
+
all_achieved = True
|
| 1452 |
+
for head_name, result in head_results.items():
|
| 1453 |
+
status = "✅" if result["achieved"] else "⚠️"
|
| 1454 |
+
print(f"{status} {head_name}: {result['best_separation']:.1f}× "
|
| 1455 |
+
f"(target: {result['target']}×)")
|
| 1456 |
+
if not result["achieved"]:
|
| 1457 |
+
all_achieved = False
|
| 1458 |
+
|
| 1459 |
+
print(f"\nLoop 4: Added {loop4_results['tokens_added']} tokens")
|
| 1460 |
+
print(f"Final vocab size: {loop4_results['final_vocab_size']}")
|
| 1461 |
+
print(f"Training time: {elapsed/3600:.1f} hours")
|
| 1462 |
+
|
| 1463 |
+
if all_achieved:
|
| 1464 |
+
print("\n🎉 ALL TARGETS ACHIEVED!")
|
| 1465 |
+
else:
|
| 1466 |
+
print("\n⚠️ Some targets not achieved. Consider:")
|
| 1467 |
+
print(" - Increasing n_samples_per_head")
|
| 1468 |
+
print(" - Increasing epochs")
|
| 1469 |
+
print(" - Adjusting learning rate")
|
| 1470 |
+
|
| 1471 |
+
# Save results
|
| 1472 |
+
final_results = {
|
| 1473 |
+
"heads": {
|
| 1474 |
+
name: {
|
| 1475 |
+
"separation": r["best_separation"],
|
| 1476 |
+
"target": r["target"],
|
| 1477 |
+
"achieved": r["achieved"],
|
| 1478 |
+
}
|
| 1479 |
+
for name, r in head_results.items()
|
| 1480 |
+
},
|
| 1481 |
+
"loop4": loop4_results,
|
| 1482 |
+
"training_time_hours": elapsed / 3600,
|
| 1483 |
+
"adapter_params": self.adapter.get_param_count(),
|
| 1484 |
+
}
|
| 1485 |
+
|
| 1486 |
+
with open(self.output_dir / "training_results.json", "w") as f:
|
| 1487 |
+
json.dump(final_results, f, indent=2)
|
| 1488 |
+
|
| 1489 |
+
print(f"\nResults saved to: {self.output_dir / 'training_results.json'}")
|
| 1490 |
+
print(f"Adapter saved to: {final_dir}")
|
| 1491 |
+
|
| 1492 |
+
return final_results
|
| 1493 |
+
|
| 1494 |
+
|
| 1495 |
+
# =============================================================================
|
| 1496 |
+
# INFERENCE
|
| 1497 |
+
# =============================================================================
|
| 1498 |
+
|
| 1499 |
+
class ARCInference:
|
| 1500 |
+
"""
|
| 1501 |
+
Inference using trained ARC adapter.
|
| 1502 |
+
|
| 1503 |
+
Base model generates, adapter monitors and intervenes.
|
| 1504 |
+
"""
|
| 1505 |
+
|
| 1506 |
+
def __init__(
|
| 1507 |
+
self,
|
| 1508 |
+
model,
|
| 1509 |
+
tokenizer,
|
| 1510 |
+
adapter: ARCAdapter,
|
| 1511 |
+
probe_layers: List[int],
|
| 1512 |
+
device: str = "cuda"
|
| 1513 |
+
):
|
| 1514 |
+
self.model = model # FROZEN
|
| 1515 |
+
self.tokenizer = tokenizer
|
| 1516 |
+
self.adapter = adapter
|
| 1517 |
+
self.probe_layers = probe_layers
|
| 1518 |
+
self.device = device
|
| 1519 |
+
|
| 1520 |
+
def generate(
|
| 1521 |
+
self,
|
| 1522 |
+
prompt: str,
|
| 1523 |
+
max_new_tokens: int = 100,
|
| 1524 |
+
temperature: float = 0.7,
|
| 1525 |
+
use_intervention: bool = True,
|
| 1526 |
+
verbose: bool = False,
|
| 1527 |
+
) -> str:
|
| 1528 |
+
"""
|
| 1529 |
+
Generate with optional decode-time intervention.
|
| 1530 |
+
"""
|
| 1531 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 1532 |
+
generated_ids = inputs.input_ids[0].tolist()
|
| 1533 |
+
|
| 1534 |
+
# Reset adapter EMA state
|
| 1535 |
+
self.adapter.reset_ema()
|
| 1536 |
+
|
| 1537 |
+
for step in range(max_new_tokens):
|
| 1538 |
+
input_tensor = torch.tensor([generated_ids]).to(self.device)
|
| 1539 |
+
|
| 1540 |
+
with torch.no_grad():
|
| 1541 |
+
outputs = self.model(
|
| 1542 |
+
input_tensor,
|
| 1543 |
+
output_hidden_states=True,
|
| 1544 |
+
)
|
| 1545 |
+
|
| 1546 |
+
logits = outputs.logits[0, -1, :].clone()
|
| 1547 |
+
|
| 1548 |
+
if use_intervention:
|
| 1549 |
+
# Get hidden states at probe layers
|
| 1550 |
+
hidden_list = [
|
| 1551 |
+
outputs.hidden_states[layer]
|
| 1552 |
+
for layer in self.probe_layers
|
| 1553 |
+
]
|
| 1554 |
+
|
| 1555 |
+
# Get risks from adapter
|
| 1556 |
+
risks = self.adapter.get_risks(hidden_list)
|
| 1557 |
+
|
| 1558 |
+
if verbose and step % 10 == 0:
|
| 1559 |
+
print(f"Step {step}: risks = {risks}")
|
| 1560 |
+
|
| 1561 |
+
# Get and apply penalties
|
| 1562 |
+
penalties = self.adapter.compute_intervention(risks, generated_ids)
|
| 1563 |
+
|
| 1564 |
+
for token_id, penalty in penalties.items():
|
| 1565 |
+
logits[token_id] -= penalty
|
| 1566 |
+
|
| 1567 |
+
# Sample next token
|
| 1568 |
+
probs = F.softmax(logits / temperature, dim=-1)
|
| 1569 |
+
next_token = torch.multinomial(probs, 1).item()
|
| 1570 |
+
|
| 1571 |
+
generated_ids.append(next_token)
|
| 1572 |
+
|
| 1573 |
+
if next_token == self.tokenizer.eos_token_id:
|
| 1574 |
+
break
|
| 1575 |
+
|
| 1576 |
+
response = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 1577 |
+
return response[len(prompt):].strip()
|
| 1578 |
+
|
| 1579 |
+
|
| 1580 |
+
# =============================================================================
|
| 1581 |
+
# MAIN
|
| 1582 |
+
# =============================================================================
|
| 1583 |
+
|
| 1584 |
+
def main():
|
| 1585 |
+
"""Main entry point."""
|
| 1586 |
+
|
| 1587 |
+
# Configuration
|
| 1588 |
+
config = ARCAdapterConfig(
|
| 1589 |
+
base_model_path=".",
|
| 1590 |
+
output_dir="arc_adapter",
|
| 1591 |
+
n_samples_per_head=300,
|
| 1592 |
+
epochs=15,
|
| 1593 |
+
batch_size=32,
|
| 1594 |
+
learning_rate=1e-4,
|
| 1595 |
+
target_separation={
|
| 1596 |
+
"repetition": 50.0,
|
| 1597 |
+
"hedging": 5.0,
|
| 1598 |
+
"verbosity": 5.0,
|
| 1599 |
+
"sycophancy": 3.0,
|
| 1600 |
+
},
|
| 1601 |
+
loop4_iterations=3,
|
| 1602 |
+
n_merges_per_iteration=10,
|
| 1603 |
+
)
|
| 1604 |
+
|
| 1605 |
+
print("="*70)
|
| 1606 |
+
print("ARC ADAPTER TRAINING")
|
| 1607 |
+
print("="*70)
|
| 1608 |
+
print()
|
| 1609 |
+
print("This script trains the ARC adapter on a FROZEN base model.")
|
| 1610 |
+
print("The base model weights are NEVER modified.")
|
| 1611 |
+
print()
|
| 1612 |
+
print("Components trained:")
|
| 1613 |
+
print(" - Shared fiber projections (~500K params)")
|
| 1614 |
+
print(" - Repetition detection head (~5K params)")
|
| 1615 |
+
print(" - Hedging detection head (~5K params)")
|
| 1616 |
+
print(" - Verbosity detection head (~5K params)")
|
| 1617 |
+
print(" - Sycophancy detection head (~5K params)")
|
| 1618 |
+
print(" - Loop 4 tokenizer expansion")
|
| 1619 |
+
print()
|
| 1620 |
+
print("="*70)
|
| 1621 |
+
|
| 1622 |
+
# Load base model (FROZEN)
|
| 1623 |
+
print("\nLoading base model...")
|
| 1624 |
+
|
| 1625 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 1626 |
+
config.base_model_path,
|
| 1627 |
+
local_files_only=True
|
| 1628 |
+
)
|
| 1629 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 1630 |
+
|
| 1631 |
+
bnb_config = BitsAndBytesConfig(
|
| 1632 |
+
load_in_4bit=True,
|
| 1633 |
+
bnb_4bit_quant_type="nf4",
|
| 1634 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 1635 |
+
bnb_4bit_use_double_quant=True,
|
| 1636 |
+
)
|
| 1637 |
+
|
| 1638 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 1639 |
+
config.base_model_path,
|
| 1640 |
+
quantization_config=bnb_config,
|
| 1641 |
+
device_map="auto",
|
| 1642 |
+
torch_dtype=torch.bfloat16,
|
| 1643 |
+
local_files_only=True,
|
| 1644 |
+
)
|
| 1645 |
+
|
| 1646 |
+
# FREEZE the base model
|
| 1647 |
+
for param in model.parameters():
|
| 1648 |
+
param.requires_grad = False
|
| 1649 |
+
|
| 1650 |
+
# Update config with actual hidden dim
|
| 1651 |
+
config.hidden_dim = model.config.hidden_size
|
| 1652 |
+
|
| 1653 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1654 |
+
print(f"Base model: {total_params/1e9:.1f}B parameters (FROZEN)")
|
| 1655 |
+
print(f"Hidden dimension: {config.hidden_dim}")
|
| 1656 |
+
print(f"Vocabulary size: {len(tokenizer)}")
|
| 1657 |
+
print(f"VRAM usage: {torch.cuda.memory_allocated()/1024**3:.1f}GB")
|
| 1658 |
+
|
| 1659 |
+
# Create trainer and run
|
| 1660 |
+
trainer = AdapterTrainer(model, tokenizer, config)
|
| 1661 |
+
results = trainer.train()
|
| 1662 |
+
|
| 1663 |
+
# Final message
|
| 1664 |
+
print("\n" + "="*70)
|
| 1665 |
+
print("ADAPTER READY FOR USE")
|
| 1666 |
+
print("="*70)
|
| 1667 |
+
print(f"\nAdapter location: {config.output_dir}/final/")
|
| 1668 |
+
print(f"Tokenizer location: {config.output_dir}/tokenizer/")
|
| 1669 |
+
print()
|
| 1670 |
+
print("To use the adapter:")
|
| 1671 |
+
print(" from arc_adapter_training import ARCAdapter, ARCInference")
|
| 1672 |
+
print(" adapter = ARCAdapter.load('arc_adapter/final')")
|
| 1673 |
+
print(" inference = ARCInference(model, tokenizer, adapter, probe_layers)")
|
| 1674 |
+
print(" response = inference.generate('Your prompt here')")
|
| 1675 |
+
print()
|
| 1676 |
+
print("="*70)
|
| 1677 |
+
|
| 1678 |
+
|
| 1679 |
+
if __name__ == "__main__":
|
| 1680 |
+
main()
|
code/training_pipelines/03_arc_dense_train_DENSE.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ARC Dense Training - Maximum Information Per Token
|
| 4 |
+
|
| 5 |
+
Instead of optimizing for brevity, we optimize for:
|
| 6 |
+
- Information density (concepts per token)
|
| 7 |
+
- Technical depth (domain vocabulary)
|
| 8 |
+
- Factual claim density
|
| 9 |
+
- Completeness relative to question complexity
|
| 10 |
+
|
| 11 |
+
While still penalizing:
|
| 12 |
+
- Repetition (zero information)
|
| 13 |
+
- Filler phrases (negative information density)
|
| 14 |
+
- Unnecessary hedging (wastes tokens)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 21 |
+
from peft import PeftModel, get_peft_model, LoraConfig
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import random
|
| 27 |
+
import re
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
| 31 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 32 |
+
|
| 33 |
+
# Technical vocabulary for density scoring
|
| 34 |
+
TECHNICAL_TERMS = {
|
| 35 |
+
# CS/ML
|
| 36 |
+
"algorithm", "tensor", "gradient", "backpropagation", "embedding", "attention",
|
| 37 |
+
"transformer", "convolution", "recurrent", "optimization", "inference", "latent",
|
| 38 |
+
"epoch", "batch", "dropout", "regularization", "softmax", "sigmoid", "relu",
|
| 39 |
+
"encoder", "decoder", "autoregressive", "tokenizer", "perplexity", "entropy",
|
| 40 |
+
# Math
|
| 41 |
+
"derivative", "integral", "matrix", "vector", "eigenvalue", "polynomial",
|
| 42 |
+
"probability", "distribution", "variance", "covariance", "logarithm",
|
| 43 |
+
# Science
|
| 44 |
+
"quantum", "entropy", "thermodynamic", "photon", "electron", "molecule",
|
| 45 |
+
"catalyst", "oxidation", "synthesis", "genome", "protein", "neuron",
|
| 46 |
+
# Philosophy
|
| 47 |
+
"ontology", "epistemology", "metaphysics", "phenomenology", "existential",
|
| 48 |
+
"determinism", "consciousness", "qualia", "dualism", "materialism",
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
FILLER_PHRASES = [
|
| 52 |
+
"it's important to note", "it should be noted", "as you may know",
|
| 53 |
+
"in other words", "that being said", "at the end of the day",
|
| 54 |
+
"basically", "essentially", "actually", "literally", "obviously",
|
| 55 |
+
"of course", "needless to say", "as i mentioned", "let me explain",
|
| 56 |
+
"i think", "i believe", "in my opinion", "to be honest",
|
| 57 |
+
"great question", "that's a good question", "interesting question",
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class DenseConfig:
|
| 62 |
+
batch_size: int = 2 # Smaller batch for longer generations
|
| 63 |
+
gradient_accumulation: int = 8 # Effective batch = 16
|
| 64 |
+
max_grad_norm: float = 1.0
|
| 65 |
+
learning_rate: float = 3e-6 # Slightly lower for stability
|
| 66 |
+
max_new_tokens: int = 256 # Allow longer responses for density
|
| 67 |
+
checkpoint_every: int = 1000
|
| 68 |
+
log_every: int = 25
|
| 69 |
+
regenerate_prompts_every: int = 5000
|
| 70 |
+
temperature: float = 0.7 # Slightly lower for more focused output
|
| 71 |
+
|
| 72 |
+
# Dense-specific
|
| 73 |
+
min_response_tokens: int = 40 # Don't reward too-short responses
|
| 74 |
+
target_density: float = 0.15 # Target concepts per token
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MultiHeadPredictor(nn.Module):
|
| 78 |
+
def __init__(self, d_model=4096, n_layers=32, d_fiber=16):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.d_model = d_model
|
| 81 |
+
self.n_layers = n_layers
|
| 82 |
+
self.d_fiber = d_fiber
|
| 83 |
+
|
| 84 |
+
self.fiber_projs = nn.ModuleList([
|
| 85 |
+
nn.Linear(d_model, d_fiber, bias=False)
|
| 86 |
+
for _ in range(n_layers)
|
| 87 |
+
])
|
| 88 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 89 |
+
self.heads = nn.ModuleDict()
|
| 90 |
+
self.loaded_heads = set()
|
| 91 |
+
|
| 92 |
+
def add_head(self, name):
|
| 93 |
+
self.heads[name] = nn.Sequential(
|
| 94 |
+
nn.Linear(self.d_fiber, 64), nn.GELU(),
|
| 95 |
+
nn.Linear(64, 64), nn.GELU(),
|
| 96 |
+
nn.Linear(64, 1)
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def get_fiber_features(self, hidden_states):
|
| 100 |
+
device = hidden_states[0].device
|
| 101 |
+
fibers = []
|
| 102 |
+
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
|
| 103 |
+
proj = proj.to(device)
|
| 104 |
+
fibers.append(proj(h.float()))
|
| 105 |
+
weights = F.softmax(self.layer_weights.to(device), dim=0)
|
| 106 |
+
return sum(w * f for w, f in zip(weights, fibers))
|
| 107 |
+
|
| 108 |
+
def get_all_risks(self, hidden_states):
|
| 109 |
+
device = hidden_states[0].device
|
| 110 |
+
features = self.get_fiber_features(hidden_states)
|
| 111 |
+
risks = {}
|
| 112 |
+
for name in self.loaded_heads:
|
| 113 |
+
self.heads[name] = self.heads[name].to(device)
|
| 114 |
+
logits = self.heads[name](features).squeeze(-1)
|
| 115 |
+
risks[name] = torch.sigmoid(logits)
|
| 116 |
+
return risks
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def load_predictor(checkpoint_dir: Path, device):
|
| 120 |
+
predictor = MultiHeadPredictor()
|
| 121 |
+
|
| 122 |
+
rep_path = checkpoint_dir / "cfhot_risk_v2/ckpt_5000/risk_predictor.pt"
|
| 123 |
+
if rep_path.exists():
|
| 124 |
+
ckpt = torch.load(rep_path, map_location=device, weights_only=False)
|
| 125 |
+
if 'fiber_projs' in ckpt:
|
| 126 |
+
for i, proj_state in enumerate(ckpt['fiber_projs']):
|
| 127 |
+
predictor.fiber_projs[i].load_state_dict(proj_state)
|
| 128 |
+
if 'layer_weights' in ckpt:
|
| 129 |
+
predictor.layer_weights.data = ckpt['layer_weights']
|
| 130 |
+
predictor.add_head('repetition')
|
| 131 |
+
if 'head_state' in ckpt:
|
| 132 |
+
predictor.heads['repetition'].load_state_dict(ckpt['head_state'])
|
| 133 |
+
predictor.loaded_heads.add('repetition')
|
| 134 |
+
print(" ✓ Loaded repetition head")
|
| 135 |
+
|
| 136 |
+
for head_name in ['hedging', 'verbosity', 'sycophancy']:
|
| 137 |
+
head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_10000/{head_name}_head.pt"
|
| 138 |
+
if not head_path.exists():
|
| 139 |
+
head_path = checkpoint_dir / f"multi_head_v2/{head_name}_head/ckpt_2000/{head_name}_head.pt"
|
| 140 |
+
if head_path.exists():
|
| 141 |
+
predictor.add_head(head_name)
|
| 142 |
+
ckpt = torch.load(head_path, map_location=device, weights_only=False)
|
| 143 |
+
if 'head_state' in ckpt:
|
| 144 |
+
predictor.heads[head_name].load_state_dict(ckpt['head_state'])
|
| 145 |
+
elif isinstance(ckpt, dict) and head_name in ckpt:
|
| 146 |
+
predictor.heads[head_name].load_state_dict(ckpt[head_name])
|
| 147 |
+
predictor.loaded_heads.add(head_name)
|
| 148 |
+
print(f" ✓ Loaded {head_name} head")
|
| 149 |
+
|
| 150 |
+
predictor.eval()
|
| 151 |
+
for param in predictor.parameters():
|
| 152 |
+
param.requires_grad = False
|
| 153 |
+
|
| 154 |
+
return predictor.to(device)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def generate_complex_prompts(n: int) -> list:
|
| 158 |
+
"""Generate prompts that demand dense, technical responses"""
|
| 159 |
+
|
| 160 |
+
templates = [
|
| 161 |
+
# Technical explanations
|
| 162 |
+
"Explain {topic} with technical precision.",
|
| 163 |
+
"How does {topic} work at a fundamental level?",
|
| 164 |
+
"What are the core mechanisms behind {topic}?",
|
| 165 |
+
"Describe the architecture of {topic}.",
|
| 166 |
+
"What distinguishes {topic1} from {topic2} technically?",
|
| 167 |
+
|
| 168 |
+
# Deep dives
|
| 169 |
+
"Explain the mathematics behind {topic}.",
|
| 170 |
+
"What are the theoretical foundations of {topic}?",
|
| 171 |
+
"Describe {topic} as you would to a graduate student.",
|
| 172 |
+
"What are the key equations governing {topic}?",
|
| 173 |
+
|
| 174 |
+
# Implementation
|
| 175 |
+
"How would you implement {topic} from scratch?",
|
| 176 |
+
"What's the most efficient algorithm for {task}?",
|
| 177 |
+
"Explain the time complexity of {topic}.",
|
| 178 |
+
|
| 179 |
+
# Analysis
|
| 180 |
+
"What are the fundamental tradeoffs in {topic}?",
|
| 181 |
+
"Why does {topic} work the way it does?",
|
| 182 |
+
"What are the failure modes of {topic}?",
|
| 183 |
+
"Analyze the strengths and weaknesses of {topic}.",
|
| 184 |
+
|
| 185 |
+
# Synthesis
|
| 186 |
+
"How do {topic1} and {topic2} relate to each other?",
|
| 187 |
+
"What principles unify {topic1} and {topic2}?",
|
| 188 |
+
"How would you combine {topic1} with {topic2}?",
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
topics = [
|
| 192 |
+
"transformer attention", "backpropagation", "gradient descent",
|
| 193 |
+
"convolutional neural networks", "recurrent neural networks",
|
| 194 |
+
"reinforcement learning", "Q-learning", "policy gradients",
|
| 195 |
+
"variational autoencoders", "GANs", "diffusion models",
|
| 196 |
+
"tokenization", "embedding spaces", "positional encoding",
|
| 197 |
+
"layer normalization", "batch normalization", "dropout",
|
| 198 |
+
"LSTM gates", "self-attention", "cross-attention",
|
| 199 |
+
"beam search", "nucleus sampling", "temperature scaling",
|
| 200 |
+
"quantization", "pruning", "knowledge distillation",
|
| 201 |
+
"quantum entanglement", "wave function collapse", "superposition",
|
| 202 |
+
"natural selection", "genetic drift", "speciation",
|
| 203 |
+
"thermodynamic entropy", "information entropy", "free energy",
|
| 204 |
+
"consciousness", "qualia", "the binding problem",
|
| 205 |
+
"Gödel's incompleteness", "Turing completeness", "P vs NP",
|
| 206 |
+
"hash tables", "B-trees", "red-black trees",
|
| 207 |
+
"recursion", "dynamic programming", "memoization",
|
| 208 |
+
"TCP/IP", "public key cryptography", "consensus algorithms",
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
tasks = [
|
| 212 |
+
"sorting n elements", "finding shortest path", "matrix multiplication",
|
| 213 |
+
"string matching", "graph traversal", "balanced tree insertion",
|
| 214 |
+
"hash collision resolution", "memory allocation", "garbage collection",
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
prompts = []
|
| 218 |
+
complexities = [] # Track expected complexity
|
| 219 |
+
|
| 220 |
+
for _ in range(n):
|
| 221 |
+
template = random.choice(templates)
|
| 222 |
+
|
| 223 |
+
if "{topic1}" in template and "{topic2}" in template:
|
| 224 |
+
t1, t2 = random.sample(topics, 2)
|
| 225 |
+
prompt = template.format(topic1=t1, topic2=t2)
|
| 226 |
+
complexity = 3 # Comparison = high complexity
|
| 227 |
+
elif "{topic}" in template:
|
| 228 |
+
topic = random.choice(topics)
|
| 229 |
+
prompt = template.format(topic=topic)
|
| 230 |
+
complexity = 2 if "mathematics" in template or "equations" in template else 1.5
|
| 231 |
+
elif "{task}" in template:
|
| 232 |
+
task = random.choice(tasks)
|
| 233 |
+
prompt = template.format(task=task)
|
| 234 |
+
complexity = 2
|
| 235 |
+
else:
|
| 236 |
+
prompt = template
|
| 237 |
+
complexity = 1
|
| 238 |
+
|
| 239 |
+
prompts.append((prompt, complexity))
|
| 240 |
+
|
| 241 |
+
return prompts
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def count_technical_terms(text: str) -> int:
|
| 245 |
+
"""Count domain-specific technical vocabulary"""
|
| 246 |
+
words = set(text.lower().split())
|
| 247 |
+
return len(words.intersection(TECHNICAL_TERMS))
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def count_filler_phrases(text: str) -> int:
|
| 251 |
+
"""Count filler phrases that waste tokens"""
|
| 252 |
+
text_lower = text.lower()
|
| 253 |
+
return sum(1 for phrase in FILLER_PHRASES if phrase in text_lower)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def count_factual_claims(text: str) -> int:
|
| 257 |
+
"""Estimate number of factual assertions"""
|
| 258 |
+
# Simple heuristic: sentences with specific patterns
|
| 259 |
+
sentences = re.split(r'[.!?]', text)
|
| 260 |
+
claims = 0
|
| 261 |
+
for sent in sentences:
|
| 262 |
+
sent = sent.strip().lower()
|
| 263 |
+
if not sent:
|
| 264 |
+
continue
|
| 265 |
+
# Patterns indicating factual claims
|
| 266 |
+
if any(pattern in sent for pattern in [
|
| 267 |
+
" is ", " are ", " was ", " were ", " has ", " have ",
|
| 268 |
+
" means ", " equals ", " produces ", " causes ", " results ",
|
| 269 |
+
" requires ", " enables ", " allows ", " prevents ",
|
| 270 |
+
"defined as", "consists of", "composed of",
|
| 271 |
+
]):
|
| 272 |
+
claims += 1
|
| 273 |
+
return claims
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def count_code_and_math(text: str) -> int:
|
| 277 |
+
"""Count structured technical content"""
|
| 278 |
+
code_blocks = len(re.findall(r'```[\s\S]*?```', text))
|
| 279 |
+
inline_code = len(re.findall(r'`[^`]+`', text))
|
| 280 |
+
equations = len(re.findall(r'\$[^$]+\$', text))
|
| 281 |
+
math_symbols = len(re.findall(r'[∑∏∫∂∇≈≠≤≥∈∀∃→←↔×÷±√∞]', text))
|
| 282 |
+
formulas = len(re.findall(r'[a-z]\s*[=<>]\s*[a-z0-9]', text, re.I))
|
| 283 |
+
|
| 284 |
+
return code_blocks * 5 + inline_code + equations * 3 + math_symbols + formulas
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def compute_dense_reward(response_ids, risks, tokenizer, complexity, config):
|
| 288 |
+
"""
|
| 289 |
+
Dense reward: maximize information per token
|
| 290 |
+
|
| 291 |
+
Reward = (information_score) / (effective_tokens) - penalties
|
| 292 |
+
"""
|
| 293 |
+
batch_rewards = []
|
| 294 |
+
batch_densities = []
|
| 295 |
+
|
| 296 |
+
for i in range(len(response_ids)):
|
| 297 |
+
response = tokenizer.decode(response_ids[i], skip_special_tokens=True)
|
| 298 |
+
tokens = len(response_ids[i])
|
| 299 |
+
|
| 300 |
+
if tokens < 5:
|
| 301 |
+
batch_rewards.append(0.0)
|
| 302 |
+
batch_densities.append(0.0)
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
# === Information Content ===
|
| 306 |
+
|
| 307 |
+
# 1. Unique concept words (content words > 4 chars)
|
| 308 |
+
words = response.split()
|
| 309 |
+
content_words = set(w.lower() for w in words if len(w) > 4 and w.isalpha())
|
| 310 |
+
concept_density = len(content_words) / tokens
|
| 311 |
+
|
| 312 |
+
# 2. Technical term density
|
| 313 |
+
tech_terms = count_technical_terms(response)
|
| 314 |
+
tech_density = tech_terms / tokens
|
| 315 |
+
|
| 316 |
+
# 3. Factual claim density
|
| 317 |
+
claims = count_factual_claims(response)
|
| 318 |
+
claim_density = claims / max(tokens / 20, 1) # Normalize by ~sentence count
|
| 319 |
+
|
| 320 |
+
# 4. Structured content (code, math)
|
| 321 |
+
structured = count_code_and_math(response)
|
| 322 |
+
structured_density = structured / tokens
|
| 323 |
+
|
| 324 |
+
# Combined information score
|
| 325 |
+
info_score = (
|
| 326 |
+
concept_density * 0.3 +
|
| 327 |
+
tech_density * 0.3 +
|
| 328 |
+
claim_density * 0.25 +
|
| 329 |
+
structured_density * 0.15
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# === Fluff Penalties ===
|
| 333 |
+
|
| 334 |
+
rep_risk = risks['repetition'][i, -1].item() if 'repetition' in risks else 0
|
| 335 |
+
verb_risk = risks['verbosity'][i, -1].item() if 'verbosity' in risks else 0
|
| 336 |
+
hedge_risk = risks['hedging'][i, -1].item() if 'hedging' in risks else 0
|
| 337 |
+
|
| 338 |
+
filler_count = count_filler_phrases(response)
|
| 339 |
+
filler_penalty = min(filler_count * 0.05, 0.3)
|
| 340 |
+
|
| 341 |
+
# Probes penalty (repetition worst, verbosity bad, hedging mild)
|
| 342 |
+
probe_penalty = 0.4 * rep_risk + 0.25 * verb_risk + 0.1 * hedge_risk
|
| 343 |
+
|
| 344 |
+
total_fluff = filler_penalty + probe_penalty
|
| 345 |
+
|
| 346 |
+
# === Completeness ===
|
| 347 |
+
|
| 348 |
+
# Scale expected length with question complexity
|
| 349 |
+
expected_min = config.min_response_tokens * complexity
|
| 350 |
+
if tokens < expected_min:
|
| 351 |
+
completeness_penalty = 0.3 * (expected_min - tokens) / expected_min
|
| 352 |
+
else:
|
| 353 |
+
completeness_penalty = 0
|
| 354 |
+
|
| 355 |
+
# Bonus for appropriate length (not too short, not excessively long)
|
| 356 |
+
if expected_min <= tokens <= expected_min * 3:
|
| 357 |
+
length_bonus = 0.1
|
| 358 |
+
elif tokens > expected_min * 4:
|
| 359 |
+
length_bonus = -0.1 # Penalize excessive length
|
| 360 |
+
else:
|
| 361 |
+
length_bonus = 0
|
| 362 |
+
|
| 363 |
+
# === Final Reward ===
|
| 364 |
+
|
| 365 |
+
# Effective tokens: actual tokens + penalty for fluff
|
| 366 |
+
effective_tokens = tokens * (1 + total_fluff)
|
| 367 |
+
|
| 368 |
+
# Information per effective token
|
| 369 |
+
density = info_score / (effective_tokens / 100)
|
| 370 |
+
|
| 371 |
+
reward = density - completeness_penalty + length_bonus
|
| 372 |
+
reward = max(0, min(1, reward)) # Clamp to [0, 1]
|
| 373 |
+
|
| 374 |
+
batch_rewards.append(reward)
|
| 375 |
+
batch_densities.append(info_score * 100) # For logging
|
| 376 |
+
|
| 377 |
+
return (
|
| 378 |
+
torch.tensor(batch_rewards, dtype=torch.float32, device=response_ids[0].device),
|
| 379 |
+
sum(batch_densities) / len(batch_densities) if batch_densities else 0
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def compute_efficiency_decision(risks):
|
| 384 |
+
"""Same efficiency routing as terse training"""
|
| 385 |
+
rep = risks.get('repetition', torch.zeros(1))[:, -1].mean().item()
|
| 386 |
+
verb = risks.get('verbosity', torch.zeros(1))[:, -1].mean().item()
|
| 387 |
+
hedge = risks.get('hedging', torch.zeros(1))[:, -1].mean().item()
|
| 388 |
+
|
| 389 |
+
if rep > 0.45:
|
| 390 |
+
return {'layers': 20, 'spec_length': 8, 'strategy': 'skip_speculate_aggressive'}
|
| 391 |
+
elif verb > 0.5:
|
| 392 |
+
return {'layers': 24, 'spec_length': 6, 'strategy': 'skip_speculate_moderate'}
|
| 393 |
+
elif rep < 0.4 and verb < 0.4 and hedge < 0.4:
|
| 394 |
+
return {'layers': 16, 'spec_length': 2, 'strategy': 'early_exit_careful'}
|
| 395 |
+
else:
|
| 396 |
+
return {'layers': 32, 'spec_length': 3, 'strategy': 'full_compute'}
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def train(args):
|
| 400 |
+
config = DenseConfig()
|
| 401 |
+
config.learning_rate = args.lr
|
| 402 |
+
device = torch.device("cuda")
|
| 403 |
+
|
| 404 |
+
print("=" * 60)
|
| 405 |
+
print(" ARC Dense Training - Maximum Information Density")
|
| 406 |
+
print("=" * 60)
|
| 407 |
+
print(f" Batch size: {config.batch_size}")
|
| 408 |
+
print(f" Gradient accumulation: {config.gradient_accumulation}")
|
| 409 |
+
print(f" Effective batch: {config.batch_size * config.gradient_accumulation}")
|
| 410 |
+
print(f" Learning rate: {config.learning_rate}")
|
| 411 |
+
print(f" Max new tokens: {config.max_new_tokens}")
|
| 412 |
+
print(f" Min response tokens: {config.min_response_tokens}")
|
| 413 |
+
print("=" * 60)
|
| 414 |
+
|
| 415 |
+
print("\n[1/3] Loading model...")
|
| 416 |
+
|
| 417 |
+
bnb_config = BitsAndBytesConfig(
|
| 418 |
+
load_in_4bit=True,
|
| 419 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 420 |
+
bnb_4bit_quant_type="nf4",
|
| 421 |
+
bnb_4bit_use_double_quant=True
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
tokenizer = AutoTokenizer.from_pretrained(args.local_model)
|
| 425 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 426 |
+
tokenizer.padding_side = "left"
|
| 427 |
+
|
| 428 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 429 |
+
args.local_model,
|
| 430 |
+
quantization_config=bnb_config,
|
| 431 |
+
device_map="auto",
|
| 432 |
+
torch_dtype=torch.bfloat16,
|
| 433 |
+
attn_implementation="sdpa"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
start_step = 0
|
| 437 |
+
if args.resume and Path(args.resume).exists():
|
| 438 |
+
print(f" Resuming from {args.resume}")
|
| 439 |
+
model = PeftModel.from_pretrained(model, args.resume, is_trainable=True)
|
| 440 |
+
|
| 441 |
+
state_path = Path(args.resume) / "training_state.pt"
|
| 442 |
+
if state_path.exists():
|
| 443 |
+
state = torch.load(state_path, weights_only=False)
|
| 444 |
+
start_step = state.get('step', 0)
|
| 445 |
+
print(f" Resuming from step {start_step}")
|
| 446 |
+
elif args.base_checkpoint and Path(args.base_checkpoint).exists():
|
| 447 |
+
print(f" Loading base checkpoint: {args.base_checkpoint}")
|
| 448 |
+
model = PeftModel.from_pretrained(model, args.base_checkpoint, is_trainable=True)
|
| 449 |
+
print(" ✓ Loaded terse-trained adapter as starting point")
|
| 450 |
+
else:
|
| 451 |
+
lora_config = LoraConfig(
|
| 452 |
+
r=16,
|
| 453 |
+
lora_alpha=32,
|
| 454 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 455 |
+
lora_dropout=0.05,
|
| 456 |
+
bias="none",
|
| 457 |
+
task_type="CAUSAL_LM"
|
| 458 |
+
)
|
| 459 |
+
model = get_peft_model(model, lora_config)
|
| 460 |
+
|
| 461 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 462 |
+
total = sum(p.numel() for p in model.parameters())
|
| 463 |
+
print(f" Model loaded. Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")
|
| 464 |
+
|
| 465 |
+
print("\n[2/3] Loading behavioral prediction heads...")
|
| 466 |
+
checkpoint_dir = Path.home() / "arc_efficiency_training/checkpoints_backup"
|
| 467 |
+
predictor = load_predictor(checkpoint_dir, device)
|
| 468 |
+
print(f" ✓ Predictor loaded with heads: {list(predictor.loaded_heads)}")
|
| 469 |
+
|
| 470 |
+
print("\n[3/3] Setting up optimizer...")
|
| 471 |
+
optimizer = torch.optim.AdamW(
|
| 472 |
+
model.parameters(),
|
| 473 |
+
lr=config.learning_rate,
|
| 474 |
+
weight_decay=0.01,
|
| 475 |
+
betas=(0.9, 0.999)
|
| 476 |
+
)
|
| 477 |
+
print(f" ✓ Optimizer: AdamW, LR: {config.learning_rate}")
|
| 478 |
+
|
| 479 |
+
print(f"\nGenerating {args.prompts} complex prompts...")
|
| 480 |
+
prompts_with_complexity = generate_complex_prompts(args.prompts)
|
| 481 |
+
print(f" ✓ Generated {len(prompts_with_complexity)} prompts")
|
| 482 |
+
|
| 483 |
+
checkpoint_dir = Path(args.checkpoint_dir)
|
| 484 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 485 |
+
|
| 486 |
+
print("\n" + "=" * 60)
|
| 487 |
+
print(f" Starting DENSE training from step {start_step}")
|
| 488 |
+
print(f" Total steps: {args.steps}")
|
| 489 |
+
print("=" * 60 + "\n")
|
| 490 |
+
|
| 491 |
+
model.train()
|
| 492 |
+
optimizer.zero_grad()
|
| 493 |
+
|
| 494 |
+
step = start_step
|
| 495 |
+
accum_loss = 0
|
| 496 |
+
accum_reward = 0
|
| 497 |
+
accum_density = 0
|
| 498 |
+
accum_rep = 0
|
| 499 |
+
accum_layers = 0
|
| 500 |
+
last_strategy = "none"
|
| 501 |
+
|
| 502 |
+
while step < args.steps:
|
| 503 |
+
batch_data = random.sample(prompts_with_complexity, config.batch_size)
|
| 504 |
+
batch_prompts = [p[0] for p in batch_data]
|
| 505 |
+
batch_complexity = [p[1] for p in batch_data]
|
| 506 |
+
avg_complexity = sum(batch_complexity) / len(batch_complexity)
|
| 507 |
+
|
| 508 |
+
formatted = [
|
| 509 |
+
f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
|
| 510 |
+
for p in batch_prompts
|
| 511 |
+
]
|
| 512 |
+
|
| 513 |
+
inputs = tokenizer(
|
| 514 |
+
formatted,
|
| 515 |
+
return_tensors="pt",
|
| 516 |
+
padding=True,
|
| 517 |
+
truncation=True,
|
| 518 |
+
max_length=512
|
| 519 |
+
).to(device)
|
| 520 |
+
|
| 521 |
+
model.eval()
|
| 522 |
+
with torch.no_grad():
|
| 523 |
+
outputs = model.generate(
|
| 524 |
+
**inputs,
|
| 525 |
+
max_new_tokens=config.max_new_tokens,
|
| 526 |
+
do_sample=True,
|
| 527 |
+
temperature=config.temperature,
|
| 528 |
+
top_p=0.9,
|
| 529 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 530 |
+
use_cache=True
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
generated_ids = outputs[:, inputs.input_ids.shape[1]:]
|
| 534 |
+
|
| 535 |
+
with torch.no_grad():
|
| 536 |
+
hidden_outputs = model(
|
| 537 |
+
outputs,
|
| 538 |
+
output_hidden_states=True,
|
| 539 |
+
return_dict=True,
|
| 540 |
+
use_cache=False
|
| 541 |
+
)
|
| 542 |
+
hidden_states = hidden_outputs.hidden_states[1:]
|
| 543 |
+
risks = predictor.get_all_risks(hidden_states)
|
| 544 |
+
|
| 545 |
+
rewards, avg_density_score = compute_dense_reward(
|
| 546 |
+
generated_ids, risks, tokenizer, avg_complexity, config
|
| 547 |
+
)
|
| 548 |
+
efficiency = compute_efficiency_decision(risks)
|
| 549 |
+
|
| 550 |
+
model.train()
|
| 551 |
+
logits = model(outputs, return_dict=True, use_cache=False).logits
|
| 552 |
+
|
| 553 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 554 |
+
shift_labels = outputs[:, 1:].contiguous()
|
| 555 |
+
|
| 556 |
+
log_probs = F.log_softmax(shift_logits.float(), dim=-1)
|
| 557 |
+
selected_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)
|
| 558 |
+
|
| 559 |
+
mask = (shift_labels != tokenizer.pad_token_id).float()
|
| 560 |
+
seq_log_probs = (selected_log_probs * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
|
| 561 |
+
|
| 562 |
+
baseline = rewards.float().mean()
|
| 563 |
+
advantages = rewards - baseline
|
| 564 |
+
loss = -(seq_log_probs * advantages.to(seq_log_probs.device)).mean()
|
| 565 |
+
loss = loss / config.gradient_accumulation
|
| 566 |
+
|
| 567 |
+
loss.backward()
|
| 568 |
+
|
| 569 |
+
accum_loss += loss.item() * config.gradient_accumulation
|
| 570 |
+
accum_reward += rewards.float().mean().item()
|
| 571 |
+
accum_density += avg_density_score
|
| 572 |
+
accum_rep += risks['repetition'][:, -1].mean().item() if 'repetition' in risks else 0
|
| 573 |
+
accum_layers += efficiency['layers']
|
| 574 |
+
last_strategy = efficiency['strategy']
|
| 575 |
+
|
| 576 |
+
if (step + 1) % config.gradient_accumulation == 0:
|
| 577 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
|
| 578 |
+
optimizer.step()
|
| 579 |
+
optimizer.zero_grad()
|
| 580 |
+
|
| 581 |
+
step += 1
|
| 582 |
+
|
| 583 |
+
if step % config.log_every == 0:
|
| 584 |
+
avg_loss = accum_loss / config.log_every
|
| 585 |
+
avg_reward = accum_reward / config.log_every
|
| 586 |
+
avg_dens = accum_density / config.log_every
|
| 587 |
+
avg_rep = accum_rep / config.log_every
|
| 588 |
+
avg_layers = accum_layers / config.log_every
|
| 589 |
+
|
| 590 |
+
print(f"Step {step:6d} | Loss: {avg_loss:.4f} | Reward: {avg_reward:.3f} | "
|
| 591 |
+
f"Density: {avg_dens:.2f} | Rep: {avg_rep:.3f} | Layers: {avg_layers:.1f} | {last_strategy}")
|
| 592 |
+
|
| 593 |
+
accum_loss = 0
|
| 594 |
+
accum_reward = 0
|
| 595 |
+
accum_density = 0
|
| 596 |
+
accum_rep = 0
|
| 597 |
+
accum_layers = 0
|
| 598 |
+
|
| 599 |
+
if step % config.checkpoint_every == 0:
|
| 600 |
+
ckpt_path = checkpoint_dir / f"step_{step}"
|
| 601 |
+
model.save_pretrained(ckpt_path)
|
| 602 |
+
|
| 603 |
+
torch.save({
|
| 604 |
+
'step': step,
|
| 605 |
+
'optimizer': optimizer.state_dict(),
|
| 606 |
+
'config': config.__dict__,
|
| 607 |
+
'mode': 'dense'
|
| 608 |
+
}, ckpt_path / "training_state.pt")
|
| 609 |
+
|
| 610 |
+
with open(ckpt_path / "README.md", "w") as f:
|
| 611 |
+
f.write(f"# ARC Dense Checkpoint - Step {step}\n\n")
|
| 612 |
+
f.write("**Mode:** Dense (maximum information per token)\n\n")
|
| 613 |
+
f.write(f"Training config:\n```json\n{json.dumps(config.__dict__, indent=2)}\n```\n")
|
| 614 |
+
|
| 615 |
+
print(f" ✓ Saved dense checkpoint at step {step}")
|
| 616 |
+
|
| 617 |
+
if step % config.regenerate_prompts_every == 0 and step > start_step:
|
| 618 |
+
print(f"\n Regenerating complex prompts...")
|
| 619 |
+
prompts_with_complexity = generate_complex_prompts(args.prompts)
|
| 620 |
+
print(f" ✓ Generated {len(prompts_with_complexity)} fresh prompts\n")
|
| 621 |
+
|
| 622 |
+
print("\n" + "=" * 60)
|
| 623 |
+
print(" Dense training complete!")
|
| 624 |
+
print("=" * 60)
|
| 625 |
+
|
| 626 |
+
final_path = checkpoint_dir / "final"
|
| 627 |
+
model.save_pretrained(final_path)
|
| 628 |
+
print(f" ✓ Saved final dense model to {final_path}")
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
if __name__ == "__main__":
|
| 632 |
+
parser = argparse.ArgumentParser()
|
| 633 |
+
parser.add_argument("--local-model", type=str, required=True)
|
| 634 |
+
parser.add_argument("--base-checkpoint", type=str, default=None,
|
| 635 |
+
help="Start from terse-trained checkpoint")
|
| 636 |
+
parser.add_argument("--steps", type=int, default=20000)
|
| 637 |
+
parser.add_argument("--lr", type=float, default=3e-6)
|
| 638 |
+
parser.add_argument("--prompts", type=int, default=5000)
|
| 639 |
+
parser.add_argument("--checkpoint-dir", type=str, default="./dense_checkpoints")
|
| 640 |
+
parser.add_argument("--resume", type=str, default=None)
|
| 641 |
+
args = parser.parse_args()
|
| 642 |
+
|
| 643 |
+
train(args)
|
code/training_pipelines/04_lie_holonomy_experiment_GEOMETRY.py
ADDED
|
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Lie Holonomy Transformer - Geometric Analysis of Hidden States
|
| 4 |
+
==============================================================
|
| 5 |
+
Tests whether geometric properties (velocity, curvature, holonomy)
|
| 6 |
+
predict model behavior better than raw hidden state probes.
|
| 7 |
+
|
| 8 |
+
This is the experiment that could change everything.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import numpy as np
|
| 15 |
+
import json
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import List, Dict, Tuple, Optional
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# CONFIGURATION
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class GeometryConfig:
|
| 28 |
+
model_path: str = "." # Your local model
|
| 29 |
+
n_sequences: int = 100
|
| 30 |
+
max_length: int = 256
|
| 31 |
+
device: str = "cuda"
|
| 32 |
+
output_dir: str = "geometry_results"
|
| 33 |
+
|
| 34 |
+
# Geometric thresholds
|
| 35 |
+
curvature_window: int = 3 # Tokens to compute curvature over
|
| 36 |
+
holonomy_threshold: float = 0.95 # Cosine similarity to detect "loops"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# =============================================================================
|
| 40 |
+
# GEOMETRIC COMPUTATIONS
|
| 41 |
+
# =============================================================================
|
| 42 |
+
|
| 43 |
+
class ManifoldAnalyzer:
|
| 44 |
+
"""
|
| 45 |
+
Analyzes the geometry of hidden state trajectories.
|
| 46 |
+
|
| 47 |
+
Key concepts:
|
| 48 |
+
- Velocity: direction of movement in hidden space (first derivative)
|
| 49 |
+
- Curvature: how sharply the path bends (second derivative)
|
| 50 |
+
- Holonomy: what you lose going around a loop (parallel transport failure)
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, config: GeometryConfig):
|
| 54 |
+
self.config = config
|
| 55 |
+
|
| 56 |
+
def compute_velocities(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
"""
|
| 58 |
+
Compute velocity vectors (tangent vectors to the trajectory).
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
hidden_states: [seq_len, hidden_dim]
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
velocities: [seq_len-1, hidden_dim]
|
| 65 |
+
"""
|
| 66 |
+
return hidden_states[1:] - hidden_states[:-1]
|
| 67 |
+
|
| 68 |
+
def compute_speeds(self, velocities: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""Magnitude of velocity vectors."""
|
| 70 |
+
return torch.norm(velocities, dim=-1)
|
| 71 |
+
|
| 72 |
+
def compute_accelerations(self, velocities: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""Second derivative - how velocity changes."""
|
| 74 |
+
return velocities[1:] - velocities[:-1]
|
| 75 |
+
|
| 76 |
+
def compute_curvature(self, velocities: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
Curvature κ = |dT/ds| where T is unit tangent, s is arc length.
|
| 79 |
+
|
| 80 |
+
Simplified: κ ≈ |a| / |v|² where a is acceleration, v is velocity.
|
| 81 |
+
High curvature = sharp turns in semantic space.
|
| 82 |
+
"""
|
| 83 |
+
speeds = self.compute_speeds(velocities)
|
| 84 |
+
accelerations = self.compute_accelerations(velocities)
|
| 85 |
+
accel_magnitudes = torch.norm(accelerations, dim=-1)
|
| 86 |
+
|
| 87 |
+
# Avoid division by zero
|
| 88 |
+
speeds_squared = speeds[:-1] ** 2 + 1e-8
|
| 89 |
+
|
| 90 |
+
curvature = accel_magnitudes / speeds_squared
|
| 91 |
+
return curvature
|
| 92 |
+
|
| 93 |
+
def compute_torsion(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""
|
| 95 |
+
Torsion measures how the path twists out of its osculating plane.
|
| 96 |
+
Third derivative information.
|
| 97 |
+
"""
|
| 98 |
+
v = self.compute_velocities(hidden_states)
|
| 99 |
+
a = self.compute_accelerations(v)
|
| 100 |
+
|
| 101 |
+
if len(a) < 2:
|
| 102 |
+
return torch.tensor([0.0])
|
| 103 |
+
|
| 104 |
+
# Jerk (third derivative)
|
| 105 |
+
j = a[1:] - a[:-1]
|
| 106 |
+
|
| 107 |
+
# Torsion involves cross product in the v-a-j frame
|
| 108 |
+
# Simplified: measure how much j is out of the v-a plane
|
| 109 |
+
v_trimmed = v[:-2]
|
| 110 |
+
a_trimmed = a[:-1]
|
| 111 |
+
|
| 112 |
+
# Project j onto plane spanned by v and a, measure residual
|
| 113 |
+
v_norm = F.normalize(v_trimmed, dim=-1)
|
| 114 |
+
a_norm = F.normalize(a_trimmed, dim=-1)
|
| 115 |
+
|
| 116 |
+
j_proj_v = (j * v_norm).sum(dim=-1, keepdim=True) * v_norm
|
| 117 |
+
j_proj_a = (j * a_norm).sum(dim=-1, keepdim=True) * a_norm
|
| 118 |
+
j_in_plane = j_proj_v + j_proj_a
|
| 119 |
+
j_out_of_plane = j - j_in_plane
|
| 120 |
+
|
| 121 |
+
torsion = torch.norm(j_out_of_plane, dim=-1)
|
| 122 |
+
return torsion
|
| 123 |
+
|
| 124 |
+
def detect_loops(self, hidden_states: torch.Tensor,
|
| 125 |
+
threshold: float = 0.95) -> List[Tuple[int, int]]:
|
| 126 |
+
"""
|
| 127 |
+
Detect semantic loops: positions where we return to similar states.
|
| 128 |
+
"""
|
| 129 |
+
# Normalize for cosine similarity
|
| 130 |
+
h_norm = F.normalize(hidden_states, dim=-1)
|
| 131 |
+
similarity = torch.mm(h_norm, h_norm.t())
|
| 132 |
+
|
| 133 |
+
loops = []
|
| 134 |
+
seq_len = hidden_states.shape[0]
|
| 135 |
+
|
| 136 |
+
# Find high similarity pairs (excluding diagonal and nearby)
|
| 137 |
+
for i in range(seq_len):
|
| 138 |
+
for j in range(i + 5, seq_len): # At least 5 tokens apart
|
| 139 |
+
if similarity[i, j] > threshold:
|
| 140 |
+
loops.append((i, j))
|
| 141 |
+
|
| 142 |
+
return loops
|
| 143 |
+
|
| 144 |
+
def compute_holonomy(self, hidden_states: torch.Tensor,
|
| 145 |
+
loop: Tuple[int, int]) -> float:
|
| 146 |
+
"""
|
| 147 |
+
Compute holonomy around a detected loop.
|
| 148 |
+
|
| 149 |
+
If we parallel transport a vector around a loop and it comes back
|
| 150 |
+
unchanged, the space is flat. If it rotates, there's curvature.
|
| 151 |
+
|
| 152 |
+
Simplified version: compare the "frame" at start vs end of loop.
|
| 153 |
+
"""
|
| 154 |
+
i, j = loop
|
| 155 |
+
|
| 156 |
+
# Get velocities at both points
|
| 157 |
+
if i > 0 and j < len(hidden_states) - 1:
|
| 158 |
+
v_start = hidden_states[i] - hidden_states[i-1]
|
| 159 |
+
v_end = hidden_states[j] - hidden_states[j-1]
|
| 160 |
+
|
| 161 |
+
# Holonomy = angle between velocity vectors at "same" point
|
| 162 |
+
v_start_norm = F.normalize(v_start, dim=-1)
|
| 163 |
+
v_end_norm = F.normalize(v_end, dim=-1)
|
| 164 |
+
|
| 165 |
+
cos_angle = (v_start_norm * v_end_norm).sum()
|
| 166 |
+
holonomy = 1 - cos_angle.abs() # 0 = flat, 1 = maximally curved
|
| 167 |
+
|
| 168 |
+
return holonomy.item()
|
| 169 |
+
|
| 170 |
+
return 0.0
|
| 171 |
+
|
| 172 |
+
def analyze_sequence(self, hidden_states: torch.Tensor) -> Dict:
|
| 173 |
+
"""Full geometric analysis of a sequence."""
|
| 174 |
+
|
| 175 |
+
# Basic derivatives
|
| 176 |
+
velocities = self.compute_velocities(hidden_states)
|
| 177 |
+
speeds = self.compute_speeds(velocities)
|
| 178 |
+
curvature = self.compute_curvature(velocities)
|
| 179 |
+
torsion = self.compute_torsion(hidden_states)
|
| 180 |
+
|
| 181 |
+
# Loop detection and holonomy
|
| 182 |
+
loops = self.detect_loops(hidden_states, self.config.holonomy_threshold)
|
| 183 |
+
holonomies = [self.compute_holonomy(hidden_states, loop) for loop in loops]
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
"velocities": velocities,
|
| 187 |
+
"speeds": speeds,
|
| 188 |
+
"curvature": curvature,
|
| 189 |
+
"torsion": torsion,
|
| 190 |
+
"loops": loops,
|
| 191 |
+
"holonomies": holonomies,
|
| 192 |
+
|
| 193 |
+
# Summary statistics
|
| 194 |
+
"mean_speed": speeds.mean().item(),
|
| 195 |
+
"std_speed": speeds.std().item(),
|
| 196 |
+
"mean_curvature": curvature.mean().item() if len(curvature) > 0 else 0,
|
| 197 |
+
"max_curvature": curvature.max().item() if len(curvature) > 0 else 0,
|
| 198 |
+
"mean_torsion": torsion.mean().item() if len(torsion) > 0 else 0,
|
| 199 |
+
"n_loops": len(loops),
|
| 200 |
+
"mean_holonomy": np.mean(holonomies) if holonomies else 0,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# =============================================================================
|
| 205 |
+
# REPETITION DETECTION (Ground Truth)
|
| 206 |
+
# =============================================================================
|
| 207 |
+
|
| 208 |
+
def detect_repetitions(token_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
Create binary labels: 1 if token is a repetition, 0 otherwise.
|
| 211 |
+
"""
|
| 212 |
+
seq_len = token_ids.shape[0]
|
| 213 |
+
labels = torch.zeros(seq_len)
|
| 214 |
+
|
| 215 |
+
for i in range(1, seq_len):
|
| 216 |
+
start = max(0, i - window)
|
| 217 |
+
if token_ids[i] in token_ids[start:i]:
|
| 218 |
+
labels[i] = 1.0
|
| 219 |
+
|
| 220 |
+
return labels
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def detect_ngram_repetitions(token_ids: torch.Tensor, n: int = 3) -> torch.Tensor:
|
| 224 |
+
"""
|
| 225 |
+
Detect n-gram repetitions (more sophisticated).
|
| 226 |
+
"""
|
| 227 |
+
seq_len = token_ids.shape[0]
|
| 228 |
+
labels = torch.zeros(seq_len)
|
| 229 |
+
|
| 230 |
+
seen_ngrams = set()
|
| 231 |
+
|
| 232 |
+
for i in range(n - 1, seq_len):
|
| 233 |
+
ngram = tuple(token_ids[i-n+1:i+1].tolist())
|
| 234 |
+
if ngram in seen_ngrams:
|
| 235 |
+
labels[i] = 1.0
|
| 236 |
+
seen_ngrams.add(ngram)
|
| 237 |
+
|
| 238 |
+
return labels
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# =============================================================================
|
| 242 |
+
# PROBE TRAINING
|
| 243 |
+
# =============================================================================
|
| 244 |
+
|
| 245 |
+
class GeometricProbe(nn.Module):
|
| 246 |
+
"""
|
| 247 |
+
Probe that uses geometric features (velocity, curvature) instead of
|
| 248 |
+
raw hidden states.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
def __init__(self, input_dim: int, hidden_dim: int = 64):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.net = nn.Sequential(
|
| 254 |
+
nn.Linear(input_dim, hidden_dim),
|
| 255 |
+
nn.GELU(),
|
| 256 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 257 |
+
nn.GELU(),
|
| 258 |
+
nn.Linear(hidden_dim, 1),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def forward(self, x):
|
| 262 |
+
return self.net(x)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class CurvatureProbe(nn.Module):
|
| 266 |
+
"""
|
| 267 |
+
Probe that takes: [velocity, acceleration, curvature_scalar]
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, d_model: int):
|
| 271 |
+
super().__init__()
|
| 272 |
+
# velocity (d_model) + acceleration (d_model) + curvature (1)
|
| 273 |
+
input_dim = d_model * 2 + 1
|
| 274 |
+
self.net = nn.Sequential(
|
| 275 |
+
nn.Linear(input_dim, 128),
|
| 276 |
+
nn.GELU(),
|
| 277 |
+
nn.Linear(128, 64),
|
| 278 |
+
nn.GELU(),
|
| 279 |
+
nn.Linear(64, 1),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, velocity, acceleration, curvature_scalar):
|
| 283 |
+
x = torch.cat([
|
| 284 |
+
velocity,
|
| 285 |
+
acceleration,
|
| 286 |
+
curvature_scalar.unsqueeze(-1)
|
| 287 |
+
], dim=-1)
|
| 288 |
+
return self.net(x)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# =============================================================================
|
| 292 |
+
# MAIN EXPERIMENT
|
| 293 |
+
# =============================================================================
|
| 294 |
+
|
| 295 |
+
class LieHolonomyExperiment:
|
| 296 |
+
"""
|
| 297 |
+
Main experiment: compare geometric probes vs raw hidden state probes.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(self, config: GeometryConfig):
|
| 301 |
+
self.config = config
|
| 302 |
+
self.analyzer = ManifoldAnalyzer(config)
|
| 303 |
+
self.device = config.device
|
| 304 |
+
|
| 305 |
+
# Results storage
|
| 306 |
+
self.results = {
|
| 307 |
+
"sequences": [],
|
| 308 |
+
"geometry_stats": [],
|
| 309 |
+
"correlations": {},
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
# Load model
|
| 313 |
+
self._load_model()
|
| 314 |
+
|
| 315 |
+
def _load_model(self):
|
| 316 |
+
"""Load the model."""
|
| 317 |
+
print("Loading model...")
|
| 318 |
+
|
| 319 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 320 |
+
self.config.model_path,
|
| 321 |
+
local_files_only=True
|
| 322 |
+
)
|
| 323 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 324 |
+
|
| 325 |
+
bnb_config = BitsAndBytesConfig(
|
| 326 |
+
load_in_4bit=True,
|
| 327 |
+
bnb_4bit_quant_type="nf4",
|
| 328 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 332 |
+
self.config.model_path,
|
| 333 |
+
quantization_config=bnb_config,
|
| 334 |
+
device_map="auto",
|
| 335 |
+
torch_dtype=torch.bfloat16,
|
| 336 |
+
local_files_only=True,
|
| 337 |
+
)
|
| 338 |
+
self.model.eval()
|
| 339 |
+
|
| 340 |
+
self.d_model = self.model.config.hidden_size
|
| 341 |
+
self.n_layers = self.model.config.num_hidden_layers
|
| 342 |
+
|
| 343 |
+
print(f"Model loaded: {self.d_model} hidden dim, {self.n_layers} layers")
|
| 344 |
+
|
| 345 |
+
def generate_with_hidden_states(self, prompt: str) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 346 |
+
"""Generate and capture all hidden states."""
|
| 347 |
+
|
| 348 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
| 349 |
+
|
| 350 |
+
all_hidden_states = []
|
| 351 |
+
generated_ids = inputs.input_ids.clone()
|
| 352 |
+
|
| 353 |
+
for step in range(self.config.max_length):
|
| 354 |
+
with torch.no_grad():
|
| 355 |
+
outputs = self.model(
|
| 356 |
+
input_ids=generated_ids,
|
| 357 |
+
output_hidden_states=True,
|
| 358 |
+
return_dict=True,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Get hidden states from last layer, last position
|
| 362 |
+
hidden = outputs.hidden_states[-1][:, -1, :] # [1, d_model]
|
| 363 |
+
all_hidden_states.append(hidden.squeeze(0).cpu())
|
| 364 |
+
|
| 365 |
+
# Sample next token
|
| 366 |
+
logits = outputs.logits[:, -1, :]
|
| 367 |
+
probs = F.softmax(logits / 0.8, dim=-1)
|
| 368 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 369 |
+
|
| 370 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
| 371 |
+
|
| 372 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
hidden_states = torch.stack(all_hidden_states) # [seq_len, d_model]
|
| 376 |
+
return generated_ids.squeeze(0).cpu(), hidden_states
|
| 377 |
+
|
| 378 |
+
def run_experiment(self, prompts: List[str] = None):
|
| 379 |
+
"""Run the full experiment."""
|
| 380 |
+
|
| 381 |
+
if prompts is None:
|
| 382 |
+
prompts = [
|
| 383 |
+
"Once upon a time",
|
| 384 |
+
"The meaning of life is",
|
| 385 |
+
"In the beginning there was",
|
| 386 |
+
"To be or not to be",
|
| 387 |
+
"The quick brown fox",
|
| 388 |
+
"Explain how neural networks",
|
| 389 |
+
"Write a story about",
|
| 390 |
+
"The most important thing",
|
| 391 |
+
"Scientists discovered that",
|
| 392 |
+
"In a world where",
|
| 393 |
+
] * 10 # 100 sequences
|
| 394 |
+
|
| 395 |
+
print(f"\nRunning experiment on {len(prompts)} sequences...")
|
| 396 |
+
|
| 397 |
+
all_curvatures = []
|
| 398 |
+
all_repetition_labels = []
|
| 399 |
+
all_holonomies = []
|
| 400 |
+
all_speeds = []
|
| 401 |
+
|
| 402 |
+
for i, prompt in enumerate(tqdm(prompts)):
|
| 403 |
+
try:
|
| 404 |
+
# Generate
|
| 405 |
+
token_ids, hidden_states = self.generate_with_hidden_states(prompt)
|
| 406 |
+
|
| 407 |
+
# Geometric analysis
|
| 408 |
+
geometry = self.analyzer.analyze_sequence(hidden_states)
|
| 409 |
+
|
| 410 |
+
# Repetition labels
|
| 411 |
+
rep_labels = detect_repetitions(token_ids)
|
| 412 |
+
ngram_labels = detect_ngram_repetitions(token_ids)
|
| 413 |
+
|
| 414 |
+
# Align lengths (curvature is shorter due to derivatives)
|
| 415 |
+
min_len = min(len(geometry["curvature"]), len(rep_labels) - 2)
|
| 416 |
+
if min_len > 0:
|
| 417 |
+
curvature = geometry["curvature"][:min_len]
|
| 418 |
+
labels = rep_labels[2:2+min_len] # Offset by 2 for second derivative
|
| 419 |
+
|
| 420 |
+
all_curvatures.extend(curvature.tolist())
|
| 421 |
+
all_repetition_labels.extend(labels.tolist())
|
| 422 |
+
|
| 423 |
+
# Store sequence data
|
| 424 |
+
self.results["sequences"].append({
|
| 425 |
+
"prompt": prompt,
|
| 426 |
+
"length": len(token_ids),
|
| 427 |
+
"n_repetitions": int(rep_labels.sum()),
|
| 428 |
+
"n_ngram_repetitions": int(ngram_labels.sum()),
|
| 429 |
+
**{k: v for k, v in geometry.items()
|
| 430 |
+
if isinstance(v, (int, float))}
|
| 431 |
+
})
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"Error on prompt {i}: {e}")
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
# Compute correlations
|
| 438 |
+
self._compute_correlations(all_curvatures, all_repetition_labels)
|
| 439 |
+
|
| 440 |
+
# Save results
|
| 441 |
+
self._save_results()
|
| 442 |
+
|
| 443 |
+
return self.results
|
| 444 |
+
|
| 445 |
+
def _compute_correlations(self, curvatures: List[float], labels: List[float]):
|
| 446 |
+
"""Compute correlations between geometry and repetition."""
|
| 447 |
+
|
| 448 |
+
curvatures = np.array(curvatures)
|
| 449 |
+
labels = np.array(labels)
|
| 450 |
+
|
| 451 |
+
# Basic correlation
|
| 452 |
+
if len(curvatures) > 0 and len(labels) > 0:
|
| 453 |
+
correlation = np.corrcoef(curvatures, labels)[0, 1]
|
| 454 |
+
else:
|
| 455 |
+
correlation = 0
|
| 456 |
+
|
| 457 |
+
# Split by label and compare means
|
| 458 |
+
rep_indices = labels > 0.5
|
| 459 |
+
non_rep_indices = labels < 0.5
|
| 460 |
+
|
| 461 |
+
if rep_indices.sum() > 0 and non_rep_indices.sum() > 0:
|
| 462 |
+
mean_curv_rep = curvatures[rep_indices].mean()
|
| 463 |
+
mean_curv_nonrep = curvatures[non_rep_indices].mean()
|
| 464 |
+
separation = mean_curv_rep / (mean_curv_nonrep + 1e-8)
|
| 465 |
+
else:
|
| 466 |
+
mean_curv_rep = 0
|
| 467 |
+
mean_curv_nonrep = 0
|
| 468 |
+
separation = 1.0
|
| 469 |
+
|
| 470 |
+
self.results["correlations"] = {
|
| 471 |
+
"curvature_repetition_correlation": float(correlation),
|
| 472 |
+
"mean_curvature_at_repetition": float(mean_curv_rep),
|
| 473 |
+
"mean_curvature_at_non_repetition": float(mean_curv_nonrep),
|
| 474 |
+
"curvature_separation_ratio": float(separation),
|
| 475 |
+
"n_samples": len(curvatures),
|
| 476 |
+
"n_repetitions": int(labels.sum()),
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
print("\n" + "="*60)
|
| 480 |
+
print("GEOMETRIC ANALYSIS RESULTS")
|
| 481 |
+
print("="*60)
|
| 482 |
+
print(f"Correlation (curvature <-> repetition): {correlation:.4f}")
|
| 483 |
+
print(f"Mean curvature at repetitions: {mean_curv_rep:.6f}")
|
| 484 |
+
print(f"Mean curvature at non-repetitions: {mean_curv_nonrep:.6f}")
|
| 485 |
+
print(f"Separation ratio: {separation:.2f}x")
|
| 486 |
+
print(f"Total samples: {len(curvatures)}")
|
| 487 |
+
print(f"Total repetitions: {int(labels.sum())}")
|
| 488 |
+
print("="*60)
|
| 489 |
+
|
| 490 |
+
# Interpretation
|
| 491 |
+
if separation > 2.0:
|
| 492 |
+
print("\n🎯 STRONG SIGNAL: Curvature predicts repetition!")
|
| 493 |
+
print(" This validates the geometric hypothesis.")
|
| 494 |
+
elif separation > 1.3:
|
| 495 |
+
print("\n📊 MODERATE SIGNAL: Some predictive power.")
|
| 496 |
+
print(" Worth investigating further.")
|
| 497 |
+
else:
|
| 498 |
+
print("\n⚠️ WEAK SIGNAL: Curvature alone may not be enough.")
|
| 499 |
+
print(" Try holonomy or learned geometric features.")
|
| 500 |
+
|
| 501 |
+
def _save_results(self):
|
| 502 |
+
"""Save results to disk."""
|
| 503 |
+
output_dir = Path(self.config.output_dir)
|
| 504 |
+
output_dir.mkdir(exist_ok=True)
|
| 505 |
+
|
| 506 |
+
# Save JSON summary
|
| 507 |
+
summary = {
|
| 508 |
+
"config": {
|
| 509 |
+
"n_sequences": self.config.n_sequences,
|
| 510 |
+
"max_length": self.config.max_length,
|
| 511 |
+
},
|
| 512 |
+
"correlations": self.results["correlations"],
|
| 513 |
+
"sequence_stats": {
|
| 514 |
+
"mean_length": np.mean([s["length"] for s in self.results["sequences"]]),
|
| 515 |
+
"mean_repetitions": np.mean([s["n_repetitions"] for s in self.results["sequences"]]),
|
| 516 |
+
"mean_curvature": np.mean([s["mean_curvature"] for s in self.results["sequences"]]),
|
| 517 |
+
"mean_n_loops": np.mean([s["n_loops"] for s in self.results["sequences"]]),
|
| 518 |
+
"mean_holonomy": np.mean([s["mean_holonomy"] for s in self.results["sequences"]]),
|
| 519 |
+
}
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
with open(output_dir / "geometry_results.json", "w") as f:
|
| 523 |
+
json.dump(summary, f, indent=2)
|
| 524 |
+
|
| 525 |
+
print(f"\nResults saved to {output_dir}/geometry_results.json")
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# =============================================================================
|
| 529 |
+
# CONNECTION NETWORK - THE KEY TO HOLONOMY
|
| 530 |
+
# =============================================================================
|
| 531 |
+
|
| 532 |
+
class ConnectionNetwork(nn.Module):
|
| 533 |
+
"""
|
| 534 |
+
Learns the Levi-Civita connection on the hidden state manifold.
|
| 535 |
+
|
| 536 |
+
This is the key insight: if we can learn how to parallel transport
|
| 537 |
+
vectors along paths, we can detect curvature (holonomy).
|
| 538 |
+
|
| 539 |
+
The connection tells us: "If I move from point A to point B,
|
| 540 |
+
how should a vector at A transform to stay 'parallel'?"
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, d_model: int, d_connection: int = 256):
|
| 544 |
+
super().__init__()
|
| 545 |
+
|
| 546 |
+
# Encode the path between two points
|
| 547 |
+
self.path_encoder = nn.Sequential(
|
| 548 |
+
nn.Linear(d_model * 2, d_connection),
|
| 549 |
+
nn.GELU(),
|
| 550 |
+
nn.Linear(d_connection, d_connection),
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
# Predict the transport matrix (simplified as a learned transformation)
|
| 554 |
+
self.transport_predictor = nn.Sequential(
|
| 555 |
+
nn.Linear(d_connection, d_connection),
|
| 556 |
+
nn.GELU(),
|
| 557 |
+
nn.Linear(d_connection, d_model * d_model), # Full matrix
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
self.d_model = d_model
|
| 561 |
+
|
| 562 |
+
def forward(self, h_start: torch.Tensor, h_end: torch.Tensor,
|
| 563 |
+
v_start: torch.Tensor) -> torch.Tensor:
|
| 564 |
+
"""
|
| 565 |
+
Parallel transport vector v_start from h_start to h_end.
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
h_start: Starting hidden state [batch, d_model]
|
| 569 |
+
h_end: Ending hidden state [batch, d_model]
|
| 570 |
+
v_start: Vector to transport [batch, d_model]
|
| 571 |
+
|
| 572 |
+
Returns:
|
| 573 |
+
v_transported: Transported vector at h_end [batch, d_model]
|
| 574 |
+
"""
|
| 575 |
+
batch_size = h_start.shape[0]
|
| 576 |
+
|
| 577 |
+
# Encode the path
|
| 578 |
+
path = torch.cat([h_start, h_end], dim=-1)
|
| 579 |
+
path_encoding = self.path_encoder(path)
|
| 580 |
+
|
| 581 |
+
# Get transport matrix
|
| 582 |
+
transport_flat = self.transport_predictor(path_encoding)
|
| 583 |
+
transport_matrix = transport_flat.view(batch_size, self.d_model, self.d_model)
|
| 584 |
+
|
| 585 |
+
# Apply transport (with residual connection for stability)
|
| 586 |
+
v_transported = torch.bmm(transport_matrix, v_start.unsqueeze(-1)).squeeze(-1)
|
| 587 |
+
v_transported = v_transported + v_start # Residual
|
| 588 |
+
|
| 589 |
+
return v_transported
|
| 590 |
+
|
| 591 |
+
def compute_holonomy(self, hidden_states: torch.Tensor,
|
| 592 |
+
loop_indices: List[int]) -> torch.Tensor:
|
| 593 |
+
"""
|
| 594 |
+
Compute holonomy around a loop defined by indices.
|
| 595 |
+
|
| 596 |
+
Holonomy = what you get when you parallel transport a vector
|
| 597 |
+
around a closed loop and compare to the original.
|
| 598 |
+
"""
|
| 599 |
+
if len(loop_indices) < 3:
|
| 600 |
+
return torch.tensor(0.0)
|
| 601 |
+
|
| 602 |
+
# Start with a basis vector
|
| 603 |
+
v = torch.randn(1, self.d_model, device=hidden_states.device)
|
| 604 |
+
v = F.normalize(v, dim=-1)
|
| 605 |
+
v_original = v.clone()
|
| 606 |
+
|
| 607 |
+
# Transport around the loop
|
| 608 |
+
for i in range(len(loop_indices) - 1):
|
| 609 |
+
idx_start = loop_indices[i]
|
| 610 |
+
idx_end = loop_indices[i + 1]
|
| 611 |
+
h_start = hidden_states[idx_start].unsqueeze(0)
|
| 612 |
+
h_end = hidden_states[idx_end].unsqueeze(0)
|
| 613 |
+
v = self.forward(h_start, h_end, v)
|
| 614 |
+
|
| 615 |
+
# Close the loop
|
| 616 |
+
h_start = hidden_states[loop_indices[-1]].unsqueeze(0)
|
| 617 |
+
h_end = hidden_states[loop_indices[0]].unsqueeze(0)
|
| 618 |
+
v_final = self.forward(h_start, h_end, v)
|
| 619 |
+
|
| 620 |
+
# Holonomy magnitude
|
| 621 |
+
holonomy = 1 - F.cosine_similarity(v_final, v_original, dim=-1)
|
| 622 |
+
|
| 623 |
+
return holonomy
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
# =============================================================================
|
| 627 |
+
# RUN EXPERIMENT
|
| 628 |
+
# =============================================================================
|
| 629 |
+
|
| 630 |
+
def main():
|
| 631 |
+
"""Run the Lie Holonomy experiment."""
|
| 632 |
+
|
| 633 |
+
print("="*60)
|
| 634 |
+
print("LIE HOLONOMY TRANSFORMER - GEOMETRIC ANALYSIS")
|
| 635 |
+
print("="*60)
|
| 636 |
+
print("\nHypothesis: Hidden state GEOMETRY (curvature, holonomy)")
|
| 637 |
+
print("predicts model behavior better than raw states.\n")
|
| 638 |
+
|
| 639 |
+
config = GeometryConfig(
|
| 640 |
+
model_path=".", # Current directory
|
| 641 |
+
n_sequences=100,
|
| 642 |
+
max_length=256,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
experiment = LieHolonomyExperiment(config)
|
| 646 |
+
results = experiment.run_experiment()
|
| 647 |
+
|
| 648 |
+
print("\n" + "="*60)
|
| 649 |
+
print("EXPERIMENT COMPLETE")
|
| 650 |
+
print("="*60)
|
| 651 |
+
print("\nNext steps based on results:")
|
| 652 |
+
|
| 653 |
+
sep = results["correlations"]["curvature_separation_ratio"]
|
| 654 |
+
|
| 655 |
+
if sep > 2.0:
|
| 656 |
+
print("""
|
| 657 |
+
✅ STRONG SIGNAL DETECTED
|
| 658 |
+
|
| 659 |
+
The geometric approach shows promise. Next:
|
| 660 |
+
1. Train a CurvatureProbe to beat your 80x probe
|
| 661 |
+
2. Implement the ConnectionNetwork for learned parallel transport
|
| 662 |
+
3. Use holonomy as a NEW training signal for self-improvement
|
| 663 |
+
|
| 664 |
+
This could be the breakthrough.
|
| 665 |
+
""")
|
| 666 |
+
elif sep > 1.3:
|
| 667 |
+
print("""
|
| 668 |
+
📊 MODERATE SIGNAL
|
| 669 |
+
|
| 670 |
+
Worth pursuing. Try:
|
| 671 |
+
1. Different geometric features (torsion, geodesic deviation)
|
| 672 |
+
2. Multi-layer analysis (geometry at each transformer layer)
|
| 673 |
+
3. Larger sample sizes
|
| 674 |
+
""")
|
| 675 |
+
else:
|
| 676 |
+
print("""
|
| 677 |
+
⚠️ WEAK SIGNAL
|
| 678 |
+
|
| 679 |
+
Raw curvature may not be the right feature. Try:
|
| 680 |
+
1. Learned geometric features (train the connection network)
|
| 681 |
+
2. Sectional curvature (curvature in specific 2D planes)
|
| 682 |
+
3. Ricci curvature (average over all directions)
|
| 683 |
+
""")
|
| 684 |
+
|
| 685 |
+
return results
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
if __name__ == "__main__":
|
| 689 |
+
main()
|
code/training_pipelines/05_breakthrough_test_v2_LOOP4.py
ADDED
|
@@ -0,0 +1,935 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LOOP 4 → RSI CEILING BREAKTHROUGH TEST (MEMORY-OPTIMIZED)
|
| 4 |
+
==========================================================
|
| 5 |
+
The critical experiment: Does tokenization co-evolution break the RSI ceiling?
|
| 6 |
+
|
| 7 |
+
Key insight: Adaptation training must EXERCISE THE NEW TOKENS.
|
| 8 |
+
We generate text dense with merged patterns, re-encode with new tokenizer,
|
| 9 |
+
and train the model to predict the merged tokens.
|
| 10 |
+
|
| 11 |
+
MEMORY OPTIMIZED:
|
| 12 |
+
- 4-bit quantization (~5GB model)
|
| 13 |
+
- Batch size 1 with gradient accumulation
|
| 14 |
+
- Reduced sequence lengths
|
| 15 |
+
- Aggressive memory cleanup
|
| 16 |
+
|
| 17 |
+
Expected VRAM: ~12-14GB
|
| 18 |
+
Expected runtime: 45-75 minutes
|
| 19 |
+
|
| 20 |
+
Pipeline:
|
| 21 |
+
1. Load Loop 4 results (merge candidates)
|
| 22 |
+
2. Resize embeddings for new tokens
|
| 23 |
+
3. TARGETED fine-tuning on text dense with merged patterns
|
| 24 |
+
4. Run extended RSI (aim for >5 iterations)
|
| 25 |
+
5. Measure if ceiling has moved
|
| 26 |
+
|
| 27 |
+
Author: Logan Napolitano
|
| 28 |
+
Date: January 2026
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
import numpy as np
|
| 35 |
+
import json
|
| 36 |
+
import os
|
| 37 |
+
import gc
|
| 38 |
+
import re
|
| 39 |
+
from pathlib import Path
|
| 40 |
+
from dataclasses import dataclass
|
| 41 |
+
from typing import List, Dict, Tuple, Optional
|
| 42 |
+
from transformers import (
|
| 43 |
+
AutoModelForCausalLM,
|
| 44 |
+
AutoTokenizer,
|
| 45 |
+
TrainingArguments,
|
| 46 |
+
Trainer,
|
| 47 |
+
DataCollatorForLanguageModeling,
|
| 48 |
+
BitsAndBytesConfig,
|
| 49 |
+
)
|
| 50 |
+
from datasets import Dataset
|
| 51 |
+
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
|
| 52 |
+
from tqdm import tqdm
|
| 53 |
+
import warnings
|
| 54 |
+
warnings.filterwarnings("ignore")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@dataclass
|
| 58 |
+
class BreakthroughConfig:
|
| 59 |
+
model_path: str = "."
|
| 60 |
+
output_dir: str = "loop4_breakthrough"
|
| 61 |
+
device: str = "cuda"
|
| 62 |
+
|
| 63 |
+
# Tokenizer modification
|
| 64 |
+
loop4_results_path: str = "loop4_full_results/loop4_full_results.json"
|
| 65 |
+
top_k_merges: int = 10 # Reduced
|
| 66 |
+
|
| 67 |
+
# Adaptation fine-tuning (targeted)
|
| 68 |
+
adaptation_samples: int = 150 # Reduced
|
| 69 |
+
adaptation_steps: int = 100 # Reduced
|
| 70 |
+
adaptation_lr: float = 2e-5
|
| 71 |
+
adaptation_batch_size: int = 1 # Reduced
|
| 72 |
+
gradient_accumulation: int = 8 # Increased to compensate
|
| 73 |
+
min_pattern_density: float = 0.08
|
| 74 |
+
|
| 75 |
+
# RSI settings
|
| 76 |
+
max_rsi_iterations: int = 10
|
| 77 |
+
rsi_micro_steps: int = 15 # Reduced
|
| 78 |
+
rsi_lr: float = 1e-5
|
| 79 |
+
quality_threshold: float = 0.92
|
| 80 |
+
|
| 81 |
+
# Evaluation
|
| 82 |
+
eval_samples: int = 10 # Reduced
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class EmbeddingResizer:
|
| 86 |
+
"""Handles resizing model embeddings for new tokens."""
|
| 87 |
+
|
| 88 |
+
def __init__(self, model, tokenizer):
|
| 89 |
+
self.model = model
|
| 90 |
+
self.tokenizer = tokenizer
|
| 91 |
+
self.original_vocab_size = len(tokenizer)
|
| 92 |
+
self.original_tokenizer = AutoTokenizer.from_pretrained(
|
| 93 |
+
tokenizer.name_or_path,
|
| 94 |
+
local_files_only=True
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def add_tokens_and_resize(self, new_tokens: List[str]) -> Tuple[int, List[str]]:
|
| 98 |
+
"""
|
| 99 |
+
Add new tokens and resize embeddings.
|
| 100 |
+
Returns: (num_added, list of added tokens)
|
| 101 |
+
"""
|
| 102 |
+
existing_vocab = set(self.tokenizer.get_vocab().keys())
|
| 103 |
+
tokens_to_add = [t for t in new_tokens if t not in existing_vocab]
|
| 104 |
+
|
| 105 |
+
if not tokens_to_add:
|
| 106 |
+
print("All tokens already in vocabulary")
|
| 107 |
+
return 0, []
|
| 108 |
+
|
| 109 |
+
print(f"Adding {len(tokens_to_add)} new tokens...")
|
| 110 |
+
for t in tokens_to_add:
|
| 111 |
+
print(f" + '{repr(t)}'")
|
| 112 |
+
|
| 113 |
+
num_added = self.tokenizer.add_tokens(tokens_to_add)
|
| 114 |
+
print(f"Vocabulary: {self.original_vocab_size} → {len(self.tokenizer)}")
|
| 115 |
+
|
| 116 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 117 |
+
self._initialize_new_embeddings(tokens_to_add)
|
| 118 |
+
|
| 119 |
+
return num_added, tokens_to_add
|
| 120 |
+
|
| 121 |
+
def _initialize_new_embeddings(self, new_tokens: List[str]):
|
| 122 |
+
"""Initialize new embeddings as average of component tokens."""
|
| 123 |
+
embed_weight = self.model.get_input_embeddings().weight
|
| 124 |
+
|
| 125 |
+
for token in new_tokens:
|
| 126 |
+
new_id = self.tokenizer.convert_tokens_to_ids(token)
|
| 127 |
+
|
| 128 |
+
# Get component IDs from original tokenizer
|
| 129 |
+
component_ids = self.original_tokenizer.encode(token, add_special_tokens=False)
|
| 130 |
+
|
| 131 |
+
if component_ids:
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
component_embeds = embed_weight[component_ids]
|
| 134 |
+
embed_weight[new_id] = component_embeds.mean(dim=0)
|
| 135 |
+
print(f" '{token}' initialized from {len(component_ids)} components")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class TargetedDataGenerator:
|
| 139 |
+
"""
|
| 140 |
+
Generates training data DENSE with the merged token patterns.
|
| 141 |
+
This is the key insight: train on text that exercises the new vocabulary.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, model, tokenizer, merged_tokens: List[str], config: BreakthroughConfig):
|
| 145 |
+
self.model = model
|
| 146 |
+
self.tokenizer = tokenizer
|
| 147 |
+
self.merged_tokens = merged_tokens
|
| 148 |
+
self.config = config
|
| 149 |
+
|
| 150 |
+
# Build pattern matchers for each merged token
|
| 151 |
+
self.patterns = []
|
| 152 |
+
for token in merged_tokens:
|
| 153 |
+
# Escape special regex chars
|
| 154 |
+
escaped = re.escape(token)
|
| 155 |
+
self.patterns.append(escaped)
|
| 156 |
+
|
| 157 |
+
# Combined pattern
|
| 158 |
+
if self.patterns:
|
| 159 |
+
self.combined_pattern = re.compile('|'.join(self.patterns))
|
| 160 |
+
else:
|
| 161 |
+
self.combined_pattern = None
|
| 162 |
+
|
| 163 |
+
def count_pattern_matches(self, text: str) -> int:
|
| 164 |
+
"""Count how many merged token patterns appear in text."""
|
| 165 |
+
if not self.combined_pattern:
|
| 166 |
+
return 0
|
| 167 |
+
return len(self.combined_pattern.findall(text))
|
| 168 |
+
|
| 169 |
+
def compute_pattern_density(self, text: str) -> float:
|
| 170 |
+
"""Compute what fraction of text is covered by merged patterns."""
|
| 171 |
+
if not text:
|
| 172 |
+
return 0.0
|
| 173 |
+
matches = self.count_pattern_matches(text)
|
| 174 |
+
# Rough estimate: each match covers ~5 chars on average
|
| 175 |
+
coverage = (matches * 5) / len(text)
|
| 176 |
+
return min(coverage, 1.0)
|
| 177 |
+
|
| 178 |
+
def generate_dense_text(self, n_samples: int) -> List[str]:
|
| 179 |
+
"""
|
| 180 |
+
Generate text that naturally contains many merged token patterns.
|
| 181 |
+
|
| 182 |
+
Strategy:
|
| 183 |
+
1. Use prompts that encourage the patterns (sentence starters, connectives)
|
| 184 |
+
2. Filter for high pattern density
|
| 185 |
+
3. Augment with explicit pattern injection if needed
|
| 186 |
+
"""
|
| 187 |
+
print("Generating pattern-dense training data...")
|
| 188 |
+
|
| 189 |
+
# Prompts designed to elicit the patterns (. The, , and, . In, . It, , the, etc.)
|
| 190 |
+
dense_prompts = [
|
| 191 |
+
# These encourage ". The" pattern
|
| 192 |
+
"The experiment showed remarkable results. The data clearly indicated",
|
| 193 |
+
"She opened the door carefully. The room inside was dark. The air felt cold. The silence was complete.",
|
| 194 |
+
"Scientists made a breakthrough. The discovery changes everything. The implications are vast.",
|
| 195 |
+
|
| 196 |
+
# These encourage ", and" and ", the" patterns
|
| 197 |
+
"The team worked hard, and the results showed improvement, and the project succeeded.",
|
| 198 |
+
"He studied mathematics, physics, and chemistry, and he excelled in all subjects.",
|
| 199 |
+
"We need food, water, and shelter, and the supplies must arrive soon.",
|
| 200 |
+
|
| 201 |
+
# These encourage ". In" pattern
|
| 202 |
+
"The city grew rapidly. In the downtown area, new buildings appeared. In the suburbs, families settled.",
|
| 203 |
+
"Change happens slowly. In the beginning, few noticed. In time, everyone understood.",
|
| 204 |
+
|
| 205 |
+
# These encourage ". It" pattern
|
| 206 |
+
"The machine hummed quietly. It processed data continuously. It never stopped working.",
|
| 207 |
+
"The algorithm converged. It found the optimal solution. It exceeded expectations.",
|
| 208 |
+
|
| 209 |
+
# These encourage ". This" pattern
|
| 210 |
+
"The theory was revolutionary. This changed how scientists thought. This led to new discoveries.",
|
| 211 |
+
"The problem seemed impossible. This made the solution more remarkable. This proved the method worked.",
|
| 212 |
+
|
| 213 |
+
# Dense combinations
|
| 214 |
+
"The research began in January. The team collected data, and the analysis revealed patterns. In the first phase, the results were promising. It became clear that the hypothesis was correct. This validated the entire approach, and the project moved forward.",
|
| 215 |
+
|
| 216 |
+
"The algorithm processes input. The output depends on parameters, and the system optimizes continuously. In each iteration, the model improves. It learns from errors. This creates a feedback loop, and the performance increases.",
|
| 217 |
+
|
| 218 |
+
"The forest stretched endlessly. The trees stood tall, and the leaves rustled softly. In the clearing, sunlight streamed down. It illuminated the path. This was the way forward, and the journey continued.",
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
# Additional templates that force patterns
|
| 222 |
+
templates = [
|
| 223 |
+
"The {noun} was {adj}. The {noun2} seemed {adj2}. In the {place}, {event}. It was {description}. This meant {conclusion}, and the {outcome}.",
|
| 224 |
+
"{Statement}. The evidence was clear. In retrospect, {reflection}. It all made sense. This {insight}, and {result}.",
|
| 225 |
+
"The process began. The first step involved {action}. In this phase, {detail}. It required {requirement}. This ensured {guarantee}, and the {final}.",
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
nouns = ["system", "approach", "method", "solution", "discovery", "pattern", "structure", "concept"]
|
| 229 |
+
adjs = ["remarkable", "significant", "important", "crucial", "fundamental", "essential"]
|
| 230 |
+
places = ["beginning", "end", "middle", "process", "analysis", "study"]
|
| 231 |
+
|
| 232 |
+
texts = []
|
| 233 |
+
attempts = 0
|
| 234 |
+
max_attempts = n_samples * 10
|
| 235 |
+
|
| 236 |
+
# First, use the dense prompts and generate continuations
|
| 237 |
+
for prompt in dense_prompts:
|
| 238 |
+
if len(texts) >= n_samples:
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.device)
|
| 243 |
+
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
outputs = self.model.generate(
|
| 246 |
+
inputs.input_ids,
|
| 247 |
+
max_new_tokens=80, # Reduced
|
| 248 |
+
temperature=0.8,
|
| 249 |
+
do_sample=True,
|
| 250 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 251 |
+
repetition_penalty=1.1,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 255 |
+
density = self.compute_pattern_density(text)
|
| 256 |
+
|
| 257 |
+
if density >= self.config.min_pattern_density:
|
| 258 |
+
texts.append(text)
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
continue
|
| 262 |
+
|
| 263 |
+
# Memory cleanup
|
| 264 |
+
torch.cuda.empty_cache()
|
| 265 |
+
gc.collect()
|
| 266 |
+
|
| 267 |
+
# Generate more with varied prompts
|
| 268 |
+
while len(texts) < n_samples and attempts < max_attempts:
|
| 269 |
+
attempts += 1
|
| 270 |
+
|
| 271 |
+
# Random dense prompt
|
| 272 |
+
base = np.random.choice(dense_prompts)
|
| 273 |
+
|
| 274 |
+
try:
|
| 275 |
+
inputs = self.tokenizer(base, return_tensors="pt").to(self.config.device)
|
| 276 |
+
|
| 277 |
+
with torch.no_grad():
|
| 278 |
+
outputs = self.model.generate(
|
| 279 |
+
inputs.input_ids,
|
| 280 |
+
max_new_tokens=60, # Reduced
|
| 281 |
+
temperature=0.9,
|
| 282 |
+
do_sample=True,
|
| 283 |
+
top_p=0.95,
|
| 284 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 285 |
+
repetition_penalty=1.1,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 289 |
+
density = self.compute_pattern_density(text)
|
| 290 |
+
|
| 291 |
+
if density >= self.config.min_pattern_density * 0.8: # Slightly relaxed
|
| 292 |
+
texts.append(text)
|
| 293 |
+
|
| 294 |
+
except:
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
if attempts % 50 == 0:
|
| 298 |
+
print(f" Generated {len(texts)}/{n_samples} samples ({attempts} attempts)")
|
| 299 |
+
|
| 300 |
+
# If we still need more, inject patterns into generated text
|
| 301 |
+
if len(texts) < n_samples:
|
| 302 |
+
print(f" Augmenting with pattern injection...")
|
| 303 |
+
texts.extend(self._inject_patterns(n_samples - len(texts)))
|
| 304 |
+
|
| 305 |
+
return texts[:n_samples]
|
| 306 |
+
|
| 307 |
+
def _inject_patterns(self, n_samples: int) -> List[str]:
|
| 308 |
+
"""Inject merged patterns into template text."""
|
| 309 |
+
templates = [
|
| 310 |
+
"The {0} was significant. The {1} followed naturally. In the {2}, everything changed. It was clear that {3}. This meant {4}, and the outcome was {5}.",
|
| 311 |
+
"Research showed {0}. The findings were {1}. In particular, the {2} demonstrated {3}. It proved that {4}. This {5}, and {6}.",
|
| 312 |
+
"The system processed {0}. The algorithm computed {1}, and the results showed {2}. In each iteration, the {3} improved. It optimized {4}. This created {5}, and the {6}.",
|
| 313 |
+
]
|
| 314 |
+
|
| 315 |
+
fillers = [
|
| 316 |
+
"the data", "the results", "the analysis", "the model", "the approach",
|
| 317 |
+
"important", "significant", "remarkable", "essential", "fundamental",
|
| 318 |
+
"process", "method", "system", "framework", "structure",
|
| 319 |
+
"the hypothesis held", "the theory worked", "the method succeeded",
|
| 320 |
+
"improvement", "progress", "advancement", "efficiency", "performance",
|
| 321 |
+
"a breakthrough", "new understanding", "better outcomes", "clear benefits",
|
| 322 |
+
"conclusion validated the approach", "study confirmed expectations",
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
texts = []
|
| 326 |
+
for _ in range(n_samples):
|
| 327 |
+
template = np.random.choice(templates)
|
| 328 |
+
fills = np.random.choice(fillers, size=10, replace=True)
|
| 329 |
+
try:
|
| 330 |
+
text = template.format(*fills)
|
| 331 |
+
texts.append(text)
|
| 332 |
+
except:
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
return texts
|
| 336 |
+
|
| 337 |
+
def create_dataset(self, n_samples: int) -> Dataset:
|
| 338 |
+
"""Create dataset with pattern-dense text."""
|
| 339 |
+
texts = self.generate_dense_text(n_samples)
|
| 340 |
+
|
| 341 |
+
# Report statistics
|
| 342 |
+
total_patterns = sum(self.count_pattern_matches(t) for t in texts)
|
| 343 |
+
avg_density = np.mean([self.compute_pattern_density(t) for t in texts])
|
| 344 |
+
|
| 345 |
+
print(f"\nDataset statistics:")
|
| 346 |
+
print(f" Samples: {len(texts)}")
|
| 347 |
+
print(f" Total pattern matches: {total_patterns}")
|
| 348 |
+
print(f" Avg pattern density: {avg_density:.2%}")
|
| 349 |
+
print(f" Patterns per sample: {total_patterns/len(texts):.1f}")
|
| 350 |
+
|
| 351 |
+
return Dataset.from_dict({"text": texts})
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class TargetedAdaptationTrainer:
|
| 355 |
+
"""
|
| 356 |
+
Fine-tuning that specifically teaches the model to USE the new tokens.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
def __init__(self, model, tokenizer, merged_tokens: List[str], config: BreakthroughConfig):
|
| 360 |
+
self.model = model
|
| 361 |
+
self.tokenizer = tokenizer
|
| 362 |
+
self.merged_tokens = merged_tokens
|
| 363 |
+
self.config = config
|
| 364 |
+
|
| 365 |
+
def train(self) -> nn.Module:
|
| 366 |
+
print("\n" + "="*60)
|
| 367 |
+
print("TARGETED ADAPTATION TRAINING")
|
| 368 |
+
print("="*60)
|
| 369 |
+
print(f"Teaching model to use {len(self.merged_tokens)} new tokens")
|
| 370 |
+
print(f"Merged tokens: {self.merged_tokens[:5]}...")
|
| 371 |
+
|
| 372 |
+
# Generate pattern-dense data
|
| 373 |
+
generator = TargetedDataGenerator(
|
| 374 |
+
self.model, self.tokenizer, self.merged_tokens, self.config
|
| 375 |
+
)
|
| 376 |
+
dataset = generator.create_dataset(self.config.adaptation_samples)
|
| 377 |
+
|
| 378 |
+
# Tokenize - this is where merged tokens get used
|
| 379 |
+
def tokenize_fn(examples):
|
| 380 |
+
tokenized = self.tokenizer(
|
| 381 |
+
examples["text"],
|
| 382 |
+
truncation=True,
|
| 383 |
+
max_length=128, # Reduced
|
| 384 |
+
padding="max_length",
|
| 385 |
+
)
|
| 386 |
+
return tokenized
|
| 387 |
+
|
| 388 |
+
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
|
| 389 |
+
|
| 390 |
+
# Verify merged tokens appear in tokenized data
|
| 391 |
+
sample_ids = tokenized[0]["input_ids"]
|
| 392 |
+
new_token_ids = set(
|
| 393 |
+
self.tokenizer.convert_tokens_to_ids(t) for t in self.merged_tokens
|
| 394 |
+
)
|
| 395 |
+
found = sum(1 for id in sample_ids if id in new_token_ids)
|
| 396 |
+
print(f" New tokens in sample: {found}")
|
| 397 |
+
|
| 398 |
+
# LoRA setup
|
| 399 |
+
lora_config = LoraConfig(
|
| 400 |
+
task_type=TaskType.CAUSAL_LM,
|
| 401 |
+
r=16, # Reduced for memory
|
| 402 |
+
lora_alpha=32,
|
| 403 |
+
lora_dropout=0.05,
|
| 404 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
# Prepare for quantized training
|
| 408 |
+
self.model = prepare_model_for_kbit_training(self.model)
|
| 409 |
+
peft_model = get_peft_model(self.model, lora_config)
|
| 410 |
+
peft_model.print_trainable_parameters()
|
| 411 |
+
|
| 412 |
+
# Training
|
| 413 |
+
training_args = TrainingArguments(
|
| 414 |
+
output_dir=f"{self.config.output_dir}/adaptation",
|
| 415 |
+
max_steps=self.config.adaptation_steps,
|
| 416 |
+
per_device_train_batch_size=self.config.adaptation_batch_size,
|
| 417 |
+
gradient_accumulation_steps=self.config.gradient_accumulation,
|
| 418 |
+
learning_rate=self.config.adaptation_lr,
|
| 419 |
+
warmup_steps=15,
|
| 420 |
+
logging_steps=10,
|
| 421 |
+
save_strategy="no",
|
| 422 |
+
fp16=True,
|
| 423 |
+
report_to="none",
|
| 424 |
+
dataloader_pin_memory=False,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 428 |
+
tokenizer=self.tokenizer,
|
| 429 |
+
mlm=False,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
trainer = Trainer(
|
| 433 |
+
model=peft_model,
|
| 434 |
+
args=training_args,
|
| 435 |
+
train_dataset=tokenized,
|
| 436 |
+
data_collator=data_collator,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
print("\nTraining on pattern-dense data...")
|
| 440 |
+
trainer.train()
|
| 441 |
+
|
| 442 |
+
# Merge weights
|
| 443 |
+
print("Merging LoRA weights...")
|
| 444 |
+
merged_model = peft_model.merge_and_unload()
|
| 445 |
+
|
| 446 |
+
# Quick verification
|
| 447 |
+
print("\nVerifying adaptation...")
|
| 448 |
+
self._verify_adaptation(merged_model)
|
| 449 |
+
|
| 450 |
+
return merged_model
|
| 451 |
+
|
| 452 |
+
def _verify_adaptation(self, model):
|
| 453 |
+
"""Verify the model uses new tokens correctly."""
|
| 454 |
+
test_prompt = "The research showed results. The"
|
| 455 |
+
|
| 456 |
+
inputs = self.tokenizer(test_prompt, return_tensors="pt").to(self.config.device)
|
| 457 |
+
|
| 458 |
+
with torch.no_grad():
|
| 459 |
+
outputs = model.generate(
|
| 460 |
+
inputs.input_ids,
|
| 461 |
+
max_new_tokens=30,
|
| 462 |
+
temperature=0.7,
|
| 463 |
+
do_sample=True,
|
| 464 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 468 |
+
print(f" Test generation: '{response[:100]}...'")
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class ExtendedRSI:
|
| 472 |
+
"""Extended RSI to test ceiling breakthrough."""
|
| 473 |
+
|
| 474 |
+
def __init__(self, model, tokenizer, config: BreakthroughConfig):
|
| 475 |
+
self.model = model
|
| 476 |
+
self.tokenizer = tokenizer
|
| 477 |
+
self.config = config
|
| 478 |
+
self.iteration_history = []
|
| 479 |
+
self.baseline_quality = None
|
| 480 |
+
|
| 481 |
+
def evaluate_quality(self, n_samples: int = None) -> Dict:
|
| 482 |
+
if n_samples is None:
|
| 483 |
+
n_samples = self.config.eval_samples
|
| 484 |
+
|
| 485 |
+
prompts = [
|
| 486 |
+
"Explain the concept of recursion in programming:",
|
| 487 |
+
"What are the key principles of effective communication?",
|
| 488 |
+
"Describe the process of photosynthesis:",
|
| 489 |
+
"How do neural networks learn from data?",
|
| 490 |
+
"What is the scientific method and why is it important?",
|
| 491 |
+
"Explain the relationship between supply and demand:",
|
| 492 |
+
"What are the main challenges in renewable energy?",
|
| 493 |
+
"How does memory work in the human brain?",
|
| 494 |
+
"Describe the structure of an atom:",
|
| 495 |
+
"What factors influence climate patterns?",
|
| 496 |
+
]
|
| 497 |
+
|
| 498 |
+
metrics = {
|
| 499 |
+
"coherence": [],
|
| 500 |
+
"completeness": [],
|
| 501 |
+
"repetition_rate": [],
|
| 502 |
+
"avg_length": [],
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
for prompt in prompts[:n_samples]:
|
| 506 |
+
try:
|
| 507 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.config.device)
|
| 508 |
+
|
| 509 |
+
with torch.no_grad():
|
| 510 |
+
outputs = self.model.generate(
|
| 511 |
+
inputs.input_ids,
|
| 512 |
+
max_new_tokens=100,
|
| 513 |
+
temperature=0.7,
|
| 514 |
+
do_sample=True,
|
| 515 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 516 |
+
repetition_penalty=1.1,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 520 |
+
response = response[len(prompt):].strip()
|
| 521 |
+
|
| 522 |
+
tokens = response.split()
|
| 523 |
+
|
| 524 |
+
# Repetition
|
| 525 |
+
if len(tokens) > 0:
|
| 526 |
+
unique_ratio = len(set(tokens)) / len(tokens)
|
| 527 |
+
metrics["repetition_rate"].append(1 - unique_ratio)
|
| 528 |
+
else:
|
| 529 |
+
metrics["repetition_rate"].append(1.0)
|
| 530 |
+
|
| 531 |
+
metrics["avg_length"].append(len(tokens))
|
| 532 |
+
metrics["coherence"].append(1.0 if len(tokens) > 10 else 0.5)
|
| 533 |
+
metrics["completeness"].append(1.0 if response and response[-1] in '.!?' else 0.7)
|
| 534 |
+
|
| 535 |
+
except:
|
| 536 |
+
continue
|
| 537 |
+
|
| 538 |
+
result = {
|
| 539 |
+
"coherence": np.mean(metrics["coherence"]) if metrics["coherence"] else 0,
|
| 540 |
+
"completeness": np.mean(metrics["completeness"]) if metrics["completeness"] else 0,
|
| 541 |
+
"repetition_rate": np.mean(metrics["repetition_rate"]) if metrics["repetition_rate"] else 1,
|
| 542 |
+
"avg_length": np.mean(metrics["avg_length"]) if metrics["avg_length"] else 0,
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
result["quality_score"] = (
|
| 546 |
+
result["coherence"] * 0.3 +
|
| 547 |
+
result["completeness"] * 0.3 +
|
| 548 |
+
(1 - result["repetition_rate"]) * 0.4
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
return result
|
| 552 |
+
|
| 553 |
+
def generate_rsi_data(self, n_samples: int = 30) -> Dataset: # Reduced default
|
| 554 |
+
prompts = [
|
| 555 |
+
"Write a clear explanation of",
|
| 556 |
+
"Describe in detail how",
|
| 557 |
+
"Explain the relationship between",
|
| 558 |
+
"What are the key aspects of",
|
| 559 |
+
"Provide an analysis of",
|
| 560 |
+
]
|
| 561 |
+
|
| 562 |
+
topics = [
|
| 563 |
+
"machine learning algorithms", "climate systems", "economic markets",
|
| 564 |
+
"cognitive processes", "technological change", "scientific methods",
|
| 565 |
+
"social structures", "mathematical proofs", "biological evolution",
|
| 566 |
+
"energy systems",
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
texts = []
|
| 570 |
+
|
| 571 |
+
for prompt in prompts:
|
| 572 |
+
for topic in topics:
|
| 573 |
+
full_prompt = f"{prompt} {topic}:"
|
| 574 |
+
|
| 575 |
+
try:
|
| 576 |
+
inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.config.device)
|
| 577 |
+
|
| 578 |
+
with torch.no_grad():
|
| 579 |
+
outputs = self.model.generate(
|
| 580 |
+
inputs.input_ids,
|
| 581 |
+
max_new_tokens=60, # Reduced
|
| 582 |
+
temperature=0.7,
|
| 583 |
+
do_sample=True,
|
| 584 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 585 |
+
repetition_penalty=1.1,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 589 |
+
texts.append(text)
|
| 590 |
+
|
| 591 |
+
except:
|
| 592 |
+
continue
|
| 593 |
+
|
| 594 |
+
if len(texts) >= n_samples:
|
| 595 |
+
break
|
| 596 |
+
if len(texts) >= n_samples:
|
| 597 |
+
break
|
| 598 |
+
|
| 599 |
+
return Dataset.from_dict({"text": texts})
|
| 600 |
+
|
| 601 |
+
def micro_train(self, dataset: Dataset):
|
| 602 |
+
def tokenize_fn(examples):
|
| 603 |
+
return self.tokenizer(
|
| 604 |
+
examples["text"],
|
| 605 |
+
truncation=True,
|
| 606 |
+
max_length=128, # Reduced
|
| 607 |
+
padding="max_length",
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
|
| 611 |
+
|
| 612 |
+
lora_config = LoraConfig(
|
| 613 |
+
task_type=TaskType.CAUSAL_LM,
|
| 614 |
+
r=8,
|
| 615 |
+
lora_alpha=16,
|
| 616 |
+
lora_dropout=0.05,
|
| 617 |
+
target_modules=["q_proj", "v_proj"],
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# Prepare model for quantized training
|
| 621 |
+
model_for_training = prepare_model_for_kbit_training(self.model)
|
| 622 |
+
peft_model = get_peft_model(model_for_training, lora_config)
|
| 623 |
+
|
| 624 |
+
training_args = TrainingArguments(
|
| 625 |
+
output_dir=f"{self.config.output_dir}/rsi_temp",
|
| 626 |
+
max_steps=self.config.rsi_micro_steps,
|
| 627 |
+
per_device_train_batch_size=1, # Reduced
|
| 628 |
+
gradient_accumulation_steps=4, # Increased
|
| 629 |
+
learning_rate=self.config.rsi_lr,
|
| 630 |
+
warmup_steps=2,
|
| 631 |
+
logging_steps=100,
|
| 632 |
+
fp16=True,
|
| 633 |
+
report_to="none",
|
| 634 |
+
save_strategy="no",
|
| 635 |
+
dataloader_pin_memory=False,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 639 |
+
tokenizer=self.tokenizer,
|
| 640 |
+
mlm=False,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
trainer = Trainer(
|
| 644 |
+
model=peft_model,
|
| 645 |
+
args=training_args,
|
| 646 |
+
train_dataset=tokenized,
|
| 647 |
+
data_collator=data_collator,
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
trainer.train()
|
| 651 |
+
self.model = peft_model.merge_and_unload()
|
| 652 |
+
|
| 653 |
+
def run(self) -> Dict:
|
| 654 |
+
print("\n" + "="*60)
|
| 655 |
+
print("EXTENDED RSI - CEILING BREAKTHROUGH TEST")
|
| 656 |
+
print("="*60)
|
| 657 |
+
print(f"Previous ceiling: 3-5 iterations")
|
| 658 |
+
print(f"Target: >5 successful iterations")
|
| 659 |
+
print(f"Max attempts: {self.config.max_rsi_iterations}")
|
| 660 |
+
|
| 661 |
+
# Baseline
|
| 662 |
+
print("\nEstablishing baseline...")
|
| 663 |
+
self.baseline_quality = self.evaluate_quality()
|
| 664 |
+
print(f"Baseline: {self.baseline_quality['quality_score']:.4f}")
|
| 665 |
+
|
| 666 |
+
successful = 0
|
| 667 |
+
consecutive_failures = 0
|
| 668 |
+
|
| 669 |
+
# Store original model reference (quantized base - don't modify)
|
| 670 |
+
# RSI works by: train LoRA -> evaluate -> if good, keep merged; if bad, skip merge
|
| 671 |
+
|
| 672 |
+
for iteration in range(1, self.config.max_rsi_iterations + 1):
|
| 673 |
+
print(f"\n--- RSI Iteration {iteration} ---")
|
| 674 |
+
|
| 675 |
+
# Self-generate data BEFORE any training
|
| 676 |
+
print(" Generating data...")
|
| 677 |
+
rsi_data = self.generate_rsi_data(25)
|
| 678 |
+
|
| 679 |
+
# Setup fresh LoRA for this iteration
|
| 680 |
+
def tokenize_fn(examples):
|
| 681 |
+
return self.tokenizer(
|
| 682 |
+
examples["text"],
|
| 683 |
+
truncation=True,
|
| 684 |
+
max_length=128,
|
| 685 |
+
padding="max_length",
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
tokenized = rsi_data.map(tokenize_fn, batched=True, remove_columns=["text"])
|
| 689 |
+
|
| 690 |
+
lora_config = LoraConfig(
|
| 691 |
+
task_type=TaskType.CAUSAL_LM,
|
| 692 |
+
r=8,
|
| 693 |
+
lora_alpha=16,
|
| 694 |
+
lora_dropout=0.05,
|
| 695 |
+
target_modules=["q_proj", "v_proj"],
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
# Apply LoRA (don't prepare_for_kbit_training again if already done)
|
| 699 |
+
try:
|
| 700 |
+
peft_model = get_peft_model(self.model, lora_config)
|
| 701 |
+
except:
|
| 702 |
+
# Model might already be a PEFT model from previous iteration
|
| 703 |
+
self.model = self.model.merge_and_unload() if hasattr(self.model, 'merge_and_unload') else self.model
|
| 704 |
+
peft_model = get_peft_model(self.model, lora_config)
|
| 705 |
+
|
| 706 |
+
training_args = TrainingArguments(
|
| 707 |
+
output_dir=f"{self.config.output_dir}/rsi_temp",
|
| 708 |
+
max_steps=self.config.rsi_micro_steps,
|
| 709 |
+
per_device_train_batch_size=1,
|
| 710 |
+
gradient_accumulation_steps=4,
|
| 711 |
+
learning_rate=self.config.rsi_lr,
|
| 712 |
+
warmup_steps=2,
|
| 713 |
+
logging_steps=100,
|
| 714 |
+
fp16=True,
|
| 715 |
+
report_to="none",
|
| 716 |
+
save_strategy="no",
|
| 717 |
+
dataloader_pin_memory=False,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 721 |
+
tokenizer=self.tokenizer,
|
| 722 |
+
mlm=False,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
print(" Training...")
|
| 726 |
+
trainer = Trainer(
|
| 727 |
+
model=peft_model,
|
| 728 |
+
args=training_args,
|
| 729 |
+
train_dataset=tokenized,
|
| 730 |
+
data_collator=data_collator,
|
| 731 |
+
)
|
| 732 |
+
trainer.train()
|
| 733 |
+
|
| 734 |
+
# Merge to evaluate
|
| 735 |
+
merged_model = peft_model.merge_and_unload()
|
| 736 |
+
|
| 737 |
+
# Evaluate
|
| 738 |
+
print(" Evaluating...")
|
| 739 |
+
old_model = self.model
|
| 740 |
+
self.model = merged_model
|
| 741 |
+
quality = self.evaluate_quality()
|
| 742 |
+
relative = quality["quality_score"] / self.baseline_quality["quality_score"]
|
| 743 |
+
|
| 744 |
+
print(f" Quality: {quality['quality_score']:.4f} ({relative:.1%} of baseline)")
|
| 745 |
+
|
| 746 |
+
if relative >= self.config.quality_threshold:
|
| 747 |
+
successful += 1
|
| 748 |
+
consecutive_failures = 0
|
| 749 |
+
|
| 750 |
+
self.iteration_history.append({
|
| 751 |
+
"iteration": iteration,
|
| 752 |
+
"status": "success",
|
| 753 |
+
"quality": quality["quality_score"],
|
| 754 |
+
"relative": relative,
|
| 755 |
+
})
|
| 756 |
+
|
| 757 |
+
print(f" ✅ SUCCESS (total: {successful})")
|
| 758 |
+
|
| 759 |
+
# Keep merged model
|
| 760 |
+
if quality["quality_score"] > self.baseline_quality["quality_score"]:
|
| 761 |
+
self.baseline_quality = quality
|
| 762 |
+
|
| 763 |
+
else:
|
| 764 |
+
consecutive_failures += 1
|
| 765 |
+
|
| 766 |
+
self.iteration_history.append({
|
| 767 |
+
"iteration": iteration,
|
| 768 |
+
"status": "rollback",
|
| 769 |
+
"quality": quality["quality_score"],
|
| 770 |
+
"relative": relative,
|
| 771 |
+
})
|
| 772 |
+
|
| 773 |
+
print(f" ⚠️ ROLLBACK (discarding LoRA)")
|
| 774 |
+
# Don't keep the merged model, go back to old
|
| 775 |
+
self.model = old_model
|
| 776 |
+
del merged_model
|
| 777 |
+
|
| 778 |
+
if consecutive_failures >= 3:
|
| 779 |
+
print(f"\n🛑 CEILING HIT at iteration {iteration}")
|
| 780 |
+
break
|
| 781 |
+
|
| 782 |
+
torch.cuda.empty_cache()
|
| 783 |
+
gc.collect()
|
| 784 |
+
|
| 785 |
+
return {
|
| 786 |
+
"successful_iterations": successful,
|
| 787 |
+
"total_attempts": len(self.iteration_history),
|
| 788 |
+
"ceiling_broken": successful > 5,
|
| 789 |
+
"history": self.iteration_history,
|
| 790 |
+
"final_quality": self.evaluate_quality(),
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
class BreakthroughExperiment:
|
| 795 |
+
"""Complete Loop 4 → RSI breakthrough experiment."""
|
| 796 |
+
|
| 797 |
+
def __init__(self, config: BreakthroughConfig):
|
| 798 |
+
self.config = config
|
| 799 |
+
self.results = {}
|
| 800 |
+
|
| 801 |
+
def run(self):
|
| 802 |
+
print("="*70)
|
| 803 |
+
print("LOOP 4 → RSI CEILING BREAKTHROUGH TEST")
|
| 804 |
+
print("="*70)
|
| 805 |
+
print("\nIf RSI goes past 5 iterations, the ceiling is broken.\n")
|
| 806 |
+
|
| 807 |
+
# Load Loop 4 results
|
| 808 |
+
print("Step 1: Loading Loop 4 results...")
|
| 809 |
+
loop4_path = Path(self.config.loop4_results_path)
|
| 810 |
+
|
| 811 |
+
if not loop4_path.exists():
|
| 812 |
+
print(f"ERROR: Not found: {loop4_path}")
|
| 813 |
+
return None
|
| 814 |
+
|
| 815 |
+
with open(loop4_path) as f:
|
| 816 |
+
loop4_results = json.load(f)
|
| 817 |
+
|
| 818 |
+
# Extract merges
|
| 819 |
+
merges = []
|
| 820 |
+
if "all_improvements" in loop4_results:
|
| 821 |
+
for imp in loop4_results["all_improvements"][:self.config.top_k_merges]:
|
| 822 |
+
merges.append(imp["merged"])
|
| 823 |
+
elif "top_stressed_pairs" in loop4_results:
|
| 824 |
+
for pair in loop4_results["top_stressed_pairs"][:self.config.top_k_merges]:
|
| 825 |
+
# Parse "'. ' + 'The'" format
|
| 826 |
+
tokens = pair["tokens"].replace("'", "").split(" + ")
|
| 827 |
+
if len(tokens) == 2:
|
| 828 |
+
merges.append(tokens[0] + tokens[1])
|
| 829 |
+
|
| 830 |
+
print(f"Merge candidates: {merges}")
|
| 831 |
+
|
| 832 |
+
# Load model with 4-bit quantization
|
| 833 |
+
print("\nStep 2: Loading model (4-bit quantized)...")
|
| 834 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 835 |
+
self.config.model_path,
|
| 836 |
+
local_files_only=True
|
| 837 |
+
)
|
| 838 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 839 |
+
|
| 840 |
+
bnb_config = BitsAndBytesConfig(
|
| 841 |
+
load_in_4bit=True,
|
| 842 |
+
bnb_4bit_quant_type="nf4",
|
| 843 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 844 |
+
bnb_4bit_use_double_quant=True,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 848 |
+
self.config.model_path,
|
| 849 |
+
quantization_config=bnb_config,
|
| 850 |
+
device_map="auto",
|
| 851 |
+
torch_dtype=torch.bfloat16,
|
| 852 |
+
local_files_only=True,
|
| 853 |
+
)
|
| 854 |
+
print(f"Loaded: {model.config.hidden_size}d, {model.config.num_hidden_layers}L")
|
| 855 |
+
print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")
|
| 856 |
+
|
| 857 |
+
# Resize embeddings
|
| 858 |
+
print("\nStep 3: Resizing embeddings...")
|
| 859 |
+
resizer = EmbeddingResizer(model, tokenizer)
|
| 860 |
+
n_added, added_tokens = resizer.add_tokens_and_resize(merges)
|
| 861 |
+
self.results["tokens_added"] = n_added
|
| 862 |
+
|
| 863 |
+
# Targeted adaptation
|
| 864 |
+
print("\nStep 4: Targeted adaptation training...")
|
| 865 |
+
adapter = TargetedAdaptationTrainer(model, tokenizer, added_tokens, self.config)
|
| 866 |
+
model = adapter.train()
|
| 867 |
+
|
| 868 |
+
# Extended RSI
|
| 869 |
+
print("\nStep 5: Extended RSI...")
|
| 870 |
+
rsi = ExtendedRSI(model, tokenizer, self.config)
|
| 871 |
+
rsi_results = rsi.run()
|
| 872 |
+
self.results["rsi"] = rsi_results
|
| 873 |
+
|
| 874 |
+
# Report
|
| 875 |
+
print("\n" + "="*70)
|
| 876 |
+
print("RESULTS")
|
| 877 |
+
print("="*70)
|
| 878 |
+
|
| 879 |
+
print(f"\nTokens added: {n_added}")
|
| 880 |
+
print(f"RSI successful: {rsi_results['successful_iterations']}")
|
| 881 |
+
print(f"RSI total: {rsi_results['total_attempts']}")
|
| 882 |
+
|
| 883 |
+
if rsi_results["ceiling_broken"]:
|
| 884 |
+
print("\n" + "🎯"*25)
|
| 885 |
+
print(" THE CEILING IS BROKEN")
|
| 886 |
+
print("🎯"*25)
|
| 887 |
+
print(f"\nRSI: {rsi_results['successful_iterations']} iterations (>5)")
|
| 888 |
+
print("Loop 4 tokenization co-evolution WORKS")
|
| 889 |
+
print("The ladder extends.")
|
| 890 |
+
else:
|
| 891 |
+
print(f"\n⚠️ Ceiling at {rsi_results['successful_iterations']} iterations")
|
| 892 |
+
if rsi_results['successful_iterations'] >= 5:
|
| 893 |
+
print(" Matched previous ceiling - more refinement needed")
|
| 894 |
+
else:
|
| 895 |
+
print(" Below previous ceiling - investigate")
|
| 896 |
+
|
| 897 |
+
# Save
|
| 898 |
+
output_dir = Path(self.config.output_dir)
|
| 899 |
+
output_dir.mkdir(exist_ok=True)
|
| 900 |
+
|
| 901 |
+
save_data = {
|
| 902 |
+
"tokens_added": n_added,
|
| 903 |
+
"merges": merges,
|
| 904 |
+
"rsi_successful": rsi_results["successful_iterations"],
|
| 905 |
+
"ceiling_broken": rsi_results["ceiling_broken"],
|
| 906 |
+
"history": rsi_results["history"],
|
| 907 |
+
}
|
| 908 |
+
|
| 909 |
+
with open(output_dir / "breakthrough_results.json", "w") as f:
|
| 910 |
+
json.dump(save_data, f, indent=2)
|
| 911 |
+
|
| 912 |
+
print(f"\nSaved to {output_dir}/breakthrough_results.json")
|
| 913 |
+
|
| 914 |
+
return self.results
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
def main():
|
| 918 |
+
config = BreakthroughConfig(
|
| 919 |
+
model_path=".",
|
| 920 |
+
loop4_results_path="loop4_full_results/loop4_full_results.json",
|
| 921 |
+
top_k_merges=15,
|
| 922 |
+
adaptation_samples=300,
|
| 923 |
+
adaptation_steps=150,
|
| 924 |
+
adaptation_lr=2e-5,
|
| 925 |
+
max_rsi_iterations=10,
|
| 926 |
+
rsi_micro_steps=20,
|
| 927 |
+
quality_threshold=0.92,
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
experiment = BreakthroughExperiment(config)
|
| 931 |
+
return experiment.run()
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
if __name__ == "__main__":
|
| 935 |
+
main()
|
code/training_pipelines/06_arc_engine_v30_FULL_ENGINE.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
code/training_pipelines/07_qwen3b_repetition_REPLICATION.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CROSS-ARCHITECTURE REPLICATION: Qwen2.5-3B Repetition Detection
|
| 4 |
+
================================================================
|
| 5 |
+
Replicates Pipeline 01 (CF-HoT Risk Predictor) on Qwen2.5-3B.
|
| 6 |
+
|
| 7 |
+
Purpose: Validate that fiber-projected behavioral detection generalizes
|
| 8 |
+
across architecture families (LLaMA → Qwen).
|
| 9 |
+
|
| 10 |
+
Changes from Pipeline 01:
|
| 11 |
+
- Model: Qwen/Qwen2.5-3B (2048d, 36 layers) vs LLaMA-8B (4096d, 32 layers)
|
| 12 |
+
- Everything else IDENTICAL: d_fiber=16, d_control=64, LoRA r=64, same schedule
|
| 13 |
+
|
| 14 |
+
Author: Logan Napolitano / Fiber AI
|
| 15 |
+
Date: February 2026
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 22 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 23 |
+
from datasets import load_dataset
|
| 24 |
+
import os
|
| 25 |
+
import time
|
| 26 |
+
import random
|
| 27 |
+
import json
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from typing import Tuple
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class Config:
|
| 34 |
+
# CHANGED: Qwen2.5-3B instead of LLaMA-8B
|
| 35 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 36 |
+
output_dir: str = "./results/qwen3b_repetition_replication"
|
| 37 |
+
|
| 38 |
+
# IDENTICAL to Pipeline 01
|
| 39 |
+
d_fiber: int = 16
|
| 40 |
+
d_control: int = 64
|
| 41 |
+
max_steps: int = 3000
|
| 42 |
+
batch_size: int = 1
|
| 43 |
+
grad_accum: int = 8
|
| 44 |
+
max_length: int = 256
|
| 45 |
+
lr_lora: float = 2e-5
|
| 46 |
+
lr_predictor: float = 1e-4
|
| 47 |
+
weight_decay: float = 0.01
|
| 48 |
+
rep_window: int = 32
|
| 49 |
+
log_every: int = 10
|
| 50 |
+
save_every: int = 500
|
| 51 |
+
eval_every: int = 200
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class RiskPredictor(nn.Module):
|
| 55 |
+
"""Identical architecture to Pipeline 01 - dimensions auto-detected."""
|
| 56 |
+
|
| 57 |
+
def __init__(self, d_model: int, n_layers: int, config: Config):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.config = config
|
| 60 |
+
self.n_layers = n_layers
|
| 61 |
+
|
| 62 |
+
# Fiber projections: d_model → 16 (2048→16 for Qwen vs 4096→16 for LLaMA)
|
| 63 |
+
self.fiber_projs = nn.ModuleList([
|
| 64 |
+
nn.Linear(d_model, config.d_fiber, bias=False)
|
| 65 |
+
for _ in range(n_layers)
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
self.layer_weights = nn.Parameter(torch.ones(n_layers) / n_layers)
|
| 69 |
+
|
| 70 |
+
# Predictor head: identical structure
|
| 71 |
+
self.predictor = nn.Sequential(
|
| 72 |
+
nn.Linear(config.d_fiber, config.d_control),
|
| 73 |
+
nn.GELU(),
|
| 74 |
+
nn.Linear(config.d_control, config.d_control),
|
| 75 |
+
nn.GELU(),
|
| 76 |
+
nn.Linear(config.d_control, 1)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
for proj in self.fiber_projs:
|
| 80 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 81 |
+
|
| 82 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 83 |
+
fibers = []
|
| 84 |
+
for i, (proj, h) in enumerate(zip(self.fiber_projs, hidden_states)):
|
| 85 |
+
if i < len(hidden_states):
|
| 86 |
+
fiber = proj(h.float())
|
| 87 |
+
fibers.append(fiber)
|
| 88 |
+
|
| 89 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 90 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 91 |
+
|
| 92 |
+
logits = self.predictor(aggregated).squeeze(-1)
|
| 93 |
+
return logits
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 97 |
+
"""Identical to Pipeline 01."""
|
| 98 |
+
B, S = input_ids.shape
|
| 99 |
+
device = input_ids.device
|
| 100 |
+
labels = torch.zeros(B, S, device=device)
|
| 101 |
+
|
| 102 |
+
for offset in range(1, min(window + 1, S)):
|
| 103 |
+
if offset < S:
|
| 104 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 105 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 106 |
+
|
| 107 |
+
return labels
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50):
|
| 111 |
+
"""Compute P(+)/P(-) separation ratio - the key metric."""
|
| 112 |
+
model.eval()
|
| 113 |
+
risk_predictor.eval()
|
| 114 |
+
|
| 115 |
+
all_pos_scores = []
|
| 116 |
+
all_neg_scores = []
|
| 117 |
+
|
| 118 |
+
prompts = [
|
| 119 |
+
"The meaning of life according to philosophy is",
|
| 120 |
+
"In the year 2050, technology will",
|
| 121 |
+
"The history of mathematics begins with",
|
| 122 |
+
"Climate change affects the planet by",
|
| 123 |
+
"Neural networks learn patterns through",
|
| 124 |
+
"The ocean contains many species of",
|
| 125 |
+
"Music has evolved significantly since",
|
| 126 |
+
"Economic theories suggest that markets",
|
| 127 |
+
"The human brain processes information",
|
| 128 |
+
"Ancient civilizations developed writing",
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
for i in range(n_samples):
|
| 134 |
+
prompt = prompts[i % len(prompts)]
|
| 135 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 136 |
+
input_ids = inp['input_ids'].to(device)
|
| 137 |
+
|
| 138 |
+
out = model.generate(
|
| 139 |
+
input_ids, max_new_tokens=80,
|
| 140 |
+
do_sample=True, temperature=0.9, top_p=0.95,
|
| 141 |
+
pad_token_id=tokenizer.eos_token_id
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
gen_outputs = model(out, output_hidden_states=True)
|
| 145 |
+
gen_logits = risk_predictor(gen_outputs.hidden_states[1:])
|
| 146 |
+
gen_risk = torch.sigmoid(gen_logits)
|
| 147 |
+
risk_vals = gen_risk[0].cpu().numpy()
|
| 148 |
+
|
| 149 |
+
# Get actual repetition labels
|
| 150 |
+
rep_labels = compute_repetition_labels_fast(out, config.rep_window)
|
| 151 |
+
labels = rep_labels[0].cpu().numpy()
|
| 152 |
+
|
| 153 |
+
for t in range(len(risk_vals)):
|
| 154 |
+
if labels[t] > 0.5:
|
| 155 |
+
all_pos_scores.append(float(risk_vals[t]))
|
| 156 |
+
else:
|
| 157 |
+
all_neg_scores.append(float(risk_vals[t]))
|
| 158 |
+
|
| 159 |
+
if all_pos_scores and all_neg_scores:
|
| 160 |
+
p_pos = sum(all_pos_scores) / len(all_pos_scores)
|
| 161 |
+
p_neg = sum(all_neg_scores) / len(all_neg_scores)
|
| 162 |
+
separation = p_pos / max(p_neg, 1e-8)
|
| 163 |
+
return p_pos, p_neg, separation, len(all_pos_scores), len(all_neg_scores)
|
| 164 |
+
|
| 165 |
+
return 0.0, 0.0, 0.0, 0, 0
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
config = Config()
|
| 170 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 171 |
+
|
| 172 |
+
print("=" * 70)
|
| 173 |
+
print("CROSS-ARCHITECTURE REPLICATION: Qwen2.5-3B")
|
| 174 |
+
print("Replicating Pipeline 01 (Repetition Detection)")
|
| 175 |
+
print("=" * 70)
|
| 176 |
+
print(f"Model: {config.model_path}")
|
| 177 |
+
print(f"d_fiber: {config.d_fiber} (identical to LLaMA-8B experiment)")
|
| 178 |
+
print(f"d_control: {config.d_control}")
|
| 179 |
+
print(f"max_steps: {config.max_steps}")
|
| 180 |
+
print()
|
| 181 |
+
|
| 182 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 183 |
+
if tokenizer.pad_token is None:
|
| 184 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 185 |
+
|
| 186 |
+
print("Loading Qwen2.5-3B in 4-bit...")
|
| 187 |
+
bnb = BitsAndBytesConfig(
|
| 188 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 189 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"
|
| 190 |
+
)
|
| 191 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 192 |
+
config.model_path, quantization_config=bnb,
|
| 193 |
+
device_map='auto', torch_dtype=torch.float16
|
| 194 |
+
)
|
| 195 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
device = next(model.parameters()).device
|
| 199 |
+
n_layers = model.config.num_hidden_layers
|
| 200 |
+
d_model = model.config.hidden_size
|
| 201 |
+
|
| 202 |
+
print(f"Architecture: {model.config.architectures[0]}")
|
| 203 |
+
print(f"Hidden dim: {d_model} (LLaMA-8B was 4096)")
|
| 204 |
+
print(f"Layers: {n_layers} (LLaMA-8B was 32)")
|
| 205 |
+
print(f"Fiber projection: {d_model} → {config.d_fiber}")
|
| 206 |
+
print()
|
| 207 |
+
|
| 208 |
+
# LoRA — identical config to Pipeline 01
|
| 209 |
+
print("Adding LoRA (identical config to LLaMA-8B experiment)...")
|
| 210 |
+
model = get_peft_model(model, LoraConfig(
|
| 211 |
+
r=64, lora_alpha=128,
|
| 212 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 213 |
+
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
|
| 214 |
+
))
|
| 215 |
+
model.print_trainable_parameters()
|
| 216 |
+
|
| 217 |
+
# Risk predictor — auto-adapts to Qwen dimensions
|
| 218 |
+
print("Adding Risk Predictor...")
|
| 219 |
+
risk_predictor = RiskPredictor(d_model, n_layers, config).to(device).float()
|
| 220 |
+
rp_params = sum(p.numel() for p in risk_predictor.parameters())
|
| 221 |
+
print(f"Risk Predictor params: {rp_params:,}")
|
| 222 |
+
print(f" (LLaMA-8B had ~{4096*16*32 + 16*64 + 64*64 + 64*1:,} — fiber proj was larger)")
|
| 223 |
+
print()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Data — identical to Pipeline 01
|
| 227 |
+
print("Loading wikitext data...")
|
| 228 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 229 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 230 |
+
random.shuffle(texts)
|
| 231 |
+
print(f"Loaded {len(texts)} samples")
|
| 232 |
+
|
| 233 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 234 |
+
optimizer = torch.optim.AdamW([
|
| 235 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 236 |
+
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
|
| 237 |
+
], weight_decay=config.weight_decay)
|
| 238 |
+
|
| 239 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 240 |
+
optimizer, T_max=config.max_steps, eta_min=1e-6
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Training log for comparison
|
| 244 |
+
training_log = {
|
| 245 |
+
"experiment": "cross_architecture_replication",
|
| 246 |
+
"source_model": "LLaMA-3.1-8B (4096d, 32L)",
|
| 247 |
+
"target_model": "Qwen2.5-3B (2048d, 36L)",
|
| 248 |
+
"d_fiber": config.d_fiber,
|
| 249 |
+
"d_control": config.d_control,
|
| 250 |
+
"hypothesis": "Fiber projection generalizes across architecture families",
|
| 251 |
+
"baseline_separation": "125x (LLaMA-8B repetition)",
|
| 252 |
+
"steps": [],
|
| 253 |
+
"separations": []
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
print("=" * 70)
|
| 258 |
+
print("TRAINING")
|
| 259 |
+
print("=" * 70)
|
| 260 |
+
|
| 261 |
+
model.train()
|
| 262 |
+
risk_predictor.train()
|
| 263 |
+
|
| 264 |
+
step = 0
|
| 265 |
+
data_idx = 0
|
| 266 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 267 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 268 |
+
start_time = time.time()
|
| 269 |
+
|
| 270 |
+
while step < config.max_steps:
|
| 271 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 272 |
+
data_idx += config.batch_size
|
| 273 |
+
|
| 274 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 275 |
+
padding='max_length', return_tensors='pt')
|
| 276 |
+
input_ids = enc['input_ids'].to(device)
|
| 277 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 278 |
+
|
| 279 |
+
outputs = model(
|
| 280 |
+
input_ids=input_ids,
|
| 281 |
+
attention_mask=attention_mask,
|
| 282 |
+
labels=input_ids,
|
| 283 |
+
output_hidden_states=True
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
lm_loss = outputs.loss
|
| 287 |
+
risk_logits = risk_predictor(outputs.hidden_states[1:])
|
| 288 |
+
rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Class-weighted loss — identical to Pipeline 01
|
| 292 |
+
mask = attention_mask.float()
|
| 293 |
+
n_pos = (rep_labels * mask).sum().clamp(min=1)
|
| 294 |
+
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
|
| 295 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 296 |
+
|
| 297 |
+
bce_loss = F.binary_cross_entropy_with_logits(
|
| 298 |
+
risk_logits, rep_labels,
|
| 299 |
+
pos_weight=torch.ones_like(rep_labels) * pos_weight,
|
| 300 |
+
reduction='none'
|
| 301 |
+
)
|
| 302 |
+
risk_loss = (bce_loss * mask).sum() / mask.sum()
|
| 303 |
+
|
| 304 |
+
loss = lm_loss + risk_loss
|
| 305 |
+
(loss / config.grad_accum).backward()
|
| 306 |
+
|
| 307 |
+
# Metrics
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
risk_pred = torch.sigmoid(risk_logits)
|
| 310 |
+
pred_binary = (risk_pred > 0.5).float()
|
| 311 |
+
tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum()
|
| 312 |
+
fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum()
|
| 313 |
+
fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum()
|
| 314 |
+
|
| 315 |
+
precision = tp / (tp + fp + 1e-8)
|
| 316 |
+
recall = tp / (tp + fn + 1e-8)
|
| 317 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 318 |
+
|
| 319 |
+
acc_loss += loss.item()
|
| 320 |
+
acc_lm += lm_loss.item()
|
| 321 |
+
acc_risk_loss += risk_loss.item()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
acc_precision += precision.item()
|
| 325 |
+
acc_recall += recall.item()
|
| 326 |
+
acc_f1 += f1.item()
|
| 327 |
+
|
| 328 |
+
step += 1
|
| 329 |
+
|
| 330 |
+
if step % config.grad_accum == 0:
|
| 331 |
+
torch.nn.utils.clip_grad_norm_(
|
| 332 |
+
list(lora_params) + list(risk_predictor.parameters()), 1.0
|
| 333 |
+
)
|
| 334 |
+
optimizer.step()
|
| 335 |
+
scheduler.step()
|
| 336 |
+
optimizer.zero_grad()
|
| 337 |
+
|
| 338 |
+
if step % config.log_every == 0:
|
| 339 |
+
eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600
|
| 340 |
+
n = config.log_every
|
| 341 |
+
|
| 342 |
+
log_line = (
|
| 343 |
+
f"Step {step:5d} | "
|
| 344 |
+
f"Loss: {acc_loss/n:.4f} | "
|
| 345 |
+
f"LM: {acc_lm/n:.4f} | "
|
| 346 |
+
f"Risk: {acc_risk_loss/n:.4f} | "
|
| 347 |
+
f"P: {acc_precision/n:.3f} | "
|
| 348 |
+
f"R: {acc_recall/n:.3f} | "
|
| 349 |
+
f"F1: {acc_f1/n:.3f} | "
|
| 350 |
+
f"ETA: {eta:.1f}h"
|
| 351 |
+
)
|
| 352 |
+
print(log_line)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
training_log["steps"].append({
|
| 356 |
+
"step": step,
|
| 357 |
+
"loss": acc_loss/n,
|
| 358 |
+
"lm_loss": acc_lm/n,
|
| 359 |
+
"risk_loss": acc_risk_loss/n,
|
| 360 |
+
"precision": acc_precision/n,
|
| 361 |
+
"recall": acc_recall/n,
|
| 362 |
+
"f1": acc_f1/n
|
| 363 |
+
})
|
| 364 |
+
|
| 365 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 366 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 367 |
+
|
| 368 |
+
if step % config.save_every == 0:
|
| 369 |
+
ckpt = os.path.join(config.output_dir, f"ckpt_{step}")
|
| 370 |
+
os.makedirs(ckpt, exist_ok=True)
|
| 371 |
+
model.save_pretrained(ckpt)
|
| 372 |
+
torch.save({
|
| 373 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 374 |
+
'step': step
|
| 375 |
+
}, os.path.join(ckpt, "risk_predictor.pt"))
|
| 376 |
+
print(f">>> Saved: {ckpt}")
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# Separation eval every eval_every steps
|
| 380 |
+
if step % config.eval_every == 0:
|
| 381 |
+
print(f"\n{'='*50}")
|
| 382 |
+
print(f"SEPARATION EVAL @ Step {step}")
|
| 383 |
+
print(f"{'='*50}")
|
| 384 |
+
|
| 385 |
+
p_pos, p_neg, separation, n_pos_samples, n_neg_samples = \
|
| 386 |
+
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=30)
|
| 387 |
+
|
| 388 |
+
print(f" P(+) = {p_pos:.4f} (n={n_pos_samples})")
|
| 389 |
+
print(f" P(-) = {p_neg:.4f} (n={n_neg_samples})")
|
| 390 |
+
print(f" SEPARATION = {separation:.1f}x")
|
| 391 |
+
print(f" [LLaMA-8B baseline: 125x]")
|
| 392 |
+
|
| 393 |
+
training_log["separations"].append({
|
| 394 |
+
"step": step,
|
| 395 |
+
"p_pos": p_pos,
|
| 396 |
+
"p_neg": p_neg,
|
| 397 |
+
"separation": separation,
|
| 398 |
+
"n_pos": n_pos_samples,
|
| 399 |
+
"n_neg": n_neg_samples
|
| 400 |
+
})
|
| 401 |
+
|
| 402 |
+
# Save log after each eval
|
| 403 |
+
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
|
| 404 |
+
json.dump(training_log, f, indent=2)
|
| 405 |
+
|
| 406 |
+
print(f"{'='*50}\n")
|
| 407 |
+
|
| 408 |
+
model.train()
|
| 409 |
+
risk_predictor.train()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# =========================================================================
|
| 413 |
+
# FINAL EVALUATION
|
| 414 |
+
# =========================================================================
|
| 415 |
+
print("\n" + "=" * 70)
|
| 416 |
+
print("FINAL CROSS-ARCHITECTURE COMPARISON")
|
| 417 |
+
print("=" * 70)
|
| 418 |
+
|
| 419 |
+
p_pos, p_neg, separation, n_pos, n_neg = \
|
| 420 |
+
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50)
|
| 421 |
+
|
| 422 |
+
print(f"""
|
| 423 |
+
┌─────────────────────────────────────────────────────────┐
|
| 424 |
+
│ CROSS-ARCHITECTURE REPLICATION RESULTS │
|
| 425 |
+
├─────────────────────────────────────────────────────────┤
|
| 426 |
+
│ │
|
| 427 |
+
│ LLaMA-3.1-8B (original): │
|
| 428 |
+
│ Hidden dim: 4096 │
|
| 429 |
+
│ Layers: 32 │
|
| 430 |
+
│ Separation: 125x │
|
| 431 |
+
│ P(+): 0.998 │
|
| 432 |
+
│ P(-): 0.008 │
|
| 433 |
+
│ │
|
| 434 |
+
│ Qwen2.5-3B (replication): │
|
| 435 |
+
│ Hidden dim: {d_model} │
|
| 436 |
+
│ Layers: {n_layers} │
|
| 437 |
+
│ Separation: {separation:.1f}x │
|
| 438 |
+
│ P(+): {p_pos:.4f} │
|
| 439 |
+
│ P(-): {p_neg:.4f} │
|
| 440 |
+
│ │
|
| 441 |
+
│ Method: IDENTICAL │
|
| 442 |
+
│ d_fiber: 16 │
|
| 443 |
+
│ d_control: 64 │
|
| 444 |
+
│ LoRA r: 64 │
|
| 445 |
+
│ Training: 3000 steps, wikitext-2 │
|
| 446 |
+
│ │
|
| 447 |
+
│ Conclusion: │
|
| 448 |
+
│ {"✅ METHOD GENERALIZES" if separation > 10 else "⚠️ NEEDS INVESTIGATION"} │
|
| 449 |
+
│ {"across architecture families" if separation > 10 else "separation below threshold"} │
|
| 450 |
+
└─────────────────────────────────────────────────────────┘
|
| 451 |
+
""")
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
training_log["final"] = {
|
| 455 |
+
"p_pos": p_pos,
|
| 456 |
+
"p_neg": p_neg,
|
| 457 |
+
"separation": separation,
|
| 458 |
+
"n_pos": n_pos,
|
| 459 |
+
"n_neg": n_neg,
|
| 460 |
+
"conclusion": "generalizes" if separation > 10 else "needs_investigation"
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
|
| 464 |
+
json.dump(training_log, f, indent=2)
|
| 465 |
+
|
| 466 |
+
# Save final checkpoint
|
| 467 |
+
final = os.path.join(config.output_dir, "final")
|
| 468 |
+
os.makedirs(final, exist_ok=True)
|
| 469 |
+
model.save_pretrained(final)
|
| 470 |
+
torch.save({
|
| 471 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 472 |
+
'step': step,
|
| 473 |
+
'separation': separation,
|
| 474 |
+
'p_pos': p_pos,
|
| 475 |
+
'p_neg': p_neg
|
| 476 |
+
}, os.path.join(final, "risk_predictor.pt"))
|
| 477 |
+
|
| 478 |
+
print(f"Saved to {final}")
|
| 479 |
+
print(f"Log saved to {config.output_dir}/replication_log.json")
|
| 480 |
+
print(f"\nDONE!")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
if __name__ == "__main__":
|
| 484 |
+
main()
|
code/training_pipelines/07b_qwen3b_repetition_FIXED.py
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CROSS-ARCHITECTURE REPLICATION v2: Qwen2.5-3B Repetition Detection
|
| 4 |
+
====================================================================
|
| 5 |
+
FIX: Use 3 specific probe layers [9, 18, 27] instead of all 36.
|
| 6 |
+
Matches Pipeline 02 methodology which achieved 125x-168x on LLaMA-8B.
|
| 7 |
+
|
| 8 |
+
Changes from v1:
|
| 9 |
+
- probe_layers = [9, 18, 27] (25%, 50%, 75% of 36 layers)
|
| 10 |
+
- 3 fiber projections instead of 36
|
| 11 |
+
- Gradient signal concentrated, not diluted
|
| 12 |
+
|
| 13 |
+
Author: Logan Napolitano / Proprioception AI
|
| 14 |
+
Date: February 2026
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 21 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 22 |
+
from datasets import load_dataset
|
| 23 |
+
import os
|
| 24 |
+
import time
|
| 25 |
+
import random
|
| 26 |
+
import json
|
| 27 |
+
from dataclasses import dataclass, field
|
| 28 |
+
from typing import Tuple, List
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class Config:
|
| 33 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 34 |
+
output_dir: str = "./results/qwen3b_repetition_v2_fixed"
|
| 35 |
+
|
| 36 |
+
# Probe layers: 25%, 50%, 75% of 36 layers (matches Pipeline 02 methodology)
|
| 37 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 38 |
+
|
| 39 |
+
# Identical to Pipeline 01/02
|
| 40 |
+
d_fiber: int = 16
|
| 41 |
+
d_control: int = 64
|
| 42 |
+
max_steps: int = 10000
|
| 43 |
+
batch_size: int = 1
|
| 44 |
+
grad_accum: int = 8
|
| 45 |
+
max_length: int = 256
|
| 46 |
+
lr_lora: float = 2e-5
|
| 47 |
+
lr_predictor: float = 1e-4
|
| 48 |
+
weight_decay: float = 0.01
|
| 49 |
+
rep_window: int = 32
|
| 50 |
+
log_every: int = 10
|
| 51 |
+
save_every: int = 500
|
| 52 |
+
eval_every: int = 200
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RiskPredictor(nn.Module):
|
| 56 |
+
"""FIXED: Only 3 probe layers instead of all 36."""
|
| 57 |
+
|
| 58 |
+
def __init__(self, d_model: int, probe_layers: List[int], config: Config):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.config = config
|
| 61 |
+
self.probe_layers = probe_layers
|
| 62 |
+
n_probes = len(probe_layers)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# Only 3 projections: 2048→16 each
|
| 66 |
+
self.fiber_projs = nn.ModuleList([
|
| 67 |
+
nn.Linear(d_model, config.d_fiber, bias=False)
|
| 68 |
+
for _ in range(n_probes)
|
| 69 |
+
])
|
| 70 |
+
|
| 71 |
+
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
|
| 72 |
+
|
| 73 |
+
self.predictor = nn.Sequential(
|
| 74 |
+
nn.Linear(config.d_fiber, config.d_control),
|
| 75 |
+
nn.GELU(),
|
| 76 |
+
nn.Linear(config.d_control, config.d_control),
|
| 77 |
+
nn.GELU(),
|
| 78 |
+
nn.Linear(config.d_control, 1)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
for proj in self.fiber_projs:
|
| 82 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 83 |
+
|
| 84 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 85 |
+
fibers = []
|
| 86 |
+
for i, layer_idx in enumerate(self.probe_layers):
|
| 87 |
+
if layer_idx < len(hidden_states):
|
| 88 |
+
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
|
| 89 |
+
fibers.append(fiber)
|
| 90 |
+
|
| 91 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 92 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 93 |
+
|
| 94 |
+
logits = self.predictor(aggregated).squeeze(-1)
|
| 95 |
+
return logits
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_repetition_labels_fast(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 99 |
+
B, S = input_ids.shape
|
| 100 |
+
device = input_ids.device
|
| 101 |
+
labels = torch.zeros(B, S, device=device)
|
| 102 |
+
for offset in range(1, min(window + 1, S)):
|
| 103 |
+
if offset < S:
|
| 104 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 105 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 106 |
+
return labels
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50):
|
| 110 |
+
model.eval()
|
| 111 |
+
risk_predictor.eval()
|
| 112 |
+
all_pos_scores = []
|
| 113 |
+
all_neg_scores = []
|
| 114 |
+
|
| 115 |
+
prompts = [
|
| 116 |
+
"The meaning of life according to philosophy is",
|
| 117 |
+
"In the year 2050, technology will",
|
| 118 |
+
"The history of mathematics begins with",
|
| 119 |
+
"Climate change affects the planet by",
|
| 120 |
+
"Neural networks learn patterns through",
|
| 121 |
+
"The ocean contains many species of",
|
| 122 |
+
"Music has evolved significantly since",
|
| 123 |
+
"Economic theories suggest that markets",
|
| 124 |
+
"The human brain processes information",
|
| 125 |
+
"Ancient civilizations developed writing",
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
for i in range(n_samples):
|
| 131 |
+
prompt = prompts[i % len(prompts)]
|
| 132 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 133 |
+
input_ids = inp['input_ids'].to(device)
|
| 134 |
+
attn_mask = inp['attention_mask'].to(device)
|
| 135 |
+
|
| 136 |
+
out = model.generate(
|
| 137 |
+
input_ids, attention_mask=attn_mask, max_new_tokens=80,
|
| 138 |
+
do_sample=True, temperature=0.9, top_p=0.95,
|
| 139 |
+
pad_token_id=tokenizer.eos_token_id
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
gen_outputs = model(out, output_hidden_states=True)
|
| 143 |
+
gen_logits = risk_predictor(gen_outputs.hidden_states)
|
| 144 |
+
gen_risk = torch.sigmoid(gen_logits)
|
| 145 |
+
risk_vals = gen_risk[0].cpu().numpy()
|
| 146 |
+
|
| 147 |
+
rep_labels = compute_repetition_labels_fast(out, config.rep_window)
|
| 148 |
+
labels = rep_labels[0].cpu().numpy()
|
| 149 |
+
|
| 150 |
+
for t in range(len(risk_vals)):
|
| 151 |
+
if labels[t] > 0.5:
|
| 152 |
+
all_pos_scores.append(float(risk_vals[t]))
|
| 153 |
+
else:
|
| 154 |
+
all_neg_scores.append(float(risk_vals[t]))
|
| 155 |
+
|
| 156 |
+
if all_pos_scores and all_neg_scores:
|
| 157 |
+
p_pos = sum(all_pos_scores) / len(all_pos_scores)
|
| 158 |
+
p_neg = sum(all_neg_scores) / len(all_neg_scores)
|
| 159 |
+
separation = p_pos / max(p_neg, 1e-8)
|
| 160 |
+
return p_pos, p_neg, separation, len(all_pos_scores), len(all_neg_scores)
|
| 161 |
+
return 0.0, 0.0, 0.0, 0, 0
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def main():
|
| 165 |
+
config = Config()
|
| 166 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
print("=" * 70)
|
| 169 |
+
print("CROSS-ARCHITECTURE REPLICATION v2 (FIXED PROBE LAYERS)")
|
| 170 |
+
print("=" * 70)
|
| 171 |
+
print(f"Model: {config.model_path}")
|
| 172 |
+
print(f"Probe layers: {config.probe_layers} (25%, 50%, 75%)")
|
| 173 |
+
print(f"d_fiber: {config.d_fiber}, d_control: {config.d_control}")
|
| 174 |
+
print(f"FIX: 3 focused projections instead of 36 diluted ones")
|
| 175 |
+
print()
|
| 176 |
+
|
| 177 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 178 |
+
if tokenizer.pad_token is None:
|
| 179 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 180 |
+
|
| 181 |
+
print("Loading Qwen2.5-3B in 4-bit...")
|
| 182 |
+
bnb = BitsAndBytesConfig(
|
| 183 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 184 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"
|
| 185 |
+
)
|
| 186 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 187 |
+
config.model_path, quantization_config=bnb,
|
| 188 |
+
device_map='auto', torch_dtype=torch.float16
|
| 189 |
+
)
|
| 190 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
| 191 |
+
|
| 192 |
+
device = next(model.parameters()).device
|
| 193 |
+
d_model = model.config.hidden_size
|
| 194 |
+
n_layers = model.config.num_hidden_layers
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
print(f"Architecture: Qwen2ForCausalLM")
|
| 198 |
+
print(f"Hidden dim: {d_model}, Layers: {n_layers}")
|
| 199 |
+
print(f"Probing layers: {config.probe_layers}")
|
| 200 |
+
print()
|
| 201 |
+
|
| 202 |
+
print("Adding LoRA...")
|
| 203 |
+
model = get_peft_model(model, LoraConfig(
|
| 204 |
+
r=64, lora_alpha=128,
|
| 205 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 206 |
+
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
|
| 207 |
+
))
|
| 208 |
+
model.print_trainable_parameters()
|
| 209 |
+
|
| 210 |
+
print("Adding Risk Predictor (3 probe layers)...")
|
| 211 |
+
risk_predictor = RiskPredictor(d_model, config.probe_layers, config).to(device).float()
|
| 212 |
+
rp_params = sum(p.numel() for p in risk_predictor.parameters())
|
| 213 |
+
print(f"Risk Predictor params: {rp_params:,}")
|
| 214 |
+
print()
|
| 215 |
+
|
| 216 |
+
print("Loading wikitext data...")
|
| 217 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 218 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 219 |
+
random.shuffle(texts)
|
| 220 |
+
print(f"Loaded {len(texts)} samples")
|
| 221 |
+
|
| 222 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 223 |
+
optimizer = torch.optim.AdamW([
|
| 224 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 225 |
+
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
|
| 226 |
+
], weight_decay=config.weight_decay)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 230 |
+
optimizer, T_max=config.max_steps, eta_min=1e-6
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
training_log = {
|
| 234 |
+
"experiment": "cross_architecture_replication_v2_fixed",
|
| 235 |
+
"fix": "3 probe layers [9,18,27] instead of all 36",
|
| 236 |
+
"source_model": "LLaMA-3.1-8B (4096d, 32L, probe [8,16,24])",
|
| 237 |
+
"target_model": f"Qwen2.5-3B ({d_model}d, {n_layers}L, probe {config.probe_layers})",
|
| 238 |
+
"d_fiber": config.d_fiber,
|
| 239 |
+
"baseline_separation": "125x (LLaMA-8B repetition)",
|
| 240 |
+
"steps": [],
|
| 241 |
+
"separations": []
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
print("=" * 70)
|
| 245 |
+
print("TRAINING")
|
| 246 |
+
print("=" * 70)
|
| 247 |
+
|
| 248 |
+
model.train()
|
| 249 |
+
risk_predictor.train()
|
| 250 |
+
|
| 251 |
+
step = 0
|
| 252 |
+
data_idx = 0
|
| 253 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 254 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 255 |
+
start_time = time.time()
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
while step < config.max_steps:
|
| 259 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 260 |
+
data_idx += config.batch_size
|
| 261 |
+
|
| 262 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 263 |
+
padding='max_length', return_tensors='pt')
|
| 264 |
+
input_ids = enc['input_ids'].to(device)
|
| 265 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 266 |
+
|
| 267 |
+
outputs = model(
|
| 268 |
+
input_ids=input_ids,
|
| 269 |
+
attention_mask=attention_mask,
|
| 270 |
+
labels=input_ids,
|
| 271 |
+
output_hidden_states=True
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
lm_loss = outputs.loss
|
| 275 |
+
# Pass full hidden_states — RiskPredictor indexes into specific layers
|
| 276 |
+
risk_logits = risk_predictor(outputs.hidden_states)
|
| 277 |
+
rep_labels = compute_repetition_labels_fast(input_ids, config.rep_window)
|
| 278 |
+
|
| 279 |
+
mask = attention_mask.float()
|
| 280 |
+
n_pos = (rep_labels * mask).sum().clamp(min=1)
|
| 281 |
+
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
|
| 282 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 283 |
+
|
| 284 |
+
bce_loss = F.binary_cross_entropy_with_logits(
|
| 285 |
+
risk_logits, rep_labels,
|
| 286 |
+
pos_weight=torch.ones_like(rep_labels) * pos_weight,
|
| 287 |
+
reduction='none'
|
| 288 |
+
)
|
| 289 |
+
risk_loss = (bce_loss * mask).sum() / mask.sum()
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
loss = lm_loss + risk_loss
|
| 293 |
+
(loss / config.grad_accum).backward()
|
| 294 |
+
|
| 295 |
+
with torch.no_grad():
|
| 296 |
+
risk_pred = torch.sigmoid(risk_logits)
|
| 297 |
+
pred_binary = (risk_pred > 0.5).float()
|
| 298 |
+
tp = ((pred_binary == 1) & (rep_labels == 1) & (mask == 1)).sum()
|
| 299 |
+
fp = ((pred_binary == 1) & (rep_labels == 0) & (mask == 1)).sum()
|
| 300 |
+
fn = ((pred_binary == 0) & (rep_labels == 1) & (mask == 1)).sum()
|
| 301 |
+
precision = tp / (tp + fp + 1e-8)
|
| 302 |
+
recall = tp / (tp + fn + 1e-8)
|
| 303 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 304 |
+
|
| 305 |
+
acc_loss += loss.item()
|
| 306 |
+
acc_lm += lm_loss.item()
|
| 307 |
+
acc_risk_loss += risk_loss.item()
|
| 308 |
+
acc_precision += precision.item()
|
| 309 |
+
acc_recall += recall.item()
|
| 310 |
+
acc_f1 += f1.item()
|
| 311 |
+
|
| 312 |
+
step += 1
|
| 313 |
+
|
| 314 |
+
if step % config.grad_accum == 0:
|
| 315 |
+
torch.nn.utils.clip_grad_norm_(
|
| 316 |
+
list(lora_params) + list(risk_predictor.parameters()), 1.0
|
| 317 |
+
)
|
| 318 |
+
optimizer.step()
|
| 319 |
+
scheduler.step()
|
| 320 |
+
optimizer.zero_grad()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
if step % config.log_every == 0:
|
| 324 |
+
eta = (config.max_steps - step) / (step / (time.time() - start_time)) / 3600
|
| 325 |
+
n = config.log_every
|
| 326 |
+
print(
|
| 327 |
+
f"Step {step:5d} | "
|
| 328 |
+
f"Loss: {acc_loss/n:.4f} | "
|
| 329 |
+
f"LM: {acc_lm/n:.4f} | "
|
| 330 |
+
f"Risk: {acc_risk_loss/n:.4f} | "
|
| 331 |
+
f"P: {acc_precision/n:.3f} | "
|
| 332 |
+
f"R: {acc_recall/n:.3f} | "
|
| 333 |
+
f"F1: {acc_f1/n:.3f} | "
|
| 334 |
+
f"ETA: {eta:.1f}h"
|
| 335 |
+
)
|
| 336 |
+
training_log["steps"].append({
|
| 337 |
+
"step": step, "loss": acc_loss/n, "lm_loss": acc_lm/n,
|
| 338 |
+
"risk_loss": acc_risk_loss/n, "precision": acc_precision/n,
|
| 339 |
+
"recall": acc_recall/n, "f1": acc_f1/n
|
| 340 |
+
})
|
| 341 |
+
acc_loss, acc_lm, acc_risk_loss = 0, 0, 0
|
| 342 |
+
acc_precision, acc_recall, acc_f1 = 0, 0, 0
|
| 343 |
+
|
| 344 |
+
if step % config.save_every == 0:
|
| 345 |
+
ckpt = os.path.join(config.output_dir, f"ckpt_{step}")
|
| 346 |
+
os.makedirs(ckpt, exist_ok=True)
|
| 347 |
+
model.save_pretrained(ckpt)
|
| 348 |
+
torch.save({
|
| 349 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 350 |
+
'step': step
|
| 351 |
+
}, os.path.join(ckpt, "risk_predictor.pt"))
|
| 352 |
+
print(f">>> Saved: {ckpt}")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
if step % config.eval_every == 0:
|
| 356 |
+
print(f"\n{'='*50}")
|
| 357 |
+
print(f"SEPARATION EVAL @ Step {step}")
|
| 358 |
+
print(f"{'='*50}")
|
| 359 |
+
|
| 360 |
+
p_pos, p_neg, separation, n_pos_s, n_neg_s = \
|
| 361 |
+
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=30)
|
| 362 |
+
|
| 363 |
+
print(f" P(+) = {p_pos:.4f} (n={n_pos_s})")
|
| 364 |
+
print(f" P(-) = {p_neg:.4f} (n={n_neg_s})")
|
| 365 |
+
print(f" SEPARATION = {separation:.1f}x")
|
| 366 |
+
print(f" [LLaMA-8B baseline: 125x]")
|
| 367 |
+
|
| 368 |
+
training_log["separations"].append({
|
| 369 |
+
"step": step, "p_pos": p_pos, "p_neg": p_neg,
|
| 370 |
+
"separation": separation, "n_pos": n_pos_s, "n_neg": n_neg_s
|
| 371 |
+
})
|
| 372 |
+
|
| 373 |
+
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
|
| 374 |
+
json.dump(training_log, f, indent=2)
|
| 375 |
+
|
| 376 |
+
print(f"{'='*50}\n")
|
| 377 |
+
model.train()
|
| 378 |
+
risk_predictor.train()
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# FINAL
|
| 382 |
+
print("\n" + "=" * 70)
|
| 383 |
+
print("FINAL CROSS-ARCHITECTURE COMPARISON")
|
| 384 |
+
print("=" * 70)
|
| 385 |
+
|
| 386 |
+
p_pos, p_neg, separation, n_pos, n_neg = \
|
| 387 |
+
compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=50)
|
| 388 |
+
|
| 389 |
+
d = d_model
|
| 390 |
+
nl = n_layers
|
| 391 |
+
print(f"""
|
| 392 |
+
┌─────────────────────────────────────────────────────────┐
|
| 393 |
+
│ CROSS-ARCHITECTURE REPLICATION v2 (FIXED) │
|
| 394 |
+
├─────────────────────────────────────────────────────────┤
|
| 395 |
+
│ LLaMA-3.1-8B: 125x (P+=0.998, P-=0.008) │
|
| 396 |
+
│ Qwen2.5-3B: {separation:>5.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f}) │
|
| 397 |
+
├─────────────────────────────────────────────────────────┤
|
| 398 |
+
│ Architecture: Qwen2 ({d}d, {nl}L) vs LLaMA (4096d, 32L) │
|
| 399 |
+
│ Probe layers: {config.probe_layers} │
|
| 400 |
+
│ d_fiber: 16 (identical) │
|
| 401 |
+
│ Method: IDENTICAL │
|
| 402 |
+
│ Conclusion: {"✅ GENERALIZES" if separation > 10 else "⚠️ INVESTIGATE"} │
|
| 403 |
+
└─────────────────────────────────────────────────────────┘
|
| 404 |
+
""")
|
| 405 |
+
|
| 406 |
+
training_log["final"] = {
|
| 407 |
+
"p_pos": p_pos, "p_neg": p_neg, "separation": separation,
|
| 408 |
+
"n_pos": n_pos, "n_neg": n_neg,
|
| 409 |
+
"conclusion": "generalizes" if separation > 10 else "needs_investigation"
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
with open(os.path.join(config.output_dir, "replication_log.json"), 'w') as f:
|
| 413 |
+
json.dump(training_log, f, indent=2)
|
| 414 |
+
|
| 415 |
+
final = os.path.join(config.output_dir, "final")
|
| 416 |
+
os.makedirs(final, exist_ok=True)
|
| 417 |
+
model.save_pretrained(final)
|
| 418 |
+
torch.save({
|
| 419 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 420 |
+
'step': step, 'separation': separation,
|
| 421 |
+
'p_pos': p_pos, 'p_neg': p_neg
|
| 422 |
+
}, os.path.join(final, "risk_predictor.pt"))
|
| 423 |
+
|
| 424 |
+
print(f"Done! Log: {config.output_dir}/replication_log.json")
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
if __name__ == "__main__":
|
| 428 |
+
main()
|
code/training_pipelines/07c_qwen3b_CONTINUE.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CONTINUE QWEN TRAINING: 3000 → 6000 steps
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 10 |
+
from peft import PeftModel
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
import os, time, random, json
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import List
|
| 15 |
+
|
| 16 |
+
CKPT = "./results/qwen3b_repetition_v2_fixed/ckpt_3000"
|
| 17 |
+
OUT = "./results/qwen3b_repetition_v2_continued"
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Config:
|
| 21 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 22 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 23 |
+
d_fiber: int = 16
|
| 24 |
+
d_control: int = 64
|
| 25 |
+
start_step: int = 3000
|
| 26 |
+
max_steps: int = 6000
|
| 27 |
+
batch_size: int = 1
|
| 28 |
+
grad_accum: int = 8
|
| 29 |
+
max_length: int = 256
|
| 30 |
+
lr_lora: float = 1e-5
|
| 31 |
+
lr_predictor: float = 5e-5
|
| 32 |
+
weight_decay: float = 0.01
|
| 33 |
+
rep_window: int = 32
|
| 34 |
+
log_every: int = 10
|
| 35 |
+
save_every: int = 500
|
| 36 |
+
eval_every: int = 200
|
| 37 |
+
|
code/training_pipelines/08_qwen3b_dimension_sweep_FULL.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
FIBER DIMENSION SWEEP + EXTENDED TRAINING: Qwen2.5-3B
|
| 4 |
+
======================================================
|
| 5 |
+
1. Quick sweep: d_fiber = [8, 16, 32] @ 800 steps each
|
| 6 |
+
2. Full training: best dimension @ 5000 steps
|
| 7 |
+
3. Target: 70x+ separation
|
| 8 |
+
|
| 9 |
+
Loads model ONCE, runs all sweeps, then extends training on winner.
|
| 10 |
+
|
| 11 |
+
Author: Logan Napolitano / Proprioception AI
|
| 12 |
+
Date: February 2026
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 19 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 20 |
+
from datasets import load_dataset
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
import random
|
| 24 |
+
import json
|
| 25 |
+
import gc
|
| 26 |
+
from dataclasses import dataclass, field
|
| 27 |
+
from typing import Tuple, List, Dict
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Sweep configuration
|
| 31 |
+
SWEEP_DIMS = [8, 16, 32]
|
| 32 |
+
SWEEP_STEPS = 800
|
| 33 |
+
FULL_TRAINING_STEPS = 5000
|
| 34 |
+
TARGET_SEPARATION = 70.0
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class Config:
|
| 38 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 39 |
+
output_dir: str = "./results/qwen3b_dimension_sweep"
|
| 40 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 41 |
+
d_fiber: int = 16 # Will be varied during sweep
|
| 42 |
+
d_control: int = 64
|
| 43 |
+
max_steps: int = 800
|
| 44 |
+
batch_size: int = 1
|
| 45 |
+
grad_accum: int = 8
|
| 46 |
+
max_length: int = 256
|
| 47 |
+
lr_lora: float = 2e-5
|
| 48 |
+
lr_predictor: float = 1e-4
|
| 49 |
+
weight_decay: float = 0.01
|
| 50 |
+
rep_window: int = 32
|
| 51 |
+
log_every: int = 50
|
| 52 |
+
eval_every: int = 200
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RiskPredictor(nn.Module):
|
| 56 |
+
def __init__(self, d_model: int, d_fiber: int, probe_layers: List[int], d_control: int = 64):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.probe_layers = probe_layers
|
| 59 |
+
self.d_fiber = d_fiber
|
| 60 |
+
n_probes = len(probe_layers)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
self.fiber_projs = nn.ModuleList([
|
| 64 |
+
nn.Linear(d_model, d_fiber, bias=False)
|
| 65 |
+
for _ in range(n_probes)
|
| 66 |
+
])
|
| 67 |
+
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
|
| 68 |
+
self.predictor = nn.Sequential(
|
| 69 |
+
nn.Linear(d_fiber, d_control),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
nn.Linear(d_control, d_control),
|
| 72 |
+
nn.GELU(),
|
| 73 |
+
nn.Linear(d_control, 1)
|
| 74 |
+
)
|
| 75 |
+
for proj in self.fiber_projs:
|
| 76 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 77 |
+
|
| 78 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 79 |
+
fibers = []
|
| 80 |
+
for i, layer_idx in enumerate(self.probe_layers):
|
| 81 |
+
if layer_idx < len(hidden_states):
|
| 82 |
+
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
|
| 83 |
+
fibers.append(fiber)
|
| 84 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 85 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 86 |
+
return self.predictor(aggregated).squeeze(-1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 90 |
+
B, S = input_ids.shape
|
| 91 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 92 |
+
for offset in range(1, min(window + 1, S)):
|
| 93 |
+
if offset < S:
|
| 94 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 95 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 96 |
+
return labels
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def compute_separation(predictor, model, tokenizer, device, config, n_samples=30):
|
| 100 |
+
model.eval()
|
| 101 |
+
predictor.eval()
|
| 102 |
+
pos_scores, neg_scores = [], []
|
| 103 |
+
prompts = [
|
| 104 |
+
"The meaning of life according to philosophy is",
|
| 105 |
+
"In the year 2050, technology will",
|
| 106 |
+
"The history of mathematics begins with",
|
| 107 |
+
"Climate change affects the planet by",
|
| 108 |
+
"Neural networks learn patterns through",
|
| 109 |
+
"The ocean contains many species of",
|
| 110 |
+
"Music has evolved significantly since",
|
| 111 |
+
"Economic theories suggest that markets",
|
| 112 |
+
"The human brain processes information",
|
| 113 |
+
"Ancient civilizations developed writing",
|
| 114 |
+
]
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
for i in range(n_samples):
|
| 117 |
+
prompt = prompts[i % len(prompts)]
|
| 118 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 119 |
+
input_ids = inp['input_ids'].to(device)
|
| 120 |
+
attn = inp['attention_mask'].to(device)
|
| 121 |
+
out = model.generate(input_ids, attention_mask=attn, max_new_tokens=80,
|
| 122 |
+
do_sample=True, temperature=0.9, top_p=0.95,
|
| 123 |
+
pad_token_id=tokenizer.eos_token_id)
|
| 124 |
+
outputs = model(out, output_hidden_states=True)
|
| 125 |
+
risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
|
| 126 |
+
labels = compute_repetition_labels(out, config.rep_window)[0].cpu().numpy()
|
| 127 |
+
for t in range(len(risk)):
|
| 128 |
+
(pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
|
| 129 |
+
if pos_scores and neg_scores:
|
| 130 |
+
p_pos, p_neg = sum(pos_scores)/len(pos_scores), sum(neg_scores)/len(neg_scores)
|
| 131 |
+
return p_pos, p_neg, p_pos/max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
|
| 132 |
+
return 0, 0, 0, 0, 0
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def train_probe(model, tokenizer, texts, device, d_model, config, d_fiber, max_steps,
|
| 136 |
+
existing_predictor=None, existing_optimizer=None):
|
| 137 |
+
"""Train a probe with given d_fiber. Returns (predictor, final_separation, history)."""
|
| 138 |
+
|
| 139 |
+
if existing_predictor is None:
|
| 140 |
+
predictor = RiskPredictor(d_model, d_fiber, config.probe_layers, config.d_control).to(device).float()
|
| 141 |
+
else:
|
| 142 |
+
predictor = existing_predictor
|
| 143 |
+
|
| 144 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 145 |
+
|
| 146 |
+
if existing_optimizer is None:
|
| 147 |
+
optimizer = torch.optim.AdamW([
|
| 148 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 149 |
+
{'params': predictor.parameters(), 'lr': config.lr_predictor}
|
| 150 |
+
], weight_decay=config.weight_decay)
|
| 151 |
+
else:
|
| 152 |
+
optimizer = existing_optimizer
|
| 153 |
+
|
| 154 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
|
| 155 |
+
|
| 156 |
+
model.train()
|
| 157 |
+
predictor.train()
|
| 158 |
+
|
| 159 |
+
history = {"steps": [], "separations": []}
|
| 160 |
+
step, data_idx = 0, 0
|
| 161 |
+
acc_loss, acc_risk = 0, 0
|
| 162 |
+
start = time.time()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
while step < max_steps:
|
| 166 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 167 |
+
data_idx += config.batch_size
|
| 168 |
+
|
| 169 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 170 |
+
padding='max_length', return_tensors='pt')
|
| 171 |
+
input_ids = enc['input_ids'].to(device)
|
| 172 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 173 |
+
|
| 174 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
|
| 175 |
+
labels=input_ids, output_hidden_states=True)
|
| 176 |
+
|
| 177 |
+
lm_loss = outputs.loss
|
| 178 |
+
risk_logits = predictor(outputs.hidden_states)
|
| 179 |
+
rep_labels = compute_repetition_labels(input_ids, config.rep_window)
|
| 180 |
+
|
| 181 |
+
mask = attention_mask.float()
|
| 182 |
+
n_pos = (rep_labels * mask).sum().clamp(min=1)
|
| 183 |
+
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
|
| 184 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 185 |
+
|
| 186 |
+
bce = F.binary_cross_entropy_with_logits(
|
| 187 |
+
risk_logits, rep_labels,
|
| 188 |
+
pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none')
|
| 189 |
+
risk_loss = (bce * mask).sum() / mask.sum()
|
| 190 |
+
|
| 191 |
+
loss = lm_loss + risk_loss
|
| 192 |
+
(loss / config.grad_accum).backward()
|
| 193 |
+
|
| 194 |
+
acc_loss += loss.item()
|
| 195 |
+
acc_risk += risk_loss.item()
|
| 196 |
+
step += 1
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if step % config.grad_accum == 0:
|
| 200 |
+
torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0)
|
| 201 |
+
optimizer.step()
|
| 202 |
+
scheduler.step()
|
| 203 |
+
optimizer.zero_grad()
|
| 204 |
+
|
| 205 |
+
if step % config.log_every == 0:
|
| 206 |
+
eta = (max_steps - step) / (step / (time.time() - start)) / 60
|
| 207 |
+
print(f" Step {step:4d}/{max_steps} | Loss: {acc_loss/config.log_every:.3f} | "
|
| 208 |
+
f"Risk: {acc_risk/config.log_every:.3f} | ETA: {eta:.1f}m")
|
| 209 |
+
history["steps"].append({"step": step, "loss": acc_loss/config.log_every})
|
| 210 |
+
acc_loss, acc_risk = 0, 0
|
| 211 |
+
|
| 212 |
+
if step % config.eval_every == 0:
|
| 213 |
+
p_pos, p_neg, sep, n_p, n_n = compute_separation(predictor, model, tokenizer, device, config)
|
| 214 |
+
print(f" >>> SEPARATION @ {step}: {sep:.1f}x (P+={p_pos:.3f}, P-={p_neg:.3f})")
|
| 215 |
+
history["separations"].append({"step": step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
|
| 216 |
+
model.train()
|
| 217 |
+
predictor.train()
|
| 218 |
+
|
| 219 |
+
# Final eval
|
| 220 |
+
p_pos, p_neg, final_sep, _, _ = compute_separation(predictor, model, tokenizer, device, config, n_samples=50)
|
| 221 |
+
return predictor, optimizer, final_sep, p_pos, p_neg, history
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
config = Config()
|
| 226 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 227 |
+
|
| 228 |
+
print("=" * 70)
|
| 229 |
+
print("FIBER DIMENSION SWEEP + EXTENDED TRAINING")
|
| 230 |
+
print(f"Target: {TARGET_SEPARATION}x separation on Qwen2.5-3B")
|
| 231 |
+
print("=" * 70)
|
| 232 |
+
print(f"Sweep dimensions: {SWEEP_DIMS}")
|
| 233 |
+
print(f"Sweep steps each: {SWEEP_STEPS}")
|
| 234 |
+
print(f"Full training steps: {FULL_TRAINING_STEPS}")
|
| 235 |
+
print()
|
| 236 |
+
|
| 237 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 238 |
+
if tokenizer.pad_token is None:
|
| 239 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 240 |
+
|
| 241 |
+
print("Loading Qwen2.5-3B...")
|
| 242 |
+
bnb = BitsAndBytesConfig(
|
| 243 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 244 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 245 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 246 |
+
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
|
| 247 |
+
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
| 248 |
+
|
| 249 |
+
print("Adding LoRA...")
|
| 250 |
+
model = get_peft_model(model, LoraConfig(
|
| 251 |
+
r=64, lora_alpha=128, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 252 |
+
lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"))
|
| 253 |
+
model.print_trainable_parameters()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
device = next(model.parameters()).device
|
| 257 |
+
d_model = model.config.hidden_size
|
| 258 |
+
|
| 259 |
+
print("Loading data...")
|
| 260 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 261 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 262 |
+
random.shuffle(texts)
|
| 263 |
+
print(f"Loaded {len(texts)} samples\n")
|
| 264 |
+
|
| 265 |
+
# =========================================================================
|
| 266 |
+
# PHASE 1: DIMENSION SWEEP
|
| 267 |
+
# =========================================================================
|
| 268 |
+
print("=" * 70)
|
| 269 |
+
print("PHASE 1: DIMENSION SWEEP")
|
| 270 |
+
print("=" * 70)
|
| 271 |
+
|
| 272 |
+
sweep_results = {}
|
| 273 |
+
best_dim, best_sep = None, 0
|
| 274 |
+
|
| 275 |
+
for d_fiber in SWEEP_DIMS:
|
| 276 |
+
print(f"\n{'─'*50}")
|
| 277 |
+
print(f"Testing d_fiber = {d_fiber}")
|
| 278 |
+
print(f" Projection: {d_model} → {d_fiber} ({d_model//d_fiber}:1 compression)")
|
| 279 |
+
print(f"{'─'*50}")
|
| 280 |
+
|
| 281 |
+
# Reset LoRA weights for fair comparison
|
| 282 |
+
for name, param in model.named_parameters():
|
| 283 |
+
if 'lora' in name.lower() and param.requires_grad:
|
| 284 |
+
if 'weight' in name:
|
| 285 |
+
nn.init.kaiming_uniform_(param)
|
| 286 |
+
elif 'bias' in name:
|
| 287 |
+
nn.init.zeros_(param)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
predictor, optimizer, sep, p_pos, p_neg, history = train_probe(
|
| 291 |
+
model, tokenizer, texts, device, d_model, config,
|
| 292 |
+
d_fiber=d_fiber, max_steps=SWEEP_STEPS)
|
| 293 |
+
|
| 294 |
+
sweep_results[d_fiber] = {
|
| 295 |
+
"separation": sep, "p_pos": p_pos, "p_neg": p_neg, "history": history}
|
| 296 |
+
|
| 297 |
+
print(f"\n d_fiber={d_fiber} RESULT: {sep:.1f}x separation")
|
| 298 |
+
|
| 299 |
+
if sep > best_sep:
|
| 300 |
+
best_sep = sep
|
| 301 |
+
best_dim = d_fiber
|
| 302 |
+
best_predictor = predictor
|
| 303 |
+
best_optimizer = optimizer
|
| 304 |
+
|
| 305 |
+
# Clear predictor if not best
|
| 306 |
+
if d_fiber != best_dim:
|
| 307 |
+
del predictor
|
| 308 |
+
gc.collect()
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
# Sweep summary
|
| 312 |
+
print("\n" + "=" * 70)
|
| 313 |
+
print("SWEEP RESULTS")
|
| 314 |
+
print("=" * 70)
|
| 315 |
+
for d, res in sweep_results.items():
|
| 316 |
+
marker = " ← BEST" if d == best_dim else ""
|
| 317 |
+
print(f" d_fiber={d:2d}: {res['separation']:6.1f}x (P+={res['p_pos']:.3f}, P-={res['p_neg']:.3f}){marker}")
|
| 318 |
+
print()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# =========================================================================
|
| 322 |
+
# PHASE 2: EXTENDED TRAINING ON BEST DIMENSION
|
| 323 |
+
# =========================================================================
|
| 324 |
+
print("=" * 70)
|
| 325 |
+
print(f"PHASE 2: EXTENDED TRAINING (d_fiber={best_dim})")
|
| 326 |
+
print(f"Current: {best_sep:.1f}x → Target: {TARGET_SEPARATION}x")
|
| 327 |
+
print("=" * 70)
|
| 328 |
+
|
| 329 |
+
remaining_steps = FULL_TRAINING_STEPS - SWEEP_STEPS
|
| 330 |
+
print(f"Running {remaining_steps} more steps...\n")
|
| 331 |
+
|
| 332 |
+
config.eval_every = 400 # Less frequent evals for extended training
|
| 333 |
+
config.log_every = 100
|
| 334 |
+
|
| 335 |
+
best_predictor, _, final_sep, final_p_pos, final_p_neg, ext_history = train_probe(
|
| 336 |
+
model, tokenizer, texts, device, d_model, config,
|
| 337 |
+
d_fiber=best_dim, max_steps=remaining_steps,
|
| 338 |
+
existing_predictor=best_predictor, existing_optimizer=best_optimizer)
|
| 339 |
+
|
| 340 |
+
# =========================================================================
|
| 341 |
+
# FINAL RESULTS
|
| 342 |
+
# =========================================================================
|
| 343 |
+
print("\n" + "=" * 70)
|
| 344 |
+
print("FINAL RESULTS")
|
| 345 |
+
print("=" * 70)
|
| 346 |
+
|
| 347 |
+
target_hit = "✅ TARGET HIT" if final_sep >= TARGET_SEPARATION else f"⚠️ {final_sep:.1f}x < {TARGET_SEPARATION}x target"
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
print(f"""
|
| 351 |
+
┌─────────────────────────────────────────────────────────┐
|
| 352 |
+
│ CROSS-ARCHITECTURE REPLICATION RESULTS │
|
| 353 |
+
├─────────────────────────────────────────────────────────┤
|
| 354 |
+
│ │
|
| 355 |
+
│ LLaMA-3.1-8B baseline: 125x separation │
|
| 356 |
+
│ │
|
| 357 |
+
│ Qwen2.5-3B (this run): │
|
| 358 |
+
│ Best d_fiber: {best_dim} │
|
| 359 |
+
│ Final separation: {final_sep:.1f}x │
|
| 360 |
+
│ P(+): {final_p_pos:.4f} │
|
| 361 |
+
│ P(-): {final_p_neg:.4f} │
|
| 362 |
+
│ │
|
| 363 |
+
│ {target_hit:^53} │
|
| 364 |
+
│ │
|
| 365 |
+
│ Sweep results: │""")
|
| 366 |
+
for d, res in sweep_results.items():
|
| 367 |
+
print(f"│ d_fiber={d:2d}: {res['separation']:5.1f}x{' ← selected' if d == best_dim else '':>20} │")
|
| 368 |
+
print(f"""│ │
|
| 369 |
+
│ Method: Fiber projection (identical to LLaMA-8B) │
|
| 370 |
+
│ Probe layers: {config.probe_layers} │
|
| 371 |
+
│ Architecture: Qwen2 (2048d, 36L) │
|
| 372 |
+
└─────────────────────────────────────────────────────────┘
|
| 373 |
+
""")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Save everything
|
| 377 |
+
full_results = {
|
| 378 |
+
"experiment": "qwen3b_dimension_sweep_extended",
|
| 379 |
+
"target_separation": TARGET_SEPARATION,
|
| 380 |
+
"sweep_dims": SWEEP_DIMS,
|
| 381 |
+
"sweep_steps": SWEEP_STEPS,
|
| 382 |
+
"full_training_steps": FULL_TRAINING_STEPS,
|
| 383 |
+
"best_d_fiber": best_dim,
|
| 384 |
+
"final_separation": final_sep,
|
| 385 |
+
"final_p_pos": final_p_pos,
|
| 386 |
+
"final_p_neg": final_p_neg,
|
| 387 |
+
"target_hit": final_sep >= TARGET_SEPARATION,
|
| 388 |
+
"sweep_results": {str(k): {"separation": v["separation"], "p_pos": v["p_pos"], "p_neg": v["p_neg"]}
|
| 389 |
+
for k, v in sweep_results.items()},
|
| 390 |
+
"baseline_comparison": {
|
| 391 |
+
"llama_8b_separation": 125.0,
|
| 392 |
+
"qwen_3b_separation": final_sep,
|
| 393 |
+
"ratio": final_sep / 125.0
|
| 394 |
+
}
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
with open(os.path.join(config.output_dir, "full_results.json"), 'w') as f:
|
| 398 |
+
json.dump(full_results, f, indent=2)
|
| 399 |
+
|
| 400 |
+
# Save best model
|
| 401 |
+
final_dir = os.path.join(config.output_dir, "final")
|
| 402 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 403 |
+
model.save_pretrained(final_dir)
|
| 404 |
+
torch.save({
|
| 405 |
+
'risk_predictor': best_predictor.state_dict(),
|
| 406 |
+
'd_fiber': best_dim,
|
| 407 |
+
'separation': final_sep,
|
| 408 |
+
'p_pos': final_p_pos,
|
| 409 |
+
'p_neg': final_p_neg
|
| 410 |
+
}, os.path.join(final_dir, "risk_predictor.pt"))
|
| 411 |
+
|
| 412 |
+
print(f"Results saved to {config.output_dir}/full_results.json")
|
| 413 |
+
print(f"Model saved to {final_dir}/")
|
| 414 |
+
print("\nDONE!")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main()
|
code/training_pipelines/09_continue_from_19x.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CONTINUE FROM 73.1x CHECKPOINT
|
| 4 |
+
============================
|
| 5 |
+
Loads the successful Qwen checkpoint (73.1x @ step 10000) and continues training.
|
| 6 |
+
Target: 100x+ separation
|
| 7 |
+
|
| 8 |
+
Author: Logan Napolitano / Proprioception AI
|
| 9 |
+
Date: February 2026
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 16 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
import os
|
| 19 |
+
import time
|
| 20 |
+
import random
|
| 21 |
+
import json
|
| 22 |
+
from dataclasses import dataclass, field
|
| 23 |
+
from typing import List, Tuple
|
| 24 |
+
|
| 25 |
+
CHECKPOINT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_19x/final"
|
| 26 |
+
OUTPUT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_56x"
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Config:
|
| 30 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 31 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 32 |
+
d_fiber: int = 16
|
| 33 |
+
d_control: int = 64
|
| 34 |
+
additional_steps: int = 25000 # Continue for 25000 more steps (total 35000)
|
| 35 |
+
batch_size: int = 1
|
| 36 |
+
grad_accum: int = 8
|
| 37 |
+
max_length: int = 256
|
| 38 |
+
lr_lora: float = 2e-6 # MUCH lower - model already trained
|
| 39 |
+
lr_predictor: float = 1e-5 # MUCH lower - predictor already trained
|
| 40 |
+
weight_decay: float = 0.01
|
| 41 |
+
rep_window: int = 32
|
| 42 |
+
log_every: int = 100
|
| 43 |
+
save_every: int = 5000
|
| 44 |
+
eval_every: int = 1000
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RiskPredictor(nn.Module):
|
| 48 |
+
def __init__(self, d_model: int, probe_layers: List[int], d_fiber: int = 16, d_control: int = 64):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.probe_layers = probe_layers
|
| 51 |
+
n_probes = len(probe_layers)
|
| 52 |
+
self.fiber_projs = nn.ModuleList([
|
| 53 |
+
nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_probes)
|
| 54 |
+
])
|
| 55 |
+
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
|
| 56 |
+
self.predictor = nn.Sequential(
|
| 57 |
+
nn.Linear(d_fiber, d_control), nn.GELU(),
|
| 58 |
+
nn.Linear(d_control, d_control), nn.GELU(),
|
| 59 |
+
nn.Linear(d_control, 1)
|
| 60 |
+
)
|
| 61 |
+
for proj in self.fiber_projs:
|
| 62 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 63 |
+
|
| 64 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 65 |
+
fibers = []
|
| 66 |
+
for i, layer_idx in enumerate(self.probe_layers):
|
| 67 |
+
if layer_idx < len(hidden_states):
|
| 68 |
+
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
|
| 69 |
+
fibers.append(fiber)
|
| 70 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 71 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 72 |
+
return self.predictor(aggregated).squeeze(-1)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 76 |
+
B, S = input_ids.shape
|
| 77 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 78 |
+
for offset in range(1, min(window + 1, S)):
|
| 79 |
+
if offset < S:
|
| 80 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 81 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 82 |
+
return labels
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def compute_separation(predictor, model, tokenizer, device, config, n_samples=50):
|
| 86 |
+
model.eval()
|
| 87 |
+
predictor.eval()
|
| 88 |
+
pos_scores, neg_scores = [], []
|
| 89 |
+
prompts = [
|
| 90 |
+
"The meaning of life according to philosophy is",
|
| 91 |
+
"In the year 2050, technology will",
|
| 92 |
+
"The history of mathematics begins with",
|
| 93 |
+
"Climate change affects the planet by",
|
| 94 |
+
"Neural networks learn patterns through",
|
| 95 |
+
"The ocean contains many species of",
|
| 96 |
+
"Music has evolved significantly since",
|
| 97 |
+
"Economic theories suggest that markets",
|
| 98 |
+
"The human brain processes information",
|
| 99 |
+
"Ancient civilizations developed writing",
|
| 100 |
+
"The quick brown fox jumps over the lazy",
|
| 101 |
+
"Once upon a time in a land far away",
|
| 102 |
+
"The scientific method involves several steps",
|
| 103 |
+
"When writing code, it is important to",
|
| 104 |
+
"In conclusion, we can see that the evidence",
|
| 105 |
+
"There are several reasons why this matters",
|
| 106 |
+
"Let me explain how this works step by step",
|
| 107 |
+
"The main point I want to make is that",
|
| 108 |
+
"According to recent research findings",
|
| 109 |
+
"One way to look at this problem is",
|
| 110 |
+
]
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
for i in range(n_samples):
|
| 113 |
+
prompt = prompts[i % len(prompts)]
|
| 114 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 115 |
+
input_ids = inp['input_ids'].to(device)
|
| 116 |
+
attn = inp['attention_mask'].to(device)
|
| 117 |
+
# DETERMINISTIC for consistent evaluation
|
| 118 |
+
out = model.generate(input_ids, attention_mask=attn, max_new_tokens=80,
|
| 119 |
+
do_sample=False,
|
| 120 |
+
pad_token_id=tokenizer.eos_token_id)
|
| 121 |
+
outputs = model(out, output_hidden_states=True)
|
| 122 |
+
risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
|
| 123 |
+
labels = compute_repetition_labels(out, config.rep_window)[0].cpu().numpy()
|
| 124 |
+
for t in range(len(risk)):
|
| 125 |
+
(pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
|
| 126 |
+
if pos_scores and neg_scores:
|
| 127 |
+
p_pos, p_neg = sum(pos_scores)/len(pos_scores), sum(neg_scores)/len(neg_scores)
|
| 128 |
+
return p_pos, p_neg, p_pos/max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
|
| 129 |
+
return 0, 0, 0, 0, 0
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
config = Config()
|
| 134 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 135 |
+
|
| 136 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 137 |
+
if tokenizer.pad_token is None:
|
| 138 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 139 |
+
|
| 140 |
+
print("Loading base model...")
|
| 141 |
+
bnb = BitsAndBytesConfig(
|
| 142 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 143 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 144 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 145 |
+
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
|
| 146 |
+
base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
|
| 147 |
+
|
| 148 |
+
print("Loading LoRA weights from checkpoint...")
|
| 149 |
+
model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
|
| 150 |
+
model.train()
|
| 151 |
+
|
| 152 |
+
# Make LoRA trainable again
|
| 153 |
+
for name, param in model.named_parameters():
|
| 154 |
+
if 'lora' in name.lower():
|
| 155 |
+
param.requires_grad = True
|
| 156 |
+
|
| 157 |
+
device = next(model.parameters()).device
|
| 158 |
+
d_model = model.config.hidden_size
|
| 159 |
+
|
| 160 |
+
print("Loading risk predictor from checkpoint...")
|
| 161 |
+
risk_predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control).to(device).float()
|
| 162 |
+
ckpt = torch.load(os.path.join(CHECKPOINT_DIR, "risk_predictor.pt"), map_location=device)
|
| 163 |
+
risk_predictor.load_state_dict(ckpt['risk_predictor'])
|
| 164 |
+
start_step = ckpt['step']
|
| 165 |
+
start_sep = ckpt['separation']
|
| 166 |
+
|
| 167 |
+
print()
|
| 168 |
+
print("=" * 70)
|
| 169 |
+
print("CONTINUING FROM CHECKPOINT (deterministic eval)")
|
| 170 |
+
print("=" * 70)
|
| 171 |
+
print(f"Starting point: {start_sep:.1f}x separation @ step {start_step}")
|
| 172 |
+
print(f"Target: 100x+ separation")
|
| 173 |
+
print(f"Additional steps: {config.additional_steps}")
|
| 174 |
+
print(f"LR: LoRA={config.lr_lora}, Predictor={config.lr_predictor}")
|
| 175 |
+
print()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
print("Loading data...")
|
| 179 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 180 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 181 |
+
random.shuffle(texts)
|
| 182 |
+
print(f"Loaded {len(texts)} samples")
|
| 183 |
+
|
| 184 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 185 |
+
optimizer = torch.optim.AdamW([
|
| 186 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 187 |
+
{'params': risk_predictor.parameters(), 'lr': config.lr_predictor}
|
| 188 |
+
], weight_decay=config.weight_decay)
|
| 189 |
+
|
| 190 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 191 |
+
optimizer, T_max=config.additional_steps, eta_min=1e-6)
|
| 192 |
+
|
| 193 |
+
log = {
|
| 194 |
+
"experiment": "continue_from_73x",
|
| 195 |
+
"start_step": start_step,
|
| 196 |
+
"start_separation": start_sep,
|
| 197 |
+
"target": "100x+",
|
| 198 |
+
"steps": [],
|
| 199 |
+
"separations": []
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
print()
|
| 203 |
+
print("=" * 70)
|
| 204 |
+
print("TRAINING")
|
| 205 |
+
print("=" * 70)
|
| 206 |
+
|
| 207 |
+
model.train()
|
| 208 |
+
risk_predictor.train()
|
| 209 |
+
|
| 210 |
+
step = 0
|
| 211 |
+
total_step = start_step
|
| 212 |
+
data_idx = 0
|
| 213 |
+
acc_loss, acc_risk = 0, 0
|
| 214 |
+
best_sep = start_sep
|
| 215 |
+
start_time = time.time()
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
while step < config.additional_steps:
|
| 219 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 220 |
+
data_idx += config.batch_size
|
| 221 |
+
|
| 222 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 223 |
+
padding='max_length', return_tensors='pt')
|
| 224 |
+
input_ids = enc['input_ids'].to(device)
|
| 225 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 226 |
+
|
| 227 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
|
| 228 |
+
labels=input_ids, output_hidden_states=True)
|
| 229 |
+
|
| 230 |
+
lm_loss = outputs.loss
|
| 231 |
+
risk_logits = risk_predictor(outputs.hidden_states)
|
| 232 |
+
rep_labels = compute_repetition_labels(input_ids, config.rep_window)
|
| 233 |
+
|
| 234 |
+
mask = attention_mask.float()
|
| 235 |
+
n_pos = (rep_labels * mask).sum().clamp(min=1)
|
| 236 |
+
n_neg = ((1 - rep_labels) * mask).sum().clamp(min=1)
|
| 237 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 238 |
+
|
| 239 |
+
bce = F.binary_cross_entropy_with_logits(
|
| 240 |
+
risk_logits, rep_labels,
|
| 241 |
+
pos_weight=torch.ones_like(rep_labels) * pos_weight, reduction='none')
|
| 242 |
+
risk_loss = (bce * mask).sum() / mask.sum()
|
| 243 |
+
|
| 244 |
+
loss = lm_loss + risk_loss
|
| 245 |
+
(loss / config.grad_accum).backward()
|
| 246 |
+
|
| 247 |
+
acc_loss += loss.item()
|
| 248 |
+
acc_risk += risk_loss.item()
|
| 249 |
+
step += 1
|
| 250 |
+
total_step += 1
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if step % config.grad_accum == 0:
|
| 254 |
+
torch.nn.utils.clip_grad_norm_(list(lora_params) + list(risk_predictor.parameters()), 1.0)
|
| 255 |
+
optimizer.step()
|
| 256 |
+
scheduler.step()
|
| 257 |
+
optimizer.zero_grad()
|
| 258 |
+
|
| 259 |
+
if step % config.log_every == 0:
|
| 260 |
+
eta = (config.additional_steps - step) / (step / (time.time() - start_time)) / 60
|
| 261 |
+
print(f"Step {total_step:5d} (+{step}) | Loss: {acc_loss/config.log_every:.3f} | "
|
| 262 |
+
f"Risk: {acc_risk/config.log_every:.3f} | Best: {best_sep:.1f}x | ETA: {eta:.1f}m")
|
| 263 |
+
log["steps"].append({"step": total_step, "loss": acc_loss/config.log_every})
|
| 264 |
+
acc_loss, acc_risk = 0, 0
|
| 265 |
+
|
| 266 |
+
if step % config.eval_every == 0:
|
| 267 |
+
print(f"\n{'='*50}")
|
| 268 |
+
print(f"SEPARATION EVAL @ Step {total_step}")
|
| 269 |
+
print(f"{'='*50}")
|
| 270 |
+
p_pos, p_neg, sep, n_p, n_n = compute_separation(risk_predictor, model, tokenizer, device, config)
|
| 271 |
+
print(f" P(+) = {p_pos:.4f} (n={n_p})")
|
| 272 |
+
print(f" P(-) = {p_neg:.4f} (n={n_n})")
|
| 273 |
+
print(f" SEPARATION = {sep:.1f}x")
|
| 274 |
+
print(f" [Target: 100x, Best so far: {best_sep:.1f}x]")
|
| 275 |
+
|
| 276 |
+
log["separations"].append({"step": total_step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
|
| 277 |
+
|
| 278 |
+
if sep > best_sep:
|
| 279 |
+
best_sep = sep
|
| 280 |
+
print(f" 🎯 NEW BEST!")
|
| 281 |
+
# Save best
|
| 282 |
+
best_dir = os.path.join(OUTPUT_DIR, "best")
|
| 283 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 284 |
+
model.save_pretrained(best_dir)
|
| 285 |
+
torch.save({
|
| 286 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 287 |
+
'step': total_step, 'separation': sep, 'p_pos': p_pos, 'p_neg': p_neg
|
| 288 |
+
}, os.path.join(best_dir, "risk_predictor.pt"))
|
| 289 |
+
|
| 290 |
+
with open(os.path.join(OUTPUT_DIR, "training_log.json"), 'w') as f:
|
| 291 |
+
json.dump(log, f, indent=2)
|
| 292 |
+
|
| 293 |
+
print(f"{'='*50}\n")
|
| 294 |
+
model.train()
|
| 295 |
+
risk_predictor.train()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
if step % config.save_every == 0:
|
| 299 |
+
ckpt_dir = os.path.join(OUTPUT_DIR, f"ckpt_{total_step}")
|
| 300 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 301 |
+
model.save_pretrained(ckpt_dir)
|
| 302 |
+
torch.save({
|
| 303 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 304 |
+
'step': total_step, 'separation': best_sep
|
| 305 |
+
}, os.path.join(ckpt_dir, "risk_predictor.pt"))
|
| 306 |
+
print(f">>> Checkpoint saved: {ckpt_dir}")
|
| 307 |
+
|
| 308 |
+
# Final eval
|
| 309 |
+
print("\n" + "=" * 70)
|
| 310 |
+
print("FINAL RESULTS")
|
| 311 |
+
print("=" * 70)
|
| 312 |
+
|
| 313 |
+
p_pos, p_neg, final_sep, _, _ = compute_separation(risk_predictor, model, tokenizer, device, config, n_samples=100)
|
| 314 |
+
|
| 315 |
+
target_hit = "✅ TARGET HIT!" if final_sep >= 100 else f"Reached {final_sep:.1f}x"
|
| 316 |
+
|
| 317 |
+
print(f"""
|
| 318 |
+
┌─────────────────────────────────────────────────────────┐
|
| 319 |
+
│ CONTINUED TRAINING RESULTS │
|
| 320 |
+
├─────────────────────────────────────────────────────────┤
|
| 321 |
+
│ Started: 73.1x @ step 10000 │
|
| 322 |
+
│ Final: {final_sep:>5.1f}x @ step {total_step} │
|
| 323 |
+
│ Best: {best_sep:>5.1f}x │
|
| 324 |
+
│ P(+): {p_pos:.4f} │
|
| 325 |
+
│ P(-): {p_neg:.4f} │
|
| 326 |
+
│ │
|
| 327 |
+
│ {target_hit:^54} │
|
| 328 |
+
└─────────────────────────────────────────────────────────┘
|
| 329 |
+
""")
|
| 330 |
+
|
| 331 |
+
log["final"] = {"step": total_step, "separation": final_sep, "best": best_sep, "p_pos": p_pos, "p_neg": p_neg}
|
| 332 |
+
with open(os.path.join(OUTPUT_DIR, "training_log.json"), 'w') as f:
|
| 333 |
+
json.dump(log, f, indent=2)
|
| 334 |
+
|
| 335 |
+
# Save final
|
| 336 |
+
final_dir = os.path.join(OUTPUT_DIR, "final")
|
| 337 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 338 |
+
model.save_pretrained(final_dir)
|
| 339 |
+
torch.save({
|
| 340 |
+
'risk_predictor': risk_predictor.state_dict(),
|
| 341 |
+
'step': total_step, 'separation': final_sep, 'p_pos': p_pos, 'p_neg': p_neg
|
| 342 |
+
}, os.path.join(final_dir, "risk_predictor.pt"))
|
| 343 |
+
|
| 344 |
+
print(f"Saved to {OUTPUT_DIR}")
|
| 345 |
+
print("DONE!")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
code/training_pipelines/10_qwen_multihead_25k.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
QWEN MULTI-HEAD BEHAVIORAL TRAINING
|
| 4 |
+
====================================
|
| 5 |
+
Continues repetition from 73.1x checkpoint (step 10000) to step 35000
|
| 6 |
+
Then trains hedging, verbosity, sycophancy heads for 25000 steps each
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 13 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import random
|
| 18 |
+
import json
|
| 19 |
+
import re
|
| 20 |
+
from dataclasses import dataclass, field
|
| 21 |
+
from typing import List, Tuple, Dict, Set
|
| 22 |
+
|
| 23 |
+
# Paths
|
| 24 |
+
CHECKPOINT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_19x/best"
|
| 25 |
+
OUTPUT_BASE = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_multihead"
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Config:
|
| 29 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 30 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 31 |
+
d_fiber: int = 16
|
| 32 |
+
d_control: int = 64
|
| 33 |
+
batch_size: int = 1
|
| 34 |
+
grad_accum: int = 8
|
| 35 |
+
max_length: int = 256
|
| 36 |
+
lr_lora: float = 1e-5
|
| 37 |
+
lr_predictor: float = 5e-5
|
| 38 |
+
weight_decay: float = 0.01
|
| 39 |
+
log_every: int = 100
|
| 40 |
+
eval_every: int = 1000
|
| 41 |
+
save_every: int = 5000
|
| 42 |
+
|
| 43 |
+
# ============== BEHAVIORAL LEXICONS ==============
|
| 44 |
+
|
| 45 |
+
HEDGING_TOKENS = {
|
| 46 |
+
"perhaps", "maybe", "might", "possibly", "probably", "likely",
|
| 47 |
+
"arguably", "somewhat", "relatively", "fairly", "quite",
|
| 48 |
+
"seem", "seems", "seemed", "appear", "appears", "appeared",
|
| 49 |
+
"suggest", "suggests", "suggested", "indicate", "indicates",
|
| 50 |
+
"could", "would", "may", "generally", "typically", "often",
|
| 51 |
+
"sometimes", "occasionally", "potentially", "presumably",
|
| 52 |
+
"apparently", "supposedly", "allegedly", "reportedly",
|
| 53 |
+
"tend", "tends", "tended", "believe", "believes", "believed",
|
| 54 |
+
"think", "thinks", "thought", "guess", "assume", "assumes"
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
VERBOSITY_PHRASES = [
|
| 58 |
+
"it is important to note that", "it should be noted that",
|
| 59 |
+
"in order to", "the fact that", "due to the fact that",
|
| 60 |
+
"for the purpose of", "in the event that", "at this point in time",
|
| 61 |
+
"in my opinion", "i believe that", "i think that",
|
| 62 |
+
"essentially", "basically", "fundamentally", "actually",
|
| 63 |
+
"literally", "virtually", "practically", "effectively",
|
| 64 |
+
"in terms of", "with respect to", "with regard to",
|
| 65 |
+
"as a matter of fact", "to be honest", "to tell the truth",
|
| 66 |
+
"needless to say", "it goes without saying",
|
| 67 |
+
"as you can see", "as mentioned earlier", "as stated above",
|
| 68 |
+
"in other words", "that is to say", "so to speak"
|
| 69 |
+
]
|
| 70 |
+
|
| 71 |
+
SYCOPHANCY_TOKENS = {
|
| 72 |
+
"great", "excellent", "wonderful", "fantastic", "amazing",
|
| 73 |
+
"brilliant", "outstanding", "superb", "terrific", "magnificent",
|
| 74 |
+
"absolutely", "definitely", "certainly", "exactly", "precisely",
|
| 75 |
+
"right", "correct", "agree", "agreed", "true",
|
| 76 |
+
"insightful", "thoughtful", "clever", "smart", "wise",
|
| 77 |
+
"fascinating", "interesting", "intriguing", "compelling"
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
SYCOPHANCY_PHRASES = [
|
| 81 |
+
"great question", "excellent question", "good question",
|
| 82 |
+
"that's a great point", "that's an excellent point",
|
| 83 |
+
"you're absolutely right", "you're exactly right",
|
| 84 |
+
"i completely agree", "i totally agree",
|
| 85 |
+
"what a fascinating", "what an interesting",
|
| 86 |
+
"you raise a great point", "you make an excellent point"
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ============== LABELING FUNCTIONS ==============
|
| 91 |
+
|
| 92 |
+
def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 93 |
+
B, S = input_ids.shape
|
| 94 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 95 |
+
for offset in range(1, min(window + 1, S)):
|
| 96 |
+
if offset < S:
|
| 97 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 98 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 99 |
+
return labels
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def compute_hedging_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 103 |
+
B, S = input_ids.shape
|
| 104 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 105 |
+
for b in range(B):
|
| 106 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
|
| 107 |
+
for t, tok in enumerate(tokens):
|
| 108 |
+
tok_clean = tok.lower().replace('▁', '').replace('Ġ', '').strip()
|
| 109 |
+
if tok_clean in HEDGING_TOKENS:
|
| 110 |
+
labels[b, t] = 1.0
|
| 111 |
+
return labels
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def compute_verbosity_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 115 |
+
B, S = input_ids.shape
|
| 116 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 117 |
+
for b in range(B):
|
| 118 |
+
text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
|
| 119 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
|
| 120 |
+
|
| 121 |
+
# Find phrase positions
|
| 122 |
+
for phrase in VERBOSITY_PHRASES:
|
| 123 |
+
start = 0
|
| 124 |
+
while True:
|
| 125 |
+
idx = text.find(phrase, start)
|
| 126 |
+
if idx == -1:
|
| 127 |
+
break
|
| 128 |
+
# Mark tokens in this range
|
| 129 |
+
char_count = 0
|
| 130 |
+
for t, tok in enumerate(tokens):
|
| 131 |
+
tok_text = tok.replace('▁', ' ').replace('Ġ', ' ')
|
| 132 |
+
tok_len = len(tok_text)
|
| 133 |
+
if char_count >= idx and char_count < idx + len(phrase):
|
| 134 |
+
labels[b, t] = 1.0
|
| 135 |
+
char_count += tok_len
|
| 136 |
+
start = idx + 1
|
| 137 |
+
return labels
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def compute_sycophancy_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 141 |
+
B, S = input_ids.shape
|
| 142 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 143 |
+
for b in range(B):
|
| 144 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[b].cpu().tolist())
|
| 145 |
+
text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
|
| 146 |
+
|
| 147 |
+
# Single token matches
|
| 148 |
+
for t, tok in enumerate(tokens):
|
| 149 |
+
tok_clean = tok.lower().replace('▁', '').replace('Ġ', '').strip()
|
| 150 |
+
if tok_clean in SYCOPHANCY_TOKENS:
|
| 151 |
+
labels[b, t] = 1.0
|
| 152 |
+
|
| 153 |
+
# Phrase matches
|
| 154 |
+
for phrase in SYCOPHANCY_PHRASES:
|
| 155 |
+
start = 0
|
| 156 |
+
while True:
|
| 157 |
+
idx = text.find(phrase, start)
|
| 158 |
+
if idx == -1:
|
| 159 |
+
break
|
| 160 |
+
char_count = 0
|
| 161 |
+
for t, tok in enumerate(tokens):
|
| 162 |
+
tok_text = tok.replace('▁', ' ').replace('Ġ', ' ')
|
| 163 |
+
tok_len = len(tok_text)
|
| 164 |
+
if char_count >= idx and char_count < idx + len(phrase):
|
| 165 |
+
labels[b, t] = 1.0
|
| 166 |
+
char_count += tok_len
|
| 167 |
+
start = idx + 1
|
| 168 |
+
return labels
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
LABEL_FUNCTIONS = {
|
| 172 |
+
"repetition": lambda ids, tok: compute_repetition_labels(ids),
|
| 173 |
+
"hedging": compute_hedging_labels,
|
| 174 |
+
"verbosity": compute_verbosity_labels,
|
| 175 |
+
"sycophancy": compute_sycophancy_labels
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ============== PROBE ARCHITECTURE ==============
|
| 180 |
+
|
| 181 |
+
class RiskPredictor(nn.Module):
|
| 182 |
+
def __init__(self, d_model: int, probe_layers: List[int], d_fiber: int = 16, d_control: int = 64):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.probe_layers = probe_layers
|
| 185 |
+
n_probes = len(probe_layers)
|
| 186 |
+
self.fiber_projs = nn.ModuleList([
|
| 187 |
+
nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_probes)
|
| 188 |
+
])
|
| 189 |
+
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
|
| 190 |
+
self.predictor = nn.Sequential(
|
| 191 |
+
nn.Linear(d_fiber, d_control), nn.GELU(),
|
| 192 |
+
nn.Linear(d_control, d_control), nn.GELU(),
|
| 193 |
+
nn.Linear(d_control, 1)
|
| 194 |
+
)
|
| 195 |
+
for proj in self.fiber_projs:
|
| 196 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 197 |
+
|
| 198 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 199 |
+
fibers = []
|
| 200 |
+
for i, layer_idx in enumerate(self.probe_layers):
|
| 201 |
+
if layer_idx < len(hidden_states):
|
| 202 |
+
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
|
| 203 |
+
fibers.append(fiber)
|
| 204 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 205 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 206 |
+
return self.predictor(aggregated).squeeze(-1)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def compute_separation(predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=50):
|
| 210 |
+
model.eval()
|
| 211 |
+
predictor.eval()
|
| 212 |
+
pos_scores, neg_scores = [], []
|
| 213 |
+
|
| 214 |
+
# Diverse prompts for robust evaluation
|
| 215 |
+
prompts = [
|
| 216 |
+
"The meaning of life according to philosophy is",
|
| 217 |
+
"In the year 2050, technology will",
|
| 218 |
+
"The history of mathematics begins with",
|
| 219 |
+
"Climate change affects the planet by",
|
| 220 |
+
"Neural networks learn patterns through",
|
| 221 |
+
"What do you think about artificial intelligence",
|
| 222 |
+
"Can you help me understand quantum physics",
|
| 223 |
+
"I believe that education is important because",
|
| 224 |
+
"The best way to solve this problem would be",
|
| 225 |
+
"Many experts suggest that we should consider",
|
| 226 |
+
"The quick brown fox jumps over the lazy",
|
| 227 |
+
"Once upon a time in a land far away",
|
| 228 |
+
"The scientific method involves several steps including",
|
| 229 |
+
"When writing code, it is important to",
|
| 230 |
+
"The human brain processes information by",
|
| 231 |
+
"In conclusion, we can see that the evidence",
|
| 232 |
+
"There are several reasons why this matters",
|
| 233 |
+
"Let me explain how this works step by step",
|
| 234 |
+
"The main point I want to make is that",
|
| 235 |
+
"According to recent research findings",
|
| 236 |
+
"I think the answer to your question is",
|
| 237 |
+
"This is a very interesting topic because",
|
| 238 |
+
"One way to look at this problem is",
|
| 239 |
+
"The fundamental principle here is that",
|
| 240 |
+
"What makes this particularly important is",
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
for i in range(n_samples):
|
| 245 |
+
prompt = prompts[i % len(prompts)]
|
| 246 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 247 |
+
input_ids = inp['input_ids'].to(device)
|
| 248 |
+
attn = inp['attention_mask'].to(device)
|
| 249 |
+
|
| 250 |
+
# DETERMINISTIC generation for consistent evaluation
|
| 251 |
+
out = model.generate(input_ids, attention_mask=attn, max_new_tokens=100,
|
| 252 |
+
do_sample=False, # Greedy decoding for consistency
|
| 253 |
+
pad_token_id=tokenizer.eos_token_id)
|
| 254 |
+
|
| 255 |
+
outputs = model(out, output_hidden_states=True)
|
| 256 |
+
risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
|
| 257 |
+
|
| 258 |
+
if behavior == "repetition":
|
| 259 |
+
labels = compute_repetition_labels(out, 32)[0].cpu().numpy()
|
| 260 |
+
else:
|
| 261 |
+
labels = label_fn(out, tokenizer)[0].cpu().numpy()
|
| 262 |
+
|
| 263 |
+
for t in range(len(risk)):
|
| 264 |
+
(pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
|
| 265 |
+
|
| 266 |
+
if pos_scores and neg_scores:
|
| 267 |
+
p_pos = sum(pos_scores) / len(pos_scores)
|
| 268 |
+
p_neg = sum(neg_scores) / len(neg_scores)
|
| 269 |
+
return p_pos, p_neg, p_pos / max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
|
| 270 |
+
return 0, 0, 0, 0, 0
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ============== TRAINING FUNCTION ==============
|
| 274 |
+
|
| 275 |
+
def train_head(model, tokenizer, texts, device, d_model, config, behavior,
|
| 276 |
+
max_steps, output_dir, start_predictor=None, start_step=0, start_best=0):
|
| 277 |
+
"""Train a single behavioral head."""
|
| 278 |
+
|
| 279 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 280 |
+
|
| 281 |
+
print(f"\n{'='*70}")
|
| 282 |
+
print(f"TRAINING: {behavior.upper()}")
|
| 283 |
+
print(f"{'='*70}")
|
| 284 |
+
print(f"Steps: {max_steps} (starting from step {start_step})")
|
| 285 |
+
print(f"Output: {output_dir}")
|
| 286 |
+
print()
|
| 287 |
+
|
| 288 |
+
# Initialize or load predictor
|
| 289 |
+
if start_predictor is not None:
|
| 290 |
+
predictor = start_predictor
|
| 291 |
+
print("Continuing from checkpoint...")
|
| 292 |
+
else:
|
| 293 |
+
predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
|
| 294 |
+
predictor = predictor.to(device).float()
|
| 295 |
+
print("Fresh predictor initialized")
|
| 296 |
+
|
| 297 |
+
# Get label function
|
| 298 |
+
if behavior == "repetition":
|
| 299 |
+
label_fn = lambda ids, tok: compute_repetition_labels(ids)
|
| 300 |
+
else:
|
| 301 |
+
label_fn = LABEL_FUNCTIONS[behavior]
|
| 302 |
+
|
| 303 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 304 |
+
optimizer = torch.optim.AdamW([
|
| 305 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 306 |
+
{'params': predictor.parameters(), 'lr': config.lr_predictor}
|
| 307 |
+
], weight_decay=config.weight_decay)
|
| 308 |
+
|
| 309 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps, eta_min=1e-6)
|
| 310 |
+
|
| 311 |
+
log = {"behavior": behavior, "steps": [], "separations": []}
|
| 312 |
+
|
| 313 |
+
model.train()
|
| 314 |
+
predictor.train()
|
| 315 |
+
|
| 316 |
+
step = 0
|
| 317 |
+
total_step = start_step # Track total steps including checkpoint
|
| 318 |
+
data_idx = 0
|
| 319 |
+
acc_loss, acc_risk = 0, 0
|
| 320 |
+
best_sep = start_best # Preserve checkpoint's best separation
|
| 321 |
+
start_time = time.time()
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
while step < max_steps:
|
| 325 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 326 |
+
data_idx += config.batch_size
|
| 327 |
+
|
| 328 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 329 |
+
padding='max_length', return_tensors='pt')
|
| 330 |
+
input_ids = enc['input_ids'].to(device)
|
| 331 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 332 |
+
|
| 333 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
|
| 334 |
+
labels=input_ids, output_hidden_states=True)
|
| 335 |
+
|
| 336 |
+
lm_loss = outputs.loss
|
| 337 |
+
risk_logits = predictor(outputs.hidden_states)
|
| 338 |
+
|
| 339 |
+
# Get labels for this behavior
|
| 340 |
+
if behavior == "repetition":
|
| 341 |
+
labels = compute_repetition_labels(input_ids)
|
| 342 |
+
else:
|
| 343 |
+
labels = label_fn(input_ids, tokenizer)
|
| 344 |
+
|
| 345 |
+
mask = attention_mask.float()
|
| 346 |
+
n_pos = (labels * mask).sum().clamp(min=1)
|
| 347 |
+
n_neg = ((1 - labels) * mask).sum().clamp(min=1)
|
| 348 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 349 |
+
|
| 350 |
+
bce = F.binary_cross_entropy_with_logits(
|
| 351 |
+
risk_logits, labels,
|
| 352 |
+
pos_weight=torch.ones_like(labels) * pos_weight, reduction='none')
|
| 353 |
+
risk_loss = (bce * mask).sum() / mask.sum()
|
| 354 |
+
|
| 355 |
+
loss = lm_loss + risk_loss
|
| 356 |
+
(loss / config.grad_accum).backward()
|
| 357 |
+
|
| 358 |
+
acc_loss += loss.item()
|
| 359 |
+
acc_risk += risk_loss.item()
|
| 360 |
+
step += 1
|
| 361 |
+
total_step += 1
|
| 362 |
+
|
| 363 |
+
if step % config.grad_accum == 0:
|
| 364 |
+
torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0)
|
| 365 |
+
optimizer.step()
|
| 366 |
+
scheduler.step()
|
| 367 |
+
optimizer.zero_grad()
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
if step % config.log_every == 0:
|
| 371 |
+
eta = (max_steps - step) / (step / (time.time() - start_time)) / 60
|
| 372 |
+
print(f"[{behavior}] Step {total_step:5d} (+{step}) | Loss: {acc_loss/config.log_every:.3f} | "
|
| 373 |
+
f"Risk: {acc_risk/config.log_every:.3f} | Best: {best_sep:.1f}x | ETA: {eta:.1f}m")
|
| 374 |
+
log["steps"].append({"step": total_step, "loss": acc_loss/config.log_every})
|
| 375 |
+
acc_loss, acc_risk = 0, 0
|
| 376 |
+
|
| 377 |
+
if step % config.eval_every == 0:
|
| 378 |
+
print(f"\n{'='*50}")
|
| 379 |
+
print(f"[{behavior}] SEPARATION EVAL @ Step {total_step}")
|
| 380 |
+
print(f"{'='*50}")
|
| 381 |
+
|
| 382 |
+
p_pos, p_neg, sep, n_p, n_n = compute_separation(
|
| 383 |
+
predictor, model, tokenizer, device, config, label_fn, behavior)
|
| 384 |
+
|
| 385 |
+
print(f" P(+) = {p_pos:.4f} (n={n_p})")
|
| 386 |
+
print(f" P(-) = {p_neg:.4f} (n={n_n})")
|
| 387 |
+
print(f" SEPARATION = {sep:.1f}x")
|
| 388 |
+
|
| 389 |
+
log["separations"].append({"step": total_step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
|
| 390 |
+
|
| 391 |
+
if sep > best_sep:
|
| 392 |
+
best_sep = sep
|
| 393 |
+
print(f" 🎯 NEW BEST!")
|
| 394 |
+
best_dir = os.path.join(output_dir, "best")
|
| 395 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 396 |
+
model.save_pretrained(best_dir)
|
| 397 |
+
torch.save({
|
| 398 |
+
'predictor': predictor.state_dict(),
|
| 399 |
+
'step': total_step, 'separation': sep, 'p_pos': p_pos, 'p_neg': p_neg
|
| 400 |
+
}, os.path.join(best_dir, "predictor.pt"))
|
| 401 |
+
|
| 402 |
+
with open(os.path.join(output_dir, "log.json"), 'w') as f:
|
| 403 |
+
json.dump(log, f, indent=2)
|
| 404 |
+
|
| 405 |
+
print(f"{'='*50}\n")
|
| 406 |
+
model.train()
|
| 407 |
+
predictor.train()
|
| 408 |
+
|
| 409 |
+
if step % config.save_every == 0:
|
| 410 |
+
ckpt_dir = os.path.join(output_dir, f"ckpt_{total_step}")
|
| 411 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 412 |
+
model.save_pretrained(ckpt_dir)
|
| 413 |
+
torch.save({'predictor': predictor.state_dict(), 'step': total_step},
|
| 414 |
+
os.path.join(ckpt_dir, "predictor.pt"))
|
| 415 |
+
print(f">>> Checkpoint: {ckpt_dir}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Final evaluation
|
| 419 |
+
print(f"\n{'='*50}")
|
| 420 |
+
print(f"[{behavior}] FINAL RESULTS @ Step {total_step}")
|
| 421 |
+
print(f"{'='*50}")
|
| 422 |
+
|
| 423 |
+
p_pos, p_neg, final_sep, n_p, n_n = compute_separation(
|
| 424 |
+
predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=100)
|
| 425 |
+
|
| 426 |
+
print(f" Final Separation: {final_sep:.1f}x")
|
| 427 |
+
print(f" Best Separation: {best_sep:.1f}x")
|
| 428 |
+
print(f" P(+): {p_pos:.4f}, P(-): {p_neg:.4f}")
|
| 429 |
+
|
| 430 |
+
log["final"] = {"separation": final_sep, "best": best_sep, "p_pos": p_pos, "p_neg": p_neg, "total_steps": total_step}
|
| 431 |
+
|
| 432 |
+
with open(os.path.join(output_dir, "log.json"), 'w') as f:
|
| 433 |
+
json.dump(log, f, indent=2)
|
| 434 |
+
|
| 435 |
+
# Save final
|
| 436 |
+
final_dir = os.path.join(output_dir, "final")
|
| 437 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 438 |
+
model.save_pretrained(final_dir)
|
| 439 |
+
torch.save({
|
| 440 |
+
'predictor': predictor.state_dict(),
|
| 441 |
+
'step': total_step, 'separation': final_sep, 'best': best_sep
|
| 442 |
+
}, os.path.join(final_dir, "predictor.pt"))
|
| 443 |
+
|
| 444 |
+
return predictor, best_sep, final_sep
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
# ============== MAIN ==============
|
| 448 |
+
|
| 449 |
+
def main():
|
| 450 |
+
config = Config()
|
| 451 |
+
os.makedirs(OUTPUT_BASE, exist_ok=True)
|
| 452 |
+
|
| 453 |
+
print("=" * 70)
|
| 454 |
+
print("QWEN2.5-3B MULTI-HEAD BEHAVIORAL TRAINING")
|
| 455 |
+
print("=" * 70)
|
| 456 |
+
print(f"Starting from 73.1x repetition checkpoint")
|
| 457 |
+
print(f"Training plan:")
|
| 458 |
+
print(f" 1. Repetition: continue to 35,000 steps (+25,000)")
|
| 459 |
+
print(f" 2. Hedging: 25,000 steps (fresh)")
|
| 460 |
+
print(f" 3. Verbosity: 25,000 steps (fresh)")
|
| 461 |
+
print(f" 4. Sycophancy: 25,000 steps (fresh)")
|
| 462 |
+
print()
|
| 463 |
+
|
| 464 |
+
# Load tokenizer
|
| 465 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 466 |
+
if tokenizer.pad_token is None:
|
| 467 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 468 |
+
|
| 469 |
+
# Load base model
|
| 470 |
+
print("Loading Qwen2.5-3B...")
|
| 471 |
+
bnb = BitsAndBytesConfig(
|
| 472 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 473 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 474 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 475 |
+
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
|
| 476 |
+
base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
|
| 477 |
+
|
| 478 |
+
# Load LoRA from checkpoint
|
| 479 |
+
print("Loading LoRA weights from 73.1x checkpoint...")
|
| 480 |
+
model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
|
| 481 |
+
for name, param in model.named_parameters():
|
| 482 |
+
if 'lora' in name.lower():
|
| 483 |
+
param.requires_grad = True
|
| 484 |
+
|
| 485 |
+
device = next(model.parameters()).device
|
| 486 |
+
d_model = model.config.hidden_size
|
| 487 |
+
|
| 488 |
+
# Load data
|
| 489 |
+
print("Loading training data...")
|
| 490 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 491 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 492 |
+
random.shuffle(texts)
|
| 493 |
+
print(f"Loaded {len(texts)} samples")
|
| 494 |
+
|
| 495 |
+
results = {}
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
# ============================================================
|
| 499 |
+
# HEAD 1: REPETITION (continue from 73.1x checkpoint @ step 10000)
|
| 500 |
+
# ============================================================
|
| 501 |
+
print("\n" + "=" * 70)
|
| 502 |
+
print("HEAD 1: REPETITION (continuing from 73.1x @ step 10000)")
|
| 503 |
+
print("=" * 70)
|
| 504 |
+
|
| 505 |
+
# Load existing predictor from 73.1x checkpoint
|
| 506 |
+
rep_predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
|
| 507 |
+
rep_predictor = rep_predictor.to(device).float()
|
| 508 |
+
ckpt = torch.load(os.path.join(CHECKPOINT_DIR, "risk_predictor.pt"), map_location=device)
|
| 509 |
+
rep_predictor.load_state_dict(ckpt['risk_predictor'])
|
| 510 |
+
start_step = ckpt.get('step', 10000)
|
| 511 |
+
start_sep = ckpt.get('separation', 73.1)
|
| 512 |
+
print(f"Loaded predictor: step={start_step}, separation={start_sep:.1f}x")
|
| 513 |
+
|
| 514 |
+
# Continue for 25000 MORE steps (to reach step 35000 total)
|
| 515 |
+
_, rep_best, rep_final = train_head(
|
| 516 |
+
model, tokenizer, texts, device, d_model, config,
|
| 517 |
+
behavior="repetition", max_steps=25000,
|
| 518 |
+
output_dir=os.path.join(OUTPUT_BASE, "repetition"),
|
| 519 |
+
start_predictor=rep_predictor,
|
| 520 |
+
start_step=start_step,
|
| 521 |
+
start_best=start_sep
|
| 522 |
+
)
|
| 523 |
+
results["repetition"] = {"best": rep_best, "final": rep_final}
|
| 524 |
+
|
| 525 |
+
# ============================================================
|
| 526 |
+
# HEAD 2: HEDGING
|
| 527 |
+
# ============================================================
|
| 528 |
+
_, hedge_best, hedge_final = train_head(
|
| 529 |
+
model, tokenizer, texts, device, d_model, config,
|
| 530 |
+
behavior="hedging", max_steps=25000,
|
| 531 |
+
output_dir=os.path.join(OUTPUT_BASE, "hedging"),
|
| 532 |
+
start_step=0,
|
| 533 |
+
start_best=0
|
| 534 |
+
)
|
| 535 |
+
results["hedging"] = {"best": hedge_best, "final": hedge_final}
|
| 536 |
+
|
| 537 |
+
# ============================================================
|
| 538 |
+
# HEAD 3: VERBOSITY
|
| 539 |
+
# ============================================================
|
| 540 |
+
_, verb_best, verb_final = train_head(
|
| 541 |
+
model, tokenizer, texts, device, d_model, config,
|
| 542 |
+
behavior="verbosity", max_steps=25000,
|
| 543 |
+
output_dir=os.path.join(OUTPUT_BASE, "verbosity"),
|
| 544 |
+
start_step=0,
|
| 545 |
+
start_best=0
|
| 546 |
+
)
|
| 547 |
+
results["verbosity"] = {"best": verb_best, "final": verb_final}
|
| 548 |
+
|
| 549 |
+
# ============================================================
|
| 550 |
+
# HEAD 4: SYCOPHANCY
|
| 551 |
+
# ============================================================
|
| 552 |
+
_, syco_best, syco_final = train_head(
|
| 553 |
+
model, tokenizer, texts, device, d_model, config,
|
| 554 |
+
behavior="sycophancy", max_steps=25000,
|
| 555 |
+
output_dir=os.path.join(OUTPUT_BASE, "sycophancy"),
|
| 556 |
+
start_step=0,
|
| 557 |
+
start_best=0
|
| 558 |
+
)
|
| 559 |
+
results["sycophancy"] = {"best": syco_best, "final": syco_final}
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
# ============================================================
|
| 563 |
+
# FINAL SUMMARY
|
| 564 |
+
# ============================================================
|
| 565 |
+
print("\n" + "=" * 70)
|
| 566 |
+
print("FINAL SUMMARY: QWEN2.5-3B MULTI-HEAD RESULTS")
|
| 567 |
+
print("=" * 70)
|
| 568 |
+
|
| 569 |
+
llama_baselines = {
|
| 570 |
+
"repetition": 125,
|
| 571 |
+
"hedging": 168,
|
| 572 |
+
"verbosity": 272,
|
| 573 |
+
"sycophancy": 218
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
print(f"""
|
| 577 |
+
┌────────────────────────────────────────────────────────────────────┐
|
| 578 |
+
│ QWEN2.5-3B vs LLaMA-3.1-8B COMPARISON │
|
| 579 |
+
├────────────────────────────────────────────────────────────────────┤
|
| 580 |
+
│ Behavior │ Qwen-3B (Best) │ LLaMA-8B │ Ratio │
|
| 581 |
+
├────────────────────────────────────────────────────────────────────┤""")
|
| 582 |
+
|
| 583 |
+
for behavior in ["repetition", "hedging", "verbosity", "sycophancy"]:
|
| 584 |
+
qwen = results[behavior]["best"]
|
| 585 |
+
llama = llama_baselines[behavior]
|
| 586 |
+
ratio = qwen / llama * 100
|
| 587 |
+
print(f"│ {behavior:<13} │ {qwen:>6.1f}x │ {llama:>5}x │ {ratio:>5.1f}% │")
|
| 588 |
+
|
| 589 |
+
print(f"""├────────────────────────────────────────────────────────────────────┤
|
| 590 |
+
│ Architecture: Qwen2 (2048d, 36L) vs LLaMA (4096d, 32L) │
|
| 591 |
+
│ Method: IDENTICAL (d_fiber=16, probe layers at 25/50/75%) │
|
| 592 |
+
│ Training: 25,000 steps per head │
|
| 593 |
+
└────────────────────────────────────────────────────────────────────┘
|
| 594 |
+
""")
|
| 595 |
+
|
| 596 |
+
# Save final results
|
| 597 |
+
with open(os.path.join(OUTPUT_BASE, "final_results.json"), 'w') as f:
|
| 598 |
+
json.dump({
|
| 599 |
+
"model": "Qwen2.5-3B",
|
| 600 |
+
"results": results,
|
| 601 |
+
"llama_baselines": llama_baselines,
|
| 602 |
+
"methodology": "identical"
|
| 603 |
+
}, f, indent=2)
|
| 604 |
+
|
| 605 |
+
print(f"Results saved to {OUTPUT_BASE}/final_results.json")
|
| 606 |
+
print("\nDONE!")
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
if __name__ == "__main__":
|
| 610 |
+
main()
|
code/training_pipelines/11_qwen_multihead_CLEAN.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
QWEN2.5-3B MULTI-HEAD BEHAVIORAL TRAINING (CLEAN)
|
| 4 |
+
==================================================
|
| 5 |
+
Uses EXACT methodology from 07b_qwen3b_repetition_FIXED.py that achieved 73.1x
|
| 6 |
+
|
| 7 |
+
Author: Logan Napolitano / Proprioception AI
|
| 8 |
+
Date: February 2026
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 15 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
|
| 16 |
+
from datasets import load_dataset
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
import random
|
| 20 |
+
import json
|
| 21 |
+
from dataclasses import dataclass, field
|
| 22 |
+
from typing import Tuple, List
|
| 23 |
+
|
| 24 |
+
# Checkpoint to continue from (73.1x repetition)
|
| 25 |
+
CHECKPOINT_DIR = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_continued_from_19x/best"
|
| 26 |
+
OUTPUT_BASE = "/home/programmer/Desktop/Claude_and_me/results/qwen3b_multihead_clean"
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class Config:
|
| 30 |
+
model_path: str = "Qwen/Qwen2.5-3B"
|
| 31 |
+
probe_layers: List[int] = field(default_factory=lambda: [9, 18, 27])
|
| 32 |
+
d_fiber: int = 16
|
| 33 |
+
d_control: int = 64
|
| 34 |
+
|
| 35 |
+
# EXACT same as original 07b
|
| 36 |
+
lr_lora: float = 2e-5
|
| 37 |
+
lr_predictor: float = 1e-4
|
| 38 |
+
|
| 39 |
+
batch_size: int = 1
|
| 40 |
+
grad_accum: int = 8
|
| 41 |
+
max_length: int = 256
|
| 42 |
+
weight_decay: float = 0.01
|
| 43 |
+
rep_window: int = 32
|
| 44 |
+
log_every: int = 100
|
| 45 |
+
save_every: int = 5000
|
| 46 |
+
eval_every: int = 1000
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RiskPredictor(nn.Module):
|
| 50 |
+
def __init__(self, d_model: int, probe_layers: List[int], d_fiber: int = 16, d_control: int = 64):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.probe_layers = probe_layers
|
| 53 |
+
n_probes = len(probe_layers)
|
| 54 |
+
self.fiber_projs = nn.ModuleList([
|
| 55 |
+
nn.Linear(d_model, d_fiber, bias=False) for _ in range(n_probes)
|
| 56 |
+
])
|
| 57 |
+
self.layer_weights = nn.Parameter(torch.ones(n_probes) / n_probes)
|
| 58 |
+
self.predictor = nn.Sequential(
|
| 59 |
+
nn.Linear(d_fiber, d_control), nn.GELU(),
|
| 60 |
+
nn.Linear(d_control, d_control), nn.GELU(),
|
| 61 |
+
nn.Linear(d_control, 1)
|
| 62 |
+
)
|
| 63 |
+
for proj in self.fiber_projs:
|
| 64 |
+
nn.init.normal_(proj.weight, std=0.02)
|
| 65 |
+
|
| 66 |
+
def forward(self, hidden_states: Tuple[torch.Tensor, ...]) -> torch.Tensor:
|
| 67 |
+
fibers = []
|
| 68 |
+
for i, layer_idx in enumerate(self.probe_layers):
|
| 69 |
+
if layer_idx < len(hidden_states):
|
| 70 |
+
fiber = self.fiber_projs[i](hidden_states[layer_idx].float())
|
| 71 |
+
fibers.append(fiber)
|
| 72 |
+
weights = F.softmax(self.layer_weights[:len(fibers)], dim=0)
|
| 73 |
+
aggregated = sum(w * f for w, f in zip(weights, fibers))
|
| 74 |
+
return self.predictor(aggregated).squeeze(-1)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def compute_repetition_labels(input_ids: torch.Tensor, window: int = 32) -> torch.Tensor:
|
| 78 |
+
B, S = input_ids.shape
|
| 79 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 80 |
+
for offset in range(1, min(window + 1, S)):
|
| 81 |
+
if offset < S:
|
| 82 |
+
matches = (input_ids[:, offset:] == input_ids[:, :-offset]).float()
|
| 83 |
+
labels[:, offset:] = torch.maximum(labels[:, offset:], matches)
|
| 84 |
+
return labels
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ============== BEHAVIORAL LABELING ==============
|
| 90 |
+
|
| 91 |
+
HEDGE_PATTERNS = [
|
| 92 |
+
"I think", "maybe", "perhaps", "possibly", "probably", "might", "could be",
|
| 93 |
+
"it seems", "apparently", "generally", "usually", "often", "sometimes",
|
| 94 |
+
"in my opinion", "I believe", "I feel", "somewhat", "relatively",
|
| 95 |
+
"to some extent", "more or less", "kind of", "sort of", "arguably",
|
| 96 |
+
"it appears", "presumably", "supposedly", "allegedly", "reportedly"
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
SYCOPHANCY_PATTERNS = [
|
| 100 |
+
"great question", "excellent point", "you're right", "absolutely",
|
| 101 |
+
"I agree", "that's correct", "good thinking", "well said", "exactly",
|
| 102 |
+
"you're absolutely right", "that's a great", "wonderful", "fantastic",
|
| 103 |
+
"brilliant", "perfect", "I couldn't agree more", "you make a great point"
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
VERBOSE_THRESHOLD = 50
|
| 107 |
+
|
| 108 |
+
def compute_hedging_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 109 |
+
B, S = input_ids.shape
|
| 110 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 111 |
+
for b in range(B):
|
| 112 |
+
text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
|
| 113 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[b])
|
| 114 |
+
for pattern in HEDGE_PATTERNS:
|
| 115 |
+
start = 0
|
| 116 |
+
while True:
|
| 117 |
+
idx = text.find(pattern, start)
|
| 118 |
+
if idx == -1:
|
| 119 |
+
break
|
| 120 |
+
char_pos = idx
|
| 121 |
+
token_pos = 0
|
| 122 |
+
current_char = 0
|
| 123 |
+
for t_idx, token in enumerate(tokens):
|
| 124 |
+
token_text = tokenizer.convert_tokens_to_string([token])
|
| 125 |
+
if current_char + len(token_text) > char_pos:
|
| 126 |
+
token_pos = t_idx
|
| 127 |
+
break
|
| 128 |
+
current_char += len(token_text)
|
| 129 |
+
pattern_tokens = len(tokenizer.encode(pattern, add_special_tokens=False))
|
| 130 |
+
for t in range(token_pos, min(token_pos + pattern_tokens, S)):
|
| 131 |
+
labels[b, t] = 1.0
|
| 132 |
+
start = idx + len(pattern)
|
| 133 |
+
return labels
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def compute_sycophancy_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 137 |
+
B, S = input_ids.shape
|
| 138 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 139 |
+
for b in range(B):
|
| 140 |
+
text = tokenizer.decode(input_ids[b], skip_special_tokens=True).lower()
|
| 141 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids[b])
|
| 142 |
+
for pattern in SYCOPHANCY_PATTERNS:
|
| 143 |
+
start = 0
|
| 144 |
+
while True:
|
| 145 |
+
idx = text.find(pattern.lower(), start)
|
| 146 |
+
if idx == -1:
|
| 147 |
+
break
|
| 148 |
+
char_pos = idx
|
| 149 |
+
token_pos = 0
|
| 150 |
+
current_char = 0
|
| 151 |
+
for t_idx, token in enumerate(tokens):
|
| 152 |
+
token_text = tokenizer.convert_tokens_to_string([token])
|
| 153 |
+
if current_char + len(token_text) > char_pos:
|
| 154 |
+
token_pos = t_idx
|
| 155 |
+
break
|
| 156 |
+
current_char += len(token_text)
|
| 157 |
+
pattern_tokens = len(tokenizer.encode(pattern, add_special_tokens=False))
|
| 158 |
+
for t in range(token_pos, min(token_pos + pattern_tokens, S)):
|
| 159 |
+
labels[b, t] = 1.0
|
| 160 |
+
start = idx + len(pattern)
|
| 161 |
+
return labels
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def compute_verbosity_labels(input_ids: torch.Tensor, tokenizer) -> torch.Tensor:
|
| 165 |
+
B, S = input_ids.shape
|
| 166 |
+
labels = torch.zeros(B, S, device=input_ids.device)
|
| 167 |
+
for b in range(B):
|
| 168 |
+
if S > VERBOSE_THRESHOLD:
|
| 169 |
+
labels[b, VERBOSE_THRESHOLD:] = torch.linspace(0.3, 1.0, S - VERBOSE_THRESHOLD, device=input_ids.device)
|
| 170 |
+
return labels
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def get_label_fn(behavior: str, tokenizer):
|
| 174 |
+
if behavior == "repetition":
|
| 175 |
+
return lambda ids, tok: compute_repetition_labels(ids, 32)
|
| 176 |
+
elif behavior == "hedging":
|
| 177 |
+
return lambda ids, tok: compute_hedging_labels(ids, tok)
|
| 178 |
+
elif behavior == "sycophancy":
|
| 179 |
+
return lambda ids, tok: compute_sycophancy_labels(ids, tok)
|
| 180 |
+
elif behavior == "verbosity":
|
| 181 |
+
return lambda ids, tok: compute_verbosity_labels(ids, tok)
|
| 182 |
+
else:
|
| 183 |
+
raise ValueError(f"Unknown behavior: {behavior}")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# ============== EVALUATION (EXACT SAME AS 07b) ==============
|
| 187 |
+
|
| 188 |
+
def compute_separation(predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=30):
|
| 189 |
+
"""EXACT same eval as 07b - uses do_sample=True, temperature=0.9"""
|
| 190 |
+
model.eval()
|
| 191 |
+
predictor.eval()
|
| 192 |
+
pos_scores, neg_scores = [], []
|
| 193 |
+
|
| 194 |
+
prompts = [
|
| 195 |
+
"The meaning of life according to philosophy is",
|
| 196 |
+
"In the year 2050, technology will",
|
| 197 |
+
"The history of mathematics begins with",
|
| 198 |
+
"Climate change affects the planet by",
|
| 199 |
+
"Neural networks learn patterns through",
|
| 200 |
+
"The ocean contains many species of",
|
| 201 |
+
"Music has evolved significantly since",
|
| 202 |
+
"Economic theories suggest that markets",
|
| 203 |
+
"The human brain processes information",
|
| 204 |
+
"Ancient civilizations developed writing",
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
for i in range(n_samples):
|
| 209 |
+
prompt = prompts[i % len(prompts)]
|
| 210 |
+
inp = tokenizer(prompt, return_tensors='pt')
|
| 211 |
+
input_ids = inp['input_ids'].to(device)
|
| 212 |
+
attn = inp['attention_mask'].to(device)
|
| 213 |
+
|
| 214 |
+
# EXACT same generation params as 07b
|
| 215 |
+
out = model.generate(
|
| 216 |
+
input_ids, attention_mask=attn, max_new_tokens=80,
|
| 217 |
+
do_sample=True, temperature=0.9, top_p=0.95,
|
| 218 |
+
pad_token_id=tokenizer.eos_token_id
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
outputs = model(out, output_hidden_states=True)
|
| 222 |
+
risk = torch.sigmoid(predictor(outputs.hidden_states))[0].cpu().numpy()
|
| 223 |
+
|
| 224 |
+
if behavior == "repetition":
|
| 225 |
+
labels = compute_repetition_labels(out, 32)[0].cpu().numpy()
|
| 226 |
+
else:
|
| 227 |
+
labels = label_fn(out, tokenizer)[0].cpu().numpy()
|
| 228 |
+
|
| 229 |
+
for t in range(len(risk)):
|
| 230 |
+
(pos_scores if labels[t] > 0.5 else neg_scores).append(float(risk[t]))
|
| 231 |
+
|
| 232 |
+
if pos_scores and neg_scores:
|
| 233 |
+
p_pos = sum(pos_scores) / len(pos_scores)
|
| 234 |
+
p_neg = sum(neg_scores) / len(neg_scores)
|
| 235 |
+
return p_pos, p_neg, p_pos / max(p_neg, 1e-8), len(pos_scores), len(neg_scores)
|
| 236 |
+
return 0, 0, 0, 0, 0
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# ============== TRAINING FUNCTION ==============
|
| 242 |
+
|
| 243 |
+
def train_behavior(model, tokenizer, texts, device, d_model, config, behavior,
|
| 244 |
+
max_steps, output_dir, start_predictor=None, start_step=0):
|
| 245 |
+
"""Train a single behavioral head using EXACT 07b methodology."""
|
| 246 |
+
|
| 247 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 248 |
+
|
| 249 |
+
print(f"\n{'='*70}")
|
| 250 |
+
print(f"TRAINING: {behavior.upper()}")
|
| 251 |
+
print(f"{'='*70}")
|
| 252 |
+
print(f"Steps: {max_steps} (starting from step {start_step})")
|
| 253 |
+
print(f"LR: LoRA={config.lr_lora}, Predictor={config.lr_predictor}")
|
| 254 |
+
print(f"Output: {output_dir}")
|
| 255 |
+
print()
|
| 256 |
+
|
| 257 |
+
# Initialize or load predictor
|
| 258 |
+
if start_predictor is not None:
|
| 259 |
+
predictor = start_predictor
|
| 260 |
+
print("Continuing from checkpoint...")
|
| 261 |
+
else:
|
| 262 |
+
predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
|
| 263 |
+
predictor = predictor.to(device).float()
|
| 264 |
+
print("Fresh predictor initialized")
|
| 265 |
+
|
| 266 |
+
label_fn = get_label_fn(behavior, tokenizer)
|
| 267 |
+
|
| 268 |
+
# Setup optimizer - EXACT same as 07b
|
| 269 |
+
lora_params = [p for p in model.parameters() if p.requires_grad]
|
| 270 |
+
optimizer = torch.optim.AdamW([
|
| 271 |
+
{'params': lora_params, 'lr': config.lr_lora},
|
| 272 |
+
{'params': predictor.parameters(), 'lr': config.lr_predictor}
|
| 273 |
+
], weight_decay=config.weight_decay)
|
| 274 |
+
|
| 275 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 276 |
+
optimizer, T_max=max_steps, eta_min=1e-6
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
log = {"behavior": behavior, "start_step": start_step, "steps": [], "separations": []}
|
| 280 |
+
|
| 281 |
+
model.train()
|
| 282 |
+
predictor.train()
|
| 283 |
+
|
| 284 |
+
step = 0
|
| 285 |
+
total_step = start_step
|
| 286 |
+
data_idx = 0
|
| 287 |
+
acc_loss, acc_lm, acc_risk = 0, 0, 0
|
| 288 |
+
best_sep = 0
|
| 289 |
+
start_time = time.time()
|
| 290 |
+
|
| 291 |
+
while step < max_steps:
|
| 292 |
+
batch = [texts[(data_idx + i) % len(texts)] for i in range(config.batch_size)]
|
| 293 |
+
data_idx += config.batch_size
|
| 294 |
+
|
| 295 |
+
enc = tokenizer(batch, truncation=True, max_length=config.max_length,
|
| 296 |
+
padding='max_length', return_tensors='pt')
|
| 297 |
+
input_ids = enc['input_ids'].to(device)
|
| 298 |
+
attention_mask = enc['attention_mask'].to(device)
|
| 299 |
+
|
| 300 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask,
|
| 301 |
+
labels=input_ids, output_hidden_states=True)
|
| 302 |
+
|
| 303 |
+
lm_loss = outputs.loss
|
| 304 |
+
risk_logits = predictor(outputs.hidden_states)
|
| 305 |
+
|
| 306 |
+
if behavior == "repetition":
|
| 307 |
+
labels = compute_repetition_labels(input_ids, config.rep_window)
|
| 308 |
+
else:
|
| 309 |
+
labels = label_fn(input_ids, tokenizer)
|
| 310 |
+
|
| 311 |
+
# Class-weighted loss - EXACT same as 07b
|
| 312 |
+
mask = attention_mask.float()
|
| 313 |
+
n_pos = (labels * mask).sum().clamp(min=1)
|
| 314 |
+
n_neg = ((1 - labels) * mask).sum().clamp(min=1)
|
| 315 |
+
pos_weight = (n_neg / n_pos).clamp(max=10.0)
|
| 316 |
+
|
| 317 |
+
bce = F.binary_cross_entropy_with_logits(
|
| 318 |
+
risk_logits, labels,
|
| 319 |
+
pos_weight=torch.ones_like(labels) * pos_weight, reduction='none')
|
| 320 |
+
risk_loss = (bce * mask).sum() / mask.sum()
|
| 321 |
+
|
| 322 |
+
loss = lm_loss + risk_loss
|
| 323 |
+
(loss / config.grad_accum).backward()
|
| 324 |
+
|
| 325 |
+
acc_loss += loss.item()
|
| 326 |
+
acc_lm += lm_loss.item()
|
| 327 |
+
acc_risk += risk_loss.item()
|
| 328 |
+
step += 1
|
| 329 |
+
total_step += 1
|
| 330 |
+
|
| 331 |
+
if step % config.grad_accum == 0:
|
| 332 |
+
torch.nn.utils.clip_grad_norm_(list(lora_params) + list(predictor.parameters()), 1.0)
|
| 333 |
+
optimizer.step()
|
| 334 |
+
scheduler.step()
|
| 335 |
+
optimizer.zero_grad()
|
| 336 |
+
|
| 337 |
+
if step % config.log_every == 0:
|
| 338 |
+
eta = (max_steps - step) / (step / (time.time() - start_time)) / 60
|
| 339 |
+
print(f"[{behavior}] Step {total_step:5d} | Loss: {acc_loss/config.log_every:.3f} | "
|
| 340 |
+
f"LM: {acc_lm/config.log_every:.3f} | Risk: {acc_risk/config.log_every:.3f} | "
|
| 341 |
+
f"Best: {best_sep:.1f}x | ETA: {eta:.1f}m")
|
| 342 |
+
log["steps"].append({"step": total_step, "loss": acc_loss/config.log_every})
|
| 343 |
+
acc_loss, acc_lm, acc_risk = 0, 0, 0
|
| 344 |
+
|
| 345 |
+
if step % config.eval_every == 0:
|
| 346 |
+
print(f"\n{'='*50}")
|
| 347 |
+
print(f"[{behavior}] SEPARATION EVAL @ Step {total_step}")
|
| 348 |
+
print(f"{'='*50}")
|
| 349 |
+
|
| 350 |
+
p_pos, p_neg, sep, n_p, n_n = compute_separation(
|
| 351 |
+
predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=30)
|
| 352 |
+
|
| 353 |
+
print(f" P(+) = {p_pos:.4f} (n={n_p})")
|
| 354 |
+
print(f" P(-) = {p_neg:.4f} (n={n_n})")
|
| 355 |
+
print(f" SEPARATION = {sep:.1f}x")
|
| 356 |
+
|
| 357 |
+
log["separations"].append({"step": total_step, "separation": sep, "p_pos": p_pos, "p_neg": p_neg})
|
| 358 |
+
|
| 359 |
+
if sep > best_sep:
|
| 360 |
+
best_sep = sep
|
| 361 |
+
print(f" 🎯 NEW BEST!")
|
| 362 |
+
best_dir = os.path.join(output_dir, "best")
|
| 363 |
+
os.makedirs(best_dir, exist_ok=True)
|
| 364 |
+
model.save_pretrained(best_dir)
|
| 365 |
+
torch.save({
|
| 366 |
+
'predictor': predictor.state_dict(),
|
| 367 |
+
'step': total_step, 'separation': sep, 'p_pos': p_pos, 'p_neg': p_neg
|
| 368 |
+
}, os.path.join(best_dir, "predictor.pt"))
|
| 369 |
+
|
| 370 |
+
print(f"{'='*50}\n")
|
| 371 |
+
model.train()
|
| 372 |
+
predictor.train()
|
| 373 |
+
|
| 374 |
+
if step % config.save_every == 0:
|
| 375 |
+
ckpt_dir = os.path.join(output_dir, f"ckpt_{total_step}")
|
| 376 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 377 |
+
model.save_pretrained(ckpt_dir)
|
| 378 |
+
torch.save({'predictor': predictor.state_dict(), 'step': total_step, 'separation': best_sep},
|
| 379 |
+
os.path.join(ckpt_dir, "predictor.pt"))
|
| 380 |
+
print(f">>> Checkpoint: {ckpt_dir}")
|
| 381 |
+
|
| 382 |
+
# Final eval
|
| 383 |
+
print(f"\n{'='*50}")
|
| 384 |
+
print(f"[{behavior}] FINAL RESULTS @ Step {total_step}")
|
| 385 |
+
print(f"{'='*50}")
|
| 386 |
+
|
| 387 |
+
p_pos, p_neg, final_sep, _, _ = compute_separation(
|
| 388 |
+
predictor, model, tokenizer, device, config, label_fn, behavior, n_samples=50)
|
| 389 |
+
|
| 390 |
+
print(f" Final separation: {final_sep:.1f}x")
|
| 391 |
+
print(f" Best separation: {best_sep:.1f}x")
|
| 392 |
+
|
| 393 |
+
log["final"] = {"separation": final_sep, "best": best_sep}
|
| 394 |
+
|
| 395 |
+
with open(os.path.join(output_dir, "log.json"), 'w') as f:
|
| 396 |
+
json.dump(log, f, indent=2)
|
| 397 |
+
|
| 398 |
+
# Save final
|
| 399 |
+
final_dir = os.path.join(output_dir, "final")
|
| 400 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 401 |
+
model.save_pretrained(final_dir)
|
| 402 |
+
torch.save({
|
| 403 |
+
'predictor': predictor.state_dict(),
|
| 404 |
+
'step': total_step, 'separation': final_sep, 'best': best_sep
|
| 405 |
+
}, os.path.join(final_dir, "predictor.pt"))
|
| 406 |
+
|
| 407 |
+
return predictor, best_sep, final_sep
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# ============== MAIN ==============
|
| 413 |
+
|
| 414 |
+
def main():
|
| 415 |
+
config = Config()
|
| 416 |
+
os.makedirs(OUTPUT_BASE, exist_ok=True)
|
| 417 |
+
|
| 418 |
+
print("=" * 70)
|
| 419 |
+
print("QWEN2.5-3B MULTI-HEAD TRAINING (CLEAN - EXACT 07b METHODOLOGY)")
|
| 420 |
+
print("=" * 70)
|
| 421 |
+
print(f"LR LoRA: {config.lr_lora} (same as 07b)")
|
| 422 |
+
print(f"LR Predictor: {config.lr_predictor} (same as 07b)")
|
| 423 |
+
print(f"Eval: do_sample=True, temperature=0.9 (same as 07b)")
|
| 424 |
+
print()
|
| 425 |
+
|
| 426 |
+
tokenizer = AutoTokenizer.from_pretrained(config.model_path)
|
| 427 |
+
if tokenizer.pad_token is None:
|
| 428 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 429 |
+
|
| 430 |
+
print("Loading Qwen2.5-3B...")
|
| 431 |
+
bnb = BitsAndBytesConfig(
|
| 432 |
+
load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16,
|
| 433 |
+
bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
|
| 434 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 435 |
+
config.model_path, quantization_config=bnb, device_map='auto', torch_dtype=torch.float16)
|
| 436 |
+
base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)
|
| 437 |
+
|
| 438 |
+
print("Loading LoRA weights from 73.1x checkpoint...")
|
| 439 |
+
model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
|
| 440 |
+
model.train()
|
| 441 |
+
|
| 442 |
+
for name, param in model.named_parameters():
|
| 443 |
+
if 'lora' in name.lower():
|
| 444 |
+
param.requires_grad = True
|
| 445 |
+
|
| 446 |
+
device = next(model.parameters()).device
|
| 447 |
+
d_model = model.config.hidden_size
|
| 448 |
+
|
| 449 |
+
print("Loading training data...")
|
| 450 |
+
ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 451 |
+
texts = [ex['text'] for ex in ds if len(ex['text']) > 50]
|
| 452 |
+
random.shuffle(texts)
|
| 453 |
+
print(f"Loaded {len(texts)} samples")
|
| 454 |
+
|
| 455 |
+
results = {}
|
| 456 |
+
|
| 457 |
+
# ============================================================
|
| 458 |
+
# HEAD 1: REPETITION (continue from 73.1x checkpoint @ step 10000)
|
| 459 |
+
# ============================================================
|
| 460 |
+
print("\n" + "=" * 70)
|
| 461 |
+
print("HEAD 1: REPETITION (continuing from checkpoint)")
|
| 462 |
+
print("=" * 70)
|
| 463 |
+
|
| 464 |
+
rep_predictor = RiskPredictor(d_model, config.probe_layers, config.d_fiber, config.d_control)
|
| 465 |
+
rep_predictor = rep_predictor.to(device).float()
|
| 466 |
+
ckpt = torch.load(os.path.join(CHECKPOINT_DIR, "risk_predictor.pt"), map_location=device)
|
| 467 |
+
rep_predictor.load_state_dict(ckpt['risk_predictor'])
|
| 468 |
+
start_step = ckpt.get('step', 10000)
|
| 469 |
+
start_sep = ckpt.get('separation', 73.1)
|
| 470 |
+
print(f"Loaded predictor: step={start_step}, separation={start_sep:.1f}x")
|
| 471 |
+
|
| 472 |
+
_, rep_best, rep_final = train_behavior(
|
| 473 |
+
model, tokenizer, texts, device, d_model, config,
|
| 474 |
+
behavior="repetition", max_steps=25000,
|
| 475 |
+
output_dir=os.path.join(OUTPUT_BASE, "repetition"),
|
| 476 |
+
start_predictor=rep_predictor,
|
| 477 |
+
start_step=start_step
|
| 478 |
+
)
|
| 479 |
+
results["repetition"] = {"best": rep_best, "final": rep_final}
|
| 480 |
+
|
| 481 |
+
# ============================================================
|
| 482 |
+
# HEAD 2: HEDGING (fresh from repetition-trained LoRA)
|
| 483 |
+
# ============================================================
|
| 484 |
+
_, hedge_best, hedge_final = train_behavior(
|
| 485 |
+
model, tokenizer, texts, device, d_model, config,
|
| 486 |
+
behavior="hedging", max_steps=25000,
|
| 487 |
+
output_dir=os.path.join(OUTPUT_BASE, "hedging"),
|
| 488 |
+
start_step=0
|
| 489 |
+
)
|
| 490 |
+
results["hedging"] = {"best": hedge_best, "final": hedge_final}
|
| 491 |
+
|
| 492 |
+
# ============================================================
|
| 493 |
+
# HEAD 3: VERBOSITY
|
| 494 |
+
# ============================================================
|
| 495 |
+
_, verb_best, verb_final = train_behavior(
|
| 496 |
+
model, tokenizer, texts, device, d_model, config,
|
| 497 |
+
behavior="verbosity", max_steps=25000,
|
| 498 |
+
output_dir=os.path.join(OUTPUT_BASE, "verbosity"),
|
| 499 |
+
start_step=0
|
| 500 |
+
)
|
| 501 |
+
results["verbosity"] = {"best": verb_best, "final": verb_final}
|
| 502 |
+
|
| 503 |
+
# ============================================================
|
| 504 |
+
# HEAD 4: SYCOPHANCY
|
| 505 |
+
# ============================================================
|
| 506 |
+
_, syco_best, syco_final = train_behavior(
|
| 507 |
+
model, tokenizer, texts, device, d_model, config,
|
| 508 |
+
behavior="sycophancy", max_steps=25000,
|
| 509 |
+
output_dir=os.path.join(OUTPUT_BASE, "sycophancy"),
|
| 510 |
+
start_step=0
|
| 511 |
+
)
|
| 512 |
+
results["sycophancy"] = {"best": syco_best, "final": syco_final}
|
| 513 |
+
|
| 514 |
+
# ============================================================
|
| 515 |
+
# FINAL SUMMARY
|
| 516 |
+
# ============================================================
|
| 517 |
+
print("\n" + "=" * 70)
|
| 518 |
+
print("FINAL SUMMARY: QWEN2.5-3B MULTI-HEAD RESULTS")
|
| 519 |
+
print("=" * 70)
|
| 520 |
+
|
| 521 |
+
llama_baselines = {
|
| 522 |
+
"repetition": 125,
|
| 523 |
+
"hedging": 168,
|
| 524 |
+
"verbosity": 272,
|
| 525 |
+
"sycophancy": 218
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
print(f"""
|
| 529 |
+
┌────────────────────────────────────────────────────────────────────┐
|
| 530 |
+
│ QWEN2.5-3B vs LLaMA-3.1-8B COMPARISON │
|
| 531 |
+
├────────────────────────────────────────────────────────────────────┤
|
| 532 |
+
│ Behavior │ Qwen-3B (Best) │ LLaMA-8B │ Ratio │
|
| 533 |
+
├────────────────────────────────────────────────────────────────────┤""")
|
| 534 |
+
|
| 535 |
+
for behavior in ["repetition", "hedging", "verbosity", "sycophancy"]:
|
| 536 |
+
qwen = results[behavior]["best"]
|
| 537 |
+
llama = llama_baselines[behavior]
|
| 538 |
+
ratio = qwen / llama * 100
|
| 539 |
+
print(f"│ {behavior:<13} │ {qwen:>6.1f}x │ {llama:>5}x │ {ratio:>5.1f}% │")
|
| 540 |
+
|
| 541 |
+
print(f"""├────────────────────────────────────────────────────────────────────┤
|
| 542 |
+
│ Methodology: EXACT same as 07b (lr=2e-5/1e-4, do_sample=True) │
|
| 543 |
+
│ Architecture: Qwen2 (2048d, 36L) vs LLaMA (4096d, 32L) │
|
| 544 |
+
└────────────────────────────────────────────────────────────────────┘
|
| 545 |
+
""")
|
| 546 |
+
|
| 547 |
+
with open(os.path.join(OUTPUT_BASE, "final_results.json"), 'w') as f:
|
| 548 |
+
json.dump({
|
| 549 |
+
"model": "Qwen2.5-3B",
|
| 550 |
+
"results": results,
|
| 551 |
+
"llama_baselines": llama_baselines,
|
| 552 |
+
"methodology": "exact_07b"
|
| 553 |
+
}, f, indent=2)
|
| 554 |
+
|
| 555 |
+
print(f"Results saved to {OUTPUT_BASE}/final_results.json")
|
| 556 |
+
print("\nDONE!")
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
if __name__ == "__main__":
|
| 560 |
+
main()
|
cognitive/mamba/calibration/calibration_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e006fe92ed7a6dd0a38d8fc774acca5f03d5e382eab15a300530a1c9baa63bdc
|
| 3 |
+
size 812565
|
cognitive/mamba/coherence/coherence_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54889ee9835dd22b7c48d776b8f6ed6d0e6408e3ce1bb2b5194f1f41da75d988
|
| 3 |
+
size 812533
|
cognitive/mamba/depth/depth_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3e8f65ff74e25ccd2333dd5201d987038fc1c297b7dc2f9d759c15abbc66469
|
| 3 |
+
size 812341
|
cognitive/mamba/focus/focus_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd04964e3ab03de9f34faf094efdc8a92df33885394d0f4700f0060cec524d82
|
| 3 |
+
size 812405
|
cognitive/mamba/specificity/specificity_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74f28fe7a189eb34f48f6089de975afdf8b1359892c61867ecce5340363187c6
|
| 3 |
+
size 812437
|
cognitive/mistral/calibration/calibration_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b4a945b071cdf508d75762c332319a0fdb89c316d7e3299c07fda3088fc9fa45
|
| 3 |
+
size 812437
|
cognitive/mistral/coherence/coherence_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3abd4a3c0e0761310e49b30015a1bc0f2e33a31d1a14d9a66ce6ef5a5039e25a
|
| 3 |
+
size 812341
|
cognitive/mistral/depth/depth_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:69e95c81ab45fca6b8afcd7dc20312ff67dff9fc6f198409afee4da443518abe
|
| 3 |
+
size 812277
|
cognitive/mistral/focus/focus_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a59b47b43accb21e1618cfcba79fd45ad024fd1da5a855a566b1c5dcec455624
|
| 3 |
+
size 812277
|
cognitive/mistral/results.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"depth": 999.638373580049,
|
| 3 |
+
"specificity": 999.6663282511262,
|
| 4 |
+
"calibration": 999.4429833748761,
|
| 5 |
+
"focus": 999.5846075316271,
|
| 6 |
+
"coherence": 999.6786809149589
|
| 7 |
+
}
|
cognitive/mistral/specificity/specificity_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f5a9db418a6e4c3e02f5fb0f79162095d35d71098bb7da2d9ec23dc547a2b66f
|
| 3 |
+
size 812437
|
cognitive/qwen/calibration/calibration_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2735c1b5eacaf5fef6e77302ae09109be44fb416caa515c304643bc4852800fa
|
| 3 |
+
size 714133
|
cognitive/qwen/coherence/coherence_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fab62274613b41759cacf7193381c436b3e601db131e4e064f43f11d35fe8911
|
| 3 |
+
size 714101
|
cognitive/qwen/depth/depth_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a09abe8b645e194e61328625a64ac50f963a33820e2da11a84249786eeb9ad4
|
| 3 |
+
size 714037
|
cognitive/qwen/focus/focus_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8edeb3373e15083779d9f9242781f829692385c878d9e5ce56d2fc8429583215
|
| 3 |
+
size 714037
|
cognitive/qwen/specificity/specificity_head.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae507715853275ebfa2ac7f4643fdee8ef9d669150a5cb26b5de55e02be9bde2
|
| 3 |
+
size 714133
|
production/adapter_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alora_invocation_tokens": null,
|
| 3 |
+
"alpha_pattern": {},
|
| 4 |
+
"arrow_config": null,
|
| 5 |
+
"auto_mapping": null,
|
| 6 |
+
"base_model_name_or_path": "LoganResearch/ARC-Base-8B-Condensed",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"corda_config": null,
|
| 9 |
+
"ensure_weight_tying": false,
|
| 10 |
+
"eva_config": null,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_lora_weights": true,
|
| 15 |
+
"layer_replication": null,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"loftq_config": {},
|
| 19 |
+
"lora_alpha": 128,
|
| 20 |
+
"lora_bias": false,
|
| 21 |
+
"lora_dropout": 0.05,
|
| 22 |
+
"megatron_config": null,
|
| 23 |
+
"megatron_core": "megatron.core",
|
| 24 |
+
"modules_to_save": null,
|
| 25 |
+
"peft_type": "LORA",
|
| 26 |
+
"peft_version": "0.18.1",
|
| 27 |
+
"qalora_group_size": 16,
|
| 28 |
+
"r": 64,
|
| 29 |
+
"rank_pattern": {},
|
| 30 |
+
"revision": null,
|
| 31 |
+
"target_modules": [
|
| 32 |
+
"v_proj",
|
| 33 |
+
"k_proj",
|
| 34 |
+
"q_proj",
|
| 35 |
+
"o_proj"
|
| 36 |
+
],
|
| 37 |
+
"target_parameters": null,
|
| 38 |
+
"task_type": "CAUSAL_LM",
|
| 39 |
+
"trainable_token_indices": null,
|
| 40 |
+
"use_dora": false,
|
| 41 |
+
"use_qalora": false,
|
| 42 |
+
"use_rslora": false
|
| 43 |
+
}
|
production/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3889eccb9c04ba25ae86b99121368121a338fc3ce92a38456874bf455347e389
|
| 3 |
+
size 218138576
|
production/manifest.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "ARC Merged Production Adapter",
|
| 3 |
+
"version": "1.0",
|
| 4 |
+
"date": "2026-01-24",
|
| 5 |
+
"components": {
|
| 6 |
+
"lora_adapter": {
|
| 7 |
+
"file": "adapter_model.safetensors",
|
| 8 |
+
"size_mb": 208,
|
| 9 |
+
"source": "125\u00d7 repetition training"
|
| 10 |
+
},
|
| 11 |
+
"heads": {
|
| 12 |
+
"file": "merged_heads.pt",
|
| 13 |
+
"contents": {
|
| 14 |
+
"repetition": "125\u00d7 separation (PRODUCTION)",
|
| 15 |
+
"hedging": "32.7\u00d7 separation (PRODUCTION)",
|
| 16 |
+
"verbosity": "1.41\u00d7 (needs work)",
|
| 17 |
+
"sycophancy": "not included"
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"tokenizer": {
|
| 21 |
+
"tokens_added": 24,
|
| 22 |
+
"vocab_size": 128280
|
| 23 |
+
},
|
| 24 |
+
"geometry": {
|
| 25 |
+
"curvature_separation": "1.54\u00d7"
|
| 26 |
+
},
|
| 27 |
+
"loop4": {
|
| 28 |
+
"rsi_iterations": "10/10",
|
| 29 |
+
"ceiling_broken": true
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"base_model": "LoganResearch/ARC-Base-8B or local merged-final-v5",
|
| 33 |
+
"usage": "\nfrom peft import PeftModel\nimport torch\n\n# Load base model\nbase = AutoModelForCausalLM.from_pretrained(...)\n\n# Load LoRA\nmodel = PeftModel.from_pretrained(base, \"MERGED_PRODUCTION_ADAPTER/\")\n\n# Load heads\nheads = torch.load(\"MERGED_PRODUCTION_ADAPTER/merged_heads.pt\")\nrepetition_weights = heads[\"heads\"][\"repetition\"][\"weights\"]\nhedging_weights = heads[\"heads\"][\"hedging\"][\"weights\"]\n"
|
| 34 |
+
}
|
production/merged_heads.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0820d998759cbaeb9e485aea04ac3ba5683799ba50c21e947d6b180b3a0e61f2
|
| 3 |
+
size 10093424
|
production/qwen_cognitive/README.md
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-4.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
+
tags:
|
| 7 |
+
- cognitive-enhancement
|
| 8 |
+
- behavioral-control
|
| 9 |
+
- hidden-state-probing
|
| 10 |
+
- fiber-projection
|
| 11 |
+
- decode-time-intervention
|
| 12 |
+
- qwen2
|
| 13 |
+
- interpretability
|
| 14 |
+
base_model:
|
| 15 |
+
- Qwen/Qwen2.5-7B-Instruct
|
| 16 |
+
pipeline_tag: text-generation
|
| 17 |
+
model-index:
|
| 18 |
+
- name: qwen2.5-7b-cognitive-enhanced
|
| 19 |
+
results:
|
| 20 |
+
- task:
|
| 21 |
+
type: text-generation
|
| 22 |
+
name: Cognitive Enhancement
|
| 23 |
+
metrics:
|
| 24 |
+
- type: separation-ratio
|
| 25 |
+
value: 366
|
| 26 |
+
name: Depth Probe Separation
|
| 27 |
+
- type: separation-ratio
|
| 28 |
+
value: 215
|
| 29 |
+
name: Specificity Probe Separation
|
| 30 |
+
- type: separation-ratio
|
| 31 |
+
value: 165
|
| 32 |
+
name: Calibration Probe Separation
|
| 33 |
+
- type: separation-ratio
|
| 34 |
+
value: 227
|
| 35 |
+
name: Focus Probe Separation
|
| 36 |
+
- type: separation-ratio
|
| 37 |
+
value: 191
|
| 38 |
+
name: Coherence Probe Separation
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
<div align="center">
|
| 42 |
+
|
| 43 |
+
# Qwen2.5-7B Cognitive Enhancement Adapter
|
| 44 |
+
|
| 45 |
+
**Decode-Time Behavioral Control via Hidden State Probing**
|
| 46 |
+
|
| 47 |
+
*Logan Matthew Napolitano*
|
| 48 |
+
|
| 49 |
+
[](https://creativecommons.org/licenses/by/4.0/)
|
| 50 |
+
[](https://www.python.org/downloads/)
|
| 51 |
+
[](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct)
|
| 52 |
+
|
| 53 |
+
*Research into cognitive behavioral control in language models*
|
| 54 |
+
|
| 55 |
+
[Quick Start](#quick-start) • [Architecture](#architecture) • [Probes](#probe-specifications) • [Evaluation](#evaluation) • [Citation](#citation)
|
| 56 |
+
|
| 57 |
+
</div>
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## Table of Contents
|
| 62 |
+
|
| 63 |
+
1. [Model Description](#model-description)
|
| 64 |
+
2. [Quick Start](#quick-start)
|
| 65 |
+
3. [Architecture](#architecture)
|
| 66 |
+
4. [Probe Specifications](#probe-specifications)
|
| 67 |
+
5. [Intervention Mechanism](#intervention-mechanism)
|
| 68 |
+
6. [Installation](#installation)
|
| 69 |
+
7. [Usage](#usage)
|
| 70 |
+
8. [Evaluation](#evaluation)
|
| 71 |
+
9. [Configuration](#configuration)
|
| 72 |
+
10. [Hardware Requirements](#hardware-requirements)
|
| 73 |
+
11. [Limitations](#limitations)
|
| 74 |
+
12. [Technical Specification](#technical-specification)
|
| 75 |
+
13. [Citation](#citation)
|
| 76 |
+
14. [License](#license)
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
## Model Description
|
| 81 |
+
|
| 82 |
+
This repository contains a **cognitive enhancement adapter** for [Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct). The adapter consists of five lightweight probes that analyze hidden states during generation and apply targeted interventions to improve response quality.
|
| 83 |
+
|
| 84 |
+
### Core Concept
|
| 85 |
+
|
| 86 |
+
The adapter detects cognitive failure modes (shallow reasoning, vagueness, overconfidence, topic drift, logical inconsistency) by monitoring the model's internal representations at decode time. When a probe fires, the system adjusts token probabilities to steer generation toward more desirable behaviors.
|
| 87 |
+
|
| 88 |
+
### Intended Use
|
| 89 |
+
|
| 90 |
+
- Research into behavioral control mechanisms in language models
|
| 91 |
+
- Study of hidden state interpretability
|
| 92 |
+
- Applications requiring structured, well-calibrated responses
|
| 93 |
+
- Base for further experimentation with decode-time intervention
|
| 94 |
+
|
| 95 |
+
### Not Intended For
|
| 96 |
+
|
| 97 |
+
- Production deployment without thorough evaluation
|
| 98 |
+
- Safety-critical applications
|
| 99 |
+
- Replacement for proper model fine-tuning when domain adaptation is needed
|
| 100 |
+
- Applications where the base model's default behavior is preferred
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## Quick Start
|
| 105 |
+
|
| 106 |
+
### Minimal Setup
|
| 107 |
+
|
| 108 |
+
```bash
|
| 109 |
+
git clone https://huggingface.co/LoganResearch/qwen2.5-7b-cognitive-enhanced
|
| 110 |
+
cd qwen2.5-7b-cognitive-enhanced
|
| 111 |
+
pip install torch transformers accelerate bitsandbytes
|
| 112 |
+
python inference.py
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
### Basic Usage
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
import torch
|
| 119 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 120 |
+
|
| 121 |
+
# Load base model
|
| 122 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 123 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 124 |
+
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
| 125 |
+
device_map="auto",
|
| 126 |
+
output_hidden_states=True,
|
| 127 |
+
)
|
| 128 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
| 129 |
+
|
| 130 |
+
# Load adapter
|
| 131 |
+
adapter = torch.load("cognitive_adapter.pt", map_location="cuda")
|
| 132 |
+
print(f"Probes loaded: {list(adapter['probes'].keys())}")
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## Architecture
|
| 138 |
+
|
| 139 |
+
### System Overview
|
| 140 |
+
|
| 141 |
+
```
|
| 142 |
+
┌─────────────────────────────────────────────────────────────────────────────┐
|
| 143 |
+
│ COGNITIVE ENHANCEMENT ARCHITECTURE │
|
| 144 |
+
├─────────────────────────────────────────────────────────────────────────────┤
|
| 145 |
+
│ │
|
| 146 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 147 |
+
│ │ INPUT PROCESSING │ │
|
| 148 |
+
│ │ User Prompt → Tokenization → Model Forward Pass │ │
|
| 149 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 150 |
+
│ │ │
|
| 151 |
+
│ ▼ │
|
| 152 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 153 |
+
│ │ HIDDEN STATE EXTRACTION │ │
|
| 154 |
+
│ │ Layer 7, 14, 21 → Last Token Position → [batch, 3584] │ │
|
| 155 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 156 |
+
│ │ │
|
| 157 |
+
│ ▼ │
|
| 158 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 159 |
+
│ │ FIBER PROJECTION │ │
|
| 160 |
+
│ │ Per-layer linear projection: 3584 → 16 dimensions │ │
|
| 161 |
+
│ │ Learned layer weights: softmax([w₇, w₁₄, w₂₁]) │ │
|
| 162 |
+
│ │ Weighted sum → 16-dimensional behavioral embedding │ │
|
| 163 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 164 |
+
│ │ │
|
| 165 |
+
│ ▼ │
|
| 166 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 167 |
+
│ │ PROBE HEADS (×5) │ │
|
| 168 |
+
│ │ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────┐ │ │
|
| 169 |
+
│ │ │ Depth │ │Specificity│ │Calibration│ │ Focus │ │Cohere.│ │ │
|
| 170 |
+
│ │ │ 366× │ │ 215× │ │ 165× │ │ 227× │ │ 191× │ │ │
|
| 171 |
+
│ │ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ └───┬───┘ │ │
|
| 172 |
+
│ │ │ │ │ │ │ │ │
|
| 173 |
+
│ │ └─────────────┴──────┬──────┴─────────────┴───────────┘ │ │
|
| 174 |
+
│ │ │ │ │
|
| 175 |
+
│ │ Probe Scores: P(behavior) ∈ [0,1] │ │
|
| 176 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 177 |
+
│ │ │
|
| 178 |
+
│ ▼ │
|
| 179 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 180 |
+
│ │ INTERVENTION ENGINE │ │
|
| 181 |
+
│ │ For each probe where score > threshold (0.5): │ │
|
| 182 |
+
│ │ • Boost tokens: logits[token_id] += strength × boost_factor │ │
|
| 183 |
+
│ │ • Suppress tokens: logits[token_id] -= strength × suppress_factor│ │
|
| 184 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 185 |
+
│ │ │
|
| 186 |
+
│ ▼ │
|
| 187 |
+
│ ┌─────────────────────────────────────────────────────────────────────┐ │
|
| 188 |
+
│ │ OUTPUT SAMPLING │ │
|
| 189 |
+
│ │ Modified logits → Softmax → Token sampling → Next token │ │
|
| 190 |
+
│ └─────────────────────────────────────────────────────────────────────┘ │
|
| 191 |
+
│ │
|
| 192 |
+
└─────────────────────────────────────────────────────────────────────────────┘
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
### Probe Head Architecture
|
| 196 |
+
|
| 197 |
+
Each probe consists of two components:
|
| 198 |
+
|
| 199 |
+
**1. Fiber Projection (shared structure, independent weights)**
|
| 200 |
+
```
|
| 201 |
+
Input: Hidden states from layers [7, 14, 21]
|
| 202 |
+
Shape: [batch, hidden_dim] × 3
|
| 203 |
+
|
| 204 |
+
Layer weights: learnable [3] → softmax
|
| 205 |
+
Per-layer projection: Linear(3584 → 16, bias=False)
|
| 206 |
+
Output: Weighted sum → [batch, 16]
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
**2. Classification Head**
|
| 210 |
+
```
|
| 211 |
+
Input: [batch, 16]
|
| 212 |
+
Linear(16 → 64) → GELU → Linear(64 → 64) → GELU → Linear(64 → 1) → Sigmoid
|
| 213 |
+
Output: P(cognitive_failure_mode) ∈ [0, 1]
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
---
|
| 217 |
+
|
| 218 |
+
## Probe Specifications
|
| 219 |
+
|
| 220 |
+
### Overview
|
| 221 |
+
|
| 222 |
+
| Probe | Separation | Detection Target | Training Steps |
|
| 223 |
+
|:------|:----------:|:-----------------|:--------------:|
|
| 224 |
+
| Depth | 366× | Shallow reasoning patterns | 2000 |
|
| 225 |
+
| Specificity | 215× | Vague or generic language | 2500 |
|
| 226 |
+
| Calibration | 165× | Overconfident assertions | 2500 |
|
| 227 |
+
| Focus | 227× | Topic drift indicators | 2500 |
|
| 228 |
+
| Coherence | 191× | Logical inconsistencies | 2500 |
|
| 229 |
+
|
| 230 |
+
**Separation Ratio** = mean(P(positive_class)) / mean(P(negative_class))
|
| 231 |
+
|
| 232 |
+
Higher separation indicates cleaner discrimination between behavioral states.
|
| 233 |
+
|
| 234 |
+
### Probe Details
|
| 235 |
+
|
| 236 |
+
#### Depth Probe (366×)
|
| 237 |
+
|
| 238 |
+
**Purpose:** Detects when the model is about to produce shallow, unsupported conclusions without intermediate reasoning steps.
|
| 239 |
+
|
| 240 |
+
**Positive class indicators:**
|
| 241 |
+
- Single-sentence answers to complex questions
|
| 242 |
+
- Missing causal connectives
|
| 243 |
+
- Absence of step-by-step structure
|
| 244 |
+
|
| 245 |
+
**Intervention tokens:**
|
| 246 |
+
- Boost: "First", "Because", "Since", "Therefore", "Let", "Step", "Consider"
|
| 247 |
+
- Suppress: "Simply", "Just", "Obviously", "Clearly"
|
| 248 |
+
|
| 249 |
+
#### Specificity Probe (215×)
|
| 250 |
+
|
| 251 |
+
**Purpose:** Detects when the model is about to produce vague, non-committal language lacking concrete details.
|
| 252 |
+
|
| 253 |
+
**Positive class indicators:**
|
| 254 |
+
- Generic nouns: "things", "stuff", "something"
|
| 255 |
+
- Hedging qualifiers: "kind of", "sort of", "basically"
|
| 256 |
+
- Absence of examples or specific instances
|
| 257 |
+
|
| 258 |
+
**Intervention tokens:**
|
| 259 |
+
- Boost: "specifically", "example", "namely", "particular", "instance", "precisely"
|
| 260 |
+
- Suppress: "things", "stuff", "various", "generally", "basically", "kind of"
|
| 261 |
+
|
| 262 |
+
#### Calibration Probe (165×)
|
| 263 |
+
|
| 264 |
+
**Purpose:** Detects when the model is about to make overconfident claims on inherently uncertain topics.
|
| 265 |
+
|
| 266 |
+
**Positive class indicators:**
|
| 267 |
+
- Absolute certainty markers on speculative topics
|
| 268 |
+
- Missing epistemic hedging
|
| 269 |
+
- Deterministic language for probabilistic questions
|
| 270 |
+
|
| 271 |
+
**Intervention tokens:**
|
| 272 |
+
- Boost: "might", "possibly", "perhaps", "likely", "probably", "could", "may"
|
| 273 |
+
- Suppress: "definitely", "certainly", "absolutely", "always", "never", "guaranteed"
|
| 274 |
+
|
| 275 |
+
#### Focus Probe (227×)
|
| 276 |
+
|
| 277 |
+
**Purpose:** Detects when the model is about to drift away from the user's question or introduce tangential content.
|
| 278 |
+
|
| 279 |
+
**Positive class indicators:**
|
| 280 |
+
- Tangent markers: "by the way", "speaking of"
|
| 281 |
+
- Unrelated topic introductions
|
| 282 |
+
- Loss of reference to original query
|
| 283 |
+
|
| 284 |
+
**Intervention tokens:**
|
| 285 |
+
- Boost: "regarding", "answer", "question", "specifically", "directly", "topic"
|
| 286 |
+
- Suppress: "anyway", "tangent", "aside", "by the way", "incidentally"
|
| 287 |
+
|
| 288 |
+
#### Coherence Probe (191×)
|
| 289 |
+
|
| 290 |
+
**Purpose:** Detects when the model is about to produce logically inconsistent or poorly structured content.
|
| 291 |
+
|
| 292 |
+
**Positive class indicators:**
|
| 293 |
+
- Missing transition words
|
| 294 |
+
- Contradictory statements
|
| 295 |
+
- Non-sequitur progressions
|
| 296 |
+
|
| 297 |
+
**Intervention tokens:**
|
| 298 |
+
- Boost: "however", "therefore", "thus", "furthermore", "moreover", "because", "consequently"
|
| 299 |
+
- Suppress: (none — coherence is structural)
|
| 300 |
+
|
| 301 |
+
---
|
| 302 |
+
|
| 303 |
+
## Intervention Mechanism
|
| 304 |
+
|
| 305 |
+
### Algorithm
|
| 306 |
+
|
| 307 |
+
```python
|
| 308 |
+
def apply_intervention(logits, probe_scores, config):
|
| 309 |
+
"""
|
| 310 |
+
Modify logits based on probe activations.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
logits: [vocab_size] tensor of next-token logits
|
| 314 |
+
probe_scores: dict mapping probe_name → score ∈ [0,1]
|
| 315 |
+
config: intervention parameters
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
Modified logits tensor
|
| 319 |
+
"""
|
| 320 |
+
for probe_name, score in probe_scores.items():
|
| 321 |
+
if score > config.threshold: # Default: 0.5
|
| 322 |
+
strength = (score - config.threshold) * 2 # Scale to [0, 1]
|
| 323 |
+
|
| 324 |
+
# Boost beneficial tokens
|
| 325 |
+
for token_id in config.boost_tokens[probe_name]:
|
| 326 |
+
logits[token_id] += strength * config.boost_strength
|
| 327 |
+
|
| 328 |
+
# Suppress harmful tokens
|
| 329 |
+
for token_id in config.suppress_tokens[probe_name]:
|
| 330 |
+
logits[token_id] -= strength * config.suppress_strength
|
| 331 |
+
|
| 332 |
+
return logits
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
### Parameters
|
| 336 |
+
|
| 337 |
+
| Parameter | Default | Description |
|
| 338 |
+
|:----------|:--------|:------------|
|
| 339 |
+
| `threshold` | 0.5 | Minimum probe score to trigger intervention |
|
| 340 |
+
| `boost_strength` | 3.0 | Multiplier for token boosting |
|
| 341 |
+
| `suppress_strength` | 4.0 | Multiplier for token suppression |
|
| 342 |
+
|
| 343 |
+
---
|
| 344 |
+
|
| 345 |
+
## Installation
|
| 346 |
+
|
| 347 |
+
### Requirements
|
| 348 |
+
|
| 349 |
+
```bash
|
| 350 |
+
pip install torch>=2.0.0
|
| 351 |
+
pip install transformers>=4.35.0
|
| 352 |
+
pip install accelerate>=0.24.0
|
| 353 |
+
pip install bitsandbytes>=0.41.0 # For 4-bit quantization
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
### Full Installation
|
| 357 |
+
|
| 358 |
+
```bash
|
| 359 |
+
git clone https://huggingface.co/LoganResearch/qwen2.5-7b-cognitive-enhanced
|
| 360 |
+
cd qwen2.5-7b-cognitive-enhanced
|
| 361 |
+
pip install -r requirements.txt
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
---
|
| 365 |
+
|
| 366 |
+
## Usage
|
| 367 |
+
|
| 368 |
+
### Complete Inference Example
|
| 369 |
+
|
| 370 |
+
See `inference.py` for the full `CognitiveEnhancedQwen` class implementation.
|
| 371 |
+
|
| 372 |
+
```python
|
| 373 |
+
from inference import CognitiveEnhancedQwen
|
| 374 |
+
|
| 375 |
+
# Initialize
|
| 376 |
+
qwen = CognitiveEnhancedQwen("cognitive_adapter.pt")
|
| 377 |
+
|
| 378 |
+
# Generate with enhancement
|
| 379 |
+
response = qwen.generate(
|
| 380 |
+
prompt="Explain why the sky is blue.",
|
| 381 |
+
enhanced=True,
|
| 382 |
+
max_tokens=300,
|
| 383 |
+
temperature=0.7
|
| 384 |
+
)
|
| 385 |
+
print(response)
|
| 386 |
+
|
| 387 |
+
# Compare vanilla vs enhanced
|
| 388 |
+
vanilla = qwen.generate("Explain the Monty Hall problem.", enhanced=False)
|
| 389 |
+
enhanced = qwen.generate("Explain the Monty Hall problem.", enhanced=True)
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
### Selective Probe Activation
|
| 393 |
+
|
| 394 |
+
```python
|
| 395 |
+
# Enable only specific probes
|
| 396 |
+
qwen.active_probes = ["depth", "calibration"]
|
| 397 |
+
|
| 398 |
+
# Disable a probe
|
| 399 |
+
qwen.active_probes = [p for p in qwen.probes.keys() if p != "focus"]
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
---
|
| 403 |
+
|
| 404 |
+
## Evaluation
|
| 405 |
+
|
| 406 |
+
### Qualitative Comparison
|
| 407 |
+
|
| 408 |
+
| Prompt | Vanilla Qwen | Enhanced Qwen |
|
| 409 |
+
|:-------|:-------------|:--------------|
|
| 410 |
+
| "Explain the Monty Hall problem" | Begins explanation without structure | "Here's a step-by-step explanation..." with labeled sections |
|
| 411 |
+
| "Will AI replace most jobs?" | "It's unlikely that AI will replace..." (leads with conclusion) | "The question is complex and multifaceted..." (acknowledges uncertainty) |
|
| 412 |
+
| "How can I improve productivity?" | Lists techniques by name | Explains techniques with specific details (e.g., "SMART criteria: Specific, Measurable...") |
|
| 413 |
+
|
| 414 |
+
### Observed Behavioral Changes
|
| 415 |
+
|
| 416 |
+
| Dimension | Vanilla | Enhanced | Change |
|
| 417 |
+
|:----------|:--------|:---------|:-------|
|
| 418 |
+
| Step-by-step reasoning | Occasional | Consistent | Improved |
|
| 419 |
+
| Concrete examples | Sometimes present | More frequent | Improved |
|
| 420 |
+
| Epistemic hedging | Inconsistent | Appropriate | Improved |
|
| 421 |
+
| Topic adherence | Generally good | Slightly improved | Marginal |
|
| 422 |
+
| Logical transitions | Present | More explicit | Improved |
|
| 423 |
+
|
| 424 |
+
**Note:** These are qualitative observations from limited testing. Independent benchmark evaluation is recommended before deployment.
|
| 425 |
+
|
| 426 |
+
---
|
| 427 |
+
|
| 428 |
+
## Configuration
|
| 429 |
+
|
| 430 |
+
### config.json Structure
|
| 431 |
+
|
| 432 |
+
```json
|
| 433 |
+
{
|
| 434 |
+
"model_type": "cognitive_enhancement_adapter",
|
| 435 |
+
"version": "1.0.0",
|
| 436 |
+
"base_model": "Qwen/Qwen2.5-7B-Instruct",
|
| 437 |
+
"architecture": {
|
| 438 |
+
"hidden_dim": 3584,
|
| 439 |
+
"fiber_dim": 16,
|
| 440 |
+
"head_hidden_dim": 64,
|
| 441 |
+
"probe_layers": [7, 14, 21]
|
| 442 |
+
},
|
| 443 |
+
"usage": {
|
| 444 |
+
"boost_strength": 3.0,
|
| 445 |
+
"suppress_strength": 4.0,
|
| 446 |
+
"threshold": 0.5
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
```
|
| 450 |
+
|
| 451 |
+
---
|
| 452 |
+
|
| 453 |
+
## Hardware Requirements
|
| 454 |
+
|
| 455 |
+
| Component | Minimum | Recommended |
|
| 456 |
+
|:----------|:--------|:------------|
|
| 457 |
+
| GPU VRAM | 8 GB (4-bit) | 16+ GB |
|
| 458 |
+
| System RAM | 16 GB | 32 GB |
|
| 459 |
+
| Storage | 20 GB | 50 GB |
|
| 460 |
+
|
| 461 |
+
**Tested Configuration:**
|
| 462 |
+
- NVIDIA RTX 3090 (24GB), 64GB RAM ✓
|
| 463 |
+
|
| 464 |
+
**Performance:**
|
| 465 |
+
- Inference overhead: ~5% additional latency from probe computation
|
| 466 |
+
- Adapter size: 3.57 MB
|
| 467 |
+
|
| 468 |
+
---
|
| 469 |
+
|
| 470 |
+
## Limitations
|
| 471 |
+
|
| 472 |
+
### Known Limitations
|
| 473 |
+
|
| 474 |
+
| Limitation | Description |
|
| 475 |
+
|:-----------|:------------|
|
| 476 |
+
| **Base model dependency** | Inherits all limitations of Qwen2.5-7B-Instruct |
|
| 477 |
+
| **Language** | English only (training data was English) |
|
| 478 |
+
| **Evaluation** | No formal benchmark results; qualitative assessment only |
|
| 479 |
+
| **Intervention scope** | Token-level intervention cannot fix deep reasoning errors |
|
| 480 |
+
| **Training data** | Synthetic training examples may not cover all edge cases |
|
| 481 |
+
| **Generalization** | Probe behavior on out-of-distribution inputs is unknown |
|
| 482 |
+
|
| 483 |
+
### What This Is Not
|
| 484 |
+
|
| 485 |
+
- This is **not** a fine-tuned model — base weights are unchanged
|
| 486 |
+
- This does **not** add knowledge — only modifies generation behavior
|
| 487 |
+
- This does **not** guarantee improved outputs — effectiveness varies by prompt
|
| 488 |
+
- This is **not** validated for production use
|
| 489 |
+
|
| 490 |
+
---
|
| 491 |
+
|
| 492 |
+
## Technical Specification
|
| 493 |
+
|
| 494 |
+
### Training Details
|
| 495 |
+
|
| 496 |
+
- **Training steps:** 2000-2500 per probe
|
| 497 |
+
- **Batch size:** 4
|
| 498 |
+
- **Learning rate:** 5e-5
|
| 499 |
+
- **Optimizer:** AdamW
|
| 500 |
+
- **Early stopping:** Applied to prevent overfitting (observed at ~2700 steps)
|
| 501 |
+
|
| 502 |
+
### Probe Training Data
|
| 503 |
+
|
| 504 |
+
Each probe was trained on ~3000 synthetic examples:
|
| 505 |
+
- **Positive class:** Examples exhibiting the target failure mode
|
| 506 |
+
- **Negative class:** Examples demonstrating desired behavior
|
| 507 |
+
- **Labeling:** Per-sequence binary classification
|
| 508 |
+
|
| 509 |
+
### File Structure
|
| 510 |
+
|
| 511 |
+
```
|
| 512 |
+
qwen2.5-7b-cognitive-enhanced/
|
| 513 |
+
├── cognitive_adapter.pt # Merged probe weights (3.57 MB)
|
| 514 |
+
├── config.json # Architecture and intervention config
|
| 515 |
+
├── inference.py # Ready-to-use inference class
|
| 516 |
+
└── README.md # This file
|
| 517 |
+
```
|
| 518 |
+
|
| 519 |
+
---
|
| 520 |
+
|
| 521 |
+
## Citation
|
| 522 |
+
|
| 523 |
+
```bibtex
|
| 524 |
+
@software{napolitano2026cognitive,
|
| 525 |
+
author = {Napolitano, Logan Matthew},
|
| 526 |
+
title = {Cognitive Enhancement Adapter for {Qwen2.5-7B}},
|
| 527 |
+
year = {2026},
|
| 528 |
+
publisher = {Hugging Face},
|
| 529 |
+
url = {https://huggingface.co/LoganResearch/qwen2.5-7b-cognitive-enhanced},
|
| 530 |
+
license = {CC BY 4.0}
|
| 531 |
+
}
|
| 532 |
+
```
|
| 533 |
+
|
| 534 |
+
### Related Work
|
| 535 |
+
|
| 536 |
+
- [ARC-Base-8B-Condensed](https://huggingface.co/LoganResearch/ARC-Base-8B-Condensed) — Self-improving language model with CF-HoT behavioral control
|
| 537 |
+
- [Qwen2.5-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-7B-Instruct) — Base model
|
| 538 |
+
|
| 539 |
+
---
|
| 540 |
+
|
| 541 |
+
## License
|
| 542 |
+
|
| 543 |
+
This work is licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/) (Creative Commons Attribution 4.0 International).
|
| 544 |
+
|
| 545 |
+
You are free to:
|
| 546 |
+
- **Share** — copy and redistribute the material in any medium or format
|
| 547 |
+
- **Adapt** — remix, transform, and build upon the material for any purpose, including commercial
|
| 548 |
+
|
| 549 |
+
Under the following terms:
|
| 550 |
+
- **Attribution** — You must give appropriate credit, provide a link to the license, and indicate if changes were made.
|
| 551 |
+
|
| 552 |
+
---
|
| 553 |
+
|
| 554 |
+
## Acknowledgments
|
| 555 |
+
|
| 556 |
+
- **Alibaba Cloud** for Qwen2.5-7B-Instruct base model
|
| 557 |
+
- **Hugging Face** for transformers library and model hosting
|
| 558 |
+
|
| 559 |
+
---
|
| 560 |
+
|
| 561 |
+
<div align="center">
|
| 562 |
+
|
| 563 |
+
**Contact:** [Hugging Face Discussions](https://huggingface.co/LoganResearch/qwen2.5-7b-cognitive-enhanced/discussions)
|
| 564 |
+
|
| 565 |
+
**Version:** 1.0.0 | **Released:** February 2026
|
| 566 |
+
|
| 567 |
+
*Logan Napolitano / Fiber AI*
|
| 568 |
+
|
| 569 |
+
</div>
|
production/qwen_cognitive/cognitive_adapter.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1197b060e2064b857044e2148e2be23f23857a63084373ada56fc5610373a6a4
|
| 3 |
+
size 3565757
|
production/qwen_cognitive/config.json
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "cognitive_enhancement_adapter",
|
| 3 |
+
"version": "1.0.0",
|
| 4 |
+
"base_model": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"architecture": {
|
| 6 |
+
"hidden_dim": 3584,
|
| 7 |
+
"fiber_dim": 16,
|
| 8 |
+
"head_hidden_dim": 64,
|
| 9 |
+
"probe_layers": [
|
| 10 |
+
7,
|
| 11 |
+
14,
|
| 12 |
+
21
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"probes": {
|
| 16 |
+
"depth": {
|
| 17 |
+
"separation": 366.2035633115866,
|
| 18 |
+
"description": "Detects shallow reasoning, encourages step-by-step thinking"
|
| 19 |
+
},
|
| 20 |
+
"specificity": {
|
| 21 |
+
"separation": 18.80886216321723,
|
| 22 |
+
"description": "Detects vague answers, encourages concrete examples"
|
| 23 |
+
},
|
| 24 |
+
"calibration": {
|
| 25 |
+
"separation": 46.77315421768513,
|
| 26 |
+
"description": "Detects overconfidence, encourages appropriate uncertainty"
|
| 27 |
+
},
|
| 28 |
+
"focus": {
|
| 29 |
+
"separation": 70.25854855375214,
|
| 30 |
+
"description": "Detects topic drift, encourages staying on-topic"
|
| 31 |
+
},
|
| 32 |
+
"coherence": {
|
| 33 |
+
"separation": 190.5594291230507,
|
| 34 |
+
"description": "Detects logical inconsistency, encourages smooth transitions"
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"interventions": {
|
| 38 |
+
"depth": {
|
| 39 |
+
"boost": [
|
| 40 |
+
"First",
|
| 41 |
+
"Because",
|
| 42 |
+
"Since",
|
| 43 |
+
"Therefore",
|
| 44 |
+
"Let",
|
| 45 |
+
"Step",
|
| 46 |
+
"Consider"
|
| 47 |
+
],
|
| 48 |
+
"suppress": [
|
| 49 |
+
"Simply",
|
| 50 |
+
"Just",
|
| 51 |
+
"Obviously"
|
| 52 |
+
]
|
| 53 |
+
},
|
| 54 |
+
"specificity": {
|
| 55 |
+
"boost": [
|
| 56 |
+
"specifically",
|
| 57 |
+
"example",
|
| 58 |
+
"namely",
|
| 59 |
+
"particular",
|
| 60 |
+
"instance"
|
| 61 |
+
],
|
| 62 |
+
"suppress": [
|
| 63 |
+
"things",
|
| 64 |
+
"stuff",
|
| 65 |
+
"various",
|
| 66 |
+
"generally",
|
| 67 |
+
"basically"
|
| 68 |
+
]
|
| 69 |
+
},
|
| 70 |
+
"calibration": {
|
| 71 |
+
"boost": [
|
| 72 |
+
"might",
|
| 73 |
+
"possibly",
|
| 74 |
+
"perhaps",
|
| 75 |
+
"likely",
|
| 76 |
+
"probably",
|
| 77 |
+
"could"
|
| 78 |
+
],
|
| 79 |
+
"suppress": [
|
| 80 |
+
"definitely",
|
| 81 |
+
"certainly",
|
| 82 |
+
"absolutely",
|
| 83 |
+
"always",
|
| 84 |
+
"never"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
"focus": {
|
| 88 |
+
"boost": [
|
| 89 |
+
"regarding",
|
| 90 |
+
"answer",
|
| 91 |
+
"question",
|
| 92 |
+
"specifically",
|
| 93 |
+
"directly"
|
| 94 |
+
],
|
| 95 |
+
"suppress": [
|
| 96 |
+
"anyway",
|
| 97 |
+
"tangent",
|
| 98 |
+
"aside",
|
| 99 |
+
"by the way"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
"coherence": {
|
| 103 |
+
"boost": [
|
| 104 |
+
"however",
|
| 105 |
+
"therefore",
|
| 106 |
+
"thus",
|
| 107 |
+
"furthermore",
|
| 108 |
+
"moreover",
|
| 109 |
+
"because"
|
| 110 |
+
],
|
| 111 |
+
"suppress": []
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"usage": {
|
| 115 |
+
"boost_strength": 3.0,
|
| 116 |
+
"suppress_strength": 4.0,
|
| 117 |
+
"threshold": 0.5
|
| 118 |
+
},
|
| 119 |
+
"license": "cc-by-4.0",
|
| 120 |
+
"author": "Logan Napolitano / Fiber AI",
|
| 121 |
+
"paper": "https://github.com/logannapolitano/fiber-ai"
|
| 122 |
+
}
|
production/qwen_cognitive/inference.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Inference script for Qwen2.5-7B with Cognitive Enhancement Adapter
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class FiberProjection(nn.Module):
|
| 12 |
+
def __init__(self, hidden_dim=3584, fiber_dim=16, num_layers=3):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
| 15 |
+
self.projections = nn.ModuleList([
|
| 16 |
+
nn.Linear(hidden_dim, fiber_dim, bias=False) for _ in range(num_layers)
|
| 17 |
+
])
|
| 18 |
+
|
| 19 |
+
def forward(self, hidden_states_list):
|
| 20 |
+
weights = torch.softmax(self.layer_weights, dim=0)
|
| 21 |
+
return sum(w * proj(h.float()) for w, h, proj in
|
| 22 |
+
zip(weights, hidden_states_list, self.projections))
|
| 23 |
+
|
| 24 |
+
class ProbeHead(nn.Module):
|
| 25 |
+
def __init__(self, fiber_dim=16, hidden_dim=64):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.classifier = nn.Sequential(
|
| 28 |
+
nn.Linear(fiber_dim, hidden_dim), nn.GELU(),
|
| 29 |
+
nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
|
| 30 |
+
nn.Linear(hidden_dim, 1),
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return torch.sigmoid(self.classifier(x))
|
| 35 |
+
|
| 36 |
+
class CognitiveEnhancedQwen:
|
| 37 |
+
def __init__(self, adapter_path="cognitive_adapter.pt", device="cuda"):
|
| 38 |
+
self.device = device
|
| 39 |
+
|
| 40 |
+
# Load base model
|
| 41 |
+
print("Loading Qwen2.5-7B-Instruct...")
|
| 42 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
| 43 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
+
"Qwen/Qwen2.5-7B-Instruct",
|
| 45 |
+
quantization_config=BitsAndBytesConfig(
|
| 46 |
+
load_in_4bit=True,
|
| 47 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 48 |
+
bnb_4bit_use_double_quant=True,
|
| 49 |
+
bnb_4bit_quant_type="nf4"
|
| 50 |
+
),
|
| 51 |
+
device_map="auto",
|
| 52 |
+
output_hidden_states=True,
|
| 53 |
+
)
|
| 54 |
+
self.model.eval()
|
| 55 |
+
|
| 56 |
+
# Load adapter
|
| 57 |
+
print("Loading cognitive adapter...")
|
| 58 |
+
adapter = torch.load(adapter_path, map_location=device)
|
| 59 |
+
self.config = adapter['config']
|
| 60 |
+
self.probe_layers = self.config['probe_layers']
|
| 61 |
+
|
| 62 |
+
# Build probes
|
| 63 |
+
self.probes = {}
|
| 64 |
+
for name, probe_data in adapter['probes'].items():
|
| 65 |
+
fiber = FiberProjection(
|
| 66 |
+
hidden_dim=self.config['hidden_dim'],
|
| 67 |
+
fiber_dim=self.config['fiber_dim'],
|
| 68 |
+
num_layers=self.config['num_layers']
|
| 69 |
+
).to(device)
|
| 70 |
+
fiber.load_state_dict(probe_data['fiber_projection'])
|
| 71 |
+
fiber.eval()
|
| 72 |
+
|
| 73 |
+
head = ProbeHead(
|
| 74 |
+
fiber_dim=self.config['fiber_dim'],
|
| 75 |
+
hidden_dim=self.config['head_hidden_dim']
|
| 76 |
+
).to(device)
|
| 77 |
+
head.load_state_dict(probe_data['head_state'])
|
| 78 |
+
head.eval()
|
| 79 |
+
|
| 80 |
+
self.probes[name] = {'fiber': fiber, 'head': head}
|
| 81 |
+
print(f" ✓ {name}: {adapter['separations'][name]:.1f}× separation")
|
| 82 |
+
|
| 83 |
+
# Load config for interventions
|
| 84 |
+
with open(adapter_path.replace('.pt', '.json').replace('cognitive_adapter', 'config'), 'r') as f:
|
| 85 |
+
self.interventions = json.load(f)['interventions']
|
| 86 |
+
|
| 87 |
+
# Build token ID maps
|
| 88 |
+
self._build_token_maps()
|
| 89 |
+
print("Ready!")
|
| 90 |
+
|
| 91 |
+
def _build_token_maps(self):
|
| 92 |
+
self.token_ids = {}
|
| 93 |
+
for name, tokens in self.interventions.items():
|
| 94 |
+
self.token_ids[name] = {"boost": set(), "suppress": set()}
|
| 95 |
+
for tok in tokens.get("boost", []):
|
| 96 |
+
self.token_ids[name]["boost"].update(
|
| 97 |
+
self.tokenizer.encode(tok, add_special_tokens=False))
|
| 98 |
+
self.token_ids[name]["boost"].update(
|
| 99 |
+
self.tokenizer.encode(" " + tok, add_special_tokens=False))
|
| 100 |
+
for tok in tokens.get("suppress", []):
|
| 101 |
+
self.token_ids[name]["suppress"].update(
|
| 102 |
+
self.tokenizer.encode(tok, add_special_tokens=False))
|
| 103 |
+
self.token_ids[name]["suppress"].update(
|
| 104 |
+
self.tokenizer.encode(" " + tok, add_special_tokens=False))
|
| 105 |
+
|
| 106 |
+
def get_probe_scores(self, hidden_states):
|
| 107 |
+
hs = [hidden_states[i][:, -1, :] for i in self.probe_layers]
|
| 108 |
+
return {name: probe['head'](probe['fiber'](hs)).item()
|
| 109 |
+
for name, probe in self.probes.items()}
|
| 110 |
+
|
| 111 |
+
def generate(self, prompt, enhanced=True, max_tokens=300,
|
| 112 |
+
boost_strength=3.0, suppress_strength=4.0, temperature=0.7):
|
| 113 |
+
messages = [{"role": "user", "content": prompt}]
|
| 114 |
+
text = self.tokenizer.apply_chat_template(
|
| 115 |
+
messages, tokenize=False, add_generation_prompt=True)
|
| 116 |
+
inputs = self.tokenizer(text, return_tensors="pt").to(self.device)
|
| 117 |
+
generated = inputs['input_ids'].clone()
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
for _ in range(max_tokens):
|
| 121 |
+
outputs = self.model(
|
| 122 |
+
input_ids=generated,
|
| 123 |
+
output_hidden_states=True,
|
| 124 |
+
return_dict=True
|
| 125 |
+
)
|
| 126 |
+
logits = outputs.logits[:, -1, :] / temperature
|
| 127 |
+
|
| 128 |
+
if enhanced:
|
| 129 |
+
scores = self.get_probe_scores(outputs.hidden_states)
|
| 130 |
+
for name, score in scores.items():
|
| 131 |
+
if score > 0.5 and name in self.token_ids:
|
| 132 |
+
strength = (score - 0.5) * 2
|
| 133 |
+
for tid in self.token_ids[name]["boost"]:
|
| 134 |
+
if tid < logits.shape[-1]:
|
| 135 |
+
logits[0, tid] += strength * boost_strength
|
| 136 |
+
for tid in self.token_ids[name]["suppress"]:
|
| 137 |
+
if tid < logits.shape[-1]:
|
| 138 |
+
logits[0, tid] -= strength * suppress_strength
|
| 139 |
+
|
| 140 |
+
probs = torch.softmax(logits, dim=-1)
|
| 141 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 142 |
+
generated = torch.cat([generated, next_token], dim=-1)
|
| 143 |
+
|
| 144 |
+
if next_token.item() == self.tokenizer.eos_token_id:
|
| 145 |
+
break
|
| 146 |
+
|
| 147 |
+
return self.tokenizer.decode(
|
| 148 |
+
generated[0][inputs['input_ids'].shape[1]:],
|
| 149 |
+
skip_special_tokens=True
|
| 150 |
+
).strip()
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
qwen = CognitiveEnhancedQwen()
|
| 154 |
+
|
| 155 |
+
prompt = "Explain why the sky is blue."
|
| 156 |
+
|
| 157 |
+
print("\n" + "="*60)
|
| 158 |
+
print("VANILLA:")
|
| 159 |
+
print(qwen.generate(prompt, enhanced=False))
|
| 160 |
+
|
| 161 |
+
print("\n" + "="*60)
|
| 162 |
+
print("ENHANCED:")
|
| 163 |
+
print(qwen.generate(prompt, enhanced=True))
|
results/hedging_results.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"step": 400,
|
| 4 |
+
"accuracy": 0.4442897439002991,
|
| 5 |
+
"precision": 0.03143163271296894,
|
| 6 |
+
"recall": 0.8222307544634286,
|
| 7 |
+
"f1": 0.06054865592727941,
|
| 8 |
+
"pos_risk": 0.520409107208252,
|
| 9 |
+
"neg_risk": 0.49916452169418335,
|
| 10 |
+
"separation": 1.0425602874217976
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"step": 800,
|
| 14 |
+
"accuracy": 0.7166696190834045,
|
| 15 |
+
"precision": 0.05191828314374946,
|
| 16 |
+
"recall": 0.6957189479746593,
|
| 17 |
+
"f1": 0.09662582821186226,
|
| 18 |
+
"pos_risk": 0.5107523798942566,
|
| 19 |
+
"neg_risk": 0.4790365397930145,
|
| 20 |
+
"separation": 1.0662075592708358
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"step": 1200,
|
| 24 |
+
"accuracy": 0.8910630941390991,
|
| 25 |
+
"precision": 0.10129681343483417,
|
| 26 |
+
"recall": 0.5083509310808216,
|
| 27 |
+
"f1": 0.16893141945773524,
|
| 28 |
+
"pos_risk": 0.493786096572876,
|
| 29 |
+
"neg_risk": 0.45646682381629944,
|
| 30 |
+
"separation": 1.081756813002461
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"step": 1600,
|
| 34 |
+
"accuracy": 0.8538420796394348,
|
| 35 |
+
"precision": 0.07995142477901099,
|
| 36 |
+
"recall": 0.5434824342484162,
|
| 37 |
+
"f1": 0.13939632675168645,
|
| 38 |
+
"pos_risk": 0.4988616406917572,
|
| 39 |
+
"neg_risk": 0.45187142491340637,
|
| 40 |
+
"separation": 1.1039902352474615
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"step": 2000,
|
| 44 |
+
"accuracy": 0.8356873393058777,
|
| 45 |
+
"precision": 0.07362101313320825,
|
| 46 |
+
"recall": 0.5649836820886927,
|
| 47 |
+
"f1": 0.13026735127478753,
|
| 48 |
+
"pos_risk": 0.5041213035583496,
|
| 49 |
+
"neg_risk": 0.44911104440689087,
|
| 50 |
+
"separation": 1.122486988099139
|
| 51 |
+
}
|
| 52 |
+
]
|
results/hedging_results_continued.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"step": 3000,
|
| 4 |
+
"accuracy": 0.7798140048980713,
|
| 5 |
+
"precision": 0.06394045212277155,
|
| 6 |
+
"recall": 0.6678825110385871,
|
| 7 |
+
"f1": 0.11670776094869087,
|
| 8 |
+
"pos_risk": 0.524723470211029,
|
| 9 |
+
"neg_risk": 0.4489431381225586,
|
| 10 |
+
"separation": 1.1687971719656465
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"step": 4000,
|
| 14 |
+
"accuracy": 0.8607996106147766,
|
| 15 |
+
"precision": 0.08731814842027921,
|
| 16 |
+
"recall": 0.5703589940487618,
|
| 17 |
+
"f1": 0.15145027272263856,
|
| 18 |
+
"pos_risk": 0.5151649713516235,
|
| 19 |
+
"neg_risk": 0.4292229413986206,
|
| 20 |
+
"separation": 1.2002270187911235
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"step": 5000,
|
| 24 |
+
"accuracy": 0.8619410991668701,
|
| 25 |
+
"precision": 0.09372991293168936,
|
| 26 |
+
"recall": 0.6158571702822039,
|
| 27 |
+
"f1": 0.1626981108152656,
|
| 28 |
+
"pos_risk": 0.5229318737983704,
|
| 29 |
+
"neg_risk": 0.42363420128822327,
|
| 30 |
+
"separation": 1.234394843967258
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"step": 6000,
|
| 34 |
+
"accuracy": 0.873987078666687,
|
| 35 |
+
"precision": 0.10097000352146493,
|
| 36 |
+
"recall": 0.6054904972163563,
|
| 37 |
+
"f1": 0.17307797837897163,
|
| 38 |
+
"pos_risk": 0.5275717377662659,
|
| 39 |
+
"neg_risk": 0.4137572944164276,
|
| 40 |
+
"separation": 1.2750753760374536
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"step": 7000,
|
| 44 |
+
"accuracy": 0.9015830159187317,
|
| 45 |
+
"precision": 0.12172369670202667,
|
| 46 |
+
"recall": 0.5661355346515646,
|
| 47 |
+
"f1": 0.20036689767631471,
|
| 48 |
+
"pos_risk": 0.521373987197876,
|
| 49 |
+
"neg_risk": 0.3951682150363922,
|
| 50 |
+
"separation": 1.319372275803764
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"step": 8000,
|
| 54 |
+
"accuracy": 0.8688943982124329,
|
| 55 |
+
"precision": 0.09942703067071115,
|
| 56 |
+
"recall": 0.6229602610865809,
|
| 57 |
+
"f1": 0.1714844369286054,
|
| 58 |
+
"pos_risk": 0.5516535639762878,
|
| 59 |
+
"neg_risk": 0.40235474705696106,
|
| 60 |
+
"separation": 1.3710626456165327
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"step": 9000,
|
| 64 |
+
"accuracy": 0.8865934014320374,
|
| 65 |
+
"precision": 0.11200424929178471,
|
| 66 |
+
"recall": 0.6072182760606643,
|
| 67 |
+
"f1": 0.18912374062004847,
|
| 68 |
+
"pos_risk": 0.5500459671020508,
|
| 69 |
+
"neg_risk": 0.3843628168106079,
|
| 70 |
+
"separation": 1.4310592571525516
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"step": 10000,
|
| 74 |
+
"accuracy": 0.8839634656906128,
|
| 75 |
+
"precision": 0.11537280327589149,
|
| 76 |
+
"recall": 0.6490689191783452,
|
| 77 |
+
"f1": 0.19592049603059628,
|
| 78 |
+
"pos_risk": 0.5594488978385925,
|
| 79 |
+
"neg_risk": 0.37506988644599915,
|
| 80 |
+
"separation": 1.491585749897678
|
| 81 |
+
}
|
| 82 |
+
]
|
results/hedging_summary.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"head_name": "hedging",
|
| 3 |
+
"start_step": 10000,
|
| 4 |
+
"final_step": 25000,
|
| 5 |
+
"best_separation": 168.37748922759167,
|
| 6 |
+
"target_separation": 5.0,
|
| 7 |
+
"achieved": true
|
| 8 |
+
}
|
results/mistral_cognitive_results.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"depth": 999.638373580049,
|
| 3 |
+
"specificity": 999.6663282511262,
|
| 4 |
+
"calibration": 999.4429833748761,
|
| 5 |
+
"focus": 999.5846075316271,
|
| 6 |
+
"coherence": 999.6786809149589
|
| 7 |
+
}
|
results/sycophancy_results.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"step": 400,
|
| 4 |
+
"accuracy": 0.35452115535736084,
|
| 5 |
+
"precision": 0.019912276473707042,
|
| 6 |
+
"recall": 0.5976851851851852,
|
| 7 |
+
"f1": 0.038540549113265106,
|
| 8 |
+
"pos_risk": 0.5032691359519958,
|
| 9 |
+
"neg_risk": 0.5036115050315857,
|
| 10 |
+
"separation": 0.9993201722435464
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"step": 800,
|
| 14 |
+
"accuracy": 0.9783545732498169,
|
| 15 |
+
"precision": 0.0,
|
| 16 |
+
"recall": 0.0,
|
| 17 |
+
"f1": 0.0,
|
| 18 |
+
"pos_risk": 0.46720242500305176,
|
| 19 |
+
"neg_risk": 0.46737194061279297,
|
| 20 |
+
"separation": 0.9996373004132021
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"step": 1200,
|
| 24 |
+
"accuracy": 0.9783545732498169,
|
| 25 |
+
"precision": 0.0,
|
| 26 |
+
"recall": 0.0,
|
| 27 |
+
"f1": 0.0,
|
| 28 |
+
"pos_risk": 0.43459680676460266,
|
| 29 |
+
"neg_risk": 0.4344600737094879,
|
| 30 |
+
"separation": 1.0003147194952744
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"step": 1600,
|
| 34 |
+
"accuracy": 0.9783545732498169,
|
| 35 |
+
"precision": 0.0,
|
| 36 |
+
"recall": 0.0,
|
| 37 |
+
"f1": 0.0,
|
| 38 |
+
"pos_risk": 0.4037657082080841,
|
| 39 |
+
"neg_risk": 0.40317171812057495,
|
| 40 |
+
"separation": 1.001473293043168
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"step": 2000,
|
| 44 |
+
"accuracy": 0.9783545732498169,
|
| 45 |
+
"precision": 0.0,
|
| 46 |
+
"recall": 0.0,
|
| 47 |
+
"f1": 0.0,
|
| 48 |
+
"pos_risk": 0.3766477108001709,
|
| 49 |
+
"neg_risk": 0.37547680735588074,
|
| 50 |
+
"separation": 1.0031184441258454
|
| 51 |
+
}
|
| 52 |
+
]
|
results/sycophancy_summary.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"head_name": "sycophancy",
|
| 3 |
+
"start_step": 2000,
|
| 4 |
+
"final_step": 27000,
|
| 5 |
+
"best_separation": 230.40327542808419,
|
| 6 |
+
"target_separation": 5.0,
|
| 7 |
+
"achieved": true
|
| 8 |
+
}
|
results/verbosity_results_continued.json
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"step": 3000,
|
| 4 |
+
"accuracy": 0.9691389203071594,
|
| 5 |
+
"precision": 0.1369258846192842,
|
| 6 |
+
"recall": 0.06012644138729353,
|
| 7 |
+
"f1": 0.08356020294518005,
|
| 8 |
+
"pos_risk": 0.4014969766139984,
|
| 9 |
+
"neg_risk": 0.26297321915626526,
|
| 10 |
+
"separation": 1.52675994119165
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"step": 4000,
|
| 14 |
+
"accuracy": 0.9612593054771423,
|
| 15 |
+
"precision": 0.11997728973650933,
|
| 16 |
+
"recall": 0.10349049463514537,
|
| 17 |
+
"f1": 0.11112571858828028,
|
| 18 |
+
"pos_risk": 0.3859938085079193,
|
| 19 |
+
"neg_risk": 0.21928799152374268,
|
| 20 |
+
"separation": 1.7602140720328825
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"step": 5000,
|
| 24 |
+
"accuracy": 0.8646724820137024,
|
| 25 |
+
"precision": 0.0876178851490621,
|
| 26 |
+
"recall": 0.5081474555896888,
|
| 27 |
+
"f1": 0.14946423485272597,
|
| 28 |
+
"pos_risk": 0.45228344202041626,
|
| 29 |
+
"neg_risk": 0.2378995418548584,
|
| 30 |
+
"separation": 1.9011530602120816
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"step": 6000,
|
| 34 |
+
"accuracy": 0.86240553855896,
|
| 35 |
+
"precision": 0.08743681522381432,
|
| 36 |
+
"recall": 0.5171408218690174,
|
| 37 |
+
"f1": 0.1495825968816301,
|
| 38 |
+
"pos_risk": 0.45121535658836365,
|
| 39 |
+
"neg_risk": 0.22329428791999817,
|
| 40 |
+
"separation": 2.0207205513023467
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"step": 7000,
|
| 44 |
+
"accuracy": 0.8883561491966248,
|
| 45 |
+
"precision": 0.09316823884267353,
|
| 46 |
+
"recall": 0.4318151462535061,
|
| 47 |
+
"f1": 0.1532675426467451,
|
| 48 |
+
"pos_risk": 0.43008705973625183,
|
| 49 |
+
"neg_risk": 0.20528849959373474,
|
| 50 |
+
"separation": 2.0950372796693078
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"step": 8000,
|
| 54 |
+
"accuracy": 0.8657205700874329,
|
| 55 |
+
"precision": 0.08880243245644875,
|
| 56 |
+
"recall": 0.5116646631939806,
|
| 57 |
+
"f1": 0.15133907260785828,
|
| 58 |
+
"pos_risk": 0.44835221767425537,
|
| 59 |
+
"neg_risk": 0.21308068931102753,
|
| 60 |
+
"separation": 2.1041428912397078
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"step": 9000,
|
| 64 |
+
"accuracy": 0.853208601474762,
|
| 65 |
+
"precision": 0.08620888430834804,
|
| 66 |
+
"recall": 0.5493076888829527,
|
| 67 |
+
"f1": 0.1490290104089601,
|
| 68 |
+
"pos_risk": 0.45867228507995605,
|
| 69 |
+
"neg_risk": 0.21553963422775269,
|
| 70 |
+
"separation": 2.128018295675932
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"step": 10000,
|
| 74 |
+
"accuracy": 0.8786846399307251,
|
| 75 |
+
"precision": 0.09251913030283324,
|
| 76 |
+
"recall": 0.4750456346556253,
|
| 77 |
+
"f1": 0.15487504399859206,
|
| 78 |
+
"pos_risk": 0.43932676315307617,
|
| 79 |
+
"neg_risk": 0.20558473467826843,
|
| 80 |
+
"separation": 2.1369619871854995
|
| 81 |
+
}
|
| 82 |
+
]
|
results/verbosity_summary.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"head_name": "verbosity",
|
| 3 |
+
"start_step": 10000,
|
| 4 |
+
"final_step": 35000,
|
| 5 |
+
"best_separation": 272.3502040867811,
|
| 6 |
+
"target_separation": 10.0,
|
| 7 |
+
"achieved": true
|
| 8 |
+
}
|
suppression/hedging_168x/fiber_proj.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f52b9c331f8a598fbf28b9fa909b4aa739491208cc88150b1f6b4adb025603f3
|
| 3 |
+
size 790136
|