| """ |
| Single-video prediction: brain β verdict. |
| """ |
|
|
| import json |
| from pathlib import Path |
| from typing import Dict |
| import numpy as np |
|
|
| from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas |
| from .signals import network_to_ux_signals |
| from .features import extract_features |
| from .classifier import ViralityClassifier |
|
|
|
|
| class ViralityPredictor: |
| """End-to-end: video file β verdict + explanation.""" |
|
|
| def __init__(self, model_path: str, tribe_model_id: str = "facebook/tribev2"): |
| self.clf = ViralityClassifier.load(model_path) |
| self.tribe = TribeV2Wrapper(model_id=tribe_model_id) |
| self.atlas = load_yeo_atlas() |
|
|
| def predict(self, video_path: str, tr: float = 1.5) -> Dict: |
| brain = self.tribe.predict_brain(video_path, tr=tr) |
| network_ts = self.tribe.network_means(brain, self.atlas) |
| signals = network_to_ux_signals(network_ts) |
| feats = extract_features(signals, tr=tr) |
| explanation = self.clf.explain(feats, top_n=3) |
| return explanation |
|
|
| @staticmethod |
| def verdict_badge(verdict: str) -> str: |
| badges = { |
| "good": "π’ GOOD β ship it", |
| "okish": "π‘ OKISH β might hover at baseline", |
| "bad": "π΄ BAD β recut or scrap", |
| } |
| return badges.get(verdict, verdict) |
|
|