Spaces:
Running
Running
| """ | |
| app.py β Algospeak Classifier demo | |
| Streamlit UI for the dual BERTweet model. | |
| Type a social media post and see the predicted class + confidence scores. | |
| Predictions are logged to a private HF dataset repo via CommitScheduler. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent / "poc" / "src")) | |
| import csv | |
| import yaml | |
| import torch | |
| import numpy as np | |
| import emoji | |
| import streamlit as st | |
| from datetime import datetime | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download, CommitScheduler | |
| from inference import load_unsupervised_encoder, classify_text | |
| BASE_DIR = Path(__file__).parent | |
| MODEL_REPO = "timagonch/algospeak-classifier-model" | |
| LOG_REPO = "timagonch/algospeak-logs" | |
| LOG_DIR = BASE_DIR / "logs" | |
| LOG_FILE = LOG_DIR / "predictions.csv" | |
| LOG_COLS = ["text", "predicted_label", "score_allowed", "score_offensive", "score_mature", "score_algospeak", "timestamp"] | |
| CLASS_COLORS = { | |
| "Allowed": "green", | |
| "Offensive Language": "red", | |
| "Mature Content": "orange", | |
| "Algospeak": "violet", | |
| } | |
| def load_model(): | |
| with open(BASE_DIR / "poc" / "config.yaml") as f: | |
| cfg = yaml.safe_load(f) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename="best_model.pt") | |
| prototypes_path = hf_hub_download(repo_id=MODEL_REPO, filename="prototypes.npy") | |
| encoder = load_unsupervised_encoder(checkpoint_path, cfg, device) | |
| prototypes = np.load(prototypes_path) | |
| tokenizer = AutoTokenizer.from_pretrained(cfg["model_name"], use_fast=False) | |
| return encoder, prototypes, tokenizer, cfg, device | |
| def get_scheduler(): | |
| import shutil | |
| LOG_DIR.mkdir(exist_ok=True) | |
| # Pull existing log from HF on startup so we append instead of overwrite | |
| try: | |
| existing = hf_hub_download( | |
| repo_id=LOG_REPO, | |
| filename="logs/predictions.csv", | |
| repo_type="dataset", | |
| ) | |
| shutil.copy(existing, LOG_FILE) | |
| except Exception: | |
| pass # no log yet, start fresh | |
| return CommitScheduler( | |
| repo_id=LOG_REPO, | |
| repo_type="dataset", | |
| folder_path=LOG_DIR, | |
| path_in_repo="logs", | |
| every=5, | |
| ) | |
| def log_prediction(text, result): | |
| scheduler = get_scheduler() | |
| scores = result["scores"] | |
| row = { | |
| "text": text, | |
| "predicted_label": result["predicted_label"], | |
| "score_allowed": round(scores["Allowed"], 4), | |
| "score_offensive": round(scores["Offensive Language"], 4), | |
| "score_mature": round(scores["Mature Content"], 4), | |
| "score_algospeak": round(scores["Algospeak"], 4), | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| with scheduler.lock: | |
| write_header = not LOG_FILE.exists() | |
| with open(LOG_FILE, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=LOG_COLS) | |
| if write_header: | |
| writer.writeheader() | |
| writer.writerow(row) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.title("Algospeak Classifier") | |
| st.caption("Dual BERTweet model Β· type a social media post to classify it.") | |
| text = st.text_area("Post text", height=120, placeholder="Type something here...") | |
| if st.button("Classify", type="primary") and text.strip(): | |
| encoder, prototypes, tokenizer, cfg, device = load_model() | |
| result = classify_text(text, encoder, prototypes, tokenizer, cfg["max_length"], device) | |
| label = result["predicted_label"] | |
| color = CLASS_COLORS[label] | |
| st.markdown(f"## :{color}[{label}]") | |
| st.divider() | |
| st.write("**Similarity scores:**") | |
| for name, score in sorted(result["scores"].items(), key=lambda x: -x[1]): | |
| st.progress(float(score), text=name) | |
| log_prediction(text, result) | |