import argparse import json from pathlib import Path import sys from typing import Optional import pandas as pd PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import ( # noqa: E402 MotionTemplates, _aggregate_summary, _annotate_phase_columns, _episode_metrics_from_frames, _interventional_validity, _load_demo, ) INTERVENTION_KEYS = [ "pre_ready_open_more_increases_pext", "pre_ready_open_more_trials", "pre_ready_hold_open_increases_pext", "pre_ready_hold_open_trials", "pre_ready_extract_success", "pre_ready_extract_trials", "pre_ready_wait_extract_success", "pre_ready_wait_trials", "post_ready_extract_success", "post_ready_extract_trials", "post_ready_open_more_low_gain", "post_ready_open_more_trials", "post_ready_hold_open_low_gain", "post_ready_hold_open_trials", ] def _load_templates(result_dir: Path) -> MotionTemplates: with result_dir.joinpath("templates.json").open("r", encoding="utf-8") as handle: payload = json.load(handle) return MotionTemplates(**payload["templates"]) def _refresh_episode( result_dir: Path, episode_name: str, dataset_root: Optional[Path], checkpoint_stride: int, ) -> dict: dense_path = result_dir / f"{episode_name}.dense.csv" keyframes_path = result_dir / f"{episode_name}.keyframes.csv" metrics_path = result_dir / f"{episode_name}.metrics.json" dense_df = pd.read_csv(dense_path) dense_df = _annotate_phase_columns(dense_df) old_key_df = pd.read_csv(keyframes_path) keyframe_indices = old_key_df["frame_index"].astype(int).tolist() key_df = dense_df[dense_df["frame_index"].isin(keyframe_indices)].copy() key_df = key_df.sort_values("frame_index").reset_index(drop=True) key_df["keyframe_ordinal"] = range(len(key_df)) with metrics_path.open("r", encoding="utf-8") as handle: old_metrics = json.load(handle) if dataset_root is None: interventions = { key: float(old_metrics[key]) for key in INTERVENTION_KEYS if key in old_metrics } else: episode_dir = dataset_root / "all_variations" / "episodes" / episode_name demo = _load_demo(episode_dir) templates = _load_templates(result_dir) interventions = _interventional_validity( demo=demo, templates=templates, frame_df=dense_df, checkpoint_stride=checkpoint_stride, ) metrics = _episode_metrics_from_frames( frame_df=dense_df, key_df=key_df, episode_name=episode_name, description=str(old_metrics.get("description", "")), interventions=interventions, ) dense_df.to_csv(dense_path, index=False) key_df.to_csv(keyframes_path, index=False) with metrics_path.open("w", encoding="utf-8") as handle: json.dump(metrics, handle, indent=2) return metrics def main(argv=None) -> int: parser = argparse.ArgumentParser() parser.add_argument("--result-dir", required=True) parser.add_argument("--dataset-root") parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--episodes", nargs="*") args = parser.parse_args(argv) result_dir = Path(args.result_dir) dataset_root = Path(args.dataset_root) if args.dataset_root else None episode_metrics = [] if args.episodes: episode_names = args.episodes else: episode_names = sorted( path.stem.replace(".metrics", "") for path in result_dir.glob("episode*.metrics.json") ) for episode_name in episode_names: episode_metrics.append( _refresh_episode( result_dir=result_dir, episode_name=episode_name, dataset_root=dataset_root, checkpoint_stride=args.checkpoint_stride, ) ) summary = _aggregate_summary(episode_metrics) with result_dir.joinpath("summary.json").open("w", encoding="utf-8") as handle: json.dump(summary, handle, indent=2) print(json.dumps(summary, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())