#!/usr/bin/env python3 """ AETHER Training Script. Integrates TRL GRPO for agent training with custom rewards, smolagents for multi-agent orchestration, neuro-symbolic reasoning, and evolutionary optimization. """ import os import sys import json import logging import argparse from typing import List import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOTrainer, GRPOConfig from trl.rewards import accuracy_reward, think_format_reward from aether.core import AetherCore, AetherConfig from aether.knowledge import KnowledgeGraphEngine logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger("AETHER.Train") def aether_reward(completions: List[str], **kwargs) -> List[float]: """AETHER neuro-symbolic reward combining reasoning structure and knowledge coherence.""" rewards = [] for completion in completions: score = 0.0 text = completion if isinstance(completion, str) else str(completion) if "" in text and "" in text: score += 0.3 steps = sum(1 for s in text.split("\n") if any(s.strip().startswith(p) for p in ["1.", "2.", "3.", "4.", "5.", "Step", "Phase"])) score += min(steps * 0.05, 0.25) if any(kw in text.lower() for kw in ["therefore", "because", "implies", "consequently"]): score += 0.2 if any(kw in text.lower() for kw in ["sub-goal", "blueprint", "plan", "phase"]): score += 0.15 if any(kw in text.lower() for kw in ["reflect", "evaluate", "improve", "evolve"]): score += 0.1 rewards.append(min(score, 1.0)) return rewards def main(): MODEL_NAME = os.environ.get("AETHER_MODEL", "Qwen/Qwen2.5-0.5B-Instruct") OUTPUT_DIR = os.environ.get("AETHER_OUTPUT", "./aether-output") trackio_space_id = os.environ.get("TRACKIO_SPACE_ID") trackio_project = os.environ.get("TRACKIO_PROJECT", "aether-evolution") logger.info("=" * 60) logger.info("AETHER TRAINING - GRPO with Neuro-Symbolic Rewards") logger.info("=" * 60) logger.info(f"Model: {MODEL_NAME}") logger.info(f"Output: {OUTPUT_DIR}") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Device: {device}") logger.info("Loading model...") dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info("Loading dataset...") try: dataset = load_dataset("trl-lib/DeepMath-103K", split="train") logger.info(f"Loaded DeepMath-103K: {len(dataset)} examples") except Exception as e: logger.warning(f"DeepMath failed: {e}") try: dataset = load_dataset("trl-lib/Capybara", split="train") logger.info(f"Loaded Capybara: {len(dataset)} examples") except Exception as e2: logger.warning(f"Capybara failed: {e2}") from datasets import Dataset prompts = [ {"prompt": "Think step by step and solve: If a train travels 240 km in 3 hours, what is its average speed?"}, {"prompt": "Plan and reason: You have 5 shelves and need to store 150 books evenly. How many per shelf?"}, {"prompt": "Analyze and explain: Why does recursive self-improvement require safety constraints?"}, {"prompt": "Break down into phases: How would you build a self-evolving AI system?"}, {"prompt": "Reflect and improve: A previous solution had an error in step 3. How would you fix it?"}, {"prompt": "Think about this: What are the trade-offs between symbolic and neural reasoning?"}, {"prompt": "Plan a hierarchy: Design a multi-agent system with a manager and workers."}, {"prompt": "Evolve this solution: Start with a simple sorting algorithm and improve it iteratively."}, {"prompt": "Knowledge reasoning: Given that all birds can fly and penguins are birds, what can you conclude?"}, {"prompt": "Meta-cognitive analysis: Evaluate your own reasoning process and identify biases."}, ] * 100 dataset = Dataset.from_list(prompts) logger.info(f"Created fallback dataset: {len(dataset)} examples") if "prompt" not in dataset.column_names: if "text" in dataset.column_names: dataset = dataset.rename_column("text", "prompt") elif "messages" in dataset.column_names: def extract_prompt(examples): prompts = [] for msgs in examples["messages"]: for msg in msgs: if msg.get("role") == "user": prompts.append(msg.get("content", "")) break else: prompts.append(str(msgs)) return {"prompt": prompts} dataset = dataset.map(extract_prompt, batched=True, remove_columns=dataset.column_names) elif "question" in dataset.column_names: dataset = dataset.rename_column("question", "prompt") dataset = dataset.train_test_split(test_size=0.1) train_ds = dataset["train"] eval_ds = dataset["test"] logger.info(f"Train: {len(train_ds)}, Eval: {len(eval_ds)}") training_args = GRPOConfig( output_dir=OUTPUT_DIR, num_train_epochs=1, per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=8, learning_rate=2e-5, logging_steps=10, save_steps=100, eval_strategy="steps", eval_steps=50, bf16=torch.cuda.is_available(), max_completion_length=512, num_generations=4, report_to="trackio" if trackio_space_id else [], run_name=f"aether-grpo-{MODEL_NAME.split('/')[-1]}", project=trackio_project, trackio_space_id=trackio_space_id, disable_tqdm=True, logging_first_step=True, push_to_hub=True, hub_model_id=f"camdog920/aether-{MODEL_NAME.split('/')[-1]}-grpo", ) reward_funcs = [ aether_reward, accuracy_reward, think_format_reward, ] logger.info("Initializing GRPO Trainer...") trainer = GRPOTrainer( model=model, reward_funcs=reward_funcs, args=training_args, train_dataset=train_ds, eval_dataset=eval_ds, ) logger.info("Starting training...") trainer.train() logger.info("Saving model...") trainer.save_model(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) metadata = { "aether_version": "0.1.0", "training_method": "GRPO", "model_name": MODEL_NAME, "reward_functions": ["aether_reward", "accuracy_reward", "think_format_reward"], } with open(os.path.join(OUTPUT_DIR, "aether_metadata.json"), "w") as f: json.dump(metadata, f, indent=2) logger.info("=" * 60) logger.info("Training complete!") logger.info(f"Model: https://huggingface.co/{training_args.hub_model_id}") if trackio_space_id: logger.info(f"Dashboard: https://huggingface.co/spaces/{trackio_space_id}") logger.info("=" * 60) if __name__ == "__main__": main()