emelryan commited on
Commit
352ae81
·
1 Parent(s): dbbb1b6

Chunk relational model for consistent CUDA performance

Browse files

Run rect_proj, combined_proj (BatchNorm), and the transformer encoder
in fixed chunk_size batches so every kernel invocation sees the same
tensor shape. This prevents cuDNN from re-benchmarking on every new
batch size, eliminating ~460ms latency spikes.

Also compute cdist per-image on actual regions instead of padding all
images to k_max, reducing wasted compute on the quadratic distance
matrix.

Made-with: Cursor

nemotron-ocr/src/nemotron_ocr/inference/models/relational.py CHANGED
@@ -111,6 +111,33 @@ class GlobalRelationalModel(nn.Module):
111
  nn.Linear(dim, 3),
112
  )
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def get_target_rects(
115
  self,
116
  quads: torch.Tensor,
@@ -254,14 +281,15 @@ class GlobalRelationalModel(nn.Module):
254
  original_quads = original_quads / self.quad_downscale
255
  mid_pts = original_quads.detach().mean(dim=1, dtype=torch.float32)
256
 
257
- rectified_quads = self.rect_proj(rectified_quads)
258
- avg_rects = rectified_quads.flatten(2).sum(
259
- dim=2, dtype=torch.float32
260
- ) / num_valid_pos.unsqueeze(1)
261
-
262
- recog_encoding = self.recog_tx(recog_features.detach()).mean(dim=1, dtype=torch.float32)
263
 
264
- semantic_encoding = self.combined_proj(torch.cat((avg_rects, recog_encoding), dim=1))
 
 
265
 
266
  h1 = original_quads[:, 3] - original_quads[:, 0]
267
  h2 = original_quads[:, 2] - original_quads[:, 1]
@@ -320,11 +348,13 @@ class GlobalRelationalModel(nn.Module):
320
 
321
  counts_list = region_counts.tolist() if region_counts.dim() > 0 else [int(region_counts.item())]
322
  batch_size = len(counts_list)
323
- k_max_raw = max(counts_list) if counts_list else 0
 
 
 
 
324
 
325
- if k_max_raw == 0:
326
- device = proj_rects.device
327
- dtype = proj_rects.dtype
328
  return {
329
  "words": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
330
  "lines": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
@@ -333,111 +363,67 @@ class GlobalRelationalModel(nn.Module):
333
  ],
334
  }
335
 
336
- chunk_size = max(1, int(getattr(self, "chunk_size", DEFAULT_CHUNK)))
337
- k_max = ((k_max_raw + chunk_size - 1) // chunk_size) * chunk_size
338
- device = proj_rects.device
339
- counts_gpu = region_counts.to(device=device, dtype=torch.long)
340
-
341
- rects_b = _pad_flat_to_batched(proj_rects, region_counts, k_max, pad_value=0.0)
342
- centers_b = _pad_flat_to_batched(mid_pts, region_counts, k_max, pad_value=0.0)
343
- quads_b = _pad_flat_to_batched(quads, region_counts, k_max, pad_value=0.0)
344
-
345
- arange_k = torch.arange(k_max, device=device)
346
- valid_rows = arange_k.unsqueeze(0) < counts_gpu.unsqueeze(1)
347
- k_per_image = (counts_gpu - 1).clamp_min(0).clamp_max(self.k - 1)
348
- z_max = int(k_per_image.max().item())
349
-
350
- if z_max == 0:
351
- to_rects = torch.zeros(batch_size, k_max, 1, proj_rects.shape[1] + 2, **options(proj_rects))
352
- closest_other_idxs = torch.zeros(batch_size, k_max, 1, dtype=torch.long, device=device)
353
- key_padding_mask = torch.zeros(batch_size, k_max, 1, dtype=torch.bool, device=device)
354
- else:
355
- all_dists = get_cdist_batched(quads_b, counts_gpu)
356
- invalid_rows = ~valid_rows
357
- all_dists.masked_fill_(invalid_rows.unsqueeze(1), float("inf"))
358
- all_dists.masked_fill_(invalid_rows.unsqueeze(2), float("inf"))
359
- all_dists.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
360
-
361
- topk_dists, topk_idxs = torch.topk(
362
- all_dists, k=z_max, dim=2, largest=False, sorted=False
363
- )
364
-
365
- gather_idx = topk_idxs.unsqueeze(-1).expand(-1, -1, -1, rects_b.shape[-1])
366
- neighbor_rects = torch.gather(
367
- rects_b.unsqueeze(1).expand(-1, k_max, -1, -1),
368
- dim=2,
369
- index=gather_idx,
370
- )
371
-
372
- gather_idx_2 = topk_idxs.unsqueeze(-1).expand(-1, -1, -1, 2)
373
- neighbor_centers = torch.gather(
374
- centers_b.unsqueeze(1).expand(-1, k_max, -1, -1),
375
- dim=2,
376
- index=gather_idx_2,
377
- )
378
-
379
- directions = _batched_directions(quads_b, neighbor_centers)
380
-
381
- neighbor_valid = valid_rows.unsqueeze(2) & (
382
- torch.arange(z_max, device=device).view(1, 1, -1) < k_per_image.view(batch_size, 1, 1)
383
- )
384
-
385
- neighbor_rects = neighbor_rects.masked_fill(~neighbor_valid.unsqueeze(-1), 0)
386
- topk_dists = topk_dists.masked_fill(~neighbor_valid, 0)
387
- directions = directions.masked_fill(~neighbor_valid, 0)
388
-
389
- null_rects = torch.zeros(batch_size, k_max, 1, rects_b.shape[-1], **options(rects_b))
390
- null_dists = torch.full((batch_size, k_max, 1), -1, **options(rects_b))
391
- null_dirs = torch.full((batch_size, k_max, 1), -2, **options(rects_b))
392
-
393
- to_rects = torch.cat(
394
- (
395
- torch.cat((null_rects, neighbor_rects), dim=2),
396
- torch.cat((null_dists, topk_dists), dim=2).unsqueeze(-1),
397
- torch.cat((null_dirs, directions), dim=2).unsqueeze(-1),
398
- ),
399
- dim=-1,
400
- )
401
-
402
- key_padding_mask = torch.ones(batch_size, k_max, z_max + 1, dtype=torch.bool, device=device)
403
- key_padding_mask[:, :, 0] = False
404
- key_padding_mask[:, :, 1:] = ~neighbor_valid
405
-
406
- invalid_target_idx = torch.full_like(topk_idxs, k_max)
407
- target_idxs = torch.where(neighbor_valid, topk_idxs + 1, invalid_target_idx)
408
- closest_other_idxs = torch.cat(
409
- (torch.zeros(batch_size, k_max, 1, dtype=torch.long, device=device), target_idxs),
410
- dim=2,
411
- )
412
-
413
- from_rects = rects_b.unsqueeze(2).expand(-1, -1, to_rects.shape[2], -1)
414
- enc_input = torch.cat((from_rects, to_rects), dim=3)
415
- enc_flat = enc_input.reshape(batch_size * k_max, enc_input.shape[2], enc_input.shape[3])
416
- mask_flat = key_padding_mask.reshape(batch_size * k_max, key_padding_mask.shape[2])
417
-
418
- if enc_flat.shape[0]:
419
- dots = self.encoder[0](enc_flat, src_key_padding_mask=mask_flat)
420
- dots = self.encoder[1](dots)
421
- else:
422
- dots = torch.empty(0, 1, 3, dtype=enc_input.dtype, device=device)
423
-
424
- dots = dots.reshape(batch_size, k_max, enc_input.shape[2], 3).permute(0, 3, 1, 2)
425
- dots = self.prohibit_self_connection(dots, closest_other_idxs)
426
-
427
  all_dots = dict(words=[], lines=[], line_log_var_unc=[])
428
- for i, region_count in enumerate(counts_list):
429
- if region_count == 0:
430
- word_pred = torch.empty(0, 1, dtype=dots.dtype, device=device)
431
- line_pred = torch.empty(0, 1, dtype=dots.dtype, device=device)
432
- line_log_var_pred = torch.empty(0, 1, dtype=dots.dtype, device=device)
433
- else:
434
- word_pred = dots[i, 0, :region_count, : region_count + 1]
435
- line_pred = dots[i, 1, :region_count, : region_count + 1]
436
- line_log_var_pred = dots[i, 2, :region_count, : region_count + 1]
437
-
438
- all_dots["words"].append(word_pred)
439
- all_dots["lines"].append(line_pred)
440
- all_dots["line_log_var_unc"].append(line_log_var_pred)
441
 
442
  return {
443
  "words": all_dots["words"],
 
111
  nn.Linear(dim, 3),
112
  )
113
 
114
+ def _chunked_forward(self, fn, x, *extra, pad_extra_ones=False):
115
+ """Run *fn* in fixed ``chunk_size`` batches along dim-0.
116
+
117
+ Pads the last chunk so every call sees the same tensor shape,
118
+ preventing cuDNN autotuning on varying batch sizes.
119
+ ``extra`` tensors are sliced/padded in sync with ``x``.
120
+ """
121
+ n = x.shape[0]
122
+ cs = max(1, self.chunk_size)
123
+ if n == 0:
124
+ return fn(x, *extra)
125
+ parts = []
126
+ for start in range(0, n, cs):
127
+ end = min(start + cs, n)
128
+ real_n = end - start
129
+ xc = x[start:end]
130
+ ec = [e[start:end] for e in extra]
131
+ if real_n < cs:
132
+ xc = torch.cat((xc, xc[:1].expand(cs - real_n, *[-1] * (xc.ndim - 1))), dim=0)
133
+ for i, e in enumerate(ec):
134
+ if pad_extra_ones and e.dtype == torch.bool:
135
+ ec[i] = torch.cat((e, torch.ones(cs - real_n, *e.shape[1:], dtype=torch.bool, device=e.device)), dim=0)
136
+ else:
137
+ ec[i] = torch.cat((e, e[:1].expand(cs - real_n, *[-1] * (e.ndim - 1))), dim=0)
138
+ parts.append(fn(xc, *ec)[:real_n])
139
+ return torch.cat(parts, dim=0)
140
+
141
  def get_target_rects(
142
  self,
143
  quads: torch.Tensor,
 
281
  original_quads = original_quads / self.quad_downscale
282
  mid_pts = original_quads.detach().mean(dim=1, dtype=torch.float32)
283
 
284
+ def _input_enc_nn(rq, rf, nvp):
285
+ rq = self.rect_proj(rq)
286
+ avg = rq.flatten(2).sum(dim=2, dtype=torch.float32) / nvp.unsqueeze(1)
287
+ rec = self.recog_tx(rf.detach()).mean(dim=1, dtype=torch.float32)
288
+ return self.combined_proj(torch.cat((avg, rec), dim=1))
 
289
 
290
+ semantic_encoding = self._chunked_forward(
291
+ _input_enc_nn, rectified_quads, recog_features, num_valid_pos,
292
+ )
293
 
294
  h1 = original_quads[:, 3] - original_quads[:, 0]
295
  h2 = original_quads[:, 2] - original_quads[:, 1]
 
348
 
349
  counts_list = region_counts.tolist() if region_counts.dim() > 0 else [int(region_counts.item())]
350
  batch_size = len(counts_list)
351
+ device = proj_rects.device
352
+ dtype = proj_rects.dtype
353
+ feat_dim = proj_rects.shape[1]
354
+ z = self.k - 1
355
+ seq_len = z + 1
356
 
357
+ if max(counts_list, default=0) == 0:
 
 
358
  return {
359
  "words": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
360
  "lines": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
 
363
  ],
364
  }
365
 
366
+ # Per-image cdist + topk, then concatenate into flat [N_total, seq_len, ...]
367
+ offsets = [0]
368
+ for c in counts_list:
369
+ offsets.append(offsets[-1] + c)
370
+ n_total = offsets[-1]
371
+
372
+ enc_input_flat = torch.zeros(n_total, seq_len, 2 * feat_dim + 2, dtype=dtype, device=device)
373
+ mask_flat = torch.ones(n_total, seq_len, dtype=torch.bool, device=device)
374
+ closest_flat = torch.zeros(n_total, seq_len, dtype=torch.long, device=device)
375
+
376
+ for i, n_i in enumerate(counts_list):
377
+ if n_i == 0:
378
+ continue
379
+ s, e = offsets[i], offsets[i + 1]
380
+ rects_i = proj_rects[s:e]
381
+ centers_i = mid_pts[s:e]
382
+ quads_i = quads[s:e]
383
+ z_i = min(n_i - 1, z)
384
+
385
+ from_r = rects_i.unsqueeze(1).expand(-1, seq_len, -1)
386
+ enc_input_flat[s:e, 0, :feat_dim] = rects_i
387
+ enc_input_flat[s:e, 0, 2 * feat_dim] = -1
388
+ enc_input_flat[s:e, 0, 2 * feat_dim + 1] = -2
389
+ mask_flat[s:e, 0] = False
390
+
391
+ if z_i > 0:
392
+ dists_i = get_cdist(quads_i, centers_i)
393
+ topk_d, topk_idx = torch.topk(dists_i, k=z_i, dim=1, largest=False, sorted=False)
394
+ nb_r = torch.gather(rects_i.unsqueeze(0).expand(n_i, -1, -1), 1, topk_idx.unsqueeze(2).expand(-1, -1, feat_dim))
395
+ nb_c = torch.gather(centers_i.unsqueeze(0).expand(n_i, -1, -1), 1, topk_idx.unsqueeze(2).expand(-1, -1, 2))
396
+ dirs_i = get_directions(quads_i, nb_c)
397
+
398
+ enc_input_flat[s:e, 1:z_i + 1, :feat_dim] = from_r[:, 1:z_i + 1]
399
+ enc_input_flat[s:e, 1:z_i + 1, feat_dim:2 * feat_dim] = nb_r
400
+ enc_input_flat[s:e, 1:z_i + 1, 2 * feat_dim] = topk_d
401
+ enc_input_flat[s:e, 1:z_i + 1, 2 * feat_dim + 1] = dirs_i
402
+ mask_flat[s:e, 1:z_i + 1] = False
403
+ closest_flat[s:e, 1:z_i + 1] = topk_idx + 1
404
+
405
+ # Chunked encoder on flat regions — always sees [chunk_size, seq_len, dim]
406
+ def _run_encoder(enc, mask):
407
+ out = self.encoder[0](enc, src_key_padding_mask=mask)
408
+ return self.encoder[1](out)
409
+
410
+ dots_flat = self._chunked_forward(_run_encoder, enc_input_flat, mask_flat, pad_extra_ones=True)
411
+
412
+ # Per-image: scatter encoder output into full relation matrices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  all_dots = dict(words=[], lines=[], line_log_var_unc=[])
414
+ for i, n_i in enumerate(counts_list):
415
+ if n_i == 0:
416
+ all_dots["words"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
417
+ all_dots["lines"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
418
+ all_dots["line_log_var_unc"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
419
+ continue
420
+ s, e = offsets[i], offsets[i + 1]
421
+ dots_i = dots_flat[s:e].unsqueeze(0).permute(0, 3, 1, 2)
422
+ cidx_i = closest_flat[s:e].unsqueeze(0)
423
+ dots_i = self.prohibit_self_connection(dots_i, cidx_i)
424
+ all_dots["words"].append(dots_i[0, 0, :n_i, :n_i + 1])
425
+ all_dots["lines"].append(dots_i[0, 1, :n_i, :n_i + 1])
426
+ all_dots["line_log_var_unc"].append(dots_i[0, 2, :n_i, :n_i + 1])
427
 
428
  return {
429
  "words": all_dots["words"],