File size: 1,313 Bytes
5aba719 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | """
Single-video prediction: brain → verdict.
"""
import json
from pathlib import Path
from typing import Dict
import numpy as np
from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas
from .signals import network_to_ux_signals
from .features import extract_features
from .classifier import ViralityClassifier
class ViralityPredictor:
"""End-to-end: video file → verdict + explanation."""
def __init__(self, model_path: str, tribe_model_id: str = "facebook/tribev2"):
self.clf = ViralityClassifier.load(model_path)
self.tribe = TribeV2Wrapper(model_id=tribe_model_id)
self.atlas = load_yeo_atlas()
def predict(self, video_path: str, tr: float = 1.5) -> Dict:
brain = self.tribe.predict_brain(video_path, tr=tr)
network_ts = self.tribe.network_means(brain, self.atlas)
signals = network_to_ux_signals(network_ts)
feats = extract_features(signals, tr=tr)
explanation = self.clf.explain(feats, top_n=3)
return explanation
@staticmethod
def verdict_badge(verdict: str) -> str:
badges = {
"good": "🟢 GOOD — ship it",
"okish": "🟡 OKISH — might hover at baseline",
"bad": "🔴 BAD — recut or scrap",
}
return badges.get(verdict, verdict)
|