sida / download_sd3_models.py
xiangzai's picture
Add files using upload-large-folder tool
7803bdf verified
#!/usr/bin/env python
# coding: utf-8
"""
只负责“按 train_lora_sd3.py 相同的方式”下载 SD3 及相关组件到默认 HF cache:
通过依次调用 `from_pretrained(..., subfolder=...)` 来触发下载。
会下载的子目录(与 train_lora_sd3.py 一致):
tokenizer, tokenizer_2, tokenizer_3,
text_encoder, text_encoder_2, text_encoder_3,
scheduler, vae, transformer
用法:
python download_sd3_models.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers
下载完成后训练可离线:
export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1
python train_lora_sd3.py --pretrained_model_name_or_path stabilityai/stable-diffusion-3-medium-diffusers ...
或直接指向你已经下载好的本地 repo 目录。
"""
import argparse
import gc
from typing import Optional
import torch
from diffusers import StableDiffusion3Pipeline
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Download SD3 via a single from_pretrained call (cache warmup only).")
p.add_argument(
"--pretrained_model_name_or_path",
type=str,
default="stabilityai/stable-diffusion-3-medium-diffusers",
help="模型 repo id 或本地路径(与 train_lora_sd3.py 参数一致",
)
p.add_argument("--revision", type=str, default=None, help="可选:下载特定 revision/branch/tag")
p.add_argument("--variant", type=str, default=None, help="可选:如 fp16 等 variant(若仓库提供)")
p.add_argument("--cache_dir", type=str, default=None, help="可选:自定义 HF cache_dir;默认用系统/用户默认")
return p.parse_args()
def main() -> None:
args = parse_args()
model = args.pretrained_model_name_or_path
# 最简单:直接加载整条 pipeline,触发把其依赖的全部组件下载进默认 HF cache
# 注意:from_pretrained 下载的是 pipeline 会用到的文件;本脚本目的就是让训练时不再联网。
pipe = StableDiffusion3Pipeline.from_pretrained(
model,
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir,
low_cpu_mem_usage=True,
)
# 释放内存(下载已完成)
del pipe
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("下载/缓存预热完成。后续训练可设置离线环境变量避免联网:")
print(" export HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1")
if __name__ == "__main__":
main()