SwinFace / source /pic_crop_LFW.py
MurmanskY's picture
upload model source code
a238b5e verified
'''
测试第一步,将50wild图片,使用MTCNN进行检测、截取,得到50cropped后的数据集
'''
import os
import torch
from facenet_pytorch import MTCNN
from PIL import Image
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
# 初始化MTCNN模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mtcnn = MTCNN(keep_all=False, device=device) # keep_all=False 只提取单张人脸
# 定义路径
data_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_wild' # LFW图像文件目录
save_dir = '../../../datasets/classification/LFWPairs/lfw-py/lfw_test_template_50_cropped' # 保存裁剪后人脸的目录
error_log_path = '../../../datasets/classification/LFWPairs/lfw-py/lfw_error_log_selected_50.txt' # 保存错误信息的文件
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 定义人脸裁剪函数
def crop_and_save_faces(image_path, save_path):
try:
# 加载图像
image = Image.open(image_path).convert('RGB')
# 检测人脸并裁剪
boxes, _ = mtcnn.detect(image)
if boxes is not None:
for i, box in enumerate(boxes):
x1, y1, x2, y2 = map(int, box)
if x2 > x1 and y2 > y1: # 确保裁剪框有效
face = image.crop((x1, y1, x2, y2)) # 裁剪人脸区域
os.makedirs(os.path.dirname(save_path), exist_ok=True)
face.save(save_path)
else:
# 如果没有检测到人脸,记录图片信息
with open(error_log_path, 'a') as f:
f.write(f"未检测到人脸: {image_path}\n")
except Exception as e:
# 如果发生错误,记录图片信息和错误信息
with open(error_log_path, 'a') as f:
f.write(f"处理 {image_path} 时出错: {e}\n")
# 遍历LFW数据集并提取人脸
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.lower().endswith(('jpg', 'jpeg', 'png')):
if 'test' in root or 'template' in root:
image_path = os.path.join(root, file)
relative_path = os.path.relpath(image_path, data_dir)
save_path = os.path.join(save_dir, relative_path)
# 使用多线程加速裁剪
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
list(tqdm(executor.map(lambda img: crop_and_save_faces(img, os.path.join(save_dir, os.path.relpath(img, data_dir))), [image_path]), total=1))
print("所有人脸提取完成并保存到: ", save_dir)
print("错误日志已保存到: ", error_log_path)