Diffusers
Safetensors
English
video
generation
MedGen-1.3B / inference.py
wangrongsheng's picture
Upload folder using huggingface_hub
40811e8 verified
"""
MediGen-1.3B - Medical Video Generation
Fine-tuned Wan2.1-T2V-1.3B on MedVideoCap-55K dataset for generating
medical-domain videos from text descriptions.
Usage:
# Single video generation
python inference.py --prompt "A doctor examining a patient" --output exam.mp4
# Batch generation from JSON file (list of strings or objects with "prompt" key)
python inference.py --batch prompts.json --output_dir results/
# Use a specific GPU
python inference.py --prompt "..." --gpu 1 --output result.mp4
# Custom resolution and seed
python inference.py --prompt "..." --height 480 --width 832 --seed 123
"""
import torch
import os
import json
import argparse
import gc
from pathlib import Path
from safetensors.torch import load_file
# ---------------------------------------------------------------------------
# Paths - all model weights are stored under models/ relative to this script
# ---------------------------------------------------------------------------
SCRIPT_DIR = Path(__file__).resolve().parent
MODELS_DIR = SCRIPT_DIR / "models"
# Default negative prompt to suppress common artifacts in generated videos
NEGATIVE_PROMPT = (
"Distorted, blurry, low quality, watermark, text overlay, "
"static image, worst quality, JPEG artifacts, deformed, "
"extra limbs, bad anatomy"
)
def load_pipeline(device="cuda"):
"""Load the full MediGen-1.3B pipeline.
This involves three steps:
1. Load the base Wan2.1-T2V-1.3B model (DIT + T5 text encoder + VAE)
2. Load the UMT5-XXL tokenizer for text encoding
3. Apply the fine-tuned DIT weights on top of the base model
Args:
device: Target device, default "cuda".
Returns:
WanVideoPipeline ready for inference.
"""
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
print("Loading MediGen-1.3B...")
# Step 1: Load base pipeline components
# - DIT (Diffusion Transformer): single file for 1.3B model
# - T5 text encoder: converts text prompts to embeddings
# - VAE: decodes latent representations into video frames
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device=device,
model_configs=[
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="diffusion_pytorch_model.safetensors",
),
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
),
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="Wan2.1_VAE.pth",
),
],
tokenizer_config=ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="google/umt5-xxl/",
),
)
# Step 2: Apply fine-tuned DIT weights
# Only the DIT component was fine-tuned; T5 and VAE remain unchanged
ckpt_path = MODELS_DIR / "medigen-1.3b.safetensors"
state_dict = load_file(str(ckpt_path))
pipe.dit.load_state_dict(state_dict, strict=False)
# Free checkpoint memory after loading
del state_dict
gc.collect()
torch.cuda.empty_cache()
print("MediGen-1.3B ready.")
return pipe
def generate_video(pipe, prompt, output_path, seed=42, height=480, width=832):
"""Generate a single video from a text prompt.
Args:
pipe: Loaded WanVideoPipeline instance.
prompt: Text description of the desired medical video.
output_path: Path to save the output .mp4 file.
seed: Random seed for reproducibility.
height: Video height in pixels (default 480).
width: Video width in pixels (default 832).
"""
from diffsynth.utils.data import save_video
print(f"Generating: {prompt[:80]}...")
# Run diffusion inference with 50 denoising steps (default)
# tiled=True enables tiled VAE decoding to reduce VRAM usage
video = pipe(
prompt=prompt,
negative_prompt=NEGATIVE_PROMPT,
seed=seed,
height=height,
width=width,
tiled=True,
)
# Save as MP4 at 15fps with quality level 5
save_video(video, output_path, fps=15, quality=5)
print(f"Saved: {output_path}")
def main():
parser = argparse.ArgumentParser(
description="MediGen-1.3B Medical Video Generation"
)
parser.add_argument("--prompt", type=str, help="Text prompt for generation")
parser.add_argument("--batch", type=str, help="JSON file with prompts")
parser.add_argument("--output", type=str, default="output.mp4",
help="Output path for single video (default: output.mp4)")
parser.add_argument("--output_dir", type=str, default="outputs",
help="Output directory for batch mode (default: outputs/)")
parser.add_argument("--seed", type=int, default=42,
help="Random seed for reproducibility (default: 42)")
parser.add_argument("--height", type=int, default=480,
help="Video height in pixels (default: 480)")
parser.add_argument("--width", type=int, default=832,
help="Video width in pixels (default: 832)")
parser.add_argument("--gpu", type=int, default=0,
help="GPU device ID (default: 0)")
args = parser.parse_args()
# Set visible GPU before any CUDA operations
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
# Load model (takes ~1-2 minutes for 1.3B model)
pipe = load_pipeline()
if args.batch:
# Batch mode: read JSON file containing a list of prompts
# Accepts either ["prompt1", "prompt2", ...] or [{"prompt": "..."}, ...]
with open(args.batch) as f:
prompts = json.load(f)
os.makedirs(args.output_dir, exist_ok=True)
for i, item in enumerate(prompts):
prompt = item if isinstance(item, str) else item.get("prompt", "")
out_path = os.path.join(args.output_dir, f"{i:03d}.mp4")
generate_video(
pipe, prompt, out_path,
seed=args.seed + i, # Different seed per video
height=args.height, width=args.width,
)
elif args.prompt:
# Single video mode
generate_video(
pipe, args.prompt, args.output,
seed=args.seed, height=args.height, width=args.width,
)
else:
print("Error: provide --prompt or --batch")
parser.print_help()
if __name__ == "__main__":
main()