studiox-reel-cutter / tests /test_generate.py
rajank18
..
b664a24
# pyright: reportMissingImports=false
from pathlib import Path
from starlette.testclient import TestClient
import app
from downloader import YouTubeDownloadError
from utils import ReelOutput, Segment
client = TestClient(app.app)
def _patch_pipeline_for_success(monkeypatch, tmp_path: Path):
def fake_probe_video(_video_path):
return {"duration": 120.0, "width": 1920, "height": 1080}
def fake_extract_audio(_video_path, audio_dir):
p = audio_dir / "audio.wav"
p.write_bytes(b"audio")
return p
def fake_transcribe_audio(_audio_path, progress_cb=None):
if progress_cb:
progress_cb("transcribing", 100)
return {"words": [], "sentences": [], "chapters": []}
def fake_detect_highlights(**_kwargs):
return [Segment(index=1, start=5.0, end=20.0, reason="hook", score=9.0)]
def fake_process_all_segments(**kwargs):
reel_path = kwargs["reels_dir"] / "reel_01_5s-20s.mp4"
reel_path.write_bytes(b"mp4")
seg = kwargs["segments"][0]
return [ReelOutput(index=1, path=reel_path, segment=seg, file_size=reel_path.stat().st_size)]
def fake_package_reels(_reel_outputs, output_dir):
zip_path = output_dir / "StudioX_Reels_test.zip"
zip_path.write_bytes(b"zip")
return zip_path
monkeypatch.setattr(app, "probe_video", fake_probe_video)
monkeypatch.setattr(app, "extract_audio", fake_extract_audio)
monkeypatch.setattr(app, "transcribe_audio", fake_transcribe_audio)
monkeypatch.setattr(app, "detect_highlights", fake_detect_highlights)
monkeypatch.setattr(app, "process_all_segments", fake_process_all_segments)
monkeypatch.setattr(app, "package_reels", fake_package_reels)
def test_generate_with_yt_url_success(monkeypatch, tmp_path):
_patch_pipeline_for_success(monkeypatch, tmp_path)
def fake_download_youtube_video(**_kwargs):
video_path = tmp_path / "video.mp4"
video_path.write_bytes(b"video")
return video_path
monkeypatch.setattr(app, "download_youtube_video", fake_download_youtube_video)
res = client.post(
"/generate",
data={"yt_url": "https://youtu.be/TT4OhtH3CFY", "job_id": "job-yt-success"},
)
assert res.status_code == 200
assert res.headers.get("x-job-id") == "job-yt-success"
assert res.headers.get("content-type", "").startswith("application/zip")
assert res.content == b"zip"
def test_generate_with_yt_url_extraction_failure(monkeypatch):
def fake_download_youtube_video(**_kwargs):
raise YouTubeDownloadError("Unable to fetch video from upstream")
monkeypatch.setattr(app, "download_youtube_video", fake_download_youtube_video)
res = client.post(
"/generate",
data={"yt_url": "https://youtu.be/TT4OhtH3CFY", "job_id": "job-yt-fail"},
)
assert res.status_code == 500
payload = res.json()
assert payload["detail"]["error"] == "HF_UPSTREAM_FAILURE"
assert payload["detail"]["job_id"] == "job-yt-fail"
def test_generate_with_video_file_success(monkeypatch, tmp_path):
_patch_pipeline_for_success(monkeypatch, tmp_path)
def should_not_call_download(**_kwargs):
raise AssertionError("download_youtube_video should not be called for video_file flow")
monkeypatch.setattr(app, "download_youtube_video", should_not_call_download)
res = client.post(
"/generate",
data={"job_id": "job-upload-success"},
files={"video_file": ("input.mp4", b"video-bytes", "video/mp4")},
)
assert res.status_code == 200
assert res.headers.get("x-job-id") == "job-upload-success"
assert res.content == b"zip"
def test_generate_invalid_input_combinations_return_422():
res_none = client.post("/generate", data={"job_id": "job-none"})
assert res_none.status_code == 422
res_both = client.post(
"/generate",
data={"yt_url": "https://youtu.be/TT4OhtH3CFY", "job_id": "job-both"},
files={"video_file": ("input.mp4", b"video-bytes", "video/mp4")},
)
assert res_both.status_code == 422