"""Readers + writers: TSV (sandbox corpus + round-trip), JSON, CoNLL-U, JSONL. Also: tokenization helpers and the alignment validator that enforces the "don't add / remove / split / merge / reorder tokens" rule from the system prompt. """ from __future__ import annotations import csv import io import json import unicodedata from pathlib import Path from typing import Iterable from schemas import AnnotationSchema # --------------------------------------------------------------------------- # Tokenization # --------------------------------------------------------------------------- def tokenize(text: str, strategy: str = "whitespace") -> list[str]: if strategy == "whitespace": return text.split() if strategy == "newline": return [line for line in text.splitlines() if line.strip()] if strategy == "as_is": # treat the text as already-tokenized, one token per line (preferred) if "\n" in text.strip(): return [line for line in text.splitlines() if line.strip()] return text.split() raise ValueError(f"Unknown tokenize strategy: {strategy}") def validate_alignment(input_tokens: list[str], output_tokens: list[str]) -> list[str]: """Strict alignment check (kept for callers that want hard equality).""" errors = [] if len(input_tokens) != len(output_tokens): errors.append( f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}." ) return errors for i, (a, b) in enumerate(zip(input_tokens, output_tokens)): if a != b: errors.append(f"Token {i}: input={a!r} but output surface={b!r}") return errors _PUNCT_TRAILING = ",.;:·•—–-«»\"'!?)]" def align_or_warn(input_tokens: list[str], output_tokens: list[str]) -> tuple[str, list[str]]: """Lenient alignment. Returns (status, messages). status ∈ { 'ok', 'length_mismatch', 'drift' }. - 'ok' → identical surfaces (after NFC normalization). - 'length_mismatch' → token count differs (caller should reject). - 'drift' → same length but surfaces differ (case / unicode form / punctuation). Caller should accept positionally and force input surfaces; messages list the offenders. """ if len(input_tokens) != len(output_tokens): return 'length_mismatch', [ f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}." ] messages: list[str] = [] for i, (a, b) in enumerate(zip(input_tokens, output_tokens)): if a == b: continue a_n = unicodedata.normalize('NFC', a) b_n = unicodedata.normalize('NFC', b) if a_n == b_n: continue # silent: pure unicode-form drift if a_n.lower() == b_n.lower(): messages.append(f"token {i}: case differs ({a!r} vs {b!r})") continue if a_n.rstrip(_PUNCT_TRAILING) == b_n.rstrip(_PUNCT_TRAILING): messages.append(f"token {i}: trailing-punct differs ({a!r} vs {b!r})") continue messages.append(f"token {i}: surface differs ({a!r} vs {b!r})") return ('drift' if messages else 'ok'), messages # --------------------------------------------------------------------------- # Sandbox TSV reader (matches EACL2026-historical-languages format) # --------------------------------------------------------------------------- def read_sandbox_tsv(path: Path, max_rows: int = -1) -> list[dict]: """Read text_form/form/lemma/pos TSV. Returns list of row dicts.""" rows: list[dict] = [] with open(path, encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="\t") for i, r in enumerate(reader): rows.append(r) if max_rows > 0 and len(rows) >= max_rows: break return rows def sandbox_sentence(rows: list[dict], start: int, n_tokens: int) -> tuple[list[str], list[dict]]: """Slice a window of `n_tokens` from sandbox rows. Returns (surface_tokens, gold_token_dicts) where each gold dict has {"surface": text_form, "lemma": lemma, "pos": pos}. """ window = rows[start : start + n_tokens] surfaces = [r["text_form"] for r in window] gold = [{"surface": r["text_form"], "lemma": r["lemma"], "pos": r["pos"]} for r in window] return surfaces, gold # --------------------------------------------------------------------------- # Exports # --------------------------------------------------------------------------- def export_tsv(annotation: dict, schema: AnnotationSchema) -> str: """Round-trip-friendly TSV. Columns: text_form, form, lemma, pos when those fields exist; otherwise text_form + each top-level schema field in order. """ tokens = annotation.get("tokens", []) field_names = [f.name for f in schema.fields] has_lemma = "lemma" in field_names pos_key = next((k for k in ("pos", "upos") if k in field_names), None) if has_lemma and pos_key: header = ["text_form", "form", "lemma", "pos"] rows = [header] for t in tokens: surf = t.get("surface", "") rows.append([surf, surf.lower(), str(t.get("lemma", "") or ""), str(t.get(pos_key, "") or "")]) return "\n".join("\t".join(r) for r in rows) + "\n" # generic header = ["text_form"] + field_names rows = [header] for t in tokens: row = [t.get("surface", "")] for fname in field_names: v = t.get(fname, "") if isinstance(v, dict): v = json.dumps(v, ensure_ascii=False) row.append(str(v) if v is not None else "") rows.append(row) return "\n".join("\t".join(r) for r in rows) + "\n" def export_json(annotation: dict) -> str: return json.dumps(annotation, ensure_ascii=False, indent=2) def export_conllu(annotation: dict, schema: AnnotationSchema) -> str: """Emit CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC. Unknown fields use `_`. Multi-sentence input not supported (one sentence per annotation); the caller can concatenate. """ field_names = [f.name for f in schema.fields] lemma_key = "lemma" if "lemma" in field_names else None upos_key = next((k for k in ("upos", "pos") if k in field_names), None) feats_key = "features" if "features" in field_names else None lines = [] sid = annotation.get("sentence_id", "s1") lang = annotation.get("language", "") lines.append(f"# sent_id = {sid}") if lang: lines.append(f"# language = {lang}") surfaces = [t.get("surface", "") for t in annotation.get("tokens", [])] lines.append(f"# text = {' '.join(surfaces)}") for i, t in enumerate(annotation.get("tokens", []), 1): form = t.get("surface", "_") or "_" lemma = (t.get(lemma_key) or "_") if lemma_key else "_" upos = (t.get(upos_key) or "_") if upos_key else "_" if feats_key and isinstance(t.get(feats_key), dict): kv = [f"{k}={v}" for k, v in t[feats_key].items() if v] feats = "|".join(kv) if kv else "_" else: feats = "_" lines.append("\t".join([str(i), form, str(lemma), str(upos), "_", feats, "_", "_", "_", "_"])) return "\n".join(lines) + "\n\n" def export_jsonl_finetune( annotation: dict, system_prompt: str, user_prompt: str, ) -> str: """One-line JSON shaped for chat-style fine-tuning APIs.""" record = { "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False)}, ] } return json.dumps(record, ensure_ascii=False) + "\n" def write_temp(content: str, suffix: str) -> str: """Write to a temp file and return path (used by gr.File downloads).""" import tempfile fd, name = tempfile.mkstemp(suffix=suffix, prefix="lrec_annot_") with open(fd, "w", encoding="utf-8") as f: f.write(content) return name