camdog920 commited on
Commit
93f6542
·
verified ·
1 Parent(s): 315c661

Upload aether_train.py

Browse files
Files changed (1) hide show
  1. aether_train.py +201 -0
aether_train.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AETHER Training Script.
4
+ Integrates TRL GRPO for agent training with custom rewards,
5
+ smolagents for multi-agent orchestration,
6
+ neuro-symbolic reasoning, and evolutionary optimization.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import logging
13
+ import argparse
14
+ from typing import List
15
+
16
+ import torch
17
+
18
+ from datasets import load_dataset
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer
20
+ from trl import GRPOTrainer, GRPOConfig
21
+ from trl.rewards import accuracy_reward, think_format_reward
22
+
23
+ from aether.core import AetherCore, AetherConfig
24
+ from aether.knowledge import KnowledgeGraphEngine
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
29
+ )
30
+ logger = logging.getLogger("AETHER.Train")
31
+
32
+
33
+ def aether_reward(completions: List[str], **kwargs) -> List[float]:
34
+ """AETHER neuro-symbolic reward combining reasoning structure and knowledge coherence."""
35
+ rewards = []
36
+ for completion in completions:
37
+ score = 0.0
38
+ text = completion if isinstance(completion, str) else str(completion)
39
+
40
+ if "<think>" in text and "</think>" in text:
41
+ score += 0.3
42
+
43
+ steps = sum(1 for s in text.split("\n") if any(s.strip().startswith(p) for p in ["1.", "2.", "3.", "4.", "5.", "Step", "Phase"]))
44
+ score += min(steps * 0.05, 0.25)
45
+
46
+ if any(kw in text.lower() for kw in ["therefore", "because", "implies", "consequently"]):
47
+ score += 0.2
48
+
49
+ if any(kw in text.lower() for kw in ["sub-goal", "blueprint", "plan", "phase"]):
50
+ score += 0.15
51
+
52
+ if any(kw in text.lower() for kw in ["reflect", "evaluate", "improve", "evolve"]):
53
+ score += 0.1
54
+
55
+ rewards.append(min(score, 1.0))
56
+ return rewards
57
+
58
+
59
+ def main():
60
+ MODEL_NAME = os.environ.get("AETHER_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
61
+ OUTPUT_DIR = os.environ.get("AETHER_OUTPUT", "./aether-output")
62
+
63
+ trackio_space_id = os.environ.get("TRACKIO_SPACE_ID")
64
+ trackio_project = os.environ.get("TRACKIO_PROJECT", "aether-evolution")
65
+
66
+ logger.info("=" * 60)
67
+ logger.info("AETHER TRAINING - GRPO with Neuro-Symbolic Rewards")
68
+ logger.info("=" * 60)
69
+ logger.info(f"Model: {MODEL_NAME}")
70
+ logger.info(f"Output: {OUTPUT_DIR}")
71
+
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ logger.info(f"Device: {device}")
74
+
75
+ logger.info("Loading model...")
76
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ MODEL_NAME,
79
+ torch_dtype=dtype,
80
+ device_map="auto" if torch.cuda.is_available() else None,
81
+ trust_remote_code=True,
82
+ )
83
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
84
+ if tokenizer.pad_token is None:
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ logger.info("Loading dataset...")
88
+ try:
89
+ dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
90
+ logger.info(f"Loaded DeepMath-103K: {len(dataset)} examples")
91
+ except Exception as e:
92
+ logger.warning(f"DeepMath failed: {e}")
93
+ try:
94
+ dataset = load_dataset("trl-lib/Capybara", split="train")
95
+ logger.info(f"Loaded Capybara: {len(dataset)} examples")
96
+ except Exception as e2:
97
+ logger.warning(f"Capybara failed: {e2}")
98
+ from datasets import Dataset
99
+ prompts = [
100
+ {"prompt": "Think step by step and solve: If a train travels 240 km in 3 hours, what is its average speed?"},
101
+ {"prompt": "Plan and reason: You have 5 shelves and need to store 150 books evenly. How many per shelf?"},
102
+ {"prompt": "Analyze and explain: Why does recursive self-improvement require safety constraints?"},
103
+ {"prompt": "Break down into phases: How would you build a self-evolving AI system?"},
104
+ {"prompt": "Reflect and improve: A previous solution had an error in step 3. How would you fix it?"},
105
+ {"prompt": "Think about this: What are the trade-offs between symbolic and neural reasoning?"},
106
+ {"prompt": "Plan a hierarchy: Design a multi-agent system with a manager and workers."},
107
+ {"prompt": "Evolve this solution: Start with a simple sorting algorithm and improve it iteratively."},
108
+ {"prompt": "Knowledge reasoning: Given that all birds can fly and penguins are birds, what can you conclude?"},
109
+ {"prompt": "Meta-cognitive analysis: Evaluate your own reasoning process and identify biases."},
110
+ ] * 100
111
+ dataset = Dataset.from_list(prompts)
112
+ logger.info(f"Created fallback dataset: {len(dataset)} examples")
113
+
114
+ if "prompt" not in dataset.column_names:
115
+ if "text" in dataset.column_names:
116
+ dataset = dataset.rename_column("text", "prompt")
117
+ elif "messages" in dataset.column_names:
118
+ def extract_prompt(examples):
119
+ prompts = []
120
+ for msgs in examples["messages"]:
121
+ for msg in msgs:
122
+ if msg.get("role") == "user":
123
+ prompts.append(msg.get("content", ""))
124
+ break
125
+ else:
126
+ prompts.append(str(msgs))
127
+ return {"prompt": prompts}
128
+ dataset = dataset.map(extract_prompt, batched=True, remove_columns=dataset.column_names)
129
+ elif "question" in dataset.column_names:
130
+ dataset = dataset.rename_column("question", "prompt")
131
+
132
+ dataset = dataset.train_test_split(test_size=0.1)
133
+ train_ds = dataset["train"]
134
+ eval_ds = dataset["test"]
135
+ logger.info(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
136
+
137
+ training_args = GRPOConfig(
138
+ output_dir=OUTPUT_DIR,
139
+ num_train_epochs=1,
140
+ per_device_train_batch_size=1,
141
+ per_device_eval_batch_size=1,
142
+ gradient_accumulation_steps=8,
143
+ learning_rate=2e-5,
144
+ logging_steps=10,
145
+ save_steps=100,
146
+ eval_strategy="steps",
147
+ eval_steps=50,
148
+ bf16=torch.cuda.is_available(),
149
+ max_completion_length=512,
150
+ num_generations=4,
151
+ report_to="trackio" if trackio_space_id else [],
152
+ run_name=f"aether-grpo-{MODEL_NAME.split('/')[-1]}",
153
+ project=trackio_project,
154
+ trackio_space_id=trackio_space_id,
155
+ disable_tqdm=True,
156
+ logging_first_step=True,
157
+ push_to_hub=True,
158
+ hub_model_id=f"camdog920/aether-{MODEL_NAME.split('/')[-1]}-grpo",
159
+ )
160
+
161
+ reward_funcs = [
162
+ aether_reward,
163
+ accuracy_reward,
164
+ think_format_reward,
165
+ ]
166
+
167
+ logger.info("Initializing GRPO Trainer...")
168
+ trainer = GRPOTrainer(
169
+ model=model,
170
+ reward_funcs=reward_funcs,
171
+ args=training_args,
172
+ train_dataset=train_ds,
173
+ eval_dataset=eval_ds,
174
+ )
175
+
176
+ logger.info("Starting training...")
177
+ trainer.train()
178
+
179
+ logger.info("Saving model...")
180
+ trainer.save_model(OUTPUT_DIR)
181
+ tokenizer.save_pretrained(OUTPUT_DIR)
182
+
183
+ metadata = {
184
+ "aether_version": "0.1.0",
185
+ "training_method": "GRPO",
186
+ "model_name": MODEL_NAME,
187
+ "reward_functions": ["aether_reward", "accuracy_reward", "think_format_reward"],
188
+ }
189
+ with open(os.path.join(OUTPUT_DIR, "aether_metadata.json"), "w") as f:
190
+ json.dump(metadata, f, indent=2)
191
+
192
+ logger.info("=" * 60)
193
+ logger.info("Training complete!")
194
+ logger.info(f"Model: https://huggingface.co/{training_args.hub_model_id}")
195
+ if trackio_space_id:
196
+ logger.info(f"Dashboard: https://huggingface.co/spaces/{trackio_space_id}")
197
+ logger.info("=" * 60)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()