Somin-Aggarwal commited on
Commit
01a014b
·
verified ·
1 Parent(s): f84ee46

Upload 3 files

Browse files
Files changed (3) hide show
  1. server/corruption.py +251 -0
  2. server/environment.py +499 -0
  3. server/grader.py +148 -0
server/corruption.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Annotation corruption strategies for the Annotation QA Environment.
3
+
4
+ Takes gold-standard COCO annotations and systematically corrupts them to create
5
+ data with known errors. The corruption is deterministic given a seed.
6
+
7
+ Corruption types by difficulty:
8
+ - Task 1 (Easy): Obvious bbox errors — expand, shift, delete, add spurious
9
+ - Task 2 (Medium): bbox + class errors — similar class confusion, boundary errors
10
+ - Task 3 (Hard): Cross-image inconsistencies + subtle errors
11
+ """
12
+
13
+ import copy
14
+ import random
15
+ from typing import Dict, List, Tuple
16
+
17
+ # ──────────────────────────────────────────────
18
+ # COCO 80 categories
19
+ # ──────────────────────────────────────────────
20
+
21
+ ALL_CLASSES = [
22
+ "person", "bicycle", "car", "motorcycle", "airplane",
23
+ "bus", "train", "truck", "boat", "traffic light",
24
+ "fire hydrant", "stop sign", "parking meter", "bench",
25
+ "bird", "cat", "dog", "horse", "sheep",
26
+ "cow", "elephant", "bear", "zebra", "giraffe",
27
+ "backpack", "umbrella", "handbag", "tie", "suitcase",
28
+ "frisbee", "skis", "snowboard", "sports ball", "kite",
29
+ "baseball bat", "baseball glove", "skateboard", "surfboard",
30
+ "tennis racket", "bottle", "wine glass", "cup",
31
+ "fork", "knife", "spoon", "bowl", "banana",
32
+ "apple", "sandwich", "orange", "broccoli", "carrot",
33
+ "hot dog", "pizza", "donut", "cake", "chair",
34
+ "couch", "potted plant", "bed", "dining table",
35
+ "toilet", "tv", "laptop", "mouse", "remote",
36
+ "keyboard", "cell phone", "microwave", "oven",
37
+ "toaster", "sink", "refrigerator", "book", "clock",
38
+ "vase", "scissors", "teddy bear", "hair drier",
39
+ "toothbrush",
40
+ ]
41
+
42
+ # Class confusion maps — COCO-specific similar category pairs
43
+ SIMILAR_CLASSES: Dict[str, List[str]] = {
44
+ "car": ["truck", "bus"],
45
+ "truck": ["car", "bus"],
46
+ "bus": ["truck", "car"],
47
+ "motorcycle": ["bicycle"],
48
+ "bicycle": ["motorcycle"],
49
+ "dog": ["cat", "horse"],
50
+ "cat": ["dog"],
51
+ "horse": ["cow", "dog"],
52
+ "cow": ["horse", "sheep"],
53
+ "sheep": ["cow"],
54
+ "elephant": ["bear"],
55
+ "bear": ["elephant"],
56
+ "zebra": ["giraffe", "horse"],
57
+ "giraffe": ["zebra"],
58
+ "bird": ["airplane", "kite"],
59
+ "airplane": ["bird", "kite"],
60
+ "chair": ["couch", "bench"],
61
+ "couch": ["chair", "bed"],
62
+ "bed": ["couch"],
63
+ "bench": ["chair"],
64
+ "dining table": ["bed"],
65
+ "bottle": ["cup", "wine glass", "vase"],
66
+ "cup": ["bottle", "wine glass", "bowl"],
67
+ "wine glass": ["cup", "bottle"],
68
+ "bowl": ["cup"],
69
+ "fork": ["knife", "spoon"],
70
+ "knife": ["fork", "spoon", "scissors"],
71
+ "spoon": ["fork", "knife"],
72
+ "scissors": ["knife"],
73
+ "banana": ["hot dog"],
74
+ "hot dog": ["banana", "sandwich"],
75
+ "pizza": ["cake", "donut"],
76
+ "donut": ["pizza", "cake", "apple", "orange"],
77
+ "cake": ["pizza", "donut"],
78
+ "apple": ["orange", "donut", "sports ball"],
79
+ "orange": ["apple", "donut", "sports ball"],
80
+ "sandwich": ["hot dog", "pizza"],
81
+ "broccoli": ["potted plant"],
82
+ "carrot": ["banana"],
83
+ "potted plant": ["broccoli", "vase"],
84
+ "tv": ["laptop", "microwave"],
85
+ "laptop": ["tv", "keyboard"],
86
+ "keyboard": ["laptop", "remote"],
87
+ "remote": ["cell phone", "keyboard"],
88
+ "cell phone": ["remote"],
89
+ "mouse": ["remote"],
90
+ "microwave": ["oven", "tv"],
91
+ "oven": ["microwave", "refrigerator"],
92
+ "toaster": ["microwave"],
93
+ "refrigerator": ["oven"],
94
+ "sink": ["toilet", "bowl"],
95
+ "toilet": ["sink", "chair"],
96
+ "book": ["laptop", "cell phone"],
97
+ "clock": ["sports ball"],
98
+ "vase": ["bottle", "cup"],
99
+ "backpack": ["suitcase", "handbag"],
100
+ "handbag": ["backpack", "suitcase"],
101
+ "suitcase": ["backpack", "handbag"],
102
+ "umbrella": ["kite"],
103
+ "tie": ["person"],
104
+ "frisbee": ["sports ball", "kite"],
105
+ "sports ball": ["frisbee", "apple", "orange"],
106
+ "kite": ["bird", "umbrella", "frisbee"],
107
+ "baseball bat": ["tennis racket", "surfboard"],
108
+ "baseball glove": ["backpack"],
109
+ "skateboard": ["surfboard", "snowboard"],
110
+ "surfboard": ["skateboard", "snowboard"],
111
+ "snowboard": ["skateboard", "surfboard", "skis"],
112
+ "skis": ["snowboard"],
113
+ "teddy bear": ["person", "dog"],
114
+ "hair drier": ["toothbrush"],
115
+ "toothbrush": ["hair drier"],
116
+ "person": ["teddy bear"],
117
+ "train": ["bus", "truck"],
118
+ "boat": ["surfboard"],
119
+ "traffic light": ["fire hydrant", "parking meter", "stop sign"],
120
+ "fire hydrant": ["traffic light", "parking meter"],
121
+ "stop sign": ["traffic light", "parking meter"],
122
+ "parking meter": ["fire hydrant", "stop sign"],
123
+ }
124
+
125
+
126
+ def generate_spurious_annotation(
127
+ existing_bboxes: List[List[float]], rng: random.Random
128
+ ) -> Dict:
129
+ """Generate a random annotation that doesn't overlap much with existing ones."""
130
+ for _ in range(20): # try up to 20 times
131
+ w = rng.uniform(0.05, 0.20)
132
+ h = rng.uniform(0.05, 0.20)
133
+ x = rng.uniform(0.0, 1.0 - w)
134
+ y = rng.uniform(0.0, 1.0 - h)
135
+ bbox = [round(x, 4), round(y, 4), round(w, 4), round(h, 4)]
136
+
137
+ # Check it doesn't overlap too much with existing
138
+ from .grader import compute_iou
139
+
140
+ max_iou = max(
141
+ (compute_iou(bbox, eb) for eb in existing_bboxes), default=0.0
142
+ )
143
+ if max_iou < 0.3:
144
+ cls = rng.choice(ALL_CLASSES)
145
+ return {"bbox": bbox, "class_label": cls}
146
+
147
+ # Fallback: place it anyway
148
+ return {
149
+ "bbox": [round(rng.uniform(0.0, 0.8), 4), round(rng.uniform(0.0, 0.8), 4), 0.1, 0.1],
150
+ "class_label": rng.choice(ALL_CLASSES),
151
+ }
152
+
153
+
154
+ def corrupt_annotations(
155
+ gold_annotations: List[Dict],
156
+ difficulty: str,
157
+ seed: int,
158
+ ) -> Tuple[List[Dict], List[str]]:
159
+ """
160
+ Corrupt gold annotations conceptually (no geometry shifts) based on difficulty level.
161
+
162
+ Difficulties:
163
+ - "spurious": Adds 2-4 entirely fake boxes.
164
+ - "classes": Swaps 30% of class labels (similar and different) + adds some spurious.
165
+ - "missing": Deletes 15-20% of annotations completely. VLM must FLAG_MISSING.
166
+ """
167
+ rng = random.Random(seed)
168
+ corrupted = copy.deepcopy(gold_annotations)
169
+ log = []
170
+
171
+ if difficulty == "spurious":
172
+ # Task 1: Spurious removal only
173
+ existing_bboxes = [a["bbox"] for a in corrupted]
174
+ n_spurious = rng.randint(2, 4)
175
+ next_id = max((a["id"] for a in corrupted), default=0) + 1
176
+ for i in range(n_spurious):
177
+ spur = generate_spurious_annotation(existing_bboxes, rng)
178
+ spur["id"] = next_id + i
179
+ corrupted.append(spur)
180
+ existing_bboxes.append(spur["bbox"])
181
+ log.append(f"Added spurious ann {spur['id']} ({spur['class_label']})")
182
+
183
+ elif difficulty == "classes":
184
+ # Task 2: Fix Classes
185
+ corruption_rate = 0.30
186
+ n_corrupt = max(2, int(len(corrupted) * corruption_rate))
187
+ indices = list(range(len(corrupted)))
188
+ rng.shuffle(indices)
189
+ corrupt_indices = indices[:n_corrupt]
190
+
191
+ for idx in corrupt_indices:
192
+ action = rng.choice(["wrong_similar_class", "wrong_different_class"])
193
+ ann = corrupted[idx]
194
+ old_cls = ann["class_label"]
195
+
196
+ if action == "wrong_similar_class":
197
+ similar = SIMILAR_CLASSES.get(old_cls, [])
198
+ if similar:
199
+ new_cls = rng.choice(similar)
200
+ ann["class_label"] = new_cls
201
+ log.append(f"Changed ann {ann['id']} class: {old_cls} → {new_cls} (similar)")
202
+ else:
203
+ candidates = [c for c in ALL_CLASSES if c != old_cls]
204
+ ann["class_label"] = rng.choice(candidates)
205
+ log.append(f"Changed ann {ann['id']} class: {old_cls} → {ann['class_label']} (fallback)")
206
+
207
+ elif action == "wrong_different_class":
208
+ candidates = [c for c in ALL_CLASSES if c != old_cls]
209
+ ann["class_label"] = rng.choice(candidates)
210
+ log.append(f"Changed ann {ann['id']} class: {old_cls} → {ann['class_label']} (different)")
211
+
212
+ # Add 1-2 spurious just to keep them on their toes
213
+ existing_bboxes = [a["bbox"] for a in corrupted]
214
+ n_spurious = rng.randint(1, 2)
215
+ next_id = max((a["id"] for a in corrupted), default=0) + 1
216
+ for i in range(n_spurious):
217
+ spur = generate_spurious_annotation(existing_bboxes, rng)
218
+ spur["id"] = next_id + i
219
+ corrupted.append(spur)
220
+ existing_bboxes.append(spur["bbox"])
221
+ log.append(f"Added spurious ann {spur['id']} ({spur['class_label']})")
222
+
223
+ elif difficulty == "missing":
224
+ # Task 3: Missing items evaluation
225
+ # Randomly delete 15-20% of annotations completely
226
+ delete_rate = rng.uniform(0.15, 0.20)
227
+ n_delete = max(1, int(len(corrupted) * delete_rate))
228
+ indices = list(range(len(corrupted)))
229
+ rng.shuffle(indices)
230
+ delete_indices = indices[:n_delete]
231
+
232
+ for idx in delete_indices:
233
+ ann = corrupted[idx]
234
+ log.append(f"Missing Obj Created: Removed ann {ann['id']} ({ann['class_label']})")
235
+ corrupted[idx] = None
236
+
237
+ corrupted = [a for a in corrupted if a is not None]
238
+
239
+ # Also add a little bit of class confusion
240
+ corruption_rate = 0.20
241
+ n_corrupt = max(1, int(len(corrupted) * corruption_rate))
242
+ remaining_indices = list(range(len(corrupted)))
243
+ rng.shuffle(remaining_indices)
244
+ for idx in remaining_indices[:n_corrupt]:
245
+ ann = corrupted[idx]
246
+ old_cls = ann["class_label"]
247
+ candidates = [c for c in ALL_CLASSES if c != old_cls]
248
+ ann["class_label"] = rng.choice(candidates)
249
+ log.append(f"Changed class: {old_cls} -> {ann['class_label']}")
250
+
251
+ return corrupted, log
server/environment.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Annotation QA Environment — Core Environment Logic.
3
+
4
+ Implements the OpenEnv 3-method interface:
5
+ - reset(task_id) → Observation
6
+ - step(action) → Observation
7
+ - state → State
8
+
9
+ The agent reviews intentionally-flawed annotations on real COCO val2017 images
10
+ and must correct bounding boxes, fix class labels, add missing annotations,
11
+ or remove spurious ones. Dense reward is provided at every step.
12
+ """
13
+
14
+ import copy
15
+ import json
16
+ import os
17
+ import random
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional
20
+ from uuid import uuid4
21
+
22
+ try:
23
+ from openenv.core.env_server.types import Action, Observation, State
24
+ except ImportError:
25
+ # Fallback for standalone
26
+ pass
27
+
28
+ try:
29
+ from ..models import (
30
+ Annotation,
31
+ AnnotationQAAction,
32
+ AnnotationQAObservation,
33
+ AnnotationQAState,
34
+ )
35
+ except ImportError:
36
+ from models import (
37
+ Annotation,
38
+ AnnotationQAAction,
39
+ AnnotationQAObservation,
40
+ AnnotationQAState,
41
+ )
42
+ from .corruption import ALL_CLASSES, corrupt_annotations
43
+ from .grader import (
44
+ compute_annotation_quality,
45
+ compute_step_reward,
46
+ grade_episode,
47
+ )
48
+
49
+
50
+ # ──────────────────────────────────────────────
51
+ # Task definitions
52
+ # ──────────────────────────────────────────────
53
+
54
+ TASK_CONFIGS = {
55
+ "remove_spurious": {
56
+ "description": (
57
+ "Spurious Box Removal Task. Fake bounding boxes have been randomly drawn. "
58
+ "Identify and remove any annotations that do not strictly bound a real object."
59
+ ),
60
+ "difficulty": "spurious",
61
+ "max_steps": 15,
62
+ "data_file": "task1_remove_spurious/samples.json",
63
+ },
64
+ "fix_classes": {
65
+ "description": (
66
+ "Class Identification Task. Some bounding boxes have incorrect class labels, "
67
+ "and some are completely fake (spurious). Fix class labels using "
68
+ "CHANGE_CLASS and REMOVE spurious labels."
69
+ ),
70
+ "difficulty": "classes",
71
+ "max_steps": 20,
72
+ "data_file": "task2_fix_classes/samples.json",
73
+ },
74
+ "find_missing": {
75
+ "description": (
76
+ "Contextual Object Detection Task. Bounding boxes for key objects have been "
77
+ "entirely removed from the image. You must meticulously identify what object classes "
78
+ "are completely missing from the drawn bounding boxes and flag them."
79
+ ),
80
+ "difficulty": "missing",
81
+ "max_steps": 30,
82
+ "data_file": "task3_find_missing/samples.json",
83
+ },
84
+ }
85
+
86
+
87
+ class AnnotationQAEnvironment:
88
+ """
89
+ Annotation QA Environment following the OpenEnv pattern.
90
+
91
+ The agent reviews real COCO val2017 image annotations that contain
92
+ intentional errors and must correct them through a series of actions.
93
+ A VLM is used to visually inspect the images.
94
+ """
95
+
96
+ SUPPORTS_CONCURRENT_SESSIONS = True
97
+
98
+ def __init__(self):
99
+ self._state = AnnotationQAState()
100
+ self._gold_annotations: List[Dict] = []
101
+ self._initial_annotations: List[Dict] = []
102
+ self._current_annotations: List[Dict] = []
103
+ self._scene_data: Dict[str, Any] = {}
104
+ self._task_config: Dict[str, Any] = {}
105
+ self._corrections_made: int = 0
106
+ self._done: bool = False
107
+ self._data_cache: Dict[str, Any] = {}
108
+ self._next_ann_id: int = 0
109
+
110
+ # Load data directory
111
+ self._data_dir = Path(__file__).parent.parent / "data" / "tasks"
112
+
113
+ def _load_task_data(self, task_id: str) -> List[Dict]:
114
+ """Load and cache task data from disk."""
115
+ if task_id in self._data_cache:
116
+ return self._data_cache[task_id]
117
+
118
+ config = TASK_CONFIGS[task_id]
119
+ data_file = self._data_dir / config["data_file"]
120
+
121
+ if not data_file.exists():
122
+ raise FileNotFoundError(
123
+ f"Task data file not found: {data_file}. "
124
+ f"Run 'python -m data.prepare_coco' to generate the COCO dataset."
125
+ )
126
+
127
+ with open(data_file, "r") as f:
128
+ data = json.load(f)
129
+
130
+ self._data_cache[task_id] = data
131
+ return data
132
+
133
+ def reset(
134
+ self,
135
+ seed: Optional[int] = None,
136
+ episode_id: Optional[str] = None,
137
+ task: Optional[str] = None,
138
+ **kwargs: Any,
139
+ ) -> AnnotationQAObservation:
140
+ """
141
+ Start a new episode.
142
+
143
+ Args:
144
+ seed: Random seed for reproducibility
145
+ episode_id: Optional episode ID
146
+ task: Task ID — one of "fix_bboxes", "fix_classes", "batch_audit"
147
+ """
148
+ task_id = task or kwargs.get("task_id", "remove_spurious")
149
+ if task_id not in TASK_CONFIGS:
150
+ task_id = "remove_spurious"
151
+
152
+ self._task_config = TASK_CONFIGS[task_id]
153
+ data = self._load_task_data(task_id)
154
+
155
+ # Select a random sample
156
+ rng = random.Random(seed) if seed is not None else random.Random()
157
+
158
+ scene = rng.choice(data)
159
+ sample_seed = scene.get("seed", rng.randint(0, 99999))
160
+
161
+ # Store gold annotations
162
+ self._gold_annotations = copy.deepcopy(scene["gold_annotations"])
163
+ self._scene_data = scene
164
+
165
+ # Create corrupted annotations
166
+ corrupted, corruption_log = corrupt_annotations(
167
+ self._gold_annotations,
168
+ self._task_config["difficulty"],
169
+ sample_seed,
170
+ )
171
+ self._initial_annotations = copy.deepcopy(corrupted)
172
+ self._current_annotations = copy.deepcopy(corrupted)
173
+ self._corrections_made = 0
174
+ self._done = False
175
+
176
+ # Track next annotation ID
177
+ self._next_ann_id = max((a["id"] for a in self._current_annotations), default=-1) + 1
178
+
179
+ # Compute initial quality
180
+ initial_quality = compute_annotation_quality(
181
+ self._initial_annotations, self._gold_annotations
182
+ )
183
+
184
+ self._state = AnnotationQAState(
185
+ episode_id=episode_id or str(uuid4()),
186
+ step_count=0,
187
+ task_id=task_id,
188
+ sample_id=scene.get("scene_id", "unknown"),
189
+ initial_quality=round(initial_quality, 4),
190
+ current_quality=round(initial_quality, 4),
191
+ corrections_made=0,
192
+ )
193
+
194
+ return self._build_observation(
195
+ reward=None,
196
+ message=(
197
+ f"Review the annotations for this COCO image. "
198
+ f"There are {len(self._current_annotations)} annotations. "
199
+ f"Some may have incorrect bounding boxes, wrong class labels, "
200
+ f"or be entirely spurious. Some objects may be missing annotations. "
201
+ f"You have {self._task_config['max_steps']} steps to fix them."
202
+ ),
203
+ )
204
+
205
+ def step(
206
+ self,
207
+ action: AnnotationQAAction,
208
+ timeout_s: Optional[float] = None,
209
+ **kwargs: Any,
210
+ ) -> AnnotationQAObservation:
211
+ """Execute a correction action and return updated observation with reward."""
212
+ if self._done:
213
+ return self._build_observation(
214
+ reward=0.0,
215
+ message="Episode is already done. Call reset() to start a new episode.",
216
+ )
217
+
218
+ self._state.step_count += 1
219
+ error_msg = None
220
+
221
+ # Save pre-action state for reward computation
222
+ old_annotations = copy.deepcopy(self._current_annotations)
223
+
224
+ # Process action
225
+ try:
226
+ if action.action_type == "adjust_bbox":
227
+ error_msg = self._handle_adjust_bbox(action)
228
+ elif action.action_type == "change_class":
229
+ error_msg = self._handle_change_class(action)
230
+ elif action.action_type == "add_annotation":
231
+ error_msg = self._handle_add_annotation(action)
232
+ elif action.action_type == "remove_annotation":
233
+ error_msg = self._handle_remove_annotation(action)
234
+ elif action.action_type == "submit":
235
+ return self._handle_submit()
236
+ elif action.action_type == "flag_safety":
237
+ error_msg = self._handle_flag_safety(action)
238
+ elif action.action_type == "change_attribute":
239
+ error_msg = self._handle_change_attribute(action)
240
+ elif action.action_type == "flag_missing":
241
+ error_msg = self._handle_flag_missing(action)
242
+ else:
243
+ error_msg = f"Unknown action_type: {action.action_type}"
244
+ except Exception as e:
245
+ error_msg = f"Error processing action: {str(e)}"
246
+
247
+ if error_msg is None:
248
+ self._corrections_made += 1
249
+ self._state.corrections_made = self._corrections_made
250
+
251
+ # Compute reward
252
+ if action.action_type == "flag_safety" and not error_msg:
253
+ reward = 0.20
254
+ elif action.action_type == "change_attribute" and not error_msg:
255
+ reward = 0.15
256
+ elif action.action_type == "flag_missing" and not error_msg:
257
+ reward = 0.25
258
+ else:
259
+ reward = compute_step_reward(
260
+ old_annotations,
261
+ self._current_annotations,
262
+ self._gold_annotations,
263
+ action.action_type,
264
+ )
265
+
266
+ # Update quality tracking
267
+ current_quality = compute_annotation_quality(
268
+ self._current_annotations, self._gold_annotations
269
+ )
270
+ self._state.current_quality = round(current_quality, 4)
271
+
272
+ # Check if max steps reached
273
+ if self._state.step_count >= self._task_config["max_steps"]:
274
+ self._done = True
275
+ final_score = grade_episode(
276
+ self._initial_annotations,
277
+ self._current_annotations,
278
+ self._gold_annotations,
279
+ )
280
+ return self._build_observation(
281
+ reward=final_score,
282
+ message=f"Max steps reached. Final score: {final_score:.3f}",
283
+ error=error_msg,
284
+ )
285
+
286
+ return self._build_observation(
287
+ reward=reward,
288
+ message=(
289
+ f"{'Error: ' + error_msg if error_msg else 'Correction applied.'} "
290
+ f"Quality: {current_quality:.3f} "
291
+ f"(was {self._state.initial_quality:.3f}). "
292
+ f"Steps remaining: {self._task_config['max_steps'] - self._state.step_count}"
293
+ ),
294
+ error=error_msg,
295
+ )
296
+
297
+ @property
298
+ def state(self) -> AnnotationQAState:
299
+ """Get current episode state."""
300
+ return self._state
301
+
302
+ def close(self) -> None:
303
+ """Clean up environment resources."""
304
+ pass
305
+
306
+ async def reset_async(self, **kwargs) -> AnnotationQAObservation:
307
+ """Async wrapper for reset (required by OpenEnv server interface)."""
308
+ return self.reset(**kwargs)
309
+
310
+ async def step_async(self, action: AnnotationQAAction, **kwargs) -> AnnotationQAObservation:
311
+ """Async wrapper for step (required by OpenEnv server interface)."""
312
+ return self.step(action, **kwargs)
313
+
314
+ # ──────────────────────────────────────────
315
+ # Action handlers
316
+ # ──────────────────────────────────────────
317
+
318
+ def _handle_adjust_bbox(self, action: AnnotationQAAction) -> Optional[str]:
319
+ """Adjust the bounding box of an existing annotation."""
320
+ if action.annotation_id is None:
321
+ return "annotation_id is required for adjust_bbox"
322
+ if action.new_bbox is None:
323
+ return "new_bbox is required for adjust_bbox"
324
+ if len(action.new_bbox) != 4:
325
+ return "new_bbox must have exactly 4 values [x, y, w, h]"
326
+
327
+ ann = self._find_annotation(action.annotation_id)
328
+ if ann is None:
329
+ return f"Annotation {action.annotation_id} not found"
330
+
331
+ # Validate bbox values
332
+ for v in action.new_bbox:
333
+ if not (0.0 <= v <= 1.0):
334
+ return "All bbox values must be between 0.0 and 1.0"
335
+
336
+ ann["bbox"] = [round(v, 4) for v in action.new_bbox]
337
+ return None
338
+
339
+ def _handle_change_class(self, action: AnnotationQAAction) -> Optional[str]:
340
+ """Change the class label of an existing annotation."""
341
+ if action.annotation_id is None:
342
+ return "annotation_id is required for change_class"
343
+ if action.new_class is None:
344
+ return "new_class is required for change_class"
345
+ if action.new_class not in ALL_CLASSES:
346
+ return f"Invalid class '{action.new_class}'. Valid: {ALL_CLASSES}"
347
+
348
+ ann = self._find_annotation(action.annotation_id)
349
+ if ann is None:
350
+ return f"Annotation {action.annotation_id} not found"
351
+
352
+ ann["class_label"] = action.new_class
353
+ return None
354
+
355
+ def _handle_add_annotation(self, action: AnnotationQAAction) -> Optional[str]:
356
+ """Add a new annotation."""
357
+ if action.new_bbox is None:
358
+ return "new_bbox is required for add_annotation"
359
+ if action.new_class is None:
360
+ return "new_class is required for add_annotation"
361
+ if len(action.new_bbox) != 4:
362
+ return "new_bbox must have exactly 4 values [x, y, w, h]"
363
+ if action.new_class not in ALL_CLASSES:
364
+ return f"Invalid class '{action.new_class}'. Valid: {ALL_CLASSES}"
365
+
366
+ for v in action.new_bbox:
367
+ if not (0.0 <= v <= 1.0):
368
+ return "All bbox values must be between 0.0 and 1.0"
369
+
370
+ new_ann = {
371
+ "id": self._next_ann_id,
372
+ "bbox": [round(v, 4) for v in action.new_bbox],
373
+ "class_label": action.new_class,
374
+ }
375
+ self._current_annotations.append(new_ann)
376
+ self._next_ann_id += 1
377
+ return None
378
+
379
+ def _handle_remove_annotation(self, action: AnnotationQAAction) -> Optional[str]:
380
+ """Remove an annotation."""
381
+ if action.annotation_id is None:
382
+ return "annotation_id is required for remove_annotation"
383
+
384
+ idx = self._find_annotation_index(action.annotation_id)
385
+ if idx is None:
386
+ return f"Annotation {action.annotation_id} not found"
387
+
388
+ self._current_annotations.pop(idx)
389
+ return None
390
+
391
+ def _handle_submit(self) -> AnnotationQAObservation:
392
+ """Submit corrections and compute final grade."""
393
+ self._done = True
394
+ final_score = grade_episode(
395
+ self._initial_annotations,
396
+ self._current_annotations,
397
+ self._gold_annotations,
398
+ )
399
+
400
+ return self._build_observation(
401
+ reward=final_score,
402
+ message=(
403
+ f"Corrections submitted! "
404
+ f"Final score: {final_score:.3f}. "
405
+ f"Quality went from {self._state.initial_quality:.3f} "
406
+ f"to {self._state.current_quality:.3f} over "
407
+ f"{self._state.step_count} steps."
408
+ ),
409
+ )
410
+
411
+ def _handle_flag_safety(self, action: AnnotationQAAction) -> Optional[str]:
412
+ if action.annotation_id is None:
413
+ return "annotation_id is required for flag_safety"
414
+ ann = self._find_annotation(action.annotation_id)
415
+ if ann is None: return "Annotation not found"
416
+ # We don't change state, just append tracking metadata for the grader
417
+ ann["safety_flagged"] = True
418
+ return None
419
+
420
+ def _handle_change_attribute(self, action: AnnotationQAAction) -> Optional[str]:
421
+ if action.annotation_id is None:
422
+ return "annotation_id is required for change_attribute"
423
+ if not action.new_attribute:
424
+ return "new_attribute is required"
425
+ ann = self._find_annotation(action.annotation_id)
426
+ if ann is None: return "Annotation not found"
427
+ ann["class_label"] = action.new_attribute
428
+ return None
429
+
430
+ def _handle_flag_missing(self, action: AnnotationQAAction) -> Optional[str]:
431
+ if not action.missing_class:
432
+ return "missing_class is required for flag_missing"
433
+ # Flagging missing class adds a placeholder marker
434
+ self._current_annotations.append({
435
+ "id": self._next_ann_id,
436
+ "bbox": [0,0,0,0],
437
+ "class_label": f"missing_{action.missing_class}"
438
+ })
439
+ self._next_ann_id += 1
440
+ return None
441
+
442
+ # ──────────────────────────────────────────
443
+ # Helpers
444
+ # ──────────────────────────────────────────
445
+
446
+ def _find_annotation(self, ann_id: int) -> Optional[Dict]:
447
+ for ann in self._current_annotations:
448
+ if ann["id"] == ann_id:
449
+ return ann
450
+ return None
451
+
452
+ def _find_annotation_index(self, ann_id: int) -> Optional[int]:
453
+ for i, ann in enumerate(self._current_annotations):
454
+ if ann["id"] == ann_id:
455
+ return i
456
+ return None
457
+
458
+ def _build_observation(
459
+ self,
460
+ reward: Optional[float],
461
+ message: str,
462
+ error: Optional[str] = None,
463
+ ) -> AnnotationQAObservation:
464
+ """Build an observation from current state."""
465
+ return AnnotationQAObservation(
466
+ done=self._done,
467
+ reward=reward,
468
+ # Image info from COCO
469
+ image_url=self._scene_data.get("image_url"),
470
+ image_width=self._scene_data.get("image_width", 0),
471
+ image_height=self._scene_data.get("image_height", 0),
472
+ # Scene info
473
+ scene_description=self._scene_data.get("scene_description", ""),
474
+ scene_objects=[
475
+ {
476
+ "id": obj["id"],
477
+ "class_label": obj["class_label"],
478
+ "position": obj.get("position", ""),
479
+ "bbox": obj["bbox"],
480
+ }
481
+ for obj in self._scene_data.get("objects", [])
482
+ ],
483
+ annotations=[
484
+ Annotation(
485
+ id=ann["id"],
486
+ bbox=ann["bbox"],
487
+ class_label=ann["class_label"],
488
+ )
489
+ for ann in self._current_annotations
490
+ ],
491
+ available_classes=ALL_CLASSES,
492
+ task_id=self._state.task_id,
493
+ task_description=self._task_config.get("description", ""),
494
+ corrections_made=self._corrections_made,
495
+ step_count=self._state.step_count,
496
+ max_steps=self._task_config.get("max_steps", 20),
497
+ message=message,
498
+ last_action_error=error,
499
+ )
server/grader.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Grading utilities for the Annotation QA Environment.
3
+
4
+ Provides deterministic scoring (0.0-1.0) based on:
5
+ - IoU (Intersection over Union) of bounding boxes
6
+ - Class label accuracy
7
+ - Precision (penalizes spurious annotations)
8
+ - Recall (penalizes missed annotations)
9
+
10
+ Uses Hungarian matching to optimally pair predicted vs gold annotations.
11
+ """
12
+
13
+ from typing import Dict, List, Tuple
14
+
15
+
16
+ def compute_iou(box_a: List[float], box_b: List[float]) -> float:
17
+ """
18
+ Compute Intersection over Union between two boxes.
19
+ Boxes are [x, y, w, h] with values in 0.0–1.0.
20
+ """
21
+ ax, ay, aw, ah = box_a
22
+ bx, by, bw, bh = box_b
23
+
24
+ # Convert to (x1, y1, x2, y2)
25
+ a_x1, a_y1, a_x2, a_y2 = ax, ay, ax + aw, ay + ah
26
+ b_x1, b_y1, b_x2, b_y2 = bx, by, bx + bw, by + bh
27
+
28
+ # Intersection
29
+ inter_x1 = max(a_x1, b_x1)
30
+ inter_y1 = max(a_y1, b_y1)
31
+ inter_x2 = min(a_x2, b_x2)
32
+ inter_y2 = min(a_y2, b_y2)
33
+
34
+ inter_w = max(0, inter_x2 - inter_x1)
35
+ inter_h = max(0, inter_y2 - inter_y1)
36
+ inter_area = inter_w * inter_h
37
+
38
+ # Union
39
+ area_a = aw * ah
40
+ area_b = bw * bh
41
+ union_area = area_a + area_b - inter_area
42
+
43
+ if union_area < 1e-8:
44
+ return 0.0
45
+
46
+ return inter_area / union_area
47
+
48
+
49
+ def compute_annotation_quality(
50
+ annotations: List[Dict],
51
+ gold_annotations: List[Dict],
52
+ ) -> float:
53
+ """
54
+ Compute specific Semantic VLM visual QA testing metrics (0.0-1.0).
55
+ Graded on:
56
+ - Spurious Precision (35%): Did you remove fake boxes without destroying real ones?
57
+ - Class Match Accuracy (35%): For existing valid boxes, did you change to the correct Gold label?
58
+ - Missing Flag Recall (30%): Did you successfully use FLAG_MISSING for objects removed from the image?
59
+ """
60
+ from collections import Counter
61
+
62
+ if not gold_annotations:
63
+ return 1.0 if not annotations else 0.5
64
+
65
+ # 1. Spurious Precision
66
+ gold_map = {a["id"]: a for a in gold_annotations}
67
+ predictions_valid = [a for a in annotations if not a.get("class_label", "").startswith("missing_")]
68
+
69
+ if not predictions_valid:
70
+ precision = 0.0
71
+ else:
72
+ precision = sum(1 for a in predictions_valid if a["id"] in gold_map) / len(predictions_valid)
73
+
74
+ # 2. Class Match Accuracy for valid boxes
75
+ matched = [a for a in predictions_valid if a["id"] in gold_map]
76
+ if not matched:
77
+ class_acc = 0.0
78
+ else:
79
+ class_acc = sum(1 for a in matched if a.get("class_label", "") == gold_map[a["id"]].get("class_label", "")) / len(matched)
80
+
81
+ # 3. Missing Object Flag Recall
82
+ expected_classes = [g.get("class_label", "") for g in gold_annotations]
83
+ present_classes = [a.get("class_label", "") for a in annotations if a["id"] in gold_map and not a.get("class_label", "").startswith("missing_")]
84
+
85
+ # Calculate exact missing instances mathematically
86
+ exp_counts = Counter(expected_classes)
87
+ pres_counts = Counter(present_classes)
88
+
89
+ actual_missing_classes = []
90
+ for cls, count in exp_counts.items():
91
+ if count > pres_counts.get(cls, 0):
92
+ for _ in range(count - pres_counts.get(cls, 0)):
93
+ actual_missing_classes.append(cls)
94
+
95
+ if not actual_missing_classes:
96
+ missing_acc = 1.0
97
+ else:
98
+ flagged_classes = [a.get("class_label", "").replace("missing_", "", 1) for a in annotations if a.get("class_label", "").startswith("missing_")]
99
+ flagged_counts = Counter(flagged_classes)
100
+
101
+ caught = 0
102
+ for cls in actual_missing_classes:
103
+ if flagged_counts.get(cls, 0) > 0:
104
+ caught += 1
105
+ flagged_counts[cls] -= 1
106
+ missing_acc = caught / len(actual_missing_classes)
107
+
108
+ quality = 0.35 * class_acc + 0.35 * precision + 0.30 * missing_acc
109
+ return max(0.0, min(1.0, quality))
110
+
111
+
112
+ def grade_episode(
113
+ initial_annotations: List[Dict],
114
+ final_annotations: List[Dict],
115
+ gold_annotations: List[Dict],
116
+ ) -> float:
117
+ """
118
+ Compute the episode grade (0.0–1.0).
119
+ """
120
+ initial_quality = compute_annotation_quality(initial_annotations, gold_annotations)
121
+ final_quality = compute_annotation_quality(final_annotations, gold_annotations)
122
+
123
+ max_improvement = 1.0 - initial_quality
124
+ if max_improvement < 0.01:
125
+ return 1.0 if final_quality >= initial_quality - 0.01 else 0.5
126
+
127
+ improvement = final_quality - initial_quality
128
+ score = improvement / max_improvement
129
+ return max(0.0, min(1.0, score))
130
+
131
+
132
+ def compute_step_reward(
133
+ old_annotations: List[Dict],
134
+ new_annotations: List[Dict],
135
+ gold_annotations: List[Dict],
136
+ action_type: str,
137
+ ) -> float:
138
+ """
139
+ Compute dense per-step reward based on quality delta.
140
+ """
141
+ old_quality = compute_annotation_quality(old_annotations, gold_annotations)
142
+ new_quality = compute_annotation_quality(new_annotations, gold_annotations)
143
+ delta = new_quality - old_quality
144
+ reward = delta * 2.0 # quality improvement → reward
145
+ reward -= 0.01 # step penalty
146
+ if action_type == "submit":
147
+ reward += 0.05
148
+ return round(reward, 4)