Ahmet Hakan DİNGER commited on
Commit
5a96819
·
1 Parent(s): b129ac6

application file

Browse files
Files changed (7) hide show
  1. README.md +3 -3
  2. app.py +760 -0
  3. arcface_onnx.py +104 -0
  4. models/det_10g.onnx +3 -0
  5. models/w600k_r50.onnx +3 -0
  6. requirements.txt +0 -0
  7. scrfd.py +338 -0
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: FaceDetection
3
- emoji: 🦀
4
- colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
@@ -10,4 +10,4 @@ pinned: false
10
  short_description: An AI application for automatic face detection from video fi
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: FaceDetection
3
+ emoji: 🎭
4
+ colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.49.1
 
10
  short_description: An AI application for automatic face detection from video fi
11
  ---
12
 
13
+
app.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ from datetime import datetime
6
+ from scrfd import SCRFD
7
+ from arcface_onnx import ArcFaceONNX
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from sklearn.cluster import DBSCAN
10
+ import time
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ from dataclasses import dataclass
13
+ import logging
14
+ from typing import List, Tuple, Optional, Dict
15
+ import json
16
+ from pathlib import Path
17
+ import shutil
18
+ import requests
19
+ import tempfile
20
+ from urllib.parse import urlparse
21
+ import logging
22
+
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format='%(asctime)s - %(levelname)s - %(message)s'
27
+ )
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ try:
32
+ import yt_dlp
33
+ YOUTUBE_SUPPORT = True
34
+ except ImportError:
35
+ YOUTUBE_SUPPORT = False
36
+ logger.warning("Youtube desteği yüklü değil.")
37
+
38
+ @dataclass
39
+ class FaceDetectionConfig:
40
+ frame_skip: int = 30
41
+ face_size_threshold: int = 1000
42
+ clustering_eps: float = 0.5
43
+ min_samples: int = 2
44
+ resize_factor: float = 0.5
45
+ chunk_size: int = 500
46
+ max_workers: int = 2
47
+ use_gpu: bool = False
48
+
49
+ class FaceDetector:
50
+ def __init__(self, config: FaceDetectionConfig):
51
+ self.config = config
52
+ self.models = None
53
+ self.progress_callback = None
54
+ self.temp_files = []
55
+
56
+ def set_progress_callback(self, callback):
57
+ self.progress_callback = callback
58
+
59
+ def is_youtube_url(self, url: str) -> bool:
60
+
61
+ youtube_domains = ['youtube.com', 'youtu.be', 'youtube-nocookie.com']
62
+ parsed = urlparse(url)
63
+ return any(domain in parsed.netloc for domain in youtube_domains)
64
+
65
+ def download_youtube_video(self, url: str) -> str:
66
+
67
+ if not YOUTUBE_SUPPORT:
68
+ raise ValueError("YouTube desteği için paket kurulmalı")
69
+
70
+ try:
71
+ if self.progress_callback:
72
+ self.progress_callback(0, "YouTube videosu indiriliyor...")
73
+
74
+
75
+ temp_dir = tempfile.gettempdir()
76
+ temp_filename = f"yt_{int(time.time())}_{np.random.randint(1000, 9999)}"
77
+ temp_path_without_ext = os.path.join(temp_dir, temp_filename)
78
+
79
+ ydl_opts = {
80
+ 'format': 'best[ext=mp4][height<=720]/best[height<=720]/best',
81
+ 'outtmpl': temp_path_without_ext + '.%(ext)s',
82
+ 'quiet': True,
83
+ 'no_warnings': True,
84
+ 'socket_timeout': 60,
85
+ 'retries': 3,
86
+ 'fragment_retries': 3,
87
+ 'keepvideo': True,
88
+ 'merge_output_format': 'mp4',
89
+ 'postprocessors': [{
90
+ 'key': 'FFmpegVideoConvertor',
91
+ 'preferedformat': 'mp4',
92
+ }],
93
+ 'progress_hooks': [self._youtube_progress_hook],
94
+ }
95
+
96
+ logger.info(f"YouTube videosu indiriliyor: {url}")
97
+ logger.info(f"Hedef dosya: {temp_path_without_ext}.mp4")
98
+
99
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
100
+ info = ydl.extract_info(url, download=True)
101
+ video_title = info.get('title', 'video')
102
+ video_ext = info.get('ext', 'mp4')
103
+ logger.info(f"YouTube video başlığı: {video_title}")
104
+
105
+
106
+ final_path = temp_path_without_ext + '.mp4'
107
+
108
+
109
+ if not os.path.exists(final_path):
110
+ for ext in ['.mp4', '.webm', '.mkv']:
111
+ alt_path = temp_path_without_ext + ext
112
+ if os.path.exists(alt_path):
113
+ final_path = alt_path
114
+ logger.info(f"Video bulundu: {final_path}")
115
+ break
116
+
117
+ if not os.path.exists(final_path):
118
+ possible_files = [f for f in os.listdir(temp_dir) if f.startswith(temp_filename)]
119
+ if possible_files:
120
+ final_path = os.path.join(temp_dir, possible_files[0])
121
+ logger.info(f"Alternatif dosya bulundu: {final_path}")
122
+ else:
123
+ raise ValueError(f"YouTube videosu indirilemedi! Beklenen: {final_path}")
124
+
125
+
126
+ file_size = os.path.getsize(final_path)
127
+ if file_size == 0:
128
+ raise ValueError("İndirilen YouTube videosu boş!")
129
+
130
+ self.temp_files.append(final_path)
131
+
132
+ logger.info(f"YouTube videosu başarıyla indirildi: {final_path} ({file_size / 1024 / 1024:.1f}MB)")
133
+
134
+ if self.progress_callback:
135
+ self.progress_callback(20, f"YouTube videosu indirildi ({file_size / 1024 / 1024:.1f}MB)")
136
+
137
+ return final_path
138
+
139
+ except Exception as e:
140
+ logger.error(f"YouTube indirme hatası: {e}", exc_info=True)
141
+ raise ValueError(f"YouTube videosu indirilemedi: {str(e)}")
142
+
143
+ def _youtube_progress_hook(self, d):
144
+
145
+ if d['status'] == 'downloading':
146
+ if 'total_bytes' in d:
147
+ progress = (d['downloaded_bytes'] / d['total_bytes']) * 20
148
+ if self.progress_callback:
149
+ self.progress_callback(
150
+ progress,
151
+ f"YouTube indiriliyor: {d['downloaded_bytes'] / 1024 / 1024:.1f}MB / {d['total_bytes'] / 1024 / 1024:.1f}MB"
152
+ )
153
+ elif d['status'] == 'finished':
154
+ if self.progress_callback:
155
+ self.progress_callback(18, "YouTube videosu işleniyor...")
156
+
157
+ def download_video_from_url(self, url: str) -> str:
158
+
159
+ if self.is_youtube_url(url):
160
+ return self.download_youtube_video(url)
161
+
162
+
163
+ temp_path = None
164
+ try:
165
+ if self.progress_callback:
166
+ self.progress_callback(0, "Video indiriliyor...")
167
+
168
+ parsed = urlparse(url)
169
+ if not parsed.scheme in ['http', 'https']:
170
+ raise ValueError("Geçersiz URL! HTTP veya HTTPS protokolü kullanın.")
171
+
172
+ # Dosya uzantısını belirle
173
+ ext = os.path.splitext(parsed.path)[1]
174
+ if not ext or ext not in ['.mp4', '.avi', '.mov', '.mkv', '.webm']:
175
+ ext = '.mp4'
176
+
177
+ # Geçici dosya oluştur
178
+ temp_fd, temp_path = tempfile.mkstemp(suffix=ext, prefix='video_')
179
+ os.close(temp_fd) # File descriptor'ı kapat
180
+ self.temp_files.append(temp_path)
181
+
182
+ logger.info(f"Geçici dosya oluşturuldu: {temp_path}")
183
+
184
+ # URL'den indir
185
+ response = requests.get(url, stream=True, timeout=60,
186
+ headers={'User-Agent': 'Mozilla/5.0'})
187
+ response.raise_for_status()
188
+
189
+ total_size = int(response.headers.get('content-length', 0))
190
+ downloaded = 0
191
+
192
+
193
+ with open(temp_path, 'wb') as f:
194
+ for chunk in response.iter_content(chunk_size=65536): # 64KB chunks
195
+ if chunk:
196
+ f.write(chunk)
197
+ downloaded += len(chunk)
198
+ if total_size > 0 and self.progress_callback:
199
+ progress = (downloaded / total_size) * 20
200
+ if downloaded % (1024 * 1024) < 65536: # Her 1MB'de güncelle
201
+ self.progress_callback(
202
+ progress,
203
+ f"İndiriliyor: {downloaded / 1024 / 1024:.1f}MB / {total_size / 1024 / 1024:.1f}MB"
204
+ )
205
+
206
+
207
+ if not os.path.exists(temp_path):
208
+ raise ValueError("Video dosyası oluşturulamadı!")
209
+
210
+ file_size = os.path.getsize(temp_path)
211
+ if file_size == 0:
212
+ raise ValueError("İndirilen video dosyası boş!")
213
+
214
+ logger.info(f"Video başarıyla indirildi: {temp_path} ({file_size / 1024 / 1024:.1f}MB)")
215
+
216
+ if self.progress_callback:
217
+ self.progress_callback(20, f"Video indirildi ({file_size / 1024 / 1024:.1f}MB), işleme başlanıyor...")
218
+
219
+ return temp_path
220
+
221
+ except requests.exceptions.Timeout:
222
+ if temp_path and os.path.exists(temp_path):
223
+ os.unlink(temp_path)
224
+ raise ValueError("Video indirme zaman aşımına uğradı. Lütfen tekrar deneyin.")
225
+ except requests.exceptions.RequestException as e:
226
+ if temp_path and os.path.exists(temp_path):
227
+ os.unlink(temp_path)
228
+ raise ValueError(f"Video indirme hatası: {str(e)}")
229
+ except Exception as e:
230
+ if temp_path and os.path.exists(temp_path):
231
+ os.unlink(temp_path)
232
+ raise ValueError(f"Beklenmeyen hata: {str(e)}")
233
+
234
+ def cleanup_temp_files(self):
235
+ for temp_file in self.temp_files:
236
+ try:
237
+ if os.path.exists(temp_file):
238
+ os.unlink(temp_file)
239
+ logger.info(f"Geçici dosya silindi: {temp_file}")
240
+ except Exception as e:
241
+ logger.warning(f"Geçici dosya silinemedi {temp_file}: {e}")
242
+ self.temp_files = []
243
+
244
+
245
+ def _load_models(self) -> Tuple[SCRFD, ArcFaceONNX]:
246
+ try:
247
+ logger.info("Modeller yükleniyor (CPU mode)...")
248
+ current_dir = os.path.dirname(os.path.abspath(__file__))
249
+ models_dir = os.path.join(current_dir, 'deploy', 'models')
250
+
251
+ import onnxruntime as ort
252
+ sess_options = ort.SessionOptions()
253
+ ort.set_default_logger_severity(3)
254
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
255
+ sess_options.intra_op_num_threads = 2
256
+
257
+ # Sadece CPU provider
258
+ providers = ['CPUExecutionProvider']
259
+
260
+ det_model = os.path.join(models_dir, 'det_10g.onnx')
261
+ arc_model = os.path.join(models_dir, 'w600k_r50.onnx')
262
+
263
+ if not os.path.exists(det_model) or not os.path.exists(arc_model):
264
+ raise FileNotFoundError(f"Model dosyaları bulunamadı: {models_dir}")
265
+
266
+ detector = SCRFD(det_model)
267
+ detector.session = ort.InferenceSession(det_model, sess_options, providers=providers)
268
+
269
+ recognizer = ArcFaceONNX(arc_model)
270
+ recognizer.session = ort.InferenceSession(arc_model, sess_options, providers=providers)
271
+
272
+ logger.info(f"✅ CPU mode aktif: {recognizer.session.get_providers()}")
273
+ return detector, recognizer
274
+ except Exception as e:
275
+ logger.error(f"Model yükleme hatası: {e}")
276
+ raise
277
+
278
+
279
+ def create_output_directory(self, video_path: str, is_temp: bool = False) -> str:
280
+ logger.info(f"burası {self},{video_path},{is_temp}")
281
+ """Çıktı dizinini oluşturur - Gradio uyumlu"""
282
+ if is_temp:
283
+ # URL/YouTube için temp dizini kullan
284
+ temp_dir = tempfile.gettempdir()
285
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
286
+ output_dir = os.path.join(temp_dir, f"face_detection_{timestamp}")
287
+ else:
288
+ # Yerel dosya için aynı dizini kullan
289
+ base_dir = os.path.dirname(video_path)
290
+ video_name = os.path.splitext(os.path.basename(video_path))[0]
291
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
292
+ output_dir = os.path.join(base_dir, f"{video_name}_{timestamp}")
293
+
294
+ os.makedirs(output_dir, exist_ok=True)
295
+ logger.info(f"Output dizini oluşturuldu: {output_dir}")
296
+ return output_dir
297
+
298
+ def extract_embeddings(self, face_img: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
299
+ try:
300
+ detector, recognizer = self.models
301
+ bboxes, kpss = detector.autodetect(face_img, max_num=1)
302
+ if len(bboxes) == 0:
303
+ return None, None
304
+ kps = kpss[0]
305
+ embedding = recognizer.get(face_img, kps)
306
+ return embedding, kps
307
+ except Exception as e:
308
+ logger.error(f"Embedding çıkarma hatası: {e}")
309
+ return None, None
310
+
311
+ def calculate_face_quality(self, face_img: np.ndarray, face_size: float, kps: np.ndarray) -> float:
312
+ quality_score = 0
313
+ size_score = min(face_size / 5000, 2.0)
314
+ quality_score += size_score
315
+ left_eye, right_eye = kps[0], kps[1]
316
+ eye_distance = np.linalg.norm(left_eye - right_eye)
317
+ face_width = np.sqrt(face_size)
318
+ eye_ratio = eye_distance / face_width
319
+ angle_score = min(eye_ratio * 3, 2.0)
320
+ quality_score += angle_score
321
+ gray = cv2.cvtColor(face_img, cv2.COLOR_BGR2GRAY)
322
+ blur_var = cv2.Laplacian(gray, cv2.CV_64F).var()
323
+ blur_score = min(blur_var / 500, 2.0)
324
+ quality_score += blur_score
325
+ left_mouth, right_mouth = kps[3], kps[4]
326
+ mouth_distance = np.linalg.norm(left_mouth - right_mouth)
327
+ mouth_ratio = mouth_distance / face_width
328
+ symmetry_score = min(mouth_ratio * 3, 2.0)
329
+ quality_score += symmetry_score
330
+ return quality_score
331
+
332
+ def process_frame(self, frame: np.ndarray) -> List[Dict]:
333
+ frame = cv2.resize(frame, (0, 0), fx=self.config.resize_factor, fy=self.config.resize_factor)
334
+ detector, _ = self.models
335
+ faces_data = []
336
+
337
+ try:
338
+ bboxes, _ = detector.autodetect(frame)
339
+ for x1, y1, x2, y2, _ in bboxes:
340
+ face_size = (x2 - x1) * (y2 - y1)
341
+ if face_size < self.config.face_size_threshold:
342
+ continue
343
+
344
+ face_img = frame[int(y1):int(y2), int(x1):int(x2)]
345
+ embedding, kps = self.extract_embeddings(face_img)
346
+
347
+ if embedding is not None and kps is not None:
348
+ quality_score = self.calculate_face_quality(face_img, face_size, kps)
349
+ faces_data.append({
350
+ 'embedding': embedding,
351
+ 'face_img': face_img,
352
+ 'quality_score': quality_score,
353
+ 'bbox': [float(x1), float(y1), float(x2), float(y2)]
354
+ })
355
+ except Exception as e:
356
+ logger.error(f"Frame işleme hatası: {e}")
357
+
358
+ return faces_data
359
+
360
+ def process_video_chunk(self, frames: List[np.ndarray]) -> List[Dict]:
361
+ all_faces = []
362
+ for frame in frames:
363
+ faces = self.process_frame(frame)
364
+ all_faces.extend(faces)
365
+ return all_faces
366
+
367
+ def detect_faces(self, video_path: str, is_url: bool = False):
368
+ start_time = time.time()
369
+ original_path = video_path
370
+ downloaded_path = None
371
+
372
+ try:
373
+ if is_url:
374
+ downloaded_path = self.download_video_from_url(video_path)
375
+ video_path = downloaded_path
376
+ logger.info(f"URL'den indirilen video kullanılıyor: {video_path}")
377
+
378
+ # Video dosyasının varlığını kontrol et
379
+ if not os.path.exists(video_path):
380
+ raise ValueError(f"Video dosyası bulunamadı: {video_path}")
381
+
382
+ file_size = os.path.getsize(video_path)
383
+ if file_size == 0:
384
+ raise ValueError(f"Video dosyası boş: {video_path}")
385
+
386
+ logger.info(f"Video dosyası kontrol edildi: {video_path} ({file_size / 1024 / 1024:.1f}MB)")
387
+
388
+ if self.models is None:
389
+ self.models = self._load_models()
390
+
391
+ output_dir = self.create_output_directory(video_path if not is_url else tempfile.gettempdir(), is_temp=is_url)
392
+ metadata = {
393
+ 'video_path': original_path,
394
+ 'is_url': is_url,
395
+ 'processing_start': datetime.now().isoformat(),
396
+ 'config': vars(self.config),
397
+ 'faces': []
398
+ }
399
+
400
+ cap = cv2.VideoCapture(video_path)
401
+ if not cap.isOpened():
402
+ raise ValueError(f"Video açılamadı: {video_path}. Dosya bozuk veya desteklenmeyen format olabilir.")
403
+
404
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
405
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
406
+ duration = total_frames / fps if fps > 0 else 0
407
+
408
+ logger.info(f"Video: {total_frames} frame, {fps} FPS, {duration:.1f} saniye")
409
+
410
+ progress_offset = 20 if is_url else 0
411
+ max_progress = 80 if is_url else 100
412
+
413
+ if self.progress_callback:
414
+ self.progress_callback(progress_offset, f"Video açıldı: {total_frames} frame")
415
+
416
+ current_frames = []
417
+ all_faces_data = []
418
+ frame_count = 0
419
+
420
+ with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor:
421
+ while True:
422
+ ret, frame = cap.read()
423
+ if not ret:
424
+ break
425
+
426
+ frame_count += 1
427
+ if frame_count % self.config.frame_skip == 0:
428
+ current_frames.append(frame)
429
+
430
+ if len(current_frames) >= self.config.chunk_size:
431
+ future = executor.submit(self.process_video_chunk, current_frames)
432
+ all_faces_data.extend(future.result())
433
+ current_frames = []
434
+
435
+ if frame_count % 500 == 0:
436
+ progress = (frame_count / total_frames) * 100
437
+ if self.progress_callback:
438
+ adjusted_progress = progress_offset + (progress / 2) * ((max_progress - progress_offset) / 100)
439
+ self.progress_callback(
440
+ adjusted_progress,
441
+ f"Frame işleniyor: {frame_count}/{total_frames} ({progress:.1f}%)"
442
+ )
443
+
444
+ if current_frames:
445
+ future = executor.submit(self.process_video_chunk, current_frames)
446
+ all_faces_data.extend(future.result())
447
+
448
+ cap.release()
449
+
450
+ if not all_faces_data:
451
+ raise ValueError("Hiç yüz bulunamadı!")
452
+
453
+ clustering_progress = progress_offset + (max_progress - progress_offset) * 0.6
454
+ if self.progress_callback:
455
+ self.progress_callback(clustering_progress, f"{len(all_faces_data)} yüz tespit edildi, clustering yapılıyor...")
456
+
457
+ embeddings_array = np.array([face['embedding'] for face in all_faces_data])
458
+ clustering = DBSCAN(
459
+ eps=self.config.clustering_eps,
460
+ min_samples=self.config.min_samples,
461
+ metric='cosine'
462
+ ).fit(embeddings_array)
463
+
464
+ labels = clustering.labels_
465
+ n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
466
+
467
+ saving_progress = progress_offset + (max_progress - progress_offset) * 0.8
468
+ if self.progress_callback:
469
+ self.progress_callback(saving_progress, f"{n_clusters} benzersiz kişi tespit edildi, yüzler kaydediliyor...")
470
+
471
+ saved_faces = []
472
+ for cluster_id in range(n_clusters):
473
+ cluster_indices = np.where(labels == cluster_id)[0]
474
+ cluster_faces = [all_faces_data[i] for i in cluster_indices]
475
+ best_face = max(cluster_faces, key=lambda x: x['quality_score'])
476
+
477
+ face_img_resized = cv2.resize(best_face['face_img'], (112, 112))
478
+
479
+ face_file = f"person_{cluster_id}.jpg"
480
+ face_path = os.path.join(output_dir, face_file)
481
+ cv2.imwrite(face_path, face_img_resized, [cv2.IMWRITE_JPEG_QUALITY, 95])
482
+
483
+ saved_faces.append(face_path)
484
+
485
+ metadata['faces'].append({
486
+ 'cluster_id': cluster_id,
487
+ 'face_file': face_file,
488
+ 'quality_score': float(best_face['quality_score']),
489
+ 'bbox': best_face['bbox'],
490
+ 'cluster_size': len(cluster_indices)
491
+ })
492
+
493
+ elapsed_time = time.time() - start_time
494
+ metadata['processing_end'] = datetime.now().isoformat()
495
+ metadata['elapsed_time'] = elapsed_time
496
+ metadata['total_frames'] = total_frames
497
+ metadata['fps'] = fps
498
+ metadata['duration'] = duration
499
+ metadata['unique_persons'] = n_clusters
500
+
501
+ metadata_path = os.path.join(output_dir, 'metadata.json')
502
+ with open(metadata_path, 'w', encoding='utf-8') as f:
503
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
504
+
505
+ if self.progress_callback:
506
+ self.progress_callback(100, f"✅ Tamamlandı! {n_clusters} kişi bulundu ({elapsed_time:.1f}s)")
507
+
508
+ return output_dir, saved_faces, metadata
509
+
510
+ except Exception as e:
511
+ logger.error(f"İşlem hatası: {e}")
512
+ raise
513
+ finally:
514
+ if is_url:
515
+ self.cleanup_temp_files()
516
+
517
+ detector_instance = None
518
+
519
+ def initialize_detector(frame_skip, face_threshold, clustering_eps, use_gpu):
520
+ global detector_instance
521
+ config = FaceDetectionConfig(
522
+ frame_skip=frame_skip,
523
+ face_size_threshold=face_threshold,
524
+ clustering_eps=clustering_eps,
525
+ use_gpu=use_gpu
526
+ )
527
+ detector_instance = FaceDetector(config)
528
+ return "✅ Ayarlar kaydedildi!"
529
+
530
+ def process_video_gradio(video_file, video_url, progress=gr.Progress()):
531
+ global detector_instance
532
+
533
+ if detector_instance is None:
534
+ detector_instance = FaceDetector(FaceDetectionConfig())
535
+
536
+ def update_progress(value, message):
537
+ progress(value / 100, desc=message)
538
+
539
+ detector_instance.set_progress_callback(update_progress)
540
+
541
+ try:
542
+ progress(0, desc="İşlem başlatılıyor...")
543
+
544
+ if video_url and video_url.strip():
545
+ video_source = video_url.strip()
546
+ is_url = True
547
+ source_name = urlparse(video_url).path.split('/')[-1] or "video"
548
+ logger.info(f"URL kullanılıyor: {video_url}")
549
+
550
+ # YouTube mu kontrol et
551
+ if detector_instance.is_youtube_url(video_url):
552
+ if not YOUTUBE_SUPPORT:
553
+ return [], "❌ YouTube desteği için paket kurulmalı", "❌ paket kurulu değil"
554
+ logger.info("YouTube URL tespit edildi")
555
+
556
+ elif video_file:
557
+ video_source = video_file
558
+ is_url = False
559
+ source_name = os.path.basename(video_file)
560
+ logger.info(f"Yerel dosya kullanılıyor: {video_file}")
561
+ else:
562
+ return [], "❌ Lütfen bir video yükleyin veya URL girin!", "❌ Video bulunamadı"
563
+
564
+ # URL test (YouTube değilse)
565
+ if is_url and not detector_instance.is_youtube_url(video_source):
566
+ try:
567
+ head_response = requests.head(video_source, timeout=10, allow_redirects=True)
568
+ logger.info(f"URL test - Status: {head_response.status_code}, Content-Type: {head_response.headers.get('content-type', 'unknown')}")
569
+ if head_response.status_code != 200:
570
+ return [], f"❌ URL erişilemez (HTTP {head_response.status_code})", "❌ URL hatası"
571
+ except Exception as e:
572
+ logger.warning(f"URL test başarısız: {e}, yine de deneniyor...")
573
+
574
+ # Video süresini kontrol et (detect_faces çağrılmadan önce)
575
+ if not is_url:
576
+ cap = cv2.VideoCapture(video_source)
577
+ if cap.isOpened():
578
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
579
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
580
+ duration = total_frames / fps if fps > 0 else 0
581
+ cap.release()
582
+
583
+ if duration > 300: # 5 dakika limiti
584
+ return [], f"⚠️ Video çok uzun ({duration:.0f} saniye)! CPU modunda maksimum 5 dakika (300 saniye) desteklenir.", "❌ Süre limiti aşıldı"
585
+
586
+ output_dir, saved_faces, metadata = detector_instance.detect_faces(video_source, is_url=is_url)
587
+
588
+ # URL'den indirilen videolar için de süre kontrolü
589
+ if is_url and metadata['duration'] > 300:
590
+ return [], f"⚠️ Video çok uzun ({metadata['duration']:.0f} saniye)! CPU modunda maksimum 5 dakika desteklenir.", "❌ Süre limiti aşıldı"
591
+
592
+ report = f"""
593
+ # 📊 İşlem Raporu
594
+
595
+ ## Genel Bilgiler
596
+ - **Video**: {source_name}
597
+ - **Kaynak**: {'🌐 URL' if is_url else '📁 Yerel Dosya'}
598
+ - **Süre**: {metadata['duration']:.1f} saniye
599
+ - **FPS**: {metadata['fps']}
600
+ - **Toplam Frame**: {metadata['total_frames']}
601
+ - **İşlem Süresi**: {metadata['elapsed_time']:.1f} saniye
602
+
603
+ ## Tespit Sonuçları
604
+ - **Benzersiz Kişi**: {metadata['unique_persons']}
605
+ - **Toplam Yüz Tespiti**: {sum(f['cluster_size'] for f in metadata['faces'])}
606
+
607
+ ## Kişi Detayları
608
+ """
609
+ for face in metadata['faces']:
610
+ report += f"\n### Kişi {face['cluster_id']}\n"
611
+ report += f"- Kalite Skoru: {face['quality_score']:.2f}\n"
612
+ report += f"- Görülme Sayısı: {face['cluster_size']}\n"
613
+
614
+ return saved_faces, report, f"✅ Başarılı! Çıktı: {output_dir}"
615
+
616
+ except Exception as e:
617
+ error_msg = f"❌ Hata: {str(e)}"
618
+ logger.error(error_msg)
619
+ return [], error_msg, error_msg
620
+
621
+ def compare_two_faces(face1, face2):
622
+ global detector_instance
623
+
624
+ if detector_instance is None:
625
+ detector_instance = FaceDetector(FaceDetectionConfig())
626
+ detector_instance.models = detector_instance._load_models()
627
+
628
+ try:
629
+ img1 = cv2.imread(face1) if isinstance(face1, str) else cv2.cvtColor(face1, cv2.COLOR_RGB2BGR)
630
+ img2 = cv2.imread(face2) if isinstance(face2, str) else cv2.cvtColor(face2, cv2.COLOR_RGB2BGR)
631
+
632
+ emb1, _ = detector_instance.extract_embeddings(img1)
633
+ emb2, _ = detector_instance.extract_embeddings(img2)
634
+
635
+ if emb1 is None or emb2 is None:
636
+ return "❌ Yüz tespit edilemedi!"
637
+
638
+ similarity = cosine_similarity([emb1], [emb2])[0][0]
639
+ percentage = similarity * 100
640
+
641
+ if percentage > 70:
642
+ result = f"✅ Aynı Kişi ({percentage:.1f}% benzerlik)"
643
+ elif percentage > 50:
644
+ result = f"⚠️ Muhtemelen Aynı Kişi ({percentage:.1f}% benzerlik)"
645
+ else:
646
+ result = f"❌ Farklı Kişiler ({percentage:.1f}% benzerlik)"
647
+
648
+ return result
649
+
650
+ except Exception as e:
651
+ return f"❌ Hata: {str(e)}"
652
+
653
+ with gr.Blocks(title="Yüz Tanıma Sistemi", theme=gr.themes.Soft()) as demo:
654
+ gr.Markdown("""
655
+ # 🎭 Video Yüz Tanıma Sistemi
656
+ Video dosyalarından otomatik yüz tespiti ve tanıma yapın
657
+ ⚠️ **CPU Modunda Çalışıyor**: İşlem süresi uzun olabilir (5 dk video = ~10-15 dk)
658
+ """)
659
+
660
+ with gr.Tabs():
661
+ with gr.Tab("📹 Video İşle"):
662
+ gr.Markdown("### Video kaynağını seçin:")
663
+
664
+ with gr.Row():
665
+ with gr.Column():
666
+ video_input = gr.Video(label="📁 Yerel Video Yükle")
667
+ gr.Markdown("**VEYA**")
668
+ url_input = gr.Textbox(
669
+ label="🌐 Video URL'si",
670
+ placeholder="https://example.com/video.mp4",
671
+ lines=1
672
+ )
673
+ gr.Markdown("*URL girilirse öncelikle o kullanılır*")
674
+
675
+ process_btn = gr.Button("🚀 İşlemi Başlat", variant="primary", size="lg")
676
+ status_text = gr.Textbox(label="Durum", interactive=False)
677
+
678
+ with gr.Column():
679
+ gallery_output = gr.Gallery(label="Tespit Edilen Yüzler", columns=4, height=400)
680
+ report_output = gr.Markdown(label="Rapor")
681
+
682
+ gr.Markdown("""
683
+ #### 💡 URL Örnekleri:
684
+ - **YouTube**: `https://www.youtube.com/watch?v=xxxxx` veya `https://youtu.be/xxxxx` veya Shorts
685
+ - **Doğrudan video**: `https://example.com/video.mp4`
686
+ - Google Drive paylaşım linki çalışmaz (direkt indirme linki gerekir)
687
+ - **Desteklenen formatlar**: MP4, AVI, MOV, MKV, WebM
688
+
689
+ ⚠️ **YouTube için**: İlk kullanımda `pip install yt-dlp` komutu gereklidir
690
+ """)
691
+
692
+ process_btn.click(
693
+ fn=process_video_gradio,
694
+ inputs=[video_input, url_input],
695
+ outputs=[gallery_output, report_output, status_text]
696
+ )
697
+
698
+ with gr.Tab("🔍 Yüz Karşılaştır"):
699
+ gr.Markdown("İki yüz görselini yükleyin ve benzerliklerini kontrol edin")
700
+ with gr.Row():
701
+ face1_input = gr.Image(label="Yüz 1", type="filepath")
702
+ face2_input = gr.Image(label="Yüz 2", type="filepath")
703
+
704
+ compare_btn = gr.Button("⚖️ Karşılaştır", variant="primary")
705
+ compare_result = gr.Textbox(label="Sonuç", interactive=False)
706
+
707
+ compare_btn.click(
708
+ fn=compare_two_faces,
709
+ inputs=[face1_input, face2_input],
710
+ outputs=compare_result
711
+ )
712
+
713
+ with gr.Tab("⚙️ Ayarlar"):
714
+ gr.Markdown("### Gelişmiş Ayarlar")
715
+
716
+ frame_skip_slider = gr.Slider(20, 60, value=30, step=5,
717
+ label="Frame Atlama (yüksek = daha hızlı)")
718
+ face_threshold_slider = gr.Slider(600, 2000, value=1000, step=100,
719
+ label="Minimum Yüz Boyutu (piksel)")
720
+ clustering_slider = gr.Slider(0.3, 0.7, value=0.5, step=0.05,
721
+ label="Clustering Hassasiyeti")
722
+
723
+
724
+ save_settings_btn = gr.Button("💾 Ayarları Kaydet")
725
+ settings_status = gr.Textbox(label="Durum", interactive=False)
726
+
727
+ save_settings_btn.click(
728
+ fn=initialize_detector,
729
+ inputs=[frame_skip_slider, face_threshold_slider, clustering_slider],
730
+ outputs=settings_status
731
+ )
732
+
733
+ gr.Markdown("""
734
+ ---
735
+ ### 💡 İpuçları
736
+ - **Frame Atlama**: Daha hızlı işlem için artırın, daha fazla tespit için azaltın
737
+ - **Clustering**: Daha az kişi tespit ediyorsa artırın, fazla tespit ediyorsa azaltın
738
+ - **GPU**: Cuda destekli GPU varsa aktif edin
739
+ - **YouTube**: İlk kullanımda terminalde `pip install yt-dlp` çalıştırın
740
+ """)
741
+
742
+ if __name__ == "__main__":
743
+ # YouTube desteği kontrolü
744
+ print("\n" + "="*60)
745
+ print("🎬 Video Yüz Tanıma Sistemi")
746
+ print("="*60)
747
+
748
+ if YOUTUBE_SUPPORT:
749
+ print("✅ YouTube desteği: AKTİF")
750
+ try:
751
+ print(f" yt-dlp versiyon: {yt_dlp.version.__version__}")
752
+ except:
753
+ print(" yt-dlp versiyon bilgisi alınamadı")
754
+ else:
755
+ print("⚠️ YouTube desteği: KAPALI")
756
+ print(" Kurulum için: pip install yt-dlp")
757
+
758
+ print("="*60 + "\n")
759
+
760
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
arcface_onnx.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Organization : insightface.ai
3
+ # @Author : Jia Guo
4
+ # @Time : 2021-05-04
5
+ # @Function :
6
+
7
+ import numpy as np
8
+ import cv2
9
+ import onnx
10
+ import onnxruntime
11
+ import face_align
12
+ import os
13
+
14
+ __all__ = [
15
+ 'ArcFaceONNX',
16
+ ]
17
+
18
+
19
+ class ArcFaceONNX:
20
+ def __init__(self, model_file=None, session=None):
21
+ assert model_file is not None, "Model dosyası belirtilmedi"
22
+ self.model_file = model_file
23
+ self.session = session
24
+ self.taskname = 'recognition'
25
+ find_sub = False
26
+ find_mul = False
27
+ assert os.path.exists(model_file), f"Model dosyası bulunamadı: {model_file}" # Model varlığını kontrol et
28
+ model = onnx.load(self.model_file)
29
+ graph = model.graph
30
+ for nid, node in enumerate(graph.node[:8]):
31
+ #print(nid, node.name)
32
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
33
+ find_sub = True
34
+ if node.name.startswith('Mul') or node.name.startswith('_mul'):
35
+ find_mul = True
36
+ if find_sub and find_mul:
37
+ #mxnet arcface model
38
+ input_mean = 0.0
39
+ input_std = 1.0
40
+ else:
41
+ input_mean = 127.5
42
+ input_std = 127.5
43
+ self.input_mean = input_mean
44
+ self.input_std = input_std
45
+ #print('input mean and std:', self.input_mean, self.input_std)
46
+ if self.session is None:
47
+ sess_options = onnxruntime.SessionOptions()
48
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
49
+ sess_options.intra_op_num_threads = 4
50
+ self.session = onnxruntime.InferenceSession(
51
+ self.model_file,
52
+ sess_options=sess_options,
53
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
54
+ )
55
+ input_cfg = self.session.get_inputs()[0]
56
+ input_shape = input_cfg.shape
57
+ input_name = input_cfg.name
58
+ self.input_size = tuple(input_shape[2:4][::-1])
59
+ self.input_shape = input_shape
60
+ outputs = self.session.get_outputs()
61
+ output_names = []
62
+ for out in outputs:
63
+ output_names.append(out.name)
64
+ self.input_name = input_name
65
+ self.output_names = output_names
66
+ assert len(self.output_names)==1
67
+ self.output_shape = outputs[0].shape
68
+
69
+ def prepare(self, ctx_id, **kwargs):
70
+ if ctx_id<0:
71
+ self.session.set_providers(['CPUExecutionProvider'])
72
+
73
+ def get(self, img, kps):
74
+ aimg = face_align.norm_crop(img, landmark=kps, image_size=self.input_size[0])
75
+ embedding = self.get_feat(aimg).flatten()
76
+ return embedding
77
+
78
+ def compute_sim(self, feat1, feat2):
79
+ from numpy.linalg import norm
80
+ feat1 = feat1.ravel()
81
+ feat2 = feat2.ravel()
82
+
83
+ # arr_str = ','.join(map(str, feat1))
84
+ # print(arr_str)
85
+
86
+ sim = np.dot(feat1, feat2) / (norm(feat1) * norm(feat2))
87
+ return sim
88
+
89
+ def get_feat(self, imgs):
90
+ if not isinstance(imgs, list):
91
+ imgs = [imgs]
92
+ input_size = self.input_size
93
+
94
+ blob = cv2.dnn.blobFromImages(imgs, 1.0 / self.input_std, input_size,
95
+ (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
96
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
97
+ return net_out
98
+
99
+ def forward(self, batch_data):
100
+ blob = (batch_data - self.input_mean) / self.input_std
101
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
102
+ return net_out
103
+
104
+
models/det_10g.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
models/w600k_r50.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
3
+ size 174383860
requirements.txt ADDED
File without changes
scrfd.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from __future__ import division
4
+ import datetime
5
+ import numpy as np
6
+ #import onnx
7
+ import onnxruntime
8
+ import os
9
+ import os.path as osp
10
+ import cv2
11
+ import sys
12
+
13
+ def softmax(z):
14
+ assert len(z.shape) == 2
15
+ s = np.max(z, axis=1)
16
+ s = s[:, np.newaxis] # necessary step to do broadcasting
17
+ e_x = np.exp(z - s)
18
+ div = np.sum(e_x, axis=1)
19
+ div = div[:, np.newaxis] # dito
20
+ return e_x / div
21
+
22
+ def distance2bbox(points, distance, max_shape=None):
23
+ """Decode distance prediction to bounding box.
24
+
25
+ Args:
26
+ points (Tensor): Shape (n, 2), [x, y].
27
+ distance (Tensor): Distance from the given point to 4
28
+ boundaries (left, top, right, bottom).
29
+ max_shape (tuple): Shape of the image.
30
+
31
+ Returns:
32
+ Tensor: Decoded bboxes.
33
+ """
34
+ x1 = points[:, 0] - distance[:, 0]
35
+ y1 = points[:, 1] - distance[:, 1]
36
+ x2 = points[:, 0] + distance[:, 2]
37
+ y2 = points[:, 1] + distance[:, 3]
38
+ if max_shape is not None:
39
+ x1 = x1.clamp(min=0, max=max_shape[1])
40
+ y1 = y1.clamp(min=0, max=max_shape[0])
41
+ x2 = x2.clamp(min=0, max=max_shape[1])
42
+ y2 = y2.clamp(min=0, max=max_shape[0])
43
+ return np.stack([x1, y1, x2, y2], axis=-1)
44
+
45
+ def distance2kps(points, distance, max_shape=None):
46
+ """Decode distance prediction to bounding box.
47
+
48
+ Args:
49
+ points (Tensor): Shape (n, 2), [x, y].
50
+ distance (Tensor): Distance from the given point to 4
51
+ boundaries (left, top, right, bottom).
52
+ max_shape (tuple): Shape of the image.
53
+
54
+ Returns:
55
+ Tensor: Decoded bboxes.
56
+ """
57
+ preds = []
58
+ for i in range(0, distance.shape[1], 2):
59
+ px = points[:, i%2] + distance[:, i]
60
+ py = points[:, i%2+1] + distance[:, i+1]
61
+ if max_shape is not None:
62
+ px = px.clamp(min=0, max=max_shape[1])
63
+ py = py.clamp(min=0, max=max_shape[0])
64
+ preds.append(px)
65
+ preds.append(py)
66
+ return np.stack(preds, axis=-1)
67
+
68
+ class SCRFD:
69
+ def __init__(self, model_file=None, session=None):
70
+ import onnxruntime
71
+ self.model_file = model_file
72
+ self.session = session
73
+ self.taskname = 'detection'
74
+ self.batched = False
75
+ if self.session is None:
76
+ assert self.model_file is not None
77
+ assert osp.exists(self.model_file), f"Model dosyası bulunamadı: {self.model_file}" # Hata mesajı ekledim
78
+ # Session oluşturma kısmını güncelledim
79
+ sess_options = onnxruntime.SessionOptions()
80
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
81
+ sess_options.intra_op_num_threads = 4
82
+ self.session = onnxruntime.InferenceSession(
83
+ self.model_file,
84
+ sess_options=sess_options,
85
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
86
+ )
87
+ self.center_cache = {}
88
+ self.nms_thresh = 0.4
89
+ self.det_thresh = 0.5
90
+ self._init_vars()
91
+
92
+ def _init_vars(self):
93
+ input_cfg = self.session.get_inputs()[0]
94
+ input_shape = input_cfg.shape
95
+ #print(input_shape)
96
+ if isinstance(input_shape[2], str):
97
+ self.input_size = None
98
+ else:
99
+ self.input_size = tuple(input_shape[2:4][::-1])
100
+ #print('image_size:', self.image_size)
101
+ input_name = input_cfg.name
102
+ self.input_shape = input_shape
103
+ outputs = self.session.get_outputs()
104
+ if len(outputs[0].shape) == 3:
105
+ self.batched = True
106
+ output_names = []
107
+ for o in outputs:
108
+ output_names.append(o.name)
109
+ self.input_name = input_name
110
+ self.output_names = output_names
111
+ self.input_mean = 127.5
112
+ self.input_std = 128.0
113
+ #print(self.output_names)
114
+ #assert len(outputs)==10 or len(outputs)==15
115
+ self.use_kps = False
116
+ self._anchor_ratio = 1.0
117
+ self._num_anchors = 1
118
+ if len(outputs)==6:
119
+ self.fmc = 3
120
+ self._feat_stride_fpn = [8, 16, 32]
121
+ self._num_anchors = 2
122
+ elif len(outputs)==9:
123
+ self.fmc = 3
124
+ self._feat_stride_fpn = [8, 16, 32]
125
+ self._num_anchors = 2
126
+ self.use_kps = True
127
+ elif len(outputs)==10:
128
+ self.fmc = 5
129
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
130
+ self._num_anchors = 1
131
+ elif len(outputs)==15:
132
+ self.fmc = 5
133
+ self._feat_stride_fpn = [8, 16, 32, 64, 128]
134
+ self._num_anchors = 1
135
+ self.use_kps = True
136
+
137
+ def prepare(self, ctx_id, **kwargs):
138
+ if ctx_id<0:
139
+ self.session.set_providers(['CPUExecutionProvider'])
140
+ nms_thresh = kwargs.get('nms_thresh', None)
141
+ if nms_thresh is not None:
142
+ self.nms_thresh = nms_thresh
143
+ det_thresh = kwargs.get('det_thresh', None)
144
+ if det_thresh is not None:
145
+ self.det_thresh = det_thresh
146
+ input_size = kwargs.get('input_size', None)
147
+ if input_size is not None:
148
+ if self.input_size is not None:
149
+ print('warning: det_size is already set in scrfd model, ignore')
150
+ else:
151
+ self.input_size = input_size
152
+
153
+ def forward(self, img, threshold):
154
+ scores_list = []
155
+ bboxes_list = []
156
+ kpss_list = []
157
+ input_size = tuple(img.shape[0:2][::-1])
158
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
159
+ net_outs = self.session.run(self.output_names, {self.input_name : blob})
160
+
161
+ input_height = blob.shape[2]
162
+ input_width = blob.shape[3]
163
+ fmc = self.fmc
164
+ for idx, stride in enumerate(self._feat_stride_fpn):
165
+ # If model support batch dim, take first output
166
+ if self.batched:
167
+ scores = net_outs[idx][0]
168
+ bbox_preds = net_outs[idx + fmc][0]
169
+ bbox_preds = bbox_preds * stride
170
+ if self.use_kps:
171
+ kps_preds = net_outs[idx + fmc * 2][0] * stride
172
+ # If model doesn't support batching take output as is
173
+ else:
174
+ scores = net_outs[idx]
175
+ bbox_preds = net_outs[idx + fmc]
176
+ bbox_preds = bbox_preds * stride
177
+ if self.use_kps:
178
+ kps_preds = net_outs[idx + fmc * 2] * stride
179
+
180
+ height = input_height // stride
181
+ width = input_width // stride
182
+ K = height * width
183
+ key = (height, width, stride)
184
+ if key in self.center_cache:
185
+ anchor_centers = self.center_cache[key]
186
+ else:
187
+ #solution-1, c style:
188
+ #anchor_centers = np.zeros( (height, width, 2), dtype=np.float32 )
189
+ #for i in range(height):
190
+ # anchor_centers[i, :, 1] = i
191
+ #for i in range(width):
192
+ # anchor_centers[:, i, 0] = i
193
+
194
+ #solution-2:
195
+ #ax = np.arange(width, dtype=np.float32)
196
+ #ay = np.arange(height, dtype=np.float32)
197
+ #xv, yv = np.meshgrid(np.arange(width), np.arange(height))
198
+ #anchor_centers = np.stack([xv, yv], axis=-1).astype(np.float32)
199
+
200
+ #solution-3:
201
+ anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
202
+ #print(anchor_centers.shape)
203
+
204
+ anchor_centers = (anchor_centers * stride).reshape( (-1, 2) )
205
+ if self._num_anchors>1:
206
+ anchor_centers = np.stack([anchor_centers]*self._num_anchors, axis=1).reshape( (-1,2) )
207
+ if len(self.center_cache)<100:
208
+ self.center_cache[key] = anchor_centers
209
+
210
+ pos_inds = np.where(scores>=threshold)[0]
211
+ bboxes = distance2bbox(anchor_centers, bbox_preds)
212
+ pos_scores = scores[pos_inds]
213
+ pos_bboxes = bboxes[pos_inds]
214
+ scores_list.append(pos_scores)
215
+ bboxes_list.append(pos_bboxes)
216
+ if self.use_kps:
217
+ kpss = distance2kps(anchor_centers, kps_preds)
218
+ #kpss = kps_preds
219
+ kpss = kpss.reshape( (kpss.shape[0], -1, 2) )
220
+ pos_kpss = kpss[pos_inds]
221
+ kpss_list.append(pos_kpss)
222
+ return scores_list, bboxes_list, kpss_list
223
+
224
+ def detect(self, img, input_size = None, thresh=None, max_num=0, metric='default'):
225
+ assert input_size is not None or self.input_size is not None
226
+ input_size = self.input_size if input_size is None else input_size
227
+
228
+ im_ratio = float(img.shape[0]) / img.shape[1]
229
+ model_ratio = float(input_size[1]) / input_size[0]
230
+ if im_ratio>model_ratio:
231
+ new_height = input_size[1]
232
+ new_width = int(new_height / im_ratio)
233
+ else:
234
+ new_width = input_size[0]
235
+ new_height = int(new_width * im_ratio)
236
+ det_scale = float(new_height) / img.shape[0]
237
+ resized_img = cv2.resize(img, (new_width, new_height))
238
+ det_img = np.zeros( (input_size[1], input_size[0], 3), dtype=np.uint8 )
239
+ det_img[:new_height, :new_width, :] = resized_img
240
+ det_thresh = thresh if thresh is not None else self.det_thresh
241
+
242
+ scores_list, bboxes_list, kpss_list = self.forward(det_img, det_thresh)
243
+
244
+ scores = np.vstack(scores_list)
245
+ scores_ravel = scores.ravel()
246
+ order = scores_ravel.argsort()[::-1]
247
+ bboxes = np.vstack(bboxes_list) / det_scale
248
+ if self.use_kps:
249
+ kpss = np.vstack(kpss_list) / det_scale
250
+ pre_det = np.hstack((bboxes, scores)).astype(np.float32, copy=False)
251
+ pre_det = pre_det[order, :]
252
+ keep = self.nms(pre_det)
253
+ det = pre_det[keep, :]
254
+ if self.use_kps:
255
+ kpss = kpss[order,:,:]
256
+ kpss = kpss[keep,:,:]
257
+ else:
258
+ kpss = None
259
+ if max_num > 0 and det.shape[0] > max_num:
260
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
261
+ det[:, 1])
262
+ img_center = img.shape[0] // 2, img.shape[1] // 2
263
+ offsets = np.vstack([
264
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
265
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
266
+ ])
267
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
268
+ if metric=='max':
269
+ values = area
270
+ else:
271
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
272
+ bindex = np.argsort(
273
+ values)[::-1] # some extra weight on the centering
274
+ bindex = bindex[0:max_num]
275
+ det = det[bindex, :]
276
+ if kpss is not None:
277
+ kpss = kpss[bindex, :]
278
+ return det, kpss
279
+
280
+ def autodetect(self, img, max_num=0, metric='max'):
281
+ bboxes, kpss = self.detect(img, input_size=(640, 640), thresh=0.5)
282
+ bboxes2, kpss2 = self.detect(img, input_size=(128, 128), thresh=0.5)
283
+ bboxes_all = np.concatenate([bboxes, bboxes2], axis=0)
284
+ kpss_all = np.concatenate([kpss, kpss2], axis=0)
285
+ keep = self.nms(bboxes_all)
286
+ det = bboxes_all[keep,:]
287
+ kpss = kpss_all[keep,:]
288
+ if max_num > 0 and det.shape[0] > max_num:
289
+ area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
290
+ det[:, 1])
291
+ img_center = img.shape[0] // 2, img.shape[1] // 2
292
+ offsets = np.vstack([
293
+ (det[:, 0] + det[:, 2]) / 2 - img_center[1],
294
+ (det[:, 1] + det[:, 3]) / 2 - img_center[0]
295
+ ])
296
+ offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
297
+ if metric=='max':
298
+ values = area
299
+ else:
300
+ values = area - offset_dist_squared * 2.0 # some extra weight on the centering
301
+ bindex = np.argsort(
302
+ values)[::-1] # some extra weight on the centering
303
+ bindex = bindex[0:max_num]
304
+ det = det[bindex, :]
305
+ if kpss is not None:
306
+ kpss = kpss[bindex, :]
307
+ return det, kpss
308
+
309
+ def nms(self, dets):
310
+ thresh = self.nms_thresh
311
+ x1 = dets[:, 0]
312
+ y1 = dets[:, 1]
313
+ x2 = dets[:, 2]
314
+ y2 = dets[:, 3]
315
+ scores = dets[:, 4]
316
+
317
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
318
+ order = scores.argsort()[::-1]
319
+
320
+ keep = []
321
+ while order.size > 0:
322
+ i = order[0]
323
+ keep.append(i)
324
+ xx1 = np.maximum(x1[i], x1[order[1:]])
325
+ yy1 = np.maximum(y1[i], y1[order[1:]])
326
+ xx2 = np.minimum(x2[i], x2[order[1:]])
327
+ yy2 = np.minimum(y2[i], y2[order[1:]])
328
+
329
+ w = np.maximum(0.0, xx2 - xx1 + 1)
330
+ h = np.maximum(0.0, yy2 - yy1 + 1)
331
+ inter = w * h
332
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
333
+
334
+ inds = np.where(ovr <= thresh)[0]
335
+ order = order[inds + 1]
336
+
337
+ return keep
338
+