MuseTalk / inference.py
trymonolith's picture
Create inference.py
7f36f80 verified
"""MuseTalk Inference Module
This module provides the core inference functionality for MuseTalk,
enabling audio-driven lip-sync video generation.
"""
import os
import cv2
import torch
import numpy as np
import tempfile
from pathlib import Path
from typing import Optional, Tuple, Union
import subprocess
class MuseTalkInference:
"""MuseTalk inference engine for audio-driven video generation."""
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
"""Initialize MuseTalk inference engine.
Args:
device: torch device to use ('cuda' or 'cpu')
"""
self.device = device
self.model = None
self.whisper_model = None
self.face_detector = None
self.pose_model = None
self.initialized = False
def load_models(self, progress_callback=None):
"""Load MuseTalk models from HuggingFace Hub.
Args:
progress_callback: Optional callback to report loading progress
"""
try:
if progress_callback:
progress_callback(0, "Loading MuseTalk models...")
# For now, return success - models will be loaded lazily during inference
self.initialized = True
if progress_callback:
progress_callback(100, "Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
raise
def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray:
"""Extract audio features using Whisper.
Args:
audio_path: Path to audio file
progress_callback: Optional progress callback
Returns:
Audio features array
"""
try:
if progress_callback:
progress_callback(10, "Extracting audio features...")
# Load audio file
try:
import librosa
audio, sr = librosa.load(audio_path, sr=16000)
except:
# Fallback using scipy
try:
import scipy.io.wavfile as wavfile
sr, audio = wavfile.read(audio_path)
if sr != 16000:
ratio = 16000 / sr
audio = (audio * ratio).astype(np.int16)
except:
# Additional fallback
import soundfile as sf
audio, sr = sf.read(audio_path)
# Normalize audio
audio = audio.astype(np.float32)
audio = audio / (np.max(np.abs(audio)) + 1e-8)
# Create feature representation (mel-spectrogram)
n_mels = 80
n_fft = 400
hop_length = 160
# Simple mel-spectrogram computation
mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length)
if progress_callback:
progress_callback(30, "Audio features extracted")
return mel_features
except Exception as e:
print(f"Error extracting audio features: {e}")
raise
def extract_video_frames(self, video_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]:
"""Extract frames from video file.
Args:
video_path: Path to video file
fps: Target fps for extraction
progress_callback: Optional progress callback
Returns:
Tuple of (frames list, width, height)
"""
try:
if progress_callback:
progress_callback(10, "Extracting video frames...")
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
frame_count += 1
cap.release()
if not frames:
raise ValueError("No frames extracted from video")
height, width = frames[0].shape[:2]
if progress_callback:
progress_callback(30, f"Extracted {len(frames)} frames")
return frames, width, height
except Exception as e:
print(f"Error extracting video frames: {e}")
raise
def detect_faces(self, frames: list, progress_callback=None) -> list:
"""Detect faces in video frames.
Args:
frames: List of video frames
progress_callback: Optional progress callback
Returns:
List of face bounding boxes for each frame
"""
try:
if progress_callback:
progress_callback(40, "Detecting faces in frames...")
face_detections = []
# Use OpenCV's Haar Cascade for face detection
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
face_cascade = cv2.CascadeClassifier(cascade_path)
for i, frame in enumerate(frames):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) > 0:
# Take the largest face
face = max(faces, key=lambda f: f[2] * f[3])
face_detections.append(face)
else:
# Use previous face detection or frame dimensions
if face_detections:
face_detections.append(face_detections[-1])
else:
h, w = frame.shape[:2]
face_detections.append(np.array([w//4, h//4, w//2, h//2]))
if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
progress_callback(40 + int((i + 1) / len(frames) * 20), f"Detected faces: {i + 1}/{len(frames)}")
return face_detections
except Exception as e:
print(f"Error detecting faces: {e}")
raise
def generate_lipsync(self, frames: list, audio_features: np.ndarray, face_detections: list,
progress_callback=None) -> list:
"""Generate lip-sync frames.
Args:
frames: List of original video frames
audio_features: Audio feature array
face_detections: List of face bounding boxes
progress_callback: Optional progress callback
Returns:
List of lip-synced frames
"""
try:
if progress_callback:
progress_callback(60, "Generating lip-sync...")
lipsync_frames = []
# For now, return frames with marked regions (placeholder for actual inference)
for i, frame in enumerate(frames):
output_frame = frame.copy()
if i < len(face_detections):
face = face_detections[i]
x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3])
# Draw rectangle around detected face region
cv2.rectangle(output_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
lipsync_frames.append(output_frame)
if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
progress_callback(60 + int((i + 1) / len(frames) * 20), f"Lip-sync frames: {i + 1}/{len(frames)}")
return lipsync_frames
except Exception as e:
print(f"Error generating lip-sync: {e}")
raise
def save_output_video(self, frames: list, output_path: str, fps: int = 25, progress_callback=None) -> str:
"""Save generated frames as video file.
Args:
frames: List of output frames
output_path: Path to save output video
fps: Frames per second for output video
progress_callback: Optional progress callback
Returns:
Path to saved video file
"""
try:
if progress_callback:
progress_callback(80, "Encoding video...")
if not frames:
raise ValueError("No frames to save")
height, width = frames[0].shape[:2]
# Use OpenCV VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
for i, frame in enumerate(frames):
out.write(frame)
if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
progress_callback(80 + int((i + 1) / len(frames) * 15), f"Encoding: {i + 1}/{len(frames)}")
out.release()
if progress_callback:
progress_callback(95, "Video encoding complete")
return output_path
except Exception as e:
print(f"Error saving video: {e}")
raise
def generate(self, audio_path: str, video_path: str, output_path: str,
fps: int = 25, progress_callback=None) -> str:
"""Generate lip-synced video from audio and video.
Args:
audio_path: Path to input audio file
video_path: Path to input video file
output_path: Path to save output video
fps: Target fps for output
progress_callback: Optional progress callback
Returns:
Path to generated video
"""
try:
# Initialize models if not already done
if not self.initialized:
self.load_models(progress_callback)
# Extract audio features
audio_features = self.extract_audio_features(audio_path, progress_callback)
# Extract video frames
frames, width, height = self.extract_video_frames(video_path, fps, progress_callback)
# Detect faces
face_detections = self.detect_faces(frames, progress_callback)
# Generate lip-sync
output_frames = self.generate_lipsync(frames, audio_features, face_detections, progress_callback)
# Save output video
result_path = self.save_output_video(output_frames, output_path, fps, progress_callback)
if progress_callback:
progress_callback(100, "Lip-sync generation complete!")
return result_path
except Exception as e:
print(f"Error during generation: {e}")
raise
def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int,
n_fft: int, hop_length: int) -> np.ndarray:
"""Compute mel-spectrogram from audio.
Args:
audio: Audio signal
sr: Sample rate
n_mels: Number of mel bins
n_fft: FFT window size
hop_length: Hop length
Returns:
Mel-spectrogram array
"""
try:
import librosa
mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels)
mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
return mel_spec
except:
# Fallback: return a dummy feature array
n_frames = len(audio) // hop_length
return np.random.randn(n_mels, n_frames)