cernenv / training /evaluate.py
anugrahhu's picture
feat: interactive Gradio demo at /demo
9c00159 verified
"""Evaluate an LLM (with optional LoRA adapters) on CERNenv.
Usage:
python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
--difficulty easy --episodes 16 --tag pre_train \\
--out training/runs/eval_pre_train.jsonl
python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
--adapter_dir training/runs/unsloth-grpo --difficulty easy \\
--episodes 16 --tag post_train --out training/runs/eval_post_train.jsonl
"""
from __future__ import annotations
import argparse
import json
import logging
import os
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, List, Optional
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
def _build_generate_fn(
*,
model_name: str,
adapter_dir: Optional[str],
use_unsloth: bool,
max_seq_length: int,
):
if use_unsloth:
from unsloth import FastLanguageModel # type: ignore
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=True,
# fast_inference requires vLLM, which is not in requirements; plain transformers generation is used instead. Re-enable after pinning vllm in space/training/requirements.txt.
fast_inference=False,
)
if adapter_dir:
model.load_adapter(adapter_dir)
FastLanguageModel.for_inference(model)
else:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if adapter_dir:
from peft import PeftModel # type: ignore
model = PeftModel.from_pretrained(model, adapter_dir)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def prompt_fn(chat: List[Dict[str, str]]) -> str:
return tokenizer.apply_chat_template(
chat, add_generation_prompt=True, tokenize=False
)
def generate_fn(prompt: str, config) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=config.max_new_tokens,
do_sample=True,
temperature=config.temperature,
top_p=config.top_p,
pad_token_id=tokenizer.pad_token_id,
)
gen = outputs[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(gen, skip_special_tokens=True)
return prompt_fn, generate_fn
def main() -> None: # pragma: no cover
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", required=True)
parser.add_argument("--adapter_dir", default=None)
parser.add_argument("--scenario", default=None)
parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
parser.add_argument("--episodes", type=int, default=16)
parser.add_argument("--seed", type=int, default=1000)
parser.add_argument("--max_steps", type=int, default=18)
parser.add_argument("--max_seq_length", type=int, default=2048)
parser.add_argument("--no_unsloth", action="store_true")
parser.add_argument("--tag", default="eval")
parser.add_argument("--out", required=True)
args = parser.parse_args()
from server.environment import CERNCollisionEnvironment
from training.llm_agent import LLMAgentConfig
from training.rollouts import collect_episode, save_episodes_jsonl
use_unsloth = not args.no_unsloth
try:
prompt_fn, generate_fn = _build_generate_fn(
model_name=args.model_name,
adapter_dir=args.adapter_dir,
use_unsloth=use_unsloth,
max_seq_length=args.max_seq_length,
)
except ImportError as exc:
logger.warning("Unsloth not available (%s); falling back to transformers.", exc)
prompt_fn, generate_fn = _build_generate_fn(
model_name=args.model_name,
adapter_dir=args.adapter_dir,
use_unsloth=False,
max_seq_length=args.max_seq_length,
)
env = CERNCollisionEnvironment(max_steps=args.max_steps)
cfg = LLMAgentConfig()
episodes = []
for ep in range(args.episodes):
seed = args.seed + ep
rec = collect_episode(
env=env,
seed=seed,
scenario=args.scenario,
difficulty=args.difficulty,
prompt_fn=prompt_fn,
generate_fn=generate_fn,
config=cfg,
)
episodes.append(rec)
logger.info(
"[%s][%d/%d] reward=%+.3f discovered=%s mass=%s channel=%s",
args.tag, ep + 1, args.episodes,
rec.cumulative_reward, rec.discovered, rec.correct_mass, rec.correct_channel,
)
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
save_episodes_jsonl(episodes, args.out)
rewards = [e.cumulative_reward for e in episodes]
success = sum(1 for e in episodes if e.discovered) / len(episodes)
logger.info("[%s] mean_reward=%.3f success_rate=%.2f", args.tag, sum(rewards) / len(rewards), success)
if __name__ == "__main__": # pragma: no cover
main()