Spaces:
Running
Running
| """ | |
| AIFinder Dataset Evaluator with Server | |
| Runs the Flask server, then allows interactive dataset input. | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import argparse | |
| import random | |
| import threading | |
| import requests | |
| from collections import defaultdict | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from config import MODEL_DIR | |
| from inference import AIFinder | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| SERVER_URL = "http://localhost:7860" | |
| def start_server(): | |
| """Start Flask server in background thread.""" | |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) | |
| from app import app, load_models | |
| load_models() | |
| print("Server started on http://localhost:7860") | |
| app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False) | |
| def wait_for_server(timeout=30): | |
| """Wait for server to be ready.""" | |
| start = time.time() | |
| while time.time() - start < timeout: | |
| try: | |
| resp = requests.get(f"{SERVER_URL}/api/status", timeout=2) | |
| if resp.status_code == 200: | |
| return True | |
| except requests.exceptions.RequestException: | |
| pass | |
| time.sleep(1) | |
| return False | |
| def _parse_msg(msg): | |
| """Parse a message that may be a dict or a JSON string.""" | |
| if isinstance(msg, dict): | |
| return msg | |
| if isinstance(msg, str): | |
| try: | |
| import json | |
| parsed = json.loads(msg) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except (ValueError, Exception): | |
| pass | |
| return {} | |
| def _extract_response_only(content): | |
| """Extract only the final response, stripping CoT blocks.""" | |
| import re | |
| if not content: | |
| return "" | |
| think_match = re.search(r"</?think(?:ing)?>(.*)$", content, re.DOTALL) | |
| if think_match: | |
| response = think_match.group(1).strip() | |
| if response: | |
| return response | |
| return content | |
| def extract_texts_from_dataset(dataset_id, max_samples=None): | |
| """Extract assistant response texts from a HuggingFace dataset.""" | |
| print(f"\nLoading dataset: {dataset_id}") | |
| load_kwargs = {"token": HF_TOKEN} if HF_TOKEN else {} | |
| rows = [] | |
| try: | |
| ds = load_dataset(dataset_id, split="train", **load_kwargs) | |
| rows = list(ds) | |
| except Exception as e: | |
| print(f"Failed to load dataset: {e}") | |
| try: | |
| import pandas as pd | |
| url = f"https://huggingface.co/api/datasets/{dataset_id}/parquet/default/train/0.parquet" | |
| df = pd.read_parquet(url) | |
| rows = df.to_dict(orient="records") | |
| except Exception as e2: | |
| print(f"Parquet fallback also failed: {e2}") | |
| return [] | |
| texts = [] | |
| for row in rows: | |
| convos = row.get("conversations") or row.get("messages") or [] | |
| if not convos: | |
| continue | |
| for msg in convos: | |
| msg = _parse_msg(msg) | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role in ("assistant", "gpt", "model") and content: | |
| response_only = _extract_response_only(content) | |
| if response_only and len(response_only) > 50: | |
| texts.append(response_only) | |
| if max_samples and len(texts) > max_samples: | |
| random.seed(42) | |
| texts = random.sample(texts, max_samples) | |
| return texts | |
| def evaluate_dataset(texts): | |
| """Evaluate all texts via API and aggregate results.""" | |
| results = { | |
| "total": len(texts), | |
| "provider_counts": defaultdict(int), | |
| "confidences": defaultdict(list), | |
| } | |
| for text in tqdm(texts, desc="Evaluating"): | |
| try: | |
| resp = requests.post( | |
| f"{SERVER_URL}/api/classify", | |
| json={"text": text, "top_n": 5}, | |
| timeout=30, | |
| ) | |
| if resp.status_code == 200: | |
| result = resp.json() | |
| pred_provider = result.get("provider") | |
| confidence = result.get("confidence", 0) / 100.0 | |
| if pred_provider: | |
| results["provider_counts"][pred_provider] += 1 | |
| results["confidences"][pred_provider].append(confidence) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| continue | |
| return results | |
| def print_results(results): | |
| """Print aggregated evaluation results.""" | |
| total = results["total"] | |
| print("\n" + "=" * 60) | |
| print(f"EVALUATION RESULTS ({total} samples)") | |
| print("=" * 60) | |
| print("\n--- Predicted Provider Distribution ---") | |
| for provider, count in sorted( | |
| results["provider_counts"].items(), key=lambda x: -x[1] | |
| ): | |
| pct = (count / total) * 100 | |
| avg_conf = sum(results["confidences"][provider]) / len( | |
| results["confidences"][provider] | |
| ) | |
| print( | |
| f" {provider}: {count} ({pct:.1f}%) - Avg confidence: {avg_conf * 100:.1f}%" | |
| ) | |
| if results["confidences"]: | |
| print("\n--- Top Providers (by cumulative confidence) ---") | |
| provider_scores = {} | |
| for provider, confs in results["confidences"].items(): | |
| if confs: | |
| avg_conf = sum(confs) / len(confs) | |
| count = results["provider_counts"][provider] | |
| provider_scores[provider] = avg_conf * count | |
| for provider, score in sorted(provider_scores.items(), key=lambda x: -x[1])[:3]: | |
| print(f" {provider}: {score:.2f}") | |
| print("\n" + "=" * 60) | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="AIFinder Dataset Evaluator with Server" | |
| ) | |
| parser.add_argument( | |
| "--max-samples", type=int, default=None, help="Max samples to test" | |
| ) | |
| args = parser.parse_args() | |
| print("Starting AIFinder server...") | |
| server_thread = threading.Thread(target=start_server, daemon=True) | |
| server_thread.start() | |
| print("Waiting for server...") | |
| if not wait_for_server(): | |
| print("Server failed to start!") | |
| sys.exit(1) | |
| print("\n" + "=" * 60) | |
| print("AIFinder Server Ready!") | |
| print("=" * 60) | |
| print(f"Server running at: {SERVER_URL}") | |
| print("Enter a HuggingFace dataset ID to evaluate.") | |
| print("Examples: ianncity/Hunter-Alpha-SFT-300000x") | |
| print("Type 'quit' or 'exit' to stop.") | |
| print("=" * 60 + "\n") | |
| while True: | |
| try: | |
| dataset_id = input("Dataset ID: ").strip() | |
| if dataset_id.lower() in ("quit", "exit", "q"): | |
| print("Goodbye!") | |
| break | |
| if not dataset_id: | |
| continue | |
| texts = extract_texts_from_dataset(dataset_id, args.max_samples) | |
| if not texts: | |
| print("No valid texts found in dataset.") | |
| continue | |
| print(f"Testing {len(texts)} responses...") | |
| results = evaluate_dataset(texts) | |
| print_results(results) | |
| except KeyboardInterrupt: | |
| print("\nGoodbye!") | |
| break | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| if __name__ == "__main__": | |
| main() | |