sida / pic_npz copy.py
xiangzai's picture
Add files using upload-large-folder tool
7803bdf verified
#!/usr/bin/env python3
"""
将文件夹下所有PNG或JPG文件读取并生成对应NPZ文件
基于 sample_ddp_new.py 中的 create_npz_from_sample_folder 函数改进
支持自动检测图片数量,支持PNG和JPG格式,输出到父级目录
支持从 metadata.jsonl 文件读取图片路径
"""
import os
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
import glob
import json
def create_npz_from_metadata(metadata_jsonl_path, output_path=None):
"""
从 metadata.jsonl 文件读取图片路径并构建 .npz 文件
Args:
metadata_jsonl_path (str): metadata.jsonl 文件路径
output_path (str, optional): 输出 npz 文件路径,默认在 metadata.jsonl 同目录下生成
Returns:
str: 生成的 npz 文件路径
"""
# 确保 metadata.jsonl 存在
if not os.path.exists(metadata_jsonl_path):
raise ValueError(f"metadata.jsonl 文件不存在: {metadata_jsonl_path}")
# 获取基础目录
base_dir = os.path.dirname(metadata_jsonl_path)
# 读取 metadata.jsonl
image_files = []
with open(metadata_jsonl_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
try:
data = json.loads(line)
file_name = data.get('file_name')
if file_name:
full_path = os.path.join(base_dir, file_name)
image_files.append(full_path)
except json.JSONDecodeError as e:
print(f"警告: 跳过无效的 JSON 行: {e}")
continue
if len(image_files) == 0:
raise ValueError(f"在 {metadata_jsonl_path} 中未找到任何有效的图片路径")
print(f"从 metadata.jsonl 读取到 {len(image_files)} 张图片路径")
# 读取所有图片
samples = []
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
try:
# 打开图片并转换为RGB格式(确保一致性)
with Image.open(img_path) as img:
# 转换为RGB,确保所有图片都是3通道
if img.mode != 'RGB':
img = img.convert('RGB')
# 将图片resize到512x512
img = img.resize((512, 512), Image.LANCZOS)
sample_np = np.asarray(img).astype(np.uint8)
# 确保图片是3通道
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
continue
samples.append(sample_np)
except Exception as e:
print(f"警告: 无法读取图片 {img_path}: {e}")
continue
if len(samples) == 0:
raise ValueError("没有成功读取任何有效的图片文件")
# 转换为numpy数组
samples = np.stack(samples)
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
# 验证数据形状
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
# 生成输出路径
if output_path is None:
base_name = os.path.splitext(os.path.basename(metadata_jsonl_path))[0]
output_path = os.path.join(base_dir, f"{base_name}.npz")
# 保存为npz文件
np.savez(output_path, arr_0=samples)
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
return output_path
def main():
"""
主函数:解析命令行参数并执行图片到npz的转换
"""
parser = argparse.ArgumentParser(
description="将文件夹下所有PNG或JPG文件转换为NPZ格式",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
使用示例:
python pic_npz.py /path/to/image/folder
python pic_npz.py /path/to/image/folder --output-dir /custom/output/path
"""
)
parser.add_argument(
"--image_folder",
type=str,
default="/gemini/space/gzy_new/models/Sida/sd3_rectified_samples",
help="包含PNG或JPG图片文件的文件夹路径"
)
# parser.add_argument(
# "--metadata_jsonl",
# type=str,
# default="/gemini/space/hsd/project/dataset/cc3m-wds/validation/metadata.jsonl",
# help="metadata.jsonl 文件路径,用于从 JSONL 文件读取图片路径"
# )
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="自定义输出目录(默认为输入文件夹的父级目录或 metadata.jsonl 所在目录)"
)
args = parser.parse_args()
try:
if args.metadata_jsonl and os.path.exists(args.metadata_jsonl):
# 使用 metadata.jsonl
metadata_path = os.path.abspath(args.metadata_jsonl)
base_dir = os.path.dirname(metadata_path)
base_name = os.path.splitext(os.path.basename(metadata_path))[0]
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
output_path = os.path.join(args.output_dir, f"{base_name}.npz")
else:
output_path = os.path.join(base_dir, f"{base_name}.npz")
npz_path = create_npz_from_metadata(metadata_path, output_path)
else:
# 使用图片文件夹
image_folder_path = os.path.abspath(args.image_folder)
if args.output_dir:
# 如果指定了输出目录,修改生成逻辑
folder_name = os.path.basename(image_folder_path.rstrip('/'))
custom_output_path = os.path.join(args.output_dir, f"{folder_name}.npz")
# 创建输出目录(如果不存在)
os.makedirs(args.output_dir, exist_ok=True)
# 临时修改函数以支持自定义输出路径
npz_path = create_npz_from_image_folder_custom(image_folder_path, custom_output_path)
else:
npz_path = create_npz_from_image_folder(image_folder_path)
print(f"转换完成!NPZ文件已保存至: {npz_path}")
except Exception as e:
print(f"错误: {e}")
return 1
return 0
def create_npz_from_image_folder_custom(image_folder_path, output_path):
"""
从包含图片的文件夹构建单个 .npz 文件(自定义输出路径版本)
Args:
image_folder_path (str): 包含图片文件的文件夹路径
output_path (str): 输出npz文件的完整路径
Returns:
str: 生成的 npz 文件路径
"""
# 确保路径存在
if not os.path.exists(image_folder_path):
raise ValueError(f"文件夹路径不存在: {image_folder_path}")
# 获取所有支持的图片文件
supported_extensions = ['*.png', '*.PNG', '*.jpg', '*.JPG', '*.jpeg', '*.JPEG']
image_files = []
for extension in supported_extensions:
pattern = os.path.join(image_folder_path, extension)
image_files.extend(glob.glob(pattern))
# 按文件名排序确保一致性
image_files.sort()
if len(image_files) == 0:
raise ValueError(f"在文件夹 {image_folder_path} 中未找到任何PNG或JPG图片文件")
print(f"找到 {len(image_files)} 张图片文件")
# 读取所有图片
samples = []
for img_path in tqdm(image_files, desc="读取图片并转换为numpy数组"):
try:
# 打开图片并转换为RGB格式(确保一致性)
with Image.open(img_path) as img:
# 转换为RGB,确保所有图片都是3通道
if img.mode != 'RGB':
img = img.convert('RGB')
# 将图片resize到512x512
img = img.resize((512, 512), Image.LANCZOS)
sample_np = np.asarray(img).astype(np.uint8)
# 确保图片是3通道
if len(sample_np.shape) != 3 or sample_np.shape[2] != 3:
print(f"警告: 跳过非3通道图片 {img_path}, 形状: {sample_np.shape}")
continue
samples.append(sample_np)
except Exception as e:
print(f"警告: 无法读取图片 {img_path}: {e}")
continue
if len(samples) == 0:
raise ValueError("没有成功读取任何有效的图片文件")
# 转换为numpy数组
samples = np.stack(samples)
print(f"成功读取 {len(samples)} 张图片,形状: {samples.shape}")
# 验证数据形状
assert len(samples.shape) == 4, f"期望4维数组,得到形状: {samples.shape}"
assert samples.shape[3] == 3, f"期望3通道图片,得到: {samples.shape[3]}通道"
# 保存为npz文件
np.savez(output_path, arr_0=samples)
print(f"已保存 .npz 文件到 {output_path} [形状={samples.shape}]")
return output_path
if __name__ == "__main__":
exit(main())