Andy0811 commited on
Commit
53c1ff6
·
verified ·
1 Parent(s): ea8444e

Upload 5 files

Browse files
Files changed (5) hide show
  1. main.py +346 -0
  2. model/best.pt +3 -0
  3. model/labels.txt +3 -0
  4. requirements.txt +18 -0
  5. verdict_logic.py +327 -0
main.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for YOLO PyTorch image classification
3
+ Accepts image uploads and returns verdict (green/red) based on detected classes
4
+ Includes GPU detection and automatic device selection
5
+ """
6
+
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydantic import BaseModel
11
+ from typing import List, Dict, Optional
12
+ import numpy as np
13
+ import cv2
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import logging
17
+ import base64
18
+
19
+ from yolo_inference_pytorch import YOLOInference
20
+ from verdict_logic import VerdictEngine
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ app = FastAPI(
26
+ title="YOLO PyTorch Image Classifier - Engineer Selfie",
27
+ description="Upload an image to get classification verdict (green/red) based on detected objects. Supports GPU acceleration.",
28
+ version="2.0.0"
29
+ )
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ class Base64ImageRequest(BaseModel):
40
+ image: str
41
+
42
+ class Config:
43
+ json_schema_extra = {
44
+ "example": {
45
+ "image": "data:image/jpeg;base64,/9j/4AAQSkZJRg..."
46
+ }
47
+ }
48
+
49
+ class Detection(BaseModel):
50
+ class_id: int
51
+ class_name: str
52
+ confidence: float
53
+ bbox: List[float]
54
+
55
+ class VerdictResponse(BaseModel):
56
+ verdict: str
57
+ confidence: float
58
+ detections: List[Detection]
59
+ message: str
60
+ image_size: Dict[str, int]
61
+
62
+ yolo_model = None
63
+ verdict_engine = None
64
+
65
+ @app.on_event("startup")
66
+ async def startup_event():
67
+ """Initialize models on startup"""
68
+ global yolo_model, verdict_engine
69
+
70
+ try:
71
+ import os
72
+ model_path = "model/best.pt"
73
+
74
+ if not os.path.exists(model_path):
75
+ logger.warning(f"⚠️ YOLO model not found at {model_path}")
76
+ logger.warning("⚠️ The API will start but /predict endpoint will not work until you add the model")
77
+ logger.warning("⚠️ Please place your best.pt file in the model/ directory")
78
+
79
+ verdict_engine = VerdictEngine(
80
+ green_classes=[],
81
+ red_classes=[],
82
+ rules_path="config/verdict_rules.json"
83
+ )
84
+ logger.info("Verdict engine initialized successfully")
85
+ return
86
+
87
+ logger.info("Loading YOLO PyTorch model with GPU detection...")
88
+ yolo_model = YOLOInference(
89
+ model_path=model_path,
90
+ labels_path="model/labels.txt",
91
+ conf_threshold=0.5,
92
+ iou_threshold=0.45,
93
+ device=None
94
+ )
95
+ logger.info("YOLO model loaded successfully")
96
+ logger.info(f"Using confidence threshold: 0.5 (50%)")
97
+
98
+ logger.info("Initializing verdict engine...")
99
+ verdict_engine = VerdictEngine(
100
+ green_classes=[],
101
+ red_classes=[],
102
+ rules_path="config/verdict_rules.json"
103
+ )
104
+ logger.info("Verdict engine initialized successfully")
105
+
106
+ except Exception as e:
107
+ logger.error(f"Error during startup: {str(e)}")
108
+ raise
109
+
110
+ @app.get("/")
111
+ async def root():
112
+ """Root endpoint with API information"""
113
+ return {
114
+ "message": "YOLO PyTorch Image Classifier API - Engineer Selfie",
115
+ "version": "2.0.0",
116
+ "description": "Upload images to check for object detection and get approval verdicts. GPU-accelerated when available.",
117
+ "endpoints": {
118
+ "check_image_base64": "/check-image (POST with base64)",
119
+ "check_image_upload": "/check-image/upload (POST with file)",
120
+ "check_images_batch": "/check-images/batch (POST with multiple files)",
121
+ "health": "/health (GET)",
122
+ "model_info": "/model/info (GET)",
123
+ "device_info": "/device/info (GET)"
124
+ },
125
+ "usage": {
126
+ "base64": "POST /check-image with JSON body: {\"image\": \"base64_string\"}",
127
+ "file_upload": "POST /check-image/upload with multipart/form-data",
128
+ "batch": "POST /check-images/batch with multiple files"
129
+ }
130
+ }
131
+
132
+ @app.get("/health")
133
+ async def health_check():
134
+ """Health check endpoint"""
135
+ return {
136
+ "status": "healthy",
137
+ "model_loaded": yolo_model is not None,
138
+ "verdict_engine_loaded": verdict_engine is not None,
139
+ "device": yolo_model.device if yolo_model else "not loaded"
140
+ }
141
+
142
+ @app.get("/device/info")
143
+ async def device_info():
144
+ """Get GPU/device information"""
145
+ if yolo_model is None:
146
+ raise HTTPException(status_code=503, detail="Model not loaded")
147
+
148
+ return yolo_model.get_device_info()
149
+
150
+ @app.get("/model/info")
151
+ async def model_info():
152
+ """Get model information"""
153
+ if yolo_model is None:
154
+ raise HTTPException(status_code=503, detail="Model not loaded")
155
+
156
+ return {
157
+ "model_type": "YOLO PyTorch",
158
+ "model_path": yolo_model.model_path,
159
+ "classes": yolo_model.get_class_names(),
160
+ "num_classes": len(yolo_model.get_class_names()),
161
+ "input_size": yolo_model.get_input_size(),
162
+ "confidence_threshold": yolo_model.conf_threshold,
163
+ "iou_threshold": yolo_model.iou_threshold,
164
+ "device_info": yolo_model.get_device_info()
165
+ }
166
+
167
+ def process_image_to_numpy(image_data: bytes) -> tuple[np.ndarray, int, int]:
168
+ """Convert image bytes to numpy array for processing"""
169
+ image = Image.open(BytesIO(image_data))
170
+
171
+ if image.mode != 'RGB':
172
+ image = image.convert('RGB')
173
+
174
+ image_np = np.array(image)
175
+
176
+ if len(image_np.shape) == 3 and image_np.shape[2] == 3:
177
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
178
+ elif len(image_np.shape) == 2:
179
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2BGR)
180
+
181
+ original_height, original_width = image_np.shape[:2]
182
+ return image_np, original_width, original_height
183
+
184
+
185
+ def create_verdict_response(detections: List[Dict], verdict_result: Dict, width: int, height: int) -> VerdictResponse:
186
+ """Create standardized verdict response"""
187
+ formatted_detections = [
188
+ Detection(
189
+ class_id=det["class_id"],
190
+ class_name=det["class_name"],
191
+ confidence=float(det["confidence"]),
192
+ bbox=[float(x) for x in det["bbox"]]
193
+ )
194
+ for det in detections
195
+ ]
196
+
197
+ return VerdictResponse(
198
+ verdict=verdict_result["verdict"],
199
+ confidence=verdict_result["confidence"],
200
+ detections=formatted_detections,
201
+ message=verdict_result["message"],
202
+ image_size={
203
+ "width": width,
204
+ "height": height
205
+ }
206
+ )
207
+
208
+
209
+ @app.post("/check-image", response_model=VerdictResponse)
210
+ async def check_image_base64(request: Base64ImageRequest):
211
+ """
212
+ CHECK IMAGE (Base64 Format)
213
+
214
+ Send a base64 encoded image and get a verdict (GREEN/RED) based on detected objects.
215
+ """
216
+ if yolo_model is None or verdict_engine is None:
217
+ raise HTTPException(status_code=503, detail="Model not loaded")
218
+
219
+ try:
220
+ base64_string = request.image
221
+
222
+ if "base64," in base64_string:
223
+ base64_string = base64_string.split("base64,")[1]
224
+
225
+ image_bytes = base64.b64decode(base64_string)
226
+ logger.info("Processing base64 encoded image")
227
+
228
+ image_np, original_width, original_height = process_image_to_numpy(image_bytes)
229
+ logger.info(f"Image size: {original_width}x{original_height}")
230
+
231
+ detections = yolo_model.predict(image_np)
232
+ logger.info(f"Found {len(detections)} detections")
233
+
234
+ verdict_result = verdict_engine.get_verdict(detections)
235
+
236
+ response = create_verdict_response(detections, verdict_result, original_width, original_height)
237
+
238
+ logger.info(f"Verdict: {verdict_result['verdict']} (confidence: {verdict_result['confidence']:.2f})")
239
+ return response
240
+
241
+ except base64.binascii.Error:
242
+ raise HTTPException(status_code=400, detail="Invalid base64 image data")
243
+ except Exception as e:
244
+ logger.error(f"Error processing image: {str(e)}")
245
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
246
+
247
+
248
+ @app.post("/check-image/upload", response_model=VerdictResponse)
249
+ async def check_image_upload(file: UploadFile = File(...)):
250
+ """
251
+ CHECK IMAGE (File Upload)
252
+
253
+ Upload an image file and get a verdict (GREEN/RED) based on detected objects.
254
+ """
255
+ if yolo_model is None or verdict_engine is None:
256
+ raise HTTPException(status_code=503, detail="Model not loaded")
257
+
258
+ if not file.content_type.startswith("image/"):
259
+ raise HTTPException(
260
+ status_code=400,
261
+ detail=f"Invalid file type: {file.content_type}. Please upload an image."
262
+ )
263
+
264
+ try:
265
+ logger.info(f"Processing image: {file.filename}")
266
+ contents = await file.read()
267
+
268
+ image_np, original_width, original_height = process_image_to_numpy(contents)
269
+ logger.info(f"Image size: {original_width}x{original_height}")
270
+
271
+ detections = yolo_model.predict(image_np)
272
+ logger.info(f"Found {len(detections)} detections")
273
+
274
+ verdict_result = verdict_engine.get_verdict(detections)
275
+
276
+ response = create_verdict_response(detections, verdict_result, original_width, original_height)
277
+
278
+ logger.info(f"Verdict: {verdict_result['verdict']} (confidence: {verdict_result['confidence']:.2f})")
279
+ return response
280
+
281
+ except Exception as e:
282
+ logger.error(f"Error processing image: {str(e)}")
283
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
284
+
285
+
286
+ @app.post("/predict", response_model=VerdictResponse)
287
+ async def predict(file: UploadFile = File(...)):
288
+ """
289
+ [LEGACY] Upload an image and get classification verdict
290
+
291
+ Note: Use /check-image/upload instead for clearer API naming
292
+ """
293
+ return await check_image_upload(file)
294
+
295
+ @app.post("/check-images/batch")
296
+ async def check_images_batch(files: List[UploadFile] = File(...)):
297
+ """
298
+ CHECK MULTIPLE IMAGES (Batch Upload)
299
+
300
+ Upload multiple image files (max 10) and get verdicts for each.
301
+ """
302
+ if len(files) > 10:
303
+ raise HTTPException(
304
+ status_code=400,
305
+ detail="Maximum 10 images allowed per batch"
306
+ )
307
+
308
+ results = []
309
+ for file in files:
310
+ try:
311
+ result = await check_image_upload(file)
312
+ results.append({
313
+ "filename": file.filename,
314
+ "result": result
315
+ })
316
+ except Exception as e:
317
+ results.append({
318
+ "filename": file.filename,
319
+ "error": str(e)
320
+ })
321
+
322
+ return {"results": results}
323
+
324
+
325
+ @app.post("/predict/batch")
326
+ async def predict_batch(files: List[UploadFile] = File(...)):
327
+ """
328
+ [LEGACY] Upload multiple images and get classification verdicts
329
+
330
+ Note: Use /check-images/batch instead for clearer API naming
331
+ """
332
+ return await check_images_batch(files)
333
+
334
+ if __name__ == "__main__":
335
+ import uvicorn
336
+ import sys
337
+
338
+ port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
339
+
340
+ print(f"\n{'='*70}")
341
+ print(f"🚀 Starting Engineer Selfie API on http://0.0.0.0:{port}")
342
+ print(f"📖 API docs available at http://localhost:{port}/docs")
343
+ print(f"🖥️ GPU acceleration will be used if available")
344
+ print(f"{'='*70}\n")
345
+
346
+ uvicorn.run(app, host="0.0.0.0", port=port)
model/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e9194db6fc94790b6902f48e8a5e9c04e025b47ca745bd0b3569d3814a1509c
3
+ size 6223722
model/labels.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ jio jersey
2
+ jio logo
3
+ person
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI and server dependencies
2
+ fastapi==0.109.0
3
+ uvicorn[standard]==0.27.0
4
+ python-multipart==0.0.6
5
+
6
+ # Machine Learning and Image Processing
7
+ opencv-python>=4.10.0
8
+ numpy>=1.24.0
9
+ Pillow==10.2.0
10
+ torch>=2.0.0
11
+ torchvision>=0.15.0
12
+ ultralytics>=8.0.0
13
+
14
+ # Data validation and utilities
15
+ pydantic==2.5.3
16
+
17
+ # Logging and monitoring (optional)
18
+ python-json-logger==2.0.7
verdict_logic.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Verdict Logic Engine
3
+ Determines green/red verdict based on detected classes
4
+ """
5
+
6
+ import json
7
+ from typing import List, Dict, Optional
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class VerdictEngine:
14
+ """
15
+ Engine to determine verdict (green/red) based on detected objects
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ green_classes: Optional[List[str]] = None,
21
+ red_classes: Optional[List[str]] = None,
22
+ rules_path: Optional[str] = None
23
+ ):
24
+ """
25
+ Initialize verdict engine
26
+
27
+ Args:
28
+ green_classes: List of class names that trigger green verdict
29
+ red_classes: List of class names that trigger red verdict
30
+ rules_path: Optional path to JSON file with custom rules
31
+ """
32
+ self.green_classes = green_classes or []
33
+ self.red_classes = red_classes or []
34
+ self.custom_rules = {}
35
+
36
+ # Load custom rules if provided
37
+ if rules_path:
38
+ try:
39
+ self.load_rules(rules_path)
40
+ except FileNotFoundError:
41
+ logger.warning(f"Rules file not found: {rules_path}")
42
+
43
+ # Default configuration if no rules specified
44
+ if not self.green_classes and not self.red_classes and not self.custom_rules:
45
+ logger.info("No verdict rules specified. Using default configuration.")
46
+ self._setup_default_rules()
47
+
48
+ def _setup_default_rules(self):
49
+ """
50
+ Setup default verdict rules
51
+ Customize this based on your specific use case
52
+ """
53
+ # Example default rules:
54
+ # - If all 3 classes detected: GREEN
55
+ # - If only specific classes detected: RED
56
+ # - If no classes detected: RED
57
+
58
+ self.custom_rules = {
59
+ "all_classes_required": False, # Set to True if all 3 classes must be present for GREEN
60
+ "min_detections_for_green": 1, # Minimum detections needed for GREEN
61
+ "min_confidence_for_green": 0.5, # Minimum confidence for GREEN verdict
62
+ "priority": "red" # If both green and red classes detected, which takes priority
63
+ }
64
+
65
+ logger.info("Default verdict rules configured")
66
+
67
+ def load_rules(self, rules_path: str):
68
+ """Load verdict rules from JSON file"""
69
+ with open(rules_path, 'r') as f:
70
+ rules = json.load(f)
71
+
72
+ self.green_classes = rules.get("green_classes", self.green_classes)
73
+ self.red_classes = rules.get("red_classes", self.red_classes)
74
+ self.custom_rules = rules.get("custom_rules", self.custom_rules)
75
+
76
+ logger.info(f"Loaded rules from {rules_path}")
77
+ logger.info(f"Green classes: {self.green_classes}")
78
+ logger.info(f"Red classes: {self.red_classes}")
79
+
80
+ def get_verdict(self, detections: List[Dict]) -> Dict:
81
+ """
82
+ Determine verdict based on detections
83
+
84
+ New Logic:
85
+ - GREEN: All 3 labels present (jio jersey, jio logo, person) AND exactly 1 person
86
+ - RED: Any condition not met
87
+
88
+ Args:
89
+ detections: List of detection dictionaries from YOLO
90
+
91
+ Returns:
92
+ Dictionary with verdict, confidence, and message
93
+ """
94
+ if len(detections) == 0:
95
+ return {
96
+ "verdict": "red",
97
+ "confidence": 1.0,
98
+ "message": "Image rejected. No objects detected in the image. Please ensure the image contains: Jio jersey, Jio logo, and exactly one person."
99
+ }
100
+
101
+ # Extract detected class names and confidences
102
+ detected_classes = [det["class_name"] for det in detections]
103
+ confidences = [det["confidence"] for det in detections]
104
+ avg_confidence = sum(confidences) / len(confidences)
105
+
106
+ # Check which required labels are present
107
+ unique_classes = set(detected_classes)
108
+
109
+ # Count number of persons detected
110
+ person_count = detected_classes.count("person")
111
+
112
+ # Check conditions
113
+ has_jio_jersey = "jio jersey" in unique_classes
114
+ has_jio_logo = "jio logo" in unique_classes
115
+ has_person = "person" in unique_classes
116
+ has_exactly_one_person = person_count == 1
117
+
118
+ # All conditions must be satisfied for GREEN
119
+ all_labels_present = has_jio_jersey and has_jio_logo and has_person
120
+ all_conditions_met = all_labels_present and has_exactly_one_person
121
+
122
+ logger.info(f"Detected classes: {detected_classes}")
123
+ logger.info(f"Unique classes: {unique_classes}")
124
+ logger.info(f"Person count: {person_count}")
125
+ logger.info(f"All conditions met: {all_conditions_met}")
126
+
127
+ # Build user-friendly message
128
+ if all_conditions_met:
129
+ message = "Image approved! All requirements met: Jio jersey found, Jio logo found, and exactly one person detected."
130
+ verdict = "green"
131
+ else:
132
+ # Build detailed explanation of what's missing or wrong
133
+ issues = []
134
+
135
+ if not has_jio_jersey:
136
+ issues.append("Jio jersey not detected")
137
+ if not has_jio_logo:
138
+ issues.append("Jio logo not detected")
139
+ if not has_person:
140
+ issues.append("No person detected")
141
+ elif person_count == 0:
142
+ issues.append("No person detected")
143
+ elif person_count > 1:
144
+ issues.append(f"Multiple people detected ({person_count} found, need exactly 1)")
145
+
146
+ # Build the message
147
+ if issues:
148
+ message = f"Image rejected. {'. '.join(issues)}. Requirements: Jio jersey, Jio logo, and exactly one person must be present."
149
+ else:
150
+ message = "Image rejected. Please ensure all requirements are met: Jio jersey, Jio logo, and exactly one person."
151
+
152
+ verdict = "red"
153
+
154
+ return {
155
+ "verdict": verdict,
156
+ "confidence": avg_confidence,
157
+ "message": message
158
+ }
159
+
160
+ def _apply_verdict_logic(
161
+ self,
162
+ detected_classes: List[str],
163
+ unique_classes: set,
164
+ num_unique_classes: int,
165
+ max_confidence: float,
166
+ avg_confidence: float,
167
+ num_detections: int
168
+ ) -> tuple:
169
+ """
170
+ Apply verdict logic based on detected classes
171
+
172
+ Returns:
173
+ Tuple of (verdict, message)
174
+ """
175
+ # Strategy 1: Check if specific green/red classes are defined
176
+ if self.green_classes or self.red_classes:
177
+ return self._apply_class_based_logic(
178
+ detected_classes, unique_classes, avg_confidence
179
+ )
180
+
181
+ # Strategy 2: Use custom rules
182
+ if self.custom_rules:
183
+ return self._apply_custom_rules(
184
+ num_unique_classes, num_detections, avg_confidence
185
+ )
186
+
187
+ # Strategy 3: Default logic (customize based on your needs)
188
+ return self._apply_default_logic(
189
+ num_unique_classes, num_detections, avg_confidence
190
+ )
191
+
192
+ def _apply_class_based_logic(
193
+ self,
194
+ detected_classes: List[str],
195
+ unique_classes: set,
196
+ avg_confidence: float
197
+ ) -> tuple:
198
+ """Apply logic based on specific green/red class lists"""
199
+
200
+ # Check for red classes first (if priority is red)
201
+ priority = self.custom_rules.get("priority", "red")
202
+
203
+ if priority == "red" and self.red_classes:
204
+ red_detected = any(cls in self.red_classes for cls in detected_classes)
205
+ if red_detected:
206
+ return (
207
+ "red",
208
+ f"Detected prohibited object(s): {', '.join(unique_classes & set(self.red_classes))}"
209
+ )
210
+
211
+ # Check for green classes
212
+ if self.green_classes:
213
+ green_detected = any(cls in self.green_classes for cls in detected_classes)
214
+ if green_detected:
215
+ return (
216
+ "green",
217
+ f"Detected approved object(s): {', '.join(unique_classes & set(self.green_classes))}"
218
+ )
219
+
220
+ # Check for red classes (if priority is green)
221
+ if priority == "green" and self.red_classes:
222
+ red_detected = any(cls in self.red_classes for cls in detected_classes)
223
+ if red_detected:
224
+ return (
225
+ "red",
226
+ f"Detected prohibited object(s): {', '.join(unique_classes & set(self.red_classes))}"
227
+ )
228
+
229
+ # Default to red if no matching classes
230
+ return (
231
+ "red",
232
+ f"Detected objects do not match approved criteria: {', '.join(unique_classes)}"
233
+ )
234
+
235
+ def _apply_custom_rules(
236
+ self,
237
+ num_unique_classes: int,
238
+ num_detections: int,
239
+ avg_confidence: float
240
+ ) -> tuple:
241
+ """Apply custom rules from configuration"""
242
+
243
+ all_classes_required = self.custom_rules.get("all_classes_required", False)
244
+ min_detections = self.custom_rules.get("min_detections_for_green", 1)
245
+ min_confidence = self.custom_rules.get("min_confidence_for_green", 0.5)
246
+
247
+ # Check if all 3 classes are required
248
+ if all_classes_required:
249
+ if num_unique_classes == 3:
250
+ if avg_confidence >= min_confidence:
251
+ return (
252
+ "green",
253
+ f"All required classes detected with confidence {avg_confidence:.2%}"
254
+ )
255
+ else:
256
+ return (
257
+ "red",
258
+ f"All classes detected but confidence too low: {avg_confidence:.2%}"
259
+ )
260
+ else:
261
+ return (
262
+ "red",
263
+ f"Only {num_unique_classes}/3 required classes detected"
264
+ )
265
+
266
+ # Check minimum detections
267
+ if num_detections >= min_detections and avg_confidence >= min_confidence:
268
+ return (
269
+ "green",
270
+ f"Detected {num_detections} object(s) with confidence {avg_confidence:.2%}"
271
+ )
272
+
273
+ return (
274
+ "red",
275
+ f"Insufficient detections ({num_detections}) or confidence ({avg_confidence:.2%})"
276
+ )
277
+
278
+ def _apply_default_logic(
279
+ self,
280
+ num_unique_classes: int,
281
+ num_detections: int,
282
+ avg_confidence: float
283
+ ) -> tuple:
284
+ """
285
+ Apply default verdict logic
286
+
287
+ Default strategy:
288
+ - GREEN: If at least 1 class detected with good confidence
289
+ - RED: If no detections or low confidence
290
+
291
+ Customize this method based on your specific requirements
292
+ """
293
+
294
+ # Example 1: Simple presence-based logic
295
+ if num_detections >= 1 and avg_confidence >= 0.5:
296
+ return (
297
+ "green",
298
+ f"Detected {num_detections} object(s) across {num_unique_classes} class(es)"
299
+ )
300
+
301
+ # Example 2: All classes must be present
302
+ # if num_unique_classes == 3 and avg_confidence >= 0.6:
303
+ # return ("green", "All 3 classes detected with high confidence")
304
+
305
+ # Example 3: Specific number of detections
306
+ # if num_detections >= 2 and avg_confidence >= 0.5:
307
+ # return ("green", f"Multiple objects detected ({num_detections})")
308
+
309
+ return (
310
+ "red",
311
+ f"Criteria not met: {num_detections} detection(s), {num_unique_classes} class(es)"
312
+ )
313
+
314
+ def set_green_classes(self, classes: List[str]):
315
+ """Set classes that trigger green verdict"""
316
+ self.green_classes = classes
317
+ logger.info(f"Green classes set to: {classes}")
318
+
319
+ def set_red_classes(self, classes: List[str]):
320
+ """Set classes that trigger red verdict"""
321
+ self.red_classes = classes
322
+ logger.info(f"Red classes set to: {classes}")
323
+
324
+ def update_rules(self, rules: Dict):
325
+ """Update custom rules"""
326
+ self.custom_rules.update(rules)
327
+ logger.info(f"Custom rules updated: {self.custom_rules}")