Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Query, File, UploadFile | |
| from fastapi.responses import FileResponse | |
| import torch | |
| from diffusion import Diffusion # Make sure you import your own modules correctly | |
| from utils import get_id_frame, get_audio_emb, save_video # Make sure you import your own modules correctly | |
| import shutil | |
| from pathlib import Path | |
| app = FastAPI() | |
| async def generate_video( | |
| id_frame_file: UploadFile = File(...), | |
| audio_file: UploadFile = File(...), | |
| gpu: bool = Query(True, description="Use GPU if available"), | |
| id_frame_random: bool = Query(False, description="Pick id_frame randomly from video"), | |
| inference_steps: int = Query(100, description="Number of inference diffusion steps"), | |
| output: str = Query("/Users/a/Documents/Automations/git talking heads/output_video.mp4", description="Path to save the output video") | |
| ): | |
| device = 'cuda' if gpu and torch.cuda.is_available() else 'cpu' | |
| print('Loading model...') | |
| unet = torch.jit.load("/Users/a/Documents/Automations/git talking heads/checkpoints/crema_script.pt") | |
| diffusion_args = { | |
| "in_channels": 3, | |
| "image_size": 128, | |
| "out_channels": 6, | |
| "n_timesteps": 1000, | |
| } | |
| diffusion = Diffusion(unet, device, **diffusion_args).to(device) | |
| diffusion.space(inference_steps) | |
| # Save uploaded files to disk | |
| id_frame_path = Path("temp_id_frame.jpg") | |
| audio_path = Path("temp_audio.mp3") | |
| with id_frame_path.open("wb") as buffer: | |
| shutil.copyfileobj(id_frame_file.file, buffer) | |
| with audio_path.open("wb") as buffer: | |
| shutil.copyfileobj(audio_file.file, buffer) | |
| id_frame = get_id_frame(str(id_frame_path), random=id_frame_random, resize=diffusion_args["image_size"]).to(device) | |
| audio, audio_emb = get_audio_emb(str(audio_path), "/Users/a/Documents/Automations/git talking heads/checkpoints/audio_encoder.pt", device) | |
| unet_args = { | |
| "n_audio_motion_embs": 2, | |
| "n_motion_frames": 2, | |
| "motion_channels": 3 | |
| } | |
| samples = diffusion.sample(id_frame, audio_emb.unsqueeze(0), **unet_args) | |
| save_video(output, samples, audio=audio, fps=25, audio_rate=16000) | |
| print(f'Results saved at {output}') | |
| return FileResponse(output) | |