Chunk relational model for consistent CUDA performance
Browse filesRun rect_proj, combined_proj (BatchNorm), and the transformer encoder
in fixed chunk_size batches so every kernel invocation sees the same
tensor shape. This prevents cuDNN from re-benchmarking on every new
batch size, eliminating ~460ms latency spikes.
Also compute cdist per-image on actual regions instead of padding all
images to k_max, reducing wasted compute on the quadratic distance
matrix.
Made-with: Cursor
nemotron-ocr/src/nemotron_ocr/inference/models/relational.py
CHANGED
|
@@ -111,6 +111,33 @@ class GlobalRelationalModel(nn.Module):
|
|
| 111 |
nn.Linear(dim, 3),
|
| 112 |
)
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def get_target_rects(
|
| 115 |
self,
|
| 116 |
quads: torch.Tensor,
|
|
@@ -254,14 +281,15 @@ class GlobalRelationalModel(nn.Module):
|
|
| 254 |
original_quads = original_quads / self.quad_downscale
|
| 255 |
mid_pts = original_quads.detach().mean(dim=1, dtype=torch.float32)
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
dim=2, dtype=torch.float32
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
recog_encoding = self.recog_tx(recog_features.detach()).mean(dim=1, dtype=torch.float32)
|
| 263 |
|
| 264 |
-
semantic_encoding = self.
|
|
|
|
|
|
|
| 265 |
|
| 266 |
h1 = original_quads[:, 3] - original_quads[:, 0]
|
| 267 |
h2 = original_quads[:, 2] - original_quads[:, 1]
|
|
@@ -320,11 +348,13 @@ class GlobalRelationalModel(nn.Module):
|
|
| 320 |
|
| 321 |
counts_list = region_counts.tolist() if region_counts.dim() > 0 else [int(region_counts.item())]
|
| 322 |
batch_size = len(counts_list)
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
-
if
|
| 326 |
-
device = proj_rects.device
|
| 327 |
-
dtype = proj_rects.dtype
|
| 328 |
return {
|
| 329 |
"words": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
|
| 330 |
"lines": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
|
|
@@ -333,111 +363,67 @@ class GlobalRelationalModel(nn.Module):
|
|
| 333 |
],
|
| 334 |
}
|
| 335 |
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
neighbor_rects = neighbor_rects.masked_fill(~neighbor_valid.unsqueeze(-1), 0)
|
| 386 |
-
topk_dists = topk_dists.masked_fill(~neighbor_valid, 0)
|
| 387 |
-
directions = directions.masked_fill(~neighbor_valid, 0)
|
| 388 |
-
|
| 389 |
-
null_rects = torch.zeros(batch_size, k_max, 1, rects_b.shape[-1], **options(rects_b))
|
| 390 |
-
null_dists = torch.full((batch_size, k_max, 1), -1, **options(rects_b))
|
| 391 |
-
null_dirs = torch.full((batch_size, k_max, 1), -2, **options(rects_b))
|
| 392 |
-
|
| 393 |
-
to_rects = torch.cat(
|
| 394 |
-
(
|
| 395 |
-
torch.cat((null_rects, neighbor_rects), dim=2),
|
| 396 |
-
torch.cat((null_dists, topk_dists), dim=2).unsqueeze(-1),
|
| 397 |
-
torch.cat((null_dirs, directions), dim=2).unsqueeze(-1),
|
| 398 |
-
),
|
| 399 |
-
dim=-1,
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
key_padding_mask = torch.ones(batch_size, k_max, z_max + 1, dtype=torch.bool, device=device)
|
| 403 |
-
key_padding_mask[:, :, 0] = False
|
| 404 |
-
key_padding_mask[:, :, 1:] = ~neighbor_valid
|
| 405 |
-
|
| 406 |
-
invalid_target_idx = torch.full_like(topk_idxs, k_max)
|
| 407 |
-
target_idxs = torch.where(neighbor_valid, topk_idxs + 1, invalid_target_idx)
|
| 408 |
-
closest_other_idxs = torch.cat(
|
| 409 |
-
(torch.zeros(batch_size, k_max, 1, dtype=torch.long, device=device), target_idxs),
|
| 410 |
-
dim=2,
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
from_rects = rects_b.unsqueeze(2).expand(-1, -1, to_rects.shape[2], -1)
|
| 414 |
-
enc_input = torch.cat((from_rects, to_rects), dim=3)
|
| 415 |
-
enc_flat = enc_input.reshape(batch_size * k_max, enc_input.shape[2], enc_input.shape[3])
|
| 416 |
-
mask_flat = key_padding_mask.reshape(batch_size * k_max, key_padding_mask.shape[2])
|
| 417 |
-
|
| 418 |
-
if enc_flat.shape[0]:
|
| 419 |
-
dots = self.encoder[0](enc_flat, src_key_padding_mask=mask_flat)
|
| 420 |
-
dots = self.encoder[1](dots)
|
| 421 |
-
else:
|
| 422 |
-
dots = torch.empty(0, 1, 3, dtype=enc_input.dtype, device=device)
|
| 423 |
-
|
| 424 |
-
dots = dots.reshape(batch_size, k_max, enc_input.shape[2], 3).permute(0, 3, 1, 2)
|
| 425 |
-
dots = self.prohibit_self_connection(dots, closest_other_idxs)
|
| 426 |
-
|
| 427 |
all_dots = dict(words=[], lines=[], line_log_var_unc=[])
|
| 428 |
-
for i,
|
| 429 |
-
if
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
all_dots["words"].append(
|
| 439 |
-
all_dots["lines"].append(
|
| 440 |
-
all_dots["line_log_var_unc"].append(
|
| 441 |
|
| 442 |
return {
|
| 443 |
"words": all_dots["words"],
|
|
|
|
| 111 |
nn.Linear(dim, 3),
|
| 112 |
)
|
| 113 |
|
| 114 |
+
def _chunked_forward(self, fn, x, *extra, pad_extra_ones=False):
|
| 115 |
+
"""Run *fn* in fixed ``chunk_size`` batches along dim-0.
|
| 116 |
+
|
| 117 |
+
Pads the last chunk so every call sees the same tensor shape,
|
| 118 |
+
preventing cuDNN autotuning on varying batch sizes.
|
| 119 |
+
``extra`` tensors are sliced/padded in sync with ``x``.
|
| 120 |
+
"""
|
| 121 |
+
n = x.shape[0]
|
| 122 |
+
cs = max(1, self.chunk_size)
|
| 123 |
+
if n == 0:
|
| 124 |
+
return fn(x, *extra)
|
| 125 |
+
parts = []
|
| 126 |
+
for start in range(0, n, cs):
|
| 127 |
+
end = min(start + cs, n)
|
| 128 |
+
real_n = end - start
|
| 129 |
+
xc = x[start:end]
|
| 130 |
+
ec = [e[start:end] for e in extra]
|
| 131 |
+
if real_n < cs:
|
| 132 |
+
xc = torch.cat((xc, xc[:1].expand(cs - real_n, *[-1] * (xc.ndim - 1))), dim=0)
|
| 133 |
+
for i, e in enumerate(ec):
|
| 134 |
+
if pad_extra_ones and e.dtype == torch.bool:
|
| 135 |
+
ec[i] = torch.cat((e, torch.ones(cs - real_n, *e.shape[1:], dtype=torch.bool, device=e.device)), dim=0)
|
| 136 |
+
else:
|
| 137 |
+
ec[i] = torch.cat((e, e[:1].expand(cs - real_n, *[-1] * (e.ndim - 1))), dim=0)
|
| 138 |
+
parts.append(fn(xc, *ec)[:real_n])
|
| 139 |
+
return torch.cat(parts, dim=0)
|
| 140 |
+
|
| 141 |
def get_target_rects(
|
| 142 |
self,
|
| 143 |
quads: torch.Tensor,
|
|
|
|
| 281 |
original_quads = original_quads / self.quad_downscale
|
| 282 |
mid_pts = original_quads.detach().mean(dim=1, dtype=torch.float32)
|
| 283 |
|
| 284 |
+
def _input_enc_nn(rq, rf, nvp):
|
| 285 |
+
rq = self.rect_proj(rq)
|
| 286 |
+
avg = rq.flatten(2).sum(dim=2, dtype=torch.float32) / nvp.unsqueeze(1)
|
| 287 |
+
rec = self.recog_tx(rf.detach()).mean(dim=1, dtype=torch.float32)
|
| 288 |
+
return self.combined_proj(torch.cat((avg, rec), dim=1))
|
|
|
|
| 289 |
|
| 290 |
+
semantic_encoding = self._chunked_forward(
|
| 291 |
+
_input_enc_nn, rectified_quads, recog_features, num_valid_pos,
|
| 292 |
+
)
|
| 293 |
|
| 294 |
h1 = original_quads[:, 3] - original_quads[:, 0]
|
| 295 |
h2 = original_quads[:, 2] - original_quads[:, 1]
|
|
|
|
| 348 |
|
| 349 |
counts_list = region_counts.tolist() if region_counts.dim() > 0 else [int(region_counts.item())]
|
| 350 |
batch_size = len(counts_list)
|
| 351 |
+
device = proj_rects.device
|
| 352 |
+
dtype = proj_rects.dtype
|
| 353 |
+
feat_dim = proj_rects.shape[1]
|
| 354 |
+
z = self.k - 1
|
| 355 |
+
seq_len = z + 1
|
| 356 |
|
| 357 |
+
if max(counts_list, default=0) == 0:
|
|
|
|
|
|
|
| 358 |
return {
|
| 359 |
"words": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
|
| 360 |
"lines": [torch.empty(0, 1, dtype=dtype, device=device) for _ in range(batch_size)],
|
|
|
|
| 363 |
],
|
| 364 |
}
|
| 365 |
|
| 366 |
+
# Per-image cdist + topk, then concatenate into flat [N_total, seq_len, ...]
|
| 367 |
+
offsets = [0]
|
| 368 |
+
for c in counts_list:
|
| 369 |
+
offsets.append(offsets[-1] + c)
|
| 370 |
+
n_total = offsets[-1]
|
| 371 |
+
|
| 372 |
+
enc_input_flat = torch.zeros(n_total, seq_len, 2 * feat_dim + 2, dtype=dtype, device=device)
|
| 373 |
+
mask_flat = torch.ones(n_total, seq_len, dtype=torch.bool, device=device)
|
| 374 |
+
closest_flat = torch.zeros(n_total, seq_len, dtype=torch.long, device=device)
|
| 375 |
+
|
| 376 |
+
for i, n_i in enumerate(counts_list):
|
| 377 |
+
if n_i == 0:
|
| 378 |
+
continue
|
| 379 |
+
s, e = offsets[i], offsets[i + 1]
|
| 380 |
+
rects_i = proj_rects[s:e]
|
| 381 |
+
centers_i = mid_pts[s:e]
|
| 382 |
+
quads_i = quads[s:e]
|
| 383 |
+
z_i = min(n_i - 1, z)
|
| 384 |
+
|
| 385 |
+
from_r = rects_i.unsqueeze(1).expand(-1, seq_len, -1)
|
| 386 |
+
enc_input_flat[s:e, 0, :feat_dim] = rects_i
|
| 387 |
+
enc_input_flat[s:e, 0, 2 * feat_dim] = -1
|
| 388 |
+
enc_input_flat[s:e, 0, 2 * feat_dim + 1] = -2
|
| 389 |
+
mask_flat[s:e, 0] = False
|
| 390 |
+
|
| 391 |
+
if z_i > 0:
|
| 392 |
+
dists_i = get_cdist(quads_i, centers_i)
|
| 393 |
+
topk_d, topk_idx = torch.topk(dists_i, k=z_i, dim=1, largest=False, sorted=False)
|
| 394 |
+
nb_r = torch.gather(rects_i.unsqueeze(0).expand(n_i, -1, -1), 1, topk_idx.unsqueeze(2).expand(-1, -1, feat_dim))
|
| 395 |
+
nb_c = torch.gather(centers_i.unsqueeze(0).expand(n_i, -1, -1), 1, topk_idx.unsqueeze(2).expand(-1, -1, 2))
|
| 396 |
+
dirs_i = get_directions(quads_i, nb_c)
|
| 397 |
+
|
| 398 |
+
enc_input_flat[s:e, 1:z_i + 1, :feat_dim] = from_r[:, 1:z_i + 1]
|
| 399 |
+
enc_input_flat[s:e, 1:z_i + 1, feat_dim:2 * feat_dim] = nb_r
|
| 400 |
+
enc_input_flat[s:e, 1:z_i + 1, 2 * feat_dim] = topk_d
|
| 401 |
+
enc_input_flat[s:e, 1:z_i + 1, 2 * feat_dim + 1] = dirs_i
|
| 402 |
+
mask_flat[s:e, 1:z_i + 1] = False
|
| 403 |
+
closest_flat[s:e, 1:z_i + 1] = topk_idx + 1
|
| 404 |
+
|
| 405 |
+
# Chunked encoder on flat regions — always sees [chunk_size, seq_len, dim]
|
| 406 |
+
def _run_encoder(enc, mask):
|
| 407 |
+
out = self.encoder[0](enc, src_key_padding_mask=mask)
|
| 408 |
+
return self.encoder[1](out)
|
| 409 |
+
|
| 410 |
+
dots_flat = self._chunked_forward(_run_encoder, enc_input_flat, mask_flat, pad_extra_ones=True)
|
| 411 |
+
|
| 412 |
+
# Per-image: scatter encoder output into full relation matrices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
all_dots = dict(words=[], lines=[], line_log_var_unc=[])
|
| 414 |
+
for i, n_i in enumerate(counts_list):
|
| 415 |
+
if n_i == 0:
|
| 416 |
+
all_dots["words"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
|
| 417 |
+
all_dots["lines"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
|
| 418 |
+
all_dots["line_log_var_unc"].append(torch.empty(0, 1, dtype=torch.float32, device=device))
|
| 419 |
+
continue
|
| 420 |
+
s, e = offsets[i], offsets[i + 1]
|
| 421 |
+
dots_i = dots_flat[s:e].unsqueeze(0).permute(0, 3, 1, 2)
|
| 422 |
+
cidx_i = closest_flat[s:e].unsqueeze(0)
|
| 423 |
+
dots_i = self.prohibit_self_connection(dots_i, cidx_i)
|
| 424 |
+
all_dots["words"].append(dots_i[0, 0, :n_i, :n_i + 1])
|
| 425 |
+
all_dots["lines"].append(dots_i[0, 1, :n_i, :n_i + 1])
|
| 426 |
+
all_dots["line_log_var_unc"].append(dots_i[0, 2, :n_i, :n_i + 1])
|
| 427 |
|
| 428 |
return {
|
| 429 |
"words": all_dots["words"],
|