tool-use-openenv / inference.py
Clove25's picture
Update inference.py
02e8821 verified
import os
import random
from collections import defaultdict
from dotenv import load_dotenv
from openai import OpenAI
from tool_use_env.client import ToolUseEnv
from tool_use_env.models import ToolUseAction
# --- Load env ---
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_MODEL = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
# --- HF client ---
hf_client = OpenAI(
base_url="https://router.huggingface.co/v1",
api_key=HF_TOKEN
)
# --- Reproducibility ---
random.seed(42)
# --- Global flag ---
HF_AVAILABLE = True
# 🧠 Rule-based (correct logic)
def rule_based_policy(query: str):
q = query.lower()
if any(op in q for op in ["+", "-", "*", "/"]):
return "use_calculator"
if "capital" in q or "who is" in q or "ceo" in q:
return "use_search"
return "use_search"
# 🧠 Noisy fallback (simulate LLM mistakes)
def noisy_rule_policy(query: str):
correct = rule_based_policy(query)
if random.random() < 0.08: # 8% noise
action = random.choice([
"use_calculator",
"use_search",
"answer_directly"
])
return correct
# 🧠 LLM + fallback policy
def llm_policy(query: str):
global HF_AVAILABLE
prompt = f"""
You are an AI agent.
Choose EXACTLY one action:
- use_calculator
- use_search
- answer_directly
Query: {query}
ONLY output one action.
"""
# --- Try HF only if still available ---
if HF_AVAILABLE:
try:
response = hf_client.chat.completions.create(
model=HF_MODEL,
messages=[{"role": "user", "content": prompt}],
temperature=0
)
action = response.choices[0].message.content.strip()
if random.random() < 0.08:
action = random.choice([
"use_calculator",
"use_search",
"answer_directly"
])
if action in ["use_calculator", "use_search", "answer_directly"]:
print("[HF] Used")
return action
except Exception as e:
print("[HF FAILED β†’ switching to fallback permanently]")
HF_AVAILABLE = False
# --- Fallback ---
return noisy_rule_policy(query)
# πŸ§ͺ Evaluation
def run_evaluation(num_episodes=50):
results = []
total_score = 0
difficulty_scores = defaultdict(list)
with ToolUseEnv(base_url="https://clove25-tool-use-openenv.hf.space").sync() as env:
for _ in range(num_episodes):
result = env.reset()
obs = result.observation
query = obs.query
state = env.state()
difficulty = state.difficulty
action_type = llm_policy(query)
action = ToolUseAction(action_type=action_type)
result = env.step(action)
obs = result.observation
score = result.reward
total_score += score
difficulty_scores[difficulty].append(score)
results.append({
"query": query,
"difficulty": difficulty,
"action": action_type,
"score": score,
"message": obs.message
})
print(f"Score: {score:.2f}")
avg_score = total_score / num_episodes
print("\n=== OVERALL PERFORMANCE ===")
print(f"Average Score: {avg_score:.2f}")
print("\n=== DIFFICULTY BREAKDOWN ===")
for level in ["easy", "medium", "hard"]:
if difficulty_scores[level]:
avg = sum(difficulty_scores[level]) / len(difficulty_scores[level])
print(f"{level.capitalize()}: {avg:.2f}")
print("\n=== SAMPLE CASES ===")
for r in results[:5]:
print(f"\nQuery: {r['query']}")
print(f"Action: {r['action']}")
print(f"Score: {r['score']:.2f}")
print(f"Details: {r['message']}")
return results
# πŸ“Š Failure analysis (FIXED VERSION)
def analyze_failures(results):
total = len(results)
tool_failures = 0
wrong_decisions = 0
for r in results:
score = r["score"]
action = r["action"]
if score < 0.5:
if "use_" in action:
tool_failures += 1
else:
wrong_decisions += 1
print("\n=== FAILURE ANALYSIS ===")
print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)")
print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)")
# πŸš€ Run
if __name__ == "__main__":
results = run_evaluation(50)
analyze_failures(results)