| """ |
| MediGen-1.3B - Medical Video Generation |
| |
| Fine-tuned Wan2.1-T2V-1.3B on MedVideoCap-55K dataset for generating |
| medical-domain videos from text descriptions. |
| |
| Usage: |
| # Single video generation |
| python inference.py --prompt "A doctor examining a patient" --output exam.mp4 |
| |
| # Batch generation from JSON file (list of strings or objects with "prompt" key) |
| python inference.py --batch prompts.json --output_dir results/ |
| |
| # Use a specific GPU |
| python inference.py --prompt "..." --gpu 1 --output result.mp4 |
| |
| # Custom resolution and seed |
| python inference.py --prompt "..." --height 480 --width 832 --seed 123 |
| """ |
|
|
| import torch |
| import os |
| import json |
| import argparse |
| import gc |
| from pathlib import Path |
| from safetensors.torch import load_file |
|
|
| |
| |
| |
| SCRIPT_DIR = Path(__file__).resolve().parent |
| MODELS_DIR = SCRIPT_DIR / "models" |
|
|
| |
| NEGATIVE_PROMPT = ( |
| "Distorted, blurry, low quality, watermark, text overlay, " |
| "static image, worst quality, JPEG artifacts, deformed, " |
| "extra limbs, bad anatomy" |
| ) |
|
|
|
|
| def load_pipeline(device="cuda"): |
| """Load the full MediGen-1.3B pipeline. |
| |
| This involves three steps: |
| 1. Load the base Wan2.1-T2V-1.3B model (DIT + T5 text encoder + VAE) |
| 2. Load the UMT5-XXL tokenizer for text encoding |
| 3. Apply the fine-tuned DIT weights on top of the base model |
| |
| Args: |
| device: Target device, default "cuda". |
| |
| Returns: |
| WanVideoPipeline ready for inference. |
| """ |
| from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig |
|
|
| print("Loading MediGen-1.3B...") |
|
|
| |
| |
| |
| |
| pipe = WanVideoPipeline.from_pretrained( |
| torch_dtype=torch.bfloat16, |
| device=device, |
| model_configs=[ |
| ModelConfig( |
| model_id="Wan-AI/Wan2.1-T2V-1.3B", |
| origin_file_pattern="diffusion_pytorch_model.safetensors", |
| ), |
| ModelConfig( |
| model_id="Wan-AI/Wan2.1-T2V-1.3B", |
| origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", |
| ), |
| ModelConfig( |
| model_id="Wan-AI/Wan2.1-T2V-1.3B", |
| origin_file_pattern="Wan2.1_VAE.pth", |
| ), |
| ], |
| tokenizer_config=ModelConfig( |
| model_id="Wan-AI/Wan2.1-T2V-1.3B", |
| origin_file_pattern="google/umt5-xxl/", |
| ), |
| ) |
|
|
| |
| |
| ckpt_path = MODELS_DIR / "medigen-1.3b.safetensors" |
| state_dict = load_file(str(ckpt_path)) |
| pipe.dit.load_state_dict(state_dict, strict=False) |
|
|
| |
| del state_dict |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| print("MediGen-1.3B ready.") |
| return pipe |
|
|
|
|
| def generate_video(pipe, prompt, output_path, seed=42, height=480, width=832): |
| """Generate a single video from a text prompt. |
| |
| Args: |
| pipe: Loaded WanVideoPipeline instance. |
| prompt: Text description of the desired medical video. |
| output_path: Path to save the output .mp4 file. |
| seed: Random seed for reproducibility. |
| height: Video height in pixels (default 480). |
| width: Video width in pixels (default 832). |
| """ |
| from diffsynth.utils.data import save_video |
|
|
| print(f"Generating: {prompt[:80]}...") |
|
|
| |
| |
| video = pipe( |
| prompt=prompt, |
| negative_prompt=NEGATIVE_PROMPT, |
| seed=seed, |
| height=height, |
| width=width, |
| tiled=True, |
| ) |
|
|
| |
| save_video(video, output_path, fps=15, quality=5) |
| print(f"Saved: {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="MediGen-1.3B Medical Video Generation" |
| ) |
| parser.add_argument("--prompt", type=str, help="Text prompt for generation") |
| parser.add_argument("--batch", type=str, help="JSON file with prompts") |
| parser.add_argument("--output", type=str, default="output.mp4", |
| help="Output path for single video (default: output.mp4)") |
| parser.add_argument("--output_dir", type=str, default="outputs", |
| help="Output directory for batch mode (default: outputs/)") |
| parser.add_argument("--seed", type=int, default=42, |
| help="Random seed for reproducibility (default: 42)") |
| parser.add_argument("--height", type=int, default=480, |
| help="Video height in pixels (default: 480)") |
| parser.add_argument("--width", type=int, default=832, |
| help="Video width in pixels (default: 832)") |
| parser.add_argument("--gpu", type=int, default=0, |
| help="GPU device ID (default: 0)") |
| args = parser.parse_args() |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
|
| |
| pipe = load_pipeline() |
|
|
| if args.batch: |
| |
| |
| with open(args.batch) as f: |
| prompts = json.load(f) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| for i, item in enumerate(prompts): |
| prompt = item if isinstance(item, str) else item.get("prompt", "") |
| out_path = os.path.join(args.output_dir, f"{i:03d}.mp4") |
| generate_video( |
| pipe, prompt, out_path, |
| seed=args.seed + i, |
| height=args.height, width=args.width, |
| ) |
|
|
| elif args.prompt: |
| |
| generate_video( |
| pipe, args.prompt, args.output, |
| seed=args.seed, height=args.height, width=args.width, |
| ) |
|
|
| else: |
| print("Error: provide --prompt or --batch") |
| parser.print_help() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|