"""MuseTalk Inference Module Refactored for Long-Form Generation (5-10 mins) using Memory-Efficient Streaming, Looping, and Audio Muxing. """ import os import cv2 import torch import numpy as np import tempfile import librosa import mimetypes import subprocess from pathlib import Path from typing import Optional, Tuple, Union class MuseTalkInference: """MuseTalk inference engine for audio-driven video generation.""" def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"): self.device = device self.model = None self.whisper_model = None self.face_detector = None self.pose_model = None self.initialized = False def load_models(self, progress_callback=None): """Load MuseTalk models from HuggingFace Hub.""" try: if progress_callback: progress_callback(0, "Loading MuseTalk models...") # Placeholder: Initialize your actual PyTorch models here self.initialized = True if progress_callback: progress_callback(5, "Models loaded successfully") except Exception as e: print(f"Error loading models: {e}") raise def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray: """Extract audio features using Whisper/Mel-Spectrogram.""" try: if progress_callback: progress_callback(10, "Extracting audio features...") try: audio, sr = librosa.load(audio_path, sr=16000) except: try: import scipy.io.wavfile as wavfile sr, audio = wavfile.read(audio_path) if sr != 16000: ratio = 16000 / sr audio = (audio * ratio).astype(np.int16) except: import soundfile as sf audio, sr = sf.read(audio_path) audio = audio.astype(np.float32) audio = audio / (np.max(np.abs(audio)) + 1e-8) n_mels = 80 n_fft = 400 hop_length = 160 mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length) if progress_callback: progress_callback(15, "Audio features extracted") return mel_features except Exception as e: print(f"Error extracting audio features: {e}") raise def extract_source_frames(self, file_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]: """Extracts frames from a short video or loads a single image to memory.""" try: if progress_callback: progress_callback(20, "Reading source image/video...") mime_type, _ = mimetypes.guess_type(file_path) frames = [] # Handle Single Image Input if mime_type and mime_type.startswith('image'): frame = cv2.imread(file_path) if frame is None: raise ValueError("Failed to read image") frames.append(frame) # Handle Short Video Input else: cap = cv2.VideoCapture(file_path) while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() if not frames: raise ValueError("No frames extracted from source file") height, width = frames[0].shape[:2] return frames, width, height except Exception as e: print(f"Error extracting video frames: {e}") raise def detect_faces(self, frames: list, progress_callback=None) -> list: """Detect faces ONLY on the short source clip to save compute.""" try: if progress_callback: progress_callback(25, "Detecting face in source media...") face_detections = [] cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' face_cascade = cv2.CascadeClassifier(cascade_path) for i, frame in enumerate(frames): gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, 1.1, 4) if len(faces) > 0: # Take the LARGEST face by area (width * height) face = max(faces, key=lambda f: f[2] * f[3]) face_detections.append(face) else: if face_detections: face_detections.append(face_detections[-1]) else: h, w = frame.shape[:2] face_detections.append(np.array([w//4, h//4, w//2, h//2])) return face_detections except Exception as e: print(f"Error detecting faces: {e}") raise def generate(self, audio_path: str, video_path: str, output_path: str, fps: int = 25, progress_callback=None) -> str: """ Memory-efficient generator for long videos. Loops short inputs to match 5-10 minute audio. """ try: if not self.initialized: self.load_models(progress_callback) # 1. Extract audio features audio_features = self.extract_audio_features(audio_path, progress_callback) # 2. Determine Total Output Frames based on Audio Length audio_data, sr = librosa.load(audio_path, sr=16000) audio_duration = len(audio_data) / sr total_target_frames = int(audio_duration * fps) if total_target_frames == 0: raise ValueError("Audio file is too short or invalid.") # 3. Extract Source Clip/Image (Only loads short clip into memory) source_frames, width, height = self.extract_source_frames(video_path, fps, progress_callback) # 4. Detect faces on the short source clip (Pre-cached) source_faces = self.detect_faces(source_frames, progress_callback) # 5. Stream Process (Write directly to file to avoid OOM crash) temp_silent_video = output_path.replace('.mp4', '_silent.mp4') fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_silent_video, fourcc, fps, (width, height)) if progress_callback: progress_callback(30, f"Generating {total_target_frames} frames (Streaming)...") for i in range(total_target_frames): # LOOPING LOGIC: Loop the short video or image continuously src_idx = i % len(source_frames) frame = source_frames[src_idx].copy() face = source_faces[src_idx] # --- START AI LIP-SYNC INFERENCE --- # NOTE: Put your actual AI model generation code here. # Right now, this just draws a box around the face. # Example: frame = self.model.infer(frame, face, audio_features[:, i]) x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3]) cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # --- END AI LIP-SYNC INFERENCE --- # Write directly to disk (Saves 30GB+ of RAM for 10 min videos) out.write(frame) # Report progress periodically if (i + 1) % max(1, total_target_frames // 20) == 0 and progress_callback: progress_pct = 30 + int((i / total_target_frames) * 60) progress_callback(progress_pct, f"Generated frames: {i + 1}/{total_target_frames}") out.release() # 6. MUX AUDIO (Combine the generated silent video with original audio) if progress_callback: progress_callback(95, "Merging final audio and video...") try: cmd = [ "ffmpeg", "-y", "-i", temp_silent_video, # The generated silent video "-i", audio_path, # The original audio "-c:v", "libx264", # Re-encode video for broad web compatibility "-c:a", "aac", # Re-encode audio to AAC "-map", "0:v:0", "-map", "1:a:0", "-shortest", # Cut at the shortest stream output_path ] subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) # Cleanup temp file if os.path.exists(temp_silent_video): os.remove(temp_silent_video) except subprocess.CalledProcessError as e: print(f"FFMPEG Error: {e.stderr}") # Fallback to silent video if FFMPEG fails os.rename(temp_silent_video, output_path) if progress_callback: progress_callback(100, "Generation Complete!") return output_path except Exception as e: print(f"Error during generation: {e}") raise def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int, n_fft: int, hop_length: int) -> np.ndarray: """Compute mel-spectrogram from audio.""" try: import librosa mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels) mel_spec = librosa.power_to_db(mel_spec, ref=np.max) return mel_spec except: n_frames = len(audio) // hop_length return np.random.randn(n_mels, n_frames)