dhuser's picture
Update empty tag
74f75da
"""MoE aggregation: per-token, per-field consensus across multiple models.
Pure function. Empty disagreement list when only one model contributes.
"""
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass, asdict
from typing import Any, Optional
from schemas import AnnotationSchema
@dataclass
class DisagreementCell:
token_idx: int
field_path: str
values_by_model: dict[str, Any]
winner: Any
agreement_ratio: float
def to_dict(self) -> dict:
return asdict(self)
CONFIDENCE_ORDER = {"low": 0, "medium": 1, "high": 2}
def _vote(values: dict[str, Any], priority: list[str]) -> tuple[Any, float]:
"""Plurality vote. Ties broken by priority order, then by string sort.
Returns (winner, agreement_ratio in [0,1]).
"""
cleaned = {m: v for m, v in values.items() if v is not None}
if not cleaned:
return None, 1.0
counts = Counter(cleaned.values())
top_count = max(counts.values())
candidates = [v for v, c in counts.items() if c == top_count]
if len(candidates) == 1:
return candidates[0], top_count / len(cleaned)
# tie-break: model with earliest priority among models voting for a candidate
for m in priority:
if m in cleaned and cleaned[m] in candidates:
return cleaned[m], top_count / len(cleaned)
return sorted(candidates, key=lambda x: str(x))[0], top_count / len(cleaned)
def _longest_common_substring(strings: list[str], min_len: int = 3) -> str:
"""Find the longest substring common to all non-empty strings.
Empty if no shared substring of length >= min_len.
"""
strings = [s for s in strings if s]
if not strings:
return ""
if len(strings) == 1:
return strings[0]
shortest = min(strings, key=len)
best = ""
n = len(shortest)
for i in range(n):
for j in range(i + len(best) + 1, n + 1):
candidate = shortest[i:j]
if all(candidate in s for s in strings):
if len(candidate) > len(best):
best = candidate
else:
break
return best if len(best) >= min_len else ""
def _lcs_aggregate(values: dict[str, str], priority: list[str]) -> tuple[Any, float]:
cleaned = {m: v for m, v in values.items() if v}
if not cleaned:
return None, 1.0
strings = list(cleaned.values())
# exact agreement → ratio = 1
counts = Counter(strings)
top = counts.most_common(1)[0]
if top[1] == len(strings):
return top[0], 1.0
lcs = _longest_common_substring(strings)
if lcs:
# ratio = average overlap fraction
ratio = sum(len(lcs) / max(len(s), 1) for s in strings) / len(strings)
return lcs, ratio
# fall back to priority
for m in priority:
if m in cleaned:
return cleaned[m], 1.0 / len(cleaned)
return strings[0], 1.0 / len(cleaned)
def _min_confidence(values: dict[str, str], priority: list[str]) -> tuple[Any, float]:
cleaned = {m: v for m, v in values.items() if v}
if not cleaned:
return None, 1.0
pick = min(cleaned.values(), key=lambda v: CONFIDENCE_ORDER.get(v, 0))
counts = Counter(cleaned.values())
return pick, counts[pick] / len(cleaned)
def _priority(values: dict[str, str], priority: list[str]) -> tuple[Any, float]:
cleaned = {m: v for m, v in values.items() if v}
if not cleaned:
return None, 1.0
for m in priority:
if m in cleaned:
return cleaned[m], 1.0
return next(iter(cleaned.values())), 1.0
def _aggregate_field(
field_name: str,
field_type: str,
aggregator: str,
subfields: list,
values_by_model: dict[str, Any],
priority: list[str],
token_idx: int,
disagreements: list[DisagreementCell],
) -> Any:
if field_type == "object":
out = {}
for sub in subfields:
sub_vals = {m: (v.get(sub.name) if isinstance(v, dict) else None) for m, v in values_by_model.items()}
winner, ratio = _vote(sub_vals, priority)
out[sub.name] = winner
if ratio < 1.0 and len(values_by_model) > 1:
disagreements.append(
DisagreementCell(
token_idx=token_idx,
field_path=f"{field_name}.{sub.name}",
values_by_model=sub_vals,
winner=winner,
agreement_ratio=ratio,
)
)
return out
if aggregator == "lcs":
winner, ratio = _lcs_aggregate(values_by_model, priority)
elif aggregator == "min":
winner, ratio = _min_confidence(values_by_model, priority)
elif aggregator == "priority":
winner, ratio = _priority(values_by_model, priority)
else: # vote
winner, ratio = _vote(values_by_model, priority)
# Only flag *task-meaningful* disagreements. `min` aggregator (confidence) and
# `priority` aggregator (comment / free-text metadata) always differ across
# models — they would drown the user in noise.
quiet_aggregators = {"min", "priority"}
if ratio < 1.0 and len(values_by_model) > 1 and aggregator not in quiet_aggregators:
disagreements.append(
DisagreementCell(
token_idx=token_idx,
field_path=field_name,
values_by_model=values_by_model,
winner=winner,
agreement_ratio=ratio,
)
)
return winner
def aggregate(
per_model: dict[str, dict],
schema: AnnotationSchema,
priority: Optional[list[str]] = None,
) -> tuple[dict, list[DisagreementCell]]:
"""Aggregate per-model annotations into a single consensus annotation.
per_model: {model_name -> {"sentence_id": ..., "language": ..., "tokens": [...]}}
Returns (consensus_annotation, disagreement_cells).
"""
if not per_model:
raise ValueError("per_model is empty")
if priority is None:
priority = list(per_model.keys())
# All models must agree on token count (upstream validation should ensure it).
sample = next(iter(per_model.values()))
n_tokens = len(sample.get("tokens", []))
for m, ann in per_model.items():
if len(ann.get("tokens", [])) != n_tokens:
raise ValueError(
f"Token count mismatch in MoE input: {m} has {len(ann.get('tokens', []))}, expected {n_tokens}."
)
disagreements: list[DisagreementCell] = []
consensus_tokens = []
for i in range(n_tokens):
token: dict[str, Any] = {"surface": sample["tokens"][i].get("surface", "")}
for f in schema.fields:
values_by_model = {m: per_model[m]["tokens"][i].get(f.name) for m in per_model}
token[f.name] = _aggregate_field(
field_name=f.name,
field_type=f.type,
aggregator=f.aggregator,
subfields=f.subfields,
values_by_model=values_by_model,
priority=priority,
token_idx=i,
disagreements=disagreements,
)
consensus_tokens.append(token)
consensus = {
"sentence_id": sample.get("sentence_id", "s1"),
"language": sample.get("language", ""),
"tokens": consensus_tokens,
}
return consensus, disagreements