File size: 8,147 Bytes
a918698 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | """Readers + writers: TSV (sandbox corpus + round-trip), JSON, CoNLL-U, JSONL.
Also: tokenization helpers and the alignment validator that enforces the
"don't add / remove / split / merge / reorder tokens" rule from the system
prompt.
"""
from __future__ import annotations
import csv
import io
import json
import unicodedata
from pathlib import Path
from typing import Iterable
from schemas import AnnotationSchema
# ---------------------------------------------------------------------------
# Tokenization
# ---------------------------------------------------------------------------
def tokenize(text: str, strategy: str = "whitespace") -> list[str]:
if strategy == "whitespace":
return text.split()
if strategy == "newline":
return [line for line in text.splitlines() if line.strip()]
if strategy == "as_is":
# treat the text as already-tokenized, one token per line (preferred)
if "\n" in text.strip():
return [line for line in text.splitlines() if line.strip()]
return text.split()
raise ValueError(f"Unknown tokenize strategy: {strategy}")
def validate_alignment(input_tokens: list[str], output_tokens: list[str]) -> list[str]:
"""Strict alignment check (kept for callers that want hard equality)."""
errors = []
if len(input_tokens) != len(output_tokens):
errors.append(
f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
)
return errors
for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
if a != b:
errors.append(f"Token {i}: input={a!r} but output surface={b!r}")
return errors
_PUNCT_TRAILING = ",.;:·•—–-«»\"'!?)]"
def align_or_warn(input_tokens: list[str], output_tokens: list[str]) -> tuple[str, list[str]]:
"""Lenient alignment. Returns (status, messages).
status ∈ { 'ok', 'length_mismatch', 'drift' }.
- 'ok' → identical surfaces (after NFC normalization).
- 'length_mismatch' → token count differs (caller should reject).
- 'drift' → same length but surfaces differ (case / unicode form /
punctuation). Caller should accept positionally and
force input surfaces; messages list the offenders.
"""
if len(input_tokens) != len(output_tokens):
return 'length_mismatch', [
f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
]
messages: list[str] = []
for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
if a == b:
continue
a_n = unicodedata.normalize('NFC', a)
b_n = unicodedata.normalize('NFC', b)
if a_n == b_n:
continue # silent: pure unicode-form drift
if a_n.lower() == b_n.lower():
messages.append(f"token {i}: case differs ({a!r} vs {b!r})")
continue
if a_n.rstrip(_PUNCT_TRAILING) == b_n.rstrip(_PUNCT_TRAILING):
messages.append(f"token {i}: trailing-punct differs ({a!r} vs {b!r})")
continue
messages.append(f"token {i}: surface differs ({a!r} vs {b!r})")
return ('drift' if messages else 'ok'), messages
# ---------------------------------------------------------------------------
# Sandbox TSV reader (matches EACL2026-historical-languages format)
# ---------------------------------------------------------------------------
def read_sandbox_tsv(path: Path, max_rows: int = -1) -> list[dict]:
"""Read text_form/form/lemma/pos TSV. Returns list of row dicts."""
rows: list[dict] = []
with open(path, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
for i, r in enumerate(reader):
rows.append(r)
if max_rows > 0 and len(rows) >= max_rows:
break
return rows
def sandbox_sentence(rows: list[dict], start: int, n_tokens: int) -> tuple[list[str], list[dict]]:
"""Slice a window of `n_tokens` from sandbox rows.
Returns (surface_tokens, gold_token_dicts) where each gold dict has
{"surface": text_form, "lemma": lemma, "pos": pos}.
"""
window = rows[start : start + n_tokens]
surfaces = [r["text_form"] for r in window]
gold = [{"surface": r["text_form"], "lemma": r["lemma"], "pos": r["pos"]} for r in window]
return surfaces, gold
# ---------------------------------------------------------------------------
# Exports
# ---------------------------------------------------------------------------
def export_tsv(annotation: dict, schema: AnnotationSchema) -> str:
"""Round-trip-friendly TSV.
Columns: text_form, form, lemma, pos when those fields exist; otherwise
text_form + each top-level schema field in order.
"""
tokens = annotation.get("tokens", [])
field_names = [f.name for f in schema.fields]
has_lemma = "lemma" in field_names
pos_key = next((k for k in ("pos", "upos") if k in field_names), None)
if has_lemma and pos_key:
header = ["text_form", "form", "lemma", "pos"]
rows = [header]
for t in tokens:
surf = t.get("surface", "")
rows.append([surf, surf.lower(), str(t.get("lemma", "") or ""), str(t.get(pos_key, "") or "")])
return "\n".join("\t".join(r) for r in rows) + "\n"
# generic
header = ["text_form"] + field_names
rows = [header]
for t in tokens:
row = [t.get("surface", "")]
for fname in field_names:
v = t.get(fname, "")
if isinstance(v, dict):
v = json.dumps(v, ensure_ascii=False)
row.append(str(v) if v is not None else "")
rows.append(row)
return "\n".join("\t".join(r) for r in rows) + "\n"
def export_json(annotation: dict) -> str:
return json.dumps(annotation, ensure_ascii=False, indent=2)
def export_conllu(annotation: dict, schema: AnnotationSchema) -> str:
"""Emit CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC.
Unknown fields use `_`. Multi-sentence input not supported (one sentence
per annotation); the caller can concatenate.
"""
field_names = [f.name for f in schema.fields]
lemma_key = "lemma" if "lemma" in field_names else None
upos_key = next((k for k in ("upos", "pos") if k in field_names), None)
feats_key = "features" if "features" in field_names else None
lines = []
sid = annotation.get("sentence_id", "s1")
lang = annotation.get("language", "")
lines.append(f"# sent_id = {sid}")
if lang:
lines.append(f"# language = {lang}")
surfaces = [t.get("surface", "") for t in annotation.get("tokens", [])]
lines.append(f"# text = {' '.join(surfaces)}")
for i, t in enumerate(annotation.get("tokens", []), 1):
form = t.get("surface", "_") or "_"
lemma = (t.get(lemma_key) or "_") if lemma_key else "_"
upos = (t.get(upos_key) or "_") if upos_key else "_"
if feats_key and isinstance(t.get(feats_key), dict):
kv = [f"{k}={v}" for k, v in t[feats_key].items() if v]
feats = "|".join(kv) if kv else "_"
else:
feats = "_"
lines.append("\t".join([str(i), form, str(lemma), str(upos), "_", feats, "_", "_", "_", "_"]))
return "\n".join(lines) + "\n\n"
def export_jsonl_finetune(
annotation: dict,
system_prompt: str,
user_prompt: str,
) -> str:
"""One-line JSON shaped for chat-style fine-tuning APIs."""
record = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False)},
]
}
return json.dumps(record, ensure_ascii=False) + "\n"
def write_temp(content: str, suffix: str) -> str:
"""Write to a temp file and return path (used by gr.File downloads)."""
import tempfile
fd, name = tempfile.mkstemp(suffix=suffix, prefix="lrec_annot_")
with open(fd, "w", encoding="utf-8") as f:
f.write(content)
return name
|