from pathlib import Path import argparse import json import sys from typing import Dict, Optional import pandas as pd PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import ( BimanualTakeTrayOutOfOven, ReplayCache, _aggregate_summary, _annotate_phase_columns, _analyze_episode, _derive_templates, _episode_metrics_from_frames, _keyframe_subset, _keypoint_discovery, _launch_replay_env, _load_demo, _load_descriptions, _pregrasp_progress_and_distance, _pregrasp_score_and_success, _frame_metrics, ) def _recompute_columns( episode_dir: Path, templates, checkpoint_stride: int, base_df: pd.DataFrame, ) -> pd.DataFrame: demo = _load_demo(episode_dir) num_frames = min(len(demo), len(base_df)) frame_df = base_df.iloc[:num_frames].copy() env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=checkpoint_stride) cache.reset() for frame_index in range(num_frames): cache.step_to(frame_index) state = cache.current_state() visibility, _ = _frame_metrics(episode_dir, demo, state, templates) pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance( state.left_gripper_pose, state.tray_pose, templates, ) p_pre, y_pre, _ = _pregrasp_score_and_success(task, templates) frame_df.at[frame_index, "frame_index"] = frame_index frame_df.at[frame_index, "time_norm"] = frame_index / max(1, num_frames - 1) frame_df.at[frame_index, "door_angle"] = state.door_angle frame_df.at[frame_index, "right_gripper_open"] = state.right_gripper_open frame_df.at[frame_index, "left_gripper_open"] = state.left_gripper_open frame_df.at[frame_index, "pregrasp_progress"] = pregrasp_progress frame_df.at[frame_index, "pregrasp_distance"] = pregrasp_distance frame_df.at[frame_index, "p_pre"] = p_pre frame_df.at[frame_index, "y_pre_raw"] = float(bool(y_pre)) frame_df.at[frame_index, "y_pre"] = float(bool(y_pre)) for key, value in visibility.items(): frame_df.at[frame_index, key] = value if (frame_index + 1) % 25 == 0 or (frame_index + 1) == num_frames: print( f"[{episode_dir.name}] recomputed {frame_index + 1}/{num_frames} dense frames", flush=True, ) return frame_df finally: env.shutdown() def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--dataset-root", required=True) parser.add_argument("--episode-dir", required=True) parser.add_argument("--input-dense-csv", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--template-episode-dir") args = parser.parse_args() dataset_root = Path(args.dataset_root) episode_dir = Path(args.episode_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) base_df = pd.read_csv(args.input_dense_csv) demo = _load_demo(episode_dir) descriptions = _load_descriptions(episode_dir) template_episode_dir = ( Path(args.template_episode_dir) if args.template_episode_dir else episode_dir ) templates, template_frames = _derive_templates(dataset_root, template_episode_dir) with output_dir.joinpath("templates.json").open("w", encoding="utf-8") as handle: json.dump( { "templates": templates.to_json(), "template_episode": template_episode_dir.name, "template_frames": template_frames, }, handle, indent=2, ) frame_df = _recompute_columns( episode_dir=episode_dir, templates=templates, checkpoint_stride=args.checkpoint_stride, base_df=base_df, ) frame_df = _annotate_phase_columns(frame_df) keyframes = [index for index in _keypoint_discovery(demo) if index < len(frame_df)] key_df = _keyframe_subset(frame_df, keyframes) metrics = _episode_metrics_from_frames( frame_df=frame_df, key_df=key_df, episode_name=episode_dir.name, description=descriptions[0], interventions={}, ) frame_df.to_csv(output_dir.joinpath(f"{episode_dir.name}.dense.csv"), index=False) key_df.to_csv(output_dir.joinpath(f"{episode_dir.name}.keyframes.csv"), index=False) with output_dir.joinpath(f"{episode_dir.name}.metrics.json").open("w", encoding="utf-8") as handle: json.dump(metrics, handle, indent=2) summary = _aggregate_summary([metrics]) with output_dir.joinpath("summary.json").open("w", encoding="utf-8") as handle: json.dump(summary, handle, indent=2) print(json.dumps(summary, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())