moonlantern1 commited on
Commit
5aba719
·
verified ·
1 Parent(s): e7b3bfa

Upload brain_virality_predictor/predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. brain_virality_predictor/predict.py +39 -0
brain_virality_predictor/predict.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Single-video prediction: brain → verdict.
3
+ """
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Dict
8
+ import numpy as np
9
+
10
+ from .tribe_wrapper import TribeV2Wrapper, load_yeo_atlas
11
+ from .signals import network_to_ux_signals
12
+ from .features import extract_features
13
+ from .classifier import ViralityClassifier
14
+
15
+
16
+ class ViralityPredictor:
17
+ """End-to-end: video file → verdict + explanation."""
18
+
19
+ def __init__(self, model_path: str, tribe_model_id: str = "facebook/tribev2"):
20
+ self.clf = ViralityClassifier.load(model_path)
21
+ self.tribe = TribeV2Wrapper(model_id=tribe_model_id)
22
+ self.atlas = load_yeo_atlas()
23
+
24
+ def predict(self, video_path: str, tr: float = 1.5) -> Dict:
25
+ brain = self.tribe.predict_brain(video_path, tr=tr)
26
+ network_ts = self.tribe.network_means(brain, self.atlas)
27
+ signals = network_to_ux_signals(network_ts)
28
+ feats = extract_features(signals, tr=tr)
29
+ explanation = self.clf.explain(feats, top_n=3)
30
+ return explanation
31
+
32
+ @staticmethod
33
+ def verdict_badge(verdict: str) -> str:
34
+ badges = {
35
+ "good": "🟢 GOOD — ship it",
36
+ "okish": "🟡 OKISH — might hover at baseline",
37
+ "bad": "🔴 BAD — recut or scrap",
38
+ }
39
+ return badges.get(verdict, verdict)