"""MoE aggregation: per-token, per-field consensus across multiple models. Pure function. Empty disagreement list when only one model contributes. """ from __future__ import annotations from collections import Counter from dataclasses import dataclass, asdict from typing import Any, Optional from schemas import AnnotationSchema @dataclass class DisagreementCell: token_idx: int field_path: str values_by_model: dict[str, Any] winner: Any agreement_ratio: float def to_dict(self) -> dict: return asdict(self) CONFIDENCE_ORDER = {"low": 0, "medium": 1, "high": 2} def _vote(values: dict[str, Any], priority: list[str]) -> tuple[Any, float]: """Plurality vote. Ties broken by priority order, then by string sort. Returns (winner, agreement_ratio in [0,1]). """ cleaned = {m: v for m, v in values.items() if v is not None} if not cleaned: return None, 1.0 counts = Counter(cleaned.values()) top_count = max(counts.values()) candidates = [v for v, c in counts.items() if c == top_count] if len(candidates) == 1: return candidates[0], top_count / len(cleaned) # tie-break: model with earliest priority among models voting for a candidate for m in priority: if m in cleaned and cleaned[m] in candidates: return cleaned[m], top_count / len(cleaned) return sorted(candidates, key=lambda x: str(x))[0], top_count / len(cleaned) def _longest_common_substring(strings: list[str], min_len: int = 3) -> str: """Find the longest substring common to all non-empty strings. Empty if no shared substring of length >= min_len. """ strings = [s for s in strings if s] if not strings: return "" if len(strings) == 1: return strings[0] shortest = min(strings, key=len) best = "" n = len(shortest) for i in range(n): for j in range(i + len(best) + 1, n + 1): candidate = shortest[i:j] if all(candidate in s for s in strings): if len(candidate) > len(best): best = candidate else: break return best if len(best) >= min_len else "" def _lcs_aggregate(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: cleaned = {m: v for m, v in values.items() if v} if not cleaned: return None, 1.0 strings = list(cleaned.values()) # exact agreement → ratio = 1 counts = Counter(strings) top = counts.most_common(1)[0] if top[1] == len(strings): return top[0], 1.0 lcs = _longest_common_substring(strings) if lcs: # ratio = average overlap fraction ratio = sum(len(lcs) / max(len(s), 1) for s in strings) / len(strings) return lcs, ratio # fall back to priority for m in priority: if m in cleaned: return cleaned[m], 1.0 / len(cleaned) return strings[0], 1.0 / len(cleaned) def _min_confidence(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: cleaned = {m: v for m, v in values.items() if v} if not cleaned: return None, 1.0 pick = min(cleaned.values(), key=lambda v: CONFIDENCE_ORDER.get(v, 0)) counts = Counter(cleaned.values()) return pick, counts[pick] / len(cleaned) def _priority(values: dict[str, str], priority: list[str]) -> tuple[Any, float]: cleaned = {m: v for m, v in values.items() if v} if not cleaned: return None, 1.0 for m in priority: if m in cleaned: return cleaned[m], 1.0 return next(iter(cleaned.values())), 1.0 def _aggregate_field( field_name: str, field_type: str, aggregator: str, subfields: list, values_by_model: dict[str, Any], priority: list[str], token_idx: int, disagreements: list[DisagreementCell], ) -> Any: if field_type == "object": out = {} for sub in subfields: sub_vals = {m: (v.get(sub.name) if isinstance(v, dict) else None) for m, v in values_by_model.items()} winner, ratio = _vote(sub_vals, priority) out[sub.name] = winner if ratio < 1.0 and len(values_by_model) > 1: disagreements.append( DisagreementCell( token_idx=token_idx, field_path=f"{field_name}.{sub.name}", values_by_model=sub_vals, winner=winner, agreement_ratio=ratio, ) ) return out if aggregator == "lcs": winner, ratio = _lcs_aggregate(values_by_model, priority) elif aggregator == "min": winner, ratio = _min_confidence(values_by_model, priority) elif aggregator == "priority": winner, ratio = _priority(values_by_model, priority) else: # vote winner, ratio = _vote(values_by_model, priority) # Only flag *task-meaningful* disagreements. `min` aggregator (confidence) and # `priority` aggregator (comment / free-text metadata) always differ across # models — they would drown the user in noise. quiet_aggregators = {"min", "priority"} if ratio < 1.0 and len(values_by_model) > 1 and aggregator not in quiet_aggregators: disagreements.append( DisagreementCell( token_idx=token_idx, field_path=field_name, values_by_model=values_by_model, winner=winner, agreement_ratio=ratio, ) ) return winner def aggregate( per_model: dict[str, dict], schema: AnnotationSchema, priority: Optional[list[str]] = None, ) -> tuple[dict, list[DisagreementCell]]: """Aggregate per-model annotations into a single consensus annotation. per_model: {model_name -> {"sentence_id": ..., "language": ..., "tokens": [...]}} Returns (consensus_annotation, disagreement_cells). """ if not per_model: raise ValueError("per_model is empty") if priority is None: priority = list(per_model.keys()) # All models must agree on token count (upstream validation should ensure it). sample = next(iter(per_model.values())) n_tokens = len(sample.get("tokens", [])) for m, ann in per_model.items(): if len(ann.get("tokens", [])) != n_tokens: raise ValueError( f"Token count mismatch in MoE input: {m} has {len(ann.get('tokens', []))}, expected {n_tokens}." ) disagreements: list[DisagreementCell] = [] consensus_tokens = [] for i in range(n_tokens): token: dict[str, Any] = {"surface": sample["tokens"][i].get("surface", "")} for f in schema.fields: values_by_model = {m: per_model[m]["tokens"][i].get(f.name) for m in per_model} token[f.name] = _aggregate_field( field_name=f.name, field_type=f.type, aggregator=f.aggregator, subfields=f.subfields, values_by_model=values_by_model, priority=priority, token_idx=i, disagreements=disagreements, ) consensus_tokens.append(token) consensus = { "sentence_id": sample.get("sentence_id", "s1"), "language": sample.get("language", ""), "tokens": consensus_tokens, } return consensus, disagreements