VLAdaptorBench / code /scripts /refresh_saved_oven_study.py
lsnu's picture
Update refined oven metric validation and smoke artifacts
0bcd290 verified
import argparse
import json
from pathlib import Path
import sys
from typing import Optional
import pandas as pd
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from rr_label_study.oven_study import ( # noqa: E402
MotionTemplates,
_aggregate_summary,
_annotate_phase_columns,
_episode_metrics_from_frames,
_interventional_validity,
_load_demo,
)
INTERVENTION_KEYS = [
"pre_ready_open_more_increases_pext",
"pre_ready_open_more_trials",
"pre_ready_hold_open_increases_pext",
"pre_ready_hold_open_trials",
"pre_ready_extract_success",
"pre_ready_extract_trials",
"pre_ready_wait_extract_success",
"pre_ready_wait_trials",
"post_ready_extract_success",
"post_ready_extract_trials",
"post_ready_open_more_low_gain",
"post_ready_open_more_trials",
"post_ready_hold_open_low_gain",
"post_ready_hold_open_trials",
]
def _load_templates(result_dir: Path) -> MotionTemplates:
with result_dir.joinpath("templates.json").open("r", encoding="utf-8") as handle:
payload = json.load(handle)
return MotionTemplates(**payload["templates"])
def _refresh_episode(
result_dir: Path,
episode_name: str,
dataset_root: Optional[Path],
checkpoint_stride: int,
) -> dict:
dense_path = result_dir / f"{episode_name}.dense.csv"
keyframes_path = result_dir / f"{episode_name}.keyframes.csv"
metrics_path = result_dir / f"{episode_name}.metrics.json"
dense_df = pd.read_csv(dense_path)
dense_df = _annotate_phase_columns(dense_df)
old_key_df = pd.read_csv(keyframes_path)
keyframe_indices = old_key_df["frame_index"].astype(int).tolist()
key_df = dense_df[dense_df["frame_index"].isin(keyframe_indices)].copy()
key_df = key_df.sort_values("frame_index").reset_index(drop=True)
key_df["keyframe_ordinal"] = range(len(key_df))
with metrics_path.open("r", encoding="utf-8") as handle:
old_metrics = json.load(handle)
if dataset_root is None:
interventions = {
key: float(old_metrics[key]) for key in INTERVENTION_KEYS if key in old_metrics
}
else:
episode_dir = dataset_root / "all_variations" / "episodes" / episode_name
demo = _load_demo(episode_dir)
templates = _load_templates(result_dir)
interventions = _interventional_validity(
demo=demo,
templates=templates,
frame_df=dense_df,
checkpoint_stride=checkpoint_stride,
)
metrics = _episode_metrics_from_frames(
frame_df=dense_df,
key_df=key_df,
episode_name=episode_name,
description=str(old_metrics.get("description", "")),
interventions=interventions,
)
dense_df.to_csv(dense_path, index=False)
key_df.to_csv(keyframes_path, index=False)
with metrics_path.open("w", encoding="utf-8") as handle:
json.dump(metrics, handle, indent=2)
return metrics
def main(argv=None) -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--result-dir", required=True)
parser.add_argument("--dataset-root")
parser.add_argument("--checkpoint-stride", type=int, default=16)
parser.add_argument("--episodes", nargs="*")
args = parser.parse_args(argv)
result_dir = Path(args.result_dir)
dataset_root = Path(args.dataset_root) if args.dataset_root else None
episode_metrics = []
if args.episodes:
episode_names = args.episodes
else:
episode_names = sorted(
path.stem.replace(".metrics", "")
for path in result_dir.glob("episode*.metrics.json")
)
for episode_name in episode_names:
episode_metrics.append(
_refresh_episode(
result_dir=result_dir,
episode_name=episode_name,
dataset_root=dataset_root,
checkpoint_stride=args.checkpoint_stride,
)
)
summary = _aggregate_summary(episode_metrics)
with result_dir.joinpath("summary.json").open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
print(json.dumps(summary, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())