File size: 1,313 Bytes
5aba719
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Single-video prediction: brain → verdict.
"""

import json
from pathlib import Path
from typing import Dict
import numpy as np

from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas
from .signals import network_to_ux_signals
from .features import extract_features
from .classifier import ViralityClassifier


class ViralityPredictor:
    """End-to-end: video file → verdict + explanation."""

    def __init__(self, model_path: str, tribe_model_id: str = "facebook/tribev2"):
        self.clf = ViralityClassifier.load(model_path)
        self.tribe = TribeV2Wrapper(model_id=tribe_model_id)
        self.atlas = load_yeo_atlas()

    def predict(self, video_path: str, tr: float = 1.5) -> Dict:
        brain = self.tribe.predict_brain(video_path, tr=tr)
        network_ts = self.tribe.network_means(brain, self.atlas)
        signals = network_to_ux_signals(network_ts)
        feats = extract_features(signals, tr=tr)
        explanation = self.clf.explain(feats, top_n=3)
        return explanation

    @staticmethod
    def verdict_badge(verdict: str) -> str:
        badges = {
            "good":  "🟢 GOOD — ship it",
            "okish": "🟡 OKISH — might hover at baseline",
            "bad":   "🔴 BAD — recut or scrap",
        }
        return badges.get(verdict, verdict)