Jonttup commited on
Commit
3ff7322
·
verified ·
1 Parent(s): 771c44a

Upload models/scope_pooler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/scope_pooler.py +350 -0
models/scope_pooler.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Scope-Aware Pooler
3
+
4
+ Extracts semantic regions from palette using scope markers (0=START, 1=END).
5
+ Implements exact scope matching via stack-based algorithm.
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import List, Tuple, NamedTuple
12
+ from dataclasses import dataclass
13
+
14
+
15
+ class RegionMetadata(NamedTuple):
16
+ """
17
+ Metadata about detected semantic regions
18
+
19
+ Fields:
20
+ - masks: BoolTensor[R, H, W] - spatial masks for each region
21
+ - starts: List[int] - flattened start indices
22
+ - ends: List[int] - flattened end indices
23
+ - depths: List[int] - nesting depth of each region
24
+ - types: List[str] - region type hints
25
+ """
26
+ masks: torch.Tensor
27
+ starts: List[int]
28
+ ends: List[int]
29
+ depths: List[int]
30
+ types: List[str]
31
+
32
+
33
+ class ScopeImbalanceError(Exception):
34
+ """Raised when scope markers are critically unbalanced"""
35
+ pass
36
+
37
+
38
+ class ScopePooler(nn.Module):
39
+ """
40
+ Extract semantic regions from palette using scope markers
41
+
42
+ This module identifies code scopes (functions, loops, classes, etc.)
43
+ by matching START_OF_SCOPE (0) and END_OF_SCOPE (1) tokens.
44
+
45
+ Algorithm:
46
+ 1. Flatten palette to 1D sequence
47
+ 2. Stack-based matching of scope markers
48
+ 3. Extract features for each matched region
49
+ 4. Pool features via mean+max aggregation
50
+
51
+ Edge Cases Handled:
52
+ - Unbalanced scopes (warning + best-effort matching)
53
+ - Nested scopes (via stack depth tracking)
54
+ - No scopes found (fallback to uniform grid)
55
+ - Empty regions (skip + warning)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ hidden_dim: int = 768,
61
+ min_region_size: int = 2,
62
+ fallback_grid_size: int = 4
63
+ ):
64
+ """
65
+ Args:
66
+ hidden_dim: Feature dimension
67
+ min_region_size: Minimum tokens per region
68
+ fallback_grid_size: Grid size when no scopes found
69
+ """
70
+ super().__init__()
71
+
72
+ self.hidden_dim = hidden_dim
73
+ self.min_region_size = min_region_size
74
+ self.fallback_grid_size = fallback_grid_size
75
+
76
+ # Learned pooling projection
77
+ # Concat [mean, max] then project back to hidden_dim
78
+ self.pool_proj = nn.Linear(hidden_dim * 2, hidden_dim)
79
+
80
+ def forward(
81
+ self,
82
+ features: torch.Tensor, # (B, H, W, D)
83
+ palette: torch.Tensor # (B, H, W)
84
+ ) -> Tuple[torch.Tensor, List[RegionMetadata]]:
85
+ """
86
+ Extract semantic regions and pool features
87
+
88
+ Args:
89
+ features: (B, H, W, D) - ViT output features
90
+ palette: (B, H, W) - palette indices
91
+
92
+ Returns:
93
+ regions: (B, R, D) - per-region pooled features
94
+ metadata: List[RegionMetadata] - one per batch item
95
+
96
+ Guarantees:
97
+ - R >= 1 always (at least one region)
98
+ - All regions non-empty
99
+ - Features normalized (unit norm)
100
+ """
101
+ B, H, W, D = features.shape
102
+ assert palette.shape == (B, H, W), f"Shape mismatch: features{features.shape} vs palette{palette.shape}"
103
+ assert D == self.hidden_dim, f"Hidden dim mismatch: {D} != {self.hidden_dim}"
104
+
105
+ all_regions = []
106
+ all_metadata = []
107
+
108
+ for b in range(B):
109
+ feat_b = features[b] # (H, W, D)
110
+ pal_b = palette[b] # (H, W)
111
+
112
+ # Extract regions for this sample
113
+ regions_b, meta_b = self._extract_regions_single(feat_b, pal_b, H, W)
114
+
115
+ all_regions.append(regions_b) # (R_b, D)
116
+ all_metadata.append(meta_b)
117
+
118
+ # Pad to max number of regions in batch
119
+ max_regions = max(r.shape[0] for r in all_regions)
120
+ padded_regions = []
121
+
122
+ for regions_b in all_regions:
123
+ R_b = regions_b.shape[0]
124
+ if R_b < max_regions:
125
+ # Pad with zeros
126
+ padding = torch.zeros(
127
+ max_regions - R_b, D,
128
+ device=regions_b.device,
129
+ dtype=regions_b.dtype
130
+ )
131
+ regions_b = torch.cat([regions_b, padding], dim=0)
132
+ padded_regions.append(regions_b)
133
+
134
+ batched_regions = torch.stack(padded_regions, dim=0) # (B, R_max, D)
135
+
136
+ return batched_regions, all_metadata
137
+
138
+ def _extract_regions_single(
139
+ self,
140
+ features: torch.Tensor, # (H, W, D)
141
+ palette: torch.Tensor, # (H, W)
142
+ H: int,
143
+ W: int
144
+ ) -> Tuple[torch.Tensor, RegionMetadata]:
145
+ """
146
+ Extract regions from a single sample
147
+
148
+ Returns:
149
+ regions: (R, D) - pooled features
150
+ metadata: RegionMetadata
151
+ """
152
+ # 1. Flatten to sequence
153
+ seq = palette.flatten() # (H*W,)
154
+ features_flat = features.view(-1, self.hidden_dim) # (H*W, D)
155
+
156
+ # 2. Match scopes
157
+ try:
158
+ scope_pairs, depths = self._match_scopes(seq)
159
+ except ScopeImbalanceError as e:
160
+ # Critical error - scopes too broken to recover
161
+ logging.warning(f"{e}. Using fallback uniform grid.")
162
+ scope_pairs, depths = self._fallback_uniform_grid(H, W)
163
+
164
+ # 3. Filter invalid regions
165
+ valid_pairs = []
166
+ valid_depths = []
167
+ for (start, end), depth in zip(scope_pairs, depths):
168
+ if (end - start + 1) >= self.min_region_size:
169
+ valid_pairs.append((start, end))
170
+ valid_depths.append(depth)
171
+
172
+ if not valid_pairs:
173
+ # No valid regions - use full sequence
174
+ valid_pairs = [(0, H*W - 1)]
175
+ valid_depths = [0]
176
+
177
+ # 4. Extract features for each region
178
+ region_features = []
179
+ region_masks = []
180
+ starts = []
181
+ ends = []
182
+
183
+ for (start, end) in valid_pairs:
184
+ # Extract features in range
185
+ region_feat = features_flat[start:end+1] # (L, D)
186
+
187
+ # Pool: mean + max
188
+ mean_pool = region_feat.mean(dim=0) # (D,)
189
+ max_pool = region_feat.max(dim=0)[0] # (D,)
190
+
191
+ # Concatenate and project
192
+ combined = torch.cat([mean_pool, max_pool], dim=0) # (2D,)
193
+ pooled = self.pool_proj(combined) # (D,)
194
+
195
+ # Normalize
196
+ pooled = torch.nn.functional.normalize(pooled, dim=0)
197
+
198
+ region_features.append(pooled)
199
+
200
+ # Create mask
201
+ mask = torch.zeros(H * W, dtype=torch.bool, device=palette.device)
202
+ mask[start:end+1] = True
203
+ mask_2d = mask.view(H, W)
204
+ region_masks.append(mask_2d)
205
+
206
+ starts.append(start)
207
+ ends.append(end)
208
+
209
+ # Stack regions
210
+ regions = torch.stack(region_features, dim=0) # (R, D)
211
+ masks = torch.stack(region_masks, dim=0) # (R, H, W)
212
+
213
+ # Create metadata
214
+ types = ['scope'] * len(valid_pairs) # Generic type for now
215
+ metadata = RegionMetadata(
216
+ masks=masks,
217
+ starts=starts,
218
+ ends=ends,
219
+ depths=valid_depths,
220
+ types=types
221
+ )
222
+
223
+ return regions, metadata
224
+
225
+ def _match_scopes(
226
+ self,
227
+ seq: torch.Tensor # (N,)
228
+ ) -> Tuple[List[Tuple[int, int]], List[int]]:
229
+ """
230
+ Stack-based scope matching
231
+
232
+ Returns:
233
+ pairs: List of (start_idx, end_idx) tuples
234
+ depths: List of nesting depths
235
+
236
+ Algorithm:
237
+ - Maintain stack of open scope indices
238
+ - When seeing START (0), push index
239
+ - When seeing END (1), pop and create pair
240
+ - Track depth = current stack size
241
+
242
+ Edge Cases:
243
+ - Unmatched START: close at sequence end
244
+ - Unmatched END: skip with warning
245
+ - No scopes: return empty list (caller handles)
246
+ """
247
+ START_OF_SCOPE = 0
248
+ END_OF_SCOPE = 1
249
+
250
+ stack = [] # Stack of (index, depth)
251
+ pairs = []
252
+ depths = []
253
+
254
+ seq_np = seq.cpu().numpy() # Faster iteration
255
+
256
+ for i, token in enumerate(seq_np):
257
+ if token == START_OF_SCOPE:
258
+ # Open new scope
259
+ depth = len(stack)
260
+ stack.append((i, depth))
261
+
262
+ elif token == END_OF_SCOPE:
263
+ # Close scope
264
+ if stack:
265
+ start_idx, depth = stack.pop()
266
+ pairs.append((start_idx, i))
267
+ depths.append(depth)
268
+ else:
269
+ # Unmatched END - skip
270
+ logging.warning(f"Unmatched END_OF_SCOPE at position {i}")
271
+
272
+ # Handle unmatched STARTs
273
+ if stack:
274
+ logging.warning(f"{len(stack)} unmatched START_OF_SCOPE tokens")
275
+ # Close them at sequence end
276
+ seq_len = len(seq_np)
277
+ for start_idx, depth in stack:
278
+ pairs.append((start_idx, seq_len - 1))
279
+ depths.append(depth)
280
+
281
+ # Validate: check for severe imbalance
282
+ num_starts = (seq == START_OF_SCOPE).sum().item()
283
+ num_ends = (seq == END_OF_SCOPE).sum().item()
284
+
285
+ if abs(num_starts - num_ends) > max(num_starts, num_ends) * 0.5:
286
+ # More than 50% imbalance - critical error
287
+ raise ScopeImbalanceError(
288
+ f"Severe scope imbalance: {num_starts} starts vs {num_ends} ends"
289
+ )
290
+
291
+ return pairs, depths
292
+
293
+ def _fallback_uniform_grid(
294
+ self,
295
+ H: int,
296
+ W: int
297
+ ) -> Tuple[List[Tuple[int, int]], List[int]]:
298
+ """
299
+ Fallback when scope matching fails
300
+
301
+ Returns uniform grid of regions
302
+
303
+ Args:
304
+ H, W: palette dimensions
305
+
306
+ Returns:
307
+ pairs: List of (start, end) for grid cells
308
+ depths: All depth=0 (flat)
309
+ """
310
+ total = H * W
311
+ grid_size = self.fallback_grid_size
312
+ region_size = total // grid_size
313
+
314
+ pairs = []
315
+ for i in range(grid_size):
316
+ start = i * region_size
317
+ end = (i + 1) * region_size - 1 if i < grid_size - 1 else total - 1
318
+ pairs.append((start, end))
319
+
320
+ depths = [0] * grid_size
321
+
322
+ return pairs, depths
323
+
324
+ def visualize_regions(
325
+ self,
326
+ palette: torch.Tensor, # (H, W)
327
+ metadata: RegionMetadata
328
+ ) -> str:
329
+ """
330
+ Generate human-readable visualization of regions
331
+
332
+ Returns: String representation
333
+ """
334
+ H, W = palette.shape
335
+ output = []
336
+ output.append(f"Detected {len(metadata.starts)} regions:")
337
+
338
+ for i, (start, end, depth) in enumerate(zip(
339
+ metadata.starts,
340
+ metadata.ends,
341
+ metadata.depths
342
+ )):
343
+ region_size = end - start + 1
344
+ indent = " " * depth
345
+ output.append(
346
+ f"{indent}Region {i}: [{start:4d}, {end:4d}] "
347
+ f"(size={region_size:3d}, depth={depth})"
348
+ )
349
+
350
+ return "\n".join(output)