DerivedFunction's picture
add
84e2dc1
#!/usr/bin/env python3
"""Gradio demo for the multilingual token-classification language ID model."""
from __future__ import annotations
from collections import Counter, defaultdict
from functools import lru_cache
import random
import os
import re
from typing import Any
import pandas as pd
import gradio as gr
import pycountry
import fasttext
import numpy as np
from huggingface_hub import hf_hub_download
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from fleurs_cache import fetch_random_fleurs_sentence, fetch_random_fleurs_sentence_mix
from language import ALL_LANGS, LANG_ALIASES, LANG_ISO2_TO_ISO3, canonical_lang, canonical_lang_family
from sib200_cache import fetch_random_sib200_sentence, fetch_random_sib200_sentence_mix
from tatoeba import fetch_random_tatoeba_sentence, fetch_random_tatoeba_sentence_mix
MODEL_CHECKPOINT = "DerivedFunction/polyglot-tagger-v2.1"
FASTTEXT_MODEL_REPO = "facebook/fasttext-language-identification"
FASTTEXT_MODEL_FILENAME = "model.bin"
FASTTEXT_MIN_CONFIDENCE = 0.15
MIN_ARTIFACT_SPAN_CHARS = 4
MIN_ARTIFACT_CONFIDENCE = 0.5
ARTIFACT_SPAN_WEIGHT = 0.35
RANDOM_SENTENCE_SAMPLERS = (
fetch_random_fleurs_sentence,
fetch_random_tatoeba_sentence,
fetch_random_sib200_sentence,
)
RANDOM_MIX_SAMPLERS = (
fetch_random_fleurs_sentence_mix,
fetch_random_tatoeba_sentence_mix,
fetch_random_sib200_sentence_mix,
)
@lru_cache(maxsize=1)
def get_pipeline():
model = AutoModelForTokenClassification.from_pretrained(MODEL_CHECKPOINT)
model.eval()
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
return pipeline(
"token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
)
@lru_cache(maxsize=1)
def get_fasttext_model():
"""Load the reference fastText language ID model once."""
model_path = hf_hub_download(repo_id=FASTTEXT_MODEL_REPO, filename=FASTTEXT_MODEL_FILENAME)
return fasttext.load_model(model_path)
def normalize_label(label: str) -> str:
if label.startswith(("B-", "I-")):
label = label[2:]
return canonical_lang_family(label.lower())
def build_lang_stats(
entities: list[dict[str, Any]],
) -> tuple[dict[str, dict[str, float | int]], int]:
"""Aggregate merged entity spans into per-language coverage stats."""
char_coverage: defaultdict[str, float] = defaultdict(float)
conf_weighted: defaultdict[str, float] = defaultdict(float)
entity_counts: defaultdict[str, int] = defaultdict(int)
total_tagged_chars = 0.0
ignored_artifacts = 0
for entity in entities:
label = normalize_label(entity.get("entity_group", entity.get("entity", "O")))
if label == "o":
continue
start = entity.get("start")
end = entity.get("end")
if start is None or end is None:
continue
span_len = max(int(end) - int(start), 1)
score = float(entity.get("score", 0.0))
span_weight = ARTIFACT_SPAN_WEIGHT if is_artifact_span(span_len, score) else 1.0
if span_weight < 1.0:
ignored_artifacts += 1
effective_span_len = span_len * span_weight
char_coverage[label] += effective_span_len
conf_weighted[label] += effective_span_len * score
entity_counts[label] += 1
total_tagged_chars += effective_span_len
if total_tagged_chars == 0:
return {}, ignored_artifacts
stats: dict[str, dict[str, float | int]] = {}
for lang, coverage in char_coverage.items():
avg_confidence = conf_weighted[lang] / coverage if coverage else 0.0
coverage_pct = coverage / total_tagged_chars
stats[lang] = {
"char_coverage": coverage,
"coverage_pct": coverage_pct,
"avg_confidence": avg_confidence,
"entity_count": entity_counts[lang],
"rank_score": coverage_pct * avg_confidence,
}
return stats, ignored_artifacts
def to_classifier_scores(lang_stats: dict[str, dict[str, float | int]]) -> dict[str, float]:
"""Normalize coverage-confidence weights into a classifier-like distribution."""
raw = {
lang: float(stat["coverage_pct"]) * float(stat["avg_confidence"])
for lang, stat in lang_stats.items()
}
total = sum(raw.values())
if total == 0:
return {}
return dict(sorted(((lang, weight / total) for lang, weight in raw.items()), key=lambda item: item[1], reverse=True))
def make_lang_chip_label(lang: str, stat: dict[str, float | int], score: float) -> str:
"""Render a compact label for a clickable language chip."""
return f"{lang.upper()} {score:.0%}"
def build_chip_button_updates(
ranked: list[tuple[str, dict[str, float | int]]],
classifier_scores: dict[str, float],
fasttext_scores: dict[str, float] | None = None,
max_chips: int = 6,
) -> list[dict[str, Any]]:
"""Return button updates for the top-ranked languages."""
fasttext_scores = fasttext_scores or {}
fasttext_ranked = sorted(fasttext_scores.items(), key=lambda item: item[1], reverse=True)
fasttext_rank = {lang: idx for idx, (lang, _) in enumerate(fasttext_ranked)}
model_avg_confidence = {lang: float(stat["avg_confidence"]) for lang, stat in ranked}
model_ranked = [lang for lang, _ in ranked]
fasttext_ranked_langs = [lang for lang, _ in fasttext_ranked]
fasttext_only = [lang for lang in fasttext_ranked_langs if lang not in set(model_ranked)]
ordered_langs = model_ranked + fasttext_only
updates: list[dict[str, Any]] = []
for idx in range(max_chips):
if idx < len(ordered_langs):
lang = ordered_langs[idx]
model_score = classifier_scores.get(lang, 0.0)
model_label_score = model_avg_confidence.get(lang, model_score)
fast_score = fasttext_scores.get(lang, 0.0)
in_fasttext = lang in fasttext_scores
in_model = model_score > 0.0
if in_model and in_fasttext:
variant = "primary"
elif in_fasttext:
variant = "secondary"
else:
variant = "stop"
fast_rank = fasttext_rank.get(lang)
fast_rank_text = f" #{fast_rank + 1}" if fast_rank is not None else ""
updates.append(
gr.update(
value=f"{lang.upper()} M {model_label_score:.0%} | FT {fast_score:.0%}{fast_rank_text}",
visible=True,
variant=variant,
)
)
else:
updates.append(gr.update(value="", visible=False))
return updates
def build_ui_state(
*,
text: str,
lang_stats: dict[str, dict[str, float | int]],
classifier_scores: dict[str, float],
fasttext_result: dict[str, Any] | None,
dominant_lang: str,
overall_confidence: float,
ignored_artifacts: int,
) -> dict[str, Any]:
"""Package the bits the interactive chips need to redraw the card."""
fasttext_scores = {}
if fasttext_result:
fasttext_scores = {item["lang"]: float(item["score"]) for item in fasttext_result.get("predictions", [])}
model_ranked = [lang for lang, _ in sorted(lang_stats.items(), key=lambda item: classifier_scores.get(item[0], 0.0), reverse=True)]
fasttext_ranked = sorted(fasttext_scores, key=lambda lang: fasttext_scores.get(lang, 0.0), reverse=True)
chip_langs = []
seen = set()
for lang in model_ranked:
chip_langs.append({"lang": lang, "source": "model"})
seen.add(lang)
for lang in fasttext_ranked:
if lang not in seen:
chip_langs.append({"lang": lang, "source": "fasttext"})
seen.add(lang)
return {
"text": text,
"lang_stats": lang_stats,
"classifier_scores": classifier_scores,
"fasttext": fasttext_result,
"dominant_lang": dominant_lang,
"selected_lang": dominant_lang,
"overall_confidence": overall_confidence,
"ignored_artifacts": ignored_artifacts,
"ranked_langs": sorted(lang_stats.keys(), key=lambda lang: classifier_scores.get(lang, 0.0), reverse=True),
"chip_langs": chip_langs,
}
def build_example_validation(
classifier_scores: dict[str, float],
reference_scores: dict[str, float] | None,
expected_langs: list[str],
) -> dict[str, Any]:
"""Compare derived scores against known source languages."""
expected_langs = [canonical_lang_family(lang) for lang in expected_langs if lang]
expected_set = set(expected_langs)
ranked_predictions = sorted(classifier_scores.items(), key=lambda item: item[1], reverse=True)
top_lang = ranked_predictions[0][0] if ranked_predictions else None
top_score = float(ranked_predictions[0][1]) if ranked_predictions else 0.0
predicted_langs = [lang for lang, score in ranked_predictions if score > 0.0]
if not predicted_langs and ranked_predictions:
predicted_langs = [lang for lang, _ in ranked_predictions[:1]]
predicted_set = set(predicted_langs)
true_positive = len(expected_set & predicted_set)
false_positive = len(predicted_set - expected_set)
false_negative = len(expected_set - predicted_set)
precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) else 0.0
recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) else 0.0
validation_score = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
reference_scores = reference_scores or {}
reference_predicted = sorted(
(lang for lang, score in reference_scores.items() if score > 0.0),
key=lambda lang: reference_scores.get(lang, 0.0),
reverse=True,
)
reference_set = set(reference_predicted)
reference_tp = len(expected_set & reference_set)
reference_fp = len(reference_set - expected_set)
reference_fn = len(expected_set - reference_set)
reference_precision = reference_tp / (reference_tp + reference_fp) if (reference_tp + reference_fp) else 0.0
reference_recall = reference_tp / (reference_tp + reference_fn) if (reference_tp + reference_fn) else 0.0
reference_score = (
2 * reference_precision * reference_recall / (reference_precision + reference_recall)
if (reference_precision + reference_recall)
else 0.0
)
return {
"expected_langs": expected_langs,
"predicted_langs": predicted_langs,
"reference_langs": reference_predicted,
"top_lang": top_lang,
"top_score": top_score,
"true_positive": true_positive,
"false_positive": false_positive,
"false_negative": false_negative,
"expected_count": len(expected_set),
"predicted_count": len(predicted_set),
"precision": precision,
"recall": recall,
"top_match": false_positive == 0 and false_negative == 0,
"validation_score": validation_score,
"reference_true_positive": reference_tp,
"reference_false_positive": reference_fp,
"reference_false_negative": reference_fn,
"reference_precision": reference_precision,
"reference_recall": reference_recall,
"reference_score": reference_score,
}
def render_validation_html(validation: dict[str, Any], *, source_label: str) -> str:
"""Render a compact validation card for a labeled example source."""
if not validation:
return ""
expected_langs = ", ".join(lang.upper() for lang in validation.get("expected_langs", [])) or "n/a"
predicted_langs = ", ".join(lang.upper() for lang in validation.get("predicted_langs", [])) or "n/a"
top_lang = validation.get("top_lang") or "n/a"
top_score = float(validation.get("top_score", 0.0))
expected_count = int(validation.get("expected_count", 0))
predicted_count = int(validation.get("predicted_count", 0))
true_positive = int(validation.get("true_positive", 0))
false_positive = int(validation.get("false_positive", 0))
false_negative = int(validation.get("false_negative", 0))
validation_score = float(validation.get("validation_score", 0.0))
precision = float(validation.get("precision", 0.0))
recall = float(validation.get("recall", 0.0))
reference_score = float(validation.get("reference_score", 0.0))
reference_precision = float(validation.get("reference_precision", 0.0))
reference_recall = float(validation.get("reference_recall", 0.0))
top_match = bool(validation.get("top_match"))
status_label = "Match" if top_match else "Mismatch"
status_class = "validation-pass" if top_match else "validation-warn"
if validation_score >= 0.8:
tier_class = "validation-high"
elif validation_score >= 0.45:
tier_class = "validation-mid"
else:
tier_class = "validation-low"
return f"""
<div class="validation-strip {tier_class}">
<div class="validation-kicker">{source_label} validation</div>
<div class="validation-main">{validation_score:.1%} <span class="validation-vs">vs {reference_score:.1%}</span></div>
<div class="validation-status {status_class}">{status_label}</div>
<div class="validation-grid">
<div class="validation-chip">
<span class="validation-chip-label">Expected</span>
<span class="validation-chip-value">{expected_langs}</span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Predicted</span>
<span class="validation-chip-value">{predicted_langs}</span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Top</span>
<span class="validation-chip-value">{top_lang.upper()} <span class="validation-chip-muted">{top_score:.1%}</span></span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Counts</span>
<span class="validation-chip-value">TP {true_positive} / FP {false_positive} / FN {false_negative}</span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Precision</span>
<span class="validation-chip-value">{precision:.1%} <span class="validation-chip-muted">({predicted_count} predicted)</span></span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Recall</span>
<span class="validation-chip-value">{recall:.1%} <span class="validation-chip-muted">({expected_count} expected)</span></span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Reference precision</span>
<span class="validation-chip-value">{reference_precision:.1%}</span>
</div>
<div class="validation-chip">
<span class="validation-chip-label">Reference recall</span>
<span class="validation-chip-value">{reference_recall:.1%}</span>
</div>
</div>
</div>
"""
def build_tatoeba_validation(
classifier_scores: dict[str, float],
expected_langs: list[str],
) -> dict[str, Any]:
"""Backward-compatible wrapper for existing Tatoeba callers."""
return build_example_validation(classifier_scores, expected_langs)
def render_tatoeba_validation_html(validation: dict[str, Any]) -> str:
"""Backward-compatible wrapper for existing Tatoeba callers."""
return render_validation_html(validation, source_label="Tatoeba")
def _source_key(source: str) -> str:
return (source or "").strip().split("-", 1)[0].lower()
def _source_label(source: str) -> str:
key = _source_key(source)
if key == "fleurs":
return "FLEURS"
if key == "tatoeba":
return "Tatoeba"
if key == "sib200":
return "SIB-200"
return key.upper() or "Example"
def _validation_key(source: str) -> str:
key = _source_key(source) or "example"
return f"{key}_validation"
def _sentence_id_keys(sentence: dict[str, Any]) -> list[str]:
keys = []
for candidate in ("fleurs_id", "sentence_id", "sib200_id", "id"):
value = sentence.get(candidate)
if value is not None:
keys.append(value)
return keys
def _language_name(lang_code: str) -> str:
"""Best-effort human readable language name for a code."""
code = (lang_code or "").strip()
if not code:
return "Unknown"
language = pycountry.languages.get(alpha_2=code)
if language is None:
language = pycountry.languages.get(alpha_3=LANG_ISO2_TO_ISO3.get(code, ""))
if language is None:
return code.upper()
name = getattr(language, "name", None)
if not name:
return code.upper()
return name
def render_language_reference_html() -> str:
"""Render a clickable footer that expands to code-to-name mappings."""
rows = []
for code in sorted(LANG_ISO2_TO_ISO3.keys()):
name = _language_name(code)
rows.append(f"<li><span class='lang-code'>{code}</span><span class='lang-name'>{name}</span></li>")
rows_html = "".join(rows)
return f"""
<details class="footer-note footer-langs">
<summary>Supported model languages: {len(LANG_ISO2_TO_ISO3)}. Click to view code-to-name mapping.</summary>
<div class="footer-langs-body">
<ul class="footer-lang-list">{rows_html}</ul>
</div>
</details>
"""
def _split_sentences_for_fasttext(text: str) -> list[str]:
blocks = re.split(r"\n\s*\n+", text)
sentences: list[str] = []
for block in blocks:
block = block.strip()
if not block:
continue
chunks = re.split(r"(?<=[.!?。!?])\s+|\n+", block)
sentences.extend(chunk.strip() for chunk in chunks if chunk and chunk.strip())
return sentences
def predict_fasttext(text: str, k: int = 5, mode: str = "full") -> dict[str, Any]:
"""Return fastText language predictions for comparison."""
model = get_fasttext_model()
original_array = np.array
def _array_compat(obj, *args, **kwargs):
if kwargs.get("copy") is False:
kwargs = {**kwargs}
kwargs.pop("copy", None)
return original_array(obj, *args, **kwargs)
def _predict_one(sample: str) -> tuple[list[str], list[float]]:
np.array = _array_compat
try:
labels, scores = model.predict(sample, k=k)
finally:
np.array = original_array
return list(labels), [float(score) for score in scores]
def _normalize_predictions(labels: list[str], scores: list[float], *, keep_best: bool = False) -> list[dict[str, Any]]:
predictions = [
{
"raw_label": label.removeprefix("__label__"),
"lang": fasttext_label_to_iso2(label.removeprefix("__label__")),
"score": float(score),
}
for label, score in zip(labels, scores)
]
if keep_best and predictions:
return [predictions[0]]
return [item for item in predictions if item["score"] >= FASTTEXT_MIN_CONFIDENCE]
if mode == "sentences":
sentence_predictions: list[dict[str, Any]] = []
for sentence in _split_sentences_for_fasttext(text):
labels, scores = _predict_one(sentence)
if not labels:
continue
sentence_predictions.append(
{
"sentence": sentence,
"top_raw_label": labels[0].removeprefix("__label__"),
"top_family": fasttext_label_to_iso2(labels[0].removeprefix("__label__")),
"top_score": float(scores[0]) if scores else 0.0,
"predictions": _normalize_predictions(labels, scores, keep_best=True),
}
)
votes: defaultdict[str, float] = defaultdict(float)
for item in sentence_predictions:
top_lang = item["top_family"]
top_score = float(item["top_score"])
votes[top_lang] += top_score
predictions = [
{"lang": lang, "score": score / max(len(sentence_predictions), 1)}
for lang, score in sorted(votes.items(), key=lambda item: item[1], reverse=True)
if score > 0.0
]
top_raw_label = sentence_predictions[0]["top_raw_label"] if sentence_predictions else None
top_family = sentence_predictions[0]["top_family"] if sentence_predictions else None
variant_warning = any(fasttext_label_is_proxy(item["top_raw_label"]) for item in sentence_predictions)
return {
"model": FASTTEXT_MODEL_REPO,
"mode": mode,
"sentences": sentence_predictions,
"sentence_count": len(sentence_predictions),
"predictions": predictions,
"top_lang": predictions[0]["lang"] if predictions else None,
"top_score": predictions[0]["score"] if predictions else 0.0,
"top_raw_label": top_raw_label,
"top_family": top_family,
"variant_warning": variant_warning,
}
line = " ".join(part.strip() for part in text.splitlines() if part.strip())
labels, scores = _predict_one(line)
predictions = _normalize_predictions(labels, scores)
top_raw_label = labels[0].removeprefix("__label__") if labels else None
top_family = fasttext_label_to_iso2(top_raw_label) if top_raw_label else None
variant_warning = any(fasttext_label_is_proxy(item["raw_label"]) for item in predictions)
return {
"model": FASTTEXT_MODEL_REPO,
"mode": mode,
"predictions": predictions,
"sentence_count": 1,
"top_lang": predictions[0]["lang"] if predictions else None,
"top_score": predictions[0]["score"] if predictions else 0.0,
"top_raw_label": top_raw_label,
"top_family": top_family,
"variant_warning": variant_warning,
}
def fasttext_label_to_iso2(label: str) -> str:
"""Convert fastText labels like `bos_Latn` or `eng` into our ISO-2 space."""
base = label.split("_", 1)[0].lower()
base = canonical_lang_family(LANG_ALIASES.get(base, base))
if base in LANG_ALIASES:
base = canonical_lang_family(LANG_ALIASES[base])
if len(base) == 2:
return base
language = pycountry.languages.get(alpha_3=base)
if language is None:
return base
alpha_2 = getattr(language, "alpha_2", None)
if alpha_2:
return canonical_lang(alpha_2.lower())
return base
def fasttext_label_is_proxy(label: str) -> bool:
"""Return True when a fastText label maps through an explicit alias/proxy."""
base = label.split("_", 1)[0].lower()
return base in LANG_ALIASES and LANG_ALIASES[base] != base
def fasttext_alias_hint_for_lang(fasttext_result: dict[str, Any] | None, lang: str) -> str | None:
"""Return the raw fastText label when the selected language was reached via an explicit alias."""
if not fasttext_result or not lang:
return None
for item in fasttext_result.get("predictions", []):
if item.get("lang") == lang and fasttext_label_is_proxy(str(item.get("raw_label", ""))):
return str(item.get("raw_label"))
for sentence in fasttext_result.get("sentences", []):
if sentence.get("top_family") == lang and fasttext_label_is_proxy(str(sentence.get("top_raw_label", ""))):
return str(sentence.get("top_raw_label"))
for item in sentence.get("predictions", []):
if item.get("lang") == lang and fasttext_label_is_proxy(str(item.get("raw_label", ""))):
return str(item.get("raw_label"))
return None
def fetch_random_cached_sentence() -> dict[str, Any]:
"""Randomly sample a sentence from either cached source."""
samplers = list(RANDOM_SENTENCE_SAMPLERS)
random.shuffle(samplers)
last_error: FileNotFoundError | None = None
for sampler in samplers:
try:
return sampler()
except FileNotFoundError as exc:
last_error = exc
if last_error is not None:
raise last_error
raise RuntimeError("No cached sentence samplers are registered.")
def fetch_random_cached_sentence_mix() -> dict[str, Any]:
"""Randomly sample a mixed-language example from either cached source."""
samplers = list(RANDOM_MIX_SAMPLERS)
random.shuffle(samplers)
last_error: FileNotFoundError | None = None
for sampler in samplers:
try:
return sampler()
except FileNotFoundError as exc:
last_error = exc
if last_error is not None:
raise last_error
raise RuntimeError("No cached mix samplers are registered.")
def render_prediction_summary(
*,
text: str,
selected_lang: str,
dominant_lang: str,
lang_stats: dict[str, dict[str, float | int]],
classifier_scores: dict[str, float],
fasttext_result: dict[str, Any] | None,
overall_confidence: float,
ignored_artifacts: int,
) -> str:
"""Render the prediction card for a selected language."""
stat = lang_stats[selected_lang]
iso3 = LANG_ISO2_TO_ISO3.get(selected_lang, "n/a")
selected_score = classifier_scores.get(selected_lang, 0.0)
tagged_chars = sum(float(s["char_coverage"]) for s in lang_stats.values())
label = "Dominant" if selected_lang == dominant_lang else "Selected"
warnings = []
if overall_confidence < 0.75:
warnings.append(f"Low confidence overall: {overall_confidence:.2f}")
if selected_lang != dominant_lang:
warnings.append(f"Top prediction: {dominant_lang.upper()}")
alias_hint = fasttext_alias_hint_for_lang(fasttext_result, selected_lang)
if alias_hint:
warnings.append(f"fastText proxy hint: {alias_hint} -> {selected_lang.upper()}")
warning_html = "".join(f"<div class='ambiguity-warning'>{note}</div>" for note in warnings)
return f"""
<div class="summary-card">
<div class="summary-kicker">Prediction</div>
<div class="summary-main">{selected_lang.upper()}</div>
<div class="summary-note">{label} view · derived score {selected_score:.1%}</div>
{warning_html}
<div class="metric-grid">
<div class="metric">
<span class="metric-label">ISO-3</span>
<span class="metric-value">{iso3}</span>
</div>
<div class="metric">
<span class="metric-label">Derived score</span>
<span class="metric-value"><strong>{selected_score:.1%}</strong></span>
</div>
<div class="metric">
<span class="metric-label">Coverage</span>
<span class="metric-value">{float(stat['coverage_pct']):.1%}</span>
</div>
<div class="metric">
<span class="metric-label">Avg confidence</span>
<span class="metric-value">{float(stat['avg_confidence']):.3f}</span>
</div>
<div class="metric" style="grid-column: 1 / -1; display: flex; justify-content: space-between; align-items: baseline; gap: 12px;">
<span class="metric-label" style="margin: 0;">Tagged chars</span>
<span class="metric-value">{tagged_chars:.0f} / {len(text)}</span>
</div>
</div>
<div class="meter">
<div class="meter-head">
<span class="metric-label">Overall confidence</span>
<span class="meter-value">{overall_confidence:.3f}</span>
</div>
<div class="meter-track">
<div class="meter-fill" style="width: {max(0.0, min(100.0, overall_confidence * 100.0)):.1f}%"></div>
</div>
</div>
<div class="summary-note">Ignored artifacts: {ignored_artifacts}</div>
</div>
"""
def fasttext_mode_from_choice(choice: str | None) -> str:
choice = (choice or "").strip().lower()
return "sentences" if choice in {"sentences", "sentence by sentence"} else "full"
def render_selected_language_summary(ui_state: dict[str, Any], selected_lang: str) -> str:
"""Redraw the summary card for a clicked language chip."""
if not ui_state:
return """
<div class="summary-card">
<div class="summary-kicker">Prediction</div>
<div class="summary-main">No language selected</div>
</div>
"""
lang_stats = ui_state.get("lang_stats", {})
if selected_lang not in lang_stats:
selected_lang = ui_state.get("selected_lang") or ui_state.get("dominant_lang") or ""
if not selected_lang:
return """
<div class="summary-card">
<div class="summary-kicker">Prediction</div>
<div class="summary-main">No language selected</div>
</div>
"""
new_state = {**ui_state, "selected_lang": selected_lang}
return render_prediction_summary(
text=new_state.get("text", ""),
selected_lang=selected_lang,
dominant_lang=new_state.get("dominant_lang", selected_lang),
lang_stats=lang_stats,
classifier_scores=new_state.get("classifier_scores", {}),
fasttext_result=new_state.get("fasttext"),
overall_confidence=float(new_state.get("overall_confidence", 0.0)),
ignored_artifacts=int(new_state.get("ignored_artifacts", 0)),
)
def render_fasttext_summary(ui_state: dict[str, Any], selected_lang: str) -> str:
"""Render a summary card for the fastText reference view."""
fasttext_result = ui_state.get("fasttext") or {}
predictions = fasttext_result.get("predictions", [])
fasttext_scores = {item["lang"]: float(item["score"]) for item in predictions}
score = fasttext_scores.get(selected_lang, 0.0)
top_lang = fasttext_result.get("top_lang") or selected_lang
top_score = float(fasttext_result.get("top_score", 0.0))
mode = fasttext_result.get("mode", "full")
mode_label = "sentence" if mode == "sentences" else "full text"
top_raw_label = fasttext_result.get("top_raw_label") or "n/a"
alias_hint = fasttext_alias_hint_for_lang(fasttext_result, selected_lang)
warnings = []
if alias_hint:
warnings.append(f"Proxy hint: {alias_hint} -> {selected_lang.upper()}")
warning_html = "".join(f"<div class='ambiguity-warning'>{note}</div>" for note in warnings)
return f"""
<div class="summary-card">
<div class="summary-kicker">fastText reference</div>
<div class="summary-main">{selected_lang.upper()}</div>
<div class="summary-note">Mode: {mode_label} · top prediction {top_lang.upper()} {top_score:.1%}</div>
{warning_html}
<div class="metric-grid">
<div class="metric">
<span class="metric-label">Reference score</span>
<span class="metric-value"><strong>{score:.1%}</strong></span>
</div>
<div class="metric">
<span class="metric-label">Top score</span>
<span class="metric-value">{top_score:.1%}</span>
</div>
</div>
<div class="summary-note">Ignored artifacts: 0</div>
</div>
"""
def select_language_from_chip(chip_index: int, ui_state: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""Pick a ranked language chip and redraw the summary for that language."""
if not ui_state:
return (
"<div class='empty-state'>Run a prediction first.</div>",
{},
)
chip_langs = ui_state.get("chip_langs", [])
if not chip_langs:
return (
"<div class='empty-state'>No ranked languages available.</div>",
ui_state,
)
chip_index = max(0, min(int(chip_index), len(chip_langs) - 1))
chip = chip_langs[chip_index]
selected_lang = chip["lang"]
selected_source = chip.get("source", "model")
if selected_source == "fasttext":
summary = render_fasttext_summary(ui_state, selected_lang)
else:
summary = render_selected_language_summary(ui_state, selected_lang)
return summary, {**ui_state, "selected_lang": selected_lang, "selected_source": selected_source}
def is_artifact_span(span_len: int, score: float) -> bool:
"""Identify tiny, low-confidence spans that are likely trailing noise."""
return span_len < MIN_ARTIFACT_SPAN_CHARS and score < MIN_ARTIFACT_CONFIDENCE
def predict(text: str, fasttext_mode: str = "full") -> tuple[str, pd.DataFrame, dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
text = (text or "").strip()
if not text:
empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"])
hidden = {}
hidden_buttons = [gr.update(value="", visible=False) for _ in range(6)]
return (
"<div class='empty-state'>Paste some text to see the model's language signal.</div>",
empty,
{},
hidden,
"",
*hidden_buttons,
)
nlp = get_pipeline()
entities = nlp(text)
fasttext_result = predict_fasttext(text, mode=fasttext_mode_from_choice(fasttext_mode))
fasttext_scores = {item["lang"]: item["score"] for item in fasttext_result.get("predictions", [])}
rows: list[dict[str, Any]] = []
token_counts: Counter[str] = Counter()
for entity in entities:
label = normalize_label(entity.get("entity_group", entity.get("entity", "O")))
if label == "o":
continue
token_counts[label] += 1
rows.append(
{
"token": entity.get("word", ""),
"language": label,
"score": round(float(entity.get("score", 0.0)), 4),
"start": entity.get("start", None),
"end": entity.get("end", None),
}
)
spans = pd.DataFrame(rows, columns=["token", "language", "score", "start", "end"])
spans = spans.sort_values(by=["start", "end"], na_position="last") if not spans.empty else spans
lang_stats, ignored_artifacts = build_lang_stats(entities)
if lang_stats:
ranked = sorted(lang_stats.items(), key=lambda item: item[1]["rank_score"], reverse=True)
classifier_scores = to_classifier_scores(lang_stats)
dominant_lang = ranked[0][0]
tagged_chars = sum(float(stat["char_coverage"]) for stat in lang_stats.values())
overall_confidence = (
sum(float(stat["char_coverage"]) * float(stat["avg_confidence"]) for stat in lang_stats.values())
/ tagged_chars
if tagged_chars
else 0.0
)
summary = render_prediction_summary(
text=text,
selected_lang=dominant_lang,
dominant_lang=dominant_lang,
lang_stats=lang_stats,
classifier_scores=classifier_scores,
fasttext_result=fasttext_result,
overall_confidence=overall_confidence,
ignored_artifacts=ignored_artifacts,
)
ui_state = build_ui_state(
text=text,
lang_stats=lang_stats,
classifier_scores=classifier_scores,
fasttext_result=fasttext_result,
dominant_lang=dominant_lang,
overall_confidence=overall_confidence,
ignored_artifacts=ignored_artifacts,
)
else:
classifier_scores = {}
ranked = []
overall_confidence = 0.0
dominant_lang = ""
ui_state = {}
summary = """
<div class="summary-card">
<div class="summary-kicker">Prediction</div>
<div class="summary-main">No language spans detected</div>
<div class="summary-subtitle">Try a longer sample or a cleaner single-language paragraph.</div>
</div>
"""
raw = {
"model": MODEL_CHECKPOINT,
"languages_supported": len(ALL_LANGS),
"top_predictions": token_counts.most_common(10),
"classifier_scores": classifier_scores if lang_stats else {},
"overall_confidence": f"{overall_confidence:.3f}" if lang_stats else "0.000",
"ignored_artifacts": ignored_artifacts,
"lang_stats": {
lang: {
**stat,
"coverage_pct": f"{float(stat['coverage_pct']):.3f}",
"avg_confidence": f"{float(stat['avg_confidence']):.3f}",
"rank_score": f"{float(stat['rank_score']):.3f}",
}
for lang, stat in ranked
}
if lang_stats
else {},
"entities": entities,
"selected_lang": dominant_lang,
"ranked_langs": [lang for lang, _ in ranked],
"fasttext": fasttext_result,
"fasttext_scores": fasttext_scores,
"text": text,
}
chip_updates = build_chip_button_updates(ranked, classifier_scores, fasttext_scores) if lang_stats else [gr.update(value="", visible=False) for _ in range(6)]
return summary, spans, raw, ui_state, "", *chip_updates
def load_random_tatoeba_example(fasttext_mode: str = "full") -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]:
sentence = fetch_random_tatoeba_sentence()
text = sentence["text"]
summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode)
sentence_rows = sentence.get("sentences") or [sentence]
sentence_ids = _sentence_id_keys(sentence)
sentence_langs = [item.get("lang_iso2", "") for item in sentence_rows]
sentence_lang_iso3s = [item.get("lang_iso3", "") for item in sentence_rows]
validation = build_example_validation(
raw.get("classifier_scores", {}),
raw.get("fasttext_scores", {}),
sentence_langs,
)
raw = {
**raw,
"source": "tatoeba",
"sentence_id": sentence_ids[0] if sentence_ids else sentence.get("sentence_id", sentence.get("id")),
"sentence_ids": sentence_ids,
"lang_count": sentence.get("lang_count", len(sentence_rows)),
"sentence_langs": sentence_langs,
"sentence_lang_iso3s": sentence_lang_iso3s,
"sentences": sentence_rows,
"sentence_lang": sentence.get("source_lang", sentence.get("lang")),
"sentence_lang_iso2": sentence.get("lang_iso2", sentence.get("source_lang")),
"sentence_lang_iso3": sentence.get("lang_iso3", ""),
_validation_key(sentence.get("source", "tatoeba")): validation,
}
validation_html = render_validation_html(validation, source_label=_source_label(sentence.get("source", "tatoeba")))
summary = render_prediction_summary(
text=text,
selected_lang=ui_state.get("selected_lang", raw.get("selected_lang", "")),
dominant_lang=ui_state.get("dominant_lang", raw.get("selected_lang", "")),
lang_stats=ui_state.get("lang_stats", {}),
classifier_scores=ui_state.get("classifier_scores", {}),
fasttext_result=raw.get("fasttext"),
overall_confidence=float(ui_state.get("overall_confidence", 0.0)),
ignored_artifacts=int(ui_state.get("ignored_artifacts", 0)),
)
return text, summary, spans, raw, ui_state, validation_html, *chip_updates
def load_random_tatoeba_mix_example(fasttext_mode: str = "full") -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]:
mix = fetch_random_tatoeba_sentence_mix()
text = mix["text"]
summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode)
validation = build_example_validation(
raw.get("classifier_scores", {}),
raw.get("fasttext_scores", {}),
mix.get("langs", []),
)
raw = {
**raw,
"source": "tatoeba-mix",
"lang_count": mix["lang_count"],
"sentence_langs": mix["langs"],
"sentence_lang_iso3s": mix["lang_iso3s"],
"sentences": mix["sentences"],
_validation_key(mix.get("source", "tatoeba-mix")): validation,
}
validation_html = render_validation_html(validation, source_label=_source_label(mix.get("source", "tatoeba-mix")))
summary = render_prediction_summary(
text=text,
selected_lang=ui_state.get("selected_lang", raw.get("selected_lang", "")),
dominant_lang=ui_state.get("dominant_lang", raw.get("selected_lang", "")),
lang_stats=ui_state.get("lang_stats", {}),
classifier_scores=ui_state.get("classifier_scores", {}),
fasttext_result=raw.get("fasttext"),
overall_confidence=float(ui_state.get("overall_confidence", 0.0)),
ignored_artifacts=int(ui_state.get("ignored_artifacts", 0)),
)
return text, summary, spans, raw, ui_state, validation_html, *chip_updates
def load_random_fleurs_example(fasttext_mode: str = "full") -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]:
try:
sentence = fetch_random_cached_sentence()
except FileNotFoundError as exc:
empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"])
message = (
"<div class='empty-state'>"
f"{exc}"
"</div>"
)
return "", message, empty, {}, {}, "", *[gr.update(value="", visible=False) for _ in range(6)]
text = sentence["text"]
summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode)
sentence_rows = sentence.get("sentences") or [sentence]
sentence_id_values = _sentence_id_keys(sentence)
sentence_langs = [item.get("lang_iso2", "") for item in sentence_rows]
sentence_lang_iso3s = [item.get("lang_iso3", "") for item in sentence_rows]
validation = build_example_validation(
raw.get("classifier_scores", {}),
raw.get("fasttext_scores", {}),
sentence_langs,
)
raw = {
**raw,
"source": sentence.get("source", "fleurs"),
"cached_sentence_id": sentence_id_values[0] if sentence_id_values else None,
"cached_sentence_ids": [_sentence_id_keys(item)[0] if _sentence_id_keys(item) else None for item in sentence_rows],
"lang_count": sentence.get("lang_count", len(sentence_rows)),
"cached_split": sentence.get("split"),
"cached_source_lang": sentence.get("source_lang"),
"cached_model_lang": sentence.get("model_lang", sentence.get("lang_iso2")),
"cached_language": sentence.get("language"),
"sentence_langs": sentence_langs,
"sentence_lang_iso3s": sentence_lang_iso3s,
"sentences": sentence_rows,
_validation_key(sentence.get("source", "fleurs")): validation,
}
source_label = _source_label(sentence.get("source", "fleurs"))
validation_html = render_validation_html(validation, source_label=source_label)
summary = render_prediction_summary(
text=text,
selected_lang=ui_state.get("selected_lang", raw.get("selected_lang", "")),
dominant_lang=ui_state.get("dominant_lang", raw.get("selected_lang", "")),
lang_stats=ui_state.get("lang_stats", {}),
classifier_scores=ui_state.get("classifier_scores", {}),
fasttext_result=raw.get("fasttext"),
overall_confidence=float(ui_state.get("overall_confidence", 0.0)),
ignored_artifacts=int(ui_state.get("ignored_artifacts", 0)),
)
return text, summary, spans, raw, ui_state, validation_html, *chip_updates
def load_random_fleurs_mix_example(fasttext_mode: str = "full") -> tuple[str, str, pd.DataFrame, dict[str, Any], dict[str, Any], str]:
try:
mix = fetch_random_cached_sentence_mix()
except FileNotFoundError as exc:
empty = pd.DataFrame(columns=["token", "language", "score", "start", "end"])
message = (
"<div class='empty-state'>"
f"{exc}"
"</div>"
)
return "", message, empty, {}, {}, "", *[gr.update(value="", visible=False) for _ in range(6)]
text = mix["text"]
summary, spans, raw, ui_state, _, *chip_updates = predict(text, fasttext_mode=fasttext_mode)
validation = build_example_validation(
raw.get("classifier_scores", {}),
raw.get("fasttext_scores", {}),
mix.get("langs", []),
)
raw = {
**raw,
"source": mix.get("source", "fleurs-mix"),
"lang_count": mix["lang_count"],
"sentence_langs": mix["langs"],
"sentence_lang_iso3s": mix["lang_iso3s"],
"sentences": mix["sentences"],
_validation_key(mix.get("source", "fleurs-mix")): validation,
}
source_label = _source_label(mix.get("source", "fleurs-mix"))
validation_html = render_validation_html(validation, source_label=source_label)
summary = render_prediction_summary(
text=text,
selected_lang=ui_state.get("selected_lang", raw.get("selected_lang", "")),
dominant_lang=ui_state.get("dominant_lang", raw.get("selected_lang", "")),
lang_stats=ui_state.get("lang_stats", {}),
classifier_scores=ui_state.get("classifier_scores", {}),
fasttext_result=raw.get("fasttext"),
overall_confidence=float(ui_state.get("overall_confidence", 0.0)),
ignored_artifacts=int(ui_state.get("ignored_artifacts", 0)),
)
return text, summary, spans, raw, ui_state, validation_html, *chip_updates
CSS = """
:root {
--bg-1: #06111f;
--bg-2: #0b1f33;
--card: rgba(10, 20, 33, 0.72);
--card-border: rgba(255, 255, 255, 0.12);
--text: #f4f7fb;
--muted: #b7c3d6;
--accent: #7dd3fc;
--accent-2: #f59e0b;
}
body {
background:
radial-gradient(circle at top left, rgba(125, 211, 252, 0.22), transparent 28%),
radial-gradient(circle at top right, rgba(245, 158, 11, 0.16), transparent 24%),
linear-gradient(135deg, var(--bg-1), var(--bg-2));
}
.wrap {
max-width: 1180px;
margin: 0 auto;
}
.hero {
padding: 28px 28px 22px;
border: 1px solid var(--card-border);
border-radius: 24px;
background: linear-gradient(180deg, rgba(255,255,255,0.08), rgba(255,255,255,0.03));
box-shadow: 0 24px 80px rgba(0, 0, 0, 0.28);
backdrop-filter: blur(14px);
}
.eyebrow {
text-transform: uppercase;
letter-spacing: 0.22em;
color: var(--accent);
font-size: 12px;
font-weight: 700;
margin-bottom: 8px;
}
.title {
font-size: clamp(32px, 5vw, 56px);
line-height: 1.02;
margin: 0;
color: var(--text);
font-weight: 800;
}
.subtitle {
margin-top: 12px;
color: var(--muted);
font-size: 16px;
max-width: 820px;
}
.summary-card {
border: 1px solid var(--card-border);
border-radius: 22px;
padding: 22px;
background: rgba(7, 13, 24, 0.7);
color: var(--text);
min-height: 240px;
display: flex;
flex-direction: column;
gap: 10px;
}
.summary-kicker {
color: var(--accent);
text-transform: uppercase;
letter-spacing: 0.18em;
font-size: 11px;
font-weight: 700;
}
.summary-main {
font-size: 48px;
font-weight: 900;
margin-top: 8px;
color: white;
line-height: 0.95;
letter-spacing: -0.03em;
}
.summary-note {
color: var(--muted);
margin-top: 2px;
line-height: 1.45;
}
.metric-grid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 8px;
}
.metric {
border: 1px solid rgba(255, 255, 255, 0.10);
background: rgba(255, 255, 255, 0.03);
border-radius: 16px;
padding: 10px 12px;
}
.metric-label {
display: block;
color: var(--muted);
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.14em;
margin-bottom: 6px;
}
.metric-value {
display: block;
color: var(--text);
font-size: 16px;
font-weight: 700;
}
.metric-value strong {
color: white;
}
.summary-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
align-items: flex-start;
}
.meter {
margin-top: 2px;
}
.meter-head {
display: flex;
justify-content: space-between;
align-items: baseline;
gap: 10px;
margin-bottom: 8px;
}
.meter-value {
color: var(--text);
font-size: 16px;
font-weight: 800;
}
.meter-track {
width: 100%;
height: 10px;
border-radius: 999px;
overflow: hidden;
background: rgba(255, 255, 255, 0.08);
}
.meter-fill {
height: 100%;
border-radius: 999px;
background: linear-gradient(90deg, var(--accent), #60a5fa);
}
.chip-row {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 2px;
}
.chip {
border: 1px solid rgba(125, 211, 252, 0.25);
background: rgba(125, 211, 252, 0.08);
color: var(--text);
padding: 8px 10px;
border-radius: 999px;
font-size: 12px;
min-width: 140px;
white-space: nowrap;
}
.chip strong {
margin-left: 4px;
color: white;
}
.chip-conf {
display: block;
color: var(--muted);
font-size: 11px;
margin-top: 2px;
}
.ambiguity-warning {
margin-top: 10px;
padding: 10px 12px;
border-radius: 14px;
border: 1px solid rgba(245, 158, 11, 0.35);
background: rgba(245, 158, 11, 0.12);
color: #fbd38d;
font-size: 13px;
font-weight: 600;
}
.empty-state {
padding: 18px 20px;
border-radius: 18px;
border: 1px dashed rgba(255,255,255,0.16);
color: var(--muted);
background: rgba(255,255,255,0.03);
}
.gradio-container .gr-textbox textarea {
font-size: 15px !important;
}
.footer-note {
color: var(--muted);
font-size: 13px;
margin-top: 8px;
}
.footer-langs {
cursor: pointer;
padding: 14px 16px;
border-radius: 14px;
border: 1px solid var(--card-border);
background: rgba(255, 255, 255, 0.04);
}
.footer-langs summary {
list-style: none;
font-weight: 700;
}
.footer-langs summary::-webkit-details-marker {
display: none;
}
.footer-langs-body {
margin-top: 10px;
max-height: 240px;
overflow: auto;
padding-right: 4px;
}
.footer-lang-list {
margin: 0;
padding: 0;
list-style: none;
display: grid;
grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
gap: 8px 12px;
}
.footer-lang-list li {
display: flex;
gap: 8px;
align-items: baseline;
}
.lang-code {
font-family: monospace;
color: var(--accent);
min-width: 2.2rem;
}
.lang-name {
color: var(--text);
}
.validation-strip {
border-radius: 18px;
padding: 12px 14px;
margin-top: 10px;
}
.validation-high {
border: 1px solid rgba(34, 197, 94, 0.30);
background: rgba(34, 197, 94, 0.10);
}
.validation-mid {
border: 1px solid rgba(245, 158, 11, 0.35);
background: rgba(245, 158, 11, 0.10);
}
.validation-low {
border: 1px solid rgba(239, 68, 68, 0.35);
background: rgba(239, 68, 68, 0.10);
}
.validation-kicker {
color: #86efac;
text-transform: uppercase;
letter-spacing: 0.18em;
font-size: 11px;
font-weight: 700;
}
.validation-main {
color: white;
font-size: 28px;
font-weight: 900;
margin-top: 4px;
}
.validation-vs {
color: var(--muted);
font-size: 16px;
font-weight: 700;
margin-left: 8px;
}
.validation-status {
margin-top: 4px;
font-size: 13px;
font-weight: 700;
}
.validation-pass {
color: #86efac;
}
.validation-warn {
color: #fbbf24;
}
.validation-subtitle {
color: var(--muted);
margin-top: 6px;
font-size: 13px;
line-height: 1.4;
}
.validation-grid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 8px;
margin-top: 10px;
}
.validation-chip {
border: 1px solid rgba(255, 255, 255, 0.10);
background: rgba(255, 255, 255, 0.04);
border-radius: 14px;
padding: 10px 12px;
display: flex;
flex-direction: column;
gap: 4px;
}
.validation-chip-label {
color: var(--muted);
text-transform: uppercase;
letter-spacing: 0.14em;
font-size: 10px;
font-weight: 700;
}
.validation-chip-value {
color: white;
font-size: 14px;
font-weight: 700;
line-height: 1.35;
}
.validation-chip-muted {
color: var(--muted);
font-size: 12px;
font-weight: 600;
}
.validation-grid .validation-chip:first-child,
.validation-grid .validation-chip:nth-child(2) {
grid-column: span 1;
}
.validation-grid .validation-chip:nth-child(4) {
grid-column: span 2;
}
.validation-grid .validation-chip:nth-child(7),
.validation-grid .validation-chip:nth-child(8) {
grid-column: span 1;
}
.chip-strip {
display: flex !important;
flex-wrap: wrap;
gap: 10px;
margin-top: -6px;
padding: 0 2px;
}
.chip-btn {
min-width: 0;
flex: 1 1 calc((100% - 20px) / 3);
}
.chip-btn button {
width: 100%;
min-height: 44px;
border-radius: 14px !important;
padding: 8px 12px;
white-space: nowrap;
line-height: 1.1;
text-align: center;
font-family: monospace !important;
font-size: 12px !important;
background: rgba(125, 211, 252, 0.06) !important;
border: 1px solid rgba(125, 211, 252, 0.2) !important;
}
.chip-btn button:hover {
background: rgba(125, 211, 252, 0.14) !important;
border-color: rgba(125, 211, 252, 0.45) !important;
}
.results-shell {
gap: 18px;
align-items: start;
}
.results-grid {
gap: 14px;
align-items: stretch;
}
.results-panel {
min-width: 0 !important;
}
.results-panel .gr-panel {
height: 100%;
}
.results-panel .gr-dataframe,
.results-panel .gr-json {
min-height: 280px;
max-height: 420px;
overflow-y: auto;
}
.gradio-container .gr-dataframe table {
table-layout: fixed !important;
width: 100% !important;
}
.gradio-container .gr-dataframe th:nth-child(1),
.gradio-container .gr-dataframe td:nth-child(1) {
width: 42% !important;
}
.gradio-container .gr-dataframe th:nth-child(2),
.gradio-container .gr-dataframe td:nth-child(2) {
width: 12% !important;
}
.gradio-container .gr-dataframe th:nth-child(3),
.gradio-container .gr-dataframe td:nth-child(3) {
width: 14% !important;
}
.gradio-container .gr-dataframe th:nth-child(4),
.gradio-container .gr-dataframe td:nth-child(4) {
width: 16% !important;
}
.gradio-container .gr-dataframe th:nth-child(5),
.gradio-container .gr-dataframe td:nth-child(5) {
width: 16% !important;
}
.gradio-container .gr-dataframe td:nth-child(1) {
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
@media (max-width: 900px) {
.chip-strip {
flex-basis: 100%;
}
}
@media (max-width: 640px) {
.chip-strip {
flex-basis: 100%;
}
.chip-btn {
flex-basis: 100%;
}
}
.action-btn {
width: 100%;
}
.action-btn button {
width: 100%;
min-height: 56px;
padding: 0 16px;
white-space: normal;
line-height: 1.15;
display: flex;
align-items: center;
justify-content: center;
text-align: center;
}
.action-primary button {
min-height: 58px;
font-weight: 800;
}
.action-secondary button {
min-height: 58px;
}
.action-clear button {
min-height: 48px;
opacity: 0.9;
}
.action-strip {
gap: 12px;
}
.action-strip > .gr-column {
min-width: 0 !important;
}
.action-stack {
margin-top: 10px;
gap: 10px;
}
"""
with gr.Blocks(title="Polyglot Tagger Studio") as demo:
gr.HTML(
"""
<div class="wrap hero">
<div class="eyebrow">Multilingual Language ID</div>
<h1 class="title">Polyglot Tagger Studio</h1>
<div class="subtitle">
A Gradio demo for the token-classification model behind this repo. Paste a sentence or paragraph,
and the app will surface the dominant language signal, token-level spans, and raw predictions. Note that this is experimental and does not replace a text classifier: be prepared for unexpected results.
</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=5):
input_text = gr.Textbox(
label="Text",
lines=12,
placeholder="Paste a sentence or a short paragraph here...",
value="",
)
fasttext_mode = gr.Radio(
choices=["Full text", "Sentence by sentence"],
value="Full text",
label="fastText mode",
info="Choose whether fastText sees the whole input at once or one sentence at a time.",
)
gr.Markdown("Sentence-by-sentence mode splits on double newlines first, then sentence punctuation inside each block.")
validation_strip = gr.HTML()
gr.Markdown(
"Use the buttons for fresh examples, or paste your own text."
)
with gr.Row(elem_classes=["action-strip"]):
with gr.Column(scale=1, min_width=0):
analyze_btn = gr.Button("Analyze", variant="primary", elem_classes=["action-btn", "action-primary"])
with gr.Column(scale=1, min_width=0):
clear_btn = gr.Button("Clear", elem_classes=["action-btn", "action-clear"])
with gr.Row(elem_classes=["action-strip", "action-stack"]):
with gr.Column(scale=1, min_width=0):
random_btn = gr.Button("Random sentence", elem_classes=["action-btn", "action-secondary"])
with gr.Column(scale=1, min_width=0):
random_mix_btn = gr.Button("Random mix", elem_classes=["action-btn", "action-secondary"])
with gr.Column(scale=7):
summary = gr.HTML()
prediction_state = gr.State({})
with gr.Row(elem_classes=["chip-strip"]):
chip_0 = gr.Button("", visible=True, elem_classes=["chip-btn"])
chip_1 = gr.Button("", visible=True, elem_classes=["chip-btn"])
chip_2 = gr.Button("", visible=True, elem_classes=["chip-btn"])
chip_3 = gr.Button("", visible=True, elem_classes=["chip-btn"])
chip_4 = gr.Button("", visible=True, elem_classes=["chip-btn"])
chip_5 = gr.Button("", visible=True, elem_classes=["chip-btn"])
with gr.Row(elem_classes=["results-shell"]):
with gr.Column(scale=7, min_width=0, elem_classes=["results-panel"]):
spans = gr.Dataframe(
headers=["token", "language", "score", "start", "end"],
datatype=["str", "str", "number", "number", "number"],
label="Token-level spans",
interactive=False,
wrap=True,
)
with gr.Column(scale=5, min_width=0, elem_classes=["results-panel"]):
raw = gr.JSON(label="Raw output")
analyze_btn.click(
fn=predict,
inputs=[input_text, fasttext_mode],
outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5],
api_name="analyze",
)
random_btn.click(
fn=load_random_fleurs_example,
inputs=fasttext_mode,
outputs=[input_text, summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5],
api_name="random_fleurs_sentence",
)
random_mix_btn.click(
fn=load_random_fleurs_mix_example,
inputs=fasttext_mode,
outputs=[input_text, summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5],
api_name="random_fleurs_mix",
)
input_text.submit(
fn=predict,
inputs=[input_text, fasttext_mode],
outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5],
api_name="analyze_text",
)
clear_btn.click(
fn=lambda: (
"",
pd.DataFrame(columns=["token", "language", "score", "start", "end"]),
{},
{},
"",
*[gr.update(value="", visible=False) for _ in range(6)],
),
inputs=None,
outputs=[summary, spans, raw, prediction_state, validation_strip, chip_0, chip_1, chip_2, chip_3, chip_4, chip_5],
api_name="clear",
)
chip_0.click(fn=lambda state: select_language_from_chip(0, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_0")
chip_1.click(fn=lambda state: select_language_from_chip(1, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_1")
chip_2.click(fn=lambda state: select_language_from_chip(2, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_2")
chip_3.click(fn=lambda state: select_language_from_chip(3, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_3")
chip_4.click(fn=lambda state: select_language_from_chip(4, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_4")
chip_5.click(fn=lambda state: select_language_from_chip(5, state), inputs=prediction_state, outputs=[summary, prediction_state], api_name="select_chip_5")
gr.HTML(render_language_reference_html())
if __name__ == "__main__":
demo.queue()
demo.launch(css=CSS, share=os.getenv("GRADIO_SHARE", "1") != "0")