ElevenClip-AI / backend /app /services /multimodal.py
JakgritB
feat(backend): implement Qwen2-VL visual analysis with ROCm support
70fbcf2
import os
import subprocess
import tempfile
from app.core.config import Settings
from app.models.schemas import ClipCandidate
_DEMO_VISUALS = [
("High-energy scene with strong visual contrast and clear subject focus.", 88.0),
("Close-up with expressive reactions — excellent engagement framing.", 92.0),
("Dynamic motion sequence; subject well-lit with clean background.", 84.0),
("Text-overlay-friendly composition with natural colour grading.", 79.0),
("Wide establishing shot; strong emotional beat in middle frames.", 81.0),
]
class QwenVisualAnalyzer:
def __init__(self, settings: Settings) -> None:
self.settings = settings
self._model = None
self._processor = None
def enrich(self, video_path: str, clips: list[ClipCandidate]) -> list[ClipCandidate]:
if self.settings.demo_mode:
return self._demo_enrich(clips)
try:
return self._qwen_enrich(video_path, clips)
except Exception:
return clips
# ------------------------------------------------------------------
# Demo mode
# ------------------------------------------------------------------
def _demo_enrich(self, clips: list[ClipCandidate]) -> list[ClipCandidate]:
enriched = []
for i, clip in enumerate(clips):
note, vscore = _DEMO_VISUALS[i % len(_DEMO_VISUALS)]
enriched.append(
clip.model_copy(
update={
"metadata": {
**clip.metadata,
"visual_model": "demo",
"visual_note": note,
"visual_score": vscore,
}
}
)
)
return enriched
# ------------------------------------------------------------------
# Production mode — Qwen2-VL on ROCm
# ------------------------------------------------------------------
def _load_model(self) -> None:
try:
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
except ImportError as exc:
raise RuntimeError("transformers + ROCm PyTorch are required for Qwen2-VL") from exc
dtype = getattr(torch, self.settings.preferred_torch_dtype, torch.bfloat16)
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
self.settings.qwen_vl_model_id,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
token=self.settings.hf_token or None,
)
self._processor = AutoProcessor.from_pretrained(
self.settings.qwen_vl_model_id,
trust_remote_code=True,
token=self.settings.hf_token or None,
)
def _qwen_enrich(self, video_path: str, clips: list[ClipCandidate]) -> list[ClipCandidate]:
if self._model is None:
self._load_model()
enriched = []
for clip in clips:
try:
frames = _sample_frames(video_path, clip.start_seconds, clip.end_seconds, self.settings.ffmpeg_binary)
if not frames:
enriched.append(clip)
continue
note, vscore = self._analyze(frames, clip.title)
enriched.append(
clip.model_copy(
update={
"metadata": {
**clip.metadata,
"visual_model": self.settings.qwen_vl_model_id,
"visual_note": note,
"visual_score": vscore,
}
}
)
)
except Exception:
enriched.append(
clip.model_copy(
update={
"metadata": {
**clip.metadata,
"visual_model": self.settings.qwen_vl_model_id,
"visual_status": "analysis_failed",
}
}
)
)
return enriched
def _analyze(self, frames: list, title: str) -> tuple[str, float]:
import torch
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": f} for f in frames],
{
"type": "text",
"text": (
f'These frames are from a clip titled "{title}". '
"Describe the visual quality and short-form engagement potential in 1-2 sentences. "
"Then output exactly: SCORE: <integer 0-100>"
),
},
],
}
]
text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self._processor(text=[text], images=frames, return_tensors="pt").to(self._model.device)
with torch.no_grad():
ids = self._model.generate(**inputs, max_new_tokens=140)
reply = self._processor.batch_decode(
ids[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)[0].strip()
vscore = 75.0
for line in reversed(reply.splitlines()):
upper = line.strip().upper()
if upper.startswith("SCORE:"):
try:
vscore = float(upper.split(":", 1)[1].strip())
except ValueError:
pass
break
note = reply.split("SCORE:")[0].strip() or reply
return note, min(max(vscore, 0.0), 100.0)
# ------------------------------------------------------------------
# Frame extraction helper
# ------------------------------------------------------------------
def _sample_frames(video_path: str, start: float, end: float, ffmpeg: str, n: int = 4) -> list:
try:
from PIL import Image
except ImportError:
return []
duration = max(end - start, 1.0)
timestamps = [start + duration * i / max(n - 1, 1) for i in range(n)]
frames = []
tmp_files = []
try:
for ts in timestamps:
fd, tmp = tempfile.mkstemp(suffix=".jpg")
os.close(fd)
tmp_files.append(tmp)
result = subprocess.run(
[
ffmpeg,
"-ss", f"{ts:.3f}",
"-i", video_path,
"-vframes", "1",
"-q:v", "2",
"-y", tmp,
],
capture_output=True,
timeout=15,
)
if result.returncode == 0:
try:
frames.append(Image.open(tmp).convert("RGB"))
except Exception:
pass
finally:
for tmp in tmp_files:
try:
os.unlink(tmp)
except OSError:
pass
return frames