aether-core / aether_train.py
camdog920's picture
Upload aether_train.py
93f6542 verified
#!/usr/bin/env python3
"""
AETHER Training Script.
Integrates TRL GRPO for agent training with custom rewards,
smolagents for multi-agent orchestration,
neuro-symbolic reasoning, and evolutionary optimization.
"""
import os
import sys
import json
import logging
import argparse
from typing import List
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOTrainer, GRPOConfig
from trl.rewards import accuracy_reward, think_format_reward
from aether.core import AetherCore, AetherConfig
from aether.knowledge import KnowledgeGraphEngine
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger("AETHER.Train")
def aether_reward(completions: List[str], **kwargs) -> List[float]:
"""AETHER neuro-symbolic reward combining reasoning structure and knowledge coherence."""
rewards = []
for completion in completions:
score = 0.0
text = completion if isinstance(completion, str) else str(completion)
if "<think>" in text and "</think>" in text:
score += 0.3
steps = sum(1 for s in text.split("\n") if any(s.strip().startswith(p) for p in ["1.", "2.", "3.", "4.", "5.", "Step", "Phase"]))
score += min(steps * 0.05, 0.25)
if any(kw in text.lower() for kw in ["therefore", "because", "implies", "consequently"]):
score += 0.2
if any(kw in text.lower() for kw in ["sub-goal", "blueprint", "plan", "phase"]):
score += 0.15
if any(kw in text.lower() for kw in ["reflect", "evaluate", "improve", "evolve"]):
score += 0.1
rewards.append(min(score, 1.0))
return rewards
def main():
MODEL_NAME = os.environ.get("AETHER_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
OUTPUT_DIR = os.environ.get("AETHER_OUTPUT", "./aether-output")
trackio_space_id = os.environ.get("TRACKIO_SPACE_ID")
trackio_project = os.environ.get("TRACKIO_PROJECT", "aether-evolution")
logger.info("=" * 60)
logger.info("AETHER TRAINING - GRPO with Neuro-Symbolic Rewards")
logger.info("=" * 60)
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Output: {OUTPUT_DIR}")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device: {device}")
logger.info("Loading model...")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Loading dataset...")
try:
dataset = load_dataset("trl-lib/DeepMath-103K", split="train")
logger.info(f"Loaded DeepMath-103K: {len(dataset)} examples")
except Exception as e:
logger.warning(f"DeepMath failed: {e}")
try:
dataset = load_dataset("trl-lib/Capybara", split="train")
logger.info(f"Loaded Capybara: {len(dataset)} examples")
except Exception as e2:
logger.warning(f"Capybara failed: {e2}")
from datasets import Dataset
prompts = [
{"prompt": "Think step by step and solve: If a train travels 240 km in 3 hours, what is its average speed?"},
{"prompt": "Plan and reason: You have 5 shelves and need to store 150 books evenly. How many per shelf?"},
{"prompt": "Analyze and explain: Why does recursive self-improvement require safety constraints?"},
{"prompt": "Break down into phases: How would you build a self-evolving AI system?"},
{"prompt": "Reflect and improve: A previous solution had an error in step 3. How would you fix it?"},
{"prompt": "Think about this: What are the trade-offs between symbolic and neural reasoning?"},
{"prompt": "Plan a hierarchy: Design a multi-agent system with a manager and workers."},
{"prompt": "Evolve this solution: Start with a simple sorting algorithm and improve it iteratively."},
{"prompt": "Knowledge reasoning: Given that all birds can fly and penguins are birds, what can you conclude?"},
{"prompt": "Meta-cognitive analysis: Evaluate your own reasoning process and identify biases."},
] * 100
dataset = Dataset.from_list(prompts)
logger.info(f"Created fallback dataset: {len(dataset)} examples")
if "prompt" not in dataset.column_names:
if "text" in dataset.column_names:
dataset = dataset.rename_column("text", "prompt")
elif "messages" in dataset.column_names:
def extract_prompt(examples):
prompts = []
for msgs in examples["messages"]:
for msg in msgs:
if msg.get("role") == "user":
prompts.append(msg.get("content", ""))
break
else:
prompts.append(str(msgs))
return {"prompt": prompts}
dataset = dataset.map(extract_prompt, batched=True, remove_columns=dataset.column_names)
elif "question" in dataset.column_names:
dataset = dataset.rename_column("question", "prompt")
dataset = dataset.train_test_split(test_size=0.1)
train_ds = dataset["train"]
eval_ds = dataset["test"]
logger.info(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}")
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=1,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-5,
logging_steps=10,
save_steps=100,
eval_strategy="steps",
eval_steps=50,
bf16=torch.cuda.is_available(),
max_completion_length=512,
num_generations=4,
report_to="trackio" if trackio_space_id else [],
run_name=f"aether-grpo-{MODEL_NAME.split('/')[-1]}",
project=trackio_project,
trackio_space_id=trackio_space_id,
disable_tqdm=True,
logging_first_step=True,
push_to_hub=True,
hub_model_id=f"camdog920/aether-{MODEL_NAME.split('/')[-1]}-grpo",
)
reward_funcs = [
aether_reward,
accuracy_reward,
think_format_reward,
]
logger.info("Initializing GRPO Trainer...")
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
)
logger.info("Starting training...")
trainer.train()
logger.info("Saving model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
metadata = {
"aether_version": "0.1.0",
"training_method": "GRPO",
"model_name": MODEL_NAME,
"reward_functions": ["aether_reward", "accuracy_reward", "think_format_reward"],
}
with open(os.path.join(OUTPUT_DIR, "aether_metadata.json"), "w") as f:
json.dump(metadata, f, indent=2)
logger.info("=" * 60)
logger.info("Training complete!")
logger.info(f"Model: https://huggingface.co/{training_args.hub_model_id}")
if trackio_space_id:
logger.info(f"Dashboard: https://huggingface.co/spaces/{trackio_space_id}")
logger.info("=" * 60)
if __name__ == "__main__":
main()