| import argparse |
| import json |
| from pathlib import Path |
| import sys |
| from typing import Optional |
|
|
| import pandas as pd |
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from rr_label_study.oven_study import ( |
| MotionTemplates, |
| _aggregate_summary, |
| _annotate_phase_columns, |
| _episode_metrics_from_frames, |
| _interventional_validity, |
| _load_demo, |
| ) |
|
|
|
|
| INTERVENTION_KEYS = [ |
| "pre_ready_open_more_increases_pext", |
| "pre_ready_open_more_trials", |
| "pre_ready_hold_open_increases_pext", |
| "pre_ready_hold_open_trials", |
| "pre_ready_extract_success", |
| "pre_ready_extract_trials", |
| "pre_ready_wait_extract_success", |
| "pre_ready_wait_trials", |
| "post_ready_extract_success", |
| "post_ready_extract_trials", |
| "post_ready_open_more_low_gain", |
| "post_ready_open_more_trials", |
| "post_ready_hold_open_low_gain", |
| "post_ready_hold_open_trials", |
| ] |
|
|
|
|
| def _load_templates(result_dir: Path) -> MotionTemplates: |
| with result_dir.joinpath("templates.json").open("r", encoding="utf-8") as handle: |
| payload = json.load(handle) |
| return MotionTemplates(**payload["templates"]) |
|
|
|
|
| def _refresh_episode( |
| result_dir: Path, |
| episode_name: str, |
| dataset_root: Optional[Path], |
| checkpoint_stride: int, |
| ) -> dict: |
| dense_path = result_dir / f"{episode_name}.dense.csv" |
| keyframes_path = result_dir / f"{episode_name}.keyframes.csv" |
| metrics_path = result_dir / f"{episode_name}.metrics.json" |
|
|
| dense_df = pd.read_csv(dense_path) |
| dense_df = _annotate_phase_columns(dense_df) |
|
|
| old_key_df = pd.read_csv(keyframes_path) |
| keyframe_indices = old_key_df["frame_index"].astype(int).tolist() |
| key_df = dense_df[dense_df["frame_index"].isin(keyframe_indices)].copy() |
| key_df = key_df.sort_values("frame_index").reset_index(drop=True) |
| key_df["keyframe_ordinal"] = range(len(key_df)) |
|
|
| with metrics_path.open("r", encoding="utf-8") as handle: |
| old_metrics = json.load(handle) |
| if dataset_root is None: |
| interventions = { |
| key: float(old_metrics[key]) for key in INTERVENTION_KEYS if key in old_metrics |
| } |
| else: |
| episode_dir = dataset_root / "all_variations" / "episodes" / episode_name |
| demo = _load_demo(episode_dir) |
| templates = _load_templates(result_dir) |
| interventions = _interventional_validity( |
| demo=demo, |
| templates=templates, |
| frame_df=dense_df, |
| checkpoint_stride=checkpoint_stride, |
| ) |
|
|
| metrics = _episode_metrics_from_frames( |
| frame_df=dense_df, |
| key_df=key_df, |
| episode_name=episode_name, |
| description=str(old_metrics.get("description", "")), |
| interventions=interventions, |
| ) |
|
|
| dense_df.to_csv(dense_path, index=False) |
| key_df.to_csv(keyframes_path, index=False) |
| with metrics_path.open("w", encoding="utf-8") as handle: |
| json.dump(metrics, handle, indent=2) |
| return metrics |
|
|
|
|
| def main(argv=None) -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--result-dir", required=True) |
| parser.add_argument("--dataset-root") |
| parser.add_argument("--checkpoint-stride", type=int, default=16) |
| parser.add_argument("--episodes", nargs="*") |
| args = parser.parse_args(argv) |
|
|
| result_dir = Path(args.result_dir) |
| dataset_root = Path(args.dataset_root) if args.dataset_root else None |
| episode_metrics = [] |
| if args.episodes: |
| episode_names = args.episodes |
| else: |
| episode_names = sorted( |
| path.stem.replace(".metrics", "") |
| for path in result_dir.glob("episode*.metrics.json") |
| ) |
|
|
| for episode_name in episode_names: |
| episode_metrics.append( |
| _refresh_episode( |
| result_dir=result_dir, |
| episode_name=episode_name, |
| dataset_root=dataset_root, |
| checkpoint_stride=args.checkpoint_stride, |
| ) |
| ) |
|
|
| summary = _aggregate_summary(episode_metrics) |
| with result_dir.joinpath("summary.json").open("w", encoding="utf-8") as handle: |
| json.dump(summary, handle, indent=2) |
| print(json.dumps(summary, indent=2)) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|