Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from collections import defaultdict | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from tool_use_env.client import ToolUseEnv | |
| from tool_use_env.models import ToolUseAction | |
| # --- Load env --- | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| HF_MODEL = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| # --- HF client --- | |
| hf_client = OpenAI( | |
| base_url="https://router.huggingface.co/v1", | |
| api_key=HF_TOKEN | |
| ) | |
| # --- Reproducibility --- | |
| random.seed(42) | |
| # --- Global flag --- | |
| HF_AVAILABLE = True | |
| # π§ Rule-based (correct logic) | |
| def rule_based_policy(query: str): | |
| q = query.lower() | |
| if any(op in q for op in ["+", "-", "*", "/"]): | |
| return "use_calculator" | |
| if "capital" in q or "who is" in q or "ceo" in q: | |
| return "use_search" | |
| return "use_search" | |
| # π§ Noisy fallback (simulate LLM mistakes) | |
| def noisy_rule_policy(query: str): | |
| correct = rule_based_policy(query) | |
| if random.random() < 0.08: # 8% noise | |
| action = random.choice([ | |
| "use_calculator", | |
| "use_search", | |
| "answer_directly" | |
| ]) | |
| return correct | |
| # π§ LLM + fallback policy | |
| def llm_policy(query: str): | |
| global HF_AVAILABLE | |
| prompt = f""" | |
| You are an AI agent. | |
| Choose EXACTLY one action: | |
| - use_calculator | |
| - use_search | |
| - answer_directly | |
| Query: {query} | |
| ONLY output one action. | |
| """ | |
| # --- Try HF only if still available --- | |
| if HF_AVAILABLE: | |
| try: | |
| response = hf_client.chat.completions.create( | |
| model=HF_MODEL, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0 | |
| ) | |
| action = response.choices[0].message.content.strip() | |
| if random.random() < 0.08: | |
| action = random.choice([ | |
| "use_calculator", | |
| "use_search", | |
| "answer_directly" | |
| ]) | |
| if action in ["use_calculator", "use_search", "answer_directly"]: | |
| print("[HF] Used") | |
| return action | |
| except Exception as e: | |
| print("[HF FAILED β switching to fallback permanently]") | |
| HF_AVAILABLE = False | |
| # --- Fallback --- | |
| return noisy_rule_policy(query) | |
| # π§ͺ Evaluation | |
| def run_evaluation(num_episodes=50): | |
| results = [] | |
| total_score = 0 | |
| difficulty_scores = defaultdict(list) | |
| with ToolUseEnv(base_url="https://clove25-tool-use-openenv.hf.space").sync() as env: | |
| for _ in range(num_episodes): | |
| result = env.reset() | |
| obs = result.observation | |
| query = obs.query | |
| state = env.state() | |
| difficulty = state.difficulty | |
| action_type = llm_policy(query) | |
| action = ToolUseAction(action_type=action_type) | |
| result = env.step(action) | |
| obs = result.observation | |
| score = result.reward | |
| total_score += score | |
| difficulty_scores[difficulty].append(score) | |
| results.append({ | |
| "query": query, | |
| "difficulty": difficulty, | |
| "action": action_type, | |
| "score": score, | |
| "message": obs.message | |
| }) | |
| print(f"Score: {score:.2f}") | |
| avg_score = total_score / num_episodes | |
| print("\n=== OVERALL PERFORMANCE ===") | |
| print(f"Average Score: {avg_score:.2f}") | |
| print("\n=== DIFFICULTY BREAKDOWN ===") | |
| for level in ["easy", "medium", "hard"]: | |
| if difficulty_scores[level]: | |
| avg = sum(difficulty_scores[level]) / len(difficulty_scores[level]) | |
| print(f"{level.capitalize()}: {avg:.2f}") | |
| print("\n=== SAMPLE CASES ===") | |
| for r in results[:5]: | |
| print(f"\nQuery: {r['query']}") | |
| print(f"Action: {r['action']}") | |
| print(f"Score: {r['score']:.2f}") | |
| print(f"Details: {r['message']}") | |
| return results | |
| # π Failure analysis (FIXED VERSION) | |
| def analyze_failures(results): | |
| total = len(results) | |
| tool_failures = 0 | |
| wrong_decisions = 0 | |
| for r in results: | |
| score = r["score"] | |
| action = r["action"] | |
| if score < 0.5: | |
| if "use_" in action: | |
| tool_failures += 1 | |
| else: | |
| wrong_decisions += 1 | |
| print("\n=== FAILURE ANALYSIS ===") | |
| print(f"Tool failures: {tool_failures}/{total} ({(tool_failures/total)*100:.1f}%)") | |
| print(f"Wrong decisions: {wrong_decisions}/{total} ({(wrong_decisions/total)*100:.1f}%)") | |
| # π Run | |
| if __name__ == "__main__": | |
| results = run_evaluation(50) | |
| analyze_failures(results) |