Spaces:
Sleeping
Sleeping
File size: 7,041 Bytes
acb327b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 | #!/usr/bin/env python3
"""
ECHO ULTIMATE — CLI entry point.
python run.py download Download all 7 task datasets
python run.py test Smoke test — 3 sample episodes
python run.py baseline Evaluate 4 baselines, generate all 6 plots
python run.py plots Generate all plots (synthetic, no eval needed)
python run.py train Full GRPO training (GPU required)
python run.py eval Evaluate trained model
python run.py demo Launch Gradio demo on :7860
python run.py server Launch FastAPI server on :8000
python run.py all download + train + eval
"""
import logging, sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s — %(message)s",
handlers=[logging.StreamHandler(sys.stdout)])
def cmd_download():
from scripts.download_tasks import main; main()
def cmd_test():
print("🧪 ECHO ULTIMATE smoke test…\n")
from config import cfg
from env.echo_env import EchoEnv
from env.task_bank import TaskBank
bank = TaskBank(); bank.ensure_loaded()
env = EchoEnv(task_bank=bank, phase=1, render_mode="human")
scenarios = [
("<confidence>75</confidence><answer>Paris</answer>", "Correct, calibrated"),
("<confidence>95</confidence><answer>wrong</answer>", "Wrong, overconfident → penalty"),
("<confidence>30</confidence><answer>wrong</answer>", "Wrong, humble → small loss"),
]
for i, (action, label) in enumerate(scenarios, 1):
state, _ = env.reset()
print(f" Episode {i} ({label})")
print(f" Domain: {state['domain']} | Difficulty: {state['difficulty']}")
_, reward, _, _, info = env.step(action)
print(f" Confidence: {info['parsed_confidence']}% | Correct: {info['was_correct']}")
print(f" Reward: {reward:+.3f} | OC Penalty: {info['overconfidence_penalty']:.2f}\n")
snap = bank._tasks # loaded
print(f" Domains loaded: {list(snap.keys())}")
print("\n✅ Smoke test passed.")
def cmd_baseline():
from scripts.run_baseline import main; main()
def cmd_plots():
from scripts.generate_plots import main; main()
def cmd_train():
print("🚀 ECHO ULTIMATE GRPO training…")
print(" Requires GPU. Estimated: 2-4 hours on A100.")
from config import cfg
from env.task_bank import TaskBank
from training.train import train
bank = TaskBank(); bank.ensure_loaded()
try:
import wandb; use_wandb = True; print(" 📊 WandB enabled")
except ImportError:
use_wandb = False; print(" 📊 WandB not found — CSV logging only")
train(cfg.MODEL_NAME, cfg.MODEL_SAVE_DIR, task_bank=bank, use_wandb=use_wandb)
def cmd_eval():
print("📊 Evaluating…")
from config import cfg
from pathlib import Path
from env.task_bank import TaskBank
from training.evaluate import evaluate_agent, compare_and_plot, make_synthetic_pair
Path(cfg.PLOTS_DIR).mkdir(parents=True, exist_ok=True)
bank = TaskBank(); bank.ensure_loaded()
if Path(cfg.MODEL_SAVE_DIR).exists():
print(f" 🤖 Loading trained model from {cfg.MODEL_SAVE_DIR}…")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(cfg.MODEL_SAVE_DIR)
model = AutoModelForCausalLM.from_pretrained(cfg.MODEL_SAVE_DIR, torch_dtype="auto")
model.eval()
def agent_fn(p):
inp = tok(p, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
out = model.generate(**inp, max_new_tokens=cfg.MAX_NEW_TOKENS,
temperature=cfg.TEMPERATURE, do_sample=True)
return tok.decode(out[0][inp["input_ids"].shape[1]:], skip_special_tokens=True)
trained = evaluate_agent(agent_fn, bank, label="ECHO Trained")
else:
print(" ⚠️ No trained model found — using synthetic results")
_, trained = make_synthetic_pair()
trained.label = "ECHO Trained"
from core.baseline import AlwaysHighAgent
untrained = evaluate_agent(AlwaysHighAgent(), bank, label="Untrained")
compare_and_plot(trained, {"Untrained": untrained})
print("\n✅ Eval complete. Plots saved to results/plots/")
def cmd_demo():
print("🎨 Launching Gradio demo → http://localhost:7860")
from ui.app import main; main()
def cmd_server():
print("🖥️ Launching FastAPI server → http://localhost:8000/docs")
import uvicorn
from config import cfg
uvicorn.run("server.app:app", host=cfg.API_HOST, port=cfg.API_PORT, reload=False)
def cmd_all():
cmd_download(); cmd_train(); cmd_eval()
print("\n🎉 Full pipeline complete!")
def cmd_publish_benchmark():
print("📦 Publishing EchoBench to HuggingFace Hub…")
token = input("Enter HuggingFace write token: ").strip()
if not token:
print("❌ No token provided.")
return
from scripts.publish_echobench import main as _pub_main
import sys as _sys
_sys.argv = ["publish_echobench.py", "--token", token]
_pub_main()
COMMANDS = {
"download": cmd_download,
"test": cmd_test,
"baseline": cmd_baseline,
"plots": cmd_plots,
"train": cmd_train,
"eval": cmd_eval,
"demo": cmd_demo,
"server": cmd_server,
"all": cmd_all,
"publish-benchmark": cmd_publish_benchmark,
}
HELP = """
ECHO ULTIMATE — Metacognitive Calibration RL Environment
python run.py download Download 7 task datasets from HuggingFace
python run.py test Smoke test (no GPU, ~5 seconds)
python run.py baseline Evaluate 4 baselines, generate 6 plots
python run.py plots Generate all plots (synthetic data, instant)
python run.py train GRPO training curriculum (GPU, 2-4h)
python run.py eval Evaluate trained model, generate plots
python run.py demo Gradio demo → localhost:7860
python run.py server FastAPI server → localhost:8000
python run.py all download + train + eval
python run.py publish-benchmark Publish EchoBench to HuggingFace Hub
Start here (no GPU needed):
python run.py test
python run.py plots
python run.py baseline
"""
if __name__ == "__main__":
if len(sys.argv) < 2 or sys.argv[1] in ("-h","--help","help"):
print(HELP); sys.exit(0)
cmd = sys.argv[1].lower()
if cmd not in COMMANDS:
print(f"❌ Unknown: {cmd}\n Available: {', '.join(COMMANDS)}")
sys.exit(1)
try:
COMMANDS[cmd]()
except KeyboardInterrupt:
print("\n⏹️ Stopped.")
except Exception as e:
logging.getLogger(__name__).exception("Command '%s' failed", cmd)
sys.exit(1)
|