File size: 7,811 Bytes
93f6542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#!/usr/bin/env python3
"""
AETHER Training Script.
Integrates TRL GRPO for agent training with custom rewards,
smolagents for multi-agent orchestration,
neuro-symbolic reasoning, and evolutionary optimization.
"""

import os
import sys
import json
import logging
import argparse
from typing import List

import torch

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward, think_format_reward

from aether.core import AetherCore, AetherConfig
from aether.knowledge import KnowledgeGraphEngine

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("AETHER.Train")


def aether_reward(completions: List[str], **kwargs) -> List[float]:
    """AETHER neuro-symbolic reward combining reasoning structure and knowledge coherence."""
    rewards = []
    for completion in completions:
        score = 0.0
        text = completion if isinstance(completion, str) else str(completion)
        
        if "<think>" in text and "</think>" in text:
            score += 0.3
        
        steps = sum(1 for s in text.split("\n") if any(s.strip().startswith(p) for p in ["1.", "2.", "3.", "4.", "5.", "Step", "Phase"]))
        score += min(steps * 0.05, 0.25)
        
        if any(kw in text.lower() for kw in ["therefore", "because", "implies", "consequently"]):
            score += 0.2
        
        if any(kw in text.lower() for kw in ["sub-goal", "blueprint", "plan", "phase"]):
            score += 0.15
        
        if any(kw in text.lower() for kw in ["reflect", "evaluate", "improve", "evolve"]):
            score += 0.1
        
        rewards.append(min(score, 1.0))
    return rewards


def main():
    MODEL_NAME = os.environ.get("AETHER_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
    OUTPUT_DIR = os.environ.get("AETHER_OUTPUT", "./aether-output")
    
    trackio_space_id = os.environ.get("TRACKIO_SPACE_ID")
    trackio_project = os.environ.get("TRACKIO_PROJECT", "aether-evolution")
    
    logger.info("=" * 60)
    logger.info("AETHER TRAINING - GRPO with Neuro-Symbolic Rewards")
    logger.info("=" * 60)
    logger.info(f"Model: {MODEL_NAME}")
    logger.info(f"Output: {OUTPUT_DIR}")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Device: {device}")
    
    logger.info("Loading model...")
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=dtype,
        device_map="auto" if torch.cuda.is_available() else None,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    logger.info("Loading dataset...")
    try:
        dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
        logger.info(f"Loaded DeepMath-103K: {len(dataset)} examples")
    except Exception as e:
        logger.warning(f"DeepMath failed: {e}")
        try:
            dataset = load_dataset("trl-lib/Capybara", split="train")
            logger.info(f"Loaded Capybara: {len(dataset)} examples")
        except Exception as e2:
            logger.warning(f"Capybara failed: {e2}")
            from datasets import Dataset
            prompts = [
                {"prompt": "Think step by step and solve: If a train travels 240 km in 3 hours, what is its average speed?"},
                {"prompt": "Plan and reason: You have 5 shelves and need to store 150 books evenly. How many per shelf?"},
                {"prompt": "Analyze and explain: Why does recursive self-improvement require safety constraints?"},
                {"prompt": "Break down into phases: How would you build a self-evolving AI system?"},
                {"prompt": "Reflect and improve: A previous solution had an error in step 3. How would you fix it?"},
                {"prompt": "Think about this: What are the trade-offs between symbolic and neural reasoning?"},
                {"prompt": "Plan a hierarchy: Design a multi-agent system with a manager and workers."},
                {"prompt": "Evolve this solution: Start with a simple sorting algorithm and improve it iteratively."},
                {"prompt": "Knowledge reasoning: Given that all birds can fly and penguins are birds, what can you conclude?"},
                {"prompt": "Meta-cognitive analysis: Evaluate your own reasoning process and identify biases."},
            ] * 100
            dataset = Dataset.from_list(prompts)
            logger.info(f"Created fallback dataset: {len(dataset)} examples")
    
    if "prompt" not in dataset.column_names:
        if "text" in dataset.column_names:
            dataset = dataset.rename_column("text", "prompt")
        elif "messages" in dataset.column_names:
            def extract_prompt(examples):
                prompts = []
                for msgs in examples["messages"]:
                    for msg in msgs:
                        if msg.get("role") == "user":
                            prompts.append(msg.get("content", ""))
                            break
                    else:
                        prompts.append(str(msgs))
                return {"prompt": prompts}
            dataset = dataset.map(extract_prompt, batched=True, remove_columns=dataset.column_names)
        elif "question" in dataset.column_names:
            dataset = dataset.rename_column("question", "prompt")
    
    dataset = dataset.train_test_split(test_size=0.1)
    train_ds = dataset["train"]
    eval_ds = dataset["test"]
    logger.info(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
    
    training_args = GRPOConfig(
        output_dir=OUTPUT_DIR,
        num_train_epochs=1,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=8,
        learning_rate=2e-5,
        logging_steps=10,
        save_steps=100,
        eval_strategy="steps",
        eval_steps=50,
        bf16=torch.cuda.is_available(),
        max_completion_length=512,
        num_generations=4,
        report_to="trackio" if trackio_space_id else [],
        run_name=f"aether-grpo-{MODEL_NAME.split('/')[-1]}",
        project=trackio_project,
        trackio_space_id=trackio_space_id,
        disable_tqdm=True,
        logging_first_step=True,
        push_to_hub=True,
        hub_model_id=f"camdog920/aether-{MODEL_NAME.split('/')[-1]}-grpo",
    )
    
    reward_funcs = [
        aether_reward,
        accuracy_reward,
        think_format_reward,
    ]
    
    logger.info("Initializing GRPO Trainer...")
    trainer = GRPOTrainer(
        model=model,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
    )
    
    logger.info("Starting training...")
    trainer.train()
    
    logger.info("Saving model...")
    trainer.save_model(OUTPUT_DIR)
    tokenizer.save_pretrained(OUTPUT_DIR)
    
    metadata = {
        "aether_version": "0.1.0",
        "training_method": "GRPO",
        "model_name": MODEL_NAME,
        "reward_functions": ["aether_reward", "accuracy_reward", "think_format_reward"],
    }
    with open(os.path.join(OUTPUT_DIR, "aether_metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)
    
    logger.info("=" * 60)
    logger.info("Training complete!")
    logger.info(f"Model: https://huggingface.co/{training_args.hub_model_id}")
    if trackio_space_id:
        logger.info(f"Dashboard: https://huggingface.co/spaces/{trackio_space_id}")
    logger.info("=" * 60)


if __name__ == "__main__":
    main()