| import cv2 |
| import numpy as np |
| from insightface.app import FaceAnalysis |
| import argparse |
| import os |
| from tqdm import tqdm |
|
|
| class FaceRecognitionPipeline: |
| def __init__(self, model_path='./models/axmodel'): |
| """ |
| 初始化人脸识别管道 |
| :param providers: ONNX 推理后端,支持 GPU/CPU |
| """ |
| |
| self.app = FaceAnalysis(root=model_path) |
| self.app.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
| |
| |
| |
|
|
| def extract_features(self, image): |
| """ |
| 从图像中提取所有人脸的特征向量 |
| :param image: BGR 图像 (H, W, 3) |
| :return: faces_info = [{'bbox', 'kps', 'embedding'}, ...] |
| """ |
| faces = self.app.get(image) |
| return faces |
|
|
| def compare(self, emb1, emb2, threshold=0.25): |
| """ |
| 计算两个 512 维特征向量的余弦相似度(InsightFace 使用余弦距离) |
| :param emb1, emb2: shape=(512,) |
| :param threshold: 相似度阈值(antelopev2 推荐 0.35,buffalo_l 推荐 0.25) |
| :return: (similarity, is_same) |
| """ |
| similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) |
| is_same = similarity > threshold |
| return similarity, is_same |
|
|
| def recognize(self, query_image, gallery_embeddings, gallery_names, threshold=0.25): |
| """ |
| 在图库中识别查询图像中的人脸 |
| :param query_image: BGR 图像 |
| :param gallery_embeddings: list of embeddings (n x 512) |
| :param gallery_names: list of names (n,) |
| :param threshold: 识别阈值 |
| :return: [{'name', 'similarity', 'bbox'}, ...] |
| """ |
| query_faces = self.extract_features(query_image) |
| results = [] |
| for face in query_faces: |
| best_sim = -1 |
| best_name = "Unknown" |
| for emb, name in zip(gallery_embeddings, gallery_names): |
| sim, _ = self.compare(face['embedding'], emb, threshold=0.25) |
| if sim > best_sim: |
| best_sim = sim |
| best_name = name if sim > threshold else "Unknown" |
| results.append({ |
| 'name': best_name, |
| 'similarity': best_sim, |
| 'face': face |
| }) |
| return results |
|
|
| def draw_results(self, image, results): |
| """在图像上绘制识别结果""" |
| img_draw = image.copy() |
| for res in results: |
| img_draw = self.app.draw_on(img_draw, [res['face']]) |
|
|
| x1, y1, x2, y2 = res['face']['bbox'].astype(int) |
| color = (0, 255, 0) if res['name'] != "Unknown" else (0, 0, 255) |
| cv2.putText(img_draw, f"{res['name']}: {res['similarity']:.2f}", |
| (x1, y2 + 15), cv2.FONT_HERSHEY_COMPLEX, 0.7, color, 1) |
| return img_draw |
|
|
| |
|
|
| if __name__ == "__main__": |
| args = argparse.ArgumentParser(description="Face Recognition Pipeline Example") |
| args.add_argument("--model_path", "-m", type=str, default="./models/buffalo_l", help="Path to the model directory") |
| args.add_argument("--type", "-t", type=int, default=0, help="Type of operation: 1: 1v1 compare, 2: 1vN recognize") |
| args.add_argument("--gallery_path", "-g", type=str, default=None, help="Path to the gallery image for image file") |
| args.add_argument("--query_path", "-q", type=str, default=None, help="Path to the query image") |
| args.add_argument("--draw", "-d", action='store_true', help="Whether to draw results on the image") |
| args = args.parse_args() |
|
|
| |
| pipeline = FaceRecognitionPipeline(model_path=args.model_path) |
|
|
| |
| if args.type == 0: |
| assert args.gallery_path is not None and args.query_path is not None, "请提供 gallery_path 和 query_path" |
| gallery_img = cv2.imread(args.gallery_path) |
| faces1 = pipeline.extract_features(gallery_img) |
| if faces1: |
| emb1 = faces1[0]['embedding'] |
| else: |
| print(f"警告: {args.gallery_path} 未检测到人脸") |
| exit(0) |
|
|
| query_img = cv2.imread(args.query_path) |
| faces2 = pipeline.extract_features(query_img) |
| if faces1 and faces2: |
| sim, is_same = pipeline.compare(emb1, faces2[0]['embedding'], threshold=0.25) |
| print(f"相似度: {sim:.4f}, 是否同一人: {is_same}") |
| else: |
| print(f"警告: {args.query_path} 未检测到人脸") |
| exit(0) |
|
|
| if args.draw: |
| os.makedirs("./output", exist_ok=True) |
| output_img = pipeline.app.draw_on(query_img, [faces2[0]]) |
| cv2.imwrite(f"./output/{os.path.basename(args.query_path)}", output_img) |
| print(f"结果已保存到 ./output/{os.path.basename(args.query_path)}") |
|
|
| elif args.type == 1: |
|
|
| assert args.gallery_path is not None, "请提供 gallery_path" |
|
|
| |
| gallery_names = [] |
| gallery_embeddings = [] |
| for fname in tqdm(os.listdir(args.gallery_path)): |
| name = os.path.splitext(os.path.basename(fname))[0] |
| gallery_img = cv2.imread(os.path.join(args.gallery_path, fname)) |
| faces = pipeline.extract_features(gallery_img) |
| if faces: |
| gallery_names.append(name) |
| gallery_embeddings.append(faces[0]['embedding']) |
| else: |
| print(f"警告: {fname} 未检测到人脸") |
|
|
| |
| print("特征库构建完成,包含以下人员:", gallery_names) |
| if args.draw: |
| os.makedirs("./output", exist_ok=True) |
| while True: |
| print("请输入查询图像路径 (输入 'exit' 退出): ") |
| user_input = input() |
| if user_input.lower() == 'exit': |
| break |
| if not os.path.isfile(user_input): |
| print("输入的路径不是有效的文件,请重新输入。") |
| continue |
| query_img = cv2.imread(user_input) |
| results = pipeline.recognize(query_img, gallery_embeddings, gallery_names, threshold=0.25) |
| |
| if results is None or len(results) == 0: |
| print(f"{user_input} 未检测到人脸") |
| continue |
| for res in results: |
| print(f"识别结果: {res['name']}, 相似度(0-1): {res['similarity']:.4f}") |
| if args.draw: |
| output_img = pipeline.draw_results(query_img, results) |
| cv2.imwrite(f"./output/{os.path.basename(user_input)}", output_img) |
| print(f"结果已保存到 ./output/{os.path.basename(user_input)}") |