lterriel commited on
Commit
834036c
·
1 Parent(s): cdf702a

fix ICL pool rules for version and exemple counter

Browse files
Files changed (4) hide show
  1. app.py +62 -29
  2. prompts.py +25 -1
  3. provider.py +17 -4
  4. static/app.js +109 -49
app.py CHANGED
@@ -5,6 +5,7 @@ file exposes a small REST API and a tiny in-memory session store. State is
5
  ephemeral and per-process; perfect for a single-user demo or HF Space.
6
  """
7
  from __future__ import annotations
 
8
 
9
  import asyncio
10
  import os
@@ -362,22 +363,68 @@ def reset_all():
362
 
363
 
364
  # --- token edit ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  @app.post("/api/sentence/{idx}/token/{tidx}")
367
  def update_token(idx: int, tidx: int, req: TokenUpdateReq):
368
  sents = SESSION["sentences"]
 
369
  if idx < 0 or idx >= len(sents):
370
  raise HTTPException(404, "Bad sentence idx")
371
  if tidx < 0 or tidx >= len(sents[idx]["tokens"]):
372
  raise HTTPException(404, "Bad token idx")
373
- # Preserve surface (never editable)
374
- surface = sents[idx]["tokens"][tidx]["surface"]
 
 
 
375
  new_tok = {**req.token, "surface": surface}
376
- sents[idx]["tokens"][tidx] = new_tok
377
- # Remove this token from disagreement list if it was there
378
- sents[idx]["disagreements"] = [d for d in sents[idx]["disagreements"] if d["token_idx"] != tidx]
379
- sents[idx]["n_disagreements"] = len(sents[idx]["disagreements"])
380
- return sents[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
 
383
  @app.post("/api/bulk_similar")
@@ -449,30 +496,16 @@ def bulk_update(idx: int, payload: dict):
449
 
450
 
451
  # --- ICL pool --------------------------------------------------------------
452
-
453
  @app.post("/api/sentence/{idx}/add_to_icl")
454
  def add_sentence_to_icl(idx: int):
455
- sents = SESSION["sentences"]
456
- if idx < 0 or idx >= len(sents):
457
- raise HTTPException(404, "Bad sentence idx")
458
- sent = sents[idx]
459
- schema_obj = schema_from_dict(SESSION["schema"])
460
- pool: ICLPool = SESSION["icl_pool"]
461
- ann = {
462
- "sentence_id": sent["id"],
463
- "language": sent["language"] or SESSION["language"],
464
- "tokens": sent["tokens"],
465
- }
466
- pool.add(ICLExample(
467
- language=sent["language"] or SESSION["language"] or "",
468
- schema_hash=schema_obj.hash(),
469
- tokens=[t["surface"] for t in sent["tokens"]],
470
- gold_annotation=ann,
471
- source="corrected",
472
- ))
473
- # Adding to ICL implies the user accepts this annotation as gold → mark validated.
474
- sent["validated"] = True
475
- return _public_state()
476
 
477
 
478
  @app.post("/api/sentence/{idx}/sent_score")
 
5
  ephemeral and per-process; perfect for a single-user demo or HF Space.
6
  """
7
  from __future__ import annotations
8
+ from copy import deepcopy
9
 
10
  import asyncio
11
  import os
 
363
 
364
 
365
  # --- token edit ------------------------------------------------------------
366
+ def _add_or_update_sentence_in_icl(idx: int) -> str:
367
+ sents = SESSION["sentences"]
368
+ if idx < 0 or idx >= len(sents):
369
+ raise HTTPException(404, "Bad sentence idx")
370
+ sent = sents[idx]
371
+ schema_obj = schema_from_dict(SESSION["schema"])
372
+ pool: ICLPool = SESSION["icl_pool"]
373
+ tokens_snapshot = deepcopy(sent["tokens"])
374
+ ann = {
375
+ "sentence_id": sent["id"],
376
+ "language": sent["language"] or SESSION["language"],
377
+ "tokens": tokens_snapshot,
378
+ }
379
+
380
+ result = pool.add(ICLExample(
381
+ language=sent["language"] or SESSION["language"] or "",
382
+ schema_hash=schema_obj.hash(),
383
+ tokens=[t["surface"] for t in tokens_snapshot],
384
+ gold_annotation=ann,
385
+ source="corrected",
386
+ ))
387
+
388
+ sent["validated"] = True
389
+ return result
390
+
391
 
392
  @app.post("/api/sentence/{idx}/token/{tidx}")
393
  def update_token(idx: int, tidx: int, req: TokenUpdateReq):
394
  sents = SESSION["sentences"]
395
+
396
  if idx < 0 or idx >= len(sents):
397
  raise HTTPException(404, "Bad sentence idx")
398
  if tidx < 0 or tidx >= len(sents[idx]["tokens"]):
399
  raise HTTPException(404, "Bad token idx")
400
+
401
+ sent = sents[idx]
402
+ was_validated = bool(sent.get("validated"))
403
+
404
+ surface = sent["tokens"][tidx]["surface"]
405
  new_tok = {**req.token, "surface": surface}
406
+ sent["tokens"][tidx] = new_tok
407
+
408
+ sent["disagreements"] = [
409
+ d for d in sent["disagreements"]
410
+ if d["token_idx"] != tidx
411
+ ]
412
+ sent["n_disagreements"] = len(sent["disagreements"])
413
+
414
+ icl_result = None
415
+
416
+ # If sentence in ICL pool already, update it. If not, add it. This way we keep the pool in sync with user corrections.
417
+ if was_validated:
418
+ icl_result = _add_or_update_sentence_in_icl(idx)
419
+
420
+ state = _public_state()
421
+ state["updated_sentence_idx"] = idx
422
+ state["icl_add_result"] = icl_result
423
+ state["icl_duplicate"] = icl_result == "unchanged"
424
+ state["icl_updated"] = icl_result == "updated"
425
+ state["icl_inserted"] = icl_result == "inserted"
426
+
427
+ return state
428
 
429
 
430
  @app.post("/api/bulk_similar")
 
496
 
497
 
498
  # --- ICL pool --------------------------------------------------------------
 
499
  @app.post("/api/sentence/{idx}/add_to_icl")
500
  def add_sentence_to_icl(idx: int):
501
+ result = _add_or_update_sentence_in_icl(idx)
502
+
503
+ state = _public_state()
504
+ state["icl_add_result"] = result
505
+ state["icl_duplicate"] = result == "unchanged"
506
+ state["icl_updated"] = result == "updated"
507
+ state["icl_inserted"] = result == "inserted"
508
+ return state
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
 
511
  @app.post("/api/sentence/{idx}/sent_score")
prompts.py CHANGED
@@ -5,6 +5,7 @@ written material. ICLPool keeps a session-scoped, filterable bank of validated
5
  or corrected examples.
6
  """
7
  from __future__ import annotations
 
8
 
9
  import json
10
  import random
@@ -30,6 +31,8 @@ class ICLExample:
30
  note: str = ""
31
 
32
 
 
 
33
  @dataclass
34
  class ICLPool:
35
  """Session-scoped pool of in-context examples.
@@ -40,9 +43,30 @@ class ICLPool:
40
  entries: list[ICLExample] = field(default_factory=list)
41
  version: int = 0
42
 
43
- def add(self, ex: ICLExample) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  self.entries.append(ex)
45
  self.version += 1
 
46
 
47
  def filter(self, language: str = "", schema_hash: str = "") -> list[ICLExample]:
48
  out = self.entries
 
5
  or corrected examples.
6
  """
7
  from __future__ import annotations
8
+ from copy import deepcopy
9
 
10
  import json
11
  import random
 
31
  note: str = ""
32
 
33
 
34
+
35
+
36
  @dataclass
37
  class ICLPool:
38
  """Session-scoped pool of in-context examples.
 
43
  entries: list[ICLExample] = field(default_factory=list)
44
  version: int = 0
45
 
46
+ def _key(self, ex: ICLExample) -> tuple[str, str, tuple[str, ...]]:
47
+ return (
48
+ ex.language or "",
49
+ ex.schema_hash or "",
50
+ tuple(ex.tokens or []),
51
+ )
52
+
53
+ def _same_content(self, a: ICLExample, b: ICLExample) -> bool:
54
+ return a.gold_annotation == b.gold_annotation
55
+
56
+ def add(self, ex: ICLExample) -> str:
57
+ ex = deepcopy(ex)
58
+ key = self._key(ex)
59
+
60
+ for i, existing in enumerate(self.entries):
61
+ if self._key(existing) == key:
62
+ if self._same_content(existing, ex):
63
+ return "unchanged"
64
+ self.entries[i] = ex
65
+ self.version += 1
66
+ return "updated"
67
  self.entries.append(ex)
68
  self.version += 1
69
+ return "inserted"
70
 
71
  def filter(self, language: str = "", schema_hash: str = "") -> list[ICLExample]:
72
  out = self.entries
provider.py CHANGED
@@ -130,6 +130,7 @@ class LLMClient:
130
  timeout: float = DEFAULT_TIMEOUT,
131
  ) -> ModelResult:
132
  """Call one model, validate JSON. One retry on schema-validation failure."""
 
133
  json_schema = to_json_schema(schema)
134
  start = time.time()
135
  msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
@@ -149,13 +150,17 @@ class LLMClient:
149
  raw_text = await self._call(client, msgs, json_schema, model, temperature)
150
  ann, err = self._parse_and_validate(raw_text, schema)
151
  if err:
 
 
152
  return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=err,
153
  raw=raw_text)
 
154
  return ModelResult(model=model, ok=True, annotation=ann, latency_s=time.time() - start, raw=raw_text)
155
  finally:
156
  if close_after:
157
  await client.aclose()
158
  except Exception as e:
 
159
  return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=str(e))
160
 
161
  async def annotate_many(
@@ -180,17 +185,25 @@ class LLMClient:
180
  temperature: float) -> str:
181
  # Strict json_schema works on OpenAI and most OpenRouter models. For Mistral and
182
  # for some open-source models routed via OpenRouter, fall back to json_object.
183
- if self.provider == "mistral":
184
  payload = {
185
- "model": model, "messages": msgs, "temperature": temperature,
 
 
186
  "response_format": {"type": "json_object"},
187
  }
188
  else:
189
  payload = {
190
- "model": model, "messages": msgs, "temperature": temperature,
 
 
191
  "response_format": {
192
  "type": "json_schema",
193
- "json_schema": {"name": "annotation", "strict": True, "schema": json_schema},
 
 
 
 
194
  },
195
  }
196
  resp = await client.post(self.endpoint, headers=self.headers, json=payload)
 
130
  timeout: float = DEFAULT_TIMEOUT,
131
  ) -> ModelResult:
132
  """Call one model, validate JSON. One retry on schema-validation failure."""
133
+ print(f"[LLM] start provider={self.provider} model={model}")
134
  json_schema = to_json_schema(schema)
135
  start = time.time()
136
  msgs = [{"role": "system", "content": system}, {"role": "user", "content": user}]
 
150
  raw_text = await self._call(client, msgs, json_schema, model, temperature)
151
  ann, err = self._parse_and_validate(raw_text, schema)
152
  if err:
153
+ print(
154
+ f"[LLM] error provider={self.provider} model={model} latency={time.time() - start:.2f}s error={e}")
155
  return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=err,
156
  raw=raw_text)
157
+ print(f"[LLM] done provider={self.provider} model={model} latency={time.time() - start:.2f}s")
158
  return ModelResult(model=model, ok=True, annotation=ann, latency_s=time.time() - start, raw=raw_text)
159
  finally:
160
  if close_after:
161
  await client.aclose()
162
  except Exception as e:
163
+ print(f"[LLM] error provider={self.provider} model={model} latency={time.time() - start:.2f}s error={e}")
164
  return ModelResult(model=model, ok=False, annotation=None, latency_s=time.time() - start, error=str(e))
165
 
166
  async def annotate_many(
 
185
  temperature: float) -> str:
186
  # Strict json_schema works on OpenAI and most OpenRouter models. For Mistral and
187
  # for some open-source models routed via OpenRouter, fall back to json_object.
188
+ if self.provider in {"mistral", "ilaas"}:
189
  payload = {
190
+ "model": model,
191
+ "messages": msgs,
192
+ "temperature": temperature,
193
  "response_format": {"type": "json_object"},
194
  }
195
  else:
196
  payload = {
197
+ "model": model,
198
+ "messages": msgs,
199
+ "temperature": temperature,
200
  "response_format": {
201
  "type": "json_schema",
202
+ "json_schema": {
203
+ "name": "annotation",
204
+ "strict": True,
205
+ "schema": json_schema,
206
+ },
207
  },
208
  }
209
  resp = await client.post(self.endpoint, headers=self.headers, json=payload)
static/app.js CHANGED
@@ -769,8 +769,31 @@ function annotator() {
769
 
770
  async addSentenceToIcl(sidx) {
771
  const r = await fetch(`/api/sentence/${sidx}/add_to_icl`, {method: 'POST'});
772
- this.applyState(await r.json());
773
- this.toast(`Added to ICL pool (v${this.state.icl_pool.version}, ${this.state.icl_pool.size} entries).`, 'ok');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  },
775
 
776
  async setValidated(sidx, value) {
@@ -889,60 +912,97 @@ function annotator() {
889
  },
890
 
891
  async saveToken() {
892
- const sidx = this.editor.sidx, tidx = this.editor.tidx;
893
- const surface = this.editor.tok.surface;
894
- const changes = this.fieldChanges();
895
- const wantPropagate = this.editor.propagateToSimilar && Object.keys(changes).length > 0 && this.matchingTokenCount() > 0;
896
-
897
- this.editor.tok._corrected = true;
898
- const r = await fetch(`/api/sentence/${sidx}/token/${tidx}`, {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  method: 'POST',
900
  headers: {'Content-Type': 'application/json'},
901
- body: JSON.stringify({token: this.editor.tok})
 
 
 
 
902
  });
903
- if (!r.ok) {
904
- this.toast('Save failed.', 'error');
905
- return;
906
- }
907
- const sent = await r.json();
908
- this.replaceSentence(sidx, sent);
909
 
910
- let propagatedCount = 0;
911
- if (wantPropagate) {
912
- try {
913
- const r2 = await fetch('/api/bulk_similar', {
914
- method: 'POST',
915
- headers: {'Content-Type': 'application/json'},
916
- body: JSON.stringify({
917
- surface,
918
- updates: changes,
919
- exclude: [{s: sidx, t: tidx}],
920
- }),
921
- });
922
- if (r2.ok) {
923
- const j = await r2.json();
924
- for (const item of (j.sentences || [])) {
925
- this.replaceSentence(item.idx, item.sentence);
926
- }
927
- propagatedCount = (j.affected || []).length;
928
- }
929
- } catch (e) {
930
- this.toast('Propagation failed: ' + e.message, 'error');
931
- }
932
- }
933
 
934
- // auto-advance
935
- if (this.editor.autoAdvance) {
936
- const next = this.findNextDisagreement(sidx, tidx);
937
- if (next) {
938
- this.openTokenEditor(next.s, next.t);
939
- if (propagatedCount > 0) this.toast(`✓ Saved + propagated to ${propagatedCount} other "${surface}".`, 'ok');
940
- return;
941
  }
 
 
942
  }
943
- this.closeModal();
944
- this.toast(propagatedCount > 0 ? `✓ Saved + propagated to ${propagatedCount} other "${surface}".` : '✓ Saved.', 'ok');
945
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
946
 
947
  findNextDisagreement(sidx, tidx) {
948
  const sents = this.state.sentences;
 
769
 
770
  async addSentenceToIcl(sidx) {
771
  const r = await fetch(`/api/sentence/${sidx}/add_to_icl`, {method: 'POST'});
772
+
773
+ if (!r.ok) {
774
+ this.toast('Could not add to ICL pool.', 'error');
775
+ return;
776
+ }
777
+
778
+ const data = await r.json();
779
+ this.applyState(data);
780
+
781
+ if (data.icl_add_result === 'unchanged') {
782
+ this.toast(
783
+ `Already in ICL pool — unchanged (v${this.state.icl_pool.version}, ${this.state.icl_pool.size} entries).`,
784
+ 'warn'
785
+ );
786
+ } else if (data.icl_add_result === 'updated') {
787
+ this.toast(
788
+ `Updated existing ICL example after correction (v${this.state.icl_pool.version}, ${this.state.icl_pool.size} entries).`,
789
+ 'ok'
790
+ );
791
+ } else {
792
+ this.toast(
793
+ `Added to ICL pool (v${this.state.icl_pool.version}, ${this.state.icl_pool.size} entries).`,
794
+ 'ok'
795
+ );
796
+ }
797
  },
798
 
799
  async setValidated(sidx, value) {
 
912
  },
913
 
914
  async saveToken() {
915
+ const sidx = this.editor.sidx;
916
+ const tidx = this.editor.tidx;
917
+ const surface = this.editor.tok.surface;
918
+ const changes = this.fieldChanges();
919
+ const wantPropagate =
920
+ this.editor.propagateToSimilar &&
921
+ Object.keys(changes).length > 0 &&
922
+ this.matchingTokenCount() > 0;
923
+
924
+ this.editor.tok._corrected = true;
925
+
926
+ const r = await fetch(`/api/sentence/${sidx}/token/${tidx}`, {
927
+ method: 'POST',
928
+ headers: {'Content-Type': 'application/json'},
929
+ body: JSON.stringify({token: this.editor.tok})
930
+ });
931
+
932
+ if (!r.ok) {
933
+ this.toast('Save failed.', 'error');
934
+ return;
935
+ }
936
+
937
+ // returns full state to ensure consistency
938
+ const data = await r.json();
939
+ this.applyState(data);
940
+
941
+ let propagatedCount = 0;
942
+
943
+ if (wantPropagate) {
944
+ try {
945
+ const r2 = await fetch('/api/bulk_similar', {
946
  method: 'POST',
947
  headers: {'Content-Type': 'application/json'},
948
+ body: JSON.stringify({
949
+ surface,
950
+ updates: changes,
951
+ exclude: [{s: sidx, t: tidx}],
952
+ }),
953
  });
 
 
 
 
 
 
954
 
955
+ if (r2.ok) {
956
+ const j = await r2.json();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957
 
958
+ for (const item of (j.sentences || [])) {
959
+ this.replaceSentence(item.idx, item.sentence);
 
 
 
 
 
960
  }
961
+
962
+ propagatedCount = (j.affected || []).length;
963
  }
964
+ } catch (e) {
965
+ this.toast('Propagation failed: ' + e.message, 'error');
966
+ }
967
+ }
968
+
969
+ let iclMsg = '';
970
+
971
+ if (data.icl_add_result === 'updated') {
972
+ iclMsg = ` + updated ICL v${this.state.icl_pool.version}`;
973
+ } else if (data.icl_add_result === 'inserted') {
974
+ iclMsg = ` + added to ICL v${this.state.icl_pool.version}`;
975
+ } else if (data.icl_add_result === 'unchanged') {
976
+ iclMsg = ` + ICL unchanged`;
977
+ }
978
+
979
+ // auto-advance
980
+ if (this.editor.autoAdvance) {
981
+ const next = this.findNextDisagreement(sidx, tidx);
982
+
983
+ if (next) {
984
+ this.openTokenEditor(next.s, next.t);
985
+
986
+ this.toast(
987
+ propagatedCount > 0
988
+ ? `✓ Saved + propagated to ${propagatedCount} other "${surface}"${iclMsg}.`
989
+ : `✓ Saved${iclMsg}.`,
990
+ 'ok'
991
+ );
992
+
993
+ return;
994
+ }
995
+ }
996
+
997
+ this.closeModal();
998
+
999
+ this.toast(
1000
+ propagatedCount > 0
1001
+ ? `✓ Saved + propagated to ${propagatedCount} other "${surface}"${iclMsg}.`
1002
+ : `✓ Saved${iclMsg}.`,
1003
+ 'ok'
1004
+ );
1005
+ },
1006
 
1007
  findNextDisagreement(sidx, tidx) {
1008
  const sents = this.state.sentences;