ElevenClip-AI / backend /app /services /subtitles.py
JakgritB
feat(editor): subtitle-first editor + AI subtitle pipeline
89e1dc4
import re
from pathlib import Path
from app.models.schemas import TranscriptSegment
def seconds_to_srt_time(value: float) -> str:
millis = int(round(value * 1000))
hours, remainder = divmod(millis, 3_600_000)
minutes, remainder = divmod(remainder, 60_000)
seconds, millis = divmod(remainder, 1000)
return f"{hours:02}:{minutes:02}:{seconds:02},{millis:03}"
def write_srt(
path: Path, clip_start: float, clip_end: float, segments: list[TranscriptSegment]
) -> list[dict]:
cues: list[dict] = []
rows: list[str] = []
index = 1
for segment in segments:
if segment.end_seconds < clip_start or segment.start_seconds > clip_end:
continue
start = max(0.0, segment.start_seconds - clip_start)
end = min(clip_end - clip_start, segment.end_seconds - clip_start)
for cue in split_timed_caption(segment.text, start, max(end, start + 1.2)):
rows.extend(_srt_row(index, cue["start_seconds"], cue["end_seconds"], cue["text"]))
cues.append(cue)
index += 1
if not rows:
cues = [{"start_seconds": 0.0, "end_seconds": 3.0, "text": ""}]
rows = _srt_row(1, 0.0, 3.0, "")
path.write_text("\n".join(rows), encoding="utf-8")
return cues
def write_single_caption_srt(path: Path, duration: float, text: str) -> list[dict]:
safe_duration = max(duration, 1.0)
cues = split_timed_caption(text, 0.0, safe_duration)
rows: list[str] = []
for index, cue in enumerate(cues, start=1):
rows.extend(_srt_row(index, cue["start_seconds"], cue["end_seconds"], cue["text"]))
if not rows:
cues = [{"start_seconds": 0.0, "end_seconds": min(safe_duration, 3.0), "text": ""}]
rows = _srt_row(1, cues[0]["start_seconds"], cues[0]["end_seconds"], "")
path.write_text("\n".join(rows), encoding="utf-8")
return cues
def write_srt_from_cues(path: Path, cues: list) -> list[dict]:
"""Write SRT using user-supplied per-cue timing (preferred over auto-distribution).
Accepts list of objects with .start_seconds / .end_seconds / .text attributes
(Pydantic SubtitleCue) or dicts with the same keys.
"""
rows: list[str] = []
out_cues: list[dict] = []
index = 1
for cue in cues:
start = float(getattr(cue, "start_seconds", None) or cue.get("start_seconds", 0))
end = float(getattr(cue, "end_seconds", None) or cue.get("end_seconds", 0))
text = str(getattr(cue, "text", None) or cue.get("text", ""))
if end <= start:
end = start + 1.0
clean_text = text.strip()
if not clean_text:
continue
rows.extend(_srt_row(index, start, end, clean_text))
out_cues.append({"start_seconds": round(start, 3), "end_seconds": round(end, 3), "text": clean_text})
index += 1
if not rows:
out_cues = [{"start_seconds": 0.0, "end_seconds": 3.0, "text": ""}]
rows = _srt_row(1, 0.0, 3.0, "")
path.write_text("\n".join(rows), encoding="utf-8")
return out_cues
def split_timed_caption(text: str, start: float, end: float) -> list[dict]:
phrases = split_caption_text(text)
if not phrases:
return []
total_duration = max(end - start, 1.2)
max_cues = max(1, int(total_duration / 1.2))
if len(phrases) > max_cues:
phrases = _merge_phrases(phrases, max_cues)
cue_duration = min(4.0, max(1.2, total_duration / len(phrases)))
cues: list[dict] = []
cursor = start
for index, phrase in enumerate(phrases):
remaining = len(phrases) - index
max_end = end - ((remaining - 1) * 1.2)
cue_end = min(max_end, cursor + cue_duration)
cue_end = max(cue_end, cursor + 1.2)
if index == len(phrases) - 1:
cue_end = end
cues.append(
{
"start_seconds": round(cursor, 3),
"end_seconds": round(max(cue_end, cursor + 0.8), 3),
"text": phrase,
}
)
cursor = cue_end
return cues
def split_caption_text(text: str, max_chars: int = 42, max_words: int = 7) -> list[str]:
clean = re.sub(r"\s+", " ", text.strip())
if not clean:
return []
words = clean.split()
if len(words) <= 1:
return [clean[index : index + max_chars] for index in range(0, len(clean), max_chars)]
phrases: list[str] = []
current: list[str] = []
for word in words:
candidate = " ".join([*current, word]).strip()
punctuation_break = bool(current and re.search(r"[,.!?;:]$", current[-1]))
if current and (len(candidate) > max_chars or len(current) >= max_words or punctuation_break):
phrases.append(" ".join(current))
current = [word]
else:
current.append(word)
if current:
phrases.append(" ".join(current))
return phrases
def _merge_phrases(phrases: list[str], target_count: int) -> list[str]:
if target_count <= 1:
return [" ".join(phrases)]
merged: list[str] = []
bucket_size = len(phrases) / target_count
for index in range(target_count):
start = round(index * bucket_size)
end = round((index + 1) * bucket_size)
merged.append(" ".join(phrases[start:end]).strip())
return [phrase for phrase in merged if phrase]
def _srt_row(index: int, start: float, end: float, text: str) -> list[str]:
return [
str(index),
f"{seconds_to_srt_time(start)} --> {seconds_to_srt_time(end)}",
text.strip(),
"",
]