| """MoE aggregation: per-token, per-field consensus across multiple models. |
| |
| Pure function. Empty disagreement list when only one model contributes. |
| """ |
| from __future__ import annotations |
|
|
| from collections import Counter |
| from dataclasses import dataclass, asdict |
| from typing import Any, Optional |
|
|
| from schemas import AnnotationSchema |
|
|
|
|
| @dataclass |
| class DisagreementCell: |
| token_idx: int |
| field_path: str |
| values_by_model: dict[str, Any] |
| winner: Any |
| agreement_ratio: float |
|
|
| def to_dict(self) -> dict: |
| return asdict(self) |
|
|
|
|
| CONFIDENCE_ORDER = {"low": 0, "medium": 1, "high": 2} |
|
|
|
|
| def _vote(values: dict[str, Any], priority: list[str]) -> tuple[Any, float]: |
| """Plurality vote. Ties broken by priority order, then by string sort. |
| |
| Returns (winner, agreement_ratio in [0,1]). |
| """ |
| cleaned = {m: v for m, v in values.items() if v is not None} |
| if not cleaned: |
| return None, 1.0 |
| counts = Counter(cleaned.values()) |
| top_count = max(counts.values()) |
| candidates = [v for v, c in counts.items() if c == top_count] |
| if len(candidates) == 1: |
| return candidates[0], top_count / len(cleaned) |
| |
| for m in priority: |
| if m in cleaned and cleaned[m] in candidates: |
| return cleaned[m], top_count / len(cleaned) |
| return sorted(candidates, key=lambda x: str(x))[0], top_count / len(cleaned) |
|
|
|
|
| def _longest_common_substring(strings: list[str], min_len: int = 3) -> str: |
| """Find the longest substring common to all non-empty strings. |
| |
| Empty if no shared substring of length >= min_len. |
| """ |
| strings = [s for s in strings if s] |
| if not strings: |
| return "" |
| if len(strings) == 1: |
| return strings[0] |
| shortest = min(strings, key=len) |
| best = "" |
| n = len(shortest) |
| for i in range(n): |
| for j in range(i + len(best) + 1, n + 1): |
| candidate = shortest[i:j] |
| if all(candidate in s for s in strings): |
| if len(candidate) > len(best): |
| best = candidate |
| else: |
| break |
| return best if len(best) >= min_len else "" |
|
|
|
|
| def _lcs_aggregate(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: |
| cleaned = {m: v for m, v in values.items() if v} |
| if not cleaned: |
| return None, 1.0 |
| strings = list(cleaned.values()) |
| |
| counts = Counter(strings) |
| top = counts.most_common(1)[0] |
| if top[1] == len(strings): |
| return top[0], 1.0 |
| lcs = _longest_common_substring(strings) |
| if lcs: |
| |
| ratio = sum(len(lcs) / max(len(s), 1) for s in strings) / len(strings) |
| return lcs, ratio |
| |
| for m in priority: |
| if m in cleaned: |
| return cleaned[m], 1.0 / len(cleaned) |
| return strings[0], 1.0 / len(cleaned) |
|
|
|
|
| def _min_confidence(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: |
| cleaned = {m: v for m, v in values.items() if v} |
| if not cleaned: |
| return None, 1.0 |
| pick = min(cleaned.values(), key=lambda v: CONFIDENCE_ORDER.get(v, 0)) |
| counts = Counter(cleaned.values()) |
| return pick, counts[pick] / len(cleaned) |
|
|
|
|
| def _priority(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: |
| cleaned = {m: v for m, v in values.items() if v} |
| if not cleaned: |
| return None, 1.0 |
| for m in priority: |
| if m in cleaned: |
| return cleaned[m], 1.0 |
| return next(iter(cleaned.values())), 1.0 |
|
|
|
|
| def _aggregate_field( |
| field_name: str, |
| field_type: str, |
| aggregator: str, |
| subfields: list, |
| values_by_model: dict[str, Any], |
| priority: list[str], |
| token_idx: int, |
| disagreements: list[DisagreementCell], |
| ) -> Any: |
| if field_type == "object": |
| out = {} |
| for sub in subfields: |
| sub_vals = {m: (v.get(sub.name) if isinstance(v, dict) else None) for m, v in values_by_model.items()} |
| winner, ratio = _vote(sub_vals, priority) |
| out[sub.name] = winner |
| if ratio < 1.0 and len(values_by_model) > 1: |
| disagreements.append( |
| DisagreementCell( |
| token_idx=token_idx, |
| field_path=f"{field_name}.{sub.name}", |
| values_by_model=sub_vals, |
| winner=winner, |
| agreement_ratio=ratio, |
| ) |
| ) |
| return out |
| if aggregator == "lcs": |
| winner, ratio = _lcs_aggregate(values_by_model, priority) |
| elif aggregator == "min": |
| winner, ratio = _min_confidence(values_by_model, priority) |
| elif aggregator == "priority": |
| winner, ratio = _priority(values_by_model, priority) |
| else: |
| winner, ratio = _vote(values_by_model, priority) |
| |
| |
| |
| quiet_aggregators = {"min", "priority"} |
| if ratio < 1.0 and len(values_by_model) > 1 and aggregator not in quiet_aggregators: |
| disagreements.append( |
| DisagreementCell( |
| token_idx=token_idx, |
| field_path=field_name, |
| values_by_model=values_by_model, |
| winner=winner, |
| agreement_ratio=ratio, |
| ) |
| ) |
| return winner |
|
|
|
|
| def aggregate( |
| per_model: dict[str, dict], |
| schema: AnnotationSchema, |
| priority: Optional[list[str]] = None, |
| ) -> tuple[dict, list[DisagreementCell]]: |
| """Aggregate per-model annotations into a single consensus annotation. |
| |
| per_model: {model_name -> {"sentence_id": ..., "language": ..., "tokens": [...]}} |
| Returns (consensus_annotation, disagreement_cells). |
| """ |
| if not per_model: |
| raise ValueError("per_model is empty") |
| if priority is None: |
| priority = list(per_model.keys()) |
|
|
| |
| sample = next(iter(per_model.values())) |
| n_tokens = len(sample.get("tokens", [])) |
| for m, ann in per_model.items(): |
| if len(ann.get("tokens", [])) != n_tokens: |
| raise ValueError( |
| f"Token count mismatch in MoE input: {m} has {len(ann.get('tokens', []))}, expected {n_tokens}." |
| ) |
|
|
| disagreements: list[DisagreementCell] = [] |
| consensus_tokens = [] |
| for i in range(n_tokens): |
| token: dict[str, Any] = {"surface": sample["tokens"][i].get("surface", "")} |
| for f in schema.fields: |
| values_by_model = {m: per_model[m]["tokens"][i].get(f.name) for m in per_model} |
| token[f.name] = _aggregate_field( |
| field_name=f.name, |
| field_type=f.type, |
| aggregator=f.aggregator, |
| subfields=f.subfields, |
| values_by_model=values_by_model, |
| priority=priority, |
| token_idx=i, |
| disagreements=disagreements, |
| ) |
| consensus_tokens.append(token) |
|
|
| consensus = { |
| "sentence_id": sample.get("sentence_id", "s1"), |
| "language": sample.get("language", ""), |
| "tokens": consensus_tokens, |
| } |
| return consensus, disagreements |
|
|