| |
| """Lightweight done-detection check for the toy VLAC dataset.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import base64 |
| import json |
| import sys |
| import time |
| from tqdm import tqdm |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional |
|
|
| import requests |
| from PIL import Image |
|
|
| |
| |
| |
|
|
|
|
| def read_manifest(dataset_dir: Path, json_name: str) -> List[Dict]: |
| manifest_path = dataset_dir / json_name |
| images_dir = dataset_dir / "images" |
| if not manifest_path.is_file(): |
| raise FileNotFoundError(f"Metadata JSON not found: {manifest_path}") |
| if not images_dir.is_dir(): |
| raise FileNotFoundError(f"Images directory not found: {images_dir}") |
|
|
| with manifest_path.open("r", encoding="utf-8") as f: |
| raw_entries = json.load(f) |
|
|
| entries: List[Dict] = [] |
| for entry in raw_entries: |
| samples = entry.get("samples") or [] |
| if not samples: |
| continue |
| resolved_samples = [] |
| for sample in samples: |
| try: |
| resolved_samples.append( |
| { |
| "label": int(sample["label"]), |
| "initial": str(images_dir / sample["initial"]), |
| "prev": str(images_dir / sample["prev"]), |
| "curr": str(images_dir / sample["curr"]), |
| } |
| ) |
| except (KeyError, TypeError, ValueError): |
| continue |
| if not resolved_samples: |
| continue |
| entry["samples"] = resolved_samples |
| entry["reference"] = [str(images_dir / rel) for rel in entry.get("reference", [])] |
| entries.append(entry) |
| return entries |
|
|
|
|
| def encode_image(path: Path) -> str: |
| with Image.open(path) as img: |
| img = img.convert("RGB") |
| from io import BytesIO |
|
|
| buffer = BytesIO() |
| img.save(buffer, format="JPEG", quality=95) |
| return base64.b64encode(buffer.getvalue()).decode("utf-8") |
|
|
|
|
| def encode_images(paths: Iterable[str]) -> List[str]: |
| return [encode_image(Path(path)) for path in paths] |
|
|
|
|
| def call_done( |
| session: requests.Session, |
| base_url: str, |
| task: str, |
| first_frame: str, |
| prev_frame: str, |
| curr_frame: str, |
| reference: Optional[List[str]], |
| timeout: float, |
| ) -> Dict: |
| payload = { |
| "task": task, |
| "first_frame": first_frame, |
| "prev_frame": prev_frame, |
| "curr_frame": curr_frame, |
| "reference": reference, |
| } |
| start = time.time() |
| resp = session.post(f"{base_url.rstrip('/')}/done", json=payload, timeout=timeout) |
| resp.raise_for_status() |
| result = resp.json() |
| result["latency_sec"] = time.time() - start |
| return result |
|
|
|
|
| |
| |
| |
|
|
|
|
| def evaluate(entries: List[Dict], base_url: str, timeout: float, scenario: str) -> Dict[str, float]: |
| session = requests.Session() |
| total = 0 |
| correct = 0 |
| latencies: List[float] = [] |
| class_totals = {0: 0, 1: 0} |
| class_correct = {0: 0, 1: 0} |
|
|
| for entry in tqdm(entries): |
| task = entry.get("task", "") |
| try: |
| reference_b64 = encode_images(entry["reference"]) if scenario == "with_ref" and entry["reference"] else None |
| except FileNotFoundError as exc: |
| print(f"[skip] missing reference frame: {exc}") |
| reference_b64 = None |
|
|
| for sample in entry["samples"]: |
| label = int(sample["label"]) |
| class_totals[label] += 1 |
|
|
| try: |
| initial_b64 = encode_image(Path(sample["initial"])) |
| prev_b64 = encode_image(Path(sample["prev"])) |
| curr_b64 = encode_image(Path(sample["curr"])) |
| except FileNotFoundError as exc: |
| print(f"[skip] missing frame: {exc}") |
| continue |
|
|
| try: |
| result = call_done( |
| session, |
| base_url, |
| task, |
| initial_b64, |
| prev_b64, |
| curr_b64, |
| reference_b64, |
| timeout, |
| ) |
| except requests.RequestException as exc: |
| print(f"[warn] request failed for demo {entry.get('demo_id')}: {exc}") |
| continue |
|
|
| total += 1 |
| latencies.append(result.get("latency_sec", 0.0)) |
| prediction = bool(result.get("done")) |
| if prediction == bool(label): |
| correct += 1 |
| class_correct[label] += 1 |
|
|
| accuracy = correct / total if total else float("nan") |
| avg_latency = sum(latencies) / len(latencies) if latencies else float("nan") |
|
|
| per_class_accuracy = { |
| label: (class_correct[label] / class_totals[label]) if class_totals[label] else float("nan") |
| for label in class_totals |
| } |
|
|
| return { |
| "accuracy": accuracy, |
| "samples": total, |
| "latency": avg_latency, |
| "per_class_accuracy": per_class_accuracy, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="VLAC done detection sanity check") |
| parser.add_argument("--dataset-dir", required=True, help="Directory containing images/ and dataset JSON") |
| parser.add_argument("--json-name", default="dataset_task_done.json", help="Manifest filename") |
| parser.add_argument("--base-url", default="http://localhost:8111", help="VLAC service base URL") |
| parser.add_argument("--timeout", type=float, default=20.0, help="HTTP timeout in seconds") |
| parser.add_argument("--max-demos", type=int, default=None, help="Evaluate only the first N demos") |
| parser.add_argument("--skip-reference", action="store_true", help="Only evaluate the no-reference scenario") |
| return parser.parse_args() |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| dataset_dir = Path(args.dataset_dir) |
| try: |
| entries = read_manifest(dataset_dir, args.json_name) |
| except FileNotFoundError as exc: |
| print(exc) |
| return 1 |
|
|
| if args.max_demos is not None: |
| entries = entries[: args.max_demos] |
|
|
| if not entries: |
| print("No demos found in the manifest." |
| " Regenerate the dataset with testing/prepare_vlac_test_data.py") |
| return 1 |
|
|
| print(f"Loaded {len(entries)} demos from {dataset_dir}") |
|
|
| res_no_ref = evaluate(entries, args.base_url, args.timeout, scenario="no_ref") |
| print(f"\nNo reference -> accuracy: {res_no_ref['accuracy']:.3f}" |
| f" | samples: {res_no_ref['samples']} | avg latency: {res_no_ref['latency']:.2f}s") |
| for label, acc in sorted(res_no_ref.get("per_class_accuracy", {}).items()): |
| label_name = "done=1" if label == 1 else "done=0" |
| print(f" {label_name} accuracy: {acc:.3f}") |
|
|
| if args.skip_reference: |
| return 0 |
|
|
| res_with_ref = evaluate(entries, args.base_url, args.timeout, scenario="with_ref") |
| print(f"With reference -> accuracy: {res_with_ref['accuracy']:.3f}" |
| f" | samples: {res_with_ref['samples']} | avg latency: {res_with_ref['latency']:.2f}s") |
| for label, acc in sorted(res_with_ref.get("per_class_accuracy", {}).items()): |
| label_name = "done=1" if label == 1 else "done=0" |
| print(f" {label_name} accuracy: {acc:.3f}") |
|
|
| if not any(map(lambda x: isinstance(x, float) and x != x, (res_no_ref["accuracy"], res_with_ref["accuracy"]))): |
| delta = res_with_ref["accuracy"] - res_no_ref["accuracy"] |
| print(f"\nΔ accuracy (with - without): {delta:+.3f}") |
| for label in sorted(res_no_ref.get("per_class_accuracy", {})): |
| acc_no = res_no_ref["per_class_accuracy"].get(label) |
| acc_ref = res_with_ref["per_class_accuracy"].get(label) |
| if any(isinstance(x, float) and x != x for x in (acc_no, acc_ref)): |
| continue |
| label_name = "done=1" if label == 1 else "done=0" |
| print(f" Δ {label_name}: {acc_ref - acc_no:+.3f}") |
|
|
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|
|
|
|
|