Fast-dDrive / generation_utils.py
xiwenyoumu's picture
Initial Fast-dDrive 3B release
5e9a603 verified
"""Generation utilities for Fast-dDrive.
This module provides the three inference paths exposed by the canonical paper
release:
* ``mdm_sample_deep_scaffold`` — Section Diffusion (SD): iterative MDM
denoising over a pre-filled JSON scaffold, no AR verification.
* ``scaffold_speculative_sample`` — Scaffold Spec (SS): scaffold-aware
self-speculative decoding (MDM draft + AR verify per block).
* ``scaffold_spec_with_ss_multi_traj`` — SS with shared-prefix multi-trajectory
rollouts (the test-time inference-scaling path).
All three are attached as bound methods on
:class:`Fast_dDriveForConditionalGeneration` when this module is
imported (see ``modeling.py`` for the import hook).
"""
import os
import re
import sys
import math
import torch
import types
import numpy as np
from transformers.cache_utils import DynamicCache
def _crop_cache(past_key_values, max_length: int):
"""Crop a DynamicCache to max_length tokens, compatible with Qwen cache layout."""
new_past_key_values = []
for layer_num in range(len(past_key_values)):
layer_past_key_values = ()
for kv_idx in range(len(past_key_values[layer_num])):
layer_past_key_values += (past_key_values[layer_num][kv_idx][:, :, :max_length, :],)
new_past_key_values.append(layer_past_key_values)
return DynamicCache(new_past_key_values)
def _sample_from_logits(logits, temperature=0.0):
"""Sample token ids from logits with optional temperature scaling.
When temperature <= 0, falls back to argmax (greedy).
"""
if temperature <= 0:
return logits.argmax(dim=-1)
scaled = logits / temperature
probs = torch.softmax(scaled, dim=-1)
original_shape = probs.shape[:-1]
flat_probs = probs.reshape(-1, probs.shape[-1])
sampled = torch.multinomial(flat_probs, num_samples=1).squeeze(-1)
return sampled.reshape(original_shape)
# ---------------------------------------------------------------------------
# mdm_sample_deep_scaffold — Section Diffusion (SD)
# ---------------------------------------------------------------------------
def mdm_sample_deep_scaffold(
self,
input_ids,
tokenizer,
max_tokens=512,
pixel_values=None,
image_grid_thw=None,
mask_id=151665,
null_id=151666,
threshold=0.9,
stop_token=151645,
explanation_block_size=32,
explanation_max_blocks=6,
block_size=32,
return_stats=False,
use_kv_cache=True,
temperature=0.0,
):
"""
Deep scaffold MDM generation with train-consistent hybrid block causal mask.
Pre-fills the entire JSON scaffold (including sub-keys for critical_objects,
future_meta_behavior, trajectory) with MASK tokens at value positions only.
Then denoises each section's value tokens via iterative unmasking.
The attention mask matches training: prompt tokens use causal attention,
response tokens use block-causal attention where each section's denoise
steps form separate blocks. Block i can see all prompt tokens and blocks
0..i, but NOT blocks i+1..N (which still contain MASK tokens).
For explanation (variable length), NULL tokens in the output signal that
the section content is complete — trailing NULLs are stripped.
KV-cache path (``use_kv_cache=True``, default):
Prompt K/V is computed once with vision embedding scatter, then each
response block, once fully denoised, gets its K/V appended to the
cache. Subsequent blocks' iterative unmasking only forwards their
own ~block_size tokens against the cache (plus prior committed
blocks), avoiding O(seqlen^2) recomputation of the prompt + prior
blocks every iteration. Correctness is preserved because
block-causal attention means block k only attends to prompt +
blocks 0..k, which is exactly what the cache provides.
"""
import math
import os as _os
from .section_utils import (
build_deep_json_scaffold,
SECTION_KEYS,
NULL_TOKEN_ID,
)
# Env override for A/B testing the KV cache path without editing code.
_kv_env = _os.environ.get("MDM_DS_USE_KV_CACHE")
if _kv_env is not None:
use_kv_cache = _kv_env not in ("0", "false", "False", "")
scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold(
tokenizer,
mask_id=mask_id,
null_id=null_id,
explanation_block_size=explanation_block_size,
explanation_max_blocks=explanation_max_blocks,
)
tokens_per_step = []
original_input_length = input_ids.shape[1]
# Phase 1: Build sequence with scaffold appended
scaffold_tensor = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
x_t = torch.cat([input_ids, scaffold_tensor], dim=1)
seqlen = x_t.shape[1]
# Track scaffold (frozen) vs value (to denoise) positions in scaffold region
scaffold_frozen = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool)
# ── Build response_block_idx matching training's compute_section_block_idx_deep_static ──
response_block_idx = torch.full((seqlen,), -1, device=self.device, dtype=torch.long)
current_block = 0
assigned = set()
for section_name in SECTION_KEYS:
if section_name not in section_ranges:
continue
sec_start, sec_end = section_ranges[section_name]
# Find value positions (non-scaffold) in this section
value_positions = []
for i in range(sec_start, sec_end):
if not scaffold_mask_list[i]: # 0 = value token
value_positions.append(original_input_length + i)
if not value_positions:
current_block += 1
continue
# Block assignment MUST match training's
# compute_section_block_idx_deep_static: n_blocks = ceil(value/block_size)
# for every section. Previously non-explanation sections were forced to
# a single block; that broke attention alignment for trajectory
# (70 value tokens → training 3 blocks vs inference 1 block), causing
# trajectory over-extrapolation. CO (12) and FMB (6) still resolve
# to 1 block since their value counts are < block_size.
tokens_per_step_sec = block_size
n_steps = max(1, math.ceil(len(value_positions) / tokens_per_step_sec))
# Assign block indices to value tokens
for vi, abs_pos in enumerate(value_positions):
block_in_section = min(vi // tokens_per_step_sec, n_steps - 1)
response_block_idx[abs_pos] = current_block + block_in_section
assigned.add(abs_pos)
# Assign scaffold tokens to nearest value token's block
for i in range(sec_start, sec_end):
abs_pos = original_input_length + i
if scaffold_mask_list[i] and abs_pos not in assigned:
best_block = -1
for delta in range(1, sec_end - sec_start + 10):
for cand in [abs_pos + delta, abs_pos - delta]:
if cand in assigned:
best_block = response_block_idx[cand].item()
break
if best_block >= 0:
break
if best_block >= 0:
response_block_idx[abs_pos] = best_block
assigned.add(abs_pos)
current_block += n_steps
# Assign any remaining unassigned scaffold tokens (e.g. top-level separators)
for i in range(len(scaffold_tokens)):
abs_pos = original_input_length + i
if abs_pos not in assigned:
# Find nearest assigned position
best_block = -1
for delta in range(1, seqlen):
for cand in [abs_pos + delta, abs_pos - delta]:
if 0 <= cand < seqlen and cand in assigned:
best_block = response_block_idx[cand].item()
break
if best_block >= 0:
break
if best_block >= 0:
response_block_idx[abs_pos] = best_block
assigned.add(abs_pos)
# ── Build hybrid block causal mask (computed once, reused for all forward passes) ──
attention_mask = self.model.eval_hybrid_mask(seqlen, response_block_idx).to(self.device)
# Section-MoE-LoRA: set section_ids before language model forward
set_section_ids = lambda *a, **kw: None # noqa: E731 (Section-MoE-LoRA disabled in release)
# Map block indices to section IDs (0=CO, 1=Exp, 2=FMB, 3=Traj, 4=Other/Prompt)
_sec_ids = torch.full((seqlen,), 4, device=self.device, dtype=torch.long)
for section_name, (sec_start, sec_end) in section_ranges.items():
abs_start = original_input_length + sec_start
abs_end = original_input_length + sec_end
if section_name == "critical_objects":
_sec_ids[abs_start:abs_end] = 0
elif section_name == "explanation":
_sec_ids[abs_start:abs_end] = 1
elif section_name == "future_meta_behavior":
_sec_ids[abs_start:abs_end] = 2
elif section_name == "trajectory":
_sec_ids[abs_start:abs_end] = 3
# Add batch dimension
_sec_ids_batch = _sec_ids.unsqueeze(0)
set_section_ids(_sec_ids_batch)
# ── Precompute vision embeddings and position_ids once ──
# BUG FIX: Previously pixel_values was only passed on the first forward
# (step==0) but with use_cache=False every forward is independent, so all
# subsequent forwards lost vision information entirely.
_embed_fn = self.model.get_input_embeddings()
_cached_image_embeds = None
_cached_image_mask = None
if pixel_values is not None:
_cached_image_embeds = self.model.get_image_features(pixel_values, image_grid_thw)
_cached_image_embeds = torch.cat(_cached_image_embeds, dim=0).to(
self.device, _embed_fn.weight.dtype
)
_tmp_embeds = _embed_fn(x_t)
_cached_image_mask, _ = self.model.get_placeholder_mask(
x_t, inputs_embeds=_tmp_embeds, image_features=_cached_image_embeds
)
# Compute position_ids once with correct image_grid_thw (3D RoPE)
_position_ids, _rope_deltas = self.model.get_rope_index(
x_t, image_grid_thw, None
)
self.model.rope_deltas = _rope_deltas
# ── Compute contiguous block ranges in the response region ──
# Each block's absolute [start, end) range in x_t is the maximal
# contiguous span of positions sharing the same response_block_idx.
# Blocks are ordered by block_idx and cover the entire response.
_block_ranges = [] # list of (block_idx, abs_start, abs_end)
_cur_bi = None
_cur_start = None
for _p in range(seqlen):
_bi = int(response_block_idx[_p].item())
if _bi < 0:
if _cur_bi is not None:
_block_ranges.append((_cur_bi, _cur_start, _p))
_cur_bi, _cur_start = None, None
continue
if _cur_bi is None:
_cur_bi, _cur_start = _bi, _p
elif _bi != _cur_bi:
_block_ranges.append((_cur_bi, _cur_start, _p))
_cur_bi, _cur_start = _bi, _p
if _cur_bi is not None:
_block_ranges.append((_cur_bi, _cur_start, seqlen))
# Map block_idx -> section_name for downstream logic (section-specific
# behaviors like explanation NULL handling can still be scoped).
_block_idx_to_section = {}
for _sname, (_sstart, _send) in section_ranges.items():
_sabs_start = original_input_length + _sstart
_sabs_end = original_input_length + _send
for _bi, _bs, _be in _block_ranges:
# Assign section by whether the block's range overlaps the section
if _bs < _sabs_end and _be > _sabs_start:
_block_idx_to_section.setdefault(_bi, _sname)
# ── Phase 2: Denoise block-by-block with optional KV cache ──
# Without cache (fallback): each forward replays the entire sequence.
# With cache: prompt K/V computed once; each block's finalized K/V is
# appended after denoising, so later blocks only forward their own
# ~block_size tokens against the cache.
step = 0
past_kv = None
prev_last_logit = None # logit at the position just before the next block
if use_kv_cache:
# Phase 0: prompt prefill. Includes vision scatter; cache becomes
# the reusable foundation for every scaffold block.
prompt_tokens = x_t[:, :original_input_length]
prompt_embeds = _embed_fn(prompt_tokens)
if _cached_image_embeds is not None:
prompt_image_mask = _cached_image_mask[:, :original_input_length]
prompt_embeds = prompt_embeds.masked_scatter(
prompt_image_mask, _cached_image_embeds
)
prompt_position_ids = _position_ids[..., :original_input_length]
# Causal over prompt (matches training's prompt-side attention).
# When attention_mask=None, the model's eval_mask auto-builds causal
# because use_block_causal_mask=True and update_kv_cache=True.
prompt_out = self.forward(
inputs_embeds=prompt_embeds,
position_ids=prompt_position_ids,
attention_mask=None,
past_key_values=None,
use_cache=True,
update_kv_cache=True,
)
past_kv = prompt_out.past_key_values
# Logit at position (original_input_length - 1); used to predict
# the first token of the first response block via causal shift.
prev_last_logit = prompt_out.logits[:, -1:, :]
# ── Iterate blocks in order ──
for _block_idx, block_abs_start, block_abs_end in _block_ranges:
B = block_abs_end - block_abs_start
section_name = _block_idx_to_section.get(_block_idx, None)
# Count MASK tokens in this block
block_slice = x_t[0, block_abs_start:block_abs_end]
n_masks_in_block = int((block_slice == mask_id).sum().item())
# ── Iterative unmasking within this block (if any MASKs) ──
if n_masks_in_block > 0:
max_iter = n_masks_in_block + 5 # safety limit
for _ in range(max_iter):
current_block_masks = (x_t[:, block_abs_start:block_abs_end] == mask_id)
if current_block_masks.sum() == 0:
break
if use_kv_cache:
# Feed only this block; past_kv covers prompt + prior blocks.
block_tokens = x_t[:, block_abs_start:block_abs_end]
block_embeds = _embed_fn(block_tokens)
block_position_ids = _position_ids[..., block_abs_start:block_abs_end]
L_cached = past_kv.get_seq_length() if past_kv is not None else 0
# Block-causal + bidirectional-within-block ⇒ this
# block's queries attend to all cached KV plus all
# fresh block KV ⇒ all-True mask of shape [B, L+B].
block_attn = torch.ones(
B, L_cached + B, device=self.device, dtype=torch.bool
)
output = self.forward(
inputs_embeds=block_embeds,
attention_mask=block_attn,
position_ids=block_position_ids,
past_key_values=past_kv,
use_cache=True,
update_kv_cache=False, # read-only during iteration
)
logits = output.logits # [1, B, V]
# Shift: pred for abs_pos uses logit at abs_pos-1.
# logit at block_abs_start-1 is prev_last_logit; the
# rest come from this forward's earlier positions.
sec_logits = torch.cat([prev_last_logit, logits[:, :-1, :]], dim=1)
else:
# Full-sequence forward (fallback path, same as before)
_cur_embeds = _embed_fn(x_t)
if _cached_image_embeds is not None:
_cur_embeds = _cur_embeds.masked_scatter(
_cached_image_mask, _cached_image_embeds
)
output = self.forward(
input_ids=x_t,
inputs_embeds=_cur_embeds,
attention_mask=attention_mask,
position_ids=_position_ids,
use_cache=False,
)
logits = output.logits
sec_logits = logits[:, block_abs_start:block_abs_end, :]
sec_logits = torch.cat(
[logits[:, block_abs_start - 1:block_abs_start, :],
sec_logits[:, :-1, :]], dim=1
)
if temperature > 0:
# Temperature sampling for diverse generation (e.g. GRPO rollouts)
sampling_probs = torch.softmax(sec_logits / temperature, dim=-1)
x_1 = torch.multinomial(
sampling_probs.view(-1, sampling_probs.shape[-1]), num_samples=1
).view(sampling_probs.shape[:-1])
else:
# Greedy (default, backward compatible)
x_1 = sec_logits.argmax(dim=-1)
probs = torch.softmax(sec_logits, dim=-1)
x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1)
# Only consider currently-masked positions in this block
x1_p = torch.where(current_block_masks, x1_p, -torch.inf)
unmask_idx = (x1_p > threshold)
if unmask_idx.sum() > 0:
x_t[:, block_abs_start:block_abs_end][unmask_idx] = x_1[unmask_idx]
tokens_per_step.append(int(unmask_idx.sum()))
else:
# Fallback: unmask highest-confidence token
pos = x1_p.argmax()
row = 0
col = pos.item()
x_t[:, block_abs_start:block_abs_end][row, col] = x_1[row, col]
tokens_per_step.append(1)
step += 1
if step > max_tokens:
break
# ── Commit this block's K/V to the cache ──
# Run one final forward at block's fully-denoised state with
# update_kv_cache=True so future blocks can attend to it via cache.
# prev_last_logit is refreshed to the logit at the last position
# of this block for the NEXT block's first-position prediction.
if use_kv_cache:
block_tokens = x_t[:, block_abs_start:block_abs_end]
block_embeds = _embed_fn(block_tokens)
block_position_ids = _position_ids[..., block_abs_start:block_abs_end]
L_cached = past_kv.get_seq_length() if past_kv is not None else 0
block_attn = torch.ones(
B, L_cached + B, device=self.device, dtype=torch.bool
)
commit_out = self.forward(
inputs_embeds=block_embeds,
attention_mask=block_attn,
position_ids=block_position_ids,
past_key_values=past_kv,
use_cache=True,
update_kv_cache=True,
)
past_kv = commit_out.past_key_values
prev_last_logit = commit_out.logits[:, -1:, :]
# NOTE: a previous null_ratio>0.3 early-stopping heuristic was
# removed. It computed the ratio globally across the whole
# explanation and, when tripped, force-filled every remaining
# MASK with NULL — including MASKs in middle positions that
# should have held real text — which cut short explanations
# mid-sentence. Training always produces 192 value tokens
# (real text + <|NULL|> padding at the tail) and the model
# learned to emit NULL cleanly at the tail, so the final
# NULL-strip below is sufficient. Cost: every sample now
# denoises all 6 explanation blocks.
# Post-process: strip NULL tokens from the output
gen_tokens = x_t[0, original_input_length:].tolist()
cleaned = [t for t in gen_tokens if t != null_id and t != mask_id]
x_t = torch.cat([
input_ids,
torch.tensor([cleaned], device=self.device, dtype=torch.long)
], dim=1)
gen_length = x_t.shape[1] - original_input_length
if return_stats:
stats = {
"tokens_per_step": tokens_per_step,
"total_steps": step,
"gen_length": gen_length,
"null_tokens_stripped": len(gen_tokens) - len(cleaned),
"block_size": block_size,
}
return x_t, stats
return x_t
@torch.no_grad()
# ---------------------------------------------------------------------------
# scaffold_speculative_sample — Scaffold Spec (SS)
# ---------------------------------------------------------------------------
def scaffold_speculative_sample(
self,
input_ids,
tokenizer,
block_size=32,
max_tokens=1024,
pixel_values=None,
image_grid_thw=None,
mask_id=151665,
null_id=151666,
threshold=0.9,
stop_token=151645,
explanation_block_size=32,
explanation_max_blocks=6,
return_stats=False,
draft_temperature=0.0,
verify_temperature=0.0,
):
"""
Scaffold-aware self-speculative decoding.
Minimal modification of standard self-spec
(speculative_block_causal_sample_cache): scaffold (structural JSON)
tokens are pre-filled in the draft block instead of MASK and
auto-accepted during causal verification.
Key design: uses *exactly the same* attention patterns as standard
self-spec (block-diff for draft, **causal** for verify via
auto eval_mask). Only the draft block content differs — scaffold
positions carry known tokens instead of MASK, giving the draft
better context while scaffold tokens are "free" during acceptance.
"""
from .section_utils import (
build_deep_json_scaffold,
NULL_TOKEN_ID,
)
scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold(
tokenizer,
mask_id=mask_id,
null_id=null_id,
explanation_block_size=explanation_block_size,
explanation_max_blocks=explanation_max_blocks,
)
scaffold_len = len(scaffold_tokens)
original_input_length = input_ids.shape[1]
tokens_per_step = []
self.model.bd_size = block_size
_ss_profile = bool(os.environ.get("SS_PROFILE"))
_ss_traj_start = section_ranges.get("trajectory", (None, None))[0]
if _ss_profile:
import time as _time
torch.cuda.synchronize()
_ss_t = {"start": _time.perf_counter()}
_ss_marked_traj_start = False
_ss_n_fwd_prefix = 0
_ss_n_fwd_traj = 0
# Pre-convert to tensors for vectorized operations in the loop
scaffold_tok_t = torch.tensor(
scaffold_tokens, device=self.device, dtype=torch.long
)
scaffold_is_fixed = torch.tensor(
scaffold_mask_list, device=self.device, dtype=torch.bool
)
# ── Phase 1: Prefill prompt (identical to standard self-spec) ──
output = self.forward(
input_ids=input_ids,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
use_cache=True,
update_kv_cache=True,
)
logits, past_key_values = output.logits, output.past_key_values
if _ss_profile:
torch.cuda.synchronize()
_ss_t["after_prefill"] = _time.perf_counter()
# First token — use scaffold token (always '{')
next_token = torch.tensor(
[[scaffold_tokens[0]]], device=self.device, dtype=torch.long
)
input_ids = torch.cat([input_ids, next_token], dim=1)
tokens_per_step.append(1)
scaffold_cursor = 1
step = 1
# ── Phase 2: Self-speculative decoding loop ──
# Follows the exact same structure as
# speculative_block_causal_sample_cache, with scaffold-aware draft.
while scaffold_cursor < scaffold_len:
if _ss_profile and (not _ss_marked_traj_start) and (
_ss_traj_start is not None and scaffold_cursor >= _ss_traj_start
):
torch.cuda.synchronize()
_ss_t["enter_traj"] = _time.perf_counter()
_ss_marked_traj_start = True
prompt_length = input_ids.shape[1]
n_draft = min(block_size - 1, scaffold_len - scaffold_cursor)
# Build draft block: [seed, scaffold_or_MASK × n_draft]
sc_end = scaffold_cursor + n_draft
is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end]
draft_tensor = torch.where(
is_fixed,
scaffold_tok_t[scaffold_cursor:sc_end],
mask_id,
).unsqueeze(0)
x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1)
mask_idx = (x_t == mask_id)
# ── Draft (block-diff bidirectional via auto eval_mask) ──
logits = self.forward(
input_ids=x_t,
use_cache=True,
past_key_values=past_key_values,
update_kv_cache=False,
eval_bd_size=block_size,
).logits
tokens_per_step.append(0)
step += 1
# Shift logits (same as standard self-spec)
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
if draft_temperature > 0:
# Temperature sampling for draft diversity
scaled = logits / draft_temperature
draft_probs = torch.softmax(scaled, dim=-1)
x_1 = torch.multinomial(
draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1
).view(draft_probs.shape[:-1])
# Confidence uses unscaled probs for thresholding
probs = torch.softmax(logits, dim=-1)
x1_p = torch.gather(
probs, dim=-1, index=x_1.unsqueeze(-1)
).squeeze(-1)
else:
x_1 = logits.argmax(dim=-1)
probs = torch.softmax(logits, dim=-1)
x1_p = torch.gather(
probs, dim=-1, index=x_1.unsqueeze(-1)
).squeeze(-1)
# Only fill MASK positions; scaffold positions keep their tokens
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
unmask_idx = (x1_p > 0) # threshold=0 for draft filling
if unmask_idx.sum() > 0:
x_t[unmask_idx] = x_1[unmask_idx]
else:
# Fallback: fill most confident MASK
mask_only_p = x1_p.clone()
mask_only_p[~mask_idx] = -torch.inf
if mask_only_p.max() > -torch.inf:
best = mask_only_p.argmax()
x_t.view(-1)[best] = x_1.view(-1)[best]
# ── Verify (causal via auto eval_mask, commit to cache) ──
output = self.forward(
input_ids=x_t,
use_cache=True,
past_key_values=past_key_values,
update_kv_cache=True,
eval_bd_size=block_size,
)
past_key_values = output.past_key_values
if verify_temperature > 0:
verify_logits = output.logits / verify_temperature
verify_probs = torch.softmax(verify_logits, dim=-1)
ar_block_token = torch.multinomial(
verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1
).view(verify_probs.shape[:-1])
else:
ar_block_token = output.logits.argmax(dim=-1)
# ── AR acceptance (scaffold positions auto-pass) ──
ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1])
accepted_token_num = 0
for i in range(n_draft):
if is_fixed[i] or ar_matches[i]:
accepted_token_num += 1
else:
break
accepted_token_num += 1 # bonus token
tokens_per_step.append(accepted_token_num)
# Force scaffold tokens at scaffold positions, AR predictions elsewhere
accepted_ids = ar_block_token[:, :accepted_token_num].clone()
acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len)
acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end]
accepted_ids[0, :len(acc_fixed)][acc_fixed] = \
scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed]
input_ids = torch.cat([input_ids, accepted_ids], dim=1)
scaffold_cursor += accepted_token_num
past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1)
step += 1
# Stop conditions
if input_ids.shape[1] - original_input_length > max_tokens:
break
if stop_token in input_ids[:, prompt_length:]:
stop_token_idx = (
input_ids[:, prompt_length:] == stop_token
).nonzero()[0][1]
if (
input_ids[:, prompt_length:prompt_length + stop_token_idx]
== mask_id
).sum() == 0:
break
if _ss_profile:
torch.cuda.synchronize()
_ss_t["end"] = _time.perf_counter()
_t_total = _ss_t["end"] - _ss_t["start"]
_t_pre = _ss_t["after_prefill"] - _ss_t["start"]
_t_traj_in = _ss_t.get("enter_traj")
if _t_traj_in is not None:
_t_prefix = _t_traj_in - _ss_t["after_prefill"]
_t_traj = _ss_t["end"] - _t_traj_in
else:
_t_prefix = _ss_t["end"] - _ss_t["after_prefill"]
_t_traj = 0.0
print(
f"[ss profile] total={_t_total*1000:.0f}ms "
f"prefill={_t_pre*1000:.0f}ms "
f"prefix-decode={_t_prefix*1000:.0f}ms "
f"traj-decode={_t_traj*1000:.0f}ms",
flush=True,
)
# ── Phase 3: Post-process — truncate at stop, strip NULL ──
if stop_token in input_ids[:, original_input_length:]:
stop_token_idx = (
input_ids[:, original_input_length:] == stop_token
).nonzero()[0][1]
input_ids = input_ids[
:, :stop_token_idx + original_input_length + 1
]
gen_tokens = input_ids[0, original_input_length:].tolist()
cleaned = [t for t in gen_tokens if t != null_id and t != mask_id]
output_ids = torch.cat(
[
input_ids[:, :original_input_length],
torch.tensor(
[cleaned], device=self.device, dtype=torch.long
),
],
dim=1,
)
gen_length = output_ids.shape[1] - original_input_length
if return_stats:
stats = {
"tokens_per_step": tokens_per_step,
"total_steps": step,
"gen_length": gen_length,
"null_tokens_stripped": len(gen_tokens) - len(cleaned),
"block_size": block_size,
"method": "scaffold_speculative_v5",
}
return output_ids, stats
return output_ids
@torch.no_grad()
# ---------------------------------------------------------------------------
# scaffold_spec_with_ss_multi_traj — SS multi-rollout inference scaling
# ---------------------------------------------------------------------------
def scaffold_spec_with_ss_multi_traj(
self,
input_ids,
tokenizer,
block_size=32,
max_tokens=1024,
pixel_values=None,
image_grid_thw=None,
mask_id=151665,
null_id=151666,
threshold=0.9,
stop_token=151645,
explanation_block_size=32,
explanation_max_blocks=6,
return_stats=False,
num_traj_rollouts=4,
traj_verify_temperature=0.5,
traj_draft_temperature=0.0,
merge_weights=None,
batch_parallel=False,
):
"""Scaffold Spec with shared prefix + N SS rollouts on the trajectory section.
Decoding pipeline:
0) Prompt prefill [shared]
1) Scaffold Spec for sections 1-3 (CoT) at verify_temp = 0 [shared, deterministic]
2) Fork KV cache N times [O(N) memory]
3) For each fork: continue Scaffold Spec on the trajectory
section with verify_temperature = traj_verify_temperature
(each rollout draws different samples in the AR-verify step
because torch.multinomial is invoked with a global RNG).
4) Parse all N trajectories and return their weighted mean.
Cost: roughly 1 full SS pass (sections 1-3 are ~88%% of decoded tokens
on our schema) + N x trajectory-only SS passes. For N = 4 this is
~1.5x the cost of a single SS, vs ~4x for naive sequential rerolling.
If batch_parallel = True, the N trajectory rollouts are executed in a
batched (batch_size = N) manner: one shared model.forward per
speculative draft / verify step over an N-replicated trajectory
suffix, which removes the per-rollout serial overhead at the cost of
replicating the per-layer KV cache N-fold along the batch dimension.
Returns: (output_ids, stats) if return_stats else output_ids.
"""
from .section_utils import (
build_deep_json_scaffold,
SECTION_KEYS,
)
scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold(
tokenizer,
mask_id=mask_id,
null_id=null_id,
explanation_block_size=explanation_block_size,
explanation_max_blocks=explanation_max_blocks,
)
scaffold_len = len(scaffold_tokens)
original_input_length = input_ids.shape[1]
tokens_per_step = []
self.model.bd_size = block_size
scaffold_tok_t = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long)
scaffold_is_fixed = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool)
traj_start_in_scaffold = section_ranges["trajectory"][0]
_profile = bool(os.environ.get("SS_MT_PROFILE"))
if _profile:
import time as _time
torch.cuda.synchronize()
_t_phase = {"start": _time.perf_counter()}
_phase_clone_total = 0.0
_phase_rollout_each = []
# ── Phase 0: Prefill prompt ──
output = self.forward(
input_ids=input_ids, pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
use_cache=True, update_kv_cache=True,
)
logits, past_key_values = output.logits, output.past_key_values
if _profile:
torch.cuda.synchronize()
_t_phase["after_prefill"] = _time.perf_counter()
next_token = torch.tensor(
[[scaffold_tokens[0]]], device=self.device, dtype=torch.long,
)
input_ids = torch.cat([input_ids, next_token], dim=1)
tokens_per_step.append(1)
scaffold_cursor = 1
step = 1
# ── Phase 1: Scaffold Spec for non-trajectory sections (shared, vt=0) ──
while scaffold_cursor < scaffold_len and scaffold_cursor < traj_start_in_scaffold:
remaining_before_traj = traj_start_in_scaffold - scaffold_cursor
n_draft = min(block_size - 1, remaining_before_traj)
if n_draft <= 0:
break
sc_end = scaffold_cursor + n_draft
is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end]
draft_tensor = torch.where(
is_fixed, scaffold_tok_t[scaffold_cursor:sc_end], mask_id,
).unsqueeze(0)
x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1)
mask_idx = (x_t == mask_id)
# Draft (block-bidirectional)
logits = self.forward(
input_ids=x_t, use_cache=True,
past_key_values=past_key_values,
update_kv_cache=False, eval_bd_size=block_size,
).logits
tokens_per_step.append(0)
step += 1
logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1)
x_1 = logits.argmax(dim=-1)
probs = torch.softmax(logits, dim=-1)
x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1)
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
unmask_idx = (x1_p > 0)
if unmask_idx.sum() > 0:
x_t[unmask_idx] = x_1[unmask_idx]
else:
mask_only_p = x1_p.clone()
mask_only_p[~mask_idx] = -torch.inf
if mask_only_p.max() > -torch.inf:
best = mask_only_p.argmax()
x_t.view(-1)[best] = x_1.view(-1)[best]
# Verify (causal, greedy)
output = self.forward(
input_ids=x_t, use_cache=True,
past_key_values=past_key_values,
update_kv_cache=True, eval_bd_size=block_size,
)
past_key_values = output.past_key_values
ar_block_token = output.logits.argmax(dim=-1)
ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1])
accepted_token_num = 0
for i in range(n_draft):
if is_fixed[i] or ar_matches[i]:
accepted_token_num += 1
else:
break
accepted_token_num += 1
max_accept = traj_start_in_scaffold - scaffold_cursor
if accepted_token_num > max_accept:
accepted_token_num = max_accept
tokens_per_step.append(accepted_token_num)
accepted_ids = ar_block_token[:, :accepted_token_num].clone()
acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len)
acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end]
accepted_ids[0, :len(acc_fixed)][acc_fixed] = \
scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed]
input_ids = torch.cat([input_ids, accepted_ids], dim=1)
scaffold_cursor += accepted_token_num
past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1)
step += 1
if input_ids.shape[1] - original_input_length > max_tokens:
break
if _profile:
torch.cuda.synchronize()
_t_phase["after_phase1"] = _time.perf_counter()
# ── Phase 2: Fork KV cache N times (one per trajectory rollout) ──
prefix_input_ids = input_ids.clone()
prefix_len = prefix_input_ids.shape[1]
def _clone_cache(kv):
if _profile:
torch.cuda.synchronize()
_t0 = _time.perf_counter()
cloned = []
for layer_num in range(len(kv)):
cloned.append(tuple(t.clone() for t in kv[layer_num]))
ret = DynamicCache(cloned)
if _profile:
torch.cuda.synchronize()
nonlocal _phase_clone_total
_phase_clone_total += _time.perf_counter() - _t0
return ret
# ── Phase 3: N SS rollouts on trajectory section, each with vt > 0 ──
# All rollouts start from the same prefix; randomness comes from
# the multinomial calls in draft / verify (RNG is process-global).
N = max(1, int(num_traj_rollouts))
def _run_one_traj_rollout(start_kv, start_input_ids):
"""Continue Scaffold Spec from start_kv / start_input_ids over the
trajectory section, applying traj_*_temperature. Returns the
final ss_input_ids (with trajectory tokens appended) and the
extracted trajectory value tokens."""
local_kv = start_kv
local_input = start_input_ids
local_cursor = scaffold_cursor
while local_cursor < scaffold_len:
n_draft = min(block_size - 1, scaffold_len - local_cursor)
sc_end = local_cursor + n_draft
is_fixed = scaffold_is_fixed[local_cursor:sc_end]
draft_tensor = torch.where(
is_fixed, scaffold_tok_t[local_cursor:sc_end], mask_id,
).unsqueeze(0)
x_t = torch.cat([local_input[:, -1:], draft_tensor], dim=1)
mask_idx = (x_t == mask_id)
# Draft (block-bidirectional, optionally temp-sampled)
draft_logits = self.forward(
input_ids=x_t, use_cache=True, past_key_values=local_kv,
update_kv_cache=False, eval_bd_size=block_size,
).logits
draft_logits = torch.cat(
[draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1,
)
if traj_draft_temperature > 0:
scaled = draft_logits / traj_draft_temperature
draft_probs = torch.softmax(scaled, dim=-1)
x_1 = torch.multinomial(
draft_probs.view(-1, draft_probs.shape[-1]),
num_samples=1,
).view(draft_probs.shape[:-1])
else:
x_1 = draft_logits.argmax(dim=-1)
probs = torch.softmax(draft_logits, dim=-1)
x1_p = torch.gather(
probs, dim=-1, index=x_1.unsqueeze(-1),
).squeeze(-1)
x1_p = torch.where(mask_idx, x1_p, -torch.inf)
unmask_idx = (x1_p > 0)
if unmask_idx.sum() > 0:
x_t[unmask_idx] = x_1[unmask_idx]
else:
mask_only_p = x1_p.clone()
mask_only_p[~mask_idx] = -torch.inf
if mask_only_p.max() > -torch.inf:
x_t.view(-1)[mask_only_p.argmax()] = \
x_1.view(-1)[mask_only_p.argmax()]
# Verify (causal, optionally temp-sampled)
v_out = self.forward(
input_ids=x_t, use_cache=True, past_key_values=local_kv,
update_kv_cache=True, eval_bd_size=block_size,
)
local_kv = v_out.past_key_values
if traj_verify_temperature > 0:
v_logits = v_out.logits / traj_verify_temperature
v_probs = torch.softmax(v_logits, dim=-1)
ar_block_token = torch.multinomial(
v_probs.view(-1, v_probs.shape[-1]),
num_samples=1,
).view(v_probs.shape[:-1])
else:
ar_block_token = v_out.logits.argmax(dim=-1)
ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1])
accepted_token_num = 0
for i in range(n_draft):
if is_fixed[i] or ar_matches[i]:
accepted_token_num += 1
else:
break
accepted_token_num += 1
accepted_ids = ar_block_token[:, :accepted_token_num].clone()
acc_end = min(local_cursor + accepted_token_num, scaffold_len)
acc_fixed = scaffold_is_fixed[local_cursor:acc_end]
accepted_ids[0, :len(acc_fixed)][acc_fixed] = \
scaffold_tok_t[local_cursor:acc_end][acc_fixed]
local_input = torch.cat([local_input, accepted_ids], dim=1)
local_cursor += accepted_token_num
local_kv = _crop_cache(local_kv, local_input.shape[1] - 1)
if local_input.shape[1] - original_input_length > max_tokens:
break
if stop_token in local_input[:, prefix_len:]:
st_idx = (local_input[:, prefix_len:] == stop_token).nonzero()
if st_idx.numel() > 0:
cand_st = st_idx[0][1].item()
if (local_input[:, prefix_len:prefix_len + cand_st] == mask_id).sum() == 0:
break
traj_values = [
t for i, t in enumerate(local_input[0, original_input_length:].tolist())
if i >= traj_start_in_scaffold and i < scaffold_len
and not scaffold_mask_list[i] and t != null_id and t != mask_id
]
return local_input, traj_values
# Sequential N rollouts (Option A; batch_parallel=False).
rollout_inputs = []
rollout_traj_values = []
for _i in range(N):
if _profile:
torch.cuda.synchronize()
_t_r0 = _time.perf_counter()
cand_kv = _clone_cache(past_key_values)
cand_input = prefix_input_ids.clone()
cand_input, traj_vals = _run_one_traj_rollout(cand_kv, cand_input)
rollout_inputs.append(cand_input)
rollout_traj_values.append(traj_vals)
step += 1
if _profile:
torch.cuda.synchronize()
_phase_rollout_each.append(_time.perf_counter() - _t_r0)
if _profile:
torch.cuda.synchronize()
_t_phase["after_rollouts"] = _time.perf_counter()
_t_total = _t_phase["after_rollouts"] - _t_phase["start"]
_t_pre = _t_phase["after_prefill"] - _t_phase["start"]
_t_p1 = _t_phase["after_phase1"] - _t_phase["after_prefill"]
_t_rolls = _t_phase["after_rollouts"] - _t_phase["after_phase1"]
print(
f"[ss_mt profile] total={_t_total*1000:.0f}ms "
f"prefill(P0)={_t_pre*1000:.0f}ms "
f"prefix-decode(P1)={_t_p1*1000:.0f}ms "
f"rollouts(P2+P3)={_t_rolls*1000:.0f}ms "
f"of which kv-clone={_phase_clone_total*1000:.0f}ms "
f"per-rollout={[f'{r*1000:.0f}' for r in _phase_rollout_each]}ms",
flush=True,
)
# ── Phase 4: Parse all rollouts, weighted-merge waypoints ──
def _decode_trajectory(traj_tokens):
text = tokenizer.decode(traj_tokens, skip_special_tokens=False)
text = text.replace("<|NULL|>", "").strip()
coords = re.findall(r"[+-]?\d+\.?\d*", text)
wps = []
for i in range(0, len(coords) - 1, 2):
wps.append([float(coords[i]), float(coords[i + 1])])
return wps
rollout_waypoints = [_decode_trajectory(v) for v in rollout_traj_values]
if merge_weights is None or len(merge_weights) != N:
ws = [1.0 / N] * N
else:
total = sum(merge_weights)
ws = [w / total for w in merge_weights]
if rollout_waypoints and all(len(w) > 0 for w in rollout_waypoints):
n_wp = min(len(w) for w in rollout_waypoints)
merged_waypoints = []
for i in range(n_wp):
mx = sum(ws[c] * rollout_waypoints[c][i][0] for c in range(N))
my = sum(ws[c] * rollout_waypoints[c][i][1] for c in range(N))
merged_waypoints.append([mx, my])
else:
merged_waypoints = next(
(w for w in rollout_waypoints if w), [],
)
# Output text: take rollout 0's full text but replace its trajectory
# with the merged waypoints.
base_input = rollout_inputs[0]
if stop_token in base_input[:, original_input_length:]:
st_idx = (base_input[:, original_input_length:] == stop_token).nonzero()[0][1]
base_input = base_input[:, :st_idx + original_input_length + 1]
base_raw_tokens = base_input[0, original_input_length:].tolist()
base_cleaned = [t for t in base_raw_tokens if t != null_id and t != mask_id]
base_null_stripped = len(base_raw_tokens) - len(base_cleaned)
base_text = tokenizer.decode(base_cleaned, skip_special_tokens=False)
traj_parts = [
f"[{x:+07.2f},{y:+06.2f}]" for x, y in merged_waypoints
]
merged_traj_str = "[" + ", ".join(traj_parts) + "]"
replaced_text = re.sub(
r'("trajectory"\s*:\s*")(\[\[.*?\]\])',
r"\g<1>" + merged_traj_str, base_text,
)
merged_tokens = tokenizer.encode(replaced_text, add_special_tokens=False)
output_ids = torch.cat([
input_ids[:, :original_input_length],
torch.tensor([merged_tokens], device=self.device, dtype=torch.long),
], dim=1)
gen_length = output_ids.shape[1] - original_input_length
if return_stats:
stats = {
"tokens_per_step": tokens_per_step,
"total_steps": step,
"gen_length": gen_length,
"null_tokens_stripped": base_null_stripped,
"block_size": block_size,
"method": "scaffold_spec_with_ss_multi_traj",
"num_traj_rollouts": N,
"traj_verify_temperature": traj_verify_temperature,
"rollout_waypoints": rollout_waypoints,
"merged_waypoints": merged_waypoints,
"merge_weights": ws,
}
return output_ids, stats
return output_ids
@torch.no_grad()
# ---------------------------------------------------------------------------
# Bind decoding methods onto the model class.
#
# ``modeling.py`` imports this module at the bottom of the file, after the
# ``Fast_dDriveForConditionalGeneration`` class has been defined. We
# attach the three decoding paths as ordinary methods so callers can invoke
# them as ``model.mdm_sample_deep_scaffold(...)`` etc. without any extra
# registration step.
# ---------------------------------------------------------------------------
def attach_generation_methods(cls):
"""Attach the three release decoding paths as methods of ``cls``."""
cls.mdm_sample_deep_scaffold = mdm_sample_deep_scaffold
cls.scaffold_speculative_sample = scaffold_speculative_sample
cls.scaffold_spec_with_ss_multi_traj = scaffold_spec_with_ss_multi_traj
return cls
__all__ = [
"mdm_sample_deep_scaffold",
"scaffold_speculative_sample",
"scaffold_spec_with_ss_multi_traj",
"attach_generation_methods",
]