File size: 8,147 Bytes
a918698
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
"""Readers + writers: TSV (sandbox corpus + round-trip), JSON, CoNLL-U, JSONL.

Also: tokenization helpers and the alignment validator that enforces the
"don't add / remove / split / merge / reorder tokens" rule from the system
prompt.
"""
from __future__ import annotations

import csv
import io
import json
import unicodedata
from pathlib import Path
from typing import Iterable

from schemas import AnnotationSchema


# ---------------------------------------------------------------------------
# Tokenization
# ---------------------------------------------------------------------------

def tokenize(text: str, strategy: str = "whitespace") -> list[str]:
    if strategy == "whitespace":
        return text.split()
    if strategy == "newline":
        return [line for line in text.splitlines() if line.strip()]
    if strategy == "as_is":
        # treat the text as already-tokenized, one token per line (preferred)
        if "\n" in text.strip():
            return [line for line in text.splitlines() if line.strip()]
        return text.split()
    raise ValueError(f"Unknown tokenize strategy: {strategy}")


def validate_alignment(input_tokens: list[str], output_tokens: list[str]) -> list[str]:
    """Strict alignment check (kept for callers that want hard equality)."""
    errors = []
    if len(input_tokens) != len(output_tokens):
        errors.append(
            f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
        )
        return errors
    for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
        if a != b:
            errors.append(f"Token {i}: input={a!r} but output surface={b!r}")
    return errors


_PUNCT_TRAILING = ",.;:·•—–-«»\"'!?)]"


def align_or_warn(input_tokens: list[str], output_tokens: list[str]) -> tuple[str, list[str]]:
    """Lenient alignment. Returns (status, messages).

    status ∈ { 'ok', 'length_mismatch', 'drift' }.
    - 'ok'              → identical surfaces (after NFC normalization).
    - 'length_mismatch' → token count differs (caller should reject).
    - 'drift'           → same length but surfaces differ (case / unicode form /
                          punctuation). Caller should accept positionally and
                          force input surfaces; messages list the offenders.
    """
    if len(input_tokens) != len(output_tokens):
        return 'length_mismatch', [
            f"Token count mismatch: input has {len(input_tokens)}, output has {len(output_tokens)}."
        ]
    messages: list[str] = []
    for i, (a, b) in enumerate(zip(input_tokens, output_tokens)):
        if a == b:
            continue
        a_n = unicodedata.normalize('NFC', a)
        b_n = unicodedata.normalize('NFC', b)
        if a_n == b_n:
            continue  # silent: pure unicode-form drift
        if a_n.lower() == b_n.lower():
            messages.append(f"token {i}: case differs ({a!r} vs {b!r})")
            continue
        if a_n.rstrip(_PUNCT_TRAILING) == b_n.rstrip(_PUNCT_TRAILING):
            messages.append(f"token {i}: trailing-punct differs ({a!r} vs {b!r})")
            continue
        messages.append(f"token {i}: surface differs ({a!r} vs {b!r})")
    return ('drift' if messages else 'ok'), messages


# ---------------------------------------------------------------------------
# Sandbox TSV reader (matches EACL2026-historical-languages format)
# ---------------------------------------------------------------------------

def read_sandbox_tsv(path: Path, max_rows: int = -1) -> list[dict]:
    """Read text_form/form/lemma/pos TSV. Returns list of row dicts."""
    rows: list[dict] = []
    with open(path, encoding="utf-8") as f:
        reader = csv.DictReader(f, delimiter="\t")
        for i, r in enumerate(reader):
            rows.append(r)
            if max_rows > 0 and len(rows) >= max_rows:
                break
    return rows


def sandbox_sentence(rows: list[dict], start: int, n_tokens: int) -> tuple[list[str], list[dict]]:
    """Slice a window of `n_tokens` from sandbox rows.

    Returns (surface_tokens, gold_token_dicts) where each gold dict has
    {"surface": text_form, "lemma": lemma, "pos": pos}.
    """
    window = rows[start : start + n_tokens]
    surfaces = [r["text_form"] for r in window]
    gold = [{"surface": r["text_form"], "lemma": r["lemma"], "pos": r["pos"]} for r in window]
    return surfaces, gold


# ---------------------------------------------------------------------------
# Exports
# ---------------------------------------------------------------------------

def export_tsv(annotation: dict, schema: AnnotationSchema) -> str:
    """Round-trip-friendly TSV.

    Columns: text_form, form, lemma, pos when those fields exist; otherwise
    text_form + each top-level schema field in order.
    """
    tokens = annotation.get("tokens", [])
    field_names = [f.name for f in schema.fields]
    has_lemma = "lemma" in field_names
    pos_key = next((k for k in ("pos", "upos") if k in field_names), None)
    if has_lemma and pos_key:
        header = ["text_form", "form", "lemma", "pos"]
        rows = [header]
        for t in tokens:
            surf = t.get("surface", "")
            rows.append([surf, surf.lower(), str(t.get("lemma", "") or ""), str(t.get(pos_key, "") or "")])
        return "\n".join("\t".join(r) for r in rows) + "\n"
    # generic
    header = ["text_form"] + field_names
    rows = [header]
    for t in tokens:
        row = [t.get("surface", "")]
        for fname in field_names:
            v = t.get(fname, "")
            if isinstance(v, dict):
                v = json.dumps(v, ensure_ascii=False)
            row.append(str(v) if v is not None else "")
        rows.append(row)
    return "\n".join("\t".join(r) for r in rows) + "\n"


def export_json(annotation: dict) -> str:
    return json.dumps(annotation, ensure_ascii=False, indent=2)


def export_conllu(annotation: dict, schema: AnnotationSchema) -> str:
    """Emit CoNLL-U: ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC.

    Unknown fields use `_`. Multi-sentence input not supported (one sentence
    per annotation); the caller can concatenate.
    """
    field_names = [f.name for f in schema.fields]
    lemma_key = "lemma" if "lemma" in field_names else None
    upos_key = next((k for k in ("upos", "pos") if k in field_names), None)
    feats_key = "features" if "features" in field_names else None

    lines = []
    sid = annotation.get("sentence_id", "s1")
    lang = annotation.get("language", "")
    lines.append(f"# sent_id = {sid}")
    if lang:
        lines.append(f"# language = {lang}")
    surfaces = [t.get("surface", "") for t in annotation.get("tokens", [])]
    lines.append(f"# text = {' '.join(surfaces)}")
    for i, t in enumerate(annotation.get("tokens", []), 1):
        form = t.get("surface", "_") or "_"
        lemma = (t.get(lemma_key) or "_") if lemma_key else "_"
        upos = (t.get(upos_key) or "_") if upos_key else "_"
        if feats_key and isinstance(t.get(feats_key), dict):
            kv = [f"{k}={v}" for k, v in t[feats_key].items() if v]
            feats = "|".join(kv) if kv else "_"
        else:
            feats = "_"
        lines.append("\t".join([str(i), form, str(lemma), str(upos), "_", feats, "_", "_", "_", "_"]))
    return "\n".join(lines) + "\n\n"


def export_jsonl_finetune(
    annotation: dict,
    system_prompt: str,
    user_prompt: str,
) -> str:
    """One-line JSON shaped for chat-style fine-tuning APIs."""
    record = {
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
            {"role": "assistant", "content": json.dumps(annotation, ensure_ascii=False)},
        ]
    }
    return json.dumps(record, ensure_ascii=False) + "\n"


def write_temp(content: str, suffix: str) -> str:
    """Write to a temp file and return path (used by gr.File downloads)."""
    import tempfile
    fd, name = tempfile.mkstemp(suffix=suffix, prefix="lrec_annot_")
    with open(fd, "w", encoding="utf-8") as f:
        f.write(content)
    return name