jsflow / pic_npz.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
#!/usr/bin/env python3
"""
将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
"""
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import glob
def main():
"""
主函数:解析命令行参数并执行图片到npz的转换
"""
parser = argparse.ArgumentParser(
description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
python pic_npz.py /path/to/image/folder
python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
"""
)
parser.add_argument(
"--image_folder",
type=str,
default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/reg_xl_2400000_sde_100/checkpoints/SiT-XL-2-0040000-size-256-vae-ema-cfg-1.0-seed-0-ode-0.85-1.0",
help="包含PNG或JPG图片文件的文件夹路径"
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="自定义输出目录(默认为输入文件夹的父级目录)"
)
args = parser.parse_args()
try:
# 仅支持从图片文件夹生成 npz
image_folder_path = os.path.abspath(args.image_folder)
if args.output_dir:
# 如果指定了输出目录,修改生成逻辑
folder_name = os.path.basename(image_folder_path.rstrip('/'))
custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
# 创建输出目录(如果不存在)
os.makedirs(args.output_dir, exist_ok=True)
npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
else:
npz_path = create_npz_from_image_folder(image_folder_path)
print(f"转换完成!NPZ文件已保存至: {npz_path}")
except Exception as e:
print(f"错误: {e}")
return 1
return 0
def create_npz_from_image_folder_custom(image_folder_path, output_path):
"""
从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
Args:
image_folder_path (str): 包含图片文件的文件夹路径
output_path (str): 输出npz文件的完整路径
Returns:
str: 生成的 npz 文件路径
"""
# 确保路径存在
if not os.path.exists(image_folder_path):
raise ValueError(f"文件夹路径不存在: {image_folder_path}")
# 获取所有支持的图片文件
supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
image_files = []
for extension in supported_extensions:
pattern = os.path.join(image_folder_path, extension)
image_files.extend(glob.glob(pattern))
# 按文件名排序确保一致性
image_files.sort()
if len(image_files) == 0:
raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
print(f"找到 {len(image_files)} 张图片文件")
# 读取所有图片
samples = []
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
try:
# 打开图片并转换为RGB格式(确保一致性)
with Image.open(img_path) as img:
# 转换为RGB,确保所有图片都是3通道
if img.mode != 'RGB':
img = img.convert('RGB')
# 将图片resize到512x512
img = img.resize((512, 512), Image.LANCZOS)
sample_np = np.asarray(img).astype(np.uint8)
# 确保图片是3通道
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
continue
samples.append(sample_np)
except Exception as e:
print(f"警告: 无法读取图片 {img_path}: {e}")
continue
if len(samples) == 0:
raise ValueError("没有成功读取任何有效的图片文件")
# 转换为numpy数组
samples = np.stack(samples)
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
# 验证数据形状
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
# 保存为npz文件
np.savez(output_path, arr_0=samples)
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
return output_path
def create_npz_from_image_folder(image_folder_path):
"""
从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz
"""
parent_dir = os.path.dirname(os.path.abspath(image_folder_path))
folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/"))
output_path = os.path.join(parent_dir, f"{folder_name}.npz")
return create_npz_from_image_folder_custom(image_folder_path, output_path)
if __name__ == "__main__":
exit(main())