| from pathlib import Path |
| import argparse |
| import json |
| import pickle |
| import sys |
|
|
| import numpy as np |
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from rr_label_study.oven_study import ( |
| BimanualTakeTrayOutOfOven, |
| ReplayCache, |
| _launch_replay_env, |
| _load_demo, |
| _pregrasp_progress_and_distance, |
| _pregrasp_score_and_success, |
| ) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--episode-dir", required=True) |
| parser.add_argument("--templates-pkl", required=True) |
| parser.add_argument("--frame-indices", nargs="+", type=int, required=True) |
| parser.add_argument("--checkpoint-stride", type=int, default=16) |
| parser.add_argument("--output-dir", required=True) |
| args = parser.parse_args() |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| frame_indices = sorted(set(args.frame_indices)) |
| pending_frame_indices = [ |
| frame_index |
| for frame_index in frame_indices |
| if not output_dir.joinpath(f"frame_{frame_index:04d}.json").exists() |
| ] |
| if not pending_frame_indices: |
| return 0 |
|
|
| episode_dir = Path(args.episode_dir) |
| with Path(args.templates_pkl).open("rb") as handle: |
| templates = pickle.load(handle) |
| demo = _load_demo(episode_dir) |
|
|
| env = _launch_replay_env() |
| try: |
| task = env.get_task(BimanualTakeTrayOutOfOven) |
| cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride) |
| cache.reset() |
| total = len(pending_frame_indices) |
| for completed, frame_index in enumerate(pending_frame_indices, start=1): |
| cache.step_to(frame_index) |
| state = cache.current_state() |
| pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance( |
| np.asarray(state.left_gripper_pose, dtype=np.float64), |
| np.asarray(state.tray_pose, dtype=np.float64), |
| templates, |
| ) |
| p_pre, y_pre, _ = _pregrasp_score_and_success(task, templates) |
| row = { |
| "frame_index": int(frame_index), |
| "pregrasp_progress": float(pregrasp_progress), |
| "pregrasp_distance": float(pregrasp_distance), |
| "p_pre": float(p_pre), |
| "y_pre_raw": float(bool(y_pre)), |
| "y_pre": float(bool(y_pre)), |
| } |
| row_path = output_dir.joinpath(f"frame_{frame_index:04d}.json") |
| tmp_path = row_path.with_suffix(".json.tmp") |
| with tmp_path.open("w", encoding="utf-8") as handle: |
| json.dump(row, handle) |
| tmp_path.replace(row_path) |
| if completed == total or completed % 8 == 0: |
| print( |
| json.dumps( |
| { |
| "done": completed, |
| "total": total, |
| "frame_index": int(frame_index), |
| } |
| ), |
| flush=True, |
| ) |
| finally: |
| env.shutdown() |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|