| |
| """ |
| 将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件 |
| 基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进 |
| 支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录 |
| """ |
|
|
| import os |
| import argparse |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| import glob |
|
|
|
|
| def main(): |
| """ |
| 主函数:解析命令行参数并执行图片到npz的转换 |
| """ |
| parser = argparse.ArgumentParser( |
| description="将文件夹下所有PNG或JPG文件转换为NPZ格式", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| 使用示例: |
| python pic_npz.py /path/to/image/folder |
| python pic_npz.py /path/to/image/folder --output-dir /custom/output/path |
| """ |
| ) |
| |
| parser.add_argument( |
| "--image_folder", |
| type=str, |
| default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/reg_xl_2400000_sde_100/checkpoints/SiT-XL-2-0040000-size-256-vae-ema-cfg-1.0-seed-0-ode-0.85-1.0", |
| help="包含PNG或JPG图片文件的文件夹路径" |
| ) |
| |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=None, |
| help="自定义输出目录(默认为输入文件夹的父级目录)" |
| ) |
| |
| args = parser.parse_args() |
| |
| try: |
| |
| image_folder_path = os.path.abspath(args.image_folder) |
| |
| if args.output_dir: |
| |
| folder_name = os.path.basename(image_folder_path.rstrip('/')) |
| custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz") |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path) |
| else: |
| npz_path = create_npz_from_image_folder(image_folder_path) |
| |
| print(f"转换完成!NPZ文件已保存至: {npz_path}") |
| |
| except Exception as e: |
| print(f"错误: {e}") |
| return 1 |
| |
| return 0 |
|
|
|
|
| def create_npz_from_image_folder_custom(image_folder_path, output_path): |
| """ |
| 从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本) |
| |
| Args: |
| image_folder_path (str): 包含图片文件的文件夹路径 |
| output_path (str): 输出npz文件的完整路径 |
| |
| Returns: |
| str: 生成的 npz 文件路径 |
| """ |
| |
| if not os.path.exists(image_folder_path): |
| raise ValueError(f"文件夹路径不存在: {image_folder_path}") |
| |
| |
| supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG'] |
| image_files = [] |
| |
| for extension in supported_extensions: |
| pattern = os.path.join(image_folder_path, extension) |
| image_files.extend(glob.glob(pattern)) |
| |
| |
| image_files.sort() |
| |
| if len(image_files) == 0: |
| raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件") |
| |
| print(f"找到 {len(image_files)} 张图片文件") |
| |
| |
| samples = [] |
| for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"): |
| try: |
| |
| with Image.open(img_path) as img: |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| img = img.resize((512, 512), Image.LANCZOS) |
| |
| sample_np = np.asarray(img).astype(np.uint8) |
| |
| |
| if len(sample_np.shape) != 3 or sample_np.shape[2] != 3: |
| print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}") |
| continue |
| |
| samples.append(sample_np) |
| |
| except Exception as e: |
| print(f"警告: 无法读取图片 {img_path}: {e}") |
| continue |
| |
| if len(samples) == 0: |
| raise ValueError("没有成功读取任何有效的图片文件") |
| |
| |
| samples = np.stack(samples) |
| print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}") |
| |
| |
| assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}" |
| assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道" |
| |
| |
| np.savez(output_path, arr_0=samples) |
| print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]") |
| |
| return output_path |
|
|
|
|
| def create_npz_from_image_folder(image_folder_path): |
| """ |
| 从图片文件夹构建 .npz,输出到该文件夹的父目录,文件名为 <文件夹名>.npz |
| """ |
| parent_dir = os.path.dirname(os.path.abspath(image_folder_path)) |
| folder_name = os.path.basename(os.path.abspath(image_folder_path).rstrip("/")) |
| output_path = os.path.join(parent_dir, f"{folder_name}.npz") |
| return create_npz_from_image_folder_custom(image_folder_path, output_path) |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |