import re from pathlib import Path from app.models.schemas import TranscriptSegment def seconds_to_srt_time(value: float) -> str: millis = int(round(value * 1000)) hours, remainder = divmod(millis, 3_600_000) minutes, remainder = divmod(remainder, 60_000) seconds, millis = divmod(remainder, 1000) return f"{hours:02}:{minutes:02}:{seconds:02},{millis:03}" def write_srt( path: Path, clip_start: float, clip_end: float, segments: list[TranscriptSegment] ) -> list[dict]: cues: list[dict] = [] rows: list[str] = [] index = 1 for segment in segments: if segment.end_seconds < clip_start or segment.start_seconds > clip_end: continue start = max(0.0, segment.start_seconds - clip_start) end = min(clip_end - clip_start, segment.end_seconds - clip_start) for cue in split_timed_caption(segment.text, start, max(end, start + 1.2)): rows.extend(_srt_row(index, cue["start_seconds"], cue["end_seconds"], cue["text"])) cues.append(cue) index += 1 if not rows: cues = [{"start_seconds": 0.0, "end_seconds": 3.0, "text": ""}] rows = _srt_row(1, 0.0, 3.0, "") path.write_text("\n".join(rows), encoding="utf-8") return cues def write_single_caption_srt(path: Path, duration: float, text: str) -> list[dict]: safe_duration = max(duration, 1.0) cues = split_timed_caption(text, 0.0, safe_duration) rows: list[str] = [] for index, cue in enumerate(cues, start=1): rows.extend(_srt_row(index, cue["start_seconds"], cue["end_seconds"], cue["text"])) if not rows: cues = [{"start_seconds": 0.0, "end_seconds": min(safe_duration, 3.0), "text": ""}] rows = _srt_row(1, cues[0]["start_seconds"], cues[0]["end_seconds"], "") path.write_text("\n".join(rows), encoding="utf-8") return cues def write_srt_from_cues(path: Path, cues: list) -> list[dict]: """Write SRT using user-supplied per-cue timing (preferred over auto-distribution). Accepts list of objects with .start_seconds / .end_seconds / .text attributes (Pydantic SubtitleCue) or dicts with the same keys. """ rows: list[str] = [] out_cues: list[dict] = [] index = 1 for cue in cues: start = float(getattr(cue, "start_seconds", None) or cue.get("start_seconds", 0)) end = float(getattr(cue, "end_seconds", None) or cue.get("end_seconds", 0)) text = str(getattr(cue, "text", None) or cue.get("text", "")) if end <= start: end = start + 1.0 clean_text = text.strip() if not clean_text: continue rows.extend(_srt_row(index, start, end, clean_text)) out_cues.append({"start_seconds": round(start, 3), "end_seconds": round(end, 3), "text": clean_text}) index += 1 if not rows: out_cues = [{"start_seconds": 0.0, "end_seconds": 3.0, "text": ""}] rows = _srt_row(1, 0.0, 3.0, "") path.write_text("\n".join(rows), encoding="utf-8") return out_cues def split_timed_caption(text: str, start: float, end: float) -> list[dict]: phrases = split_caption_text(text) if not phrases: return [] total_duration = max(end - start, 1.2) max_cues = max(1, int(total_duration / 1.2)) if len(phrases) > max_cues: phrases = _merge_phrases(phrases, max_cues) cue_duration = min(4.0, max(1.2, total_duration / len(phrases))) cues: list[dict] = [] cursor = start for index, phrase in enumerate(phrases): remaining = len(phrases) - index max_end = end - ((remaining - 1) * 1.2) cue_end = min(max_end, cursor + cue_duration) cue_end = max(cue_end, cursor + 1.2) if index == len(phrases) - 1: cue_end = end cues.append( { "start_seconds": round(cursor, 3), "end_seconds": round(max(cue_end, cursor + 0.8), 3), "text": phrase, } ) cursor = cue_end return cues def split_caption_text(text: str, max_chars: int = 42, max_words: int = 7) -> list[str]: clean = re.sub(r"\s+", " ", text.strip()) if not clean: return [] words = clean.split() if len(words) <= 1: return [clean[index : index + max_chars] for index in range(0, len(clean), max_chars)] phrases: list[str] = [] current: list[str] = [] for word in words: candidate = " ".join([*current, word]).strip() punctuation_break = bool(current and re.search(r"[,.!?;:]$", current[-1])) if current and (len(candidate) > max_chars or len(current) >= max_words or punctuation_break): phrases.append(" ".join(current)) current = [word] else: current.append(word) if current: phrases.append(" ".join(current)) return phrases def _merge_phrases(phrases: list[str], target_count: int) -> list[str]: if target_count <= 1: return [" ".join(phrases)] merged: list[str] = [] bucket_size = len(phrases) / target_count for index in range(target_count): start = round(index * bucket_size) end = round((index + 1) * bucket_size) merged.append(" ".join(phrases[start:end]).strip()) return [phrase for phrase in merged if phrase] def _srt_row(index: int, start: float, end: float, text: str) -> list[str]: return [ str(index), f"{seconds_to_srt_time(start)} --> {seconds_to_srt_time(end)}", text.strip(), "", ]