| |
| """ |
| 将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件 |
| 基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进 |
| 支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录 |
| 支持从 metadata.jsonl 文件读取图片路径 |
| """ |
|
|
| import os |
| import argparse |
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
| import glob |
| import json |
|
|
|
|
| def create_npz_from_metadata(metadata_jsonl_path, output_path=None): |
| """ |
| 从 metadata.jsonl 文件读取图片路径并构建 .npz 文件 |
| |
| Args: |
| metadata_jsonl_path (str): metadata.jsonl 文件路径 |
| output_path (str, optional): 输出 npz 文件路径,默认在 metadata.jsonl 同目录下生成 |
| |
| Returns: |
| str: 生成的 npz 文件路径 |
| """ |
| |
| if not os.path.exists(metadata_jsonl_path): |
| raise ValueError(f"metadata.jsonl 文件不存在: {metadata_jsonl_path}") |
| |
| |
| base_dir = os.path.dirname(metadata_jsonl_path) |
| |
| |
| image_files = [] |
| with open(metadata_jsonl_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| try: |
| data = json.loads(line) |
| file_name = data.get('file_name') |
| if file_name: |
| full_path = os.path.join(base_dir, file_name) |
| image_files.append(full_path) |
| except json.JSONDecodeError as e: |
| print(f"警告: 跳过无效的 JSON 行: {e}") |
| continue |
| |
| if len(image_files) == 0: |
| raise ValueError(f"在 {metadata_jsonl_path} 中未找到任何有效的图片路径") |
| |
| print(f"从 metadata.jsonl 读取到 {len(image_files)} 张图片路径") |
| |
| |
| samples = [] |
| for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"): |
| try: |
| |
| with Image.open(img_path) as img: |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| img = img.resize((512, 512), Image.LANCZOS) |
| |
| sample_np = np.asarray(img).astype(np.uint8) |
| |
| |
| if len(sample_np.shape) != 3 or sample_np.shape[2] != 3: |
| print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}") |
| continue |
| |
| samples.append(sample_np) |
| |
| except Exception as e: |
| print(f"警告: 无法读取图片 {img_path}: {e}") |
| continue |
| |
| if len(samples) == 0: |
| raise ValueError("没有成功读取任何有效的图片文件") |
| |
| |
| samples = np.stack(samples) |
| print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}") |
| |
| |
| assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}" |
| assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道" |
| |
| |
| if output_path is None: |
| base_name = os.path.splitext(os.path.basename(metadata_jsonl_path))[0] |
| output_path = os.path.join(base_dir, f"{base_name}.npz") |
| |
| |
| np.savez(output_path, arr_0=samples) |
| print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]") |
| |
| return output_path |
|
|
|
|
| def main(): |
| """ |
| 主函数:解析命令行参数并执行图片到npz的转换 |
| """ |
| parser = argparse.ArgumentParser( |
| description="将文件夹下所有PNG或JPG文件转换为NPZ格式", |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| epilog=""" |
| 使用示例: |
| python pic_npz.py /path/to/image/folder |
| python pic_npz.py /path/to/image/folder --output-dir /custom/output/path |
| """ |
| ) |
| |
| parser.add_argument( |
| "--image_folder", |
| type=str, |
| default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples", |
| help="包含PNG或JPG图片文件的文件夹路径" |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default=None, |
| help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)" |
| ) |
| |
| args = parser.parse_args() |
| |
| try: |
| if args.metadata_jsonl and os.path.exists(args.metadata_jsonl): |
| |
| metadata_path = os.path.abspath(args.metadata_jsonl) |
| base_dir = os.path.dirname(metadata_path) |
| base_name = os.path.splitext(os.path.basename(metadata_path))[0] |
| |
| if args.output_dir: |
| os.makedirs(args.output_dir, exist_ok=True) |
| output_path = os.path.join(args.output_dir, f"{base_name}.npz") |
| else: |
| output_path = os.path.join(base_dir, f"{base_name}.npz") |
| |
| npz_path = create_npz_from_metadata(metadata_path, output_path) |
| else: |
| |
| image_folder_path = os.path.abspath(args.image_folder) |
| |
| if args.output_dir: |
| |
| folder_name = os.path.basename(image_folder_path.rstrip('/')) |
| custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz") |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path) |
| else: |
| npz_path = create_npz_from_image_folder(image_folder_path) |
| |
| print(f"转换完成!NPZ文件已保存至: {npz_path}") |
| |
| except Exception as e: |
| print(f"错误: {e}") |
| return 1 |
| |
| return 0 |
|
|
|
|
| def create_npz_from_image_folder_custom(image_folder_path, output_path): |
| """ |
| 从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本) |
| |
| Args: |
| image_folder_path (str): 包含图片文件的文件夹路径 |
| output_path (str): 输出npz文件的完整路径 |
| |
| Returns: |
| str: 生成的 npz 文件路径 |
| """ |
| |
| if not os.path.exists(image_folder_path): |
| raise ValueError(f"文件夹路径不存在: {image_folder_path}") |
| |
| |
| supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG'] |
| image_files = [] |
| |
| for extension in supported_extensions: |
| pattern = os.path.join(image_folder_path, extension) |
| image_files.extend(glob.glob(pattern)) |
| |
| |
| image_files.sort() |
| |
| if len(image_files) == 0: |
| raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件") |
| |
| print(f"找到 {len(image_files)} 张图片文件") |
| |
| |
| samples = [] |
| for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"): |
| try: |
| |
| with Image.open(img_path) as img: |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| |
| |
| img = img.resize((512, 512), Image.LANCZOS) |
| |
| sample_np = np.asarray(img).astype(np.uint8) |
| |
| |
| if len(sample_np.shape) != 3 or sample_np.shape[2] != 3: |
| print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}") |
| continue |
| |
| samples.append(sample_np) |
| |
| except Exception as e: |
| print(f"警告: 无法读取图片 {img_path}: {e}") |
| continue |
| |
| if len(samples) == 0: |
| raise ValueError("没有成功读取任何有效的图片文件") |
| |
| |
| samples = np.stack(samples) |
| print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}") |
| |
| |
| assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}" |
| assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道" |
| |
| |
| np.savez(output_path, arr_0=samples) |
| print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]") |
| |
| return output_path |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |