File size: 5,201 Bytes
b6c1b75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""Generate or refresh data/metadata.csv entries with provenance details."""

from __future__ import annotations

import argparse
import csv
import hashlib
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, Optional

import yaml


DEFAULT_EXTENSIONS = {".wav", ".mp3", ".m4a"}


@dataclass
class MetadataRow:
    path: Path
    device: str
    source: str
    license: str
    split: str
    sha256: str

    def as_dict(self, root: Path) -> Dict[str, str]:
        rel_path = self.path.relative_to(root).as_posix()
        return {
            "path": rel_path,
            "device": self.device,
            "source": self.source,
            "license": self.license,
            "split": self.split,
            "sha256": self.sha256,
        }


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--config", default="configs/base.yaml", help="YAML config that defines data root and defaults.")
    parser.add_argument("--output", help="Override output metadata CSV path. Defaults to the config value.")
    parser.add_argument("--extensions", nargs="*", help="File extensions to include (e.g., .wav .mp3 .m4a). Defaults to built-ins.")
    return parser.parse_args()


def load_config(path: Path) -> dict:
    if not path.exists():
        raise SystemExit(f"Config not found: {path}")
    with path.open("r", encoding="utf-8") as fh:
        cfg = yaml.safe_load(fh) or {}
    if "data" not in cfg:
        raise SystemExit("Config is missing a `data` section.")
    return cfg


def read_existing_metadata(path: Path) -> Dict[str, dict]:
    if not path.exists():
        return {}
    with path.open("r", encoding="utf-8", newline="") as fh:
        reader = csv.DictReader(fh)
        return {row["path"]: row for row in reader if "path" in row}


def compute_sha256(path: Path) -> str:
    hasher = hashlib.sha256()
    with path.open("rb") as fh:
        for chunk in iter(lambda: fh.read(8192), b""):
            hasher.update(chunk)
    return hasher.hexdigest()


def gather_files(root: Path, extensions: Iterable[str]) -> Iterable[Path]:
    for file_path in root.rglob("*"):
        if not file_path.is_file():
            continue
        if file_path.suffix.lower() in extensions:
            yield file_path


def build_rows(
    files: Iterable[Path],
    existing_rows: Dict[str, dict],
    root: Path,
    device_defaults: Optional[dict],
    include_devices: Optional[set[str]],
) -> Iterable[MetadataRow]:
    for path in files:
        rel_key = path.relative_to(root).as_posix()
        parts = path.relative_to(root).parts
        if not parts:
            continue
        device = parts[0]
        if include_devices and device not in include_devices:
            continue

        defaults = (device_defaults or {}).get(device, {})
        existing = existing_rows.get(rel_key, {})

        source = existing.get("source") or defaults.get("source")
        license_ = existing.get("license") or defaults.get("license")
        split = existing.get("split") or "train"

        if not source or not license_:
            sys.stderr.write(f"[warn] Missing source/license for {rel_key}; fill these in manually.\n")

        sha256 = compute_sha256(path)

        yield MetadataRow(
            path=path,
            device=device,
            source=source or "",
            license=license_ or "",
            split=split,
            sha256=sha256,
        )


def main() -> None:
    args = parse_args()
    config_path = Path(args.config)
    cfg = load_config(config_path)

    data_cfg = cfg["data"]
    root = Path(data_cfg.get("root", "data")).resolve()
    metadata_path = Path(args.output or data_cfg.get("metadata", root / "metadata.csv")).resolve()
    extensions = {ext.lower() for ext in (args.extensions or data_cfg.get("extensions", DEFAULT_EXTENSIONS))}

    if not root.exists():
        raise SystemExit(f"Data root does not exist: {root}")

    existing_rows = read_existing_metadata(metadata_path)
    device_defaults = data_cfg.get("device_defaults", {})
    include_devices = set(data_cfg.get("include_devices", []) or [])

    files = sorted(gather_files(root, extensions))
    rows = sorted(
        build_rows(files, existing_rows, root, device_defaults, include_devices if include_devices else None),
        key=lambda row: row.path.relative_to(root).as_posix(),
    )

    metadata_path.parent.mkdir(parents=True, exist_ok=True)
    with metadata_path.open("w", encoding="utf-8", newline="") as fh:
        writer = csv.DictWriter(fh, fieldnames=["path", "device", "source", "license", "split", "sha256"])
        writer.writeheader()
        for row in rows:
            writer.writerow(row.as_dict(root))

    orphaned = sorted(set(existing_rows) - {row.path.relative_to(root).as_posix() for row in rows})
    if orphaned:
        sys.stderr.write(f"[warn] Orphaned metadata entries (files missing): {len(orphaned)}\n")
        for item in orphaned:
            sys.stderr.write(f"  - {item}\n")

    print(f"Wrote {len(rows)} rows to {metadata_path}")


if __name__ == "__main__":
    main()