| """Readers + writers: TSV (sandbox corpus + round-trip), JSON, CoNLL-U, JSONL. |
| |
| Also: tokenization helpers and the alignment validator that enforces the |
| "don't add / remove / split / merge / reorder tokens" rule from the system |
| prompt. |
| """ |
| from __future__ import annotations |
|
|
| import csv |
| import io |
| import json |
| import unicodedata |
| from pathlib import Path |
| from typing import Iterable |
|
|
| from schemas import AnnotationSchema |
|
|
|
|
| |
| |
| |
|
|
| def tokenize(text: str, strategy: str = "whitespace") -> list[str]: |
| if strategy == "whitespace": |
| return text.split() |
| if strategy == "newline": |
| return [line for line in text.splitlines() if line.strip()] |
| if strategy == "as_is": |
| |
| if "\n" in text.strip(): |
| return [line for line in text.splitlines() if line.strip()] |
| return text.split() |
| raise ValueError(f"Unknown tokenize strategy: {strategy}") |
|
|
|
|
| def validate_alignment(input_tokens: list[str], output_tokens: list[str]) -> list[str]: |
| """Strict alignment check (kept for callers that want hard equality).""" |
| errors = [] |
| if len(input_tokens) != len(output_tokens): |
| errors.append( |
| f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}." |
| ) |
| return errors |
| for i, (a, b) in enumerate(zip(input_tokens, output_tokens)): |
| if a != b: |
| errors.append(f"Token {i}: input={a!r} but output surface={b!r}") |
| return errors |
|
|
|
|
| _PUNCT_TRAILING = ",.;:·•—–-«»\"'!?)]" |
|
|
|
|
| def align_or_warn(input_tokens: list[str], output_tokens: list[str]) -> tuple[str, list[str]]: |
| """Lenient alignment. Returns (status, messages). |
| |
| status ∈ { 'ok', 'length_mismatch', 'drift' }. |
| - 'ok' → identical surfaces (after NFC normalization). |
| - 'length_mismatch' → token count differs (caller should reject). |
| - 'drift' → same length but surfaces differ (case / unicode form / |
| punctuation). Caller should accept positionally and |
| force input surfaces; messages list the offenders. |
| """ |
| if len(input_tokens) != len(output_tokens): |
| return 'length_mismatch', [ |
| f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}." |
| ] |
| messages: list[str] = [] |
| for i, (a, b) in enumerate(zip(input_tokens, output_tokens)): |
| if a == b: |
| continue |
| a_n = unicodedata.normalize('NFC', a) |
| b_n = unicodedata.normalize('NFC', b) |
| if a_n == b_n: |
| continue |
| if a_n.lower() == b_n.lower(): |
| messages.append(f"token {i}: case differs ({a!r} vs {b!r})") |
| continue |
| if a_n.rstrip(_PUNCT_TRAILING) == b_n.rstrip(_PUNCT_TRAILING): |
| messages.append(f"token {i}: trailing-punct differs ({a!r} vs {b!r})") |
| continue |
| messages.append(f"token {i}: surface differs ({a!r} vs {b!r})") |
| return ('drift' if messages else 'ok'), messages |
|
|
|
|
| |
| |
| |
|
|
| def read_sandbox_tsv(path: Path, max_rows: int = -1) -> list[dict]: |
| """Read text_form/form/lemma/pos TSV. Returns list of row dicts.""" |
| rows: list[dict] = [] |
| with open(path, encoding="utf-8") as f: |
| reader = csv.DictReader(f, delimiter="\t") |
| for i, r in enumerate(reader): |
| rows.append(r) |
| if max_rows > 0 and len(rows) >= max_rows: |
| break |
| return rows |
|
|
|
|
| def sandbox_sentence(rows: list[dict], start: int, n_tokens: int) -> tuple[list[str], list[dict]]: |
| """Slice a window of `n_tokens` from sandbox rows. |
| |
| Returns (surface_tokens, gold_token_dicts) where each gold dict has |
| {"surface": text_form, "lemma": lemma, "pos": pos}. |
| """ |
| window = rows[start : start + n_tokens] |
| surfaces = [r["text_form"] for r in window] |
| gold = [{"surface": r["text_form"], "lemma": r["lemma"], "pos": r["pos"]} for r in window] |
| return surfaces, gold |
|
|
|
|
| |
| |
| |
|
|
| def export_tsv(annotation: dict, schema: AnnotationSchema) -> str: |
| """Round-trip-friendly TSV. |
| |
| Columns: text_form, form, lemma, pos when those fields exist; otherwise |
| text_form + each top-level schema field in order. |
| """ |
| tokens = annotation.get("tokens", []) |
| field_names = [f.name for f in schema.fields] |
| has_lemma = "lemma" in field_names |
| pos_key = next((k for k in ("pos", "upos") if k in field_names), None) |
| if has_lemma and pos_key: |
| header = ["text_form", "form", "lemma", "pos"] |
| rows = [header] |
| for t in tokens: |
| surf = t.get("surface", "") |
| rows.append([surf, surf.lower(), str(t.get("lemma", "") or ""), str(t.get(pos_key, "") or "")]) |
| return "\n".join("\t".join(r) for r in rows) + "\n" |
| |
| header = ["text_form"] + field_names |
| rows = [header] |
| for t in tokens: |
| row = [t.get("surface", "")] |
| for fname in field_names: |
| v = t.get(fname, "") |
| if isinstance(v, dict): |
| v = json.dumps(v, ensure_ascii=False) |
| row.append(str(v) if v is not None else "") |
| rows.append(row) |
| return "\n".join("\t".join(r) for r in rows) + "\n" |
|
|
|
|
| def export_json(annotation: dict) -> str: |
| return json.dumps(annotation, ensure_ascii=False, indent=2) |
|
|
|
|
| def export_conllu(annotation: dict, schema: AnnotationSchema) -> str: |
| """Emit CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC. |
| |
| Unknown fields use `_`. Multi-sentence input not supported (one sentence |
| per annotation); the caller can concatenate. |
| """ |
| field_names = [f.name for f in schema.fields] |
| lemma_key = "lemma" if "lemma" in field_names else None |
| upos_key = next((k for k in ("upos", "pos") if k in field_names), None) |
| feats_key = "features" if "features" in field_names else None |
|
|
| lines = [] |
| sid = annotation.get("sentence_id", "s1") |
| lang = annotation.get("language", "") |
| lines.append(f"# sent_id = {sid}") |
| if lang: |
| lines.append(f"# language = {lang}") |
| surfaces = [t.get("surface", "") for t in annotation.get("tokens", [])] |
| lines.append(f"# text = {' '.join(surfaces)}") |
| for i, t in enumerate(annotation.get("tokens", []), 1): |
| form = t.get("surface", "_") or "_" |
| lemma = (t.get(lemma_key) or "_") if lemma_key else "_" |
| upos = (t.get(upos_key) or "_") if upos_key else "_" |
| if feats_key and isinstance(t.get(feats_key), dict): |
| kv = [f"{k}={v}" for k, v in t[feats_key].items() if v] |
| feats = "|".join(kv) if kv else "_" |
| else: |
| feats = "_" |
| lines.append("\t".join([str(i), form, str(lemma), str(upos), "_", feats, "_", "_", "_", "_"])) |
| return "\n".join(lines) + "\n\n" |
|
|
|
|
| def export_jsonl_finetune( |
| annotation: dict, |
| system_prompt: str, |
| user_prompt: str, |
| ) -> str: |
| """One-line JSON shaped for chat-style fine-tuning APIs.""" |
| record = { |
| "messages": [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False)}, |
| ] |
| } |
| return json.dumps(record, ensure_ascii=False) + "\n" |
|
|
|
|
| def write_temp(content: str, suffix: str) -> str: |
| """Write to a temp file and return path (used by gr.File downloads).""" |
| import tempfile |
| fd, name = tempfile.mkstemp(suffix=suffix, prefix="lrec_annot_") |
| with open(fd, "w", encoding="utf-8") as f: |
| f.write(content) |
| return name |
|
|