BokehFlow / AUDIT.md
asdf98's picture
Add code audit — 10 issues found with fixes
373e0ae verified

BokehFlow Code Audit — Issues Found and Fixed

CRITICAL BUGS (Will cause training failure or incorrect results)

1. ❌ Compositing double-multiplication bug (render_bokeh, line ~825)

Problem: output = output + blurred * visible / (blurred_mask + 1e-6) * visible
This multiplies by visible TWICE — blurred * visible / blurred_mask * visible = wrong alpha compositing.
Fix: output = output + blurred / (blurred_mask + 1e-6) * visible

2. ❌ CoC map computation doesn't handle focus distance == depth correctly

Problem: When D == S₁ (pixel at focus distance), CoC should be exactly 0. The formula computes abs(D - S₁) which is correct, but the S1.clamp(min=f+1.0) can produce NaN gradients when f is a learnable parameter.
Fix: Detach f from the clamp or use a fixed minimum.

3. ❌ BatchNorm in ConvStem will break at batch_size=1 during inference

Problem: nn.BatchNorm2d computes running stats during training but fails with batch_size=1 if model is in training mode.
Fix: Use nn.GroupNorm(num_groups=8, num_channels=...) or nn.InstanceNorm2d instead.

STABILITY ISSUES (May cause NaN/Inf during training)

4. ⚠️ No gradient clipping mentioned in training config

Problem: The GatedDeltaNet recurrence compounds matrix operations. Without gradient clipping, gradients can explode.
Fix: Add max_grad_norm=1.0 to training config.

5. ⚠️ Key L2-normalization — correct but needs epsilon

Problem: F.normalize(k, p=2, dim=-1) can produce NaN if k is all zeros.
Fix: Add eps: k = F.normalize(k, p=2, dim=-1, eps=1e-8)

6. ⚠️ State explosion risk

Problem: The state update state = a_t * (state - b_t * (state @ kk_t)) + b_t * vk_t has matrix products that can grow unbounded if α≈1 and β≈0 for many steps.
Fix: Add periodic state normalization: state = state / (state.norm() + 1e-6).clamp(min=0.1) every 256 steps.

7. ⚠️ Softplus depth output has no upper bound

Problem: nn.Softplus() can output arbitrarily large values, causing CoC explosion.
Fix: depth = F.softplus(raw_depth).clamp(max=100.0) (100 meters max).

LOGICAL ISSUES

8. ⚠️ embed_dim mismatch for base variant

Problem: num_heads=6, head_dim=32 means inner_dim=192 but embed_dim=192, so the linear projections to_qkv project 192→3*192=576. This is correct but the output gate also projects 192→192. No bug but very heavy for base variant.

9. ⚠️ Direction fusion uses outputs before normalization

Problem: The adaptive direction fusion softmax(W_γ · [o_→;...]) operates on raw scan outputs, then the result is LayerNorm'd. The softmax inputs can have different scales per direction, potentially making one direction always dominate.
Fix: Apply LayerNorm to each scan output BEFORE fusion, or use a temperature in the softmax.

10. ⚠️ TSP state shape mismatch

Problem: self.S_init has shape (1, num_heads, head_dim, head_dim) but BiGDR returns a list of states (one per scan direction), not a single state. The propagate function iterates over block_states which are lists, not tensors.
Fix: S_init should match the per-direction state shape, and propagation should handle the list structure properly.

DATASET CONFIRMED COMPATIBLE ✅

RealBokeh has paired data:

  • Input: {split}/in/{id}_f22.JPG (sharp, f/22)
  • GT: {split}/gt/{id}/{id}_f{fstop}.JPG (variable bokeh)
  • Metadata: {split}/metadata/{id}.json with focal_length, focus_plane_distance, target_avs

This maps perfectly to BokehFlow's inputs: image, f_number, focal_length_mm, focus_distance_m.