""" MediGen-1.3B - Medical Video Generation Fine-tuned Wan2.1-T2V-1.3B on MedVideoCap-55K dataset for generating medical-domain videos from text descriptions. Usage: # Single video generation python inference.py --prompt "A doctor examining a patient" --output exam.mp4 # Batch generation from JSON file (list of strings or objects with "prompt" key) python inference.py --batch prompts.json --output_dir results/ # Use a specific GPU python inference.py --prompt "..." --gpu 1 --output result.mp4 # Custom resolution and seed python inference.py --prompt "..." --height 480 --width 832 --seed 123 """ import torch import os import json import argparse import gc from pathlib import Path from safetensors.torch import load_file # --------------------------------------------------------------------------- # Paths - all model weights are stored under models/ relative to this script # --------------------------------------------------------------------------- SCRIPT_DIR = Path(__file__).resolve().parent MODELS_DIR = SCRIPT_DIR / "models" # Default negative prompt to suppress common artifacts in generated videos NEGATIVE_PROMPT = ( "Distorted, blurry, low quality, watermark, text overlay, " "static image, worst quality, JPEG artifacts, deformed, " "extra limbs, bad anatomy" ) def load_pipeline(device="cuda"): """Load the full MediGen-1.3B pipeline. This involves three steps: 1. Load the base Wan2.1-T2V-1.3B model (DIT + T5 text encoder + VAE) 2. Load the UMT5-XXL tokenizer for text encoding 3. Apply the fine-tuned DIT weights on top of the base model Args: device: Target device, default "cuda". Returns: WanVideoPipeline ready for inference. """ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig print("Loading MediGen-1.3B...") # Step 1: Load base pipeline components # - DIT (Diffusion Transformer): single file for 1.3B model # - T5 text encoder: converts text prompts to embeddings # - VAE: decodes latent representations into video frames pipe = WanVideoPipeline.from_pretrained( torch_dtype=torch.bfloat16, device=device, model_configs=[ ModelConfig( model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model.safetensors", ), ModelConfig( model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", ), ModelConfig( model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", ), ], tokenizer_config=ModelConfig( model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/", ), ) # Step 2: Apply fine-tuned DIT weights # Only the DIT component was fine-tuned; T5 and VAE remain unchanged ckpt_path = MODELS_DIR / "medigen-1.3b.safetensors" state_dict = load_file(str(ckpt_path)) pipe.dit.load_state_dict(state_dict, strict=False) # Free checkpoint memory after loading del state_dict gc.collect() torch.cuda.empty_cache() print("MediGen-1.3B ready.") return pipe def generate_video(pipe, prompt, output_path, seed=42, height=480, width=832): """Generate a single video from a text prompt. Args: pipe: Loaded WanVideoPipeline instance. prompt: Text description of the desired medical video. output_path: Path to save the output .mp4 file. seed: Random seed for reproducibility. height: Video height in pixels (default 480). width: Video width in pixels (default 832). """ from diffsynth.utils.data import save_video print(f"Generating: {prompt[:80]}...") # Run diffusion inference with 50 denoising steps (default) # tiled=True enables tiled VAE decoding to reduce VRAM usage video = pipe( prompt=prompt, negative_prompt=NEGATIVE_PROMPT, seed=seed, height=height, width=width, tiled=True, ) # Save as MP4 at 15fps with quality level 5 save_video(video, output_path, fps=15, quality=5) print(f"Saved: {output_path}") def main(): parser = argparse.ArgumentParser( description="MediGen-1.3B Medical Video Generation" ) parser.add_argument("--prompt", type=str, help="Text prompt for generation") parser.add_argument("--batch", type=str, help="JSON file with prompts") parser.add_argument("--output", type=str, default="output.mp4", help="Output path for single video (default: output.mp4)") parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for batch mode (default: outputs/)") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility (default: 42)") parser.add_argument("--height", type=int, default=480, help="Video height in pixels (default: 480)") parser.add_argument("--width", type=int, default=832, help="Video width in pixels (default: 832)") parser.add_argument("--gpu", type=int, default=0, help="GPU device ID (default: 0)") args = parser.parse_args() # Set visible GPU before any CUDA operations os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) # Load model (takes ~1-2 minutes for 1.3B model) pipe = load_pipeline() if args.batch: # Batch mode: read JSON file containing a list of prompts # Accepts either ["prompt1", "prompt2", ...] or [{"prompt": "..."}, ...] with open(args.batch) as f: prompts = json.load(f) os.makedirs(args.output_dir, exist_ok=True) for i, item in enumerate(prompts): prompt = item if isinstance(item, str) else item.get("prompt", "") out_path = os.path.join(args.output_dir, f"{i:03d}.mp4") generate_video( pipe, prompt, out_path, seed=args.seed + i, # Different seed per video height=args.height, width=args.width, ) elif args.prompt: # Single video mode generate_video( pipe, args.prompt, args.output, seed=args.seed, height=args.height, width=args.width, ) else: print("Error: provide --prompt or --batch") parser.print_help() if __name__ == "__main__": main()