clipforge / src /humeo /clip_selector.py
moonlantern1's picture
Add OpenRouter backup key failover
0997589 verified
"""
Step 2 - Clip Selection: Gemini-only LLM for viral clip identification.
Uses the unified ``google-genai`` SDK (``from google import genai``). See:
https://github.com/googleapis/python-genai
"""
from __future__ import annotations
import json
import logging
import re
import time
from pathlib import Path
from typing import Callable, TypeVar
from google import genai
from openai import OpenAI
from humeo.gemini_generate import gemini_generate_config
from humeo_core.schemas import Clip, ClipPlan
from humeo.config import (
GEMINI_MODEL,
MAX_CLIP_DURATION_SEC,
MIN_CLIP_DURATION_SEC,
TEXT_AXIS_WEIGHTS,
TARGET_CLIP_COUNT,
)
from humeo.env import (
OPENROUTER_BASE_URL,
model_name_for_provider,
openrouter_default_headers,
resolve_gemini_api_key,
resolve_llm_provider,
resolve_openrouter_api_keys,
)
from humeo.hook_library import (
format_hook_examples,
retrieve_hook_examples,
)
from humeo.prompt_loader import clip_selection_prompts
logger = logging.getLogger(__name__)
T = TypeVar("T")
LLM_MAX_ATTEMPTS = 4
LLM_RETRY_DELAY_SEC = 2.0
# Over-generation defaults (also exposed via PipelineConfig so callers can
# override per-run without touching code). Rationale:
#
# - Ask Gemini for a *pool* of ~12 candidates at temperature 0.7 so the model
# considers a wider slice of the transcript instead of locking onto the
# first 5 obvious ones. More candidates -> more chance the actual gold
# nugget is in the list.
# - Then rank by ``virality_score`` and keep everything >= threshold, but
# always keep at least ``min_kept`` and at most ``max_kept`` clips. This
# lets a single strong clip survive a weak transcript ("keep the best 5
# even if no one clears the bar") AND lets an exceptionally rich
# transcript ship 7-8 strong shorts instead of artificially capping at 5.
DEFAULT_CANDIDATE_COUNT = 12
DEFAULT_QUALITY_THRESHOLD = 0.70
DEFAULT_MIN_KEPT = TARGET_CLIP_COUNT
DEFAULT_MAX_KEPT = 8
# Higher than the old 0.3 so the pool is meaningfully different from
# "the same five most-obvious clips every run". Still well below 1.0 so we
# do not get word-salad IDs or timestamps.
DEFAULT_CANDIDATE_TEMPERATURE = 0.7
_TITLE_SMALL_WORDS = {
"a",
"an",
"and",
"as",
"at",
"by",
"for",
"from",
"in",
"of",
"on",
"or",
"the",
"to",
"vs",
"with",
}
_TITLE_DROP_WORDS = {
"actually",
"entirely",
"just",
"next",
"really",
"still",
"that",
"their",
"these",
"this",
"those",
"very",
"will",
"your",
}
_TITLE_BLAND_WORDS = {
"big",
"future",
"important",
"lesson",
"matter",
"matters",
"opportunity",
"reason",
"soon",
"story",
"thing",
}
_GENERIC_TITLE_PATTERNS = (
"big opportunity",
"future of",
"important lesson",
"start a business with ai",
"why this matters",
"what this means",
)
_TITLE_TOKEN_REPLACEMENTS = {
"ai": "AI",
"agi": "AGI",
"api": "API",
"btc": "BTC",
"ev": "EV",
"evs": "EVs",
"us": "US",
}
_POWER_TITLE_TOKENS = {"$", "%", "under", "beats", "fewer", "more", "less", "vs"}
_FILLER_OPENERS = {
"actually",
"basically",
"i",
"kind",
"look",
"listen",
"now",
"okay",
"ok",
"right",
"so",
"sort",
"well",
"yeah",
"you",
}
_FILLER_OPENING_PHRASES = {
"i mean",
"kind of",
"sort of",
"you know",
}
_PREFERRED_MAX_DURATION_SEC = 72.0
def _has_valid_duration(clip: Clip) -> bool:
"""Return True when the clip window satisfies the product duration contract."""
return MIN_CLIP_DURATION_SEC <= clip.duration_sec <= MAX_CLIP_DURATION_SEC
def _text_composite_score(clip: Clip) -> float:
"""Weighted composite from the text-axis breakdown, falling back to virality_score.
Cache compatibility note:
- New Ticket 3 clips use the three-axis rubric (message_wow / hook_emotion / catchy).
- Older caches may still contain legacy rule-name ``score_breakdown`` maps from the
pre-Ticket-3 prompt. If none of the expected axes are present, fall back cleanly
to ``virality_score`` instead of treating the legacy shape as three missing axes.
"""
if not clip.score_breakdown:
return clip.virality_score
present_expected_axes = [axis for axis in TEXT_AXIS_WEIGHTS if axis in clip.score_breakdown]
if not present_expected_axes:
return clip.virality_score
total = 0.0
missing: list[str] = []
for axis, weight in TEXT_AXIS_WEIGHTS.items():
value = clip.score_breakdown.get(axis)
if value is None:
missing.append(axis)
continue
total += value * weight
if missing:
logger.warning(
"Clip %s score_breakdown missing axis(es) %s; treating as 0.0.",
clip.clip_id,
", ".join(missing),
)
return total
def _title_quality_penalty(clip: Clip) -> float:
title = _tighten_overlay_title_text(clip.suggested_overlay_title or "")
if not title:
return 0.0
penalty = 0.0
if _looks_generic_title(title):
penalty += 0.18
tokens = [token for token in _normalized_title(title).split() if token]
if len(tokens) < 2 or len(tokens) > 6:
penalty += 0.05
if not any(token in title.lower() for token in _POWER_TITLE_TOKENS) and not any(
ch.isdigit() for ch in title
):
penalty += 0.03
return min(0.22, penalty)
def _hook_quality_penalty(clip: Clip) -> float:
penalty = 0.0
if clip.hook_start_sec is not None and clip.hook_start_sec > 5.0:
penalty += min(0.18, 0.06 + (clip.hook_start_sec - 5.0) * 0.025)
opener = " ".join((clip.viral_hook or clip.transcript or "").split()).lower()
if opener:
first_words = opener.split()
first_word = first_words[0] if first_words else ""
opening_phrase = " ".join(first_words[:2])
if first_word in _FILLER_OPENERS:
penalty += 0.14
if opening_phrase in _FILLER_OPENING_PHRASES:
penalty += 0.06
if len(first_words) >= 12:
penalty += 0.03
return min(0.24, penalty)
def _duration_quality_penalty(clip: Clip) -> float:
if clip.duration_sec <= _PREFERRED_MAX_DURATION_SEC:
return 0.0
drift = clip.duration_sec - _PREFERRED_MAX_DURATION_SEC
return min(0.14, 0.03 + drift * 0.01)
def clip_quality_penalty(clip: Clip) -> float:
return min(
0.42,
_title_quality_penalty(clip)
+ _hook_quality_penalty(clip)
+ _duration_quality_penalty(clip),
)
def clip_quality_priority_score(clip: Clip) -> float:
review_penalty = 0.5 if clip.needs_review else 0.0
composite = _text_composite_score(clip)
return composite - review_penalty - clip_quality_penalty(clip)
def renumber_clips_dense(clips: list[Clip]) -> list[Clip]:
renumbered: list[Clip] = []
for idx, clip in enumerate(clips, start=1):
new_id = f"{idx:03d}"
renumbered.append(clip if clip.clip_id == new_id else clip.model_copy(update={"clip_id": new_id}))
return renumbered
def _openai_message_text(content: object) -> str:
"""Normalize OpenAI-compatible message content into plain text."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return ""
def _retry_llm(name: str, fn: Callable[[], T], attempts: int = LLM_MAX_ATTEMPTS) -> T:
last: Exception | None = None
for i in range(attempts):
try:
return fn()
except Exception as e:
last = e
if i < attempts - 1:
logger.warning("%s attempt %d/%d failed: %s", name, i + 1, attempts, e)
time.sleep(LLM_RETRY_DELAY_SEC * (i + 1))
assert last is not None
raise last
def _headline_case_title(text: str) -> str:
words = text.split()
if not words:
return ""
out: list[str] = []
for idx, word in enumerate(words):
if any(ch.isdigit() for ch in word) or word.startswith("$"):
out.append(word)
continue
raw = re.sub(r"^[^A-Za-z]+|[^A-Za-z]+$", "", word)
lower = raw.lower()
if lower in _TITLE_TOKEN_REPLACEMENTS:
out.append(word.replace(raw, _TITLE_TOKEN_REPLACEMENTS[lower]))
continue
if idx not in (0, len(words) - 1) and lower in _TITLE_SMALL_WORDS:
out.append(word.replace(raw, lower))
continue
out.append(word.replace(raw, raw.capitalize()))
return " ".join(out)
def _normalized_title(text: str) -> str:
return re.sub(r"\s+", " ", re.sub(r"[^a-z0-9$% ]+", " ", (text or "").lower())).strip()
def _looks_generic_title(text: str) -> bool:
normalized = _normalized_title(text)
if not normalized:
return True
if any(pattern in normalized for pattern in _GENERIC_TITLE_PATTERNS):
return True
tokens = [token for token in normalized.split() if token]
bland_count = sum(token in _TITLE_BLAND_WORDS for token in tokens)
return bland_count >= 2
def _tighten_overlay_title_text(text: str) -> str:
title = " ".join((text or "").replace("—", "-").split()).strip(" .,!?:;-")
if not title:
return ""
title = re.sub(r"\bwill cost less than\b", "under", title, flags=re.IGNORECASE)
title = re.sub(r"\bless than\b", "under", title, flags=re.IGNORECASE)
title = re.sub(r"\bmade your\b", "", title, flags=re.IGNORECASE)
title = re.sub(r"\bis still\b", "is", title, flags=re.IGNORECASE)
title = re.sub(r"\bis creating\b", "creates", title, flags=re.IGNORECASE)
title = re.sub(r"\bthere are\b", "", title, flags=re.IGNORECASE)
title = re.sub(r"\bentirely\b", "", title, flags=re.IGNORECASE)
words = title.split()
while len(words) > 6:
filtered = [word for word in words if word.lower() not in _TITLE_DROP_WORDS]
if len(filtered) == len(words):
break
words = filtered
if len(words) > 4:
words = [word for word in words if word.lower() not in {"your", "next"} or len(words) <= 4]
if len(words) > 6 and words[0].lower() in {"why", "how", "when"}:
words = words[1:]
if len(words) > 6:
words = words[:6]
return _headline_case_title(" ".join(words).strip(" .,!?:;-"))
def _polish_overlay_title(clip: Clip) -> str:
current = _tighten_overlay_title_text(clip.suggested_overlay_title or "")
if current and not _looks_generic_title(current):
return current
for candidate in (clip.viral_hook or "", clip.topic or ""):
polished = _tighten_overlay_title_text(candidate)
if polished and not _looks_generic_title(polished):
return polished
return current
def _polish_clip_metadata(clip: Clip) -> Clip:
title = _polish_overlay_title(clip)
if not title or title == clip.suggested_overlay_title:
return clip
return clip.model_copy(update={"suggested_overlay_title": title})
def build_prompt(
transcript: dict,
*,
candidate_count: int = DEFAULT_CANDIDATE_COUNT,
steering_notes: list[str] | None = None,
hook_library_path: Path | None = None,
) -> tuple[str, str]:
"""Return ``(system_prompt, user_message)`` for the clip-selector LLM call.
``candidate_count`` is the size of the candidate POOL we ask Gemini for.
A downstream ranker (``rank_and_filter_clips``) then keeps the top
clips that clear the quality threshold. Defaults preserve the previous
visible output (5 clips) when the pool is narrow.
"""
lines = []
for seg in transcript.get("segments", []):
start = seg.get("start", 0)
end = seg.get("end", 0)
text = seg.get("text", "").strip()
lines.append(f"[{start:.1f}s - {end:.1f}s] {text}")
transcript_text = "\n".join(lines)
hook_examples = format_hook_examples(
retrieve_hook_examples(
transcript_text[:8000],
path=hook_library_path,
limit=8,
)
)
system, user = clip_selection_prompts(
transcript_text=transcript_text,
min_dur=MIN_CLIP_DURATION_SEC,
max_dur=MAX_CLIP_DURATION_SEC,
count=candidate_count,
steering_notes=steering_notes,
hook_examples=hook_examples,
)
return system, user
def rank_and_filter_clips(
clips: list[Clip],
*,
threshold: float = DEFAULT_QUALITY_THRESHOLD,
min_kept: int = DEFAULT_MIN_KEPT,
max_kept: int = DEFAULT_MAX_KEPT,
) -> list[Clip]:
"""Rank ``clips`` by text composite (or legacy ``virality_score``) and apply
the threshold+floor+cap.
Rules (in order, with clear precedence):
1. Sort descending by the text composite score when the Ticket 3
three-axis ``score_breakdown`` is present; otherwise fall back to the
legacy ``virality_score``.
2. Keep clips whose active score signal is ``>= threshold`` (or
``needs_review`` cleared). Reviewed-out clips (``needs_review=True``)
are always sent to the back of the priority queue.
3. If fewer than ``min_kept`` clips passed the threshold, fill up from
the remaining clips in rank order until we reach ``min_kept`` (or
run out of candidates).
4. Cap the final list at ``max_kept`` entries.
5. Renumber ``clip_id`` to ``001``, ``002``, ... so downstream artifacts
(keyframes, subtitles, output filenames) stay dense and ordered.
This is the "threshold with a floor" policy the user asked for: quality
first, but never ship zero shorts when the transcript is weak.
"""
if not clips:
return []
score_signal = {id(c): _text_composite_score(c) for c in clips}
priority_signal = {id(c): clip_quality_priority_score(c) for c in clips}
def _priority(c: Clip) -> tuple[float, float]:
return (priority_signal[id(c)], score_signal[id(c)])
valid: list[Clip] = []
invalid: list[Clip] = []
for clip in clips:
if _has_valid_duration(clip):
valid.append(clip)
else:
invalid.append(clip)
logger.warning(
"Clip %s dropped before ranking: duration %.1fs is outside [%ds, %ds] - %s",
clip.clip_id,
clip.duration_sec,
MIN_CLIP_DURATION_SEC,
MAX_CLIP_DURATION_SEC,
clip.topic,
)
if not valid:
logger.warning(
"Clip ranking: 0 valid candidates remain after duration filtering (dropped=%d).",
len(invalid),
)
return []
ordered = sorted(valid, key=_priority, reverse=True)
strong = [c for c in ordered if priority_signal[id(c)] >= threshold and not c.needs_review]
kept = list(strong)
if len(kept) < min_kept:
backfill = [c for c in ordered if c not in kept]
for c in backfill:
if len(kept) >= min_kept:
break
kept.append(c)
if len(kept) < min_kept:
logger.warning(
"Clip ranking: only %d valid candidates remain after duration filtering; "
"cannot satisfy min_kept=%d without invalid clips.",
len(kept),
min_kept,
)
if len(kept) > max_kept:
kept = kept[:max_kept]
# Renumber clip_ids so consumers (filenames, layout vision, subtitles)
# always see 001..NNN in rank order regardless of what the LLM returned.
renumbered = renumber_clips_dense(kept)
dropped = len(valid) - len(kept) + len(invalid)
logger.info(
"Clip ranking: kept %d / %d candidates (threshold=%.2f, min=%d, max=%d, dropped=%d).",
len(renumbered),
len(clips),
threshold,
min_kept,
max_kept,
dropped,
)
for c in renumbered:
logger.info(
" [%s] score=%.2f priority=%.2f penalty=%.2f %s %s",
c.clip_id,
c.virality_score,
clip_quality_priority_score(c),
clip_quality_penalty(c),
"(review)" if c.needs_review else "",
c.topic,
)
return renumbered
def select_clips(
transcript: dict,
*,
gemini_model: str | None = None,
hook_library_path: Path | None = None,
candidate_count: int = DEFAULT_CANDIDATE_COUNT,
quality_threshold: float = DEFAULT_QUALITY_THRESHOLD,
min_kept: int = DEFAULT_MIN_KEPT,
max_kept: int = DEFAULT_MAX_KEPT,
temperature: float = DEFAULT_CANDIDATE_TEMPERATURE,
steering_notes: list[str] | None = None,
) -> tuple[list[Clip], str]:
"""
Call Gemini to select clips. Returns ``(clips, raw_json)`` for caching / debugging.
The returned clip list has already been ranked + filtered by
:func:`rank_and_filter_clips`. ``raw_json`` is the untouched LLM
response so the cache artifact reflects the entire candidate pool for
audit / re-ranking without another LLM call.
Uses ``google.genai.Client`` and ``GenerateContentConfig`` (see Google
Gen AI SDK for Python).
"""
provider = resolve_llm_provider()
model_name = model_name_for_provider((gemini_model or GEMINI_MODEL).strip(), provider)
system_prompt, user_text = build_prompt(
transcript,
candidate_count=candidate_count,
steering_notes=steering_notes,
hook_library_path=hook_library_path,
)
def _call() -> str:
logger.info(
"%s clip selection (model=%s, candidate_pool=%d, temp=%.2f)...",
provider,
model_name,
candidate_count,
temperature,
)
if provider == "google":
client = genai.Client(api_key=resolve_gemini_api_key())
response = client.models.generate_content(
model=model_name,
contents=user_text,
config=gemini_generate_config(
system_instruction=system_prompt,
temperature=temperature,
response_mime_type="application/json",
),
)
if not response.text:
raise RuntimeError("Gemini returned empty response text")
return response.text
keys = resolve_openrouter_api_keys()
last_error: Exception | None = None
for key_idx, api_key in enumerate(keys, start=1):
try:
client = OpenAI(
api_key=api_key,
base_url=OPENROUTER_BASE_URL,
default_headers=openrouter_default_headers(),
)
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_text},
],
temperature=temperature,
response_format={"type": "json_object"},
)
text = _openai_message_text(response.choices[0].message.content)
if not text:
raise RuntimeError("OpenRouter returned empty response text")
if key_idx > 1:
logger.info("OpenRouter clip selection succeeded with fallback key %d/%d", key_idx, len(keys))
return text
except Exception as exc:
last_error = exc
if key_idx < len(keys):
logger.warning(
"OpenRouter clip selection failed with key %d/%d: %s; trying fallback",
key_idx,
len(keys),
exc,
)
assert last_error is not None
raise last_error
raw = _retry_llm("Gemini clip selection", _call)
candidates = _parse_clips(raw)
# The ranker can only backfill from the pool Gemini returned. If Gemini
# under-delivered (e.g. returned 2 of a requested 12), the min_kept floor
# is unenforceable -- warn loudly so we do not silently ship fewer shorts
# than the caller expected.
if len(candidates) < min_kept:
logger.warning(
"Clip selection: Gemini returned only %d candidates (requested %d, floor %d). "
"Output will be capped at %d shorts -- check prompt or transcript length.",
len(candidates),
candidate_count,
min_kept,
len(candidates),
)
elif len(candidates) < candidate_count:
logger.info(
"Clip selection: Gemini returned %d of %d requested candidates "
"(pool still >= floor of %d).",
len(candidates),
candidate_count,
min_kept,
)
clips = rank_and_filter_clips(
candidates,
threshold=quality_threshold,
min_kept=min_kept,
max_kept=max_kept,
)
return clips, raw
def _parse_clips(raw_json: str) -> list[Clip]:
"""Parse and validate the LLM's JSON response into Clip objects."""
data = json.loads(raw_json)
clips_data = data.get("clips", data) if isinstance(data, dict) else data
clips: list[Clip] = []
for item in clips_data:
payload = dict(item)
payload.pop("duration_sec", None)
clip = _polish_clip_metadata(Clip.model_validate(payload))
actual_dur = clip.end_time_sec - clip.start_time_sec
stated_dur = item.get("duration_sec")
if stated_dur is not None and abs(actual_dur - float(stated_dur)) > 1.0:
logger.warning(
"Clip %s: stated duration %.1fs doesn't match (%.1f-%.1f = %.1f).",
clip.clip_id, float(stated_dur),
clip.start_time_sec, clip.end_time_sec, actual_dur,
)
clips.append(clip)
logger.info("Parsed %d clips from LLM response", len(clips))
return clips
def save_clips(clips: list[Clip], output_path: Path) -> Path:
"""Persist clips to a JSON file using the shared Pydantic schema."""
plan = ClipPlan(source_path="", clips=list(clips))
with open(output_path, "w", encoding="utf-8") as f:
f.write(plan.model_dump_json(indent=2))
logger.info("Saved %d clips to %s", len(clips), output_path)
return output_path
def load_clips(clips_path: Path) -> list[Clip]:
"""Load clips from a previously saved JSON file."""
with open(clips_path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "clips" in data:
return [Clip.model_validate(c) for c in data["clips"]]
return [Clip.model_validate(c) for c in data]