DerivedFunction1 commited on
Commit
3b3f566
·
1 Parent(s): a42debc

add cache

Browse files
Files changed (1) hide show
  1. fleurs_cache.py +323 -0
fleurs_cache.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import random
5
+ import unicodedata
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import pandas as pd
11
+ from huggingface_hub import HfApi, hf_hub_download
12
+
13
+ from language import ALL_LANGS, LANG_ISO2_TO_ISO3, canonical_lang
14
+
15
+
16
+ FLEURS_DATASET = "google/fleurs"
17
+ FLEURS_CACHE_DIR = Path(__file__).with_name("data") / "fleurs"
18
+ FLEURS_PARQUET_PATH = FLEURS_CACHE_DIR / "fleurs_text_only.parquet"
19
+ FLEURS_DOWNLOAD_DIR = FLEURS_CACHE_DIR / "downloads"
20
+ FLEURS_TSV_COLUMNS = [
21
+ "id",
22
+ "file_name",
23
+ "source_sentence",
24
+ "transcription",
25
+ "tokens",
26
+ "num_samples",
27
+ "gender",
28
+ ]
29
+ FLEURS_SPLIT_ORDER = {"train": 0, "validation": 1, "test": 2}
30
+ FLEURS_LEAN_COLUMNS = ["id", "text", "source_lang", "model_lang", "split"]
31
+
32
+
33
+ def _normalize_model_lang(source_lang: str) -> str:
34
+ """Map a FLEURS locale like `am_et` to the model language code."""
35
+ base_lang = source_lang.split("_", 1)[0].strip().lower()
36
+ return canonical_lang(base_lang)
37
+
38
+
39
+ def _discover_tsv_files() -> list[str]:
40
+ """Return all FLEURS TSV metadata files, preferring the local cache."""
41
+ local_root = FLEURS_DOWNLOAD_DIR / "data"
42
+ local_files = sorted(local_root.rglob("*.tsv"))
43
+ if local_files:
44
+ return [str(path.relative_to(FLEURS_DOWNLOAD_DIR)) for path in local_files]
45
+
46
+ api = HfApi()
47
+ try:
48
+ files = api.list_repo_files(repo_id=FLEURS_DATASET, repo_type="dataset")
49
+ except TypeError:
50
+ files = api.list_repo_files(FLEURS_DATASET, repo_type="dataset")
51
+
52
+ tsv_files = [
53
+ file_path
54
+ for file_path in files
55
+ if file_path.startswith("data/") and file_path.endswith(".tsv")
56
+ ]
57
+ if not tsv_files:
58
+ raise RuntimeError("Could not find any FLEURS TSV metadata files.")
59
+ return sorted(tsv_files)
60
+
61
+
62
+ def _normalize_split_name(file_name: str) -> str:
63
+ stem = Path(file_name).stem.lower()
64
+ if stem == "dev":
65
+ return "validation"
66
+ return stem
67
+
68
+
69
+ def _normalize_text_key(text: str) -> str:
70
+ """Normalize text for deduping while keeping the original text intact."""
71
+ normalized = unicodedata.normalize("NFKC", text)
72
+ normalized = " ".join(normalized.split())
73
+ return normalized.casefold().strip()
74
+
75
+
76
+ def _download_tsv(file_path: str) -> Path:
77
+ local_candidate = FLEURS_DOWNLOAD_DIR / file_path
78
+ if local_candidate.exists():
79
+ return local_candidate
80
+
81
+ FLEURS_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
82
+ try:
83
+ local_path = hf_hub_download(
84
+ repo_id=FLEURS_DATASET,
85
+ repo_type="dataset",
86
+ filename=file_path,
87
+ local_dir=str(FLEURS_DOWNLOAD_DIR),
88
+ )
89
+ except TypeError:
90
+ local_path = hf_hub_download(
91
+ repo_id=FLEURS_DATASET,
92
+ repo_type="dataset",
93
+ filename=file_path,
94
+ cache_dir=str(FLEURS_DOWNLOAD_DIR),
95
+ )
96
+ return Path(local_path)
97
+
98
+
99
+ def _frame_from_tsv(tsv_path: Path, source_lang: str) -> pd.DataFrame:
100
+ records: list[dict[str, Any]] = []
101
+ header_seen = False
102
+ header_markers = {name.lower() for name in FLEURS_TSV_COLUMNS}
103
+
104
+ with tsv_path.open("r", encoding="utf-8", newline="") as handle:
105
+ for line in handle:
106
+ line = line.rstrip("\n")
107
+ if not line.strip():
108
+ continue
109
+
110
+ parts = line.split("\t", 6)
111
+ if not header_seen:
112
+ header_candidate = [part.strip().lower() for part in parts]
113
+ if header_markers.issubset(set(header_candidate)):
114
+ header_seen = True
115
+ continue
116
+ header_seen = True
117
+
118
+ if len(parts) < len(FLEURS_TSV_COLUMNS):
119
+ parts.extend([""] * (len(FLEURS_TSV_COLUMNS) - len(parts)))
120
+ elif len(parts) > len(FLEURS_TSV_COLUMNS):
121
+ parts = parts[: len(FLEURS_TSV_COLUMNS) - 1] + ["\t".join(parts[len(FLEURS_TSV_COLUMNS) - 1 :])]
122
+
123
+ record = dict(zip(FLEURS_TSV_COLUMNS, parts, strict=True))
124
+ records.append(record)
125
+
126
+ if not records:
127
+ return pd.DataFrame()
128
+
129
+ frame = pd.DataFrame.from_records(records)
130
+ frame = frame.fillna("")
131
+ frame["source_sentence"] = frame["source_sentence"].astype(str).str.strip()
132
+ frame["transcription"] = frame["transcription"].astype(str).str.strip()
133
+ frame["tokens"] = frame["tokens"].astype(str).str.strip()
134
+
135
+ frame["text"] = frame["transcription"].where(frame["transcription"].ne(""), frame["source_sentence"])
136
+ frame["raw_text"] = frame["source_sentence"].where(frame["source_sentence"].ne(""), frame["transcription"])
137
+ frame["source"] = "fleurs"
138
+ frame["source_lang"] = source_lang
139
+ frame["model_lang"] = _normalize_model_lang(source_lang)
140
+ frame["split"] = _normalize_split_name(tsv_path.name)
141
+ frame["lang_iso3"] = frame["model_lang"].map(lambda lang: LANG_ISO2_TO_ISO3.get(lang, ""))
142
+ frame["language_name"] = source_lang
143
+ frame["text"] = frame["text"].astype(str).str.strip().replace("", pd.NA)
144
+ frame["raw_text"] = frame["raw_text"].astype(str).str.strip()
145
+ frame = frame[frame["text"].notna()].reset_index(drop=True)
146
+ return frame
147
+
148
+
149
+ def _post_process_fleurs_frame(frame: pd.DataFrame) -> pd.DataFrame:
150
+ """Drop redundant rows and keep only the lean demo columns."""
151
+ if frame.empty:
152
+ return frame
153
+
154
+ frame = frame.copy()
155
+ frame["split_rank"] = frame["split"].map(lambda split: FLEURS_SPLIT_ORDER.get(str(split), 99))
156
+ frame["text_key"] = frame["text"].astype(str).map(_normalize_text_key)
157
+ frame["id_sort"] = pd.to_numeric(frame["id"], errors="coerce").fillna(10**18)
158
+
159
+ frame = frame[frame["text_key"].ne("")].sort_values(
160
+ by=["source_lang", "text_key", "split_rank", "id_sort"],
161
+ kind="stable",
162
+ )
163
+ frame = frame.drop_duplicates(subset=["source_lang", "text_key"], keep="first")
164
+
165
+ lean = frame.loc[:, [col for col in FLEURS_LEAN_COLUMNS if col in frame.columns]].copy()
166
+ lean["text"] = frame["text"].astype(str).values
167
+ lean["source_lang"] = frame["source_lang"].astype(str).values
168
+ lean["model_lang"] = frame["model_lang"].astype(str).values
169
+ lean["split"] = frame["split"].astype(str).values
170
+ lean["id"] = pd.to_numeric(frame["id"], errors="coerce").fillna(-1).astype(int).values
171
+ lean = lean[lean["text"].astype(str).str.strip().ne("")].reset_index(drop=True)
172
+ return lean
173
+
174
+
175
+ def build_fleurs_text_parquet(
176
+ parquet_path: str | Path = FLEURS_PARQUET_PATH,
177
+ ) -> Path:
178
+ """Download FLEURS TSV metadata and persist a text-only parquet cache."""
179
+ parquet_path = Path(parquet_path)
180
+ parquet_path.parent.mkdir(parents=True, exist_ok=True)
181
+
182
+ frames: list[pd.DataFrame] = []
183
+ for repo_path in _discover_tsv_files():
184
+ source_lang = Path(repo_path).parent.name
185
+ tsv_path = _download_tsv(repo_path)
186
+ frame = _frame_from_tsv(tsv_path, source_lang)
187
+ if not frame.empty:
188
+ frames.append(frame)
189
+
190
+ if not frames:
191
+ raise RuntimeError("No rows were loaded from the FLEURS TSV metadata files.")
192
+
193
+ combined = pd.concat(frames, ignore_index=True)
194
+ before_rows = len(combined)
195
+ combined = _post_process_fleurs_frame(combined)
196
+ combined.to_parquet(parquet_path, index=False)
197
+ print(
198
+ f"Built lean FLEURS parquet with {len(combined):,} rows "
199
+ f"from {before_rows:,} raw rows and {len(combined.columns)} columns."
200
+ )
201
+ return parquet_path
202
+
203
+
204
+ @lru_cache(maxsize=1)
205
+ def load_fleurs_table(parquet_path: str | Path = FLEURS_PARQUET_PATH) -> pd.DataFrame:
206
+ """Load the cached FLEURS text-only parquet into memory."""
207
+ parquet_path = Path(parquet_path)
208
+ if not parquet_path.exists():
209
+ raise FileNotFoundError(
210
+ f"Missing FLEURS cache at {parquet_path}. "
211
+ "Run `./.venv/bin/python fleurs_cache.py` once while online to build it."
212
+ )
213
+
214
+ frame = pd.read_parquet(parquet_path)
215
+ if "text" not in frame.columns:
216
+ raise RuntimeError("FLEURS parquet cache is missing the text column.")
217
+ return frame
218
+
219
+
220
+ def _pick_random_rows(frame: pd.DataFrame, *, count: int) -> pd.DataFrame:
221
+ if frame.empty:
222
+ raise RuntimeError("FLEURS cache has no rows.")
223
+ sample_size = min(count, len(frame))
224
+ return frame.sample(n=sample_size)
225
+
226
+
227
+ def _row_to_sentence(row: pd.Series) -> dict[str, Any]:
228
+ source_lang = str(row.get("source_lang", "")).strip()
229
+ model_lang = str(row.get("model_lang", "")).strip()
230
+ lang_iso2 = model_lang or _normalize_model_lang(source_lang)
231
+ language = str(row.get("language_name", source_lang)).strip()
232
+ text = str(row.get("text", "")).strip()
233
+ return {
234
+ "text": text,
235
+ "raw_text": text,
236
+ "source": "fleurs",
237
+ "source_lang": source_lang,
238
+ "model_lang": model_lang or lang_iso2,
239
+ "lang_iso2": lang_iso2,
240
+ "lang_iso3": LANG_ISO2_TO_ISO3.get(lang_iso2, ""),
241
+ "language": language,
242
+ "split": str(row.get("split", "")).strip(),
243
+ "fleurs_id": int(row.get("id", -1)) if str(row.get("id", "-1")).strip().lstrip("-").isdigit() else -1,
244
+ }
245
+
246
+
247
+ def fetch_random_fleurs_sentence(
248
+ *,
249
+ attempts: int = 8,
250
+ parquet_path: str | Path = FLEURS_PARQUET_PATH,
251
+ ) -> dict[str, Any]:
252
+ """Fetch one random text sample from the cached FLEURS parquet."""
253
+ frame = load_fleurs_table(parquet_path)
254
+ candidate_frame = frame[frame["text"].astype(str).str.strip().ne("")]
255
+
256
+ supported = candidate_frame[candidate_frame["model_lang"].isin(ALL_LANGS)]
257
+ if not supported.empty:
258
+ candidate_frame = supported
259
+
260
+ for _ in range(max(1, attempts)):
261
+ row = _pick_random_rows(candidate_frame, count=1).iloc[0]
262
+ sentence = _row_to_sentence(row)
263
+ if sentence["text"]:
264
+ return sentence
265
+
266
+ raise RuntimeError("Unable to sample a random FLEURS sentence.")
267
+
268
+
269
+ def fetch_random_fleurs_sentence_mix(
270
+ *,
271
+ min_sentences: int = 2,
272
+ max_sentences: int = 3,
273
+ parquet_path: str | Path = FLEURS_PARQUET_PATH,
274
+ ) -> dict[str, Any]:
275
+ """Fetch 2-3 random FLEURS sentences from distinct languages and concatenate them."""
276
+ frame = load_fleurs_table(parquet_path)
277
+ candidate_frame = frame[frame["text"].astype(str).str.strip().ne("")]
278
+ supported = candidate_frame[candidate_frame["model_lang"].isin(ALL_LANGS)]
279
+ if not supported.empty:
280
+ candidate_frame = supported
281
+
282
+ min_sentences = max(1, min_sentences)
283
+ max_sentences = max(min_sentences, max_sentences)
284
+ count = random.randint(min_sentences, max_sentences)
285
+
286
+ distinct_langs = [lang for lang in candidate_frame["model_lang"].dropna().unique().tolist() if lang]
287
+ if not distinct_langs:
288
+ raise RuntimeError("No usable FLEURS languages were found in the cache.")
289
+
290
+ random.shuffle(distinct_langs)
291
+ chosen_langs = distinct_langs[: min(count, len(distinct_langs))]
292
+
293
+ rows = []
294
+ for lang in chosen_langs:
295
+ lang_rows = candidate_frame[candidate_frame["model_lang"] == lang]
296
+ rows.append(_pick_random_rows(lang_rows, count=1).iloc[0])
297
+
298
+ sentences = [_row_to_sentence(row) for row in rows]
299
+ combined_text = "\n\n".join(sentence["text"] for sentence in sentences if sentence["text"])
300
+ return {
301
+ "text": combined_text,
302
+ "sentences": sentences,
303
+ "lang_count": len(sentences),
304
+ "langs": [sentence["lang_iso2"] for sentence in sentences],
305
+ "lang_iso3s": [sentence["lang_iso3"] for sentence in sentences],
306
+ "source": "fleurs-mix",
307
+ }
308
+
309
+
310
+ def main() -> None:
311
+ parser = argparse.ArgumentParser(description="Build the cached text-only FLEURS parquet.")
312
+ parser.add_argument(
313
+ "--output",
314
+ default=str(FLEURS_PARQUET_PATH),
315
+ help="Output parquet path for the cached FLEURS text rows.",
316
+ )
317
+ args = parser.parse_args()
318
+ path = build_fleurs_text_parquet(args.output)
319
+ print(f"Wrote FLEURS text cache to {path}")
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()