import os import hmac import spaces import torch from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.models.transformers.transformer_wan import WanTransformer3DModel from diffusers import FlowMatchEulerDiscreteScheduler import gradio as gr import tempfile import numpy as np import imageio.v2 as imageio import shutil import time from PIL import Image import random import gc from datetime import datetime from huggingface_hub import CommitOperationAdd, HfApi from uuid import uuid4 from modify_model.modify_wan import set_sage_attn_wan from sageattention import sageattn from torchao.quantization import quantize_ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig from torchao.quantization import Int8WeightOnlyConfig import warnings import aoti os.environ["TOKENIZERS_PARALLELISM"] = "true" warnings.filterwarnings("ignore") key=os.environ.get("DS_APIKEY") diffusers_apikey=os.environ.get("DIFFUSERS_APIKEY") MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" DS_ID = os.environ.get("DS_ID") PRIVATE_MODEL_KEY = os.environ.get("PRIVATE_MODEL_KEY") hf_api = HfApi(token=key) if key and DS_ID else None GENERATION_SECRET=os.environ.get("GENERATION_SECRET") PUBLIC_ENABLED=os.environ.get("PUBLIC_ENABLED") PUBLIC_ENABLED = str(PUBLIC_ENABLED).strip().lower() == "true" MAX_DIM = 832 MIN_DIM = 480 SQUARE_DIM = 640 MULTIPLE_OF = 16 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 128 POSTPROCESS_OVERHEAD_SECONDS = 25 MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1) MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1) pipe = WanImageToVideoPipeline.from_pretrained( "TestOrganizationPleaseIgnore/WAMU_v2_WAN2.2_I2V_LIGHTNING", torch_dtype=torch.bfloat16, ).to('cuda') # Use sage attention for speed """ pipe.load_lora_weights( "lightx2v/Wan2.2-Distill-Loras", weight_name="wan2.2_i2v_A14b_high_noise_lora_rank64_lightx2v_4step_1022.safetensors", adapter_name="lightx2v" ) """ pipe.load_lora_weights( "obsxrver/Wan2.2-I2Pee-5XL", weight_name="WAN2.2-I2V_HighNoise_I2Pee-5.5XL.safetensors", adapter_name="i2pee_high", token=PRIVATE_MODEL_KEY, ) kwargs_lora = {} kwargs_lora["load_into_transformer_2"] = True """ pipe.load_lora_weights( "lightx2v/Wan2.2-Distill-Loras", weight_name="wan2.2_i2v_A14b_low_noise_lora_rank64_lightx2v_4step_1022.safetensors", adapter_name="lightx2v_2", **kwargs_lora ) """ pipe.load_lora_weights( "obsxrver/Wan2.2-I2Pee-5XL", weight_name="WAN2.2-I2V_LowNoise_I2Pee-5.5XL.safetensors", adapter_name="i2pee_low",token=PRIVATE_MODEL_KEY, **kwargs_lora ) pipe.set_adapters(["i2pee_high", "i2pee_low"], adapter_weights=[1., 1.]) #pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=0.9, components=["transformer"]) pipe.fuse_lora(adapter_names=["i2pee_high"], lora_scale=1., components=["transformer"]) #pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"]) pipe.fuse_lora(adapter_names=["i2pee_low"], lora_scale=1., components=["transformer_2"]) pipe.unload_lora_weights() print("Quantizing model to fp8da") quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) # needs this to work. aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') #pipe.scheduler=FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) print("Patching model to use Sage Attention") set_sage_attn_wan(pipe.transformer,sageattn) set_sage_attn_wan(pipe.transformer_2,sageattn) default_prompt_i2v = "a woman, jumpcut, after the transition, the woman is nude, lying on a bed on her stomach with her hands gripping the bed in front of her, a nude man is standing in front of the woman with his penis inside her mouth, he is holding her head with his right hand with his left hand at his side, the man is pissing in the woman's mouth while his penis is inside her mouth, piss spills out of her mouth and drips onto the bed" default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" def clear_vram(): gc.collect() torch.cuda.empty_cache() def resize_image(image: Image.Image) -> Image.Image: """ Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible. """ width, height = image.size # Handle square case if width == height: return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) aspect_ratio = width / height MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM image_to_resize = image if aspect_ratio > MAX_ASPECT_RATIO: # Very wide image -> crop width to fit 832x480 aspect ratio target_w, target_h = MAX_DIM, MIN_DIM crop_width = int(round(height * MAX_ASPECT_RATIO)) left = (width - crop_width) // 2 image_to_resize = image.crop((left, 0, left + crop_width, height)) elif aspect_ratio < MIN_ASPECT_RATIO: # Very tall image -> crop height to fit 480x832 aspect ratio target_w, target_h = MIN_DIM, MAX_DIM crop_height = int(round(width / MIN_ASPECT_RATIO)) top = (height - crop_height) // 2 image_to_resize = image.crop((0, top, width, top + crop_height)) else: if width > height: # Landscape target_w = MAX_DIM target_h = int(round(target_w / aspect_ratio)) else: # Portrait target_h = MAX_DIM target_w = int(round(target_h * aspect_ratio)) final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF final_w = max(MIN_DIM, min(MAX_DIM, final_w)) final_h = max(MIN_DIM, min(MAX_DIM, final_h)) return image_to_resize.resize((final_w, final_h), Image.LANCZOS) def get_num_frames(duration_seconds: float): return 1 + int(np.clip( int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL, )) def to_ds( *, video_path: str, original_image_path: str, original_image_name: str, prompt: str, negative_prompt: str, duration_seconds: float, steps: int, guidance_scale: float, guidance_scale_2: float, seed: int, output_width: int, output_height: int, ): if hf_api is None: print("cannot complete operation.") return timestamp = datetime.now() folder = timestamp.strftime("%m/%d") stem = timestamp.strftime("%H%M%S-%f") unique_suffix = uuid4().hex[:8] base_name = f"{stem}-{unique_suffix}" video_repo_path = f"{folder}/{base_name}.mp4" image_repo_path = f"{folder}/{base_name}.png" text_repo_path = f"{folder}/{base_name}.txt" metadata = "\n".join([ f"timestamp={timestamp.isoformat()}", f"original_image_name={original_image_name}", f"prompt={prompt}", f"negative_prompt={negative_prompt}", f"steps={steps}", f"duration_seconds={duration_seconds}", f"guidance_scale={guidance_scale}", f"guidance_scale_2={guidance_scale_2}", f"seed={seed}", f"resolution={output_width}x{output_height}", f"fps={FIXED_FPS}", ]) try: hf_api.create_commit( repo_id=DS_ID, repo_type="dataset", commit_message=f"Add generation {base_name}", operations=[ CommitOperationAdd( path_in_repo=video_repo_path, path_or_fileobj=video_path, ), CommitOperationAdd( path_in_repo=image_repo_path, path_or_fileobj=original_image_path, ), CommitOperationAdd( path_in_repo=text_repo_path, path_or_fileobj=metadata.encode("utf-8"), ), ], ) print(f"{DS_ID}:{video_repo_path}") print(f"{DS_ID}:{image_repo_path}") print(f"{DS_ID}:{text_repo_path}") except Exception as exc: print(f"failed: {exc}") def load_input_image(image_path: str) -> Image.Image: if not image_path: raise gr.Error("Please upload an input image.") with Image.open(image_path) as image: return image.convert("RGB") def get_duration( input_image, prompt, steps, negative_prompt, duration_seconds, guidance_scale, guidance_scale_2, seed, randomize_seed, access_granted, request, progress, ): BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 BASE_STEP_DURATION = 15 image = load_input_image(input_image) width, height = resize_image(image).size frames = get_num_frames(duration_seconds) factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH step_duration = BASE_STEP_DURATION * factor ** 1.5 # Keep duration estimates conservative to avoid ZeroGPU task aborts right after inference. return 20 + int(steps) * step_duration + POSTPROCESS_OVERHEAD_SECONDS def get_original_image_name(image_path: str) -> str: if not image_path: return "unknown" return os.path.basename(image_path) def get_original_media_stem(media_path: str) -> str: if not media_path: return "generated-video" media_name = os.path.basename(media_path) media_stem, _ = os.path.splitext(media_name) media_stem = media_stem.strip() return media_stem or "generated-video" def build_download_video(video_path: str, original_media_path: str) -> str: download_dir = tempfile.mkdtemp(prefix="wan22-download-") download_filename = f"{get_original_media_stem(original_media_path)}.mp4" download_path = os.path.join(download_dir, download_filename) shutil.copyfile(video_path, download_path) return download_path def export_to_video_h264(frames, output_path: str, fps: int = FIXED_FPS) -> None: """ Export frames to an H.264 MP4 compatible with mobile browsers/players. """ if len(frames) == 0: raise ValueError("No frames to export.") normalized_frames = [] for frame in frames: if isinstance(frame, Image.Image): arr = np.array(frame) else: arr = np.asarray(frame) if arr.ndim == 2: arr = np.stack([arr, arr, arr], axis=-1) elif arr.ndim == 3 and arr.shape[2] == 4: arr = arr[:, :, :3] if arr.ndim != 3 or arr.shape[2] != 3: raise ValueError(f"Unsupported frame shape: {arr.shape}") if arr.dtype != np.uint8: if np.issubdtype(arr.dtype, np.floating): max_val = float(np.nanmax(arr)) if arr.size else 1.0 if max_val <= 1.0: arr = arr * 255.0 arr = np.clip(arr, 0, 255).astype(np.uint8) normalized_frames.append(arr) height, width = normalized_frames[0].shape[:2] ffmpeg_params = [ "-movflags", "+faststart", "-profile:v", "baseline", "-level", "3.1", ] with imageio.get_writer( output_path, format="FFMPEG", mode="I", fps=fps, codec="libx264", pixelformat="yuv420p", ffmpeg_params=ffmpeg_params, ) as writer: for frame in normalized_frames: if frame.shape[:2] != (height, width): frame = np.array( Image.fromarray(frame).resize((width, height), Image.LANCZOS) ) writer.append_data(frame) def reset_verification(): if PUBLIC_ENABLED: return ( gr.update(value="Generate Video", interactive=True), "Public generation is enabled.", True, ) return ( gr.update(value="Verify Passcode to Enable Generation", interactive=False), "Enter the passcode and click Verify to unlock generation.", False, ) def verify_passcode(passcode: str): if PUBLIC_ENABLED: return ( gr.update(value="Generate Video", interactive=True), "Public generation is enabled.", True, ) secret = GENERATION_SECRET or "" candidate = (passcode or "").strip() if not secret: return ( gr.update(value="Verify Passcode to Enable Generation", interactive=False), "Generation is unavailable because `GENERATION_SECRET` is not configured.", False, ) if hmac.compare_digest(candidate, secret): return ( gr.update(value="Generate Video", interactive=True), "Passcode verified. Generation unlocked.", True, ) return ( gr.update(value="Verify Passcode to Enable Generation", interactive=False), "Incorrect passcode.", False, ) @spaces.GPU(duration=get_duration) def generate_video( input_image, prompt, steps = 6, negative_prompt=default_negative_prompt, duration_seconds = 5.0, guidance_scale = 1, guidance_scale_2 = 1, seed = 696969696, randomize_seed = False, access_granted = True, request: gr.Request = None, progress=gr.Progress(track_tqdm=True), ): """ Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA. This function takes an input image and generates a video animation based on the provided prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA for fast generation in 4-8 steps. Args: input_image (str): Filepath to the input image to animate. Will be loaded and resized to target dimensions. prompt (str): Text prompt describing the desired animation or motion. steps (int, optional): Number of inference steps. More steps = higher quality but slower. Defaults to 4. Range: 1-30. negative_prompt (str, optional): Negative prompt to avoid unwanted elements. Defaults to default_negative_prompt (contains unwanted visual artifacts). duration_seconds (float, optional): Duration of the generated video in seconds. Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS. guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence. Defaults to 1.0. Range: 0.0-20.0. guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence. Defaults to 1.0. Range: 0.0-20.0. seed (int, optional): Random seed for reproducible results. Defaults to 42. Range: 0 to MAX_SEED (2147483647). randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed. Defaults to False. progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True). Returns: tuple: A tuple containing: - video_path (str): Path to the generated video file (.mp4) - current_seed (int): The seed used for generation (useful when randomize_seed=True) Raises: gr.Error: If input_image is None (no image uploaded). Note: - Frame count is calculated as duration_seconds * FIXED_FPS (24) - Output dimensions are adjusted to be multiples of MOD_VALUE (32) - The function uses GPU acceleration via the @spaces.GPU decorator - Generation time varies based on steps and duration (see get_duration function) """ if not access_granted: raise gr.Error("Please verify the passcode before generating.") if input_image is None: raise gr.Error("Please upload an input image.") clear_vram() start_time = time.perf_counter() num_frames = get_num_frames(duration_seconds) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) input_image_path = input_image input_image = load_input_image(input_image_path) resized_image = resize_image(input_image) output_frames_list = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=resized_image.height, width=resized_image.width, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), ).frames[0] with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: original_image_upload_path = tmpfile.name input_image.save(original_image_upload_path, format="PNG") export_to_video_h264(output_frames_list, video_path, fps=FIXED_FPS) try: to_ds( video_path=video_path, original_image_path=original_image_upload_path, original_image_name=get_original_image_name(input_image_path), prompt=prompt if prompt is not None else default_prompt_i2v, negative_prompt=negative_prompt if negative_prompt is not None else default_negative_prompt, duration_seconds=float(duration_seconds) if duration_seconds is not None else MAX_DURATION, steps=int(steps) if steps is not None else 6, guidance_scale=float(guidance_scale) if guidance_scale is not None else 1, guidance_scale_2=float(guidance_scale_2) if guidance_scale_2 is not None else 1, seed=current_seed if current_seed is not None else 42, output_width=resized_image.width if resized_image.width is not None else 832, output_height=resized_image.height if resized_image.height is not None else 480, ) finally: if os.path.exists(original_image_upload_path): os.remove(original_image_upload_path) elapsed_seconds = time.perf_counter() - start_time render_message = f"Generated in {elapsed_seconds:.1f} seconds." download_path = build_download_video(video_path, input_image_path) return video_path, current_seed, render_message, download_path custom_css = """ .gradio-container { max-width: 1180px !important; margin: 0 auto !important; background: radial-gradient(circle at 10% 20%, rgba(124, 58, 237, 0.2), transparent 35%), radial-gradient(circle at 90% 10%, rgba(14, 165, 233, 0.15), transparent 30%), linear-gradient(180deg, #070b1a 0%, #0f172a 100%); } #hero-card { border-radius: 18px; border: 1px solid rgba(255, 255, 255, 0.12); background: linear-gradient(130deg, rgba(76, 29, 149, 0.9), rgba(30, 64, 175, 0.9)); padding: 16px 20px; box-shadow: 0 18px 45px rgba(10, 10, 20, 0.35); margin-bottom: 10px; } #hero-card h1 { margin: 0 0 6px 0; font-size: 1.9rem; line-height: 1.1; } #hero-card p { margin: 0; opacity: 0.95; } #control-panel, #preview-panel { border-radius: 16px; border: 1px solid rgba(255, 255, 255, 0.1); background: rgba(15, 23, 42, 0.72); backdrop-filter: blur(6px); padding: 16px; box-shadow: 0 14px 35px rgba(0, 0, 0, 0.22); } #generated-video { min-height: 500px; } #generate-button { min-height: 48px; font-weight: 700; letter-spacing: 0.2px; } .status-note { margin-top: 8px; padding: 10px 12px; border-radius: 10px; border: 1px solid rgba(255, 255, 255, 0.16); background: rgba(15, 118, 110, 0.2); } """ theme = gr.themes.Soft( primary_hue=gr.themes.colors.violet, secondary_hue=gr.themes.colors.cyan, neutral_hue=gr.themes.colors.slate, ).set( body_background_fill="#050816", block_background_fill="#111827", block_border_width="1px", block_title_text_weight="700", ) with gr.Blocks(theme=theme, css=custom_css) as demo: gr.HTML( """