from pathlib import Path import argparse import json import pickle import sys from typing import Dict, List, Optional, Sequence, Tuple import pandas as pd PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import ( _aggregate_summary, _annotate_phase_columns, _derive_templates, _json_safe, _load_demo, ) from scripts.recompute_oven_episode_parallel import ( _chunk_frame_indices, _collect_debug_rows, _collect_rows, _launch_xvfb, _spawn_frame_batch_job, _stop_process, ) def _merge_new_columns( base_df: pd.DataFrame, probe_df: pd.DataFrame ) -> Tuple[pd.DataFrame, List[str]]: new_columns = [ column for column in probe_df.columns if column != "frame_index" and column not in base_df.columns ] if not new_columns: return base_df.copy(), [] merged = base_df.merge( probe_df[["frame_index", *new_columns]], on="frame_index", how="left", sort=False, ) return merged, new_columns def _verification_record( base_df: pd.DataFrame, merged_df: pd.DataFrame, base_key_df: Optional[pd.DataFrame], merged_key_df: Optional[pd.DataFrame], base_metrics: Optional[Dict[str, object]], output_metrics: Optional[Dict[str, object]], debug_rows: List[Dict[str, object]], ) -> Dict[str, object]: dense_equal = base_df.equals(merged_df[base_df.columns]) key_equal = True if base_key_df is not None and merged_key_df is not None: key_equal = base_key_df.equals(merged_key_df[base_key_df.columns]) metrics_equal = True if base_metrics is not None and output_metrics is not None: metrics_equal = base_metrics == output_metrics dense_pose_consistent = True best_pose_consistent = True num_goal_consistent = True debug_by_frame = { int(row["frame_index"]): row for row in debug_rows } for _, row in merged_df.iterrows(): frame_index = int(row["frame_index"]) debug = debug_by_frame.get(frame_index) if debug is None: dense_pose_consistent = False best_pose_consistent = False num_goal_consistent = False continue state = debug["state"] if "left_arm_pose_x" in merged_df.columns: expected_pose = state.get("left_arm_pose", []) actual_pose = [ float(row["left_arm_pose_x"]), float(row["left_arm_pose_y"]), float(row["left_arm_pose_z"]), float(row["left_arm_pose_qx"]), float(row["left_arm_pose_qy"]), float(row["left_arm_pose_qz"]), float(row["left_arm_pose_qw"]), ] if any(abs(a - b) > 1e-9 for a, b in zip(actual_pose, expected_pose)): dense_pose_consistent = False if "p_pre_num_goal_poses" in merged_df.columns: if int(round(float(row["p_pre_num_goal_poses"]))) != int( debug["p_pre"].get("num_goal_poses", 0) ): num_goal_consistent = False if "p_pre_best_target_pose_x" in merged_df.columns: expected_best = debug["p_pre"].get("best_goal_pose", []) if expected_best: actual_best = [ float(row["p_pre_best_target_pose_x"]), float(row["p_pre_best_target_pose_y"]), float(row["p_pre_best_target_pose_z"]), float(row["p_pre_best_target_pose_qx"]), float(row["p_pre_best_target_pose_qy"]), float(row["p_pre_best_target_pose_qz"]), float(row["p_pre_best_target_pose_qw"]), ] if any(abs(a - b) > 1e-9 for a, b in zip(actual_best, expected_best)): best_pose_consistent = False else: actual_best = [ row["p_pre_best_target_pose_x"], row["p_pre_best_target_pose_y"], row["p_pre_best_target_pose_z"], row["p_pre_best_target_pose_qx"], row["p_pre_best_target_pose_qy"], row["p_pre_best_target_pose_qz"], row["p_pre_best_target_pose_qw"], ] if any(pd.notna(value) for value in actual_best): best_pose_consistent = False return { "dense_existing_columns_unchanged": bool(dense_equal), "keyframe_existing_columns_unchanged": bool(key_equal), "metrics_json_preserved": bool(metrics_equal), "debug_row_count": int(len(debug_rows)), "dense_row_count": int(len(merged_df)), "dense_pose_columns_match_debug_state": bool(dense_pose_consistent), "best_target_pose_columns_match_debug": bool(best_pose_consistent), "num_goal_pose_columns_match_debug": bool(num_goal_consistent), } def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--dataset-root", required=True) parser.add_argument("--episode-dir", required=True) parser.add_argument("--base-dense-csv", required=True) parser.add_argument("--output-dir", required=True) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--num-workers", type=int, default=8) parser.add_argument("--base-display", type=int, default=700) parser.add_argument("--template-episode-dir") parser.add_argument("--stagger-seconds", type=float, default=0.15) parser.add_argument("--base-keyframes-csv") parser.add_argument("--base-metrics-json") parser.add_argument("--base-summary-json") parser.add_argument("--keep-frame-json", action="store_true") args = parser.parse_args() dataset_root = Path(args.dataset_root) episode_dir = Path(args.episode_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) base_dense_csv = Path(args.base_dense_csv) base_df = pd.read_csv(base_dense_csv) base_key_df = None if args.base_keyframes_csv: base_key_df = pd.read_csv(args.base_keyframes_csv) base_metrics = None if args.base_metrics_json: base_metrics = json.loads(Path(args.base_metrics_json).read_text()) base_summary = None if args.base_summary_json: base_summary = json.loads(Path(args.base_summary_json).read_text()) demo = _load_demo(episode_dir) num_frames = len(demo) if len(base_df) != num_frames: raise ValueError( f"base dense rows {len(base_df)} do not match demo length {num_frames} for {episode_dir}" ) template_episode_dir = ( Path(args.template_episode_dir) if args.template_episode_dir else episode_dir ) templates, template_frames = _derive_templates(dataset_root, template_episode_dir) templates_pkl = output_dir.joinpath("templates.pkl") with templates_pkl.open("wb") as handle: pickle.dump(templates, handle) with output_dir.joinpath("templates.json").open("w", encoding="utf-8") as handle: json.dump( { "template_mode": "per_episode", "template_episode": template_episode_dir.name, "template_frames": template_frames, "templates": templates.to_json(), "preserve_base_dense_csv": str(base_dense_csv), }, handle, indent=2, ) frame_json_dir = output_dir.joinpath("frame_rows") frame_json_dir.mkdir(parents=True, exist_ok=True) frame_indices = list(range(num_frames)) frame_chunks = _chunk_frame_indices(frame_indices, args.num_workers) displays = [args.base_display + index for index in range(len(frame_chunks))] xvfb_procs = [] active: Dict[int, Tuple[Sequence[int], object]] = {} try: for display_num in displays: xvfb_procs.append( _launch_xvfb(display_num, output_dir.joinpath(f"xvfb_{display_num}.log")) ) for display_num, frame_chunk in zip(displays, frame_chunks): process = _spawn_frame_batch_job( display_num=display_num, episode_dir=episode_dir, templates_pkl=templates_pkl, frame_indices=frame_chunk, checkpoint_stride=args.checkpoint_stride, output_dir=frame_json_dir, ) active[display_num] = (frame_chunk, process) if args.stagger_seconds > 0: import time time.sleep(args.stagger_seconds) while active: import time time.sleep(1.0) finished: List[int] = [] for display_num, (frame_chunk, process) in active.items(): return_code = process.poll() if return_code is None: continue missing = [ frame_index for frame_index in frame_chunk if not frame_json_dir.joinpath(f"frame_{frame_index:04d}.json").exists() or not frame_json_dir.joinpath(f"frame_{frame_index:04d}.debug.json").exists() ] if return_code != 0 or missing: raise RuntimeError( f"display :{display_num} failed for frames {list(frame_chunk)[:3]}...; missing={missing[:8]}" ) finished.append(display_num) for display_num in finished: active.pop(display_num) finally: for _, process in list(active.values()): _stop_process(process) for xvfb in xvfb_procs: _stop_process(xvfb) probe_df = _collect_rows(frame_json_dir, num_frames) debug_rows = _collect_debug_rows(frame_json_dir, num_frames) merged_df, new_columns = _merge_new_columns(base_df, probe_df) annotated_df = _annotate_phase_columns(merged_df.copy()) phase_new_columns = [ column for column in annotated_df.columns if column not in merged_df.columns ] if phase_new_columns: merged_df = merged_df.merge( annotated_df[["frame_index", *phase_new_columns]], on="frame_index", how="left", sort=False, ) new_columns.extend(phase_new_columns) merged_key_df = None if base_key_df is not None: merged_key_df, _ = _merge_new_columns(base_key_df, probe_df) if phase_new_columns: merged_key_df = merged_key_df.merge( annotated_df[["frame_index", *phase_new_columns]], on="frame_index", how="left", sort=False, ) output_metrics = base_metrics if base_metrics is not None else None output_summary = base_summary if base_summary is not None else ( _aggregate_summary([output_metrics]) if output_metrics is not None else None ) merged_df.to_csv(output_dir.joinpath(f"{episode_dir.name}.dense.csv"), index=False) if merged_key_df is not None: merged_key_df.to_csv(output_dir.joinpath(f"{episode_dir.name}.keyframes.csv"), index=False) elif args.base_keyframes_csv: raise RuntimeError("base keyframes csv was provided but merged keyframes are missing") if output_metrics is not None: with output_dir.joinpath(f"{episode_dir.name}.metrics.json").open("w", encoding="utf-8") as handle: json.dump(output_metrics, handle, indent=2) if output_summary is not None: with output_dir.joinpath("summary.json").open("w", encoding="utf-8") as handle: json.dump(output_summary, handle, indent=2) with output_dir.joinpath(f"{episode_dir.name}.debug.jsonl").open("w", encoding="utf-8") as handle: for row in debug_rows: handle.write(json.dumps(_json_safe(row))) handle.write("\n") verification = _verification_record( base_df=base_df, merged_df=merged_df, base_key_df=base_key_df, merged_key_df=merged_key_df, base_metrics=base_metrics, output_metrics=output_metrics, debug_rows=debug_rows, ) verification["new_columns_added"] = new_columns verification["phase_new_columns_added"] = phase_new_columns verification["probe_mode"] = "preserve_base_metrics" with output_dir.joinpath("verification.json").open("w", encoding="utf-8") as handle: json.dump(_json_safe(verification), handle, indent=2) if not args.keep_frame_json: for row_path in frame_json_dir.glob("frame_*.json*"): row_path.unlink() frame_json_dir.rmdir() print(json.dumps(_json_safe(verification), indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())