moonlantern1's picture
Upload brain_virality_predictor/predict.py with huggingface_hub
5aba719 verified
"""
Single-video prediction: brain β†’ verdict.
"""
import json
from pathlib import Path
from typing import Dict
import numpy as np
from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas
from .signals import network_to_ux_signals
from .features import extract_features
from .classifier import ViralityClassifier
class ViralityPredictor:
"""End-to-end: video file β†’ verdict + explanation."""
def __init__(self, model_path: str, tribe_model_id: str = "facebook/tribev2"):
self.clf = ViralityClassifier.load(model_path)
self.tribe = TribeV2Wrapper(model_id=tribe_model_id)
self.atlas = load_yeo_atlas()
def predict(self, video_path: str, tr: float = 1.5) -> Dict:
brain = self.tribe.predict_brain(video_path, tr=tr)
network_ts = self.tribe.network_means(brain, self.atlas)
signals = network_to_ux_signals(network_ts)
feats = extract_features(signals, tr=tr)
explanation = self.clf.explain(feats, top_n=3)
return explanation
@staticmethod
def verdict_badge(verdict: str) -> str:
badges = {
"good": "🟒 GOOD β€” ship it",
"okish": "🟑 OKISH β€” might hover at baseline",
"bad": "πŸ”΄ BAD β€” recut or scrap",
}
return badges.get(verdict, verdict)