"""Generation utilities for Fast-dDrive. This module provides the three inference paths exposed by the canonical paper release: * ``mdm_sample_deep_scaffold`` — Section Diffusion (SD): iterative MDM denoising over a pre-filled JSON scaffold, no AR verification. * ``scaffold_speculative_sample`` — Scaffold Spec (SS): scaffold-aware self-speculative decoding (MDM draft + AR verify per block). * ``scaffold_spec_with_ss_multi_traj`` — SS with shared-prefix multi-trajectory rollouts (the test-time inference-scaling path). All three are attached as bound methods on :class:`Fast_dDriveForConditionalGeneration` when this module is imported (see ``modeling.py`` for the import hook). """ import os import re import sys import math import torch import types import numpy as np from transformers.cache_utils import DynamicCache def _crop_cache(past_key_values, max_length: int): """Crop a DynamicCache to max_length tokens, compatible with Qwen cache layout.""" new_past_key_values = [] for layer_num in range(len(past_key_values)): layer_past_key_values = () for kv_idx in range(len(past_key_values[layer_num])): layer_past_key_values += (past_key_values[layer_num][kv_idx][:, :, :max_length, :],) new_past_key_values.append(layer_past_key_values) return DynamicCache(new_past_key_values) def _sample_from_logits(logits, temperature=0.0): """Sample token ids from logits with optional temperature scaling. When temperature <= 0, falls back to argmax (greedy). """ if temperature <= 0: return logits.argmax(dim=-1) scaled = logits / temperature probs = torch.softmax(scaled, dim=-1) original_shape = probs.shape[:-1] flat_probs = probs.reshape(-1, probs.shape[-1]) sampled = torch.multinomial(flat_probs, num_samples=1).squeeze(-1) return sampled.reshape(original_shape) # --------------------------------------------------------------------------- # mdm_sample_deep_scaffold — Section Diffusion (SD) # --------------------------------------------------------------------------- def mdm_sample_deep_scaffold( self, input_ids, tokenizer, max_tokens=512, pixel_values=None, image_grid_thw=None, mask_id=151665, null_id=151666, threshold=0.9, stop_token=151645, explanation_block_size=32, explanation_max_blocks=6, block_size=32, return_stats=False, use_kv_cache=True, temperature=0.0, ): """ Deep scaffold MDM generation with train-consistent hybrid block causal mask. Pre-fills the entire JSON scaffold (including sub-keys for critical_objects, future_meta_behavior, trajectory) with MASK tokens at value positions only. Then denoises each section's value tokens via iterative unmasking. The attention mask matches training: prompt tokens use causal attention, response tokens use block-causal attention where each section's denoise steps form separate blocks. Block i can see all prompt tokens and blocks 0..i, but NOT blocks i+1..N (which still contain MASK tokens). For explanation (variable length), NULL tokens in the output signal that the section content is complete — trailing NULLs are stripped. KV-cache path (``use_kv_cache=True``, default): Prompt K/V is computed once with vision embedding scatter, then each response block, once fully denoised, gets its K/V appended to the cache. Subsequent blocks' iterative unmasking only forwards their own ~block_size tokens against the cache (plus prior committed blocks), avoiding O(seqlen^2) recomputation of the prompt + prior blocks every iteration. Correctness is preserved because block-causal attention means block k only attends to prompt + blocks 0..k, which is exactly what the cache provides. """ import math import os as _os from .section_utils import ( build_deep_json_scaffold, SECTION_KEYS, NULL_TOKEN_ID, ) # Env override for A/B testing the KV cache path without editing code. _kv_env = _os.environ.get("MDM_DS_USE_KV_CACHE") if _kv_env is not None: use_kv_cache = _kv_env not in ("0", "false", "False", "") scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( tokenizer, mask_id=mask_id, null_id=null_id, explanation_block_size=explanation_block_size, explanation_max_blocks=explanation_max_blocks, ) tokens_per_step = [] original_input_length = input_ids.shape[1] # Phase 1: Build sequence with scaffold appended scaffold_tensor = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long).unsqueeze(0) x_t = torch.cat([input_ids, scaffold_tensor], dim=1) seqlen = x_t.shape[1] # Track scaffold (frozen) vs value (to denoise) positions in scaffold region scaffold_frozen = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool) # ── Build response_block_idx matching training's compute_section_block_idx_deep_static ── response_block_idx = torch.full((seqlen,), -1, device=self.device, dtype=torch.long) current_block = 0 assigned = set() for section_name in SECTION_KEYS: if section_name not in section_ranges: continue sec_start, sec_end = section_ranges[section_name] # Find value positions (non-scaffold) in this section value_positions = [] for i in range(sec_start, sec_end): if not scaffold_mask_list[i]: # 0 = value token value_positions.append(original_input_length + i) if not value_positions: current_block += 1 continue # Block assignment MUST match training's # compute_section_block_idx_deep_static: n_blocks = ceil(value/block_size) # for every section. Previously non-explanation sections were forced to # a single block; that broke attention alignment for trajectory # (70 value tokens → training 3 blocks vs inference 1 block), causing # trajectory over-extrapolation. CO (12) and FMB (6) still resolve # to 1 block since their value counts are < block_size. tokens_per_step_sec = block_size n_steps = max(1, math.ceil(len(value_positions) / tokens_per_step_sec)) # Assign block indices to value tokens for vi, abs_pos in enumerate(value_positions): block_in_section = min(vi // tokens_per_step_sec, n_steps - 1) response_block_idx[abs_pos] = current_block + block_in_section assigned.add(abs_pos) # Assign scaffold tokens to nearest value token's block for i in range(sec_start, sec_end): abs_pos = original_input_length + i if scaffold_mask_list[i] and abs_pos not in assigned: best_block = -1 for delta in range(1, sec_end - sec_start + 10): for cand in [abs_pos + delta, abs_pos - delta]: if cand in assigned: best_block = response_block_idx[cand].item() break if best_block >= 0: break if best_block >= 0: response_block_idx[abs_pos] = best_block assigned.add(abs_pos) current_block += n_steps # Assign any remaining unassigned scaffold tokens (e.g. top-level separators) for i in range(len(scaffold_tokens)): abs_pos = original_input_length + i if abs_pos not in assigned: # Find nearest assigned position best_block = -1 for delta in range(1, seqlen): for cand in [abs_pos + delta, abs_pos - delta]: if 0 <= cand < seqlen and cand in assigned: best_block = response_block_idx[cand].item() break if best_block >= 0: break if best_block >= 0: response_block_idx[abs_pos] = best_block assigned.add(abs_pos) # ── Build hybrid block causal mask (computed once, reused for all forward passes) ── attention_mask = self.model.eval_hybrid_mask(seqlen, response_block_idx).to(self.device) # Section-MoE-LoRA: set section_ids before language model forward set_section_ids = lambda *a, **kw: None # noqa: E731 (Section-MoE-LoRA disabled in release) # Map block indices to section IDs (0=CO, 1=Exp, 2=FMB, 3=Traj, 4=Other/Prompt) _sec_ids = torch.full((seqlen,), 4, device=self.device, dtype=torch.long) for section_name, (sec_start, sec_end) in section_ranges.items(): abs_start = original_input_length + sec_start abs_end = original_input_length + sec_end if section_name == "critical_objects": _sec_ids[abs_start:abs_end] = 0 elif section_name == "explanation": _sec_ids[abs_start:abs_end] = 1 elif section_name == "future_meta_behavior": _sec_ids[abs_start:abs_end] = 2 elif section_name == "trajectory": _sec_ids[abs_start:abs_end] = 3 # Add batch dimension _sec_ids_batch = _sec_ids.unsqueeze(0) set_section_ids(_sec_ids_batch) # ── Precompute vision embeddings and position_ids once ── # BUG FIX: Previously pixel_values was only passed on the first forward # (step==0) but with use_cache=False every forward is independent, so all # subsequent forwards lost vision information entirely. _embed_fn = self.model.get_input_embeddings() _cached_image_embeds = None _cached_image_mask = None if pixel_values is not None: _cached_image_embeds = self.model.get_image_features(pixel_values, image_grid_thw) _cached_image_embeds = torch.cat(_cached_image_embeds, dim=0).to( self.device, _embed_fn.weight.dtype ) _tmp_embeds = _embed_fn(x_t) _cached_image_mask, _ = self.model.get_placeholder_mask( x_t, inputs_embeds=_tmp_embeds, image_features=_cached_image_embeds ) # Compute position_ids once with correct image_grid_thw (3D RoPE) _position_ids, _rope_deltas = self.model.get_rope_index( x_t, image_grid_thw, None ) self.model.rope_deltas = _rope_deltas # ── Compute contiguous block ranges in the response region ── # Each block's absolute [start, end) range in x_t is the maximal # contiguous span of positions sharing the same response_block_idx. # Blocks are ordered by block_idx and cover the entire response. _block_ranges = [] # list of (block_idx, abs_start, abs_end) _cur_bi = None _cur_start = None for _p in range(seqlen): _bi = int(response_block_idx[_p].item()) if _bi < 0: if _cur_bi is not None: _block_ranges.append((_cur_bi, _cur_start, _p)) _cur_bi, _cur_start = None, None continue if _cur_bi is None: _cur_bi, _cur_start = _bi, _p elif _bi != _cur_bi: _block_ranges.append((_cur_bi, _cur_start, _p)) _cur_bi, _cur_start = _bi, _p if _cur_bi is not None: _block_ranges.append((_cur_bi, _cur_start, seqlen)) # Map block_idx -> section_name for downstream logic (section-specific # behaviors like explanation NULL handling can still be scoped). _block_idx_to_section = {} for _sname, (_sstart, _send) in section_ranges.items(): _sabs_start = original_input_length + _sstart _sabs_end = original_input_length + _send for _bi, _bs, _be in _block_ranges: # Assign section by whether the block's range overlaps the section if _bs < _sabs_end and _be > _sabs_start: _block_idx_to_section.setdefault(_bi, _sname) # ── Phase 2: Denoise block-by-block with optional KV cache ── # Without cache (fallback): each forward replays the entire sequence. # With cache: prompt K/V computed once; each block's finalized K/V is # appended after denoising, so later blocks only forward their own # ~block_size tokens against the cache. step = 0 past_kv = None prev_last_logit = None # logit at the position just before the next block if use_kv_cache: # Phase 0: prompt prefill. Includes vision scatter; cache becomes # the reusable foundation for every scaffold block. prompt_tokens = x_t[:, :original_input_length] prompt_embeds = _embed_fn(prompt_tokens) if _cached_image_embeds is not None: prompt_image_mask = _cached_image_mask[:, :original_input_length] prompt_embeds = prompt_embeds.masked_scatter( prompt_image_mask, _cached_image_embeds ) prompt_position_ids = _position_ids[..., :original_input_length] # Causal over prompt (matches training's prompt-side attention). # When attention_mask=None, the model's eval_mask auto-builds causal # because use_block_causal_mask=True and update_kv_cache=True. prompt_out = self.forward( inputs_embeds=prompt_embeds, position_ids=prompt_position_ids, attention_mask=None, past_key_values=None, use_cache=True, update_kv_cache=True, ) past_kv = prompt_out.past_key_values # Logit at position (original_input_length - 1); used to predict # the first token of the first response block via causal shift. prev_last_logit = prompt_out.logits[:, -1:, :] # ── Iterate blocks in order ── for _block_idx, block_abs_start, block_abs_end in _block_ranges: B = block_abs_end - block_abs_start section_name = _block_idx_to_section.get(_block_idx, None) # Count MASK tokens in this block block_slice = x_t[0, block_abs_start:block_abs_end] n_masks_in_block = int((block_slice == mask_id).sum().item()) # ── Iterative unmasking within this block (if any MASKs) ── if n_masks_in_block > 0: max_iter = n_masks_in_block + 5 # safety limit for _ in range(max_iter): current_block_masks = (x_t[:, block_abs_start:block_abs_end] == mask_id) if current_block_masks.sum() == 0: break if use_kv_cache: # Feed only this block; past_kv covers prompt + prior blocks. block_tokens = x_t[:, block_abs_start:block_abs_end] block_embeds = _embed_fn(block_tokens) block_position_ids = _position_ids[..., block_abs_start:block_abs_end] L_cached = past_kv.get_seq_length() if past_kv is not None else 0 # Block-causal + bidirectional-within-block ⇒ this # block's queries attend to all cached KV plus all # fresh block KV ⇒ all-True mask of shape [B, L+B]. block_attn = torch.ones( B, L_cached + B, device=self.device, dtype=torch.bool ) output = self.forward( inputs_embeds=block_embeds, attention_mask=block_attn, position_ids=block_position_ids, past_key_values=past_kv, use_cache=True, update_kv_cache=False, # read-only during iteration ) logits = output.logits # [1, B, V] # Shift: pred for abs_pos uses logit at abs_pos-1. # logit at block_abs_start-1 is prev_last_logit; the # rest come from this forward's earlier positions. sec_logits = torch.cat([prev_last_logit, logits[:, :-1, :]], dim=1) else: # Full-sequence forward (fallback path, same as before) _cur_embeds = _embed_fn(x_t) if _cached_image_embeds is not None: _cur_embeds = _cur_embeds.masked_scatter( _cached_image_mask, _cached_image_embeds ) output = self.forward( input_ids=x_t, inputs_embeds=_cur_embeds, attention_mask=attention_mask, position_ids=_position_ids, use_cache=False, ) logits = output.logits sec_logits = logits[:, block_abs_start:block_abs_end, :] sec_logits = torch.cat( [logits[:, block_abs_start - 1:block_abs_start, :], sec_logits[:, :-1, :]], dim=1 ) if temperature > 0: # Temperature sampling for diverse generation (e.g. GRPO rollouts) sampling_probs = torch.softmax(sec_logits / temperature, dim=-1) x_1 = torch.multinomial( sampling_probs.view(-1, sampling_probs.shape[-1]), num_samples=1 ).view(sampling_probs.shape[:-1]) else: # Greedy (default, backward compatible) x_1 = sec_logits.argmax(dim=-1) probs = torch.softmax(sec_logits, dim=-1) x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1) # Only consider currently-masked positions in this block x1_p = torch.where(current_block_masks, x1_p, -torch.inf) unmask_idx = (x1_p > threshold) if unmask_idx.sum() > 0: x_t[:, block_abs_start:block_abs_end][unmask_idx] = x_1[unmask_idx] tokens_per_step.append(int(unmask_idx.sum())) else: # Fallback: unmask highest-confidence token pos = x1_p.argmax() row = 0 col = pos.item() x_t[:, block_abs_start:block_abs_end][row, col] = x_1[row, col] tokens_per_step.append(1) step += 1 if step > max_tokens: break # ── Commit this block's K/V to the cache ── # Run one final forward at block's fully-denoised state with # update_kv_cache=True so future blocks can attend to it via cache. # prev_last_logit is refreshed to the logit at the last position # of this block for the NEXT block's first-position prediction. if use_kv_cache: block_tokens = x_t[:, block_abs_start:block_abs_end] block_embeds = _embed_fn(block_tokens) block_position_ids = _position_ids[..., block_abs_start:block_abs_end] L_cached = past_kv.get_seq_length() if past_kv is not None else 0 block_attn = torch.ones( B, L_cached + B, device=self.device, dtype=torch.bool ) commit_out = self.forward( inputs_embeds=block_embeds, attention_mask=block_attn, position_ids=block_position_ids, past_key_values=past_kv, use_cache=True, update_kv_cache=True, ) past_kv = commit_out.past_key_values prev_last_logit = commit_out.logits[:, -1:, :] # NOTE: a previous null_ratio>0.3 early-stopping heuristic was # removed. It computed the ratio globally across the whole # explanation and, when tripped, force-filled every remaining # MASK with NULL — including MASKs in middle positions that # should have held real text — which cut short explanations # mid-sentence. Training always produces 192 value tokens # (real text + <|NULL|> padding at the tail) and the model # learned to emit NULL cleanly at the tail, so the final # NULL-strip below is sufficient. Cost: every sample now # denoises all 6 explanation blocks. # Post-process: strip NULL tokens from the output gen_tokens = x_t[0, original_input_length:].tolist() cleaned = [t for t in gen_tokens if t != null_id and t != mask_id] x_t = torch.cat([ input_ids, torch.tensor([cleaned], device=self.device, dtype=torch.long) ], dim=1) gen_length = x_t.shape[1] - original_input_length if return_stats: stats = { "tokens_per_step": tokens_per_step, "total_steps": step, "gen_length": gen_length, "null_tokens_stripped": len(gen_tokens) - len(cleaned), "block_size": block_size, } return x_t, stats return x_t @torch.no_grad() # --------------------------------------------------------------------------- # scaffold_speculative_sample — Scaffold Spec (SS) # --------------------------------------------------------------------------- def scaffold_speculative_sample( self, input_ids, tokenizer, block_size=32, max_tokens=1024, pixel_values=None, image_grid_thw=None, mask_id=151665, null_id=151666, threshold=0.9, stop_token=151645, explanation_block_size=32, explanation_max_blocks=6, return_stats=False, draft_temperature=0.0, verify_temperature=0.0, ): """ Scaffold-aware self-speculative decoding. Minimal modification of standard self-spec (speculative_block_causal_sample_cache): scaffold (structural JSON) tokens are pre-filled in the draft block instead of MASK and auto-accepted during causal verification. Key design: uses *exactly the same* attention patterns as standard self-spec (block-diff for draft, **causal** for verify via auto eval_mask). Only the draft block content differs — scaffold positions carry known tokens instead of MASK, giving the draft better context while scaffold tokens are "free" during acceptance. """ from .section_utils import ( build_deep_json_scaffold, NULL_TOKEN_ID, ) scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( tokenizer, mask_id=mask_id, null_id=null_id, explanation_block_size=explanation_block_size, explanation_max_blocks=explanation_max_blocks, ) scaffold_len = len(scaffold_tokens) original_input_length = input_ids.shape[1] tokens_per_step = [] self.model.bd_size = block_size _ss_profile = bool(os.environ.get("SS_PROFILE")) _ss_traj_start = section_ranges.get("trajectory", (None, None))[0] if _ss_profile: import time as _time torch.cuda.synchronize() _ss_t = {"start": _time.perf_counter()} _ss_marked_traj_start = False _ss_n_fwd_prefix = 0 _ss_n_fwd_traj = 0 # Pre-convert to tensors for vectorized operations in the loop scaffold_tok_t = torch.tensor( scaffold_tokens, device=self.device, dtype=torch.long ) scaffold_is_fixed = torch.tensor( scaffold_mask_list, device=self.device, dtype=torch.bool ) # ── Phase 1: Prefill prompt (identical to standard self-spec) ── output = self.forward( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, use_cache=True, update_kv_cache=True, ) logits, past_key_values = output.logits, output.past_key_values if _ss_profile: torch.cuda.synchronize() _ss_t["after_prefill"] = _time.perf_counter() # First token — use scaffold token (always '{') next_token = torch.tensor( [[scaffold_tokens[0]]], device=self.device, dtype=torch.long ) input_ids = torch.cat([input_ids, next_token], dim=1) tokens_per_step.append(1) scaffold_cursor = 1 step = 1 # ── Phase 2: Self-speculative decoding loop ── # Follows the exact same structure as # speculative_block_causal_sample_cache, with scaffold-aware draft. while scaffold_cursor < scaffold_len: if _ss_profile and (not _ss_marked_traj_start) and ( _ss_traj_start is not None and scaffold_cursor >= _ss_traj_start ): torch.cuda.synchronize() _ss_t["enter_traj"] = _time.perf_counter() _ss_marked_traj_start = True prompt_length = input_ids.shape[1] n_draft = min(block_size - 1, scaffold_len - scaffold_cursor) # Build draft block: [seed, scaffold_or_MASK × n_draft] sc_end = scaffold_cursor + n_draft is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end] draft_tensor = torch.where( is_fixed, scaffold_tok_t[scaffold_cursor:sc_end], mask_id, ).unsqueeze(0) x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1) mask_idx = (x_t == mask_id) # ── Draft (block-diff bidirectional via auto eval_mask) ── logits = self.forward( input_ids=x_t, use_cache=True, past_key_values=past_key_values, update_kv_cache=False, eval_bd_size=block_size, ).logits tokens_per_step.append(0) step += 1 # Shift logits (same as standard self-spec) logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) if draft_temperature > 0: # Temperature sampling for draft diversity scaled = logits / draft_temperature draft_probs = torch.softmax(scaled, dim=-1) x_1 = torch.multinomial( draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1 ).view(draft_probs.shape[:-1]) # Confidence uses unscaled probs for thresholding probs = torch.softmax(logits, dim=-1) x1_p = torch.gather( probs, dim=-1, index=x_1.unsqueeze(-1) ).squeeze(-1) else: x_1 = logits.argmax(dim=-1) probs = torch.softmax(logits, dim=-1) x1_p = torch.gather( probs, dim=-1, index=x_1.unsqueeze(-1) ).squeeze(-1) # Only fill MASK positions; scaffold positions keep their tokens x1_p = torch.where(mask_idx, x1_p, -torch.inf) unmask_idx = (x1_p > 0) # threshold=0 for draft filling if unmask_idx.sum() > 0: x_t[unmask_idx] = x_1[unmask_idx] else: # Fallback: fill most confident MASK mask_only_p = x1_p.clone() mask_only_p[~mask_idx] = -torch.inf if mask_only_p.max() > -torch.inf: best = mask_only_p.argmax() x_t.view(-1)[best] = x_1.view(-1)[best] # ── Verify (causal via auto eval_mask, commit to cache) ── output = self.forward( input_ids=x_t, use_cache=True, past_key_values=past_key_values, update_kv_cache=True, eval_bd_size=block_size, ) past_key_values = output.past_key_values if verify_temperature > 0: verify_logits = output.logits / verify_temperature verify_probs = torch.softmax(verify_logits, dim=-1) ar_block_token = torch.multinomial( verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1 ).view(verify_probs.shape[:-1]) else: ar_block_token = output.logits.argmax(dim=-1) # ── AR acceptance (scaffold positions auto-pass) ── ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) accepted_token_num = 0 for i in range(n_draft): if is_fixed[i] or ar_matches[i]: accepted_token_num += 1 else: break accepted_token_num += 1 # bonus token tokens_per_step.append(accepted_token_num) # Force scaffold tokens at scaffold positions, AR predictions elsewhere accepted_ids = ar_block_token[:, :accepted_token_num].clone() acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len) acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end] accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed] input_ids = torch.cat([input_ids, accepted_ids], dim=1) scaffold_cursor += accepted_token_num past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1) step += 1 # Stop conditions if input_ids.shape[1] - original_input_length > max_tokens: break if stop_token in input_ids[:, prompt_length:]: stop_token_idx = ( input_ids[:, prompt_length:] == stop_token ).nonzero()[0][1] if ( input_ids[:, prompt_length:prompt_length + stop_token_idx] == mask_id ).sum() == 0: break if _ss_profile: torch.cuda.synchronize() _ss_t["end"] = _time.perf_counter() _t_total = _ss_t["end"] - _ss_t["start"] _t_pre = _ss_t["after_prefill"] - _ss_t["start"] _t_traj_in = _ss_t.get("enter_traj") if _t_traj_in is not None: _t_prefix = _t_traj_in - _ss_t["after_prefill"] _t_traj = _ss_t["end"] - _t_traj_in else: _t_prefix = _ss_t["end"] - _ss_t["after_prefill"] _t_traj = 0.0 print( f"[ss profile] total={_t_total*1000:.0f}ms " f"prefill={_t_pre*1000:.0f}ms " f"prefix-decode={_t_prefix*1000:.0f}ms " f"traj-decode={_t_traj*1000:.0f}ms", flush=True, ) # ── Phase 3: Post-process — truncate at stop, strip NULL ── if stop_token in input_ids[:, original_input_length:]: stop_token_idx = ( input_ids[:, original_input_length:] == stop_token ).nonzero()[0][1] input_ids = input_ids[ :, :stop_token_idx + original_input_length + 1 ] gen_tokens = input_ids[0, original_input_length:].tolist() cleaned = [t for t in gen_tokens if t != null_id and t != mask_id] output_ids = torch.cat( [ input_ids[:, :original_input_length], torch.tensor( [cleaned], device=self.device, dtype=torch.long ), ], dim=1, ) gen_length = output_ids.shape[1] - original_input_length if return_stats: stats = { "tokens_per_step": tokens_per_step, "total_steps": step, "gen_length": gen_length, "null_tokens_stripped": len(gen_tokens) - len(cleaned), "block_size": block_size, "method": "scaffold_speculative_v5", } return output_ids, stats return output_ids @torch.no_grad() # --------------------------------------------------------------------------- # scaffold_spec_with_ss_multi_traj — SS multi-rollout inference scaling # --------------------------------------------------------------------------- def scaffold_spec_with_ss_multi_traj( self, input_ids, tokenizer, block_size=32, max_tokens=1024, pixel_values=None, image_grid_thw=None, mask_id=151665, null_id=151666, threshold=0.9, stop_token=151645, explanation_block_size=32, explanation_max_blocks=6, return_stats=False, num_traj_rollouts=4, traj_verify_temperature=0.5, traj_draft_temperature=0.0, merge_weights=None, batch_parallel=False, ): """Scaffold Spec with shared prefix + N SS rollouts on the trajectory section. Decoding pipeline: 0) Prompt prefill [shared] 1) Scaffold Spec for sections 1-3 (CoT) at verify_temp = 0 [shared, deterministic] 2) Fork KV cache N times [O(N) memory] 3) For each fork: continue Scaffold Spec on the trajectory section with verify_temperature = traj_verify_temperature (each rollout draws different samples in the AR-verify step because torch.multinomial is invoked with a global RNG). 4) Parse all N trajectories and return their weighted mean. Cost: roughly 1 full SS pass (sections 1-3 are ~88%% of decoded tokens on our schema) + N x trajectory-only SS passes. For N = 4 this is ~1.5x the cost of a single SS, vs ~4x for naive sequential rerolling. If batch_parallel = True, the N trajectory rollouts are executed in a batched (batch_size = N) manner: one shared model.forward per speculative draft / verify step over an N-replicated trajectory suffix, which removes the per-rollout serial overhead at the cost of replicating the per-layer KV cache N-fold along the batch dimension. Returns: (output_ids, stats) if return_stats else output_ids. """ from .section_utils import ( build_deep_json_scaffold, SECTION_KEYS, ) scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( tokenizer, mask_id=mask_id, null_id=null_id, explanation_block_size=explanation_block_size, explanation_max_blocks=explanation_max_blocks, ) scaffold_len = len(scaffold_tokens) original_input_length = input_ids.shape[1] tokens_per_step = [] self.model.bd_size = block_size scaffold_tok_t = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long) scaffold_is_fixed = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool) traj_start_in_scaffold = section_ranges["trajectory"][0] _profile = bool(os.environ.get("SS_MT_PROFILE")) if _profile: import time as _time torch.cuda.synchronize() _t_phase = {"start": _time.perf_counter()} _phase_clone_total = 0.0 _phase_rollout_each = [] # ── Phase 0: Prefill prompt ── output = self.forward( input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, use_cache=True, update_kv_cache=True, ) logits, past_key_values = output.logits, output.past_key_values if _profile: torch.cuda.synchronize() _t_phase["after_prefill"] = _time.perf_counter() next_token = torch.tensor( [[scaffold_tokens[0]]], device=self.device, dtype=torch.long, ) input_ids = torch.cat([input_ids, next_token], dim=1) tokens_per_step.append(1) scaffold_cursor = 1 step = 1 # ── Phase 1: Scaffold Spec for non-trajectory sections (shared, vt=0) ── while scaffold_cursor < scaffold_len and scaffold_cursor < traj_start_in_scaffold: remaining_before_traj = traj_start_in_scaffold - scaffold_cursor n_draft = min(block_size - 1, remaining_before_traj) if n_draft <= 0: break sc_end = scaffold_cursor + n_draft is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end] draft_tensor = torch.where( is_fixed, scaffold_tok_t[scaffold_cursor:sc_end], mask_id, ).unsqueeze(0) x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1) mask_idx = (x_t == mask_id) # Draft (block-bidirectional) logits = self.forward( input_ids=x_t, use_cache=True, past_key_values=past_key_values, update_kv_cache=False, eval_bd_size=block_size, ).logits tokens_per_step.append(0) step += 1 logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) x_1 = logits.argmax(dim=-1) probs = torch.softmax(logits, dim=-1) x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1) x1_p = torch.where(mask_idx, x1_p, -torch.inf) unmask_idx = (x1_p > 0) if unmask_idx.sum() > 0: x_t[unmask_idx] = x_1[unmask_idx] else: mask_only_p = x1_p.clone() mask_only_p[~mask_idx] = -torch.inf if mask_only_p.max() > -torch.inf: best = mask_only_p.argmax() x_t.view(-1)[best] = x_1.view(-1)[best] # Verify (causal, greedy) output = self.forward( input_ids=x_t, use_cache=True, past_key_values=past_key_values, update_kv_cache=True, eval_bd_size=block_size, ) past_key_values = output.past_key_values ar_block_token = output.logits.argmax(dim=-1) ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) accepted_token_num = 0 for i in range(n_draft): if is_fixed[i] or ar_matches[i]: accepted_token_num += 1 else: break accepted_token_num += 1 max_accept = traj_start_in_scaffold - scaffold_cursor if accepted_token_num > max_accept: accepted_token_num = max_accept tokens_per_step.append(accepted_token_num) accepted_ids = ar_block_token[:, :accepted_token_num].clone() acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len) acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end] accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed] input_ids = torch.cat([input_ids, accepted_ids], dim=1) scaffold_cursor += accepted_token_num past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1) step += 1 if input_ids.shape[1] - original_input_length > max_tokens: break if _profile: torch.cuda.synchronize() _t_phase["after_phase1"] = _time.perf_counter() # ── Phase 2: Fork KV cache N times (one per trajectory rollout) ── prefix_input_ids = input_ids.clone() prefix_len = prefix_input_ids.shape[1] def _clone_cache(kv): if _profile: torch.cuda.synchronize() _t0 = _time.perf_counter() cloned = [] for layer_num in range(len(kv)): cloned.append(tuple(t.clone() for t in kv[layer_num])) ret = DynamicCache(cloned) if _profile: torch.cuda.synchronize() nonlocal _phase_clone_total _phase_clone_total += _time.perf_counter() - _t0 return ret # ── Phase 3: N SS rollouts on trajectory section, each with vt > 0 ── # All rollouts start from the same prefix; randomness comes from # the multinomial calls in draft / verify (RNG is process-global). N = max(1, int(num_traj_rollouts)) def _run_one_traj_rollout(start_kv, start_input_ids): """Continue Scaffold Spec from start_kv / start_input_ids over the trajectory section, applying traj_*_temperature. Returns the final ss_input_ids (with trajectory tokens appended) and the extracted trajectory value tokens.""" local_kv = start_kv local_input = start_input_ids local_cursor = scaffold_cursor while local_cursor < scaffold_len: n_draft = min(block_size - 1, scaffold_len - local_cursor) sc_end = local_cursor + n_draft is_fixed = scaffold_is_fixed[local_cursor:sc_end] draft_tensor = torch.where( is_fixed, scaffold_tok_t[local_cursor:sc_end], mask_id, ).unsqueeze(0) x_t = torch.cat([local_input[:, -1:], draft_tensor], dim=1) mask_idx = (x_t == mask_id) # Draft (block-bidirectional, optionally temp-sampled) draft_logits = self.forward( input_ids=x_t, use_cache=True, past_key_values=local_kv, update_kv_cache=False, eval_bd_size=block_size, ).logits draft_logits = torch.cat( [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1, ) if traj_draft_temperature > 0: scaled = draft_logits / traj_draft_temperature draft_probs = torch.softmax(scaled, dim=-1) x_1 = torch.multinomial( draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1, ).view(draft_probs.shape[:-1]) else: x_1 = draft_logits.argmax(dim=-1) probs = torch.softmax(draft_logits, dim=-1) x1_p = torch.gather( probs, dim=-1, index=x_1.unsqueeze(-1), ).squeeze(-1) x1_p = torch.where(mask_idx, x1_p, -torch.inf) unmask_idx = (x1_p > 0) if unmask_idx.sum() > 0: x_t[unmask_idx] = x_1[unmask_idx] else: mask_only_p = x1_p.clone() mask_only_p[~mask_idx] = -torch.inf if mask_only_p.max() > -torch.inf: x_t.view(-1)[mask_only_p.argmax()] = \ x_1.view(-1)[mask_only_p.argmax()] # Verify (causal, optionally temp-sampled) v_out = self.forward( input_ids=x_t, use_cache=True, past_key_values=local_kv, update_kv_cache=True, eval_bd_size=block_size, ) local_kv = v_out.past_key_values if traj_verify_temperature > 0: v_logits = v_out.logits / traj_verify_temperature v_probs = torch.softmax(v_logits, dim=-1) ar_block_token = torch.multinomial( v_probs.view(-1, v_probs.shape[-1]), num_samples=1, ).view(v_probs.shape[:-1]) else: ar_block_token = v_out.logits.argmax(dim=-1) ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) accepted_token_num = 0 for i in range(n_draft): if is_fixed[i] or ar_matches[i]: accepted_token_num += 1 else: break accepted_token_num += 1 accepted_ids = ar_block_token[:, :accepted_token_num].clone() acc_end = min(local_cursor + accepted_token_num, scaffold_len) acc_fixed = scaffold_is_fixed[local_cursor:acc_end] accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ scaffold_tok_t[local_cursor:acc_end][acc_fixed] local_input = torch.cat([local_input, accepted_ids], dim=1) local_cursor += accepted_token_num local_kv = _crop_cache(local_kv, local_input.shape[1] - 1) if local_input.shape[1] - original_input_length > max_tokens: break if stop_token in local_input[:, prefix_len:]: st_idx = (local_input[:, prefix_len:] == stop_token).nonzero() if st_idx.numel() > 0: cand_st = st_idx[0][1].item() if (local_input[:, prefix_len:prefix_len + cand_st] == mask_id).sum() == 0: break traj_values = [ t for i, t in enumerate(local_input[0, original_input_length:].tolist()) if i >= traj_start_in_scaffold and i < scaffold_len and not scaffold_mask_list[i] and t != null_id and t != mask_id ] return local_input, traj_values # Sequential N rollouts (Option A; batch_parallel=False). rollout_inputs = [] rollout_traj_values = [] for _i in range(N): if _profile: torch.cuda.synchronize() _t_r0 = _time.perf_counter() cand_kv = _clone_cache(past_key_values) cand_input = prefix_input_ids.clone() cand_input, traj_vals = _run_one_traj_rollout(cand_kv, cand_input) rollout_inputs.append(cand_input) rollout_traj_values.append(traj_vals) step += 1 if _profile: torch.cuda.synchronize() _phase_rollout_each.append(_time.perf_counter() - _t_r0) if _profile: torch.cuda.synchronize() _t_phase["after_rollouts"] = _time.perf_counter() _t_total = _t_phase["after_rollouts"] - _t_phase["start"] _t_pre = _t_phase["after_prefill"] - _t_phase["start"] _t_p1 = _t_phase["after_phase1"] - _t_phase["after_prefill"] _t_rolls = _t_phase["after_rollouts"] - _t_phase["after_phase1"] print( f"[ss_mt profile] total={_t_total*1000:.0f}ms " f"prefill(P0)={_t_pre*1000:.0f}ms " f"prefix-decode(P1)={_t_p1*1000:.0f}ms " f"rollouts(P2+P3)={_t_rolls*1000:.0f}ms " f"of which kv-clone={_phase_clone_total*1000:.0f}ms " f"per-rollout={[f'{r*1000:.0f}' for r in _phase_rollout_each]}ms", flush=True, ) # ── Phase 4: Parse all rollouts, weighted-merge waypoints ── def _decode_trajectory(traj_tokens): text = tokenizer.decode(traj_tokens, skip_special_tokens=False) text = text.replace("<|NULL|>", "").strip() coords = re.findall(r"[+-]?\d+\.?\d*", text) wps = [] for i in range(0, len(coords) - 1, 2): wps.append([float(coords[i]), float(coords[i + 1])]) return wps rollout_waypoints = [_decode_trajectory(v) for v in rollout_traj_values] if merge_weights is None or len(merge_weights) != N: ws = [1.0 / N] * N else: total = sum(merge_weights) ws = [w / total for w in merge_weights] if rollout_waypoints and all(len(w) > 0 for w in rollout_waypoints): n_wp = min(len(w) for w in rollout_waypoints) merged_waypoints = [] for i in range(n_wp): mx = sum(ws[c] * rollout_waypoints[c][i][0] for c in range(N)) my = sum(ws[c] * rollout_waypoints[c][i][1] for c in range(N)) merged_waypoints.append([mx, my]) else: merged_waypoints = next( (w for w in rollout_waypoints if w), [], ) # Output text: take rollout 0's full text but replace its trajectory # with the merged waypoints. base_input = rollout_inputs[0] if stop_token in base_input[:, original_input_length:]: st_idx = (base_input[:, original_input_length:] == stop_token).nonzero()[0][1] base_input = base_input[:, :st_idx + original_input_length + 1] base_raw_tokens = base_input[0, original_input_length:].tolist() base_cleaned = [t for t in base_raw_tokens if t != null_id and t != mask_id] base_null_stripped = len(base_raw_tokens) - len(base_cleaned) base_text = tokenizer.decode(base_cleaned, skip_special_tokens=False) traj_parts = [ f"[{x:+07.2f},{y:+06.2f}]" for x, y in merged_waypoints ] merged_traj_str = "[" + ", ".join(traj_parts) + "]" replaced_text = re.sub( r'("trajectory"\s*:\s*")(\[\[.*?\]\])', r"\g<1>" + merged_traj_str, base_text, ) merged_tokens = tokenizer.encode(replaced_text, add_special_tokens=False) output_ids = torch.cat([ input_ids[:, :original_input_length], torch.tensor([merged_tokens], device=self.device, dtype=torch.long), ], dim=1) gen_length = output_ids.shape[1] - original_input_length if return_stats: stats = { "tokens_per_step": tokens_per_step, "total_steps": step, "gen_length": gen_length, "null_tokens_stripped": base_null_stripped, "block_size": block_size, "method": "scaffold_spec_with_ss_multi_traj", "num_traj_rollouts": N, "traj_verify_temperature": traj_verify_temperature, "rollout_waypoints": rollout_waypoints, "merged_waypoints": merged_waypoints, "merge_weights": ws, } return output_ids, stats return output_ids @torch.no_grad() # --------------------------------------------------------------------------- # Bind decoding methods onto the model class. # # ``modeling.py`` imports this module at the bottom of the file, after the # ``Fast_dDriveForConditionalGeneration`` class has been defined. We # attach the three decoding paths as ordinary methods so callers can invoke # them as ``model.mdm_sample_deep_scaffold(...)`` etc. without any extra # registration step. # --------------------------------------------------------------------------- def attach_generation_methods(cls): """Attach the three release decoding paths as methods of ``cls``.""" cls.mdm_sample_deep_scaffold = mdm_sample_deep_scaffold cls.scaffold_speculative_sample = scaffold_speculative_sample cls.scaffold_spec_with_ss_multi_traj = scaffold_spec_with_ss_multi_traj return cls __all__ = [ "mdm_sample_deep_scaffold", "scaffold_speculative_sample", "scaffold_spec_with_ss_multi_traj", "attach_generation_methods", ]