Spaces:
Sleeping
Sleeping
| # from tool_use_env.client import ToolUseEnv | |
| # from tool_use_env.models import ToolUseAction | |
| # import random | |
| # def rule_based_policy(query: str): | |
| # query = query.lower() | |
| # # --- Introduce slight imperfection --- | |
| # if random.random() < 0.1: | |
| # return "answer_directly" | |
| # if "what is" in query and any(op in query for op in ["+", "-", "*", "/"]): | |
| # return "use_calculator" | |
| # if "capital" in query or "who is" in query: | |
| # return "use_search" | |
| # return "answer_directly" | |
| # def run_single_episode(env): | |
| # result = env.reset() | |
| # obs = result.observation | |
| # query = obs.query | |
| # action_type = rule_based_policy(query) | |
| # action = ToolUseAction(action_type=action_type) | |
| # result = env.step(action) | |
| # obs = result.observation | |
| # return { | |
| # "query": query, | |
| # "action": action_type, | |
| # "reward": result.reward, | |
| # "message": obs.message | |
| # } | |
| # def run_evaluation(num_episodes=20): | |
| # results = [] | |
| # difficulty_scores = { | |
| # "easy": [], | |
| # "medium": [], | |
| # "hard": [] | |
| # } | |
| # total_score = 0 | |
| # with ToolUseEnv(base_url="http://localhost:8000").sync() as env: | |
| # for _ in range(num_episodes): | |
| # result = env.reset() | |
| # obs = result.observation | |
| # query = obs.query | |
| # state = env.state() | |
| # difficulty = state.difficulty | |
| # action_type = rule_based_policy(query) | |
| # action = ToolUseAction(action_type=action_type) | |
| # result = env.step(action) | |
| # score = result.reward | |
| # total_score += score | |
| # difficulty_scores[difficulty].append(score) | |
| # results.append({ | |
| # "query": query, | |
| # "difficulty": difficulty, | |
| # "action": action_type, | |
| # "score": score, | |
| # "message": result.observation.message | |
| # }) | |
| # avg_score = total_score / num_episodes | |
| # print("\n=== OVERALL PERFORMANCE ===") | |
| # print(f"Average Score: {avg_score:.2f}") | |
| # print("\n=== DIFFICULTY BREAKDOWN ===") | |
| # for level in difficulty_scores: | |
| # if difficulty_scores[level]: | |
| # avg = sum(difficulty_scores[level]) / len(difficulty_scores[level]) | |
| # print(f"{level.capitalize()}: {avg:.2f}") | |
| # print("\n=== SAMPLE CASES ===") | |
| # for r in results[:5]: | |
| # print(f"\nQuery: {r['query']}") | |
| # print(f"Action: {r['action']}") | |
| # print(f"Score: {r['score']:.2f}") | |
| # print(f"Details: {r['message']}") | |
| # return results | |
| # def analyze_failures(results): | |
| # wrong_decisions = 0 | |
| # tool_failures = 0 | |
| # total = len(results) | |
| # for r in results: | |
| # msg = r["message"] | |
| # if "Correct: False" in msg: | |
| # if "use_" in msg: | |
| # tool_failures += 1 | |
| # else: | |
| # wrong_decisions += 1 | |
| # print("\n=== FAILURE ANALYSIS ===") | |
| # print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)") | |
| # print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)") | |
| # if __name__ == "__main__": | |
| # results = run_evaluation(50) | |
| # analyze_failures(results) | |
| import os | |
| import random | |
| from collections import defaultdict | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from tool_use_env.client import ToolUseEnv | |
| from tool_use_env.models import ToolUseAction | |
| # --- Load environment variables --- | |
| load_dotenv() | |
| # --- Initialize OpenAI client --- | |
| client = OpenAI() | |
| # --- Reproducibility --- | |
| random.seed(42) | |
| # ๐ง LLM Policy (CORE) | |
| def llm_policy(query: str): | |
| prompt = f""" | |
| You are an AI agent choosing the best tool. | |
| Available actions: | |
| - use_calculator (for math problems) | |
| - use_search (for factual questions) | |
| - answer_directly (if neither tool is needed) | |
| Query: {query} | |
| Respond with ONLY one of: | |
| use_calculator | |
| use_search | |
| answer_directly | |
| """ | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| action = response.choices[0].message.content.strip() | |
| # --- Safety check --- | |
| if action not in ["use_calculator", "use_search", "answer_directly"]: | |
| return "answer_directly" | |
| return action | |
| except Exception as e: | |
| print(f"[ERROR] LLM call failed: {e}") | |
| return "answer_directly" | |
| # ๐งช Evaluation Loop | |
| def run_evaluation(num_episodes=50): | |
| results = [] | |
| total_score = 0 | |
| difficulty_scores = defaultdict(list) | |
| with ToolUseEnv(base_url="http://localhost:8000").sync() as env: | |
| for _ in range(num_episodes): | |
| # --- Reset --- | |
| result = env.reset() | |
| obs = result.observation | |
| query = obs.query | |
| # --- Get difficulty --- | |
| state = env.state() | |
| difficulty = state.difficulty | |
| # --- LLM decides action --- | |
| action_type = llm_policy(query) | |
| action = ToolUseAction(action_type=action_type) | |
| # --- Step --- | |
| result = env.step(action) | |
| obs = result.observation | |
| score = result.reward | |
| total_score += score | |
| difficulty_scores[difficulty].append(score) | |
| results.append({ | |
| "query": query, | |
| "difficulty": difficulty, | |
| "action": action_type, | |
| "score": score, | |
| "message": obs.message | |
| }) | |
| print(f"Score: {score:.2f}") | |
| # --- Overall --- | |
| avg_score = total_score / num_episodes | |
| print("\n=== OVERALL PERFORMANCE ===") | |
| print(f"Average Score: {avg_score:.2f}") | |
| # --- Breakdown --- | |
| print("\n=== DIFFICULTY BREAKDOWN ===") | |
| for level in ["easy", "medium", "hard"]: | |
| if difficulty_scores[level]: | |
| avg = sum(difficulty_scores[level]) / len(difficulty_scores[level]) | |
| print(f"{level.capitalize()}: {avg:.2f}") | |
| # --- Sample Cases --- | |
| print("\n=== SAMPLE CASES ===") | |
| for r in results[:5]: | |
| print(f"\nQuery: {r['query']}") | |
| print(f"Action: {r['action']}") | |
| print(f"Score: {r['score']:.2f}") | |
| print(f"Details: {r['message']}") | |
| return results | |
| # ๐ Failure Analysis | |
| def analyze_failures(results): | |
| total = len(results) | |
| tool_failures = 0 | |
| wrong_decisions = 0 | |
| for r in results: | |
| msg = r["message"] | |
| if "Correct: False" in msg: | |
| if "use_" in msg: | |
| tool_failures += 1 | |
| else: | |
| wrong_decisions += 1 | |
| print("\n=== FAILURE ANALYSIS ===") | |
| print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)") | |
| print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)") | |
| # ๐ Main | |
| if __name__ == "__main__": | |
| results = run_evaluation(50) | |
| analyze_failures(results) |