stevenkhan commited on
Commit
3af60ea
·
verified ·
1 Parent(s): 8fe9595

Upload clashcr/models/evidence_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. clashcr/models/evidence_model.py +298 -0
clashcr/models/evidence_model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Card evidence model: detect troops, buildings, spells, evolutions, heroes.
2
+
3
+ Outputs evidence, not just card names:
4
+ - detected units/effects,
5
+ - bounding boxes/masks,
6
+ - spawn time,
7
+ - location,
8
+ - side,
9
+ - possible cards,
10
+ - confidence,
11
+ - ambiguity reason.
12
+
13
+ For normal live view, we use a YOLO-based unit detector (inspired by KataCR)
14
+ plus heuristic spell/effect detectors for transient effects.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import time
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Dict, List, Optional, Tuple
23
+
24
+ import cv2
25
+ import numpy as np
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class UnitEvidence:
32
+ unit_name: str
33
+ bbox: Tuple[int, int, int, int] # x, y, w, h
34
+ confidence: float
35
+ side: str
36
+ frame_idx: int
37
+ timestamp: float
38
+ is_evolution: bool = False
39
+ is_hero: bool = False
40
+ is_building: bool = False
41
+ is_spell_effect: bool = False
42
+
43
+
44
+ @dataclass
45
+ class SpellEvidence:
46
+ spell_name: str
47
+ effect_mask: np.ndarray = field(repr=False)
48
+ bbox: Tuple[int, int, int, int]
49
+ confidence: float
50
+ side: str
51
+ frame_idx: int
52
+ timestamp: float
53
+ ambiguity_reason: str = ""
54
+
55
+
56
+ @dataclass
57
+ class EvidenceBundle:
58
+ timestamp: float
59
+ frame_idx: int
60
+ side: str
61
+ units: List[UnitEvidence] = field(default_factory=list)
62
+ spells: List[SpellEvidence] = field(default_factory=list)
63
+ possible_cards: List[str] = field(default_factory=list)
64
+ confidence: float = 0.0
65
+ ambiguity_reason: str = ""
66
+
67
+
68
+ class EvidenceModel:
69
+ """Wraps a YOLO unit detector and adds spell/heuristic effect detection.
70
+
71
+ Args:
72
+ model_path: Path to YOLO .pt file (e.g., KataCR-trained model).
73
+ img_size: Inference size.
74
+ conf_threshold: Minimum confidence for unit detections.
75
+ spell_enabled: Whether to run heuristic spell detectors.
76
+ """
77
+
78
+ # Heuristic spell effect signatures (BGR color ranges in HSV)
79
+ SPELL_SIGNATURES = {
80
+ "zap": {
81
+ "hsv_ranges": [(90, 50, 200), (130, 255, 255)], # bright blue-white flash
82
+ "min_area": 300,
83
+ "max_duration_frames": 8,
84
+ },
85
+ "fireball": {
86
+ "hsv_ranges": [(0, 100, 200), (20, 255, 255)], # orange-red explosion
87
+ "min_area": 500,
88
+ "max_duration_frames": 12,
89
+ },
90
+ "arrows": {
91
+ "hsv_ranges": [(0, 0, 180), (180, 30, 255)], # white/grey streaks
92
+ "min_area": 200,
93
+ "max_duration_frames": 6,
94
+ },
95
+ "poison": {
96
+ "hsv_ranges": [(35, 50, 50), (85, 255, 200)], # green cloud
97
+ "min_area": 400,
98
+ "max_duration_frames": 20,
99
+ },
100
+ "freeze": {
101
+ "hsv_ranges": [(80, 30, 200), (120, 100, 255)], # icy blue-white
102
+ "min_area": 300,
103
+ "max_duration_frames": 15,
104
+ },
105
+ "rage": {
106
+ "hsv_ranges": [(150, 100, 150), (180, 255, 255)], # purple-pink
107
+ "min_area": 400,
108
+ "max_duration_frames": 18,
109
+ },
110
+ "tornado": {
111
+ "hsv_ranges": [(0, 0, 100), (180, 50, 200)], # grey swirl
112
+ "min_area": 500,
113
+ "max_duration_frames": 15,
114
+ },
115
+ "earthquake": {
116
+ "hsv_ranges": [(10, 50, 50), (30, 200, 150)], # brown cracks
117
+ "min_area": 600,
118
+ "max_duration_frames": 15,
119
+ },
120
+ "log": {
121
+ "hsv_ranges": [(10, 50, 80), (30, 200, 180)], # brown rolling log
122
+ "min_area": 400,
123
+ "max_duration_frames": 10,
124
+ },
125
+ "barbarian-barrel": {
126
+ "hsv_ranges": [(10, 50, 80), (30, 200, 180)], # similar to log
127
+ "min_area": 400,
128
+ "max_duration_frames": 10,
129
+ },
130
+ "vines": {
131
+ "hsv_ranges": [(35, 80, 80), (75, 255, 200)], # green tangling vines
132
+ "min_area": 300,
133
+ "max_duration_frames": 20,
134
+ },
135
+ "void": {
136
+ "hsv_ranges": [(120, 50, 20), (160, 255, 80)], # dark purple/black hole
137
+ "min_area": 400,
138
+ "max_duration_frames": 15,
139
+ },
140
+ }
141
+
142
+ def __init__(self,
143
+ model_path: Optional[str] = None,
144
+ img_size: int = 640,
145
+ conf_threshold: float = 0.5,
146
+ spell_enabled: bool = True,
147
+ device: str = "cpu"):
148
+ self.model_path = model_path
149
+ self.img_size = img_size
150
+ self.conf_threshold = conf_threshold
151
+ self.spell_enabled = spell_enabled
152
+ self.device = device
153
+ self._model = None
154
+ self._spell_history: Dict[str, List[int]] = {} # spell_name -> list of frame indices seen
155
+ self._frame_idx = 0
156
+
157
+ if model_path and Path(model_path).exists():
158
+ self._load_yolo()
159
+ else:
160
+ logger.warning("YOLO model not found at %s; unit detection disabled.", model_path)
161
+
162
+ def _load_yolo(self) -> None:
163
+ try:
164
+ from ultralytics import YOLO
165
+ self._model = YOLO(self.model_path)
166
+ logger.info("Loaded YOLO model from %s", self.model_path)
167
+ except Exception as e:
168
+ logger.error("Failed to load YOLO model: %s", e)
169
+ self._model = None
170
+
171
+ def detect_units(self, frame: np.ndarray, side: str) -> List[UnitEvidence]:
172
+ """Run YOLO unit detection on the frame."""
173
+ if self._model is None:
174
+ return []
175
+
176
+ results = self._model(frame, imgsz=self.img_size, verbose=False, device=self.device)
177
+ evidence = []
178
+ for box in results[0].boxes:
179
+ conf = float(box.conf)
180
+ if conf < self.conf_threshold:
181
+ continue
182
+ cls_name = self._model.names[int(box.cls)]
183
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
184
+ bbox = (int(x1), int(y1), int(x2 - x1), int(y2 - y1))
185
+
186
+ # Infer flags from class name
187
+ is_evolution = "-evolution" in cls_name or "evolution" in cls_name
188
+ is_hero = cls_name in {
189
+ "skeleton-king", "golden-knight", "archer-queen",
190
+ "monk", "mighty-miner", "little-prince", "royal-guardian"
191
+ }
192
+ is_building = cls_name in {
193
+ "cannon", "tesla", "inferno-tower", "bomb-tower",
194
+ "mortar", "x-bow", "elixir-collector", "furnace",
195
+ "goblin-hut", "barbarian-hut", "tombstone"
196
+ }
197
+ is_spell = cls_name in {
198
+ "fireball", "zap", "arrows", "poison", "freeze",
199
+ "rage", "tornado", "earthquake", "the-log",
200
+ "barbarian-barrel", "clone", "mirror", "royal-delivery",
201
+ "giant-snowball", "lightning", "rocket", "graveyard"
202
+ }
203
+
204
+ ev = UnitEvidence(
205
+ unit_name=cls_name,
206
+ bbox=bbox,
207
+ confidence=conf,
208
+ side=side,
209
+ frame_idx=self._frame_idx,
210
+ timestamp=time.monotonic(),
211
+ is_evolution=is_evolution,
212
+ is_hero=is_hero,
213
+ is_building=is_building,
214
+ is_spell_effect=is_spell,
215
+ )
216
+ evidence.append(ev)
217
+ return evidence
218
+
219
+ def detect_spells(self, frame: np.ndarray, side: str) -> List[SpellEvidence]:
220
+ """Heuristic spell effect detection based on color/motion signatures."""
221
+ if not self.spell_enabled:
222
+ return []
223
+
224
+ hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
225
+ evidence = []
226
+ for spell_name, sig in self.SPELL_SIGNATURES.items():
227
+ lower, upper = sig["hsv_ranges"]
228
+ mask = cv2.inRange(hsv, np.array(lower), np.array(upper))
229
+ # Morphological cleanup
230
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
231
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
232
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
233
+
234
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8)
235
+ for i in range(1, num_labels):
236
+ area = stats[i, cv2.CC_STAT_AREA]
237
+ if area < sig["min_area"]:
238
+ continue
239
+ x = stats[i, cv2.CC_STAT_LEFT]
240
+ y = stats[i, cv2.CC_STAT_TOP]
241
+ w = stats[i, cv2.CC_STAT_WIDTH]
242
+ h = stats[i, cv2.CC_STAT_HEIGHT]
243
+
244
+ # Track duration to avoid re-detecting same spell
245
+ history = self._spell_history.setdefault(spell_name, [])
246
+ history.append(self._frame_idx)
247
+ history[:] = [f for f in history if self._frame_idx - f <= sig["max_duration_frames"]]
248
+ if len(history) > 1:
249
+ # Already detected recently
250
+ continue
251
+
252
+ ev = SpellEvidence(
253
+ spell_name=spell_name,
254
+ effect_mask=mask[y:y+h, x:x+w].copy(),
255
+ bbox=(x, y, w, h),
256
+ confidence=min(area / (sig["min_area"] * 3), 1.0),
257
+ side=side,
258
+ frame_idx=self._frame_idx,
259
+ timestamp=time.monotonic(),
260
+ ambiguity_reason="heuristic_color",
261
+ )
262
+ evidence.append(ev)
263
+ return evidence
264
+
265
+ def process(self, frame: np.ndarray, side: str = "opponent") -> EvidenceBundle:
266
+ """Run full evidence extraction on a frame crop."""
267
+ units = self.detect_units(frame, side)
268
+ spells = self.detect_spells(frame, side)
269
+
270
+ # Build possible cards from evidence
271
+ possible = set()
272
+ for u in units:
273
+ possible.add(u.unit_name)
274
+ for s in spells:
275
+ possible.add(s.spell_name)
276
+
277
+ ambiguity = ""
278
+ if not units and not spells:
279
+ ambiguity = "no_evidence"
280
+ elif len(possible) > 3:
281
+ ambiguity = f"too_many_candidates:{len(possible)}"
282
+
283
+ bundle = EvidenceBundle(
284
+ timestamp=time.monotonic(),
285
+ frame_idx=self._frame_idx,
286
+ side=side,
287
+ units=units,
288
+ spells=spells,
289
+ possible_cards=sorted(possible),
290
+ confidence=0.0 if ambiguity else 0.7,
291
+ ambiguity_reason=ambiguity,
292
+ )
293
+ self._frame_idx += 1
294
+ return bundle
295
+
296
+ def reset(self) -> None:
297
+ self._spell_history.clear()
298
+ self._frame_idx = 0