Add run_eval.py and virtual_api_server.py
Browse files- pipeline/run_eval.py +101 -0
pipeline/run_eval.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main evaluation runner for P2P StableToolBench experiments.
|
| 2 |
+
|
| 3 |
+
Orchestrates: load queries -> load P2P data -> run ReAct inference -> save results.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python run_eval.py --condition baseline --groups G1_instruction --max_queries 5
|
| 7 |
+
python run_eval.py --condition all --groups all
|
| 8 |
+
python run_eval.py --condition p2p_desc --p2p_desc_dir /path/to/your/descriptions
|
| 9 |
+
"""
|
| 10 |
+
import os, sys, json, time, argparse
|
| 11 |
+
from typing import Dict, List, Any, Optional
|
| 12 |
+
from config import (TOOL_ROOT_DIR, SOLVABLE_QUERIES_DIR, OUTPUT_DIR, P2P_DESCRIPTIONS_DIR, P2P_EXAMPLES_DIR, TASK_MODEL, TEMPERATURE, ALL_GROUPS, CONDITION_NAMES, API_SERVER_URL, API_SERVER_PORT)
|
| 13 |
+
from tool_utils import load_query_data, load_p2p_descriptions, load_p2p_examples
|
| 14 |
+
from prompt_builder import build_initial_messages, get_condition_config, gather_examples_for_query
|
| 15 |
+
from llm_client import LLMClient
|
| 16 |
+
from react_loop import ReActRunner
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def run_condition(condition, group, llm, max_queries=None, output_dir=OUTPUT_DIR, p2p_descriptions=None, p2p_examples=None, service_url=None):
|
| 20 |
+
config = get_condition_config(condition, p2p_descriptions, p2p_examples)
|
| 21 |
+
query_path = os.path.join(SOLVABLE_QUERIES_DIR, f"{group}.json")
|
| 22 |
+
if not os.path.exists(query_path):
|
| 23 |
+
print(f"Query file not found: {query_path}"); return {}
|
| 24 |
+
custom_descs = config["custom_descriptions"] if config["use_custom_descriptions"] else None
|
| 25 |
+
queries = load_query_data(query_path, TOOL_ROOT_DIR, custom_descriptions=custom_descs)
|
| 26 |
+
if max_queries: queries = queries[:max_queries]
|
| 27 |
+
print(f"\n{'='*60}\nCondition: {condition} | Group: {group} | Queries: {len(queries)}\n{'='*60}")
|
| 28 |
+
condition_dir = os.path.join(output_dir, condition, group)
|
| 29 |
+
os.makedirs(condition_dir, exist_ok=True)
|
| 30 |
+
results = []
|
| 31 |
+
for i, query_data in enumerate(queries):
|
| 32 |
+
query_id = query_data["query_id"]
|
| 33 |
+
output_file = os.path.join(condition_dir, f"{query_id}_CoT@1.json")
|
| 34 |
+
if os.path.exists(output_file):
|
| 35 |
+
print(f" [{i+1}/{len(queries)}] Query {query_id}: already done, skipping")
|
| 36 |
+
with open(output_file) as f: result = json.load(f)
|
| 37 |
+
results.append(result); continue
|
| 38 |
+
print(f" [{i+1}/{len(queries)}] Query {query_id}: {query_data['query'][:80]}...")
|
| 39 |
+
examples = None
|
| 40 |
+
if config["use_examples"] and config["examples"]:
|
| 41 |
+
examples = gather_examples_for_query(query_data["tool_names"], query_data["api_name_reflect"], query_data["functions"], config["examples"], max_per_tool=1)
|
| 42 |
+
messages = build_initial_messages(query=query_data["query"], tool_descriptions=query_data["tool_descriptions"], examples=examples)
|
| 43 |
+
runner = ReActRunner(llm=llm, functions=query_data["functions"], tool_descriptions=query_data["tool_descriptions"], api_name_reflect=query_data["api_name_reflect"], tool_names=query_data["tool_names"], cate_names=query_data["cate_names"], service_url=service_url)
|
| 44 |
+
result = runner.run(messages)
|
| 45 |
+
result["query"], result["query_id"], result["condition"], result["group"] = query_data["query"], query_id, condition, group
|
| 46 |
+
with open(output_file, "w") as f:
|
| 47 |
+
save_data = {"query_id": query_id, "query": query_data["query"], "condition": condition, "group": group, "success": result["success"], "final_answer": result["final_answer"], "give_up": result["give_up"], "total_tokens": result["total_tokens"], "query_count": result["query_count"], "steps": result["steps"], "trajectory": result["trajectory"], "available_tools": [f_["function"]["name"] for f_ in query_data["functions"]]}
|
| 48 |
+
json.dump(save_data, f, indent=2)
|
| 49 |
+
results.append(result)
|
| 50 |
+
status = "solved" if result["success"] else ("gave up" if result["give_up"] else "failed")
|
| 51 |
+
print(f" {status} | steps={result['steps']} | tokens={result['total_tokens']}")
|
| 52 |
+
total = len(results)
|
| 53 |
+
solved = sum(1 for r in results if r.get("success", False))
|
| 54 |
+
gave_up = sum(1 for r in results if r.get("give_up", False))
|
| 55 |
+
summary = {"condition": condition, "group": group, "total_queries": total, "solved": solved, "gave_up": gave_up, "failed": total - solved - gave_up, "pass_rate": solved / total * 100 if total > 0 else 0}
|
| 56 |
+
with open(os.path.join(condition_dir, "summary.json"), "w") as f: json.dump(summary, f, indent=2)
|
| 57 |
+
print(f"\n Summary: {solved}/{total} solved ({summary['pass_rate']:.1f}% pass rate)")
|
| 58 |
+
return summary
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
parser = argparse.ArgumentParser(description="P2P StableToolBench Evaluation")
|
| 63 |
+
parser.add_argument("--condition", type=str, default="baseline", choices=CONDITION_NAMES + ["all"])
|
| 64 |
+
parser.add_argument("--groups", type=str, nargs="+", default=["G1_instruction"])
|
| 65 |
+
parser.add_argument("--max_queries", type=int, default=None)
|
| 66 |
+
parser.add_argument("--model", type=str, default=TASK_MODEL)
|
| 67 |
+
parser.add_argument("--vllm_url", type=str, default="http://localhost:8000/v1")
|
| 68 |
+
parser.add_argument("--api_server_url", type=str, default=None)
|
| 69 |
+
parser.add_argument("--temperature", type=float, default=TEMPERATURE)
|
| 70 |
+
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR)
|
| 71 |
+
parser.add_argument("--p2p_desc_dir", type=str, default=P2P_DESCRIPTIONS_DIR)
|
| 72 |
+
parser.add_argument("--p2p_examples_dir", type=str, default=P2P_EXAMPLES_DIR)
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
groups = ALL_GROUPS if "all" in args.groups else args.groups
|
| 75 |
+
conditions = CONDITION_NAMES if args.condition == "all" else [args.condition]
|
| 76 |
+
service_url = args.api_server_url or f"{API_SERVER_URL}:{API_SERVER_PORT}/virtual"
|
| 77 |
+
print(f"Connecting to vLLM at {args.vllm_url}\nModel: {args.model}")
|
| 78 |
+
llm = LLMClient(model=args.model, base_url=args.vllm_url, temperature=args.temperature)
|
| 79 |
+
p2p_descriptions, p2p_examples = None, None
|
| 80 |
+
if any(c in conditions for c in ["p2p_desc", "p2p_full"]):
|
| 81 |
+
p2p_descriptions = load_p2p_descriptions(args.p2p_desc_dir)
|
| 82 |
+
print(f" Loaded {len(p2p_descriptions)} descriptions")
|
| 83 |
+
if any(c in conditions for c in ["p2p_demo", "p2p_full"]):
|
| 84 |
+
p2p_examples = load_p2p_examples(args.p2p_examples_dir)
|
| 85 |
+
print(f" Loaded examples for {len(p2p_examples)} tools")
|
| 86 |
+
all_summaries = []
|
| 87 |
+
for condition in conditions:
|
| 88 |
+
for group in groups:
|
| 89 |
+
summary = run_condition(condition=condition, group=group, llm=llm, max_queries=args.max_queries, output_dir=args.output_dir, p2p_descriptions=p2p_descriptions, p2p_examples=p2p_examples, service_url=service_url)
|
| 90 |
+
all_summaries.append(summary)
|
| 91 |
+
print("\n" + "="*80 + "\nFINAL RESULTS\n" + "="*80)
|
| 92 |
+
print(f"{'Condition':<15} {'Group':<20} {'Solved':>8} {'Total':>8} {'Pass Rate':>10}")
|
| 93 |
+
print("-"*65)
|
| 94 |
+
for s in all_summaries:
|
| 95 |
+
if s: print(f"{s['condition']:<15} {s['group']:<20} {s['solved']:>8} {s['total_queries']:>8} {s['pass_rate']:>9.1f}%")
|
| 96 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 97 |
+
with open(os.path.join(args.output_dir, "all_summaries.json"), "w") as f: json.dump(all_summaries, f, indent=2)
|
| 98 |
+
print(f"\nResults saved to {args.output_dir}")
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
main()
|