from pathlib import Path import argparse import json import pickle import sys import numpy as np PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import ( BimanualTakeTrayOutOfOven, ReplayCache, _launch_replay_env, _load_demo, _pregrasp_progress_and_distance, _pregrasp_score_and_success, ) def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--episode-dir", required=True) parser.add_argument("--templates-pkl", required=True) parser.add_argument("--frame-indices", nargs="+", type=int, required=True) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--output-dir", required=True) args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) frame_indices = sorted(set(args.frame_indices)) pending_frame_indices = [ frame_index for frame_index in frame_indices if not output_dir.joinpath(f"frame_{frame_index:04d}.json").exists() ] if not pending_frame_indices: return 0 episode_dir = Path(args.episode_dir) with Path(args.templates_pkl).open("rb") as handle: templates = pickle.load(handle) demo = _load_demo(episode_dir) env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride) cache.reset() total = len(pending_frame_indices) for completed, frame_index in enumerate(pending_frame_indices, start=1): cache.step_to(frame_index) state = cache.current_state() pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance( np.asarray(state.left_gripper_pose, dtype=np.float64), np.asarray(state.tray_pose, dtype=np.float64), templates, ) p_pre, y_pre, _ = _pregrasp_score_and_success(task, templates) row = { "frame_index": int(frame_index), "pregrasp_progress": float(pregrasp_progress), "pregrasp_distance": float(pregrasp_distance), "p_pre": float(p_pre), "y_pre_raw": float(bool(y_pre)), "y_pre": float(bool(y_pre)), } row_path = output_dir.joinpath(f"frame_{frame_index:04d}.json") tmp_path = row_path.with_suffix(".json.tmp") with tmp_path.open("w", encoding="utf-8") as handle: json.dump(row, handle) tmp_path.replace(row_path) if completed == total or completed % 8 == 0: print( json.dumps( { "done": completed, "total": total, "frame_index": int(frame_index), } ), flush=True, ) finally: env.shutdown() return 0 if __name__ == "__main__": raise SystemExit(main())