| from __future__ import annotations |
|
|
| import importlib.util |
| import json |
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| import cv2 |
| import numpy as np |
|
|
|
|
| def _load_local_miner_class(): |
| miner_path = Path(__file__).resolve().parent / "miner.py" |
| spec = importlib.util.spec_from_file_location("manako_bridge_local_miner", str(miner_path)) |
| if spec is None or spec.loader is None: |
| raise RuntimeError(f"Could not load miner module from {miner_path}") |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| miner_class = getattr(module, "Miner", None) |
| if miner_class is None: |
| raise RuntimeError(f"miner.py does not export Miner in {miner_path}") |
| return miner_class |
|
|
|
|
| Miner = _load_local_miner_class() |
|
|
|
|
| CLASS_NAMES = ['football', 'player', 'pitch'] |
| MODEL_TYPE = 'ultralytics-yolo' |
|
|
|
|
| def _to_dict(value: Any) -> dict[str, Any]: |
| if isinstance(value, dict): |
| return value |
| if hasattr(value, "model_dump") and callable(value.model_dump): |
| dumped = value.model_dump() |
| if isinstance(dumped, dict): |
| return dumped |
| if hasattr(value, "__dict__"): |
| return dict(value.__dict__) |
| return {} |
|
|
|
|
| def _extract_boxes(frame_result: Any) -> list[Any]: |
| frame = _to_dict(frame_result) |
| boxes = frame.get("boxes", []) |
| if isinstance(boxes, list): |
| return boxes |
| return [] |
|
|
|
|
| def _to_detection(box: Any) -> dict[str, Any]: |
| payload = _to_dict(box) |
| cls_id = int(payload.get("cls_id", 0)) |
| x1 = float(payload.get("x1", 0.0)) |
| y1 = float(payload.get("y1", 0.0)) |
| x2 = float(payload.get("x2", 0.0)) |
| y2 = float(payload.get("y2", 0.0)) |
| width = max(0.0, x2 - x1) |
| height = max(0.0, y2 - y1) |
| return { |
| "x": x1 + width / 2.0, |
| "y": y1 + height / 2.0, |
| "width": width, |
| "height": height, |
| "confidence": float(payload.get("conf", 0.0)), |
| "class_id": cls_id, |
| "class": CLASS_NAMES[cls_id] if 0 <= cls_id < len(CLASS_NAMES) else str(cls_id), |
| } |
|
|
|
|
| def _normalize_image_for_miner(image: Any) -> Any: |
| if image is None or hasattr(image, "shape"): |
| return image |
| if isinstance(image, (bytes, bytearray, memoryview)): |
| try: |
| buffer = np.frombuffer(bytes(image), dtype=np.uint8) |
| decoded = cv2.imdecode(buffer, cv2.IMREAD_COLOR) |
| if decoded is not None: |
| return decoded |
| except Exception: |
| return image |
| if hasattr(image, "convert") and callable(image.convert): |
| try: |
| rgb = image.convert("RGB") |
| array = np.array(rgb) |
| if getattr(array, "ndim", 0) == 3 and array.shape[-1] == 3: |
| return cv2.cvtColor(array, cv2.COLOR_RGB2BGR) |
| return array |
| except Exception: |
| return image |
| try: |
| array = np.asarray(image) |
| if getattr(array, "shape", None): |
| return array |
| except Exception: |
| return image |
| return image |
|
|
|
|
| def load_model(onnx_path: str | None = None, data_dir: str | None = None): |
| del onnx_path |
| repo_dir = Path(data_dir) if data_dir else Path(__file__).resolve().parent |
| miner = Miner(repo_dir) |
| return { |
| "miner": miner, |
| "model_type": MODEL_TYPE, |
| "class_names": CLASS_NAMES, |
| } |
|
|
|
|
| def _candidate_keypoint_counts(miner: Any) -> list[int]: |
| counts: list[int] = [0] |
| for attr in ("n_keypoints", "num_keypoints", "keypoint_count", "num_joints"): |
| value = getattr(miner, attr, None) |
| if isinstance(value, int) and value > 0: |
| counts.append(value) |
| counts.append(32) |
|
|
| seen: set[int] = set() |
| ordered: list[int] = [] |
| for count in counts: |
| if count in seen: |
| continue |
| seen.add(count) |
| ordered.append(count) |
| return ordered |
|
|
|
|
| def _predict_batch_with_fallbacks(miner: Any, image: Any) -> list[Any]: |
| normalized_image = _normalize_image_for_miner(image) |
| errors: list[str] = [] |
| for n_keypoints in _candidate_keypoint_counts(miner): |
| try: |
| return miner.predict_batch([normalized_image], offset=0, n_keypoints=n_keypoints) |
| except Exception as exc: |
| errors.append(f"n_keypoints={n_keypoints} -> {exc}") |
| continue |
| raise RuntimeError("predict_batch failed for all keypoint candidates: " + " | ".join(errors)) |
|
|
|
|
| def run_model(model: Any, image: Any = None, onnx_path: str | None = None, data_dir: str | None = None): |
| del onnx_path |
| if image is None: |
| image = model |
| model = load_model(data_dir=data_dir) |
| miner = model["miner"] |
| results = _predict_batch_with_fallbacks(miner, image) |
| if not results: |
| return [[]] |
| frame_boxes = _extract_boxes(results[0]) |
| detections = [_to_detection(box) for box in frame_boxes] |
| return [detections] |
|
|
|
|
| def main() -> None: |
| if len(sys.argv) < 2: |
| print("Usage: main.py <image_path>", file=sys.stderr) |
| raise SystemExit(1) |
| image_path = sys.argv[1] |
| image = cv2.imread(image_path, cv2.IMREAD_COLOR) |
| if image is None: |
| print(f"Could not read image: {image_path}", file=sys.stderr) |
| raise SystemExit(1) |
| data_dir = os.path.dirname(os.path.abspath(__file__)) |
| model = load_model(data_dir=data_dir) |
| output = run_model(model, image) |
| print(json.dumps(output, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|