sida / run_sd3_lora_sampling.sh
xiangzai's picture
Add files using upload-large-folder tool
7803bdf verified
#!/bin/bash
# SD3 LoRA模型采样脚本
# 使用JSONL文件进行采样的示例脚本
# 使用方法: ./run_sd3_lora_sampling.sh
# 设置GPU设备
export CUDA_VISIBLE_DEVICES="0,1,2,3" # 使用4个GPU(0,1,2,3)
# 内存优化设置
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# 模型和LoRA路径配置
PRETRAINED_MODEL="/gemini/space/hsd/project/pretrained_model/huggingface/hub/models--stabilityai--stable-diffusion-3-medium-diffusers/snapshots/ea42f8cef0f178587cf766dc8129abd379c90671"
# LoRA checkpoint路径 - 使用accelerator checkpoint目录
LORA_CHECKPOINT_PATH="/gemini/space/gzy_new/models/Sida/sd3-lora-finetuned-batch-4/checkpoint-500000"
# LoRA rank(必须与训练时一致)
LORA_RANK=32
# 采样参数配置
NUM_INFERENCE_STEPS=40
GUIDANCE_SCALE=7.0
HEIGHT=512
WIDTH=512
PER_PROC_BATCH_SIZE=1 # 每个GPU的批大小,建议从1开始(SD3模型很大,保持为1以避免内存溢出)
MAX_SAMPLES=30000 # 最大采样数量限制
# 提示词配置
#NEGATIVE_PROMPT="blurry, low quality, distorted, ugly, bad anatomy"
# Caption文件配置
CAPTIONS_JSONL="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl" # JSONL文件路径
IMAGES_PER_CAPTION=3 # 每个caption生成几张图片
# 输出配置
SAMPLE_DIR="./sd3_lora_samples_3w"
GLOBAL_SEED=42
echo "开始SD3 LoRA采样(从checkpoint加载)..."
echo "模型: $PRETRAINED_MODEL"
echo "LoRA Checkpoint路径: $LORA_CHECKPOINT_PATH"
echo "LoRA Rank: $LORA_RANK"
echo "Caption文件: $CAPTIONS_JSONL"
echo "每个caption生成图片数: $IMAGES_PER_CAPTION"
echo "图像尺寸: ${HEIGHT}x${WIDTH}"
echo "引导尺度: $GUIDANCE_SCALE"
echo "推理步数: $NUM_INFERENCE_STEPS"
# 检查必要文件
if [ ! -f "$CAPTIONS_JSONL" ]; then
echo "错误: Caption文件 $CAPTIONS_JSONL 不存在"
exit 1
fi
if [ ! -d "$LORA_CHECKPOINT_PATH" ]; then
echo "错误: LoRA checkpoint目录 $LORA_CHECKPOINT_PATH 不存在"
exit 1
fi
# 构建命令参数数组
CMD_ARGS=(
"--pretrained_model_name_or_path=$PRETRAINED_MODEL"
"--lora_checkpoint_path=$LORA_CHECKPOINT_PATH"
"--lora_rank=$LORA_RANK"
"--num_inference_steps=$NUM_INFERENCE_STEPS"
"--guidance_scale=$GUIDANCE_SCALE"
"--height=$HEIGHT"
"--width=$WIDTH"
"--per_proc_batch_size=$PER_PROC_BATCH_SIZE"
"--captions_jsonl=$CAPTIONS_JSONL"
"--images_per_caption=$IMAGES_PER_CAPTION"
"--sample_dir=$SAMPLE_DIR"
"--global_seed=$GLOBAL_SEED"
#"--max_samples=$MAX_SAMPLES"
"--mixed_precision=fp16" # 使用 fp16 以减少内存占用
# 注意:在多GPU环境下,CPU offload会被代码自动禁用(不支持分布式)
# 代码会自动检测world_size > 1并禁用CPU offload
"--enable_cpu_offload"
)
# # 添加负面提示词参数(如果存在)
# if [ ! -z "$NEGATIVE_PROMPT" ]; then
# CMD_ARGS+=("--negative_prompt" "$NEGATIVE_PROMPT")
# fi
# 运行分布式采样
torchrun --nproc_per_node=4 --master_port=25900 sample_sd3_lora_checkpoint_ddp.py "${CMD_ARGS[@]}"
echo "采样完成!"
echo "结果保存在: $SAMPLE_DIR"
echo "Caption信息保存在: $SAMPLE_DIR/*/captions.txt"
echo "NPZ文件已生成用于FID评估"
# nohup bash run_sd3_lora_sampling.sh > run_sd3_lora_sampling.log 2>&1 &