File size: 4,985 Bytes
c6dfc69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | #!/usr/bin/env python3
"""
Build merge-debug tree under AVSBench/v2/ (name = v2):
AVSBench/v2/avss_index/metadata.csv
AVSBench/v2/v1s/<20 uids>/
AVSBench/v2/v1m/<20 uids>/
AVSBench/v2/v2/<20 uids>/ # v2-protocol clips from ~/Downloads/v2.zip
Each modality: 16 train + 4 test rows (20 clips). Full AVSBench is used as source for v1s/v1m copies.
"""
from __future__ import annotations
import csv
import os
import shutil
import zipfile
import pandas as pd
_WORKSPACE = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
AVSBENCH = os.path.join(_WORKSPACE, "AVSBench")
SUBSET_ROOT = os.path.join(AVSBENCH, "v2") # data_root for merge run
FULL_META = os.path.join(AVSBENCH, "avss_index", "metadata.csv")
ZIP_PATH = os.path.expanduser("~/Downloads/v2.zip")
PER_LABEL = 20
TRAIN_N, TEST_N = 16, 4
def _pick_rows_from_full_metadata(label: str) -> list[dict]:
df = pd.read_csv(FULL_META)
picked: list[dict] = []
for split, need in ("train", TRAIN_N), ("test", TEST_N):
got = 0
sub = df[(df["split"] == split) & (df["label"] == label)]
for _, row in sub.iterrows():
uid = str(row["uid"])
src = os.path.join(AVSBENCH, label, uid)
if not os.path.isdir(src):
continue
picked.append({k: row[k] for k in row.index})
got += 1
if got >= need:
break
if got < need:
raise SystemExit(
f"not enough existing {label}/{split} under {AVSBENCH}: need {need}, got {got}"
)
return picked
def _list_v2_uids(z: zipfile.ZipFile) -> list[str]:
uids: set[str] = set()
for name in z.namelist():
if not name.startswith("v2/") or name.endswith("/"):
continue
parts = name.split("/")
if len(parts) >= 3 and parts[1]:
uids.add(parts[1])
return sorted(uids)
def _extract_v2_from_zip(uids: list[str]) -> None:
allowed = set(uids)
with zipfile.ZipFile(ZIP_PATH, "r") as z:
for info in z.infolist():
n = info.filename
if not n.startswith("v2/"):
continue
parts = n.split("/")
if len(parts) < 3 or parts[1] not in allowed:
continue
if "/labels_semantic/" in n:
continue
if "/frames/" in n and n.endswith(".jpg"):
pass
elif "/labels_rgb/" in n and n.endswith(".png"):
pass
elif n.endswith("/audio.wav"):
pass
else:
continue
dest = os.path.join(SUBSET_ROOT, "v2", parts[1], *parts[2:])
os.makedirs(os.path.dirname(dest), exist_ok=True)
with z.open(info, "r") as src, open(dest, "wb") as out:
shutil.copyfileobj(src, out)
def _copy_clip(label: str, uid: str) -> None:
src = os.path.join(AVSBENCH, label, uid)
dst = os.path.join(SUBSET_ROOT, label, uid)
if os.path.isdir(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst)
def main() -> None:
if not os.path.isfile(FULL_META):
raise SystemExit(f"missing full metadata: {FULL_META}")
if not os.path.isfile(ZIP_PATH):
raise SystemExit(f"missing v2 zip: {ZIP_PATH}")
shutil.rmtree(SUBSET_ROOT, ignore_errors=True)
os.makedirs(os.path.join(SUBSET_ROOT, "avss_index"), exist_ok=True)
for sub in ("v1s", "v1m", "v2"):
os.makedirs(os.path.join(SUBSET_ROOT, sub), exist_ok=True)
all_rows: list[dict] = []
for label in ("v1s", "v1m"):
rows = _pick_rows_from_full_metadata(label)
assert len(rows) == PER_LABEL
for r in rows:
_copy_clip(label, str(r["uid"]))
all_rows.extend(rows)
with zipfile.ZipFile(ZIP_PATH, "r") as z:
uids_all = _list_v2_uids(z)
if len(uids_all) < PER_LABEL:
raise SystemExit(f"v2.zip has only {len(uids_all)} clips")
v2_uids = uids_all[:PER_LABEL]
_extract_v2_from_zip(v2_uids)
for i, uid in enumerate(v2_uids):
split = "train" if i < TRAIN_N else "test"
vid = uid.rsplit("_", 2)[0] if uid.count("_") >= 2 else uid
all_rows.append(
{
"vid": vid,
"uid": uid,
"s_min": 0,
"s_sec": 0,
"a_obj": "v2subset",
"split": split,
"label": "v2",
}
)
meta_out = os.path.join(SUBSET_ROOT, "avss_index", "metadata.csv")
with open(meta_out, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=["vid", "uid", "s_min", "s_sec", "a_obj", "split", "label"])
w.writeheader()
for r in all_rows:
w.writerow({k: r.get(k, "") for k in w.fieldnames})
print("subset data_root:", SUBSET_ROOT)
print("metadata:", meta_out)
print("v2 uids:", v2_uids)
print("rows:", len(all_rows))
if __name__ == "__main__":
main()
|