Clove25's picture
Upload 41 files
d9175ae verified
# from tool_use_env.client import ToolUseEnv
# from tool_use_env.models import ToolUseAction
# import random
# def rule_based_policy(query: str):
# query = query.lower()
# # --- Introduce slight imperfection ---
# if random.random() < 0.1:
# return "answer_directly"
# if "what is" in query and any(op in query for op in ["+", "-", "*", "/"]):
# return "use_calculator"
# if "capital" in query or "who is" in query:
# return "use_search"
# return "answer_directly"
# def run_single_episode(env):
# result = env.reset()
# obs = result.observation
# query = obs.query
# action_type = rule_based_policy(query)
# action = ToolUseAction(action_type=action_type)
# result = env.step(action)
# obs = result.observation
# return {
# "query": query,
# "action": action_type,
# "reward": result.reward,
# "message": obs.message
# }
# def run_evaluation(num_episodes=20):
# results = []
# difficulty_scores = {
# "easy": [],
# "medium": [],
# "hard": []
# }
# total_score = 0
# with ToolUseEnv(base_url="http://localhost:8000").sync() as env:
# for _ in range(num_episodes):
# result = env.reset()
# obs = result.observation
# query = obs.query
# state = env.state()
# difficulty = state.difficulty
# action_type = rule_based_policy(query)
# action = ToolUseAction(action_type=action_type)
# result = env.step(action)
# score = result.reward
# total_score += score
# difficulty_scores[difficulty].append(score)
# results.append({
# "query": query,
# "difficulty": difficulty,
# "action": action_type,
# "score": score,
# "message": result.observation.message
# })
# avg_score = total_score / num_episodes
# print("\n=== OVERALL PERFORMANCE ===")
# print(f"Average Score: {avg_score:.2f}")
# print("\n=== DIFFICULTY BREAKDOWN ===")
# for level in difficulty_scores:
# if difficulty_scores[level]:
# avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
# print(f"{level.capitalize()}: {avg:.2f}")
# print("\n=== SAMPLE CASES ===")
# for r in results[:5]:
# print(f"\nQuery: {r['query']}")
# print(f"Action: {r['action']}")
# print(f"Score: {r['score']:.2f}")
# print(f"Details: {r['message']}")
# return results
# def analyze_failures(results):
# wrong_decisions = 0
# tool_failures = 0
# total = len(results)
# for r in results:
# msg = r["message"]
# if "Correct: False" in msg:
# if "use_" in msg:
# tool_failures += 1
# else:
# wrong_decisions += 1
# print("\n=== FAILURE ANALYSIS ===")
# print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
# print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
# if __name__ == "__main__":
# results = run_evaluation(50)
# analyze_failures(results)
import os
import random
from collections import defaultdict
from dotenv import load_dotenv
from openai import OpenAI
from tool_use_env.client import ToolUseEnv
from tool_use_env.models import ToolUseAction
# --- Load environment variables ---
load_dotenv()
# --- Initialize OpenAI client ---
client = OpenAI()
# --- Reproducibility ---
random.seed(42)
# ๐Ÿง  LLM Policy (CORE)
def llm_policy(query: str):
prompt = f"""
You are an AI agent choosing the best tool.
Available actions:
- use_calculator (for math problems)
- use_search (for factual questions)
- answer_directly (if neither tool is needed)
Query: {query}
Respond with ONLY one of:
use_calculator
use_search
answer_directly
"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0
)
action = response.choices[0].message.content.strip()
# --- Safety check ---
if action not in ["use_calculator", "use_search", "answer_directly"]:
return "answer_directly"
return action
except Exception as e:
print(f"[ERROR] LLM call failed: {e}")
return "answer_directly"
# ๐Ÿงช Evaluation Loop
def run_evaluation(num_episodes=50):
results = []
total_score = 0
difficulty_scores = defaultdict(list)
with ToolUseEnv(base_url="http://localhost:8000").sync() as env:
for _ in range(num_episodes):
# --- Reset ---
result = env.reset()
obs = result.observation
query = obs.query
# --- Get difficulty ---
state = env.state()
difficulty = state.difficulty
# --- LLM decides action ---
action_type = llm_policy(query)
action = ToolUseAction(action_type=action_type)
# --- Step ---
result = env.step(action)
obs = result.observation
score = result.reward
total_score += score
difficulty_scores[difficulty].append(score)
results.append({
"query": query,
"difficulty": difficulty,
"action": action_type,
"score": score,
"message": obs.message
})
print(f"Score: {score:.2f}")
# --- Overall ---
avg_score = total_score / num_episodes
print("\n=== OVERALL PERFORMANCE ===")
print(f"Average Score: {avg_score:.2f}")
# --- Breakdown ---
print("\n=== DIFFICULTY BREAKDOWN ===")
for level in ["easy", "medium", "hard"]:
if difficulty_scores[level]:
avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
print(f"{level.capitalize()}: {avg:.2f}")
# --- Sample Cases ---
print("\n=== SAMPLE CASES ===")
for r in results[:5]:
print(f"\nQuery: {r['query']}")
print(f"Action: {r['action']}")
print(f"Score: {r['score']:.2f}")
print(f"Details: {r['message']}")
return results
# ๐Ÿ“Š Failure Analysis
def analyze_failures(results):
total = len(results)
tool_failures = 0
wrong_decisions = 0
for r in results:
msg = r["message"]
if "Correct: False" in msg:
if "use_" in msg:
tool_failures += 1
else:
wrong_decisions += 1
print("\n=== FAILURE ANALYSIS ===")
print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
# ๐Ÿš€ Main
if __name__ == "__main__":
results = run_evaluation(50)
analyze_failures(results)