lrec2026-llm-annotator / io_utils.py
dhuser's picture
Initial LREC LLM-as-Annotator app
a918698
raw
history blame
8.15 kB
"""Readers + writers: TSV (sandbox corpus + round-trip), JSON, CoNLL-U, JSONL.
Also: tokenization helpers and the alignment validator that enforces the
"don't add / remove / split / merge / reorder tokens" rule from the system
prompt.
"""
from __future__ import annotations
import csv
import io
import json
import unicodedata
from pathlib import Path
from typing import Iterable
from schemas import AnnotationSchema
# ---------------------------------------------------------------------------
# Tokenization
# ---------------------------------------------------------------------------
def tokenize(text: str, strategy: str = "whitespace") -> list[str]:
if strategy == "whitespace":
return text.split()
if strategy == "newline":
return [line for line in text.splitlines() if line.strip()]
if strategy == "as_is":
# treat the text as already-tokenized, one token per line (preferred)
if "\n" in text.strip():
return [line for line in text.splitlines() if line.strip()]
return text.split()
raise ValueError(f"Unknown tokenize strategy: {strategy}")
def validate_alignment(input_tokens: list[str], output_tokens: list[str]) -> list[str]:
"""Strict alignment check (kept for callers that want hard equality)."""
errors = []
if len(input_tokens) != len(output_tokens):
errors.append(
f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
)
return errors
for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
if a != b:
errors.append(f"Token {i}: input={a!r} but output surface={b!r}")
return errors
_PUNCT_TRAILING = ",.;:·•—–-«»\"'!?)]"
def align_or_warn(input_tokens: list[str], output_tokens: list[str]) -> tuple[str, list[str]]:
"""Lenient alignment. Returns (status, messages).
status ∈ { 'ok', 'length_mismatch', 'drift' }.
- 'ok' → identical surfaces (after NFC normalization).
- 'length_mismatch' → token count differs (caller should reject).
- 'drift' → same length but surfaces differ (case / unicode form /
punctuation). Caller should accept positionally and
force input surfaces; messages list the offenders.
"""
if len(input_tokens) != len(output_tokens):
return 'length_mismatch', [
f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
]
messages: list[str] = []
for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
if a == b:
continue
a_n = unicodedata.normalize('NFC', a)
b_n = unicodedata.normalize('NFC', b)
if a_n == b_n:
continue # silent: pure unicode-form drift
if a_n.lower() == b_n.lower():
messages.append(f"token {i}: case differs ({a!r} vs {b!r})")
continue
if a_n.rstrip(_PUNCT_TRAILING) == b_n.rstrip(_PUNCT_TRAILING):
messages.append(f"token {i}: trailing-punct differs ({a!r} vs {b!r})")
continue
messages.append(f"token {i}: surface differs ({a!r} vs {b!r})")
return ('drift' if messages else 'ok'), messages
# ---------------------------------------------------------------------------
# Sandbox TSV reader (matches EACL2026-historical-languages format)
# ---------------------------------------------------------------------------
def read_sandbox_tsv(path: Path, max_rows: int = -1) -> list[dict]:
"""Read text_form/form/lemma/pos TSV. Returns list of row dicts."""
rows: list[dict] = []
with open(path, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
for i, r in enumerate(reader):
rows.append(r)
if max_rows > 0 and len(rows) >= max_rows:
break
return rows
def sandbox_sentence(rows: list[dict], start: int, n_tokens: int) -> tuple[list[str], list[dict]]:
"""Slice a window of `n_tokens` from sandbox rows.
Returns (surface_tokens, gold_token_dicts) where each gold dict has
{"surface": text_form, "lemma": lemma, "pos": pos}.
"""
window = rows[start : start + n_tokens]
surfaces = [r["text_form"] for r in window]
gold = [{"surface": r["text_form"], "lemma": r["lemma"], "pos": r["pos"]} for r in window]
return surfaces, gold
# ---------------------------------------------------------------------------
# Exports
# ---------------------------------------------------------------------------
def export_tsv(annotation: dict, schema: AnnotationSchema) -> str:
"""Round-trip-friendly TSV.
Columns: text_form, form, lemma, pos when those fields exist; otherwise
text_form + each top-level schema field in order.
"""
tokens = annotation.get("tokens", [])
field_names = [f.name for f in schema.fields]
has_lemma = "lemma" in field_names
pos_key = next((k for k in ("pos", "upos") if k in field_names), None)
if has_lemma and pos_key:
header = ["text_form", "form", "lemma", "pos"]
rows = [header]
for t in tokens:
surf = t.get("surface", "")
rows.append([surf, surf.lower(), str(t.get("lemma", "") or ""), str(t.get(pos_key, "") or "")])
return "\n".join("\t".join(r) for r in rows) + "\n"
# generic
header = ["text_form"] + field_names
rows = [header]
for t in tokens:
row = [t.get("surface", "")]
for fname in field_names:
v = t.get(fname, "")
if isinstance(v, dict):
v = json.dumps(v, ensure_ascii=False)
row.append(str(v) if v is not None else "")
rows.append(row)
return "\n".join("\t".join(r) for r in rows) + "\n"
def export_json(annotation: dict) -> str:
return json.dumps(annotation, ensure_ascii=False, indent=2)
def export_conllu(annotation: dict, schema: AnnotationSchema) -> str:
"""Emit CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC.
Unknown fields use `_`. Multi-sentence input not supported (one sentence
per annotation); the caller can concatenate.
"""
field_names = [f.name for f in schema.fields]
lemma_key = "lemma" if "lemma" in field_names else None
upos_key = next((k for k in ("upos", "pos") if k in field_names), None)
feats_key = "features" if "features" in field_names else None
lines = []
sid = annotation.get("sentence_id", "s1")
lang = annotation.get("language", "")
lines.append(f"# sent_id = {sid}")
if lang:
lines.append(f"# language = {lang}")
surfaces = [t.get("surface", "") for t in annotation.get("tokens", [])]
lines.append(f"# text = {' '.join(surfaces)}")
for i, t in enumerate(annotation.get("tokens", []), 1):
form = t.get("surface", "_") or "_"
lemma = (t.get(lemma_key) or "_") if lemma_key else "_"
upos = (t.get(upos_key) or "_") if upos_key else "_"
if feats_key and isinstance(t.get(feats_key), dict):
kv = [f"{k}={v}" for k, v in t[feats_key].items() if v]
feats = "|".join(kv) if kv else "_"
else:
feats = "_"
lines.append("\t".join([str(i), form, str(lemma), str(upos), "_", feats, "_", "_", "_", "_"]))
return "\n".join(lines) + "\n\n"
def export_jsonl_finetune(
annotation: dict,
system_prompt: str,
user_prompt: str,
) -> str:
"""One-line JSON shaped for chat-style fine-tuning APIs."""
record = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False)},
]
}
return json.dumps(record, ensure_ascii=False) + "\n"
def write_temp(content: str, suffix: str) -> str:
"""Write to a temp file and return path (used by gr.File downloads)."""
import tempfile
fd, name = tempfile.mkstemp(suffix=suffix, prefix="lrec_annot_")
with open(fd, "w", encoding="utf-8") as f:
f.write(content)
return name