tcooper-xx commited on
Commit
34ecf0d
·
0 Parent(s):

Initial Commit

Browse files
.api_keys.json.swp ADDED
Binary file (12.3 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.ts filter=lfs diff=lfs merge=lfs -text
3
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ api_keys.json
2
+ .admin_key
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ .env
7
+ .venv/
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ # System libraries required by OpenCV and PyTorch
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ libglib2.0-0 \
6
+ libgl1 \
7
+ libsm6 \
8
+ libxrender1 \
9
+ libxext6 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ # Install Python dependencies before copying the rest of the code
15
+ # so this layer is cached as long as requirements.txt doesn't change.
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Copy application code and assets
20
+ COPY . .
21
+
22
+ # HF Spaces runs containers as uid 1000 (non-root)
23
+ RUN useradd -m -u 1000 appuser \
24
+ && chown -R appuser /app
25
+ USER appuser
26
+
27
+ EXPOSE 7860
28
+
29
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SW Fish Identifier
3
+ emoji: 🐟
4
+ colorFrom: blue
5
+ colorTo: teal
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ ---
10
+
11
+ # SW Fish Identifier
12
+
13
+ Upload a photo and SWClassifier will detect every fish in it, segment their outline,
14
+ and identify the species — returning both the common name and scientific name.
15
+
16
+ ## Pipeline
17
+
18
+ | Step | Model | Notes |
19
+ |---|---|---|
20
+ | Detection | YOLO v8 nano | Bounding boxes |
21
+ | Segmentation | FPN ResNet-18 | Per-fish polygon mask |
22
+ | Classification | BEiT v2 Base + FAISS kNN | 775 species |
23
+
24
+ ## API
25
+
26
+ The Space exposes a REST API alongside the web UI.
27
+
28
+ ### Identify fish in an image
29
+
30
+ ```
31
+ POST /api/v1/predict
32
+ X-API-Key: <your-key>
33
+ Content-Type: multipart/form-data
34
+
35
+ file=<image>
36
+ ```
37
+
38
+ **Response**
39
+
40
+ ```json
41
+ {
42
+ "detections": [
43
+ {
44
+ "bbox": { "x1": 120, "y1": 45, "x2": 380, "y2": 290, "confidence": 0.91 },
45
+ "polygon": [[120, 180], [135, 170], "..."],
46
+ "predictions": [
47
+ { "name": "Wahoo", "taxon": "Acanthocybium solandri", "accuracy": 0.87, "species_id": "..." }
48
+ ]
49
+ }
50
+ ],
51
+ "image_size": { "width": 1280, "height": 720 },
52
+ "timing": { "detect_ms": 210, "segment_ms": 85, "classify_ms": 430, "total_ms": 730 }
53
+ }
54
+ ```
55
+
56
+ Full interactive docs available at `/docs`.
57
+
58
+ ## Configuration (Space Secrets)
59
+
60
+ | Secret | Description |
61
+ |---|---|
62
+ | `SW_API_KEYS` | Comma-separated list of valid API keys |
63
+ | `SW_ADMIN_KEY` | Key required to create / revoke API keys via `/api/v1/keys` |
64
+
65
+ Set these under **Settings → Variables and Secrets** in the Space dashboard.
66
+ If not set, keys are auto-generated at startup (lost on container restart).
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SW Identifier — FastAPI server
3
+
4
+ Routes
5
+ ------
6
+ GET / SPA frontend
7
+ POST /predict Internal SPA endpoint (no auth)
8
+ POST /api/v1/predict Public API (requires X-API-Key header)
9
+ GET /api/v1/keys List API keys (requires X-Admin-Key header)
10
+ POST /api/v1/keys Create API key (requires X-Admin-Key header)
11
+ DELETE /api/v1/keys/{key} Revoke API key (requires X-Admin-Key header)
12
+ GET /docs OpenAPI / Swagger UI
13
+ """
14
+ import csv
15
+ import io
16
+ import json
17
+ import logging
18
+ import os
19
+ import secrets
20
+ import sys
21
+ import time
22
+ from contextlib import asynccontextmanager
23
+ from datetime import datetime, timezone
24
+ from typing import List, Optional
25
+
26
+ import numpy as np
27
+ from PIL import Image
28
+ from fastapi import APIRouter, Depends, FastAPI, File, HTTPException, Security, UploadFile
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from fastapi.responses import HTMLResponse
31
+ from fastapi.security import APIKeyHeader
32
+ from fastapi.staticfiles import StaticFiles
33
+ from pydantic import BaseModel
34
+
35
+ # ── paths ─────────────────────────────────────────────────────────────────────
36
+ BASE = os.path.dirname(os.path.abspath(__file__))
37
+ DETECTOR_PATH = os.path.join(BASE, "detector", "model.pt")
38
+ SEGMENTATOR_PATH = os.path.join(BASE, "segmentator", "model.ts")
39
+ CLASSIFIER_CKPT = os.path.join(BASE, "classification_model", "model.ckpt")
40
+ DATABASE_PATH = os.path.join(BASE, "classification_model", "database.pt")
41
+ STATIC_DIR = os.path.join(BASE, "static")
42
+ TAXONS_CSV = os.path.join(BASE, "taxons.csv")
43
+ KEYS_FILE = os.path.join(BASE, "api_keys.json")
44
+
45
+ sys.path.insert(0, BASE)
46
+
47
+ # ── logging ───────────────────────────────────────────────────────────────────
48
+ logging.basicConfig(level=logging.WARNING)
49
+ log = logging.getLogger("sw.app")
50
+
51
+ # ── common name lookup ────────────────────────────────────────────────────────
52
+ def _load_common_names(path: str) -> dict:
53
+ mapping = {}
54
+ with open(path, newline="", encoding="utf-8") as f:
55
+ for row in csv.DictReader(f):
56
+ taxon = row["taxon"].strip()
57
+ common = row["common_name"].strip()
58
+ if taxon:
59
+ mapping[taxon] = common or taxon
60
+ return mapping
61
+
62
+ COMMON_NAMES: dict = _load_common_names(TAXONS_CSV)
63
+
64
+ # ── API key store ─────────────────────────────────────────────────────────────
65
+ def _load_keys() -> list:
66
+ if os.path.exists(KEYS_FILE):
67
+ with open(KEYS_FILE, encoding="utf-8") as f:
68
+ return json.load(f)
69
+ return []
70
+
71
+ def _save_keys(keys: list) -> None:
72
+ with open(KEYS_FILE, "w", encoding="utf-8") as f:
73
+ json.dump(keys, f, indent=2)
74
+
75
+ def _valid_key_set() -> set:
76
+ # Prefer env var (comma-separated) — required for stateless deployments
77
+ # like HF Spaces where the filesystem is ephemeral.
78
+ env = os.environ.get("SW_API_KEYS", "").strip()
79
+ if env:
80
+ return {k.strip() for k in env.split(",") if k.strip()}
81
+ return {k["key"] for k in _load_keys()}
82
+
83
+ def _new_key(name: str) -> dict:
84
+ return {
85
+ "key": "fsh_" + secrets.token_urlsafe(32),
86
+ "name": name,
87
+ "created_at": datetime.now(timezone.utc).isoformat(),
88
+ }
89
+
90
+ # Ensure at least one key exists on startup; print it once to console.
91
+ def _bootstrap_keys() -> None:
92
+ # Skip file-based bootstrap when keys are supplied via env var.
93
+ if os.environ.get("SW_API_KEYS", "").strip():
94
+ return
95
+ keys = _load_keys()
96
+ if not keys:
97
+ k = _new_key("default")
98
+ _save_keys([k])
99
+ print("\n" + "═" * 60)
100
+ print(" No API keys found — generated a default key:")
101
+ print(f" {k['key']}")
102
+ print(" Store this somewhere safe; it won't be shown again.")
103
+ print("═" * 60 + "\n")
104
+
105
+ # Admin key — set SW_ADMIN_KEY env var, or one is auto-generated once.
106
+ _ADMIN_KEY_FILE = os.path.join(BASE, ".admin_key")
107
+
108
+ def _get_admin_key() -> str:
109
+ env = os.environ.get("SW_ADMIN_KEY")
110
+ if env:
111
+ return env
112
+ if os.path.exists(_ADMIN_KEY_FILE):
113
+ with open(_ADMIN_KEY_FILE) as f:
114
+ return f.read().strip()
115
+ key = "fadm_" + secrets.token_urlsafe(32)
116
+ with open(_ADMIN_KEY_FILE, "w") as f:
117
+ f.write(key)
118
+ print("\n" + "═" * 60)
119
+ print(" Admin key (manage API keys):")
120
+ print(f" {key}")
121
+ print(" Stored in .admin_key — keep it out of version control.")
122
+ print("═" * 60 + "\n")
123
+ return key
124
+
125
+ ADMIN_KEY: str = "" # set during lifespan
126
+
127
+ # ── model globals ─────────────────────────────────────────────────────────────
128
+ detector = None
129
+ segmentator = None
130
+ classifier = None
131
+
132
+ # ── lifespan ──────────────────────────────────────────────────────────────────
133
+ @asynccontextmanager
134
+ async def lifespan(app: FastAPI):
135
+ global detector, segmentator, classifier, ADMIN_KEY
136
+
137
+ _bootstrap_keys()
138
+ ADMIN_KEY = _get_admin_key()
139
+
140
+ from ultralytics import YOLO
141
+ log.warning("Loading detector …")
142
+ detector = YOLO(DETECTOR_PATH)
143
+
144
+ log.warning("Loading segmentator …")
145
+ from segmentator.inference import Inference as Segmentator
146
+ segmentator = Segmentator(SEGMENTATOR_PATH)
147
+
148
+ log.warning("Loading classifier …")
149
+ from classification_model.inference import EmbeddingClassifier
150
+ classifier = EmbeddingClassifier({
151
+ "log_level": "WARNING",
152
+ "dataset": {"path": DATABASE_PATH},
153
+ "model": {
154
+ "checkpoint_path": CLASSIFIER_CKPT,
155
+ "backbone_model_name": "beitv2_base_patch16_224.in1k_ft_in22k_in1k",
156
+ "embedding_dim": 512,
157
+ "num_classes": 775,
158
+ "arcface_s": 64.0,
159
+ "arcface_m": 0.2,
160
+ "pooling_type": "attention",
161
+ "device": "cpu",
162
+ },
163
+ "use_knn": True,
164
+ "arcface_min_score": 0.1,
165
+ "centroid_fallback_score": 0.1,
166
+ "topk_centroid": 5,
167
+ "topk_neighbors": 10,
168
+ "topk_arcface": 5,
169
+ "centroid_threshold": 0.7,
170
+ "neighbor_threshold": 0.8,
171
+ "use_albumentations": False,
172
+ })
173
+
174
+ log.warning("All models ready.")
175
+ yield
176
+ log.warning("Shutting down.")
177
+
178
+ # ── Pydantic response models ──────────────────────────────────────────────────
179
+ class BoundingBox(BaseModel):
180
+ x1: int
181
+ y1: int
182
+ x2: int
183
+ y2: int
184
+ confidence: float
185
+
186
+ class Prediction(BaseModel):
187
+ name: str # common name
188
+ taxon: str # scientific name
189
+ accuracy: float # confidence 0–1
190
+ species_id: str
191
+
192
+ class Detection(BaseModel):
193
+ bbox: BoundingBox
194
+ polygon: Optional[List[List[int]]] # [[x,y], ...] in original image coords
195
+ predictions: List[Prediction]
196
+
197
+ class ImageSize(BaseModel):
198
+ width: int
199
+ height: int
200
+
201
+ class Timing(BaseModel):
202
+ detect_ms: int
203
+ segment_ms: int
204
+ classify_ms: int
205
+ total_ms: int
206
+
207
+ class PredictResponse(BaseModel):
208
+ detections: List[Detection]
209
+ image_size: ImageSize
210
+ timing: Timing
211
+
212
+ # ── shared pipeline ───────────────────────────────────────────────────────────
213
+ async def _run_pipeline(raw: bytes) -> PredictResponse:
214
+ try:
215
+ image_rgb = np.array(Image.open(io.BytesIO(raw)).convert("RGB"))
216
+ except Exception as exc:
217
+ raise HTTPException(status_code=400, detail=f"Cannot decode image: {exc}")
218
+
219
+ h, w = image_rgb.shape[:2]
220
+ t_start = time.perf_counter()
221
+
222
+ # 1. Detection
223
+ t0 = time.perf_counter()
224
+ yolo_out = detector.predict(
225
+ source=image_rgb, imgsz=640, conf=0.25, iou=0.45,
226
+ device="cpu", verbose=False, save=False,
227
+ )
228
+ detect_ms = (time.perf_counter() - t0) * 1000
229
+ boxes_raw = yolo_out[0].boxes.data.cpu().numpy() if yolo_out else []
230
+
231
+ detections: List[Detection] = []
232
+ seg_ms_total = 0.0
233
+ cls_ms_total = 0.0
234
+
235
+ for box in boxes_raw:
236
+ x1 = max(0, int(box[0])); y1 = max(0, int(box[1]))
237
+ x2 = min(w, int(box[2])); y2 = min(h, int(box[3]))
238
+ confidence = float(box[4])
239
+ if x2 <= x1 or y2 <= y1:
240
+ continue
241
+
242
+ crop_rgb = image_rgb[y1:y2, x1:x2]
243
+
244
+ # 2. Segmentation
245
+ polygon_coords = None
246
+ masked_crop = crop_rgb
247
+ t0 = time.perf_counter()
248
+ try:
249
+ seg_results = segmentator.predict(crop_rgb)
250
+ if seg_results:
251
+ poly = seg_results[0]
252
+ polygon_coords = [[int(px) + x1, int(py) + y1] for px, py in poly.points]
253
+ masked_crop = poly.mask_polygon(crop_rgb)
254
+ except Exception as exc:
255
+ log.warning("Segmentator error: %s", exc)
256
+ seg_ms_total += (time.perf_counter() - t0) * 1000
257
+
258
+ # 3. Classification
259
+ pred_list: List[Prediction] = []
260
+ t0 = time.perf_counter()
261
+ try:
262
+ preds = classifier(masked_crop)
263
+ for p in (preds or [])[:3]:
264
+ pred_list.append(Prediction(
265
+ name = COMMON_NAMES.get(p.name, p.name),
266
+ taxon = p.name,
267
+ accuracy = round(float(p.accuracy), 4),
268
+ species_id = str(p.species_id),
269
+ ))
270
+ except Exception as exc:
271
+ log.warning("Classifier error: %s", exc)
272
+ cls_ms_total += (time.perf_counter() - t0) * 1000
273
+
274
+ detections.append(Detection(
275
+ bbox = BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2,
276
+ confidence=round(confidence, 3)),
277
+ polygon = polygon_coords,
278
+ predictions = pred_list,
279
+ ))
280
+
281
+ total_ms = (time.perf_counter() - t_start) * 1000
282
+ return PredictResponse(
283
+ detections = detections,
284
+ image_size = ImageSize(width=w, height=h),
285
+ timing = Timing(
286
+ detect_ms = round(detect_ms),
287
+ segment_ms = round(seg_ms_total),
288
+ classify_ms = round(cls_ms_total),
289
+ total_ms = round(total_ms),
290
+ ),
291
+ )
292
+
293
+ # ── auth dependencies ─────────────────────────────────────────────────────────
294
+ _api_key_header = APIKeyHeader(name="X-API-Key", auto_error=True)
295
+ _admin_key_header = APIKeyHeader(name="X-Admin-Key", auto_error=True)
296
+
297
+ def _require_api_key(key: str = Security(_api_key_header)):
298
+ if key not in _valid_key_set():
299
+ raise HTTPException(status_code=401, detail="Invalid or missing API key.")
300
+ return key
301
+
302
+ def _require_admin_key(key: str = Security(_admin_key_header)):
303
+ if key != ADMIN_KEY:
304
+ raise HTTPException(status_code=401, detail="Invalid admin key.")
305
+ return key
306
+
307
+ # ── app & middleware ──────────────────────────────────────────────────────────
308
+ app = FastAPI(
309
+ title = "SW Identifier API",
310
+ description = "Fish detection, segmentation, and species classification.",
311
+ version = "1.0.0",
312
+ lifespan = lifespan,
313
+ )
314
+
315
+ app.add_middleware(
316
+ CORSMiddleware,
317
+ allow_origins = ["*"],
318
+ allow_methods = ["GET", "POST", "DELETE"],
319
+ allow_headers = ["*"],
320
+ )
321
+
322
+ app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
323
+
324
+ # ── SPA routes ────────────────────────────────────────────────────────────────
325
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
326
+ async def root():
327
+ with open(os.path.join(STATIC_DIR, "index.html"), encoding="utf-8") as fh:
328
+ return fh.read()
329
+
330
+ @app.post("/predict", include_in_schema=False)
331
+ async def predict_spa(file: UploadFile = File(...)):
332
+ """Internal endpoint used by the SPA — no auth required."""
333
+ if not file.content_type.startswith("image/"):
334
+ raise HTTPException(status_code=400, detail="Upload must be an image file.")
335
+ return await _run_pipeline(await file.read())
336
+
337
+ # ── public API v1 ─────────────────────────────────────────────────────────────
338
+ api = APIRouter(prefix="/api/v1", tags=["SW Identifier API"])
339
+
340
+ @api.post(
341
+ "/predict",
342
+ response_model = PredictResponse,
343
+ summary = "Identify fish in an image",
344
+ description = (
345
+ "Upload an image. Returns every detected fish with its bounding box, "
346
+ "segmentation polygon, and ranked species predictions.\n\n"
347
+ "Requires an `X-API-Key` header."
348
+ ),
349
+ )
350
+ async def predict_api(
351
+ file: UploadFile = File(..., description="Image file (JPEG, PNG, WEBP, …)"),
352
+ _key: str = Depends(_require_api_key),
353
+ ):
354
+ if not file.content_type.startswith("image/"):
355
+ raise HTTPException(status_code=400, detail="Upload must be an image file.")
356
+ return await _run_pipeline(await file.read())
357
+
358
+
359
+ # ── key management ────────────────────────────────────────────────────────────
360
+ class KeyRecord(BaseModel):
361
+ key: str
362
+ name: str
363
+ created_at: str
364
+
365
+ class CreateKeyRequest(BaseModel):
366
+ name: str = "unnamed"
367
+
368
+ @api.get(
369
+ "/keys",
370
+ response_model = List[KeyRecord],
371
+ summary = "List API keys",
372
+ description = "Requires `X-Admin-Key` header.",
373
+ )
374
+ async def list_keys(_admin: str = Depends(_require_admin_key)):
375
+ return _load_keys()
376
+
377
+ @api.post(
378
+ "/keys",
379
+ response_model = KeyRecord,
380
+ status_code = 201,
381
+ summary = "Create a new API key",
382
+ description = "Requires `X-Admin-Key` header.",
383
+ )
384
+ async def create_key(
385
+ body: CreateKeyRequest = CreateKeyRequest(),
386
+ _admin: str = Depends(_require_admin_key),
387
+ ):
388
+ keys = _load_keys()
389
+ k = _new_key(body.name)
390
+ keys.append(k)
391
+ _save_keys(keys)
392
+ return k
393
+
394
+ @api.delete(
395
+ "/keys/{key}",
396
+ status_code = 204,
397
+ summary = "Revoke an API key",
398
+ description = "Requires `X-Admin-Key` header.",
399
+ )
400
+ async def revoke_key(key: str, _admin: str = Depends(_require_admin_key)):
401
+ keys = _load_keys()
402
+ remaining = [k for k in keys if k["key"] != key]
403
+ if len(remaining) == len(keys):
404
+ raise HTTPException(status_code=404, detail="Key not found.")
405
+ _save_keys(remaining)
406
+
407
+ app.include_router(api)
408
+
409
+ # ── entry point ───────────────────────────────────────────────────────────────
410
+ if __name__ == "__main__":
411
+ import uvicorn
412
+ port = int(os.environ.get("PORT", 7860))
413
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)
classification_model/__MACOSX/._database.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06af3c007f4209f0aaf82b962a2f10ed05f4b91d38839358ebf4dcf7d92adaf8
3
+ size 212
classification_model/__MACOSX/._inference.py ADDED
Binary file (212 Bytes). View file
 
classification_model/__MACOSX/._info.json ADDED
Binary file (268 Bytes). View file
 
classification_model/__MACOSX/._model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06af3c007f4209f0aaf82b962a2f10ed05f4b91d38839358ebf4dcf7d92adaf8
3
+ size 212
classification_model/__MACOSX/._requirements.txt ADDED
Binary file (212 Bytes). View file
 
classification_model/database.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e027ddba847b91c2ecfba617c604722b3a0fbd19d064e7fd09448d4e228082c0
3
+ size 143400638
classification_model/inference.py ADDED
@@ -0,0 +1,2029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Standalone Interpreter for Lightning-trained Fish Classification Models.
4
+
5
+ This module provides a self-contained classifier for loading and using models
6
+ trained with lightning_train.py. All necessary classes are included in this file
7
+ to enable standalone deployment without additional module dependencies.
8
+
9
+ Features:
10
+ - Load PyTorch Lightning checkpoints
11
+ - Support for both ViT and CNN backbones
12
+ - Multiple pooling strategies (Attention, GeM, Hybrid)
13
+ - FAISS-based nearest neighbor search (can be disabled)
14
+ - Centroid-based class filtering
15
+ - Automatic input size detection
16
+ - Robust error handling and validation
17
+ - Configurable kNN classifier (enable/disable)
18
+
19
+ Usage:
20
+ config = {
21
+ 'log_level': 'INFO',
22
+ 'dataset': {'path': 'path/to/embeddings.pt'},
23
+ 'model': {
24
+ 'checkpoint_path': 'path/to/model.ckpt',
25
+ 'backbone_model_name': 'maxvit_base_tf_224',
26
+ 'embedding_dim': 512,
27
+ 'num_classes': 639,
28
+ 'arcface_s': 64.0,
29
+ 'arcface_m': 0.2,
30
+ 'pooling_type': 'attention',
31
+ 'input_size': 224, # Optional, auto-detected if not provided
32
+ 'device': 'cuda'
33
+ },
34
+ # Optional inference parameters
35
+ 'use_knn': True, # Enable/disable kNN classifier (default: True)
36
+ 'use_albumentations': False, # Use albumentations transforms (default: False, uses torchvision)
37
+ 'arcface_min_score': 0.1,
38
+ 'centroid_fallback_score': 0.1,
39
+ 'topk_centroid': 5,
40
+ 'topk_neighbors': 10,
41
+ 'topk_arcface': 5,
42
+ 'centroid_threshold': 0.7,
43
+ 'neighbor_threshold': 0.8
44
+ }
45
+
46
+ # Initialize classifier
47
+ classifier = EmbeddingClassifier(config)
48
+
49
+ # Optional: warmup for stable performance
50
+ classifier.warmup(num_iterations=5)
51
+
52
+ # Single image inference
53
+ results = classifier(image_array) # np.ndarray [H, W, 3]
54
+
55
+ # Batch inference
56
+ results = classifier([img1, img2, img3]) # List[np.ndarray]
57
+
58
+ # Get model information
59
+ info = classifier.get_model_info()
60
+
61
+ # Context manager usage (recommended)
62
+ with EmbeddingClassifier(config) as classifier:
63
+ results = classifier(image_array)
64
+ # Auto cleanup on exit
65
+
66
+ Security Warning:
67
+ This module uses torch.load() which relies on pickle and can execute arbitrary code.
68
+ Only load checkpoints from trusted sources. The module attempts to use weights_only=True
69
+ first for safety, but falls back to weights_only=False if needed. Always verify checksums
70
+ and only load files from trusted sources in production environments.
71
+
72
+ Performance Notes:
73
+ - Memory usage scales with number of classes and database size
74
+ - Expected inference time: ~10-50ms per image (depending on backbone and device)
75
+ - FAISS indices are pre-built for faster search but require memory
76
+ - Large batches are automatically split into chunks (MAX_BATCH_SIZE) to prevent OOM errors
77
+ - For optimal performance, keep batch sizes <= 32 images
78
+ """
79
+
80
+ import logging
81
+ import time
82
+ import math
83
+ from collections import defaultdict
84
+ from dataclasses import dataclass
85
+ from pathlib import Path
86
+ from typing import Dict, List, Tuple, Union, Optional, Literal
87
+
88
+ import faiss
89
+ import numpy as np
90
+ import torch
91
+ import torch.nn as nn
92
+ import torch.nn.functional as F
93
+ from PIL import Image
94
+ from scipy.stats import entropy
95
+ from sklearn.metrics import pairwise_distances
96
+ from torchvision import transforms
97
+ import timm
98
+ from timm.models.vision_transformer import VisionTransformer
99
+
100
+ # Optional: Albumentations support (install with: pip install albumentations)
101
+ try:
102
+ import albumentations as A
103
+ from albumentations.pytorch import ToTensorV2
104
+ ALBUMENTATIONS_AVAILABLE = True
105
+ except ImportError:
106
+ ALBUMENTATIONS_AVAILABLE = False
107
+ A = None
108
+ ToTensorV2 = None
109
+
110
+
111
+ # Constants
112
+ SUPPORTED_VIT_BACKBONES = ['vit', 'beit', 'deit', 'maxvit', 'maxxvit', 'eva', 'dino', 'swin']
113
+ DEFAULT_IMAGE_SIZE = 224
114
+ GEM_POOLING_DEFAULT_P = 3.0
115
+ ATTENTION_HIDDEN_DIVISOR = 4
116
+ ATTENTION_HIDDEN_MIN = 128
117
+ NUMERICAL_EPSILON = 1e-6
118
+ WEIGHT_NORMALIZATION_EPSILON = 1e-10
119
+ MAX_BATCH_SIZE = 32 # Maximum batch size to prevent OOM
120
+ DEFAULT_WARMUP_ITERATIONS = 5
121
+ DEFAULT_ARCFACE_MIN_SCORE = 0.1
122
+ DEFAULT_CENTROID_FALLBACK_SCORE = 0.1
123
+ DEFAULT_TOPK_CENTROID = 5
124
+ DEFAULT_TOPK_NEIGHBORS = 10
125
+ DEFAULT_TOPK_ARCFACE = 5
126
+ DEFAULT_CENTROID_THRESHOLD = 0.7
127
+ DEFAULT_NEIGHBOR_THRESHOLD = 0.8
128
+ DEFAULT_USE_KNN = True
129
+ DEFAULT_RERANK_MODE = 'hybrid' # 'hybrid', 'weighted_fusion', or 'rrf'
130
+ DEFAULT_ARCFACE_WEIGHT = 0.6 # Weight for ArcFace in weighted fusion
131
+ DEFAULT_KNN_WEIGHT = 0.4 # Weight for kNN in weighted fusion
132
+ DEFAULT_RRF_K = 60 # Constant for Reciprocal Rank Fusion
133
+ DEFAULT_USE_ALBUMENTATIONS = False # Use albumentations for transforms (if available)
134
+
135
+
136
+ # Setup Logger
137
+ logger = logging.getLogger("EmbeddingClassifier")
138
+ if not logger.handlers:
139
+ handler = logging.StreamHandler()
140
+ formatter = logging.Formatter(
141
+ "[%(asctime)s] [%(levelname)s] - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
142
+ )
143
+ handler.setFormatter(formatter)
144
+ logger.addHandler(handler)
145
+
146
+
147
+ @dataclass
148
+ class PredictionResult:
149
+ """Result of a single prediction."""
150
+ name: str
151
+ species_id: int
152
+ distance: float
153
+ accuracy: float # Average similarity score (kept for backward compatibility)
154
+ image_id: Optional[str]
155
+ annotation_id: Optional[str]
156
+ drawn_fish_id: Optional[str]
157
+
158
+ @property
159
+ def average_similarity(self) -> float:
160
+ """Alias for accuracy field (which is actually average similarity)."""
161
+ return self.accuracy
162
+
163
+
164
+ # =============================================================================
165
+ # Pooling Layers
166
+ # =============================================================================
167
+
168
+ class GeMPooling(nn.Module):
169
+ """
170
+ Generalized Mean Pooling (GeM).
171
+
172
+ Popular in image retrieval tasks. Provides a learnable pooling between
173
+ average pooling (p=1) and max pooling (p→∞).
174
+
175
+ Reference: "Fine-tuning CNN Image Retrieval with No Human Annotation" (Radenović et al.)
176
+ """
177
+ def __init__(self, p: float = GEM_POOLING_DEFAULT_P, eps: float = NUMERICAL_EPSILON, learnable: bool = True):
178
+ super().__init__()
179
+ if learnable:
180
+ self.p = nn.Parameter(torch.ones(1) * p)
181
+ else:
182
+ self.register_buffer('p', torch.ones(1) * p)
183
+ self.eps = eps
184
+ self.learnable = learnable
185
+
186
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Args:
189
+ x: Feature map [B, C, H, W]
190
+ Returns:
191
+ Pooled features [B, C]
192
+ """
193
+ # Clamp both min and max for numerical stability
194
+ x_clamped = x.clamp(min=self.eps, max=1e4)
195
+ return F.adaptive_avg_pool2d(
196
+ x_clamped.pow(self.p),
197
+ 1
198
+ ).pow(1.0 / self.p.clamp(min=1e-2)).squeeze(-1).squeeze(-1)
199
+
200
+ def __repr__(self):
201
+ return f"GeMPooling(p={self.p.item():.2f}, learnable={self.learnable})"
202
+
203
+
204
+ class ViTAttentionPooling(nn.Module):
205
+ """
206
+ Attention Pooling for Vision Transformer output of shape [B, N, D].
207
+ Computes a weighted sum of patch embeddings based on learned attention.
208
+ """
209
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None):
210
+ super().__init__()
211
+ if hidden_features is None:
212
+ hidden_features = max(in_features // ATTENTION_HIDDEN_DIVISOR, ATTENTION_HIDDEN_MIN)
213
+
214
+ self.attention_net = nn.Sequential(
215
+ nn.Linear(in_features, hidden_features),
216
+ nn.Tanh(),
217
+ nn.Linear(hidden_features, 1)
218
+ )
219
+
220
+ def forward(
221
+ self,
222
+ x: torch.Tensor,
223
+ object_mask: Optional[torch.Tensor] = None,
224
+ return_attention_map: bool = False
225
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
226
+ """
227
+ Args:
228
+ x: ViT output [B, N, D]
229
+ object_mask: Not used for ViT, kept for interface compatibility
230
+ return_attention_map: Whether to return attention weights
231
+
232
+ Returns:
233
+ pooled: Pooled features [B, D]
234
+ weights: Optional attention weights [B, N, 1]
235
+ """
236
+ attention_scores = self.attention_net(x) # [B, N, 1]
237
+ weights = F.softmax(attention_scores, dim=1) # [B, N, 1]
238
+ pooled = (x * weights).sum(dim=1) # [B, D]
239
+
240
+ if return_attention_map:
241
+ return pooled, weights
242
+ return pooled, None
243
+
244
+
245
+ class AttentionPooling(nn.Module):
246
+ """
247
+ Attention-based pooling for CNN feature maps.
248
+ Weighs spatial features based on learned attention, optionally focusing
249
+ on regions within a provided object mask.
250
+ """
251
+ def __init__(self, in_channels: int, hidden_channels: Optional[int] = None):
252
+ super().__init__()
253
+ if hidden_channels is None:
254
+ hidden_channels = max(in_channels // ATTENTION_HIDDEN_DIVISOR, 32)
255
+
256
+ self.attention_conv = nn.Sequential(
257
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=1, bias=False),
258
+ nn.ReLU(inplace=True),
259
+ nn.Conv2d(hidden_channels, 1, kernel_size=1, bias=False)
260
+ )
261
+
262
+ def forward(
263
+ self,
264
+ x: torch.Tensor,
265
+ object_mask: Optional[torch.Tensor] = None,
266
+ return_attention_map: bool = False
267
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
268
+ """
269
+ Args:
270
+ x: Feature map [B, C, H, W]
271
+ object_mask: Optional binary mask [B, 1, H', W'] or [B, H', W']
272
+ return_attention_map: Whether to return attention weights
273
+
274
+ Returns:
275
+ pooled: Pooled features [B, C]
276
+ weights: Optional attention map [B, 1, H, W]
277
+ """
278
+ x_for_attn = x
279
+
280
+ if object_mask is not None:
281
+ B, _, H_feat, W_feat = x.shape
282
+ object_mask_for_x = object_mask.float().to(x.device)
283
+ if object_mask_for_x.ndim == 3:
284
+ object_mask_for_x = object_mask_for_x.unsqueeze(1)
285
+
286
+ if object_mask_for_x.shape[2] != H_feat or object_mask_for_x.shape[3] != W_feat:
287
+ object_mask_for_x_resized = F.interpolate(
288
+ object_mask_for_x, size=(H_feat, W_feat), mode='nearest'
289
+ )
290
+ else:
291
+ object_mask_for_x_resized = object_mask_for_x
292
+
293
+ x_for_attn = x * object_mask_for_x_resized
294
+
295
+ attention_scores = self.attention_conv(x_for_attn)
296
+ weights = torch.sigmoid(attention_scores)
297
+
298
+ final_weights_for_pooling = weights
299
+ if object_mask is not None:
300
+ B_w, _, H_attn, W_attn = weights.shape
301
+ object_mask_for_weights = object_mask.float().to(weights.device)
302
+ if object_mask_for_weights.ndim == 3:
303
+ object_mask_for_weights = object_mask_for_weights.unsqueeze(1)
304
+ mask_downsampled_for_weights = F.interpolate(
305
+ object_mask_for_weights, size=(H_attn, W_attn), mode='nearest'
306
+ )
307
+ final_weights_for_pooling = weights * mask_downsampled_for_weights
308
+
309
+ weighted_features = x * final_weights_for_pooling
310
+ sum_weighted_features = weighted_features.sum(dim=(2, 3))
311
+ sum_weights = final_weights_for_pooling.sum(dim=(2, 3)).clamp(min=NUMERICAL_EPSILON)
312
+ pooled = sum_weighted_features / sum_weights
313
+
314
+ if return_attention_map:
315
+ return pooled, final_weights_for_pooling
316
+ return pooled, None
317
+
318
+
319
+ class HybridPooling(nn.Module):
320
+ """
321
+ Hybrid pooling combining GeM and Attention pooling.
322
+ Concatenates GeM-pooled features with attention-pooled features.
323
+ """
324
+ def __init__(
325
+ self,
326
+ in_channels: int,
327
+ gem_p: float = GEM_POOLING_DEFAULT_P,
328
+ attention_hidden: Optional[int] = None,
329
+ output_mode: Literal['concat', 'add'] = 'concat',
330
+ ):
331
+ super().__init__()
332
+ self.gem = GeMPooling(p=gem_p, learnable=True)
333
+ self.attention = AttentionPooling(in_channels, attention_hidden)
334
+ self.output_mode = output_mode
335
+
336
+ if output_mode == 'add':
337
+ # Learnable weights for combining
338
+ self.gem_weight = nn.Parameter(torch.tensor(0.5))
339
+ self.attn_weight = nn.Parameter(torch.tensor(0.5))
340
+
341
+ def forward(
342
+ self,
343
+ x: torch.Tensor,
344
+ object_mask: Optional[torch.Tensor] = None,
345
+ return_attention_map: bool = False
346
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
347
+ gem_out = self.gem(x)
348
+ attn_out, attn_map = self.attention(x, object_mask, return_attention_map=True)
349
+
350
+ if self.output_mode == 'concat':
351
+ pooled = torch.cat([gem_out, attn_out], dim=1)
352
+ else:
353
+ w_gem = torch.sigmoid(self.gem_weight)
354
+ w_attn = torch.sigmoid(self.attn_weight)
355
+ pooled = w_gem * gem_out + w_attn * attn_out
356
+
357
+ if return_attention_map:
358
+ return pooled, attn_map
359
+ return pooled, None
360
+
361
+ @property
362
+ def output_features(self) -> int:
363
+ """Returns output feature dimension multiplier."""
364
+ return 2 if self.output_mode == 'concat' else 1
365
+
366
+
367
+ # =============================================================================
368
+ # ArcFace Head
369
+ # =============================================================================
370
+
371
+ class ArcFaceHead(nn.Module):
372
+ """
373
+ ArcFace loss head for metric learning.
374
+ Implements the additive angular margin penalty.
375
+
376
+ Reference: "ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
377
+ """
378
+ def __init__(
379
+ self,
380
+ embedding_dim: int,
381
+ num_classes: int,
382
+ s: float = 32.0,
383
+ m: float = 0.10
384
+ ):
385
+ super().__init__()
386
+ self.embedding_dim = embedding_dim
387
+ self.num_classes = num_classes
388
+ self.s = s
389
+ self.m = m
390
+
391
+ self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_dim))
392
+ nn.init.xavier_uniform_(self.weight)
393
+
394
+ # Buffers for constants
395
+ self.register_buffer('cos_m', torch.tensor(math.cos(m)))
396
+ self.register_buffer('sin_m', torch.tensor(math.sin(m)))
397
+ self.register_buffer('th', torch.tensor(math.cos(math.pi - m)))
398
+ self.register_buffer('mm', torch.tensor(math.sin(math.pi - m) * m))
399
+ self.register_buffer('eps', torch.tensor(NUMERICAL_EPSILON))
400
+
401
+ def set_margin(self, new_m: float):
402
+ """Dynamically update the margin 'm' and its related constants."""
403
+ self.m = new_m
404
+ self.cos_m.data = torch.tensor(math.cos(new_m), device=self.cos_m.device)
405
+ self.sin_m.data = torch.tensor(math.sin(new_m), device=self.sin_m.device)
406
+ self.th.data = torch.tensor(math.cos(math.pi - new_m), device=self.th.device)
407
+ self.mm.data = torch.tensor(math.sin(math.pi - new_m) * new_m, device=self.mm.device)
408
+
409
+ def forward(
410
+ self,
411
+ normalized_emb: torch.Tensor,
412
+ labels: Optional[torch.Tensor] = None
413
+ ) -> torch.Tensor:
414
+ """
415
+ Args:
416
+ normalized_emb: L2-normalized embeddings [B, D]
417
+ labels: Optional class labels [B] (required during training)
418
+
419
+ Returns:
420
+ Scaled logits [B, num_classes]
421
+ """
422
+ normalized_w = F.normalize(self.weight, dim=1)
423
+ cosine = F.linear(normalized_emb, normalized_w)
424
+
425
+ if labels is not None:
426
+ cosine_sq = cosine ** 2
427
+ sine = torch.sqrt((1.0 - cosine_sq).clamp(min=self.eps.item()))
428
+ phi = cosine * self.cos_m - sine * self.sin_m
429
+ phi = torch.where(cosine > self.th, phi, cosine - self.mm)
430
+
431
+ output = cosine.clone()
432
+ idx = labels.to(dtype=torch.long, device=cosine.device).view(-1, 1)
433
+ src = phi.gather(1, idx).to(dtype=output.dtype)
434
+ output.scatter_(1, idx, src)
435
+ output *= self.s
436
+ else:
437
+ output = cosine * self.s
438
+
439
+ return output
440
+
441
+
442
+ # =============================================================================
443
+ # Model Classes
444
+ # =============================================================================
445
+
446
+ class StableEmbeddingModelViT(nn.Module):
447
+ """
448
+ Embedding model for Vision Transformer backbones.
449
+
450
+ Supports various ViT architectures from timm including:
451
+ - BEiT v2, DeiT, ViT
452
+ - MaxViT, MaxxViT
453
+ - EVA, DINOv2
454
+ - Swin Transformer
455
+ """
456
+ def __init__(
457
+ self,
458
+ embedding_dim: int = 128,
459
+ num_classes: int = 1000,
460
+ pretrained_backbone: bool = True,
461
+ freeze_backbone_initially: bool = False,
462
+ backbone_model_name: str = 'beitv2_base_patch16_224.in1k_ft_in22k_in1k',
463
+ custom_backbone: Optional[VisionTransformer] = None,
464
+ attention_hidden_channels: Optional[int] = None,
465
+ arcface_s: float = 64.0,
466
+ arcface_m: float = 0.5,
467
+ add_bn_to_embedding: bool = False,
468
+ embedding_dropout_rate: float = 0.11,
469
+ pooling_type: str = 'attention',
470
+ ):
471
+ super().__init__()
472
+ self.embedding_dim = embedding_dim
473
+ self.num_classes = num_classes
474
+ self.pooling_type = pooling_type
475
+
476
+ if custom_backbone:
477
+ self.backbone = custom_backbone
478
+ logger.info("Using custom ViT backbone.")
479
+ else:
480
+ logger.info(f"Loading ViT backbone: {backbone_model_name}")
481
+ self.backbone: VisionTransformer = timm.create_model(
482
+ backbone_model_name,
483
+ pretrained=pretrained_backbone,
484
+ num_classes=0
485
+ )
486
+
487
+ self.backbone_out_features = self._infer_backbone_embedding_dim()
488
+ self.backbone_feature_extractor = self.backbone.forward_features
489
+
490
+ if freeze_backbone_initially:
491
+ self.freeze_backbone()
492
+
493
+ # Pooling layer
494
+ if pooling_type == 'attention':
495
+ self.pooling = ViTAttentionPooling(
496
+ in_features=self.backbone_out_features,
497
+ hidden_features=attention_hidden_channels
498
+ )
499
+ else:
500
+ # For ViT, we'll use global average pooling
501
+ self.pooling = None
502
+
503
+ # Embedding layers
504
+ embedding_layers = [nn.Linear(self.backbone_out_features, embedding_dim)]
505
+ if add_bn_to_embedding:
506
+ embedding_layers.append(nn.BatchNorm1d(embedding_dim))
507
+ if embedding_dropout_rate > 0.0:
508
+ embedding_layers.append(nn.Dropout(embedding_dropout_rate))
509
+
510
+ self.embedding_fc = nn.Sequential(*embedding_layers)
511
+ self.arcface_head = ArcFaceHead(embedding_dim, num_classes, s=arcface_s, m=arcface_m)
512
+
513
+ logger.info(f"StableEmbeddingModelViT initialized")
514
+ logger.info(f" Embedding Dim: {embedding_dim}, Num Classes: {num_classes}")
515
+ logger.info(f" ArcFace s: {arcface_s}, m: {arcface_m}")
516
+ logger.info(f" Backbone out features: {self.backbone_out_features}")
517
+ logger.info(f" Pooling type: {pooling_type}")
518
+
519
+ def _tokens_and_grid_from_features(self, features: torch.Tensor):
520
+ """Normalize backbone features into token tensor [B, N, D] + optional grid."""
521
+ if features.ndim == 4:
522
+ B, C, H, W = features.shape
523
+ tokens = features.flatten(2).transpose(1, 2).contiguous()
524
+ return tokens, (H, W)
525
+
526
+ if features.ndim == 3:
527
+ tokens = features
528
+ if hasattr(self.backbone, "cls_token") and tokens.shape[1] > 1:
529
+ tokens = tokens[:, 1:, :]
530
+
531
+ if hasattr(self.backbone, "patch_embed") and hasattr(self.backbone.patch_embed, "grid_size"):
532
+ gs = self.backbone.patch_embed.grid_size
533
+ if isinstance(gs, (tuple, list)) and len(gs) == 2 and int(gs[0]) * int(gs[1]) == tokens.shape[1]:
534
+ return tokens, (int(gs[0]), int(gs[1]))
535
+
536
+ N = tokens.shape[1]
537
+ s = int(round(math.sqrt(N)))
538
+ if s * s == N:
539
+ return tokens, (s, s)
540
+
541
+ return tokens, None
542
+
543
+ raise ValueError(f"Unsupported backbone output shape: {tuple(features.shape)}")
544
+
545
+ def freeze_backbone(self):
546
+ """Freeze all backbone parameters."""
547
+ logger.info("Freezing backbone parameters.")
548
+ for param in self.backbone.parameters():
549
+ param.requires_grad = False
550
+
551
+ def unfreeze_backbone(self, specific_layer_keywords=None, verbose=False):
552
+ """Unfreeze backbone parameters, optionally filtering by keywords."""
553
+ logger.info(f"Unfreezing backbone parameters... (Keywords: {specific_layer_keywords})")
554
+ unfrozen_count = 0
555
+ for name, param in self.backbone.named_parameters():
556
+ if specific_layer_keywords is None or any(kw in name for kw in specific_layer_keywords):
557
+ param.requires_grad = True
558
+ unfrozen_count += 1
559
+ if verbose:
560
+ logger.info(f" Unfroze: {name}")
561
+ logger.info(f"Total parameters unfrozen: {unfrozen_count}")
562
+
563
+ def _infer_backbone_embedding_dim(self) -> int:
564
+ """Infer backbone output dimension."""
565
+ for attr in ("num_features", "embed_dim"):
566
+ v = getattr(self.backbone, attr, None)
567
+ if isinstance(v, int) and v > 0:
568
+ return int(v)
569
+
570
+ def _infer_input_hw() -> int:
571
+ cfg = getattr(self.backbone, "default_cfg", None) or {}
572
+ inp = cfg.get("input_size", None)
573
+ if isinstance(inp, (tuple, list)) and len(inp) == 3:
574
+ return int(inp[1])
575
+ name = str(getattr(self.backbone, "name", "") or "")
576
+ for s in (512, 384, 256, 224):
577
+ if name.endswith(f"_{s}"):
578
+ return s
579
+ return 224
580
+
581
+ self.backbone.eval()
582
+ orig_device = next(self.backbone.parameters()).device
583
+ self.backbone.to("cpu")
584
+ with torch.no_grad():
585
+ hw = _infer_input_hw()
586
+ dummy = torch.randn(1, 3, hw, hw)
587
+ features = self.backbone.forward_features(dummy)
588
+ self.backbone.to(orig_device)
589
+
590
+ if features.ndim == 4:
591
+ return int(features.shape[1])
592
+ if features.ndim == 3:
593
+ return int(features.shape[-1])
594
+ raise ValueError(f"Unsupported output shape: {tuple(features.shape)}")
595
+
596
+ def forward(
597
+ self,
598
+ x: torch.Tensor,
599
+ labels: Optional[torch.Tensor] = None,
600
+ object_mask: Optional[torch.Tensor] = None,
601
+ return_softmax: bool = False,
602
+ return_attention_map: bool = True
603
+ ):
604
+ """
605
+ Forward pass.
606
+
607
+ Args:
608
+ x: Input images [B, 3, H, W]
609
+ labels: Optional class labels [B]
610
+ object_mask: Optional object mask (ignored for ViT)
611
+ return_softmax: Return softmax probabilities instead of logits
612
+ return_attention_map: Return attention visualization map
613
+
614
+ Returns:
615
+ emb_norm: L2-normalized embeddings [B, D]
616
+ logits/probs: Class logits or probabilities [B, num_classes]
617
+ attn_map: Optional attention map for visualization
618
+ """
619
+ features = self.backbone_feature_extractor(x)
620
+ tokens, grid = self._tokens_and_grid_from_features(features)
621
+
622
+ if self.pooling is not None:
623
+ pooled, attn_weights = self.pooling(tokens, object_mask=object_mask, return_attention_map=True)
624
+ else:
625
+ # Global average pooling
626
+ pooled = tokens.mean(dim=1)
627
+ attn_weights = None
628
+
629
+ emb_raw = self.embedding_fc(pooled)
630
+ emb_norm = F.normalize(emb_raw, p=2, dim=1)
631
+ logits = self.arcface_head(emb_norm, labels)
632
+
633
+ vis_attn_map = None
634
+ if return_attention_map and attn_weights is not None and grid is not None:
635
+ try:
636
+ B, N, _ = attn_weights.shape
637
+ H, W = grid
638
+ if H * W == N:
639
+ vis_attn_map = attn_weights.permute(0, 2, 1).reshape(B, 1, H, W)
640
+ except Exception:
641
+ vis_attn_map = None
642
+
643
+ output_attn = vis_attn_map if return_attention_map else None
644
+
645
+ if return_softmax:
646
+ probabilities = F.softmax(logits, dim=1)
647
+ return emb_norm, probabilities, output_attn
648
+ return emb_norm, logits, output_attn
649
+
650
+
651
+ class StableEmbeddingModel(nn.Module):
652
+ """
653
+ Embedding model for CNN backbones (ConvNeXt, EfficientNet, ResNet, etc.).
654
+ """
655
+ def __init__(
656
+ self,
657
+ embedding_dim: int = 256,
658
+ num_classes: int = 1000,
659
+ pretrained_backbone: bool = True,
660
+ freeze_backbone_initially: bool = False,
661
+ backbone_model_name: str = 'convnext_tiny',
662
+ custom_backbone=None,
663
+ backbone_out_features: int = 768,
664
+ attention_hidden_channels: Optional[int] = None,
665
+ arcface_s: float = 32.0,
666
+ arcface_m: float = 0.11,
667
+ add_bn_to_embedding: bool = True,
668
+ embedding_dropout_rate: float = 0.0,
669
+ pooling_type: str = 'attention',
670
+ ):
671
+ super().__init__()
672
+ self.embedding_dim = embedding_dim
673
+ self.num_classes = num_classes
674
+ self.backbone_out_features = backbone_out_features
675
+ self.pooling_type = pooling_type
676
+
677
+ if custom_backbone:
678
+ self.backbone = custom_backbone
679
+ self.backbone_feature_extractor = self.backbone
680
+ logger.info("Using custom backbone.")
681
+ elif 'convnext' in backbone_model_name:
682
+ logger.info(f"Loading backbone from timm: {backbone_model_name}")
683
+ self.backbone = timm.create_model(
684
+ backbone_model_name,
685
+ pretrained=pretrained_backbone,
686
+ features_only=True,
687
+ out_indices=(-1,)
688
+ )
689
+ self.backbone_feature_extractor = lambda x: self.backbone(x)[-1]
690
+
691
+ dummy_input = torch.randn(1, 3, 224, 224)
692
+ with torch.no_grad():
693
+ out = self.backbone_feature_extractor(dummy_input)
694
+ self.backbone_out_features = out.shape[1]
695
+ logger.info(f" Detected backbone output channels: {self.backbone_out_features}")
696
+
697
+ else:
698
+ try:
699
+ logger.info(f"Attempting to load generic backbone from timm: {backbone_model_name}")
700
+ self.backbone = timm.create_model(
701
+ backbone_model_name,
702
+ pretrained=pretrained_backbone,
703
+ num_classes=0,
704
+ global_pool=''
705
+ )
706
+ self.backbone_feature_extractor = self.backbone.forward_features
707
+
708
+ dummy_input = torch.randn(1, 3, 224, 224)
709
+ with torch.no_grad():
710
+ out = self.backbone_feature_extractor(dummy_input)
711
+ self.backbone_out_features = out.shape[1]
712
+ logger.info(f" Detected backbone output channels: {self.backbone_out_features}")
713
+ except Exception as e:
714
+ raise ValueError(f"Unsupported backbone: {backbone_model_name}. Error: {e}")
715
+
716
+ if freeze_backbone_initially:
717
+ self.freeze_backbone()
718
+
719
+ # Pooling layer
720
+ if pooling_type == 'attention':
721
+ self.pooling = AttentionPooling(
722
+ in_channels=self.backbone_out_features,
723
+ hidden_channels=attention_hidden_channels
724
+ )
725
+ pooling_out_features = self.backbone_out_features
726
+ elif pooling_type == 'gem':
727
+ self.pooling = GeMPooling(p=3.0, learnable=True)
728
+ pooling_out_features = self.backbone_out_features
729
+ elif pooling_type == 'hybrid':
730
+ self.pooling = HybridPooling(
731
+ in_channels=self.backbone_out_features,
732
+ attention_hidden=attention_hidden_channels,
733
+ output_mode='concat'
734
+ )
735
+ pooling_out_features = self.backbone_out_features * 2
736
+ else: # 'avg'
737
+ self.pooling = nn.AdaptiveAvgPool2d(1)
738
+ pooling_out_features = self.backbone_out_features
739
+
740
+ # Embedding layers
741
+ embedding_layers = [nn.Linear(pooling_out_features, embedding_dim)]
742
+ if add_bn_to_embedding:
743
+ embedding_layers.append(nn.BatchNorm1d(embedding_dim))
744
+ if embedding_dropout_rate > 0.0:
745
+ embedding_layers.append(nn.Dropout(embedding_dropout_rate))
746
+
747
+ self.embedding_fc = nn.Sequential(*embedding_layers)
748
+ self.arcface_head = ArcFaceHead(embedding_dim, num_classes, s=arcface_s, m=arcface_m)
749
+
750
+ logger.info(f"StableEmbeddingModel initialized")
751
+ logger.info(f" Embedding Dim: {embedding_dim}, Num Classes: {num_classes}")
752
+ logger.info(f" ArcFace s: {arcface_s}, m: {arcface_m}")
753
+ logger.info(f" Backbone out features: {self.backbone_out_features}")
754
+ logger.info(f" Pooling type: {pooling_type}")
755
+
756
+ def freeze_backbone(self):
757
+ """Freeze all backbone parameters."""
758
+ logger.info("Freezing backbone parameters.")
759
+ for param in self.backbone.parameters():
760
+ param.requires_grad = False
761
+
762
+ def unfreeze_backbone(self, specific_layer_keywords=None, verbose=False):
763
+ """Unfreeze backbone parameters."""
764
+ logger.info(f"Unfreezing backbone parameters... (Keywords: {specific_layer_keywords})")
765
+ unfrozen_count = 0
766
+ for name, param in self.backbone.named_parameters():
767
+ if specific_layer_keywords is None or any(kw in name for kw in specific_layer_keywords):
768
+ param.requires_grad = True
769
+ unfrozen_count += 1
770
+ if verbose:
771
+ logger.info(f" Unfroze: {name}")
772
+ logger.info(f"Total parameters unfrozen: {unfrozen_count}")
773
+
774
+ def forward(
775
+ self,
776
+ x: torch.Tensor,
777
+ labels: Optional[torch.Tensor] = None,
778
+ object_mask: Optional[torch.Tensor] = None,
779
+ return_softmax: bool = False,
780
+ return_attention_map: bool = True
781
+ ):
782
+ """
783
+ Forward pass.
784
+
785
+ Args:
786
+ x: Input images [B, 3, H, W]
787
+ labels: Optional class labels [B]
788
+ object_mask: Optional object mask for attention guidance
789
+ return_softmax: Return softmax probabilities instead of logits
790
+ return_attention_map: Return attention visualization map
791
+
792
+ Returns:
793
+ emb_norm: L2-normalized embeddings [B, D]
794
+ logits/probs: Class logits or probabilities [B, num_classes]
795
+ attn_map: Optional attention map for visualization
796
+ """
797
+ features = self.backbone_feature_extractor(x)
798
+
799
+ attn_map = None
800
+ if self.pooling_type == 'attention':
801
+ pooled, attn_map = self.pooling(features, object_mask=object_mask, return_attention_map=return_attention_map)
802
+ elif self.pooling_type == 'hybrid':
803
+ pooled, attn_map = self.pooling(features, object_mask=object_mask, return_attention_map=return_attention_map)
804
+ elif self.pooling_type == 'gem':
805
+ pooled = self.pooling(features)
806
+ else: # avg
807
+ pooled = self.pooling(features).squeeze(-1).squeeze(-1)
808
+
809
+ emb_raw = self.embedding_fc(pooled)
810
+ emb_norm = F.normalize(emb_raw, p=2, dim=1)
811
+ logits = self.arcface_head(emb_norm, labels)
812
+
813
+ output_attn = attn_map if return_attention_map else None
814
+
815
+ if return_softmax:
816
+ probabilities = F.softmax(logits, dim=1)
817
+ return emb_norm, probabilities, output_attn
818
+ return emb_norm, logits, output_attn
819
+
820
+
821
+ # =============================================================================
822
+ # Embedding Classifier
823
+ # =============================================================================
824
+
825
+ class EmbeddingClassifier:
826
+ """
827
+ Main classifier for inference using embedding-based approach.
828
+
829
+ This classifier loads a trained model and uses FAISS for fast nearest neighbor search
830
+ combined with centroid-based filtering for efficient classification.
831
+
832
+ Configuration example:
833
+ config = {
834
+ 'log_level': 'INFO',
835
+ 'dataset': {'path': 'embeddings.pt'},
836
+ 'model': {
837
+ 'checkpoint_path': 'model.ckpt',
838
+ 'backbone_model_name': 'maxvit_base_tf_224',
839
+ 'embedding_dim': 512,
840
+ 'num_classes': 639,
841
+ 'arcface_s': 64.0,
842
+ 'arcface_m': 0.2,
843
+ 'pooling_type': 'attention',
844
+ 'device': 'cuda'
845
+ },
846
+ 'use_knn': True # Enable/disable kNN classifier (default: True)
847
+ }
848
+ """
849
+
850
+ def __init__(self, config: Dict):
851
+ # Validate configuration
852
+ self._validate_config(config)
853
+
854
+ logger.setLevel(getattr(logging, config.get('log_level', 'INFO').upper()))
855
+
856
+ # Load dataset
857
+ self._load_data(config["dataset"]["path"])
858
+ self.dim = self.db_embeddings.shape[1]
859
+ self._prepare_centroids()
860
+
861
+ logger.info("Initializing EmbeddingClassifier...")
862
+
863
+ # Setup device
864
+ self.device = config["model"].get("device", "cpu")
865
+
866
+ # Load inference configuration
867
+ self.use_knn = config.get('use_knn', DEFAULT_USE_KNN)
868
+ self.arcface_min_score = config.get('arcface_min_score', DEFAULT_ARCFACE_MIN_SCORE)
869
+ self.centroid_fallback_score = config.get('centroid_fallback_score', DEFAULT_CENTROID_FALLBACK_SCORE)
870
+ self.default_topk_centroid = config.get('topk_centroid', DEFAULT_TOPK_CENTROID)
871
+ self.default_topk_neighbors = config.get('topk_neighbors', DEFAULT_TOPK_NEIGHBORS)
872
+ self.default_centroid_threshold = config.get('centroid_threshold', DEFAULT_CENTROID_THRESHOLD)
873
+ self.default_neighbor_threshold = config.get('neighbor_threshold', DEFAULT_NEIGHBOR_THRESHOLD)
874
+ self.default_topk_arcface = config.get('topk_arcface', DEFAULT_TOPK_ARCFACE)
875
+
876
+ # Reranking configuration
877
+ self.rerank_mode = config.get('rerank_mode', DEFAULT_RERANK_MODE)
878
+ self.arcface_weight = config.get('arcface_weight', DEFAULT_ARCFACE_WEIGHT)
879
+ self.knn_weight = config.get('knn_weight', DEFAULT_KNN_WEIGHT)
880
+ self.rrf_k = config.get('rrf_k', DEFAULT_RRF_K)
881
+
882
+ # Transform configuration
883
+ self.use_albumentations = config.get('use_albumentations', DEFAULT_USE_ALBUMENTATIONS)
884
+
885
+ logger.info(f"Inference config: use_knn={self.use_knn}, "
886
+ f"arcface_min={self.arcface_min_score}, "
887
+ f"centroid_fallback={self.centroid_fallback_score}, "
888
+ f"topk_centroid={self.default_topk_centroid}, "
889
+ f"topk_neighbors={self.default_topk_neighbors}, "
890
+ f"topk_arcface={self.default_topk_arcface}")
891
+ logger.info(f"Reranking config: mode={self.rerank_mode}, "
892
+ f"arcface_weight={self.arcface_weight}, "
893
+ f"knn_weight={self.knn_weight}, "
894
+ f"rrf_k={self.rrf_k}")
895
+
896
+ # Load model
897
+ self._load_model(config["model"])
898
+
899
+ # Validate embedding dimensions match
900
+ model_embedding_dim = config["model"]["embedding_dim"]
901
+ if self.dim != model_embedding_dim:
902
+ raise ValueError(
903
+ f"Embedding dimension mismatch: dataset has {self.dim}, "
904
+ f"but model expects {model_embedding_dim}"
905
+ )
906
+
907
+ # Infer input size from model or use config/default
908
+ self.input_size = self._get_input_size(config["model"])
909
+
910
+ # Setup transforms based on configuration
911
+ self.transform = self._create_transforms()
912
+
913
+ logger.info(f"Using {'Albumentations' if self.use_albumentations else 'torchvision'} transforms")
914
+
915
+ # Create ID to label mapping
916
+ self.id_to_label = {internal_id: self.keys[internal_id]['label'] for internal_id in self.keys}
917
+
918
+ # Pre-build FAISS indices for better performance (only if kNN is enabled)
919
+ if self.use_knn:
920
+ self._prepare_faiss_indices()
921
+ else:
922
+ logger.info("kNN classifier is disabled - skipping FAISS index creation")
923
+
924
+ logger.info("EmbeddingClassifier initialized successfully.")
925
+
926
+ def _create_transforms(self):
927
+ """Create image transforms based on configuration.
928
+
929
+ Returns:
930
+ Transform pipeline (Albumentations or torchvision)
931
+ """
932
+ if self.use_albumentations:
933
+ if not ALBUMENTATIONS_AVAILABLE:
934
+ logger.warning("Albumentations requested but not installed. Falling back to torchvision.")
935
+ logger.warning("Install with: pip install albumentations")
936
+ self.use_albumentations = False
937
+ else:
938
+ logger.info("Creating Albumentations transform pipeline")
939
+ return A.Compose([
940
+ A.Resize(self.input_size, self.input_size),
941
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
942
+ ToTensorV2(),
943
+ ])
944
+
945
+ # Default: torchvision transforms
946
+ logger.info("Creating torchvision transform pipeline")
947
+ return transforms.Compose([
948
+ transforms.Resize((self.input_size, self.input_size), Image.Resampling.BILINEAR),
949
+ transforms.ToTensor(),
950
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
951
+ ])
952
+
953
+ @staticmethod
954
+ def _safe_int_to_str(value) -> str:
955
+ """Safely convert value to string, handling tensors, numpy arrays, UUIDs, etc.
956
+
957
+ Args:
958
+ value: Any value (tensor, numpy array, int, float, string/UUID, etc.)
959
+
960
+ Returns:
961
+ String representation of the value
962
+ """
963
+ # Handle torch tensors
964
+ if hasattr(value, 'item'):
965
+ value = value.item()
966
+ # Handle numpy arrays
967
+ elif hasattr(value, 'tolist'):
968
+ value = value.tolist()
969
+
970
+ # If already a string, return as is
971
+ if isinstance(value, str):
972
+ return value
973
+
974
+ # Try to convert to int, fallback to str if it fails (e.g., UUIDs)
975
+ try:
976
+ return str(int(value))
977
+ except (ValueError, TypeError):
978
+ return str(value)
979
+
980
+ def _validate_config(self, config: Dict) -> None:
981
+ """Validate configuration structure and required fields."""
982
+ if not isinstance(config, dict):
983
+ raise TypeError(f"Config must be a dictionary, got {type(config)}")
984
+
985
+ # Check required keys
986
+ if "dataset" not in config:
987
+ raise ValueError("Config must contain 'dataset' key")
988
+ if "path" not in config["dataset"]:
989
+ raise ValueError("Config['dataset'] must contain 'path' key")
990
+ if "model" not in config:
991
+ raise ValueError("Config must contain 'model' key")
992
+
993
+ required_model_keys = ["checkpoint_path", "backbone_model_name", "embedding_dim", "num_classes"]
994
+ for key in required_model_keys:
995
+ if key not in config["model"]:
996
+ raise ValueError(f"Config['model'] must contain '{key}' key")
997
+
998
+ # Validate numeric parameters
999
+ if config["model"]["embedding_dim"] <= 0:
1000
+ raise ValueError(f"embedding_dim must be positive, got {config['model']['embedding_dim']}")
1001
+ if config["model"]["num_classes"] <= 0:
1002
+ raise ValueError(f"num_classes must be positive, got {config['model']['num_classes']}")
1003
+
1004
+ # Validate optional thresholds if present
1005
+ for param in ["arcface_min_score", "centroid_fallback_score", "centroid_threshold", "neighbor_threshold"]:
1006
+ if param in config and (config[param] < 0 or config[param] > 1):
1007
+ raise ValueError(f"{param} must be between 0 and 1, got {config[param]}")
1008
+
1009
+ logger.info("Configuration validated successfully")
1010
+
1011
+ def _get_input_size(self, model_config: Dict) -> int:
1012
+ """Infer input size from model config or backbone."""
1013
+ # Check if explicitly provided in config
1014
+ if "input_size" in model_config:
1015
+ return model_config["input_size"]
1016
+
1017
+ # Try to infer from backbone name
1018
+ backbone_name = model_config.get("backbone_model_name", "")
1019
+
1020
+ # Check for common size patterns in backbone name
1021
+ for size in [512, 384, 256, 224]:
1022
+ if f"_{size}" in backbone_name or f"{size}" in backbone_name:
1023
+ logger.info(f"Inferred input size {size} from backbone name")
1024
+ return size
1025
+
1026
+ # Try to get from model's default config
1027
+ if hasattr(self.model, 'backbone') and hasattr(self.model.backbone, 'default_cfg'):
1028
+ cfg = self.model.backbone.default_cfg
1029
+ if 'input_size' in cfg:
1030
+ input_size = cfg['input_size']
1031
+ if isinstance(input_size, (tuple, list)) and len(input_size) == 3:
1032
+ size = input_size[1] # Get height
1033
+ logger.info(f"Using input size {size} from model config")
1034
+ return size
1035
+
1036
+ # Default fallback
1037
+ logger.info(f"Using default input size {DEFAULT_IMAGE_SIZE}")
1038
+ return DEFAULT_IMAGE_SIZE
1039
+
1040
+ def _load_model(self, model_config: Dict):
1041
+ """Load model from Lightning checkpoint or regular PyTorch checkpoint."""
1042
+ checkpoint_path = model_config["checkpoint_path"]
1043
+
1044
+ # Validate checkpoint exists
1045
+ if not Path(checkpoint_path).exists():
1046
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
1047
+
1048
+ backbone_name = model_config.get("backbone_model_name", "maxvit_base_tf_224")
1049
+ embedding_dim = model_config.get("embedding_dim", 512)
1050
+ num_classes = model_config.get("num_classes", 639)
1051
+ arcface_s = model_config.get("arcface_s", 64.0)
1052
+ arcface_m = model_config.get("arcface_m", 0.2)
1053
+ pooling_type = model_config.get("pooling_type", "attention")
1054
+
1055
+ # Determine model class based on backbone
1056
+ is_vit = any(x in backbone_name.lower() for x in SUPPORTED_VIT_BACKBONES)
1057
+
1058
+ model_cls = StableEmbeddingModelViT if is_vit else StableEmbeddingModel
1059
+
1060
+ # Create model
1061
+ if is_vit:
1062
+ self.model = model_cls(
1063
+ embedding_dim=embedding_dim,
1064
+ num_classes=num_classes,
1065
+ backbone_model_name=backbone_name,
1066
+ arcface_s=arcface_s,
1067
+ arcface_m=arcface_m,
1068
+ pooling_type=pooling_type,
1069
+ pretrained_backbone=False, # We'll load from checkpoint
1070
+ )
1071
+ else:
1072
+ self.model = model_cls(
1073
+ embedding_dim=embedding_dim,
1074
+ num_classes=num_classes,
1075
+ backbone_model_name=backbone_name,
1076
+ arcface_s=arcface_s,
1077
+ arcface_m=arcface_m,
1078
+ pooling_type=pooling_type,
1079
+ pretrained_backbone=False, # We'll load from checkpoint
1080
+ )
1081
+
1082
+ # Load checkpoint
1083
+ # WARNING: torch.load uses pickle which can execute arbitrary code.
1084
+ # Only load checkpoints from trusted sources!
1085
+ # TODO: Add checksum verification for production use
1086
+ logger.warning(f"Loading checkpoint with weights_only=False (security risk). "
1087
+ f"Only load from trusted sources: {checkpoint_path}")
1088
+ try:
1089
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=True)
1090
+ except Exception as e:
1091
+ logger.warning(f"Failed to load with weights_only=True: {e}. Falling back to weights_only=False")
1092
+ checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'), weights_only=False)
1093
+
1094
+ # Handle Lightning checkpoint format
1095
+ if 'state_dict' in checkpoint:
1096
+ state_dict = checkpoint['state_dict']
1097
+ # Remove 'model.' prefix if present (from Lightning)
1098
+ new_state_dict = {}
1099
+ for k, v in state_dict.items():
1100
+ if k.startswith('model.'):
1101
+ new_state_dict[k[6:]] = v
1102
+ else:
1103
+ new_state_dict[k] = v
1104
+ state_dict = new_state_dict
1105
+ else:
1106
+ state_dict = checkpoint
1107
+
1108
+ # Load state dict with error handling
1109
+ try:
1110
+ self.model.load_state_dict(state_dict, strict=True)
1111
+ logger.info(f"Model loaded successfully from {checkpoint_path}")
1112
+ except RuntimeError as e:
1113
+ logger.warning(f"Strict loading failed: {str(e)[:200]}")
1114
+ result = self.model.load_state_dict(state_dict, strict=False)
1115
+ if result.missing_keys:
1116
+ logger.warning(f"Missing keys in checkpoint: {result.missing_keys[:5]}")
1117
+ if result.unexpected_keys:
1118
+ logger.warning(f"Unexpected keys in checkpoint: {result.unexpected_keys[:5]}")
1119
+ logger.info(f"Model loaded with strict=False from {checkpoint_path}")
1120
+
1121
+ self.model.to(self.device)
1122
+ self.model.eval()
1123
+
1124
+ # Log model info
1125
+ total_params = sum(p.numel() for p in self.model.parameters())
1126
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
1127
+ logger.info(f"Model loaded and moved to {self.device}")
1128
+ logger.info(f"Total parameters: {total_params:,}, Trainable: {trainable_params:,}")
1129
+
1130
+ return self.model
1131
+
1132
+ def _load_data(self, dataset_path: str) -> None:
1133
+ """Load embeddings database."""
1134
+ # Validate dataset file exists
1135
+ if not Path(dataset_path).exists():
1136
+ raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
1137
+
1138
+ try:
1139
+ logger.info(f"Loading dataset from {dataset_path}")
1140
+ try:
1141
+ data = torch.load(dataset_path, weights_only=True)
1142
+ except Exception as e:
1143
+ logger.warning(f"Failed to load dataset with weights_only=True: {e}. Using weights_only=False")
1144
+ data = torch.load(dataset_path, weights_only=False)
1145
+ except Exception as e:
1146
+ raise RuntimeError(f"Failed to load dataset from {dataset_path}: {e}")
1147
+
1148
+ # Validate required keys
1149
+ required_keys = ['embeddings', 'labels', 'image_ids', 'annotation_ids', 'drawn_fish_ids', 'labels_keys']
1150
+ for key in required_keys:
1151
+ if key not in data:
1152
+ raise ValueError(f"Dataset missing required key: '{key}'")
1153
+
1154
+ # Optimize: direct conversion to float32 numpy array
1155
+ self.db_embeddings = np.asarray(data['embeddings'], dtype=np.float32)
1156
+
1157
+ self.db_labels = np.array(data['labels'])
1158
+ self.image_ids = data['image_ids']
1159
+ self.annotation_ids = data['annotation_ids']
1160
+ self.drawn_fish_ids = data['drawn_fish_ids']
1161
+ self.keys = data['labels_keys']
1162
+
1163
+ # Validate array lengths match
1164
+ n_embeddings = len(self.db_embeddings)
1165
+ if not (len(self.db_labels) == len(self.image_ids) == len(self.annotation_ids) == len(self.drawn_fish_ids) == n_embeddings):
1166
+ raise ValueError(
1167
+ f"Array length mismatch: embeddings={n_embeddings}, labels={len(self.db_labels)}, "
1168
+ f"image_ids={len(self.image_ids)}, annotation_ids={len(self.annotation_ids)}, "
1169
+ f"drawn_fish_ids={len(self.drawn_fish_ids)}"
1170
+ )
1171
+
1172
+ self.label_to_species_id = {
1173
+ v['label']: v['species_id'] for v in self.keys.values()
1174
+ }
1175
+
1176
+ # Calculate memory usage
1177
+ embeddings_size_mb = self.db_embeddings.nbytes / (1024 * 1024)
1178
+
1179
+ logger.info(f"Dataset loaded from {dataset_path}")
1180
+ logger.info(f" Embeddings shape: {self.db_embeddings.shape}")
1181
+ logger.info(f" Embeddings memory: {embeddings_size_mb:.2f} MB")
1182
+ logger.info(f" Unique labels: {len(np.unique(self.db_labels))}")
1183
+
1184
+ def __call__(self, img: Union[np.ndarray, List[np.ndarray]]):
1185
+ """
1186
+ Perform inference on image(s).
1187
+
1188
+ Args:
1189
+ img: Single image as np.ndarray or list of images
1190
+
1191
+ Returns:
1192
+ List of prediction results for each image
1193
+ """
1194
+ if isinstance(img, np.ndarray):
1195
+ return self.inference_numpy(img)
1196
+ elif isinstance(img, list) and all(isinstance(i, np.ndarray) for i in img):
1197
+ return self.inference_numpy_batch(img)
1198
+ else:
1199
+ raise TypeError("Input must be np.ndarray or List[np.ndarray].")
1200
+
1201
+ def _preprocess_image(self, img: np.ndarray, img_index: int = 0) -> np.ndarray:
1202
+ """Preprocess a single image to RGB uint8 format.
1203
+
1204
+ Args:
1205
+ img: Input image array
1206
+ img_index: Index of image in batch (for error messages)
1207
+
1208
+ Returns:
1209
+ Preprocessed RGB image as uint8 array
1210
+ """
1211
+ # Validate input
1212
+ if img.ndim not in [2, 3]:
1213
+ raise ValueError(f"Image {img_index} must be 2D or 3D array, got shape {img.shape}")
1214
+ if img.ndim == 3 and img.shape[2] not in [1, 3, 4]:
1215
+ raise ValueError(f"Image {img_index} must have 1, 3, or 4 channels, got {img.shape[2]}")
1216
+
1217
+ # Check for empty/invalid images
1218
+ if img.size == 0 or min(img.shape[:2]) == 0:
1219
+ raise ValueError(f"Image {img_index} has invalid dimensions: {img.shape}")
1220
+
1221
+ # Convert grayscale to RGB if needed
1222
+ if img.ndim == 2 or (img.ndim == 3 and img.shape[2] == 1):
1223
+ img = np.stack([img.squeeze()] * 3, axis=-1)
1224
+ elif img.shape[2] == 4: # RGBA
1225
+ img = img[:, :, :3]
1226
+
1227
+ # Ensure correct dtype and range
1228
+ if img.dtype != np.uint8:
1229
+ max_val = img.max()
1230
+ if max_val == 0:
1231
+ logger.warning(f"Image {img_index} is completely black (all zeros)")
1232
+ img = np.zeros(img.shape, dtype=np.uint8)
1233
+ elif max_val <= 1.0:
1234
+ img = (img * 255).astype(np.uint8)
1235
+ else:
1236
+ img = img.astype(np.uint8)
1237
+
1238
+ return img
1239
+
1240
+ def inference_numpy(self, img: np.ndarray):
1241
+ """Inference on a single numpy image."""
1242
+ try:
1243
+ img = self._preprocess_image(img, img_index=0)
1244
+
1245
+ # Apply transforms based on type
1246
+ if self.use_albumentations and ALBUMENTATIONS_AVAILABLE:
1247
+ # Albumentations expects numpy array in HWC format
1248
+ transformed = self.transform(image=img)
1249
+ tensor = transformed['image'].unsqueeze(0).to(self.device)
1250
+ else:
1251
+ # torchvision expects PIL Image
1252
+ pil_img = Image.fromarray(img)
1253
+ tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
1254
+
1255
+ return self._inference_batch_tensor(tensor)[0]
1256
+ except Exception as e:
1257
+ logger.error(f"Failed to process image: {e}", exc_info=True)
1258
+ raise RuntimeError(f"Image processing failed: {e}")
1259
+
1260
+ def inference_numpy_batch(self, imgs: List[np.ndarray]):
1261
+ """Inference on a batch of numpy images."""
1262
+ if not imgs:
1263
+ raise ValueError("Empty image list provided")
1264
+
1265
+ if len(imgs) > MAX_BATCH_SIZE:
1266
+ logger.info(f"Large batch detected ({len(imgs)} images). "
1267
+ f"Will be processed in chunks of {MAX_BATCH_SIZE}.")
1268
+
1269
+ try:
1270
+ processed_tensors = []
1271
+ for i, img in enumerate(imgs):
1272
+ img = self._preprocess_image(img, img_index=i)
1273
+
1274
+ # Apply transforms based on type
1275
+ if self.use_albumentations and ALBUMENTATIONS_AVAILABLE:
1276
+ # Albumentations expects numpy array
1277
+ transformed = self.transform(image=img)
1278
+ processed_tensors.append(transformed['image'])
1279
+ else:
1280
+ # torchvision expects PIL Image
1281
+ pil_img = Image.fromarray(img)
1282
+ processed_tensors.append(self.transform(pil_img))
1283
+
1284
+ tensors = torch.stack(processed_tensors).to(self.device)
1285
+ return self._inference_batch_tensor(tensors)
1286
+ except Exception as e:
1287
+ logger.error(f"Failed to process image batch: {e}", exc_info=True)
1288
+ raise RuntimeError(f"Batch image processing failed: {e}")
1289
+
1290
+ def _inference_batch_tensor(self, tensors: torch.Tensor):
1291
+ """Internal inference on tensor batch."""
1292
+ batch_size = tensors.shape[0]
1293
+
1294
+ # Validate batch size to prevent OOM
1295
+ if batch_size > MAX_BATCH_SIZE:
1296
+ logger.warning(f"Batch size {batch_size} exceeds MAX_BATCH_SIZE={MAX_BATCH_SIZE}. "
1297
+ f"Processing in chunks to prevent OOM.")
1298
+ # Process in chunks
1299
+ all_results = []
1300
+ for i in range(0, batch_size, MAX_BATCH_SIZE):
1301
+ chunk = tensors[i:i + MAX_BATCH_SIZE]
1302
+ chunk_results = self._inference_batch_tensor(chunk)
1303
+ all_results.extend(chunk_results)
1304
+ return all_results
1305
+
1306
+ with torch.no_grad():
1307
+ embeddings, archead_logits, _ = self.model(tensors, return_softmax=False)
1308
+
1309
+ # Get top-5 ArcFace predictions
1310
+ k_arcface = min(5, archead_logits.shape[1])
1311
+ top_probabilities, top_indices = torch.topk(archead_logits, k_arcface)
1312
+
1313
+ # Store top-5 ArcFace predictions with their scores
1314
+ topk_arcface = []
1315
+ for i in range(len(top_indices)):
1316
+ batch_top5 = []
1317
+ for rank in range(k_arcface):
1318
+ pred_id = top_indices[i][rank].item()
1319
+ pred_score = top_probabilities[i][rank].item()
1320
+ batch_top5.append((pred_id, pred_score, rank))
1321
+ topk_arcface.append(batch_top5)
1322
+
1323
+ # Use kNN search if enabled
1324
+ if self.use_knn:
1325
+ knn_output = self.get_top_neighbors_from_embeddings(embeddings)
1326
+
1327
+ # Log summary instead of full output (only if debug enabled)
1328
+ if logger.isEnabledFor(logging.DEBUG):
1329
+ logger.debug(f"Inference: {len(knn_output)} predictions generated (kNN enabled)")
1330
+ else:
1331
+ # kNN disabled - use empty results
1332
+ knn_output = [{} for _ in range(len(top_indices))]
1333
+
1334
+ if logger.isEnabledFor(logging.DEBUG):
1335
+ logger.debug(f"Inference: kNN disabled, using only ArcFace predictions")
1336
+
1337
+ return self._postprocess_hybrid(knn_output, topk_arcface)
1338
+
1339
+ def _rerank_predictions(
1340
+ self,
1341
+ arcface_predictions: List[Tuple[int, float, int]],
1342
+ knn_predictions: Dict,
1343
+ mode: str = 'weighted_fusion'
1344
+ ) -> List[Tuple[int, float, str]]:
1345
+ """
1346
+ Rerank predictions using different fusion strategies.
1347
+
1348
+ Args:
1349
+ arcface_predictions: List of (label_id, score, rank) from ArcFace
1350
+ knn_predictions: Dict of {label_id: data} from kNN
1351
+ mode: Reranking mode ('weighted_fusion', 'rrf', 'hybrid')
1352
+
1353
+ Returns:
1354
+ List of (label_id, final_score, source) tuples, sorted by final_score
1355
+ """
1356
+ combined_scores = {}
1357
+
1358
+ if mode == 'weighted_fusion':
1359
+ # Weighted Fusion: combine normalized scores with weights
1360
+ # ArcFace scores are already softmax probabilities [0, 1]
1361
+ for label_id, prob, rank in arcface_predictions:
1362
+ combined_scores[label_id] = {
1363
+ 'arcface_score': prob,
1364
+ 'arcface_rank': rank,
1365
+ 'knn_score': 0.0,
1366
+ 'knn_rank': None
1367
+ }
1368
+
1369
+ # Add kNN scores (already normalized similarities [0, 1])
1370
+ for idx, (label_id, data) in enumerate(
1371
+ sorted(knn_predictions.items(),
1372
+ key=lambda x: x[1]['similarity'] / x[1]['times'],
1373
+ reverse=True)
1374
+ ):
1375
+ knn_score = data['similarity'] / data['times']
1376
+ knn_score = max(0.0, min(1.0, knn_score)) # Clamp to [0, 1]
1377
+
1378
+ if isinstance(label_id, (int, np.integer)):
1379
+ label_id_int = int(label_id)
1380
+ else:
1381
+ # Find corresponding ID for string label
1382
+ label_id_int = None
1383
+ for k, v in self.id_to_label.items():
1384
+ if v == str(label_id):
1385
+ label_id_int = k
1386
+ break
1387
+ if label_id_int is None:
1388
+ continue
1389
+
1390
+ if label_id_int not in combined_scores:
1391
+ combined_scores[label_id_int] = {
1392
+ 'arcface_score': 0.0,
1393
+ 'arcface_rank': None,
1394
+ 'knn_score': knn_score,
1395
+ 'knn_rank': idx
1396
+ }
1397
+ else:
1398
+ combined_scores[label_id_int]['knn_score'] = knn_score
1399
+ combined_scores[label_id_int]['knn_rank'] = idx
1400
+
1401
+ # Calculate weighted final scores
1402
+ final_scores = []
1403
+ for label_id, scores in combined_scores.items():
1404
+ final_score = (
1405
+ self.arcface_weight * scores['arcface_score'] +
1406
+ self.knn_weight * scores['knn_score']
1407
+ )
1408
+
1409
+ # Determine source
1410
+ if scores['arcface_rank'] is not None and scores['knn_rank'] is not None:
1411
+ source = 'both'
1412
+ elif scores['arcface_rank'] is not None:
1413
+ source = 'arcface'
1414
+ else:
1415
+ source = 'knn'
1416
+
1417
+ final_scores.append((label_id, final_score, source))
1418
+
1419
+ elif mode == 'rrf':
1420
+ # Reciprocal Rank Fusion
1421
+ for label_id, prob, rank in arcface_predictions:
1422
+ rrf_score = 1.0 / (self.rrf_k + rank)
1423
+ combined_scores[label_id] = {
1424
+ 'rrf_score': rrf_score,
1425
+ 'arcface_rank': rank
1426
+ }
1427
+
1428
+ # Add kNN RRF scores
1429
+ for idx, (label_id, data) in enumerate(
1430
+ sorted(knn_predictions.items(),
1431
+ key=lambda x: x[1]['similarity'] / x[1]['times'],
1432
+ reverse=True)
1433
+ ):
1434
+ if isinstance(label_id, (int, np.integer)):
1435
+ label_id_int = int(label_id)
1436
+ else:
1437
+ label_id_int = None
1438
+ for k, v in self.id_to_label.items():
1439
+ if v == str(label_id):
1440
+ label_id_int = k
1441
+ break
1442
+ if label_id_int is None:
1443
+ continue
1444
+
1445
+ knn_rrf = 1.0 / (self.rrf_k + idx)
1446
+
1447
+ if label_id_int not in combined_scores:
1448
+ combined_scores[label_id_int] = {
1449
+ 'rrf_score': knn_rrf,
1450
+ 'knn_rank': idx
1451
+ }
1452
+ else:
1453
+ combined_scores[label_id_int]['rrf_score'] += knn_rrf
1454
+
1455
+ final_scores = [
1456
+ (label_id, scores['rrf_score'],
1457
+ 'both' if 'arcface_rank' in scores and 'knn_rank' in scores else
1458
+ 'arcface' if 'arcface_rank' in scores else 'knn')
1459
+ for label_id, scores in combined_scores.items()
1460
+ ]
1461
+
1462
+ else: # 'hybrid' - original behavior
1463
+ # Top-5 ArcFace first, then top-5 unique kNN
1464
+ return None # Will be handled separately
1465
+
1466
+ # Sort by final score (descending)
1467
+ final_scores.sort(key=lambda x: x[1], reverse=True)
1468
+ return final_scores
1469
+
1470
+ def _postprocess_hybrid(self, knn_results, topk_arcface) -> List[PredictionResult]:
1471
+ """Combine top-5 ArcFace and top-5 unique kNN predictions.
1472
+
1473
+ Args:
1474
+ knn_results: kNN prediction results (list of dicts)
1475
+ topk_arcface: List of lists with (label_id, score, rank) tuples for top-5 ArcFace
1476
+
1477
+ Returns:
1478
+ List of PredictionResult objects:
1479
+ - Positions 1-5: Top-5 ArcFace predictions (with softmax probabilities)
1480
+ - Positions 6-10: Top-5 unique kNN predictions (not in ArcFace top-5)
1481
+ """
1482
+ results = []
1483
+
1484
+ for batch_idx in range(len(knn_results)):
1485
+ arcface_top5 = topk_arcface[batch_idx]
1486
+ knn_dict = knn_results[batch_idx]
1487
+
1488
+ # Step 1: Apply softmax to ArcFace logits to get probabilities
1489
+ arcface_scores = torch.tensor([score for _, score, _ in arcface_top5])
1490
+ arcface_probs = F.softmax(arcface_scores, dim=0).cpu().numpy()
1491
+
1492
+ # Update arcface_top5 with probabilities
1493
+ arcface_top5_with_probs = [
1494
+ (label_id, float(arcface_probs[idx]), rank)
1495
+ for idx, (label_id, score, rank) in enumerate(arcface_top5)
1496
+ ]
1497
+
1498
+ # Step 2: Rerank predictions based on mode
1499
+ if self.rerank_mode in ['weighted_fusion', 'rrf']:
1500
+ reranked = self._rerank_predictions(
1501
+ arcface_top5_with_probs,
1502
+ knn_dict,
1503
+ mode=self.rerank_mode
1504
+ )
1505
+
1506
+ # Convert reranked results to PredictionResult objects
1507
+ final_predictions = []
1508
+ for label_id, final_score, source in reranked[:10]: # Top-10
1509
+ label = self.id_to_label.get(label_id, str(label_id))
1510
+ species_id = self.label_to_species_id.get(label, -1)
1511
+
1512
+ # Get additional info from kNN if available
1513
+ image_id = None
1514
+ annotation_id = None
1515
+ drawn_fish_id = None
1516
+
1517
+ if label_id in [int(k) if isinstance(k, (int, np.integer)) else None
1518
+ for k in knn_dict.keys()]:
1519
+ for k, data in knn_dict.items():
1520
+ k_int = int(k) if isinstance(k, (int, np.integer)) else None
1521
+ if k_int == label_id and data.get('index') is not None:
1522
+ idx = data['index']
1523
+ try:
1524
+ if 0 <= idx < len(self.image_ids):
1525
+ # Convert to string, handling tensors/numpy
1526
+ image_id = self._safe_int_to_str(self.image_ids[idx])
1527
+ annotation_id = self._safe_int_to_str(self.annotation_ids[idx])
1528
+ drawn_fish_id = self._safe_int_to_str(self.drawn_fish_ids[idx])
1529
+ except (IndexError, KeyError):
1530
+ pass
1531
+ break
1532
+
1533
+ final_predictions.append(PredictionResult(
1534
+ name=label,
1535
+ species_id=species_id,
1536
+ distance=final_score,
1537
+ accuracy=final_score,
1538
+ image_id=image_id,
1539
+ annotation_id=annotation_id,
1540
+ drawn_fish_id=drawn_fish_id,
1541
+ ))
1542
+
1543
+ results.append(final_predictions)
1544
+ continue
1545
+
1546
+ # Step 3: Hybrid mode - original behavior (top-5 ArcFace + top-5 unique kNN)
1547
+ arcface_predictions = []
1548
+ arcface_label_ids = set()
1549
+
1550
+ for idx, (label_id, score, rank) in enumerate(arcface_top5):
1551
+ label = self.id_to_label.get(label_id, str(label_id))
1552
+ arcface_label_ids.add(label_id)
1553
+
1554
+ species_id = self.label_to_species_id.get(label)
1555
+ if species_id is None:
1556
+ species_id = -1
1557
+
1558
+ probability = float(arcface_probs[idx]) # Softmax probability [0, 1]
1559
+
1560
+ arcface_predictions.append(PredictionResult(
1561
+ name=label,
1562
+ species_id=species_id,
1563
+ distance=score, # Keep raw logit for reference
1564
+ accuracy=probability, # Use softmax probability
1565
+ image_id=None,
1566
+ annotation_id=None,
1567
+ drawn_fish_id=None,
1568
+ ))
1569
+
1570
+ # Step 3: Create kNN predictions (exclude those already in ArcFace top-5)
1571
+ knn_predictions = []
1572
+
1573
+ for label_id, data in knn_dict.items():
1574
+ # Handle label conversion
1575
+ if isinstance(label_id, (int, np.integer)):
1576
+ label = self.id_to_label.get(int(label_id), str(label_id))
1577
+ label_id_int = int(label_id)
1578
+ else:
1579
+ # Already a string label name
1580
+ label = str(label_id)
1581
+ # Try to find corresponding ID
1582
+ label_id_int = None
1583
+ for k, v in self.id_to_label.items():
1584
+ if v == label:
1585
+ label_id_int = k
1586
+ break
1587
+
1588
+ # Skip if this label is already in ArcFace top-5
1589
+ if label_id_int in arcface_label_ids:
1590
+ continue
1591
+
1592
+ index = data.get("index")
1593
+
1594
+ # Safely access arrays with bounds checking
1595
+ image_id = None
1596
+ annotation_id = None
1597
+ drawn_fish_id = None
1598
+
1599
+ if index is not None:
1600
+ try:
1601
+ if 0 <= index < len(self.image_ids):
1602
+ # Convert to string, handling tensors/numpy
1603
+ image_id = self._safe_int_to_str(self.image_ids[index])
1604
+ annotation_id = self._safe_int_to_str(self.annotation_ids[index])
1605
+ drawn_fish_id = self._safe_int_to_str(self.drawn_fish_ids[index])
1606
+ except (IndexError, KeyError) as e:
1607
+ logger.warning(f"Error accessing index {index}: {e}")
1608
+
1609
+ species_id = self.label_to_species_id.get(label)
1610
+ if species_id is None:
1611
+ species_id = -1
1612
+
1613
+ # Calculate average similarity score (already normalized in [0, 1] from cosine similarity)
1614
+ avg_similarity = data['similarity'] / data['times']
1615
+ # Clamp to [0, 1] for safety
1616
+ avg_similarity = max(0.0, min(1.0, avg_similarity))
1617
+
1618
+ knn_predictions.append(PredictionResult(
1619
+ name=label,
1620
+ species_id=species_id,
1621
+ distance=data['similarity'],
1622
+ accuracy=avg_similarity, # Normalized similarity score
1623
+ image_id=image_id,
1624
+ annotation_id=annotation_id,
1625
+ drawn_fish_id=drawn_fish_id,
1626
+ ))
1627
+
1628
+ # Step 4: Sort kNN predictions by average similarity (descending) and take top-5
1629
+ knn_predictions.sort(key=lambda x: x.accuracy, reverse=True)
1630
+ top5_knn = knn_predictions[:5]
1631
+
1632
+ # Step 5: Combine: ArcFace top-5 first, then unique kNN top-5
1633
+ final_predictions = arcface_predictions + top5_knn
1634
+
1635
+ results.append(final_predictions)
1636
+
1637
+ return results
1638
+
1639
+ def _postprocess(self, class_results, top1_arcface) -> List[PredictionResult]:
1640
+ """Convert raw results to PredictionResult objects with custom sorting.
1641
+
1642
+ Args:
1643
+ class_results: Raw prediction results
1644
+ top1_arcface: List of (label_id, score) tuples for top-1 ArcFace predictions
1645
+
1646
+ Returns:
1647
+ List of sorted PredictionResult objects
1648
+ """
1649
+ results = []
1650
+ for batch_idx, single_fish in enumerate(class_results):
1651
+ fish_results = []
1652
+ top1_result = None
1653
+ top1_label_id = top1_arcface[batch_idx][0]
1654
+
1655
+ for label_id, data in single_fish.items():
1656
+ # Handle label conversion - label_id can be int or string
1657
+ if isinstance(label_id, (int, np.integer)):
1658
+ label = self.id_to_label.get(int(label_id), str(label_id))
1659
+ label_id_int = int(label_id)
1660
+ else:
1661
+ # Already a string label name
1662
+ label = str(label_id)
1663
+ # Try to find corresponding ID for comparison
1664
+ label_id_int = None
1665
+ for k, v in self.id_to_label.items():
1666
+ if v == label:
1667
+ label_id_int = k
1668
+ break
1669
+
1670
+ index = data["index"]
1671
+
1672
+ # Safely access arrays with bounds checking
1673
+ image_id = None
1674
+ annotation_id = None
1675
+ drawn_fish_id = None
1676
+
1677
+ if index is not None:
1678
+ try:
1679
+ if 0 <= index < len(self.image_ids):
1680
+ # Convert to string, handling tensors/numpy
1681
+ image_id = self._safe_int_to_str(self.image_ids[index])
1682
+ annotation_id = self._safe_int_to_str(self.annotation_ids[index])
1683
+ drawn_fish_id = self._safe_int_to_str(self.drawn_fish_ids[index])
1684
+ else:
1685
+ logger.warning(f"Index {index} out of bounds for arrays of length {len(self.image_ids)}")
1686
+ except (IndexError, KeyError) as e:
1687
+ logger.warning(f"Error accessing index {index}: {e}")
1688
+
1689
+ species_id = self.label_to_species_id.get(label)
1690
+ if species_id is None:
1691
+ logger.warning(f"Unknown label '{label}' not found in label_to_species_id mapping")
1692
+ species_id = -1 # Fallback for backward compatibility
1693
+
1694
+ # Calculate average similarity score
1695
+ avg_similarity = data['similarity'] / data['times']
1696
+
1697
+ result = PredictionResult(
1698
+ name=label,
1699
+ species_id=species_id,
1700
+ distance=data['similarity'],
1701
+ accuracy=avg_similarity, # Average similarity score
1702
+ image_id=image_id,
1703
+ annotation_id=annotation_id,
1704
+ drawn_fish_id=drawn_fish_id,
1705
+ )
1706
+
1707
+ # Check if this is the top-1 ArcFace prediction
1708
+ is_arcface_top1 = (
1709
+ (label_id_int is not None and label_id_int == top1_label_id) or
1710
+ (data.get('source') == 'arcface' and data.get('arcface_rank') == 0)
1711
+ )
1712
+
1713
+ if is_arcface_top1:
1714
+ top1_result = result
1715
+ else:
1716
+ fish_results.append(result)
1717
+
1718
+ # Sort remaining results by average similarity (descending)
1719
+ fish_results.sort(key=lambda x: x.accuracy, reverse=True)
1720
+
1721
+ # Place top-1 ArcFace prediction first, then kNN results
1722
+ if top1_result is not None:
1723
+ final_results = [top1_result] + fish_results
1724
+ else:
1725
+ final_results = fish_results
1726
+ if logger.isEnabledFor(logging.WARNING):
1727
+ logger.warning(f"Top-1 ArcFace prediction not found in results for batch {batch_idx}")
1728
+
1729
+ results.append(final_results)
1730
+ return results
1731
+
1732
+ def _prepare_centroids(self) -> None:
1733
+ """Compute class centroids for efficient filtering."""
1734
+ unique_labels = np.unique(self.db_labels)
1735
+ self.label_to_centroid = {}
1736
+ skipped_labels = []
1737
+
1738
+ for label in unique_labels:
1739
+ class_embs = self.db_embeddings[self.db_labels == label]
1740
+ if len(class_embs) == 0:
1741
+ logger.warning(f"Label {label} has no embeddings, skipping")
1742
+ skipped_labels.append(label)
1743
+ continue
1744
+
1745
+ centroid = np.mean(class_embs, axis=0)
1746
+ norm = np.linalg.norm(centroid)
1747
+
1748
+ if norm < NUMERICAL_EPSILON:
1749
+ logger.warning(f"Label {label} has zero-norm centroid, using unnormalized")
1750
+ self.label_to_centroid[label] = centroid
1751
+ else:
1752
+ self.label_to_centroid[label] = centroid / norm
1753
+
1754
+ self.centroid_matrix = np.stack([self.label_to_centroid[label] for label in self.label_to_centroid])
1755
+ self.centroid_labels = list(self.label_to_centroid.keys())
1756
+
1757
+ if skipped_labels:
1758
+ logger.warning(f"Skipped {len(skipped_labels)} labels with no embeddings")
1759
+ logger.info(f"Prepared {len(self.centroid_labels)} class centroids")
1760
+
1761
+ def _prepare_faiss_indices(self) -> None:
1762
+ """Pre-build FAISS indices for each class for faster search."""
1763
+ logger.info("Building FAISS indices for each class...")
1764
+ self.class_indices = {}
1765
+ unique_labels = np.unique(self.db_labels)
1766
+
1767
+ for label in unique_labels:
1768
+ # Use np.where directly to get indices (more memory efficient)
1769
+ global_indices = np.where(self.db_labels == label)[0]
1770
+ class_embs = self.db_embeddings[global_indices]
1771
+
1772
+ if len(class_embs) > 0:
1773
+ # Create FAISS index for this class
1774
+ index = faiss.IndexFlatIP(self.dim)
1775
+ index.add(class_embs)
1776
+
1777
+ self.class_indices[label] = {
1778
+ 'index': index,
1779
+ 'global_indices': global_indices,
1780
+ 'size': len(class_embs)
1781
+ }
1782
+
1783
+ logger.info(f"Built FAISS indices for {len(self.class_indices)} classes")
1784
+
1785
+ def get_top_neighbors_from_embeddings(
1786
+ self,
1787
+ query_embeddings: Union[np.ndarray, torch.Tensor],
1788
+ topk_centroid: Optional[int] = None,
1789
+ topk_neighbors: Optional[int] = None,
1790
+ centroid_threshold: Optional[float] = None,
1791
+ neighbor_threshold: Optional[float] = None
1792
+ ) -> List[Dict[str, Dict[str, Union[float, int, None]]]]:
1793
+ """
1794
+ Find top neighbors using centroid filtering + FAISS search.
1795
+
1796
+ Args:
1797
+ query_embeddings: Query embeddings [B, D]
1798
+ topk_centroid: Number of top centroids to consider (None = use default)
1799
+ topk_neighbors: Number of neighbors to retrieve (None = use default)
1800
+ centroid_threshold: Minimum similarity to centroid (None = use default)
1801
+ neighbor_threshold: Minimum similarity to neighbor (None = use default)
1802
+
1803
+ Returns:
1804
+ List of dictionaries mapping labels to similarity scores
1805
+ """
1806
+ # Use default values if not specified
1807
+ topk_centroid = self.default_topk_centroid if topk_centroid is None else topk_centroid
1808
+ topk_neighbors = self.default_topk_neighbors if topk_neighbors is None else topk_neighbors
1809
+ centroid_threshold = self.default_centroid_threshold if centroid_threshold is None else centroid_threshold
1810
+ neighbor_threshold = self.default_neighbor_threshold if neighbor_threshold is None else neighbor_threshold
1811
+
1812
+ # Validate parameters
1813
+ if topk_centroid <= 0:
1814
+ raise ValueError(f"topk_centroid must be positive, got {topk_centroid}")
1815
+ if topk_neighbors <= 0:
1816
+ raise ValueError(f"topk_neighbors must be positive, got {topk_neighbors}")
1817
+ if not 0 <= centroid_threshold <= 1:
1818
+ raise ValueError(f"centroid_threshold must be in [0, 1], got {centroid_threshold}")
1819
+ if not 0 <= neighbor_threshold <= 1:
1820
+ raise ValueError(f"neighbor_threshold must be in [0, 1], got {neighbor_threshold}")
1821
+
1822
+ start_time = time.time()
1823
+ if logger.isEnabledFor(logging.DEBUG):
1824
+ logger.debug(f"Starting search over {len(query_embeddings)} embeddings")
1825
+
1826
+ if isinstance(query_embeddings, torch.Tensor):
1827
+ query_embeddings = query_embeddings.cpu().numpy().astype("float32")
1828
+
1829
+ # Timing breakdown
1830
+ timing = {'centroid': 0, 'faiss': 0, 'aggregation': 0}
1831
+
1832
+ # Step 1: Vectorized centroid similarity computation for all queries
1833
+ t0 = time.time()
1834
+ # Embeddings are already L2-normalized, use matrix multiplication for cosine similarity
1835
+ all_centroid_sims = np.dot(query_embeddings, self.centroid_matrix.T) # [B, num_centroids]
1836
+ timing['centroid'] = time.time() - t0
1837
+
1838
+ results = []
1839
+ for query_idx, query_emb in enumerate(query_embeddings):
1840
+ centroid_sims = all_centroid_sims[query_idx]
1841
+ top_centroid_indices = np.argsort(-centroid_sims)[:topk_centroid]
1842
+
1843
+ centroid_scores = {
1844
+ self.centroid_labels[idx]: centroid_sims[idx]
1845
+ for idx in top_centroid_indices if centroid_sims[idx] >= centroid_threshold
1846
+ }
1847
+ selected_classes = set(centroid_scores.keys())
1848
+
1849
+ if not selected_classes:
1850
+ if logger.isEnabledFor(logging.DEBUG):
1851
+ max_sim = centroid_sims[top_centroid_indices[0]] if len(top_centroid_indices) > 0 else 0
1852
+ logger.debug(f"Query {query_idx}: No classes passed centroid threshold "
1853
+ f"(max similarity: {max_sim:.3f}, threshold: {centroid_threshold})")
1854
+ results.append({})
1855
+ continue
1856
+
1857
+ # Step 2: FAISS search using pre-built indices
1858
+ t0 = time.time()
1859
+ score_map = defaultdict(lambda: {'index': None, 'similarity': 0.0, 'times': 0, 'source': 'knn'})
1860
+
1861
+ for label in selected_classes:
1862
+ if label not in self.class_indices:
1863
+ if logger.isEnabledFor(logging.DEBUG):
1864
+ logger.debug(f"Label {label} not found in class_indices, skipping")
1865
+ continue
1866
+
1867
+ class_data = self.class_indices[label]
1868
+ class_index = class_data['index']
1869
+ global_indices = class_data['global_indices']
1870
+
1871
+ # Search within this class
1872
+ k = min(topk_neighbors, class_data['size'])
1873
+ distances, indices = class_index.search(query_emb.reshape(1, -1), k)
1874
+
1875
+ # Aggregate results for this class
1876
+ for rank, idx in enumerate(indices[0]):
1877
+ sim = float(distances[0][rank])
1878
+ if sim >= neighbor_threshold:
1879
+ original_idx = int(global_indices[idx])
1880
+ score_map[label]['similarity'] += sim
1881
+ score_map[label]['times'] += 1
1882
+ score_map[label]['source'] = 'knn'
1883
+ if score_map[label]['index'] is None:
1884
+ score_map[label]['index'] = original_idx
1885
+
1886
+ timing['faiss'] += time.time() - t0
1887
+
1888
+ # Step 3: Add centroid-only predictions for classes without neighbors
1889
+ t0 = time.time()
1890
+ for label, sim in centroid_scores.items():
1891
+ if label not in score_map:
1892
+ # Use actual centroid similarity instead of fixed fallback score
1893
+ centroid_sim = max(float(sim), self.centroid_fallback_score)
1894
+ score_map[label] = {
1895
+ 'index': None,
1896
+ 'similarity': centroid_sim,
1897
+ 'times': 1,
1898
+ 'source': 'knn'
1899
+ }
1900
+ timing['aggregation'] += time.time() - t0
1901
+
1902
+ results.append(dict(score_map))
1903
+
1904
+ total_time = time.time() - start_time
1905
+ if logger.isEnabledFor(logging.DEBUG):
1906
+ logger.debug(f"Search completed in {total_time:.3f}s "
1907
+ f"(centroid: {timing['centroid']:.3f}s, "
1908
+ f"faiss: {timing['faiss']:.3f}s, "
1909
+ f"aggregation: {timing['aggregation']:.3f}s)")
1910
+
1911
+ # Log performance metrics for production monitoring (only for larger batches)
1912
+ if len(query_embeddings) > 5:
1913
+ throughput = len(query_embeddings) / total_time if total_time > 0 else 0
1914
+ logger.info(f"Batch search: {len(query_embeddings)} queries in {total_time:.3f}s "
1915
+ f"({throughput:.1f} queries/s)")
1916
+
1917
+ return results
1918
+
1919
+ def get_model_info(self) -> Dict:
1920
+ """Return model configuration and statistics.
1921
+
1922
+ Returns:
1923
+ Dictionary with model information
1924
+ """
1925
+ info = {
1926
+ 'embedding_dim': self.dim,
1927
+ 'num_classes': len(self.keys),
1928
+ 'num_embeddings': len(self.db_embeddings),
1929
+ 'device': str(self.device),
1930
+ 'input_size': self.input_size,
1931
+ 'num_centroid_classes': len(self.centroid_labels) if self.use_knn else 0,
1932
+ 'inference_config': {
1933
+ 'use_knn': self.use_knn,
1934
+ 'arcface_min_score': self.arcface_min_score,
1935
+ 'centroid_fallback_score': self.centroid_fallback_score,
1936
+ 'topk_centroid': self.default_topk_centroid,
1937
+ 'topk_neighbors': self.default_topk_neighbors,
1938
+ 'topk_arcface': self.default_topk_arcface,
1939
+ 'centroid_threshold': self.default_centroid_threshold,
1940
+ 'neighbor_threshold': self.default_neighbor_threshold,
1941
+ }
1942
+ }
1943
+
1944
+ if hasattr(self, 'model') and hasattr(self.model, 'backbone'):
1945
+ info['backbone'] = self.model.backbone.__class__.__name__
1946
+
1947
+ return info
1948
+
1949
+ def warmup(self, num_iterations: int = DEFAULT_WARMUP_ITERATIONS) -> float:
1950
+ """Warmup model with dummy data for stable performance.
1951
+
1952
+ Args:
1953
+ num_iterations: Number of warmup iterations
1954
+
1955
+ Returns:
1956
+ Average warmup time per iteration in seconds
1957
+ """
1958
+ logger.info(f"Warming up model with {num_iterations} iterations...")
1959
+ dummy = torch.randn(1, 3, self.input_size, self.input_size).to(self.device)
1960
+
1961
+ # Warmup iterations
1962
+ times = []
1963
+ for i in range(num_iterations):
1964
+ start = time.time()
1965
+ with torch.no_grad():
1966
+ self.model(dummy, return_softmax=False)
1967
+ times.append(time.time() - start)
1968
+
1969
+ avg_time = np.mean(times)
1970
+ logger.info(f"Warmup completed: avg={avg_time*1000:.2f}ms, "
1971
+ f"min={min(times)*1000:.2f}ms, max={max(times)*1000:.2f}ms")
1972
+ return avg_time
1973
+
1974
+ def __enter__(self):
1975
+ """Context manager entry."""
1976
+ return self
1977
+
1978
+ def __exit__(self, exc_type, exc_val, exc_tb):
1979
+ """Context manager exit with cleanup."""
1980
+ self.cleanup()
1981
+ return False # Don't suppress exceptions
1982
+
1983
+ def cleanup(self) -> None:
1984
+ """Release resources and cleanup."""
1985
+ logger.info("Cleaning up resources...")
1986
+
1987
+ # Clear FAISS indices with error handling (only if kNN was enabled)
1988
+ if self.use_knn and hasattr(self, 'class_indices'):
1989
+ for label, data in self.class_indices.items():
1990
+ try:
1991
+ if 'index' in data and data['index'] is not None:
1992
+ data['index'].reset()
1993
+ except Exception as e:
1994
+ logger.warning(f"Failed to reset FAISS index for label {label}: {e}")
1995
+ try:
1996
+ self.class_indices.clear()
1997
+ except Exception as e:
1998
+ logger.warning(f"Failed to clear class_indices: {e}")
1999
+
2000
+ # Move model to CPU and clear cache
2001
+ if hasattr(self, 'model'):
2002
+ try:
2003
+ self.model.cpu()
2004
+ except Exception as e:
2005
+ logger.warning(f"Failed to move model to CPU: {e}")
2006
+
2007
+ try:
2008
+ if torch.cuda.is_available():
2009
+ torch.cuda.empty_cache()
2010
+ except Exception as e:
2011
+ logger.warning(f"Failed to empty CUDA cache: {e}")
2012
+
2013
+ logger.info("Cleanup completed")
2014
+
2015
+ def __del__(self):
2016
+ """Destructor - logs warning if cleanup wasn't called.
2017
+
2018
+ Note: Do not rely on __del__ for cleanup. Always use context manager
2019
+ or explicitly call cleanup().
2020
+ """
2021
+ try:
2022
+ # Check if resources are still allocated (only relevant if kNN was enabled)
2023
+ if hasattr(self, 'use_knn') and self.use_knn:
2024
+ if hasattr(self, 'class_indices') and self.class_indices:
2025
+ logger.warning("EmbeddingClassifier destroyed without cleanup(). "
2026
+ "Use context manager or call cleanup() explicitly.")
2027
+ except Exception:
2028
+ # Silently ignore errors in destructor during interpreter shutdown
2029
+ pass
classification_model/info.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 10.0,
3
+ "num_of_class": 755,
4
+ "date": "29-01-2026",
5
+ "author": "Codahead@Andrew",
6
+ "database_size": 64445,
7
+ "samples_per_class": 100,
8
+ "val_samples": 18654,
9
+ "accuracy_arcface": 0.9498,
10
+ "accuracy_knn": 0.9552,
11
+ "top5_accuracy_arcface": 0.9922
12
+ }
classification_model/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f562a378ac98328fed6847662d03963485ef3b46c6c0a67b6d1eaf469607b295
3
+ size 346882574
classification_model/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.10.2
2
+ torchvision>=0.11.3
3
+ numpy>=1.19.2
4
+ opencv-python
5
+ logging
6
+ Pillow>=8.4.0
7
+ faiss
8
+ scipy
9
+ joblib
10
+ timm==1.0.15
11
+ albumentations==2.0.8
detector/__MACOSX/._inference.py ADDED
Binary file (368 Bytes). View file
 
detector/__MACOSX/._info.json ADDED
Binary file (368 Bytes). View file
 
detector/__MACOSX/._model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27538dceef5d2857509477fa7ec98d74c1dc372523ca10a6069b25f530d05dfa
3
+ size 641
detector/inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import List, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from ultralytics import YOLO
7
+
8
+ logger = app_logger.getChild("models.detector.ultralytics")
9
+
10
+ class YOLOInference(BaseInference):
11
+ def __init__(self, model_path: str, imsz: int = 640,
12
+ conf_threshold: float = 0.25, nms_threshold: float = 0.45,
13
+ device: str = "cpu"):
14
+ """
15
+ Initializing the YOLO class using the official Ultralytics SDK.
16
+
17
+ Args:
18
+ model_path: Path to the model file (.pt, .onnx, or .torchscript).
19
+ imsz: Input image size for the model.
20
+ conf_threshold: Confidence threshold to filter out low-confidence boxes.
21
+ nms_threshold: IoU threshold for Non-Maximum Suppression.
22
+ device: Computing device ('cpu' or 'cuda').
23
+ """
24
+ super().__init__(config={"device": device})
25
+
26
+ self.model_path = model_path
27
+ self.imsz = imsz
28
+ self.conf_threshold = conf_threshold
29
+ self.nms_threshold = nms_threshold
30
+
31
+ self.load_model(model_path)
32
+
33
+ def load_model(self, model_path: str):
34
+ """
35
+ Loads the model into memory. Ultralytics handle various formats automatically.
36
+ """
37
+ logger.info(f"[load] Loading Ultralytics model from {model_path} on {self.device}")
38
+ # The YOLO class automatically handles weights and architecture configuration
39
+ self.model = YOLO(model_path)
40
+ self.model.to(self.device)
41
+
42
+ def predict(self, im_bgr: Union[np.ndarray, List[np.ndarray]]) -> List[List[YOLOResult]]:
43
+ """
44
+ Performs end-to-end inference including preprocessing, model forward pass, and NMS.
45
+
46
+ Args:
47
+ im_bgr: A single image or a list of images in BGR format (numpy arrays).
48
+
49
+ Returns:
50
+ A list of lists containing YOLOResult objects for each input image.
51
+ """
52
+ if isinstance(im_bgr, np.ndarray):
53
+ im_bgr = [im_bgr]
54
+
55
+ start_time = time.time()
56
+ logger.debug(f"[infer] Starting detector inference on {len(im_bgr)} frame(s)")
57
+
58
+ final_results = []
59
+
60
+ try:
61
+ # Ultralytics .predict() handles letterboxing, normalization, and NMS internally.
62
+ # It also automatically scales coordinates back to the original image size.
63
+ results = self.model.predict(
64
+ source=im_bgr,
65
+ imgsz=self.imsz,
66
+ conf=self.conf_threshold,
67
+ iou=self.nms_threshold,
68
+ device=self.device,
69
+ verbose=False,
70
+ save=False
71
+ )
72
+
73
+ for i, res in enumerate(results):
74
+ # res.boxes.data contains [x1, y1, x2, y2, confidence, class_id]
75
+ boxes_data = res.boxes.data.cpu().numpy()
76
+
77
+ frame_results = []
78
+ for box in boxes_data:
79
+ # box[:5] extract [x1, y1, x2, y2, confidence]
80
+ # We pass the scaled coordinates and the original image to your YOLOResult wrapper
81
+ frame_results.append(YOLOResult(box[:5], im_bgr[i]))
82
+
83
+ final_results.append(frame_results)
84
+
85
+ return final_results
86
+
87
+ except Exception as e:
88
+ logger.error(f"Inference error occurred: {e}")
89
+ # Return empty lists to prevent the pipeline from breaking
90
+ return [[] for _ in range(len(im_bgr))]
91
+
92
+ finally:
93
+ logger.info(
94
+ f"[infer] Detector inference completed in {(time.time() - start_time) * 1000:.2f} ms "
95
+ f"for {len(im_bgr)} frame(s)"
96
+ )
detector/info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 3,
3
+ "model": "YOLO26 nano",
4
+ "input_img_size": 640,
5
+ "date": "11-02-2025",
6
+ "author": "Codahead@Andrew"
7
+ }
detector/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b786b334355fdb0c3faa9d375d70de24f3535ee6f11e9c3e26ec2e90810c03f
3
+ size 131193007
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web server
2
+ fastapi>=0.110.0
3
+ uvicorn[standard]>=0.29.0
4
+ python-multipart>=0.0.9
5
+
6
+ # ML / vision (already present in classification_model/requirements.txt)
7
+ torch>=1.10.2
8
+ torchvision>=0.11.3
9
+ numpy>=1.19.2
10
+ opencv-python
11
+ Pillow>=8.4.0
12
+ ultralytics
13
+ faiss-cpu
14
+ scipy
15
+ scikit-learn
16
+ timm==1.0.15
17
+ shapely
18
+ albumentations==2.0.8
segmentator/inference.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from PIL import Image
8
+ from shapely.geometry import LinearRing, MultiPolygon, Polygon
9
+ from torchvision import transforms
10
+
11
+
12
+ class Inference:
13
+ def __init__(self, model_path, image_size=416, threshold=0.5, poly_dict = True, max_points = 250):
14
+ self.model = torch.jit.load(model_path)
15
+ self.model.eval()
16
+ self.model.cpu()
17
+
18
+
19
+ self.max_points = max_points
20
+ self.IMAGE_SIZE = image_size
21
+ self.THRESHOLD = threshold
22
+
23
+ self.loader = transforms.Compose([
24
+ transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE), Image.BILINEAR),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
27
+ ])
28
+
29
+ def preprocess(self, image: np.ndarray):
30
+ # Converting an image to a tensor and normalizing
31
+ pil_image = Image.fromarray(image)
32
+ input_tensor = self.loader(pil_image)
33
+
34
+ return input_tensor
35
+
36
+ def postprocess(self, logit, src_size):
37
+ height, width = src_size
38
+ #(1, img_size, img_size) -> (img_size, img_size)
39
+ pr_mask = logit[0].numpy()
40
+ pr_mask = resize_logits_mask_pil(pr_mask, width, height)
41
+ pr_mask = pr_mask > self.THRESHOLD
42
+ contours = bitmap_to_polygon(pr_mask)
43
+ poly, valid_state = full_fix_contour(contours)
44
+ poly = poly.astype(int)
45
+
46
+ return poly, valid_state
47
+
48
+ def predict(self, images):
49
+ # Checking the type of the input argument and casting to a list
50
+ if isinstance(images, np.ndarray):
51
+ images = [images]
52
+
53
+ #insurance in case you somehow end up with an empty list
54
+ if len(images) == 0: return []
55
+
56
+ # Preprocessing images and saving their sizes
57
+ _input = [self.preprocess(image) for image in images]
58
+ src_sizes = [image.shape[:2] for image in images] # HEIGHT - WIDTH
59
+
60
+ _input = torch.stack(_input)
61
+
62
+ # Processing a batch of images
63
+ return self.predict_batch(_input, src_sizes)
64
+
65
+
66
+ def predict_batch(self, _input, src_sizes):
67
+ results = []
68
+ start_time = time.time()
69
+
70
+ with torch.no_grad():
71
+ logits = self.model(_input).sigmoid()
72
+
73
+ for idx, src_size in enumerate(src_sizes):
74
+ logit = logits[idx]
75
+ poly, valid_state = self.postprocess(logit, src_size)
76
+
77
+ if len(poly) != 0:
78
+ poly = approximate_to_max_point_cnt(poly, max_points=self.max_points)
79
+ else:
80
+ poly = [(0,0), (src_size[1],0), (src_size[1], src_size[0]), (0, src_size[0]), (0,0)]
81
+
82
+ results.append(FishPolygon(poly))
83
+
84
+ duration = time.time() - start_time
85
+
86
+ return results
87
+
88
+
89
+ def bitmap_to_polygon(bitmap):
90
+ """Convert masks from the form of bitmaps to polygons.
91
+
92
+ Args:
93
+ bitmap (ndarray): masks in bitmap representation.
94
+
95
+ Return:
96
+ list[ndarray]: the converted mask in polygon representation.
97
+ bool: whether the mask has holes.
98
+ """
99
+ bitmap = np.ascontiguousarray(bitmap).astype(np.uint8)
100
+ # cv2.RETR_CCOMP: retrieves all of the contours and organizes them
101
+ # into a two-level hierarchy. At the top level, there are external
102
+ # boundaries of the components. At the second level, there are
103
+ # boundaries of the holes. If there is another contour inside a hole
104
+ # of a connected component, it is still put at the top level.
105
+ # cv2.CHAIN_APPROX_NONE: stores absolutely all the contour points.
106
+ outs = cv2.findContours(bitmap, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
107
+ contours = outs[-2]
108
+ hierarchy = outs[-1]
109
+ if hierarchy is None:
110
+ return [], False
111
+ # hierarchy[i]: 4 elements, for the indexes of next, previous,
112
+ # parent, or nested contours. If there is no corresponding contour,
113
+ # it will be -1.
114
+ contours = [c.reshape(-1, 2) for c in contours]
115
+ return sorted(contours, key=len, reverse = True)
116
+
117
+ def poly_array_to_dict(polygon):
118
+ """
119
+ Converts an array of polygon points into a dictionary with labeled coordinates.
120
+
121
+ Args:
122
+ polygon (ndarray): An array of points representing the polygon. Each point is an array [x, y].
123
+
124
+ Returns:
125
+ dict: A dictionary where keys are labeled coordinates ('x1', 'y1', 'x2', 'y2', etc.)
126
+ and values are the corresponding x and y coordinates from the input array.
127
+ """
128
+ polygons_dict = {}
129
+
130
+ for i, point in enumerate(polygon):
131
+ # Add x coordinate with label 'x{i+1}'
132
+ polygons_dict[f"x{i + 1}"] = int(point[0])
133
+
134
+ # Add y coordinate with label 'y{i+1}'
135
+ polygons_dict[f"y{i + 1}"] = int(point[1])
136
+
137
+ return polygons_dict
138
+
139
+ def is_contour_valid(contour):
140
+ """
141
+ Checks if a contour is valid (i.e., its lines do not intersect).
142
+
143
+ Args:
144
+ contour (ndarray): The contour represented as an array of points.
145
+
146
+ Returns:
147
+ bool: True if the contour is valid, False otherwise.
148
+ """
149
+ if len(contour) < 3:
150
+ # A contour must contain at least three points to be a polygon
151
+ return False
152
+
153
+ polygon = Polygon(contour)
154
+
155
+ # Check for self-intersection
156
+ if not polygon.is_valid:
157
+ return False
158
+
159
+ # Check for intersection between the start and end points (to close the contour)
160
+ ring = LinearRing(contour)
161
+ if not ring.is_simple:
162
+ return False
163
+
164
+ return True
165
+
166
+ def fix_contour(contour):
167
+ """
168
+ Fixes a damaged contour (removes self-intersections).
169
+
170
+ Args:
171
+ contour (ndarray): The contour represented as an array of points.
172
+
173
+ Returns:
174
+ ndarray: The fixed contour.
175
+ """
176
+ polygon = Polygon(contour)
177
+ if polygon.is_valid:
178
+ return contour
179
+
180
+ # Fix the contour using buffer(0)
181
+ fixed_polygon = polygon.buffer(0)
182
+
183
+ if fixed_polygon.is_empty:
184
+ return np.array([]) # Return an empty array if the contour cannot be fixed
185
+
186
+ # Check the type of the returned object
187
+ if isinstance(fixed_polygon, Polygon):
188
+ fixed_contour = np.array(fixed_polygon.exterior.coords)
189
+ elif isinstance(fixed_polygon, MultiPolygon):
190
+ # If it's a MultiPolygon, choose the polygon with the largest area
191
+ largest_polygon = max(fixed_polygon.geoms, key=lambda p: p.area)
192
+ fixed_contour = np.array(largest_polygon.exterior.coords)
193
+
194
+ return fixed_contour
195
+
196
+ def full_fix_contour(poly):
197
+ """
198
+ Attempts to validate and fix a polygon contour. If the contour is valid, it returns the contour.
199
+ If the contour is invalid, it tries to fix it. If the fix is successful, it returns the fixed contour.
200
+
201
+ Args:
202
+ poly (ndarray): An array of polygons, where each polygon is represented as an array of points.
203
+
204
+ Returns:
205
+ tuple: A tuple containing the following:
206
+ - ndarray: The valid or fixed contour. If the contour cannot be fixed, an empty array is returned.
207
+ - str: A message indicating the status of the contour ("Empty Contour", "Fixed Contour", or "Can't fix").
208
+ """
209
+ if len(poly) == 0 or len(poly[0]) < 10:
210
+ return [], "Empty Contour"
211
+
212
+ contour = poly[0]
213
+
214
+ # Check the validity of the contour
215
+ if is_contour_valid(contour):
216
+ return contour, None
217
+ else:
218
+ # Attempt to fix the contour
219
+ fixed_contour = fix_contour(contour)
220
+ if fixed_contour.size > 0 and is_contour_valid(fixed_contour):
221
+ return fixed_contour, "Fixed Contour"
222
+ else:
223
+ return [], "Can't fix"
224
+
225
+ def resize_logits_mask_pil(logits_mask, width, height):
226
+ """
227
+ Resize a logits mask to the specified output shape using PIL.
228
+
229
+ Parameters:
230
+ logits_mask (np.array): Input logits mask.
231
+ width (int): Desired width of the output shape.
232
+ height (int): Desired height of the output shape.
233
+
234
+ Returns:
235
+ np.array: Resized logits mask.
236
+ """
237
+ # Convert logits mask to float32 for PIL compatibility
238
+ mask_float32 = logits_mask.astype(np.float32)
239
+
240
+ # Create PIL image from the numpy array
241
+ pil_image = Image.fromarray(mask_float32)
242
+
243
+ # Resize the image
244
+ resized_pil_image = pil_image.resize((width, height), Image.BILINEAR)
245
+
246
+ # Convert back to numpy array
247
+ resized_mask = np.array(resized_pil_image)
248
+
249
+ return resized_mask
250
+
251
+ def approximate_to_max_point_cnt(poly, epsilon=0.08, max_points = 400):
252
+ while(True):
253
+ approximations = cv2.approxPolyDP(poly, epsilon, False)
254
+
255
+ if len(approximations) > max_points:
256
+ epsilon += 0.05
257
+ else:
258
+ break
259
+ approximations = np.reshape(approximations, (-1, 2))
260
+ return approximations
261
+
262
+ def convert_local_polygons_to_global(outputs, list_of_boxes):
263
+ for box_id, box in enumerate(list_of_boxes):
264
+ x, y = box[:2]
265
+ outputs[box_id] = [(point[0] + x, point[1] + y) for point in outputs[box_id]]
266
+
267
+
268
+ class FishPolygon:
269
+ def __init__(self, points):
270
+ """
271
+ Initializes the Polygon.
272
+
273
+ Args:
274
+ points (list): List of tuples representing the polygon points.
275
+ """
276
+ self.points = points
277
+ self.width, self.height = self.calculate_dimensions()
278
+
279
+ def calculate_dimensions(self):
280
+ """
281
+ Calculates the width and height of the polygon's bounding box.
282
+
283
+ Returns:
284
+ tuple: Width and height of the polygon's bounding box.
285
+ """
286
+ x_coords = [p[0] for p in self.points]
287
+ y_coords = [p[1] for p in self.points]
288
+ width = max(x_coords) - min(x_coords)
289
+ height = max(y_coords) - min(y_coords)
290
+ return width, height
291
+
292
+ def get_area(self):
293
+ """
294
+ Calculates the area of the polygon using the Shoelace formula.
295
+
296
+ Returns:
297
+ float: Area of the polygon.
298
+ """
299
+ x = [p[0] for p in self.points]
300
+ y = [p[1] for p in self.points]
301
+ return 0.5 * abs(sum(x[i] * y[i+1] - y[i] * x[i+1] for i in range(-1, len(self.points)-1)))
302
+
303
+ def get_centroid(self):
304
+ """
305
+ Calculates the centroid of the polygon.
306
+
307
+ Returns:
308
+ tuple: Coordinates of the centroid (x, y).
309
+ """
310
+ x = [p[0] for p in self.points]
311
+ y = [p[1] for p in self.points]
312
+ area = self.get_area()
313
+ cx = sum((x[i] + x[i+1]) * (x[i] * y[i+1] - x[i+1] * y[i]) for i in range(-1, len(self.points)-1)) / (6 * area)
314
+ cy = sum((y[i] + y[i+1]) * (x[i] * y[i+1] - x[i+1] * y[i]) for i in range(-1, len(self.points)-1)) / (6 * area)
315
+ return (cx, cy)
316
+
317
+ def draw_polygon(self, image, color=(0, 255, 0), thickness=2):
318
+ """
319
+ Draws the polygon on the image.
320
+
321
+ Args:
322
+ image (numpy.ndarray): Image on which the polygon will be drawn.
323
+ color (tuple): Color of the polygon in (B, G, R) format.
324
+ thickness (int): Thickness of the polygon lines.
325
+ """
326
+ pts = np.array(self.points, np.int32)
327
+ pts = pts.reshape((-1, 1, 2))
328
+ cv2.polylines(image, [pts], isClosed=True, color=color, thickness=thickness)
329
+
330
+ def get_mask(self):
331
+ """
332
+ Creates a mask for the polygon.
333
+
334
+ Args:
335
+ image_shape (tuple): Shape of the image (height, width).
336
+
337
+ Returns:
338
+ numpy.ndarray: Mask of the polygon.
339
+ """
340
+ mask = np.zeros((self.height, self.width), dtype=np.uint8)
341
+ pts = np.array(self.points, np.int32)
342
+ pts = pts.reshape((-1, 1, 2))
343
+ cv2.fillPoly(mask, [pts], 255)
344
+ return mask
345
+
346
+ def mask_polygon(self, image):
347
+ """
348
+ Applies a mask to the polygon on the image.
349
+
350
+ Args:
351
+ image (numpy.ndarray): Image on which the mask will be applied.
352
+
353
+ Returns:
354
+ numpy.ndarray: Image with the polygon area masked.
355
+ """
356
+ mask = np.zeros_like(image)
357
+ pts = np.array(self.points, dtype=np.int32)
358
+ cv2.fillPoly(mask, [pts], (255, 255, 255))
359
+ masked_image = cv2.bitwise_and(image, mask)
360
+ return masked_image
361
+
362
+ def move_to(self, x, y):
363
+ """
364
+ Moves the polygon to a new point (x, y).
365
+
366
+ Args:
367
+ x (float): The x-coordinate of the new point.
368
+ y (float): The y-coordinate of the new point.
369
+ """
370
+ self.points = [(px + x, py + y) for px, py in self.points]
371
+ self.width, self.height = self.calculate_dimensions()
372
+
373
+ def to_points_dict(self):
374
+ """
375
+ Converts the polygon points to a dictionary format.
376
+
377
+ Returns:
378
+ dict: Dictionary with keys as 'x1', 'y1', 'x2', 'y2', etc.
379
+ """
380
+ points_dict = {}
381
+ for i, (x, y) in enumerate(self.points, start=1):
382
+ points_dict[f'x{i}'] = x
383
+ points_dict[f'y{i}'] = y
384
+ return points_dict
385
+
386
+ def __repr__(self):
387
+ return f"Polygon(points={self.points}, width={self.width}, height={self.height})"
388
+
389
+ def to_dict(self):
390
+ """
391
+ Converts the object to a dictionary.
392
+
393
+ Returns:
394
+ dict: Dictionary with the key 'points'.
395
+ """
396
+ return {
397
+ 'points': self.points,
398
+ 'width': self.width,
399
+ 'height': self.height,
400
+ 'area': self.get_area(),
401
+ 'centroid': self.get_centroid()
402
+ }
segmentator/info.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": 1,
3
+ "model": "FPN_resnet_18",
4
+ "input_img_size": 416,
5
+ "date": "31-07-2024",
6
+ "author": "Codahead@Andrew"
7
+ }
segmentator/model.ts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1ef999b2096905b217a5bcd61801d5d3e6f9141ab141d68c85be5429f559f2e
3
+ size 52491755
static/index.html ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>SW — Fish Identifier</title>
7
+ <style>
8
+ *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
9
+
10
+ :root {
11
+ --bg: #0f1117;
12
+ --surface: #1a1d27;
13
+ --border: #2d3147;
14
+ --accent: #3b82f6;
15
+ --accent2: #10b981;
16
+ --warn: #f59e0b;
17
+ --text: #e2e8f0;
18
+ --muted: #64748b;
19
+ --radius: 12px;
20
+ --shadow: 0 4px 24px rgba(0,0,0,.45);
21
+ }
22
+
23
+ body {
24
+ background: var(--bg);
25
+ color: var(--text);
26
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
27
+ min-height: 100vh;
28
+ display: flex;
29
+ flex-direction: column;
30
+ }
31
+
32
+ header {
33
+ padding: 1rem 1.5rem;
34
+ border-bottom: 1px solid var(--border);
35
+ display: flex;
36
+ align-items: center;
37
+ gap: .75rem;
38
+ background: var(--surface);
39
+ }
40
+ header h1 {
41
+ font-size: 1.2rem;
42
+ font-weight: 700;
43
+ letter-spacing: -.02em;
44
+ }
45
+ header .badge {
46
+ font-size: .7rem;
47
+ background: var(--accent);
48
+ color: #fff;
49
+ border-radius: 4px;
50
+ padding: 2px 6px;
51
+ text-transform: uppercase;
52
+ letter-spacing: .06em;
53
+ }
54
+
55
+ main {
56
+ flex: 1;
57
+ display: grid;
58
+ grid-template-columns: 1fr 360px;
59
+ gap: 1rem;
60
+ padding: 1rem;
61
+ max-width: 1400px;
62
+ width: 100%;
63
+ margin: 0 auto;
64
+ }
65
+
66
+ /* ── Drop / canvas panel ── */
67
+ .canvas-panel {
68
+ background: var(--surface);
69
+ border: 1px solid var(--border);
70
+ border-radius: var(--radius);
71
+ overflow: hidden;
72
+ position: relative;
73
+ display: flex;
74
+ align-items: center;
75
+ justify-content: center;
76
+ min-height: 480px;
77
+ }
78
+
79
+ #drop-zone {
80
+ position: absolute;
81
+ inset: 0;
82
+ display: flex;
83
+ flex-direction: column;
84
+ align-items: center;
85
+ justify-content: center;
86
+ gap: 1rem;
87
+ cursor: pointer;
88
+ transition: background .2s;
89
+ z-index: 1;
90
+ }
91
+ #drop-zone.drag-over { background: rgba(59,130,246,.08); }
92
+ #drop-zone.hidden { display: none; }
93
+
94
+ .drop-icon {
95
+ width: 72px;
96
+ height: 72px;
97
+ border-radius: 50%;
98
+ background: rgba(59,130,246,.12);
99
+ display: flex;
100
+ align-items: center;
101
+ justify-content: center;
102
+ }
103
+ .drop-icon svg { color: var(--accent); }
104
+
105
+ #drop-zone p { color: var(--muted); font-size: .9rem; }
106
+ #drop-zone b { color: var(--text); }
107
+
108
+ #file-input { display: none; }
109
+
110
+ /* canvas lives here */
111
+ #canvas-wrap {
112
+ position: relative;
113
+ display: none;
114
+ width: 100%;
115
+ height: 100%;
116
+ }
117
+ #canvas-wrap.visible { display: block; }
118
+
119
+ #base-canvas, #overlay-canvas {
120
+ position: absolute;
121
+ top: 50%;
122
+ left: 50%;
123
+ transform: translate(-50%, -50%);
124
+ }
125
+ #overlay-canvas { pointer-events: none; }
126
+
127
+ .canvas-toolbar {
128
+ position: absolute;
129
+ top: .75rem;
130
+ right: .75rem;
131
+ display: flex;
132
+ gap: .5rem;
133
+ z-index: 10;
134
+ }
135
+ .canvas-toolbar button {
136
+ padding: .35rem .75rem;
137
+ border-radius: 6px;
138
+ border: 1px solid var(--border);
139
+ background: rgba(15,17,23,.8);
140
+ color: var(--text);
141
+ font-size: .8rem;
142
+ cursor: pointer;
143
+ backdrop-filter: blur(4px);
144
+ transition: border-color .15s;
145
+ }
146
+ .canvas-toolbar button:hover { border-color: var(--accent); }
147
+
148
+ /* ── Spinner overlay ── */
149
+ #spinner {
150
+ position: absolute;
151
+ inset: 0;
152
+ background: rgba(15,17,23,.7);
153
+ display: none;
154
+ flex-direction: column;
155
+ align-items: center;
156
+ justify-content: center;
157
+ gap: 1rem;
158
+ z-index: 20;
159
+ border-radius: var(--radius);
160
+ }
161
+ #spinner.active { display: flex; }
162
+ .spin-ring {
163
+ width: 48px;
164
+ height: 48px;
165
+ border: 3px solid var(--border);
166
+ border-top-color: var(--accent);
167
+ border-radius: 50%;
168
+ animation: spin .8s linear infinite;
169
+ }
170
+ @keyframes spin { to { transform: rotate(360deg); } }
171
+
172
+ /* ── Results panel ── */
173
+ .results-panel {
174
+ display: flex;
175
+ flex-direction: column;
176
+ gap: .75rem;
177
+ overflow-y: auto;
178
+ max-height: calc(100vh - 100px);
179
+ }
180
+
181
+ .results-header {
182
+ background: var(--surface);
183
+ border: 1px solid var(--border);
184
+ border-radius: var(--radius);
185
+ padding: .75rem 1rem;
186
+ display: flex;
187
+ justify-content: space-between;
188
+ align-items: center;
189
+ }
190
+ .results-header h2 { font-size: .95rem; font-weight: 600; }
191
+
192
+ .timing-bar {
193
+ background: var(--surface);
194
+ border: 1px solid var(--border);
195
+ border-radius: var(--radius);
196
+ padding: .65rem 1rem;
197
+ display: none;
198
+ gap: 1rem;
199
+ flex-wrap: wrap;
200
+ }
201
+ .timing-bar.visible { display: flex; }
202
+ .timing-item { display: flex; flex-direction: column; gap: 2px; }
203
+ .timing-label { font-size: .65rem; color: var(--muted); text-transform: uppercase; letter-spacing: .06em; }
204
+ .timing-value { font-size: .85rem; font-weight: 600; font-variant-numeric: tabular-nums; }
205
+
206
+ .no-results {
207
+ background: var(--surface);
208
+ border: 1px solid var(--border);
209
+ border-radius: var(--radius);
210
+ padding: 2rem 1rem;
211
+ text-align: center;
212
+ color: var(--muted);
213
+ font-size: .9rem;
214
+ }
215
+
216
+ .fish-card {
217
+ background: var(--surface);
218
+ border: 1px solid var(--border);
219
+ border-radius: var(--radius);
220
+ overflow: hidden;
221
+ transition: border-color .15s;
222
+ }
223
+ .fish-card:hover { border-color: var(--accent); }
224
+ .fish-card.highlighted { border-color: var(--accent2); }
225
+
226
+ .fish-card-header {
227
+ padding: .6rem .9rem;
228
+ display: flex;
229
+ align-items: center;
230
+ gap: .5rem;
231
+ cursor: pointer;
232
+ background: rgba(255,255,255,.02);
233
+ }
234
+ .fish-number {
235
+ width: 22px;
236
+ height: 22px;
237
+ border-radius: 50%;
238
+ background: var(--accent);
239
+ color: #fff;
240
+ font-size: .7rem;
241
+ font-weight: 700;
242
+ display: flex;
243
+ align-items: center;
244
+ justify-content: center;
245
+ flex-shrink: 0;
246
+ }
247
+ .fish-card-header h3 { font-size: .85rem; font-weight: 600; flex: 1; }
248
+ .conf-badge {
249
+ font-size: .7rem;
250
+ padding: 2px 6px;
251
+ border-radius: 4px;
252
+ background: rgba(59,130,246,.15);
253
+ color: var(--accent);
254
+ }
255
+
256
+ .fish-card-body { padding: .6rem .9rem .9rem; }
257
+
258
+ .prediction-row {
259
+ display: flex;
260
+ align-items: center;
261
+ gap: .5rem;
262
+ padding: .4rem 0;
263
+ border-bottom: 1px solid var(--border);
264
+ }
265
+ .prediction-row:last-child { border-bottom: none; }
266
+
267
+ .pred-rank {
268
+ width: 16px;
269
+ font-size: .7rem;
270
+ color: var(--muted);
271
+ text-align: center;
272
+ flex-shrink: 0;
273
+ }
274
+ .pred-name { flex: 1; font-size: .83rem; display: flex; flex-direction: column; gap: 1px; }
275
+ .pred-taxon { font-size: .7rem; color: var(--muted); font-style: italic; }
276
+ .pred-bar-wrap { width: 72px; background: var(--bg); border-radius: 4px; height: 6px; flex-shrink: 0; }
277
+ .pred-bar { height: 100%; border-radius: 4px; background: var(--accent2); }
278
+ .pred-pct { font-size: .75rem; color: var(--muted); width: 38px; text-align: right; flex-shrink: 0; }
279
+
280
+ .bbox-info {
281
+ margin-top: .5rem;
282
+ font-size: .72rem;
283
+ color: var(--muted);
284
+ }
285
+
286
+ /* Legend */
287
+ .legend {
288
+ background: var(--surface);
289
+ border: 1px solid var(--border);
290
+ border-radius: var(--radius);
291
+ padding: .65rem 1rem;
292
+ display: flex;
293
+ gap: 1rem;
294
+ align-items: center;
295
+ font-size: .78rem;
296
+ color: var(--muted);
297
+ flex-wrap: wrap;
298
+ }
299
+ .legend-item { display: flex; align-items: center; gap: .4rem; }
300
+ .legend-swatch { width: 14px; height: 4px; border-radius: 2px; }
301
+
302
+ @media (max-width: 800px) {
303
+ main { grid-template-columns: 1fr; }
304
+ .results-panel { max-height: none; }
305
+ }
306
+ </style>
307
+ </head>
308
+ <body>
309
+
310
+ <header>
311
+ <svg width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#3b82f6" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
312
+ <path d="M6.5 12c.94-3.46 4.94-6 10.5-6s9.56 2.54 10.5 6c-.94 3.46-4.94 6-10.5 6S7.44 15.46 6.5 12z"/>
313
+ <circle cx="17" cy="12" r="2"/>
314
+ </svg>
315
+ <h1>SW Identifier</h1>
316
+ <span class="badge">Fish ID</span>
317
+ </header>
318
+
319
+ <main>
320
+ <!-- ── Left: canvas ── -->
321
+ <div class="canvas-panel" id="canvas-panel">
322
+ <div id="drop-zone">
323
+ <div class="drop-icon">
324
+ <svg width="32" height="32" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round">
325
+ <polyline points="16 16 12 12 8 16"/>
326
+ <line x1="12" y1="12" x2="12" y2="21"/>
327
+ <path d="M20.39 18.39A5 5 0 0 0 18 9h-1.26A8 8 0 1 0 3 16.3"/>
328
+ </svg>
329
+ </div>
330
+ <p><b>Drop an image</b> here or <b id="browse-link" style="color:var(--accent);cursor:pointer">browse</b></p>
331
+ <p>JPG, PNG, WEBP — any resolution</p>
332
+ </div>
333
+
334
+ <div id="canvas-wrap">
335
+ <canvas id="base-canvas"></canvas>
336
+ <canvas id="overlay-canvas"></canvas>
337
+ <div class="canvas-toolbar">
338
+ <button id="toggle-overlay" title="Toggle detection overlay">Overlay</button>
339
+ <button id="new-image-btn">New image</button>
340
+ </div>
341
+ </div>
342
+
343
+ <div id="spinner">
344
+ <div class="spin-ring"></div>
345
+ <span style="color:var(--muted);font-size:.85rem">Analysing…</span>
346
+ </div>
347
+ </div>
348
+
349
+ <!-- ── Right: results ── -->
350
+ <div class="results-panel">
351
+ <div class="results-header">
352
+ <h2>Detections</h2>
353
+ <span id="detection-count" style="font-size:.8rem;color:var(--muted)">—</span>
354
+ </div>
355
+
356
+ <div class="timing-bar" id="timing-bar">
357
+ <div class="timing-item">
358
+ <span class="timing-label">Detect</span>
359
+ <span class="timing-value" id="t-detect">—</span>
360
+ </div>
361
+ <div class="timing-item">
362
+ <span class="timing-label">Segment</span>
363
+ <span class="timing-value" id="t-segment">—</span>
364
+ </div>
365
+ <div class="timing-item">
366
+ <span class="timing-label">Classify</span>
367
+ <span class="timing-value" id="t-classify">—</span>
368
+ </div>
369
+ <div class="timing-item">
370
+ <span class="timing-label">Total</span>
371
+ <span class="timing-value" id="t-total">—</span>
372
+ </div>
373
+ </div>
374
+
375
+ <div id="results-body">
376
+ <div class="no-results">Drop an image to begin</div>
377
+ </div>
378
+
379
+ <div class="legend" id="legend" style="display:none">
380
+ <div class="legend-item">
381
+ <div class="legend-swatch" style="background:#3b82f6;height:2px;border:1px solid #3b82f6"></div>
382
+ Bounding box
383
+ </div>
384
+ <div class="legend-item">
385
+ <div class="legend-swatch" style="background:rgba(16,185,129,.5)"></div>
386
+ Segmentation
387
+ </div>
388
+ </div>
389
+ </div>
390
+ </main>
391
+
392
+ <input type="file" id="file-input" accept="image/*" />
393
+
394
+ <script>
395
+ (() => {
396
+ const dropZone = document.getElementById('drop-zone');
397
+ const canvasWrap = document.getElementById('canvas-wrap');
398
+ const baseCanvas = document.getElementById('base-canvas');
399
+ const overlayCanvas = document.getElementById('overlay-canvas');
400
+ const spinner = document.getElementById('spinner');
401
+ const fileInput = document.getElementById('file-input');
402
+ const resultsBody = document.getElementById('results-body');
403
+ const timingBar = document.getElementById('timing-bar');
404
+ const countEl = document.getElementById('detection-count');
405
+ const legend = document.getElementById('legend');
406
+ const panel = document.getElementById('canvas-panel');
407
+
408
+ const baseCtx = baseCanvas.getContext('2d');
409
+ const overlayCtx = overlayCanvas.getContext('2d');
410
+
411
+ let currentDetections = [];
412
+ let overlayVisible = true;
413
+ let imgNaturalW = 0, imgNaturalH = 0;
414
+ let displayW = 0, displayH = 0;
415
+ let scaleX = 1, scaleY = 1;
416
+
417
+ // ── Drag & drop ──────────────────────────────────────────────────────────
418
+ dropZone.addEventListener('dragover', e => {
419
+ e.preventDefault();
420
+ dropZone.classList.add('drag-over');
421
+ });
422
+ dropZone.addEventListener('dragleave', () => dropZone.classList.remove('drag-over'));
423
+ dropZone.addEventListener('drop', e => {
424
+ e.preventDefault();
425
+ dropZone.classList.remove('drag-over');
426
+ const file = e.dataTransfer.files[0];
427
+ if (file && file.type.startsWith('image/')) processFile(file);
428
+ });
429
+
430
+ document.getElementById('browse-link').addEventListener('click', () => fileInput.click());
431
+ dropZone.addEventListener('click', () => fileInput.click());
432
+ fileInput.addEventListener('change', () => {
433
+ if (fileInput.files[0]) processFile(fileInput.files[0]);
434
+ });
435
+
436
+ document.getElementById('toggle-overlay').addEventListener('click', () => {
437
+ overlayVisible = !overlayVisible;
438
+ overlayCanvas.style.display = overlayVisible ? '' : 'none';
439
+ });
440
+
441
+ document.getElementById('new-image-btn').addEventListener('click', reset);
442
+
443
+ // Paste support
444
+ document.addEventListener('paste', e => {
445
+ const items = e.clipboardData?.items;
446
+ if (!items) return;
447
+ for (const item of items) {
448
+ if (item.type.startsWith('image/')) {
449
+ processFile(item.getAsFile());
450
+ break;
451
+ }
452
+ }
453
+ });
454
+
455
+ // ── Core flow ─────────────────────────────────────────────────────────────
456
+ async function processFile(file) {
457
+ reset();
458
+ const objectUrl = URL.createObjectURL(file);
459
+ const img = new Image();
460
+ img.onload = async () => {
461
+ renderImage(img);
462
+ URL.revokeObjectURL(objectUrl);
463
+ await runPipeline(file);
464
+ };
465
+ img.src = objectUrl;
466
+ }
467
+
468
+ function renderImage(img) {
469
+ imgNaturalW = img.naturalWidth;
470
+ imgNaturalH = img.naturalHeight;
471
+
472
+ // Fit image into panel keeping aspect ratio
473
+ const panelW = panel.clientWidth - 2; // minus borders
474
+ const panelH = panel.clientHeight - 2;
475
+ const ratio = Math.min(panelW / imgNaturalW, panelH / imgNaturalH, 1);
476
+ displayW = Math.round(imgNaturalW * ratio);
477
+ displayH = Math.round(imgNaturalH * ratio);
478
+ scaleX = displayW / imgNaturalW;
479
+ scaleY = displayH / imgNaturalH;
480
+
481
+ for (const c of [baseCanvas, overlayCanvas]) {
482
+ c.width = displayW;
483
+ c.height = displayH;
484
+ c.style.width = displayW + 'px';
485
+ c.style.height = displayH + 'px';
486
+ }
487
+
488
+ baseCtx.drawImage(img, 0, 0, displayW, displayH);
489
+ dropZone.classList.add('hidden');
490
+ canvasWrap.classList.add('visible');
491
+ }
492
+
493
+ async function runPipeline(file) {
494
+ spinner.classList.add('active');
495
+
496
+ const form = new FormData();
497
+ form.append('file', file);
498
+
499
+ try {
500
+ const resp = await fetch('/predict', { method: 'POST', body: form });
501
+ if (!resp.ok) {
502
+ const err = await resp.json().catch(() => ({ detail: resp.statusText }));
503
+ throw new Error(err.detail || resp.statusText);
504
+ }
505
+ const data = await resp.json();
506
+ currentDetections = data.detections;
507
+ renderOverlay(data.detections);
508
+ renderResults(data);
509
+ } catch (err) {
510
+ resultsBody.innerHTML = `<div class="no-results" style="color:#f87171">Error: ${err.message}</div>`;
511
+ } finally {
512
+ spinner.classList.remove('active');
513
+ }
514
+ }
515
+
516
+ // ── Canvas overlay ────────────────────────────────────────────────────────
517
+ function renderOverlay(detections) {
518
+ overlayCtx.clearRect(0, 0, overlayCanvas.width, overlayCanvas.height);
519
+
520
+ detections.forEach((det, idx) => {
521
+ const { bbox, polygon } = det;
522
+ const color = hue(idx);
523
+
524
+ // Segmentation polygon fill
525
+ if (polygon && polygon.length > 2) {
526
+ overlayCtx.beginPath();
527
+ overlayCtx.moveTo(polygon[0][0] * scaleX, polygon[0][1] * scaleY);
528
+ for (let i = 1; i < polygon.length; i++) {
529
+ overlayCtx.lineTo(polygon[i][0] * scaleX, polygon[i][1] * scaleY);
530
+ }
531
+ overlayCtx.closePath();
532
+ overlayCtx.fillStyle = color.fill;
533
+ overlayCtx.fill();
534
+ overlayCtx.strokeStyle = color.stroke;
535
+ overlayCtx.lineWidth = 1.5;
536
+ overlayCtx.stroke();
537
+ }
538
+
539
+ // Bounding box
540
+ const bx1 = bbox.x1 * scaleX;
541
+ const by1 = bbox.y1 * scaleY;
542
+ const bw = (bbox.x2 - bbox.x1) * scaleX;
543
+ const bh = (bbox.y2 - bbox.y1) * scaleY;
544
+
545
+ overlayCtx.strokeStyle = '#3b82f6';
546
+ overlayCtx.lineWidth = 2;
547
+ overlayCtx.strokeRect(bx1, by1, bw, bh);
548
+
549
+ // Label chip
550
+ const topName = det.predictions[0]?.name || '?';
551
+ const topConf = det.predictions[0]?.accuracy ?? 0;
552
+ const label = `#${idx + 1} ${topName} ${(topConf * 100).toFixed(0)}%`;
553
+ const fontSize = Math.max(10, Math.round(11 * Math.min(scaleX, scaleY)));
554
+ overlayCtx.font = `600 ${fontSize}px -apple-system, sans-serif`;
555
+ const tw = overlayCtx.measureText(label).width;
556
+ const pad = 4;
557
+ const chipH = fontSize + pad * 2;
558
+ const chipY = Math.max(0, by1 - chipH - 2);
559
+
560
+ overlayCtx.fillStyle = '#3b82f6';
561
+ roundRect(overlayCtx, bx1, chipY, tw + pad * 2, chipH, 4);
562
+ overlayCtx.fill();
563
+
564
+ overlayCtx.fillStyle = '#fff';
565
+ overlayCtx.fillText(label, bx1 + pad, chipY + chipH - pad - 1);
566
+ });
567
+ }
568
+
569
+ function roundRect(ctx, x, y, w, h, r) {
570
+ ctx.beginPath();
571
+ ctx.moveTo(x + r, y);
572
+ ctx.lineTo(x + w - r, y);
573
+ ctx.quadraticCurveTo(x + w, y, x + w, y + r);
574
+ ctx.lineTo(x + w, y + h - r);
575
+ ctx.quadraticCurveTo(x + w, y + h, x + w - r, y + h);
576
+ ctx.lineTo(x + r, y + h);
577
+ ctx.quadraticCurveTo(x, y + h, x, y + h - r);
578
+ ctx.lineTo(x, y + r);
579
+ ctx.quadraticCurveTo(x, y, x + r, y);
580
+ ctx.closePath();
581
+ }
582
+
583
+ const PALETTE = [
584
+ { fill: 'rgba(16,185,129,.25)', stroke: '#10b981' },
585
+ { fill: 'rgba(245,158,11,.25)', stroke: '#f59e0b' },
586
+ { fill: 'rgba(239,68,68,.25)', stroke: '#ef4444' },
587
+ { fill: 'rgba(168,85,247,.25)', stroke: '#a855f7' },
588
+ { fill: 'rgba(236,72,153,.25)', stroke: '#ec4899' },
589
+ ];
590
+ function hue(i) { return PALETTE[i % PALETTE.length]; }
591
+
592
+ // ── Results panel ─────────────────────────────────────────────────────────
593
+ function renderResults(data) {
594
+ const { detections, timing } = data;
595
+
596
+ // Timing bar
597
+ document.getElementById('t-detect').textContent = timing.detect_ms + ' ms';
598
+ document.getElementById('t-segment').textContent = timing.segment_ms + ' ms';
599
+ document.getElementById('t-classify').textContent = timing.classify_ms + ' ms';
600
+ document.getElementById('t-total').textContent = timing.total_ms + ' ms';
601
+ timingBar.classList.add('visible');
602
+
603
+ countEl.textContent = detections.length
604
+ ? `${detections.length} fish found`
605
+ : 'No fish detected';
606
+
607
+ legend.style.display = detections.length ? '' : 'none';
608
+
609
+ if (!detections.length) {
610
+ resultsBody.innerHTML = '<div class="no-results">No fish detected in this image</div>';
611
+ return;
612
+ }
613
+
614
+ resultsBody.innerHTML = '';
615
+ detections.forEach((det, idx) => {
616
+ const card = document.createElement('div');
617
+ card.className = 'fish-card';
618
+ card.dataset.idx = idx;
619
+
620
+ const topName = det.predictions[0]?.name || 'Unknown';
621
+ const detConf = (det.bbox.confidence * 100).toFixed(0);
622
+
623
+ card.innerHTML = `
624
+ <div class="fish-card-header">
625
+ <div class="fish-number">${idx + 1}</div>
626
+ <h3>${esc(topName)}</h3>
627
+ <span class="conf-badge">det ${detConf}%</span>
628
+ </div>
629
+ <div class="fish-card-body">
630
+ ${det.predictions.length
631
+ ? det.predictions.map((p, r) => predRow(p, r)).join('')
632
+ : '<span style="color:var(--muted);font-size:.8rem">No classification</span>'
633
+ }
634
+ <div class="bbox-info">
635
+ Box: ${det.bbox.x1},${det.bbox.y1} → ${det.bbox.x2},${det.bbox.y2}
636
+ &nbsp;·&nbsp;
637
+ ${det.bbox.x2 - det.bbox.x1}×${det.bbox.y2 - det.bbox.y1} px
638
+ ${det.polygon ? '&nbsp;·&nbsp;seg ✓' : ''}
639
+ </div>
640
+ </div>`;
641
+
642
+ card.querySelector('.fish-card-header').addEventListener('mouseenter', () => {
643
+ card.classList.add('highlighted');
644
+ highlightDetection(idx);
645
+ });
646
+ card.querySelector('.fish-card-header').addEventListener('mouseleave', () => {
647
+ card.classList.remove('highlighted');
648
+ renderOverlay(currentDetections);
649
+ });
650
+
651
+ resultsBody.appendChild(card);
652
+ });
653
+ }
654
+
655
+ function predRow(p, rank) {
656
+ const pct = (p.accuracy * 100).toFixed(1);
657
+ const bar = Math.round(p.accuracy * 100);
658
+ return `<div class="prediction-row">
659
+ <span class="pred-rank">${rank + 1}</span>
660
+ <span class="pred-name">${esc(p.name)}<span class="pred-taxon">${esc(p.taxon)}</span></span>
661
+ <div class="pred-bar-wrap"><div class="pred-bar" style="width:${bar}%"></div></div>
662
+ <span class="pred-pct">${pct}%</span>
663
+ </div>`;
664
+ }
665
+
666
+ function highlightDetection(idx) {
667
+ renderOverlay(currentDetections);
668
+ // draw a brighter ring around the selected detection
669
+ const det = currentDetections[idx];
670
+ if (!det) return;
671
+ const { bbox } = det;
672
+ overlayCtx.strokeStyle = '#facc15';
673
+ overlayCtx.lineWidth = 3;
674
+ overlayCtx.strokeRect(
675
+ bbox.x1 * scaleX - 2,
676
+ bbox.y1 * scaleY - 2,
677
+ (bbox.x2 - bbox.x1) * scaleX + 4,
678
+ (bbox.y2 - bbox.y1) * scaleY + 4,
679
+ );
680
+ }
681
+
682
+ function reset() {
683
+ currentDetections = [];
684
+ overlayVisible = true;
685
+ overlayCanvas.style.display = '';
686
+ overlayCtx.clearRect(0, 0, overlayCanvas.width, overlayCanvas.height);
687
+ baseCtx.clearRect(0, 0, baseCanvas.width, baseCanvas.height);
688
+ canvasWrap.classList.remove('visible');
689
+ dropZone.classList.remove('hidden');
690
+ resultsBody.innerHTML = '<div class="no-results">Drop an image to begin</div>';
691
+ timingBar.classList.remove('visible');
692
+ countEl.textContent = '—';
693
+ legend.style.display = 'none';
694
+ fileInput.value = '';
695
+ }
696
+
697
+ function esc(s) {
698
+ return String(s)
699
+ .replace(/&/g, '&amp;')
700
+ .replace(/</g, '&lt;')
701
+ .replace(/>/g, '&gt;')
702
+ .replace(/"/g, '&quot;');
703
+ }
704
+ })();
705
+ </script>
706
+ </body>
707
+ </html>
taxons.csv ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ common_name,taxon
2
+ Flat needlefish,Ablennes hians
3
+ Bream,Abramis brama
4
+ Sergeant major,Abudefduf saxatilis
5
+ Blackspot sergeant,Abudefduf sordidus
6
+ Mud sunfish,Acantharchus pomotis
7
+ white-cheeked blenny,Acanthemblemaria johnsoni
8
+ Spinyhead blenny,Acanthemblemaria spinosa
9
+ Wahoo,Acanthocybium solandri
10
+ Australian sea bream,Acanthopagrus australis
11
+ Bream,Acanthopagrus butcheri
12
+ Yellowfin seabream,Acanthopagrus latus
13
+ Scrawled cowfish,Acanthostracion quadricornis
14
+ Doctorfish,Acanthurus chirurgus
15
+ blue tang,Acanthurus coeruleus
16
+ Palani,Acanthurus dussumieri
17
+ Eastern Blue Groper,Achoerodus viridis
18
+ lake sturgeon,Acipenser fulvescens
19
+ European sturgeon,Acipenser sturio
20
+ white sturgeon,Acipenser transmontanus
21
+ Lesser guitarfish,Acroteriobatus annulatus
22
+ Spotted eagle ray,Aetobatus narinari
23
+ Bonefish,Albula vulpes
24
+ Bleak,Alburnus alburnus
25
+ Yellow-eye mullet,Aldrichetta forsteri
26
+ African pompano,Alectis ciliaris
27
+ Thresher shark,Alopias vulpinus
28
+ Skipjack herring,Alosa chrysochloris
29
+ Hickory shad,Alosa mediocris
30
+ Alewife,Alosa pseudoharengus
31
+ American shad,Alosa sapidissima
32
+ leatherjacket,Aluterus monoceros
33
+ Orange filefish,Aluterus schoepfii
34
+ Scrawled filefish,Aluterus scriptus
35
+ Shadow bass,Ambloplites ariommus
36
+ Rock bass,Ambloplites rupestris
37
+ White bullhead,Ameiurus catus
38
+ Black bullhead,Ameiurus melas
39
+ Yellow bullhead,Ameiurus natalis
40
+ Brown bullhead,Ameiurus nebulosus
41
+ Bowfin,Amia calva
42
+ Barred grunter,Amniataba percoides
43
+ Midas cichlid,Amphilophus citrinellus
44
+ Clown anemonefish,Amphiprion ocellaris
45
+ Orange clownfish,Amphiprion percula
46
+ Barred surfperch,Amphistichus argenteus
47
+ Redtail surfperch,Amphistichus rhodoterus
48
+ Bay anchovy,Anchoa mitchilli
49
+ European eel,Anguilla anguilla
50
+ Speckled longfin eel,Anguilla reinhardtii
51
+ American eel,Anguilla rostrata
52
+ sargo,Anisotremus davidsonii
53
+ Black margate,Anisotremus surinamensis
54
+ Porkfish,Anisotremus virginicus
55
+ freshwater drum,Aplodinotus grunniens
56
+ Green jobfish,Aprion virescens
57
+ Stripe eel,Aprognathodon platyventris
58
+ sheepshead,Archosargus probatocephalus
59
+ Sea bream,Archosargus rhomboidalis
60
+ Coron meagre,Argyrosomus coronus
61
+ Japanese meagre,Argyrosomus japonicus
62
+ Hardhead catfish,Ariopsis felis
63
+ White-spotted puffer,Arothron hispidus
64
+ Star puffer,Arothron stellatus
65
+ Australian salmon,Arripis trutta
66
+ Bay trout,Arripis truttacea
67
+ Oscar,Astronotus ocellatus
68
+ topsmelt,Atherinops affinis
69
+ jacksmelt,Atherinopsis californiensis
70
+ White seabass,Atractoscion nobilis
71
+ Alligator gar,Atractosteus spatula
72
+ Yellowtail scad,Atule mate
73
+ Trumpetfish,Aulostomus maculatus
74
+ Gafftopsail catfish,Bagre marinus
75
+ Silver perch,Bairdiella chrysoura
76
+ Orange-lined triggerfish,Balistapus undulatus
77
+ Gray triggerfish,Balistes capriscus
78
+ Queen triggerfish,Balistes vetula
79
+ titan triggerfish,Balistoides viridescens
80
+ Java barb,Barbonymus gonionotus
81
+ Barbel,Barbus barbus
82
+ Garfish,Belone belone
83
+ Silver perch,Bidyanus bidyanus
84
+ Silver bream,Blicca bjoerkna
85
+ Spanish hogfish,Bodianus rufus
86
+ Bogue,Boops boops
87
+ Blind shark,Brachaelurus waddi
88
+ Gulf menhaden,Brevoortia patronus
89
+ Atlantic menhaden,Brevoortia tyrannus
90
+ Jolthead porgy,Calamus bajonado
91
+ Central stoneroller,Campostoma anomalum
92
+ Orangespotted filefish,Cantherhines pullus
93
+ Sharpnose puffer,Canthigaster rostrata
94
+ Yellow jack,Caranx bartholomaei
95
+ Pacific crevalle jack,Caranx caninus
96
+ Blue runner,Caranx crysos
97
+ Crevalle jack,Caranx hippos
98
+ Giant trevally,Caranx ignobilis
99
+ horse-eye jack,Caranx latus
100
+ black jack,Caranx lugubris
101
+ Bluefin trevally,Caranx melampygus
102
+ Brassy trevally,Caranx papuensis
103
+ Bar jack,Caranx ruber
104
+ Bigeye trevally,Caranx sexfasciatus
105
+ Goldfish,Carassius auratus
106
+ Crucian carp,Carassius carassius
107
+ Prussian carp,Carassius gibelio
108
+ Blacknose shark,Carcharhinus acronotus
109
+ Narrowtooth shark,Carcharhinus brachyurus
110
+ Spinner shark,Carcharhinus brevipinna
111
+ Silky shark,Carcharhinus falciformis
112
+ Finetooth shark,Carcharhinus isodon
113
+ Bull shark,Carcharhinus leucas
114
+ Blacktip shark,Carcharhinus limbatus
115
+ Blacktip reef shark,Carcharhinus melanopterus
116
+ Dusky shark,Carcharhinus obscurus
117
+ Reef shark,Carcharhinus perezii
118
+ sandbar shark,Carcharhinus plumbeus
119
+ Spottail shark,Carcharhinus sorrah
120
+ Sand tiger,Carcharias taurus
121
+ White shark,Carcharodon carcharias
122
+ river carpsucker,Carpiodes carpio
123
+ Quillback,Carpiodes cyprinus
124
+ Highfin carpsucker,Carpiodes velifer
125
+ Longnose sucker,Catostomus catostomus
126
+ White sucker,Catostomus commersonii
127
+ Goldface tilefish,Caulolatilus chrysops
128
+ Monkeyface prickleback,Cebidichthys violaceus
129
+ Flier,Centrarchus macropterus
130
+ Fat snook,Centropomus parallelus
131
+ Common snook,Centropomus undecimalis
132
+ Black sea bass,Centropristis striata
133
+ Bluespotted grouper,Cephalopholis argus
134
+ Graysby,Cephalopholis cruentata
135
+ Coney,Cephalopholis fulva
136
+ coral grouper,Cephalopholis miniata
137
+ African hind,Cephalopholis taeniops
138
+ Yellowface pikeblenny,Chaenopsis limbaughi
139
+ Atlantic spadefish,Chaetodipterus faber
140
+ Foureye butterflyfish,Chaetodon capistratus
141
+ Saddle butterflyfish,Chaetodon ephippium
142
+ snakehead,Channa argus
143
+ Goldline snakehead,Channa aurolineata
144
+ emperor snakehead,Channa marulioides
145
+ Giant snakehead,Channa marulius
146
+ Giant snakehead,Channa micropeltes
147
+ Orangespotted Snakehead,Channa pseudomarulius
148
+ Chevron snakehead,Channa striata
149
+ Broadbanded moray,Channomuraena vittata
150
+ Milkfish,Chanos chanos
151
+ Tripletail wrasse,Cheilinus trilobatus
152
+ Humphead wrasse,Cheilinus undulatus
153
+ cigar wrasse,Cheilio inermis
154
+ Bluefin gurnard,Chelidonichthys kumu
155
+ Tub gurnard,Chelidonichthys lucerna
156
+ Striped burrfish,Chilomycterus schoepfii
157
+ Clown featherback,Chitala ornata
158
+ Atlantic bumper,Chloroscombrus chrysurus
159
+ Daisy parrotfish,Chlorurus sordidus
160
+ Damselfish,Chromis chromis
161
+ Green chromis,Chromis viridis
162
+ Roman seabream,Chrysoblephus laticeps
163
+ Tucanare peacock bass,Cichla monoculus
164
+ Peacock cichlid,Cichla ocellaris
165
+ Speckled pavon,Cichla temensis
166
+ Stocky hawkfish,Cirrhitus pinnulatus
167
+ Walking catfish,Clarias batrachus
168
+ Sharptooth catfish,Clarias gariepinus
169
+ Blunt-toothed African catfish,Clarias ngamensis
170
+ creole wrasse,Clepticus parrae
171
+ Woolly sculpin,Clinocottus analis
172
+ Atlantic herring,Clupea harengus
173
+ Pacific herring,Clupea pallasii
174
+ Cobbler,Cnidoglanis macrocephalus
175
+ European conger,Conger conger
176
+ Redbreast tilapia,Coptodon rendalli
177
+ Redbelly tilapia,Coptodon zillii
178
+ Cisco,Coregonus artedi
179
+ Lake whitefish,Coregonus clupeaformis
180
+ Rainbow wrasse,Coris julis
181
+ Common dolphinfish,Coryphaena hippurus
182
+ colon goby,Coryphopterus dicrus
183
+ Mottled sculpin,Cottus bairdii
184
+ Goldsinny-wrasse,Ctenolabrus rupestris
185
+ Grass carp,Ctenopharyngodon idella
186
+ Shiner perch,Cymatogaster aggregata
187
+ Sand seatrout,Cynoscion arenarius
188
+ Spotted seatrout,Cynoscion nebulosus
189
+ Weakfish,Cynoscion regalis
190
+ Red shiner,Cyprinella lutrensis
191
+ Spotfin shiner,Cyprinella spiloptera
192
+ Blacktail shiner,Cyprinella venusta
193
+ Common carp,Cyprinus carpio
194
+ Common carp,Cyprinus carpio carpio
195
+ Koi,Cyprinus rubrofuscus
196
+ Flying gurnard,Dactylopterus volitans
197
+ Common stingray,Dasyatis pastinaca
198
+ Mackerel scad,Decapterus macarellus
199
+ Dentex,Dentex dentex
200
+ Painted sweetlips,Diagramma pictum
201
+ European bass,Dicentrarchus labrax
202
+ Spotted seabass,Dicentrarchus punctatus
203
+ Galjoen,Dichistius capensis
204
+ Balloonfish,Diodon holocanthus
205
+ Porcupinefish,Diodon hystrix
206
+ Sand perch,Diplectrum formosum
207
+ annular sea bream,Diplodus annularis
208
+ Blacktail,Diplodus capensis
209
+ Zebra seabream,Diplodus cervinus
210
+ Spottail pinfish,Diplodus holbrookii
211
+ Puntazzo,Diplodus puntazzo
212
+ white seabream,Diplodus sargus
213
+ Twoband bream,Diplodus vulgaris
214
+ gizzard shad,Dorosoma cepedianum
215
+ Threadfin shad,Dorosoma petenense
216
+ Sharksucker,Echeneis naucrates
217
+ chain moray,Echidna catenata
218
+ Spotted spoon-nose eel,Echiophis intertinctus
219
+ Rainbow runner,Elagatis bipinnulata
220
+ blind tassel-fish,Eleutheronema tetradactylum
221
+ Squaretail mullet,Ellochelon vaigiensis
222
+ Ladyfish,Elops saurus
223
+ Black perch,Embiotoca jacksoni
224
+ Striped seaperch,Embiotoca lateralis
225
+ Blackbanded sunfish,Enneacanthus chaetodon
226
+ Bluespotted sunfish,Enneacanthus gloriosus
227
+ Banded sunfish,Enneacanthus obesus
228
+ Globefish,Ephippion guttifer
229
+ Rock hind,Epinephelus adscensionis
230
+ Spotted cabrilla,Epinephelus analogus
231
+ Orange-spotted grouper,Epinephelus coioides
232
+ Blacktip grouper,Epinephelus fasciatus
233
+ Brown-marbled grouper,Epinephelus fuscoguttatus
234
+ Red hind,Epinephelus guttatus
235
+ Goliath grouper,Epinephelus itajara
236
+ flag cabrilla,Epinephelus labriformis
237
+ Giant grouper,Epinephelus lanceolatus
238
+ Malabar grouper,Epinephelus malabaricus
239
+ Dusky grouper,Epinephelus marginatus
240
+ Honeycomb grouper,Epinephelus merra
241
+ Red grouper,Epinephelus morio
242
+ Nassau grouper,Epinephelus striatus
243
+ greasy grouper,Epinephelus tauvina
244
+ Potato grouper,Epinephelus tukula
245
+ Redfin pickerel,Esox americanus
246
+ Redfin pickerel,Esox americanus vermiculatus
247
+ Northern pike,Esox lucius
248
+ Muskellunge,Esox masquinongy
249
+ Tiger Musky,Esox masquinongy X Esox lucius
250
+ Muskellunge,Esox masquinongy punctulatus
251
+ Chain pickerel,Esox niger
252
+ Queen snapper,Etelis oculatus
253
+ Rainbow darter,Etheostoma caeruleum
254
+ Fringed flounder,Etropus crossotus
255
+ Silver jenny,Eucinostomus gula
256
+ Kawakawa,Euthynnus affinis
257
+ little tunny,Euthynnus alletteratus
258
+ Golden topminnow,Fundulus chrysotus
259
+ Banded killifish,Fundulus diaphanus
260
+ Blackspotted topminnow,Fundulus olivaceus
261
+ Atlantic cod,Gadus morhua
262
+ Tiger shark,Galeocerdo cuvier
263
+ Tope shark,Galeorhinus galeus
264
+ Western mosquitofish,Gambusia affinis
265
+ Threespine stickleback,Gasterosteus aculeatus
266
+ White croaker,Genyonemus lineatus
267
+ Yellowfin mojarra,Gerres cinereus
268
+ nurse shark,Ginglymostoma cirratum
269
+ Black bream,Girella elevata
270
+ Blackfish,Girella tricuspidata
271
+ Dhufish,Glaucosoma hebraicum
272
+ Golden trevally,Gnathanodon speciosus
273
+ California clingfish,Gobiesox rhessodon
274
+ Gudgeon,Gobio gobio
275
+ Quillfin blenny,Gobioclinus filamentosus
276
+ Rock goby,Gobius paganellus
277
+ Ruffe,Gymnocephalus cernua
278
+ Dogtooth tuna,Gymnosarda unicolor
279
+ Green moray,Gymnothorax funebris
280
+ Goldentail moray,Gymnothorax miliaris
281
+ Spotted moray,Gymnothorax moringa
282
+ Margate,Haemulon album
283
+ Tomtate,Haemulon aurolineatum
284
+ smallmouth grunt,Haemulon chrysargyreum
285
+ French grunt,Haemulon flavolineatum
286
+ Sailors choice,Haemulon parra
287
+ White grunt,Haemulon plumierii
288
+ Bluestriped grunt,Haemulon sciurus
289
+ Slippery dick,Halichoeres bivittatus
290
+ Yellowhead wrasse,Halichoeres garnoti
291
+ Clown wrasse,Halichoeres maculipinna
292
+ Puddingwife,Halichoeres radiatus
293
+ Rock wrasse,Halichoeres semicinctus
294
+ Carp,Hampala macrolepidota
295
+ Blackeye thicklip,Hemigymnus melapterus
296
+ Black bream,Hephaestus fuliginosus
297
+ Rio Grande cichlid,Herichthys cyanoguttatus
298
+ Garden eel,Heteroconger longissimus
299
+ Horn shark,Heterodontus francisci
300
+ Port Jackson shark,Heterodontus portusjacksoni
301
+ Kelp greenling,Hexagrammos decagrammus
302
+ Whitespotted greenling,Hexagrammos stelleri
303
+ Goldeye,Hiodon alosoides
304
+ mooneye,Hiodon tergisus
305
+ Flathead sole,Hippoglossoides elassodon
306
+ Pacific halibut,Hippoglossus stenolepis
307
+ angelfish,Holacanthus bermudensis
308
+ Queen angelfish,Holacanthus ciliaris
309
+ Squirrelfish,Holocentrus adscensionis
310
+ Trahira,Hoplias malabaricus
311
+ Greenbar snapper,Hoplopagrus guentherii
312
+ Brown hoplo,Hoplosternum littorale
313
+ Brassy minnow,Hybognathus hankinsoni
314
+ Tigerfish,Hydrocynus vittatus
315
+ Southern stingray,Hypanus americanus
316
+ Atlantic stingray,Hypanus sabinus
317
+ Northern hog sucker,Hypentelium nigricans
318
+ Surf smelt,Hypomesus pretiosus
319
+ Silver carp,Hypophthalmichthys molitrix
320
+ Bighead carp,Hypophthalmichthys nobilis
321
+ shy hamlet,Hypoplectrus guttavarius
322
+ Snowy grouper,Hyporthodus niveatus
323
+ Garibaldi,Hypsypops rubicundus
324
+ Blue catfish,Ictalurus furcatus
325
+ Channel catfish,Ictalurus punctatus
326
+ Smallmouth buffalo,Ictiobus bubalus
327
+ Bigmouth buffalo,Ictiobus cyprinellus
328
+ Black buffalo,Ictiobus niger
329
+ Black marlin,Istiompax indica
330
+ Atlantic sailfish,Istiophorus albicans
331
+ Sailfish,Istiophorus platypterus
332
+ Shortfin mako,Isurus oxyrinchus
333
+ Striped marlin,Kajikia audax
334
+ Skipjack tuna,Katsuwonus pelamis
335
+ Rock flagtail,Kuhlia rupestris
336
+ Bermuda sea chub,Kyphosus sectatrix
337
+ Blue-bronze chub,Kyphosus vaigiensis
338
+ Orange river mudfish,Labeo capensis
339
+ Smallmouth yellowfish,Labeobarbus aeneus
340
+ Largescale yellowfish,Labeobarbus marequensis
341
+ Bluestreak cleaner wrasse,Labroides dimidiatus
342
+ Ballan wrasse,Labrus bergylta
343
+ Cuckoo wrasse,Labrus mixtus
344
+ Hogfish,Lachnolaimus maximus
345
+ Smooth puffer,Lagocephalus laevigatus
346
+ Oceanic puffer,Lagocephalus lagocephalus
347
+ Pinfish,Lagodon rhomboides
348
+ Opah,Lampris guttatus
349
+ Barramundi perch,Lates calcarifer
350
+ Spangled perch,Leiopotherapon unicolor
351
+ Spot,Leiostomus xanthurus
352
+ Spotted gar,Lepisosteus oculatus
353
+ longnose gar,Lepisosteus osseus
354
+ Shortnose gar,Lepisosteus platostomus
355
+ Florida gar,Lepisosteus platyrhincus
356
+ Redbreast sunfish,Lepomis auritus
357
+ Green sunfish,Lepomis cyanellus
358
+ Pumpkinseed,Lepomis gibbosus
359
+ Warmouth,Lepomis gulosus
360
+ Orangespotted sunfish,Lepomis humilis
361
+ Bluegill,Lepomis macrochirus
362
+ Dollar sunfish,Lepomis marginatus
363
+ Longear sunfish,Lepomis megalotis
364
+ Redear sunfish,Lepomis microlophus
365
+ Redspotted sunfish,Lepomis miniatus
366
+ Northern sunfish,Lepomis peltastes
367
+ Spotted sunfish,Lepomis punctatus
368
+ Bantam sunfish,Lepomis symmetricus
369
+ Pacific staghorn sculpin,Leptocottus armatus
370
+ Pink ear emperor,Lethrinus lentjan
371
+ Spangled emperor,Lethrinus nebulosus
372
+ Orange-striped emperor,Lethrinus obsoletus
373
+ Asp,Leuciscus aspius
374
+ Ide,Leuciscus idus
375
+ Eurasian dace,Leuciscus leuciscus
376
+ Leerfish,Lichia amia
377
+ Dab,Limanda limanda
378
+ Shanny,Lipophrys pholis
379
+ White steenbras,Lithognathus lithognathus
380
+ Sand steenbras,Lithognathus mormyrus
381
+ Tripletail,Lobotes surinamensis
382
+ Cape Hope squid,Loligo vulgaris
383
+ Burbot,Lota lota
384
+ Mutton snapper,Lutjanus analis
385
+ Schoolmaster,Lutjanus apodus
386
+ Mangrove red snapper,Lutjanus argentimaculatus
387
+ amarillo snapper,Lutjanus argentiventris
388
+ Twospot snapper,Lutjanus bohar
389
+ blackfin snapper,Lutjanus buccanella
390
+ Red snapper,Lutjanus campechanus
391
+ Spanish flag,Lutjanus carponotatus
392
+ Cubera snapper,Lutjanus cyanopterus
393
+ Checkered snapper,Lutjanus decussatus
394
+ Blackspot snapper,Lutjanus ehrenbergii
395
+ Blackspot snapper,Lutjanus fulviflamma
396
+ Blacktail snapper,Lutjanus fulvus
397
+ Humpback snapper,Lutjanus gibbus
398
+ Gray snapper,Lutjanus griseus
399
+ Dog snapper,Lutjanus jocu
400
+ John's snapper,Lutjanus johnii
401
+ Bluestriped snapper,Lutjanus kasmira
402
+ Mahogany snapper,Lutjanus mahogoni
403
+ Onespot snapper,Lutjanus monostigma
404
+ Dog snapper,Lutjanus novemfasciatus
405
+ Caribbean red snapper,Lutjanus purpureus
406
+ Blubberlip snapper,Lutjanus rivulatus
407
+ Russell's snapper,Lutjanus russellii
408
+ Emperor snapper,Lutjanus sebae
409
+ Lane snapper,Lutjanus synagris
410
+ Silk snapper,Lutjanus vivanus
411
+ Striped shiner,Luxilus chrysocephalus
412
+ Common shiner,Luxilus cornutus
413
+ Trout cod,Maccullochella macquariensis
414
+ Murray cod,Maccullochella peelii
415
+ Golden perch,Macquaria ambigua
416
+ Blue marlin,Makaira nigricans
417
+ Mayan cichlid,Mayaheros urophthalmus
418
+ Tarpon,Megalops atlanticus
419
+ Oxeye,Megalops cyprinoides
420
+ Haddock,Melanogrammus aeglefinus
421
+ Black durgon,Melichthys niger
422
+ Atlantic silverside,Menidia menidia
423
+ southern kingcroaker,Menticirrhus americanus
424
+ Gulf kingcroaker,Menticirrhus littoralis
425
+ Northern kingfish,Menticirrhus saxatilis
426
+ California corbina,Menticirrhus undulatus
427
+ Whiting,Merlangius merlangus
428
+ Six-spined leatherjacket,Meuschenia freycineti
429
+ Atlantic croaker,Micropogonias undulatus
430
+ Shoal bass,Micropterus cataractae
431
+ Redeye bass,Micropterus coosae
432
+ Smallmouth bass,Micropterus dolomieu
433
+ Florida bass,Micropterus floridanus
434
+ Alabama bass,Micropterus henshalli
435
+ Largemouth bass,Micropterus nigricans
436
+ Suwannee bass,Micropterus notius
437
+ Spotted bass,Micropterus punctulatus
438
+ Guadalupe bass,Micropterus treculii
439
+ Spotted sucker,Minytrema melanops
440
+ Ocean sunfish,Mola mola
441
+ Centreboard leatherjacket,Monacanthus chinensis
442
+ Diamond moonfish,Monodactylus argenteus
443
+ White perch,Morone americana
444
+ White bass,Morone chrysops
445
+ Wiper,Morone chrysops X Morone saxatilis
446
+ Yellow bass,Morone mississippiensis
447
+ Striped bass,Morone saxatilis
448
+ Silver redhorse,Moxostoma anisurum
449
+ River redhorse,Moxostoma carinatum
450
+ Black redhorse,Moxostoma duquesnei
451
+ Golden redhorse,Moxostoma erythrurum
452
+ Shorthead redhorse,Moxostoma macrolepidotum
453
+ Greater redhorse,Moxostoma valenciennesi
454
+ flathead grey mullet,Mugil cephalus
455
+ White mullet,Mugil curema
456
+ Yellowstripe goatfish,Mulloidichthys flavolineatus
457
+ yellow goatfish,Mulloidichthys martinicus
458
+ Red mullet,Mullus surmuletus
459
+ Gummy shark,Mustelus antarcticus
460
+ Starry smooth-hound,Mustelus asterias
461
+ Smooth dogfish,Mustelus canis
462
+ Spotted estuary smooth-hound,Mustelus lenticulatus
463
+ Smooth-hound,Mustelus mustelus
464
+ Black grouper,Mycteroperca bonaci
465
+ finescale rockfish,Mycteroperca microlepis
466
+ Scamp,Mycteroperca phenax
467
+ Comb grouper,Mycteroperca rubra
468
+ Tiger grouper,Mycteroperca tigris
469
+ Yellowfin grouper,Mycteroperca venenosa
470
+ Common eagle ray,Myliobatis aquila
471
+ Bat ray,Myliobatis californica
472
+ Sharptail eel,Myrichthys breviceps
473
+ Goldspotted eel,Myrichthys ocellatus
474
+ Blotcheye soldierfish,Myripristis berndti
475
+ Bluespine unicornfish,Naso unicornis
476
+ Tawny nurse shark,Nebrius ferrugineus
477
+ lemon shark,Negaprion brevirostris
478
+ Roosterfish,Nematistius pectoralis
479
+ Round goby,Neogobius melanostomus
480
+ Hornyhead chub,Nocomis biguttatus
481
+ Bluehead chub,Nocomis leptocephalus
482
+ River chub,Nocomis micropogon
483
+ Golden shiner,Notemigonus crysoleucas
484
+ Spotty,Notolabrus celidotus
485
+ Banded parrotfish,Notolabrus fucicola
486
+ crimson banded wrasse,Notolabrus gymnogenis
487
+ Blue-throated parrotfish,Notolabrus tetricus
488
+ Broadnose sevengill shark,Notorynchus cepedianus
489
+ Emerald shiner,Notropis atherinoides
490
+ spottail shiner,Notropis hudsonius
491
+ Sand Shiner,Notropis stramineus
492
+ Stonecat,Noturus flavus
493
+ Saddle bream,Oblada melanurus
494
+ Yellowtail snapper,Ocyurus chrysurus
495
+ Leatherjacket,Oligoplites saurus
496
+ Golden trout,Oncorhynchus aguabonita
497
+ Cutthroat trout,Oncorhynchus clarkii
498
+ Cutthroat trout,Oncorhynchus clarkii clarkii
499
+ Pink salmon,Oncorhynchus gorbuscha
500
+ Chum salmon,Oncorhynchus keta
501
+ Coho salmon,Oncorhynchus kisutch
502
+ Rainbow trout,Oncorhynchus mykiss
503
+ Cutbow,Oncorhynchus mykiss X Color variant2
504
+ Cutbow,Oncorhynchus mykiss X Oncorhynchus clarkii
505
+ Sockeye salmon,Oncorhynchus nerka
506
+ Chinook salmon,Oncorhynchus tshawytscha
507
+ Spotted snake eel,Ophichthus ophis
508
+ Lingcod,Ophiodon elongatus
509
+ Atlantic thread herring,Opisthonema oglinum
510
+ Oyster toadfish,Opsanus tau
511
+ Threespot tilapia,Oreochromis andersonii
512
+ Blue tilapia,Oreochromis aureus
513
+ Longfin tilapia,Oreochromis macrochir
514
+ Mozambique tilapia,Oreochromis mossambicus
515
+ Nile tilapia,Oreochromis niloticus
516
+ Pigfish,Orthopristis chrysoptera
517
+ Hottentot seabream,Pachymetopon blochii
518
+ axillary seabream,Pagellus acarne
519
+ Pandora,Pagellus erythrinus
520
+ Squirefish,Pagrus auratus
521
+ Red porgy,Pagrus pagrus
522
+ Bermuda lobster,Panulirus argus
523
+ Palette surgeonfish,Paracanthurus hepatus
524
+ Guapote,Parachromis dovii
525
+ Jaguar guapote,Parachromis managuensis
526
+ kelp bass,Paralabrax clathratus
527
+ Spotted sand bass,Paralabrax maculatofasciatus
528
+ barred sand bass,Paralabrax nebulifer
529
+ Gulf flounder,Paralichthys albigutta
530
+ California halibut,Paralichthys californicus
531
+ Summer flounder,Paralichthys dentatus
532
+ Southern flounder,Paralichthys lethostigma
533
+ Blue cod,Parapercis colias
534
+ Blue goatfish,Parupeneus cyclostomus
535
+ Indian goatfish,Parupeneus indicus
536
+ Manybar goatfish,Parupeneus multifasciatus
537
+ Blackspot goatfish,Parupeneus spilurus
538
+ Spotted tilapia,Pelmatolapia mariae
539
+ Butterfish,Peprilus triacanthus
540
+ Yellow perch,Perca flavescens
541
+ Eurasian perch,Perca fluviatilis
542
+ Australian bass,Percalates novemaculeatus
543
+ Logperch,Percina caprodes
544
+ Trout-perch,Percopsis omiscomaycus
545
+ Eurasian minnow,Phoxinus phoxinus
546
+ Redtail catfish,Phractocephalus hemioliopterus
547
+ Pirapitinga,Piaractus brachypomus
548
+ Congo barbaso,Pimelodus maculatus
549
+ Bluntnose minnow,Pimephales notatus
550
+ Fathead minnow,Pimephales promelas
551
+ Bullhead minnow,Pimephales vigilax
552
+ Golden spadefish,Platax boersii
553
+ Longfin batfish,Platax teira
554
+ European flounder,Platichthys flesus
555
+ Starry flounder,Platichthys stellatus
556
+ Bay flathead,Platycephalus bassensis
557
+ Blue-spotted flathead,Platycephalus caeruleopunctatus
558
+ Black flathead,Platycephalus fuscus
559
+ Bartail flathead,Platycephalus indicus
560
+ Thornback,Platyrhinoidis triseriata
561
+ Blacksaddled coralgrouper,Plectropomus laevis
562
+ Leopard coralgrouper,Plectropomus leopardus
563
+ Spotted coralgrouper,Plectropomus maculatus
564
+ Plaice,Pleuronectes platessa
565
+ Striped eel-catfish,Plotosus lineatus
566
+ Guppy,Poecilia reticulata
567
+ Black drum,Pogonias cromis
568
+ Pollack,Pollachius pollachius
569
+ Pollock,Pollachius virens
570
+ Paddlefish,Polyodon spathula
571
+ Wreckfish,Polyprion americanus
572
+ Gray angelfish,Pomacanthus arcuatus
573
+ French angelfish,Pomacanthus paru
574
+ Goldbelly damsel,Pomacentrus auriventris
575
+ Smallspotted grunt,Pomadasys commersonnii
576
+ Javelin grunter,Pomadasys kaakan
577
+ blue fish,Pomatomus saltatrix
578
+ sand goby,Pomatoschistus minutus
579
+ White crappie,Pomoxis annularis
580
+ black crappie,Pomoxis nigromaculatus
581
+ striped catshark,Poroderma africanum
582
+ Black river stingray,Potamotrygon motoro
583
+ Rusty goby,Priolepis hipoliti
584
+ Blue shark,Prionace glauca
585
+ Northern searobin,Prionotus carolinus
586
+ striped searobin,Prionotus evolans
587
+ Mountain whitefish,Prosopium williamsoni
588
+ Shovelnose guitarfish,Pseudobatos productus
589
+ Günther's wrasse,Pseudolabrus guentheri
590
+ Barred sorubim,Pseudoplatystoma fasciatum
591
+ Winter flounder,Pseudopleuronectes americanus
592
+ Spotted goatfish,Pseudupeneus maculatus
593
+ Lionfish,Pterois volitans
594
+ Leopard pleco,Pterygoplichthys gibbiceps
595
+ Sacramento pikeminnow,Ptychocheilus grandis
596
+ Columbia River dace,Ptychocheilus oregonensis
597
+ Red piranha,Pygocentrus nattereri
598
+ Flathead catfish,Pylodictis olivaris
599
+ blackspotted snake eel,Quassiremus ascensionis
600
+ Cobia,Rachycentron canadum
601
+ roker,Raja clavata
602
+ Indian mackerel,Rastrelliger kanagurta
603
+ Cape stumpnose,Rhabdosargus holubi
604
+ Goldlined seabream,Rhabdosargus sarba
605
+ Catfish,Rhamdia quelen
606
+ Whale shark,Rhincodon typus
607
+ Patchy triggerfish,Rhinecanthus rectangulus
608
+ Blacknose dace,Rhinichthys atratulus
609
+ Longnose dace,Rhinichthys cataractae
610
+ Western blacknose dace,Rhinichthys obtusus
611
+ Cownose ray,Rhinoptera bonasus
612
+ Atlantic sharpnose shark,Rhizoprionodon terraenovae
613
+ vermilion snapper,Rhomboplites aurorubens
614
+ Roach,Rutilus rutilus
615
+ Greater soapfish,Rypticus saponaceus
616
+ Dorado,Salminus brasiliensis
617
+ marble trout,Salmo marmoratus
618
+ Atlantic salmon,Salmo salar
619
+ Brown trout,Salmo trutta
620
+ Tiger trout,Salmo trutta X Salvelinus fontinalis
621
+ Arctic char,Salvelinus alpinus
622
+ Bull trout,Salvelinus confluentus
623
+ Brook trout,Salvelinus fontinalis
624
+ Dolly varden,Salvelinus malma
625
+ Lake trout,Salvelinus namaycush
626
+ Sauger,Sander canadensis
627
+ Zander,Sander lucioperca
628
+ Walleye,Sander vitreus
629
+ Saugeye,Sander vitreus X Sander canadensis
630
+ Striped bonito,Sarda orientalis
631
+ Atlantic bonito,Sarda sarda
632
+ Blackchin tilapia,Sarotherodon melanotheron
633
+ Salpa,Sarpa salpa
634
+ Shovelnose sturgeon,Scaphirhynchus platorynchus
635
+ Rudd,Scardinius erythrophthalmus
636
+ Blue parrotfish,Scarus coeruleus
637
+ Blue-barred parrotfish,Scarus ghobban
638
+ Rainbow parrotfish,Scarus guacamaia
639
+ Striped parrotfish,Scarus iseri
640
+ Dusky parrotfish,Scarus niger
641
+ Common parrotfish,Scarus psittacus
642
+ Princess parrotfish,Scarus taeniopterus
643
+ Queen parrotfish,Scarus vetula
644
+ Spotted scat,Scatophagus argus
645
+ Red drum,Sciaenops ocellatus
646
+ Spotted mackerel,Scomber australasicus
647
+ Atlantic chub mackerel,Scomber colias
648
+ Chub mackerel,Scomber japonicus
649
+ Atlantic mackerel,Scomber scombrus
650
+ Talang queenfish,Scomberoides commersonnianus
651
+ Doublespotted queenfish,Scomberoides lysan
652
+ King mackerel,Scomberomorus cavalla
653
+ Narrowbarred mackerel,Scomberomorus commerson
654
+ Spanish mackerel,Scomberomorus maculatus
655
+ Cero,Scomberomorus regalis
656
+ Pacific sierra,Scomberomorus sierra
657
+ California scorpionfish,Scorpaena guttata
658
+ Smallscaled scorpionfish,Scorpaena porcus
659
+ Cabezon,Scorpaenichthys marmoratus
660
+ Silver sweep,Scorpis lineolata
661
+ Small-spotted catshark,Scyliorhinus canicula
662
+ greater spotted dogfish,Scyliorhinus stellaris
663
+ Brown rockfish,Sebastes auriculatus
664
+ Copper rockfish,Sebastes caurinus
665
+ Quillback rockfish,Sebastes maliger
666
+ Black rockfish,Sebastes melanops
667
+ Vermilion rockfish,Sebastes miniatus
668
+ Blue rockfish,Sebastes mystinus
669
+ Yelloweye rockfish,Sebastes ruberrimus
670
+ Bigeye scad,Selar crumenophthalmus
671
+ Atlantic moonfish,Selene setapinnis
672
+ Lookdown,Selene vomer
673
+ California sheephead,Semicossyphus pulcher
674
+ Creek chub,Semotilus atromaculatus
675
+ Fallfish,Semotilus corporalis
676
+ Greater amberjack,Seriola dumerili
677
+ Samsonfish,Seriola hippos
678
+ Yellowtail amberjack,Seriola lalandi
679
+ Almaco jack,Seriola rivoliana
680
+ Yellow-belly bream,Serranochromis robustus
681
+ Comber,Serranus cabrilla
682
+ Painted comber,Serranus scriba
683
+ Redeye piranha,Serrasalmus rhombeus
684
+ Goldlined spinefoot,Siganus guttatus
685
+ Foxface,Siganus vulpinus
686
+ Spotted whiting,Sillaginodes punctatus
687
+ Sand sillago,Sillago ciliata
688
+ Wels,Silurus glanis
689
+ Greenblotch parrotfish,Sparisoma atomarium
690
+ Redband parrotfish,Sparisoma aurofrenatum
691
+ Redtail parrotfish,Sparisoma chrysopterum
692
+ Parrotfish,Sparisoma cretense
693
+ redfin parrotfish,Sparisoma rubripinne
694
+ Stoplight parrotfish,Sparisoma viride
695
+ White musselcracker,Sparodon durbanensis
696
+ Gilthead bream,Sparus aurata
697
+ Northern puffer,Sphoeroides maculatus
698
+ Southern puffer,Sphoeroides nephelus
699
+ Bandtail puffer,Sphoeroides spengleri
700
+ Checkered puffer,Sphoeroides testudineus
701
+ Pacific barracuda,Sphyraena argentea
702
+ Great barracuda,Sphyraena barracuda
703
+ Pickhandle barracuda,Sphyraena jello
704
+ Australian barracuda,Sphyraena novaehollandiae
705
+ Yellowstriped barracuda,Sphyraena obtusata
706
+ Blackfin barracuda,Sphyraena qenie
707
+ European barracuda,Sphyraena sphyraena
708
+ barracuda,Sphyraena viridensis
709
+ Scalloped hammerhead,Sphyrna lewini
710
+ Great hammerhead,Sphyrna mokarran
711
+ Bonnethead,Sphyrna tiburo
712
+ Smooth hammerhead,Sphyrna zygaena
713
+ Black seabream,Spondyliosoma cantharus
714
+ Spiny dogfish,Squalus acanthias
715
+ Dusky damselfish,Stegastes adustus
716
+ Longfin damselfish,Stegastes diencaeus
717
+ Beaugregory,Stegastes leucostictus
718
+ Bicolor damselfish,Stegastes partitus
719
+ Cocoa damselfish,Stegastes variabilis
720
+ Scup,Stenotomus chrysops
721
+ Giant sea bass,Stereolepis gigas
722
+ Atlantic needlefish,Strongylura marina
723
+ Grey wrasse,Symphodus cinereus
724
+ Corkwing wrasse,Symphodus melops
725
+ East Atlantic peacock wrasse,Symphodus tinca
726
+ Mandarinfish,Synchiropus splendidus
727
+ Inshore lizardfish,Synodus foetens
728
+ Atlantic lizardfish,Synodus saurus
729
+ Blue-spotted fantail ray,Taeniura lymma
730
+ Eel-tailed catfish,Tandanus tandanus
731
+ Tautog,Tautoga onitis
732
+ Cunner,Tautogolabrus adspersus
733
+ Jarbua terapon,Terapon jarbua
734
+ Bluehead,Thalassoma bifasciatum
735
+ Saddle wrasse,Thalassoma duperrey
736
+ moon wrasse,Thalassoma lunare
737
+ Ornate wrasse,Thalassoma pavo
738
+ Surge wrasse,Thalassoma purpureum
739
+ Albacore,Thunnus alalunga
740
+ Yellowfin tuna,Thunnus albacares
741
+ Blackfin tuna,Thunnus atlanticus
742
+ Pacific bluefin tuna,Thunnus orientalis
743
+ Bluefin tuna,Thunnus thynnus
744
+ Arctic grayling,Thymallus arcticus
745
+ Grayling,Thymallus thymallus
746
+ Snoek,Thyrsites atun
747
+ Banded tilapia,Tilapia sparrmanii
748
+ Tench,Tinca tinca
749
+ Banded toadfish,Torquigener pleurogramma
750
+ Archerfish,Toxotes jaculatrix
751
+ Smallspotted dart,Trachinotus baillonii
752
+ Snubnose pompano,Trachinotus blochii
753
+ Florida pompano,Trachinotus carolinus
754
+ Swallowtail dart,Trachinotus coppingeri
755
+ Permit,Trachinotus falcatus
756
+ Palometa,Trachinotus goodei
757
+ Derbio,Trachinotus ovatus
758
+ Mediterranean scad,Trachurus mediterraneus
759
+ Yellowtail horse mackerel,Trachurus novaezelandiae
760
+ Jack mackerel,Trachurus symmetricus
761
+ European horse mackerel,Trachurus trachurus
762
+ whitetip reef shark,Triaenodon obesus
763
+ Leopard shark,Triakis semifasciata
764
+ Atlantic cutlass fish,Trichiurus lepturus
765
+ Hogchoker,Trinectes maculatus
766
+ Bib,Trisopterus luscus
767
+ Eastern fiddler ray,Trygonorrhina fasciata
768
+ Houndfish,Tylosurus crocodilus
769
+ Central mudminnow,Umbra limi
770
+ lunartail grouper,Variola louti
771
+ Wallago,Wallago attu
772
+ Sargassum triggerfish,Xanthichthys ringens
773
+ Swordfish,Xiphias gladius
774
+ Pearly razorfish,Xyrichtys novacula
775
+ Yellow tang,Zebrasoma flavescens
776
+ Common cuttlefish,Sepia officinalis