| |
| """ |
| Agent Zero Orchestrator — Gradio Space App |
| =========================================== |
| Fully autonomous self-healing training on FREE CPU tier. |
| Auto-resume across Space sleeps. Live dashboard. |
| """ |
| import os, sys, json, time, threading, traceback |
| from pathlib import Path |
| from datetime import datetime |
| from typing import Optional, Dict, Any |
|
|
| import gradio as gr |
| import plotly.graph_objects as go |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
| from self_healing import SelfHealingTrainer, HealingConfig |
|
|
| |
| training_thread: Optional[threading.Thread] = None |
| stop_event = threading.Event() |
| state: Dict[str, Any] = {"running": False, "step": 0, "loss": None, |
| "recoveries": 0, "zclip_clips": 0, "start_time": None, |
| "logs": [], "recovery_history": [], "status": "idle"} |
| STATE_FILE = Path("/app/training_state.json") |
| CKPT_DIR = Path("/app/checkpoints") |
|
|
| def _log(msg: str): |
| ts = datetime.now().strftime("%H:%M:%S") |
| entry = f"[{ts}] {msg}" |
| state["logs"].append(entry) |
| print(entry, flush=True) |
| if len(state["logs"]) > 500: state["logs"] = state["logs"][-500:] |
|
|
| def save_state(): |
| try: |
| with open(STATE_FILE, "w") as f: |
| json.dump({k: v for k, v in state.items() if k != "logs"}, f, default=str) |
| except: pass |
|
|
| def load_state(): |
| if STATE_FILE.exists(): |
| try: |
| with open(STATE_FILE) as f: state.update(json.load(f)) |
| except: pass |
| load_state() |
|
|
| def worker(model_id: str, dataset_id: str, max_steps: int, lr: float, |
| batch_size: int, hub_user: str, push_hub: bool): |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from datasets import load_dataset |
| from trl import SFTConfig, SFTTrainer |
|
|
| state["running"] = True; state["status"] = "loading" |
| state["start_time"] = time.time(); stop_event.clear() |
| state["logs"] = []; state["step"] = 0 |
|
|
| try: |
| _log(f"Loading {model_id}...") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True) |
| tok = AutoTokenizer.from_pretrained(model_id) |
| if tok.pad_token is None: tok.pad_token = tok.eos_token |
|
|
| _log(f"Loading dataset {dataset_id}...") |
| ds = load_dataset(dataset_id, split="train[:500]") |
|
|
| state["status"] = "training" |
| args = SFTConfig( |
| output_dir=str(CKPT_DIR), per_device_train_batch_size=batch_size, |
| gradient_accumulation_steps=4, learning_rate=lr, max_steps=max_steps, |
| logging_steps=1, logging_strategy="steps", logging_first_step=True, |
| save_steps=10, save_total_limit=5, use_cpu=True, |
| report_to="none", disable_tqdm=True, |
| push_to_hub=push_hub, |
| hub_model_id=f"{hub_user}/agent-zero-model" if push_hub else None) |
|
|
| trainer = SFTTrainer(model=model, args=args, train_dataset=ds, tokenizer=tok) |
|
|
| hcfg = HealingConfig(nan_patience=2, loss_spike_factor=5.0, |
| divergence_patience=30, grad_explosion_threshold=50.0, |
| zclip_enabled=True, zclip_z_threshold=3.0, |
| max_recovery_attempts=5, max_lr_reductions=3, |
| max_batch_reductions=2, postmortem_path="/app/postmortem.json") |
| sh = SelfHealingTrainer(trainer, hcfg) |
|
|
| resume = None |
| if CKPT_DIR.exists(): |
| cks = sorted(CKPT_DIR.glob("checkpoint-*")) |
| if cks: resume = str(cks[-1]); _log(f"Resuming from {resume}") |
|
|
| _log("Dry-run...") |
| sh.dry_run(num_steps=2) |
| _log("Starting training!") |
|
|
| sh.train(resume_from_checkpoint=resume) |
| state["status"] = "completed" |
| rpt = sh.get_report() |
| state["recoveries"] = rpt["total_recoveries"] |
| state["zclip_clips"] = rpt["zclip_total_clips"] |
| _log(f"Done! Recoveries: {rpt['total_recoveries']}") |
| if push_hub: _log(f"Pushed to {hub_user}/agent-zero-model") |
| except Exception as e: |
| state["status"] = f"error: {type(e).__name__}" |
| _log(f"ERROR: {e}"); traceback.print_exc() |
| finally: |
| state["running"] = False; save_state() |
| _log("Thread ended.") |
|
|
| def start(model_id, dataset_id, max_steps, lr, batch_size, hub_user, push_hub): |
| global training_thread |
| if state["running"]: return "Already running!", "" |
| state["logs"] = []; state["step"] = 0; state["recoveries"] = 0; state["zclip_clips"] = 0 |
| training_thread = threading.Thread(target=worker, daemon=True, |
| args=(model_id, dataset_id, int(max_steps), float(lr), int(batch_size), hub_user, push_hub)) |
| training_thread.start() |
| return "Training started!", "" |
|
|
| def stop(): |
| stop_event.set(); state["running"] = False; state["status"] = "stopped" |
| save_state(); return "Stop signal sent.", "" |
|
|
| def get_logs(): return "\n".join(state["logs"][-50:]) |
|
|
| def get_status(): |
| el = f" | {int(time.time()-state['start_time'])}s" if state["start_time"] else "" |
| return f"Status: {state['status']} | Step: {state['step']} | Rec: {state['recoveries']} | ZClip: {state['zclip_clips']}{el}" |
|
|
| def get_pm(): |
| p = Path("/app/postmortem.json") |
| return json.dumps(json.load(open(p)), indent=2) if p.exists() else "No postmortem yet." |
|
|
| def get_plot(): |
| try: |
| p = CKPT_DIR / "trainer_state.json" |
| if p.exists(): |
| with open(p) as f: data = json.load(f) |
| hist = [e for e in data.get("log_history", []) if "loss" in e] |
| if hist: |
| fig = go.Figure() |
| fig.add_trace(go.Scatter(x=[e.get("step", i) for i, e in enumerate(hist)], |
| y=[e["loss"] for e in hist], mode="lines", name="Loss")) |
| fig.update_layout(title="Training Loss", xaxis_title="Step", yaxis_title="Loss", template="plotly_dark") |
| return fig |
| except: pass |
| fig = go.Figure(); fig.update_layout(title="Loss (no data)", template="plotly_dark") |
| return fig |
|
|
| with gr.Blocks(title="Agent Zero Orchestrator", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# 🔄 Agent Zero Orchestrator\n**Self-healing ML training. Free CPU. Auto-resume. Zero credits.**") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### Config") |
| m = gr.Textbox(value="HuggingFaceTB/SmolLM2-135M", label="Model") |
| d = gr.Textbox(value="trl-lib/Capybara", label="Dataset") |
| s = gr.Number(value=100, label="Max Steps", minimum=10) |
| l = gr.Number(value=2e-5, label="LR", format=".2e") |
| b = gr.Number(value=1, label="Batch Size", minimum=1) |
| u = gr.Textbox(value="ScottzillaSystems", label="Hub User") |
| p = gr.Checkbox(value=False, label="Push to Hub") |
| with gr.Row(): |
| gr.Button("🚀 Start", variant="primary").click(start, [m,d,s,l,b,u,p], [gr.Textbox(label="Status"), gr.Textbox(label="Logs")]) |
| gr.Button("⏹ Stop", variant="stop").click(stop, outputs=[gr.Textbox(label="Status"), gr.Textbox(label="Logs")]) |
| with gr.Column(scale=2): |
| gr.Markdown("### Dashboard") |
| gr.Textbox(value=get_status, label="Status", every=2, interactive=False) |
| gr.Plot(value=get_plot, label="Loss", every=5) |
| with gr.Row(): |
| gr.Textbox(value=get_logs, label="Logs", lines=20, every=2, interactive=False) |
| gr.Textbox(value=get_pm, label="Postmortem", lines=20, every=10, interactive=False) |
| gr.Markdown("Papers: Unicron arxiv:2401.00134 | ZClip arxiv:2504.02507 | Pioneer Agent arxiv:2604.09791") |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|