Image-Text-to-Text
Transformers
Safetensors
fast_d_drive
feature-extraction
block-diffusion
vision-language-action
autonomous-driving
qwen2.5-vl
conversational
custom_code
Instructions to use xiwenyoumu/Fast-dDrive with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use xiwenyoumu/Fast-dDrive with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="xiwenyoumu/Fast-dDrive", trust_remote_code=True) messages = [ { "role": "user", "content": [ {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"}, {"type": "text", "text": "What animal is on the candy?"} ] }, ] pipe(text=messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("xiwenyoumu/Fast-dDrive", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use xiwenyoumu/Fast-dDrive with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "xiwenyoumu/Fast-dDrive" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "xiwenyoumu/Fast-dDrive", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker
docker model run hf.co/xiwenyoumu/Fast-dDrive
- SGLang
How to use xiwenyoumu/Fast-dDrive with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "xiwenyoumu/Fast-dDrive" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "xiwenyoumu/Fast-dDrive", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "xiwenyoumu/Fast-dDrive" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "xiwenyoumu/Fast-dDrive", "messages": [ { "role": "user", "content": [ { "type": "text", "text": "Describe this image in one sentence." }, { "type": "image_url", "image_url": { "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" } } ] } ] }' - Docker Model Runner
How to use xiwenyoumu/Fast-dDrive with Docker Model Runner:
docker model run hf.co/xiwenyoumu/Fast-dDrive
| """Generation utilities for Fast-dDrive. | |
| This module provides the three inference paths exposed by the canonical paper | |
| release: | |
| * ``mdm_sample_deep_scaffold`` — Section Diffusion (SD): iterative MDM | |
| denoising over a pre-filled JSON scaffold, no AR verification. | |
| * ``scaffold_speculative_sample`` — Scaffold Spec (SS): scaffold-aware | |
| self-speculative decoding (MDM draft + AR verify per block). | |
| * ``scaffold_spec_with_ss_multi_traj`` — SS with shared-prefix multi-trajectory | |
| rollouts (the test-time inference-scaling path). | |
| All three are attached as bound methods on | |
| :class:`Fast_dDriveForConditionalGeneration` when this module is | |
| imported (see ``modeling.py`` for the import hook). | |
| """ | |
| import os | |
| import re | |
| import sys | |
| import math | |
| import torch | |
| import types | |
| import numpy as np | |
| from transformers.cache_utils import DynamicCache | |
| def _crop_cache(past_key_values, max_length: int): | |
| """Crop a DynamicCache to max_length tokens, compatible with Qwen cache layout.""" | |
| new_past_key_values = [] | |
| for layer_num in range(len(past_key_values)): | |
| layer_past_key_values = () | |
| for kv_idx in range(len(past_key_values[layer_num])): | |
| layer_past_key_values += (past_key_values[layer_num][kv_idx][:, :, :max_length, :],) | |
| new_past_key_values.append(layer_past_key_values) | |
| return DynamicCache(new_past_key_values) | |
| def _sample_from_logits(logits, temperature=0.0): | |
| """Sample token ids from logits with optional temperature scaling. | |
| When temperature <= 0, falls back to argmax (greedy). | |
| """ | |
| if temperature <= 0: | |
| return logits.argmax(dim=-1) | |
| scaled = logits / temperature | |
| probs = torch.softmax(scaled, dim=-1) | |
| original_shape = probs.shape[:-1] | |
| flat_probs = probs.reshape(-1, probs.shape[-1]) | |
| sampled = torch.multinomial(flat_probs, num_samples=1).squeeze(-1) | |
| return sampled.reshape(original_shape) | |
| # --------------------------------------------------------------------------- | |
| # mdm_sample_deep_scaffold — Section Diffusion (SD) | |
| # --------------------------------------------------------------------------- | |
| def mdm_sample_deep_scaffold( | |
| self, | |
| input_ids, | |
| tokenizer, | |
| max_tokens=512, | |
| pixel_values=None, | |
| image_grid_thw=None, | |
| mask_id=151665, | |
| null_id=151666, | |
| threshold=0.9, | |
| stop_token=151645, | |
| explanation_block_size=32, | |
| explanation_max_blocks=6, | |
| block_size=32, | |
| return_stats=False, | |
| use_kv_cache=True, | |
| temperature=0.0, | |
| ): | |
| """ | |
| Deep scaffold MDM generation with train-consistent hybrid block causal mask. | |
| Pre-fills the entire JSON scaffold (including sub-keys for critical_objects, | |
| future_meta_behavior, trajectory) with MASK tokens at value positions only. | |
| Then denoises each section's value tokens via iterative unmasking. | |
| The attention mask matches training: prompt tokens use causal attention, | |
| response tokens use block-causal attention where each section's denoise | |
| steps form separate blocks. Block i can see all prompt tokens and blocks | |
| 0..i, but NOT blocks i+1..N (which still contain MASK tokens). | |
| For explanation (variable length), NULL tokens in the output signal that | |
| the section content is complete — trailing NULLs are stripped. | |
| KV-cache path (``use_kv_cache=True``, default): | |
| Prompt K/V is computed once with vision embedding scatter, then each | |
| response block, once fully denoised, gets its K/V appended to the | |
| cache. Subsequent blocks' iterative unmasking only forwards their | |
| own ~block_size tokens against the cache (plus prior committed | |
| blocks), avoiding O(seqlen^2) recomputation of the prompt + prior | |
| blocks every iteration. Correctness is preserved because | |
| block-causal attention means block k only attends to prompt + | |
| blocks 0..k, which is exactly what the cache provides. | |
| """ | |
| import math | |
| import os as _os | |
| from .section_utils import ( | |
| build_deep_json_scaffold, | |
| SECTION_KEYS, | |
| NULL_TOKEN_ID, | |
| ) | |
| # Env override for A/B testing the KV cache path without editing code. | |
| _kv_env = _os.environ.get("MDM_DS_USE_KV_CACHE") | |
| if _kv_env is not None: | |
| use_kv_cache = _kv_env not in ("0", "false", "False", "") | |
| scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( | |
| tokenizer, | |
| mask_id=mask_id, | |
| null_id=null_id, | |
| explanation_block_size=explanation_block_size, | |
| explanation_max_blocks=explanation_max_blocks, | |
| ) | |
| tokens_per_step = [] | |
| original_input_length = input_ids.shape[1] | |
| # Phase 1: Build sequence with scaffold appended | |
| scaffold_tensor = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long).unsqueeze(0) | |
| x_t = torch.cat([input_ids, scaffold_tensor], dim=1) | |
| seqlen = x_t.shape[1] | |
| # Track scaffold (frozen) vs value (to denoise) positions in scaffold region | |
| scaffold_frozen = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool) | |
| # ── Build response_block_idx matching training's compute_section_block_idx_deep_static ── | |
| response_block_idx = torch.full((seqlen,), -1, device=self.device, dtype=torch.long) | |
| current_block = 0 | |
| assigned = set() | |
| for section_name in SECTION_KEYS: | |
| if section_name not in section_ranges: | |
| continue | |
| sec_start, sec_end = section_ranges[section_name] | |
| # Find value positions (non-scaffold) in this section | |
| value_positions = [] | |
| for i in range(sec_start, sec_end): | |
| if not scaffold_mask_list[i]: # 0 = value token | |
| value_positions.append(original_input_length + i) | |
| if not value_positions: | |
| current_block += 1 | |
| continue | |
| # Block assignment MUST match training's | |
| # compute_section_block_idx_deep_static: n_blocks = ceil(value/block_size) | |
| # for every section. Previously non-explanation sections were forced to | |
| # a single block; that broke attention alignment for trajectory | |
| # (70 value tokens → training 3 blocks vs inference 1 block), causing | |
| # trajectory over-extrapolation. CO (12) and FMB (6) still resolve | |
| # to 1 block since their value counts are < block_size. | |
| tokens_per_step_sec = block_size | |
| n_steps = max(1, math.ceil(len(value_positions) / tokens_per_step_sec)) | |
| # Assign block indices to value tokens | |
| for vi, abs_pos in enumerate(value_positions): | |
| block_in_section = min(vi // tokens_per_step_sec, n_steps - 1) | |
| response_block_idx[abs_pos] = current_block + block_in_section | |
| assigned.add(abs_pos) | |
| # Assign scaffold tokens to nearest value token's block | |
| for i in range(sec_start, sec_end): | |
| abs_pos = original_input_length + i | |
| if scaffold_mask_list[i] and abs_pos not in assigned: | |
| best_block = -1 | |
| for delta in range(1, sec_end - sec_start + 10): | |
| for cand in [abs_pos + delta, abs_pos - delta]: | |
| if cand in assigned: | |
| best_block = response_block_idx[cand].item() | |
| break | |
| if best_block >= 0: | |
| break | |
| if best_block >= 0: | |
| response_block_idx[abs_pos] = best_block | |
| assigned.add(abs_pos) | |
| current_block += n_steps | |
| # Assign any remaining unassigned scaffold tokens (e.g. top-level separators) | |
| for i in range(len(scaffold_tokens)): | |
| abs_pos = original_input_length + i | |
| if abs_pos not in assigned: | |
| # Find nearest assigned position | |
| best_block = -1 | |
| for delta in range(1, seqlen): | |
| for cand in [abs_pos + delta, abs_pos - delta]: | |
| if 0 <= cand < seqlen and cand in assigned: | |
| best_block = response_block_idx[cand].item() | |
| break | |
| if best_block >= 0: | |
| break | |
| if best_block >= 0: | |
| response_block_idx[abs_pos] = best_block | |
| assigned.add(abs_pos) | |
| # ── Build hybrid block causal mask (computed once, reused for all forward passes) ── | |
| attention_mask = self.model.eval_hybrid_mask(seqlen, response_block_idx).to(self.device) | |
| # Section-MoE-LoRA: set section_ids before language model forward | |
| set_section_ids = lambda *a, **kw: None # noqa: E731 (Section-MoE-LoRA disabled in release) | |
| # Map block indices to section IDs (0=CO, 1=Exp, 2=FMB, 3=Traj, 4=Other/Prompt) | |
| _sec_ids = torch.full((seqlen,), 4, device=self.device, dtype=torch.long) | |
| for section_name, (sec_start, sec_end) in section_ranges.items(): | |
| abs_start = original_input_length + sec_start | |
| abs_end = original_input_length + sec_end | |
| if section_name == "critical_objects": | |
| _sec_ids[abs_start:abs_end] = 0 | |
| elif section_name == "explanation": | |
| _sec_ids[abs_start:abs_end] = 1 | |
| elif section_name == "future_meta_behavior": | |
| _sec_ids[abs_start:abs_end] = 2 | |
| elif section_name == "trajectory": | |
| _sec_ids[abs_start:abs_end] = 3 | |
| # Add batch dimension | |
| _sec_ids_batch = _sec_ids.unsqueeze(0) | |
| set_section_ids(_sec_ids_batch) | |
| # ── Precompute vision embeddings and position_ids once ── | |
| # BUG FIX: Previously pixel_values was only passed on the first forward | |
| # (step==0) but with use_cache=False every forward is independent, so all | |
| # subsequent forwards lost vision information entirely. | |
| _embed_fn = self.model.get_input_embeddings() | |
| _cached_image_embeds = None | |
| _cached_image_mask = None | |
| if pixel_values is not None: | |
| _cached_image_embeds = self.model.get_image_features(pixel_values, image_grid_thw) | |
| _cached_image_embeds = torch.cat(_cached_image_embeds, dim=0).to( | |
| self.device, _embed_fn.weight.dtype | |
| ) | |
| _tmp_embeds = _embed_fn(x_t) | |
| _cached_image_mask, _ = self.model.get_placeholder_mask( | |
| x_t, inputs_embeds=_tmp_embeds, image_features=_cached_image_embeds | |
| ) | |
| # Compute position_ids once with correct image_grid_thw (3D RoPE) | |
| _position_ids, _rope_deltas = self.model.get_rope_index( | |
| x_t, image_grid_thw, None | |
| ) | |
| self.model.rope_deltas = _rope_deltas | |
| # ── Compute contiguous block ranges in the response region ── | |
| # Each block's absolute [start, end) range in x_t is the maximal | |
| # contiguous span of positions sharing the same response_block_idx. | |
| # Blocks are ordered by block_idx and cover the entire response. | |
| _block_ranges = [] # list of (block_idx, abs_start, abs_end) | |
| _cur_bi = None | |
| _cur_start = None | |
| for _p in range(seqlen): | |
| _bi = int(response_block_idx[_p].item()) | |
| if _bi < 0: | |
| if _cur_bi is not None: | |
| _block_ranges.append((_cur_bi, _cur_start, _p)) | |
| _cur_bi, _cur_start = None, None | |
| continue | |
| if _cur_bi is None: | |
| _cur_bi, _cur_start = _bi, _p | |
| elif _bi != _cur_bi: | |
| _block_ranges.append((_cur_bi, _cur_start, _p)) | |
| _cur_bi, _cur_start = _bi, _p | |
| if _cur_bi is not None: | |
| _block_ranges.append((_cur_bi, _cur_start, seqlen)) | |
| # Map block_idx -> section_name for downstream logic (section-specific | |
| # behaviors like explanation NULL handling can still be scoped). | |
| _block_idx_to_section = {} | |
| for _sname, (_sstart, _send) in section_ranges.items(): | |
| _sabs_start = original_input_length + _sstart | |
| _sabs_end = original_input_length + _send | |
| for _bi, _bs, _be in _block_ranges: | |
| # Assign section by whether the block's range overlaps the section | |
| if _bs < _sabs_end and _be > _sabs_start: | |
| _block_idx_to_section.setdefault(_bi, _sname) | |
| # ── Phase 2: Denoise block-by-block with optional KV cache ── | |
| # Without cache (fallback): each forward replays the entire sequence. | |
| # With cache: prompt K/V computed once; each block's finalized K/V is | |
| # appended after denoising, so later blocks only forward their own | |
| # ~block_size tokens against the cache. | |
| step = 0 | |
| past_kv = None | |
| prev_last_logit = None # logit at the position just before the next block | |
| if use_kv_cache: | |
| # Phase 0: prompt prefill. Includes vision scatter; cache becomes | |
| # the reusable foundation for every scaffold block. | |
| prompt_tokens = x_t[:, :original_input_length] | |
| prompt_embeds = _embed_fn(prompt_tokens) | |
| if _cached_image_embeds is not None: | |
| prompt_image_mask = _cached_image_mask[:, :original_input_length] | |
| prompt_embeds = prompt_embeds.masked_scatter( | |
| prompt_image_mask, _cached_image_embeds | |
| ) | |
| prompt_position_ids = _position_ids[..., :original_input_length] | |
| # Causal over prompt (matches training's prompt-side attention). | |
| # When attention_mask=None, the model's eval_mask auto-builds causal | |
| # because use_block_causal_mask=True and update_kv_cache=True. | |
| prompt_out = self.forward( | |
| inputs_embeds=prompt_embeds, | |
| position_ids=prompt_position_ids, | |
| attention_mask=None, | |
| past_key_values=None, | |
| use_cache=True, | |
| update_kv_cache=True, | |
| ) | |
| past_kv = prompt_out.past_key_values | |
| # Logit at position (original_input_length - 1); used to predict | |
| # the first token of the first response block via causal shift. | |
| prev_last_logit = prompt_out.logits[:, -1:, :] | |
| # ── Iterate blocks in order ── | |
| for _block_idx, block_abs_start, block_abs_end in _block_ranges: | |
| B = block_abs_end - block_abs_start | |
| section_name = _block_idx_to_section.get(_block_idx, None) | |
| # Count MASK tokens in this block | |
| block_slice = x_t[0, block_abs_start:block_abs_end] | |
| n_masks_in_block = int((block_slice == mask_id).sum().item()) | |
| # ── Iterative unmasking within this block (if any MASKs) ── | |
| if n_masks_in_block > 0: | |
| max_iter = n_masks_in_block + 5 # safety limit | |
| for _ in range(max_iter): | |
| current_block_masks = (x_t[:, block_abs_start:block_abs_end] == mask_id) | |
| if current_block_masks.sum() == 0: | |
| break | |
| if use_kv_cache: | |
| # Feed only this block; past_kv covers prompt + prior blocks. | |
| block_tokens = x_t[:, block_abs_start:block_abs_end] | |
| block_embeds = _embed_fn(block_tokens) | |
| block_position_ids = _position_ids[..., block_abs_start:block_abs_end] | |
| L_cached = past_kv.get_seq_length() if past_kv is not None else 0 | |
| # Block-causal + bidirectional-within-block ⇒ this | |
| # block's queries attend to all cached KV plus all | |
| # fresh block KV ⇒ all-True mask of shape [B, L+B]. | |
| block_attn = torch.ones( | |
| B, L_cached + B, device=self.device, dtype=torch.bool | |
| ) | |
| output = self.forward( | |
| inputs_embeds=block_embeds, | |
| attention_mask=block_attn, | |
| position_ids=block_position_ids, | |
| past_key_values=past_kv, | |
| use_cache=True, | |
| update_kv_cache=False, # read-only during iteration | |
| ) | |
| logits = output.logits # [1, B, V] | |
| # Shift: pred for abs_pos uses logit at abs_pos-1. | |
| # logit at block_abs_start-1 is prev_last_logit; the | |
| # rest come from this forward's earlier positions. | |
| sec_logits = torch.cat([prev_last_logit, logits[:, :-1, :]], dim=1) | |
| else: | |
| # Full-sequence forward (fallback path, same as before) | |
| _cur_embeds = _embed_fn(x_t) | |
| if _cached_image_embeds is not None: | |
| _cur_embeds = _cur_embeds.masked_scatter( | |
| _cached_image_mask, _cached_image_embeds | |
| ) | |
| output = self.forward( | |
| input_ids=x_t, | |
| inputs_embeds=_cur_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=_position_ids, | |
| use_cache=False, | |
| ) | |
| logits = output.logits | |
| sec_logits = logits[:, block_abs_start:block_abs_end, :] | |
| sec_logits = torch.cat( | |
| [logits[:, block_abs_start - 1:block_abs_start, :], | |
| sec_logits[:, :-1, :]], dim=1 | |
| ) | |
| if temperature > 0: | |
| # Temperature sampling for diverse generation (e.g. GRPO rollouts) | |
| sampling_probs = torch.softmax(sec_logits / temperature, dim=-1) | |
| x_1 = torch.multinomial( | |
| sampling_probs.view(-1, sampling_probs.shape[-1]), num_samples=1 | |
| ).view(sampling_probs.shape[:-1]) | |
| else: | |
| # Greedy (default, backward compatible) | |
| x_1 = sec_logits.argmax(dim=-1) | |
| probs = torch.softmax(sec_logits, dim=-1) | |
| x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1) | |
| # Only consider currently-masked positions in this block | |
| x1_p = torch.where(current_block_masks, x1_p, -torch.inf) | |
| unmask_idx = (x1_p > threshold) | |
| if unmask_idx.sum() > 0: | |
| x_t[:, block_abs_start:block_abs_end][unmask_idx] = x_1[unmask_idx] | |
| tokens_per_step.append(int(unmask_idx.sum())) | |
| else: | |
| # Fallback: unmask highest-confidence token | |
| pos = x1_p.argmax() | |
| row = 0 | |
| col = pos.item() | |
| x_t[:, block_abs_start:block_abs_end][row, col] = x_1[row, col] | |
| tokens_per_step.append(1) | |
| step += 1 | |
| if step > max_tokens: | |
| break | |
| # ── Commit this block's K/V to the cache ── | |
| # Run one final forward at block's fully-denoised state with | |
| # update_kv_cache=True so future blocks can attend to it via cache. | |
| # prev_last_logit is refreshed to the logit at the last position | |
| # of this block for the NEXT block's first-position prediction. | |
| if use_kv_cache: | |
| block_tokens = x_t[:, block_abs_start:block_abs_end] | |
| block_embeds = _embed_fn(block_tokens) | |
| block_position_ids = _position_ids[..., block_abs_start:block_abs_end] | |
| L_cached = past_kv.get_seq_length() if past_kv is not None else 0 | |
| block_attn = torch.ones( | |
| B, L_cached + B, device=self.device, dtype=torch.bool | |
| ) | |
| commit_out = self.forward( | |
| inputs_embeds=block_embeds, | |
| attention_mask=block_attn, | |
| position_ids=block_position_ids, | |
| past_key_values=past_kv, | |
| use_cache=True, | |
| update_kv_cache=True, | |
| ) | |
| past_kv = commit_out.past_key_values | |
| prev_last_logit = commit_out.logits[:, -1:, :] | |
| # NOTE: a previous null_ratio>0.3 early-stopping heuristic was | |
| # removed. It computed the ratio globally across the whole | |
| # explanation and, when tripped, force-filled every remaining | |
| # MASK with NULL — including MASKs in middle positions that | |
| # should have held real text — which cut short explanations | |
| # mid-sentence. Training always produces 192 value tokens | |
| # (real text + <|NULL|> padding at the tail) and the model | |
| # learned to emit NULL cleanly at the tail, so the final | |
| # NULL-strip below is sufficient. Cost: every sample now | |
| # denoises all 6 explanation blocks. | |
| # Post-process: strip NULL tokens from the output | |
| gen_tokens = x_t[0, original_input_length:].tolist() | |
| cleaned = [t for t in gen_tokens if t != null_id and t != mask_id] | |
| x_t = torch.cat([ | |
| input_ids, | |
| torch.tensor([cleaned], device=self.device, dtype=torch.long) | |
| ], dim=1) | |
| gen_length = x_t.shape[1] - original_input_length | |
| if return_stats: | |
| stats = { | |
| "tokens_per_step": tokens_per_step, | |
| "total_steps": step, | |
| "gen_length": gen_length, | |
| "null_tokens_stripped": len(gen_tokens) - len(cleaned), | |
| "block_size": block_size, | |
| } | |
| return x_t, stats | |
| return x_t | |
| # --------------------------------------------------------------------------- | |
| # scaffold_speculative_sample — Scaffold Spec (SS) | |
| # --------------------------------------------------------------------------- | |
| def scaffold_speculative_sample( | |
| self, | |
| input_ids, | |
| tokenizer, | |
| block_size=32, | |
| max_tokens=1024, | |
| pixel_values=None, | |
| image_grid_thw=None, | |
| mask_id=151665, | |
| null_id=151666, | |
| threshold=0.9, | |
| stop_token=151645, | |
| explanation_block_size=32, | |
| explanation_max_blocks=6, | |
| return_stats=False, | |
| draft_temperature=0.0, | |
| verify_temperature=0.0, | |
| ): | |
| """ | |
| Scaffold-aware self-speculative decoding. | |
| Minimal modification of standard self-spec | |
| (speculative_block_causal_sample_cache): scaffold (structural JSON) | |
| tokens are pre-filled in the draft block instead of MASK and | |
| auto-accepted during causal verification. | |
| Key design: uses *exactly the same* attention patterns as standard | |
| self-spec (block-diff for draft, **causal** for verify via | |
| auto eval_mask). Only the draft block content differs — scaffold | |
| positions carry known tokens instead of MASK, giving the draft | |
| better context while scaffold tokens are "free" during acceptance. | |
| """ | |
| from .section_utils import ( | |
| build_deep_json_scaffold, | |
| NULL_TOKEN_ID, | |
| ) | |
| scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( | |
| tokenizer, | |
| mask_id=mask_id, | |
| null_id=null_id, | |
| explanation_block_size=explanation_block_size, | |
| explanation_max_blocks=explanation_max_blocks, | |
| ) | |
| scaffold_len = len(scaffold_tokens) | |
| original_input_length = input_ids.shape[1] | |
| tokens_per_step = [] | |
| self.model.bd_size = block_size | |
| _ss_profile = bool(os.environ.get("SS_PROFILE")) | |
| _ss_traj_start = section_ranges.get("trajectory", (None, None))[0] | |
| if _ss_profile: | |
| import time as _time | |
| torch.cuda.synchronize() | |
| _ss_t = {"start": _time.perf_counter()} | |
| _ss_marked_traj_start = False | |
| _ss_n_fwd_prefix = 0 | |
| _ss_n_fwd_traj = 0 | |
| # Pre-convert to tensors for vectorized operations in the loop | |
| scaffold_tok_t = torch.tensor( | |
| scaffold_tokens, device=self.device, dtype=torch.long | |
| ) | |
| scaffold_is_fixed = torch.tensor( | |
| scaffold_mask_list, device=self.device, dtype=torch.bool | |
| ) | |
| # ── Phase 1: Prefill prompt (identical to standard self-spec) ── | |
| output = self.forward( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| use_cache=True, | |
| update_kv_cache=True, | |
| ) | |
| logits, past_key_values = output.logits, output.past_key_values | |
| if _ss_profile: | |
| torch.cuda.synchronize() | |
| _ss_t["after_prefill"] = _time.perf_counter() | |
| # First token — use scaffold token (always '{') | |
| next_token = torch.tensor( | |
| [[scaffold_tokens[0]]], device=self.device, dtype=torch.long | |
| ) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| tokens_per_step.append(1) | |
| scaffold_cursor = 1 | |
| step = 1 | |
| # ── Phase 2: Self-speculative decoding loop ── | |
| # Follows the exact same structure as | |
| # speculative_block_causal_sample_cache, with scaffold-aware draft. | |
| while scaffold_cursor < scaffold_len: | |
| if _ss_profile and (not _ss_marked_traj_start) and ( | |
| _ss_traj_start is not None and scaffold_cursor >= _ss_traj_start | |
| ): | |
| torch.cuda.synchronize() | |
| _ss_t["enter_traj"] = _time.perf_counter() | |
| _ss_marked_traj_start = True | |
| prompt_length = input_ids.shape[1] | |
| n_draft = min(block_size - 1, scaffold_len - scaffold_cursor) | |
| # Build draft block: [seed, scaffold_or_MASK × n_draft] | |
| sc_end = scaffold_cursor + n_draft | |
| is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end] | |
| draft_tensor = torch.where( | |
| is_fixed, | |
| scaffold_tok_t[scaffold_cursor:sc_end], | |
| mask_id, | |
| ).unsqueeze(0) | |
| x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1) | |
| mask_idx = (x_t == mask_id) | |
| # ── Draft (block-diff bidirectional via auto eval_mask) ── | |
| logits = self.forward( | |
| input_ids=x_t, | |
| use_cache=True, | |
| past_key_values=past_key_values, | |
| update_kv_cache=False, | |
| eval_bd_size=block_size, | |
| ).logits | |
| tokens_per_step.append(0) | |
| step += 1 | |
| # Shift logits (same as standard self-spec) | |
| logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) | |
| if draft_temperature > 0: | |
| # Temperature sampling for draft diversity | |
| scaled = logits / draft_temperature | |
| draft_probs = torch.softmax(scaled, dim=-1) | |
| x_1 = torch.multinomial( | |
| draft_probs.view(-1, draft_probs.shape[-1]), num_samples=1 | |
| ).view(draft_probs.shape[:-1]) | |
| # Confidence uses unscaled probs for thresholding | |
| probs = torch.softmax(logits, dim=-1) | |
| x1_p = torch.gather( | |
| probs, dim=-1, index=x_1.unsqueeze(-1) | |
| ).squeeze(-1) | |
| else: | |
| x_1 = logits.argmax(dim=-1) | |
| probs = torch.softmax(logits, dim=-1) | |
| x1_p = torch.gather( | |
| probs, dim=-1, index=x_1.unsqueeze(-1) | |
| ).squeeze(-1) | |
| # Only fill MASK positions; scaffold positions keep their tokens | |
| x1_p = torch.where(mask_idx, x1_p, -torch.inf) | |
| unmask_idx = (x1_p > 0) # threshold=0 for draft filling | |
| if unmask_idx.sum() > 0: | |
| x_t[unmask_idx] = x_1[unmask_idx] | |
| else: | |
| # Fallback: fill most confident MASK | |
| mask_only_p = x1_p.clone() | |
| mask_only_p[~mask_idx] = -torch.inf | |
| if mask_only_p.max() > -torch.inf: | |
| best = mask_only_p.argmax() | |
| x_t.view(-1)[best] = x_1.view(-1)[best] | |
| # ── Verify (causal via auto eval_mask, commit to cache) ── | |
| output = self.forward( | |
| input_ids=x_t, | |
| use_cache=True, | |
| past_key_values=past_key_values, | |
| update_kv_cache=True, | |
| eval_bd_size=block_size, | |
| ) | |
| past_key_values = output.past_key_values | |
| if verify_temperature > 0: | |
| verify_logits = output.logits / verify_temperature | |
| verify_probs = torch.softmax(verify_logits, dim=-1) | |
| ar_block_token = torch.multinomial( | |
| verify_probs.view(-1, verify_probs.shape[-1]), num_samples=1 | |
| ).view(verify_probs.shape[:-1]) | |
| else: | |
| ar_block_token = output.logits.argmax(dim=-1) | |
| # ── AR acceptance (scaffold positions auto-pass) ── | |
| ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) | |
| accepted_token_num = 0 | |
| for i in range(n_draft): | |
| if is_fixed[i] or ar_matches[i]: | |
| accepted_token_num += 1 | |
| else: | |
| break | |
| accepted_token_num += 1 # bonus token | |
| tokens_per_step.append(accepted_token_num) | |
| # Force scaffold tokens at scaffold positions, AR predictions elsewhere | |
| accepted_ids = ar_block_token[:, :accepted_token_num].clone() | |
| acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len) | |
| acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end] | |
| accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ | |
| scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed] | |
| input_ids = torch.cat([input_ids, accepted_ids], dim=1) | |
| scaffold_cursor += accepted_token_num | |
| past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1) | |
| step += 1 | |
| # Stop conditions | |
| if input_ids.shape[1] - original_input_length > max_tokens: | |
| break | |
| if stop_token in input_ids[:, prompt_length:]: | |
| stop_token_idx = ( | |
| input_ids[:, prompt_length:] == stop_token | |
| ).nonzero()[0][1] | |
| if ( | |
| input_ids[:, prompt_length:prompt_length + stop_token_idx] | |
| == mask_id | |
| ).sum() == 0: | |
| break | |
| if _ss_profile: | |
| torch.cuda.synchronize() | |
| _ss_t["end"] = _time.perf_counter() | |
| _t_total = _ss_t["end"] - _ss_t["start"] | |
| _t_pre = _ss_t["after_prefill"] - _ss_t["start"] | |
| _t_traj_in = _ss_t.get("enter_traj") | |
| if _t_traj_in is not None: | |
| _t_prefix = _t_traj_in - _ss_t["after_prefill"] | |
| _t_traj = _ss_t["end"] - _t_traj_in | |
| else: | |
| _t_prefix = _ss_t["end"] - _ss_t["after_prefill"] | |
| _t_traj = 0.0 | |
| print( | |
| f"[ss profile] total={_t_total*1000:.0f}ms " | |
| f"prefill={_t_pre*1000:.0f}ms " | |
| f"prefix-decode={_t_prefix*1000:.0f}ms " | |
| f"traj-decode={_t_traj*1000:.0f}ms", | |
| flush=True, | |
| ) | |
| # ── Phase 3: Post-process — truncate at stop, strip NULL ── | |
| if stop_token in input_ids[:, original_input_length:]: | |
| stop_token_idx = ( | |
| input_ids[:, original_input_length:] == stop_token | |
| ).nonzero()[0][1] | |
| input_ids = input_ids[ | |
| :, :stop_token_idx + original_input_length + 1 | |
| ] | |
| gen_tokens = input_ids[0, original_input_length:].tolist() | |
| cleaned = [t for t in gen_tokens if t != null_id and t != mask_id] | |
| output_ids = torch.cat( | |
| [ | |
| input_ids[:, :original_input_length], | |
| torch.tensor( | |
| [cleaned], device=self.device, dtype=torch.long | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| gen_length = output_ids.shape[1] - original_input_length | |
| if return_stats: | |
| stats = { | |
| "tokens_per_step": tokens_per_step, | |
| "total_steps": step, | |
| "gen_length": gen_length, | |
| "null_tokens_stripped": len(gen_tokens) - len(cleaned), | |
| "block_size": block_size, | |
| "method": "scaffold_speculative_v5", | |
| } | |
| return output_ids, stats | |
| return output_ids | |
| # --------------------------------------------------------------------------- | |
| # scaffold_spec_with_ss_multi_traj — SS multi-rollout inference scaling | |
| # --------------------------------------------------------------------------- | |
| def scaffold_spec_with_ss_multi_traj( | |
| self, | |
| input_ids, | |
| tokenizer, | |
| block_size=32, | |
| max_tokens=1024, | |
| pixel_values=None, | |
| image_grid_thw=None, | |
| mask_id=151665, | |
| null_id=151666, | |
| threshold=0.9, | |
| stop_token=151645, | |
| explanation_block_size=32, | |
| explanation_max_blocks=6, | |
| return_stats=False, | |
| num_traj_rollouts=4, | |
| traj_verify_temperature=0.5, | |
| traj_draft_temperature=0.0, | |
| merge_weights=None, | |
| batch_parallel=False, | |
| ): | |
| """Scaffold Spec with shared prefix + N SS rollouts on the trajectory section. | |
| Decoding pipeline: | |
| 0) Prompt prefill [shared] | |
| 1) Scaffold Spec for sections 1-3 (CoT) at verify_temp = 0 [shared, deterministic] | |
| 2) Fork KV cache N times [O(N) memory] | |
| 3) For each fork: continue Scaffold Spec on the trajectory | |
| section with verify_temperature = traj_verify_temperature | |
| (each rollout draws different samples in the AR-verify step | |
| because torch.multinomial is invoked with a global RNG). | |
| 4) Parse all N trajectories and return their weighted mean. | |
| Cost: roughly 1 full SS pass (sections 1-3 are ~88%% of decoded tokens | |
| on our schema) + N x trajectory-only SS passes. For N = 4 this is | |
| ~1.5x the cost of a single SS, vs ~4x for naive sequential rerolling. | |
| If batch_parallel = True, the N trajectory rollouts are executed in a | |
| batched (batch_size = N) manner: one shared model.forward per | |
| speculative draft / verify step over an N-replicated trajectory | |
| suffix, which removes the per-rollout serial overhead at the cost of | |
| replicating the per-layer KV cache N-fold along the batch dimension. | |
| Returns: (output_ids, stats) if return_stats else output_ids. | |
| """ | |
| from .section_utils import ( | |
| build_deep_json_scaffold, | |
| SECTION_KEYS, | |
| ) | |
| scaffold_tokens, section_ranges, scaffold_mask_list = build_deep_json_scaffold( | |
| tokenizer, | |
| mask_id=mask_id, | |
| null_id=null_id, | |
| explanation_block_size=explanation_block_size, | |
| explanation_max_blocks=explanation_max_blocks, | |
| ) | |
| scaffold_len = len(scaffold_tokens) | |
| original_input_length = input_ids.shape[1] | |
| tokens_per_step = [] | |
| self.model.bd_size = block_size | |
| scaffold_tok_t = torch.tensor(scaffold_tokens, device=self.device, dtype=torch.long) | |
| scaffold_is_fixed = torch.tensor(scaffold_mask_list, device=self.device, dtype=torch.bool) | |
| traj_start_in_scaffold = section_ranges["trajectory"][0] | |
| _profile = bool(os.environ.get("SS_MT_PROFILE")) | |
| if _profile: | |
| import time as _time | |
| torch.cuda.synchronize() | |
| _t_phase = {"start": _time.perf_counter()} | |
| _phase_clone_total = 0.0 | |
| _phase_rollout_each = [] | |
| # ── Phase 0: Prefill prompt ── | |
| output = self.forward( | |
| input_ids=input_ids, pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| use_cache=True, update_kv_cache=True, | |
| ) | |
| logits, past_key_values = output.logits, output.past_key_values | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _t_phase["after_prefill"] = _time.perf_counter() | |
| next_token = torch.tensor( | |
| [[scaffold_tokens[0]]], device=self.device, dtype=torch.long, | |
| ) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| tokens_per_step.append(1) | |
| scaffold_cursor = 1 | |
| step = 1 | |
| # ── Phase 1: Scaffold Spec for non-trajectory sections (shared, vt=0) ── | |
| while scaffold_cursor < scaffold_len and scaffold_cursor < traj_start_in_scaffold: | |
| remaining_before_traj = traj_start_in_scaffold - scaffold_cursor | |
| n_draft = min(block_size - 1, remaining_before_traj) | |
| if n_draft <= 0: | |
| break | |
| sc_end = scaffold_cursor + n_draft | |
| is_fixed = scaffold_is_fixed[scaffold_cursor:sc_end] | |
| draft_tensor = torch.where( | |
| is_fixed, scaffold_tok_t[scaffold_cursor:sc_end], mask_id, | |
| ).unsqueeze(0) | |
| x_t = torch.cat([input_ids[:, -1:], draft_tensor], dim=1) | |
| mask_idx = (x_t == mask_id) | |
| # Draft (block-bidirectional) | |
| logits = self.forward( | |
| input_ids=x_t, use_cache=True, | |
| past_key_values=past_key_values, | |
| update_kv_cache=False, eval_bd_size=block_size, | |
| ).logits | |
| tokens_per_step.append(0) | |
| step += 1 | |
| logits = torch.cat([logits[:, :1, :], logits[:, :-1, :]], dim=1) | |
| x_1 = logits.argmax(dim=-1) | |
| probs = torch.softmax(logits, dim=-1) | |
| x1_p = torch.gather(probs, dim=-1, index=x_1.unsqueeze(-1)).squeeze(-1) | |
| x1_p = torch.where(mask_idx, x1_p, -torch.inf) | |
| unmask_idx = (x1_p > 0) | |
| if unmask_idx.sum() > 0: | |
| x_t[unmask_idx] = x_1[unmask_idx] | |
| else: | |
| mask_only_p = x1_p.clone() | |
| mask_only_p[~mask_idx] = -torch.inf | |
| if mask_only_p.max() > -torch.inf: | |
| best = mask_only_p.argmax() | |
| x_t.view(-1)[best] = x_1.view(-1)[best] | |
| # Verify (causal, greedy) | |
| output = self.forward( | |
| input_ids=x_t, use_cache=True, | |
| past_key_values=past_key_values, | |
| update_kv_cache=True, eval_bd_size=block_size, | |
| ) | |
| past_key_values = output.past_key_values | |
| ar_block_token = output.logits.argmax(dim=-1) | |
| ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) | |
| accepted_token_num = 0 | |
| for i in range(n_draft): | |
| if is_fixed[i] or ar_matches[i]: | |
| accepted_token_num += 1 | |
| else: | |
| break | |
| accepted_token_num += 1 | |
| max_accept = traj_start_in_scaffold - scaffold_cursor | |
| if accepted_token_num > max_accept: | |
| accepted_token_num = max_accept | |
| tokens_per_step.append(accepted_token_num) | |
| accepted_ids = ar_block_token[:, :accepted_token_num].clone() | |
| acc_end = min(scaffold_cursor + accepted_token_num, scaffold_len) | |
| acc_fixed = scaffold_is_fixed[scaffold_cursor:acc_end] | |
| accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ | |
| scaffold_tok_t[scaffold_cursor:acc_end][acc_fixed] | |
| input_ids = torch.cat([input_ids, accepted_ids], dim=1) | |
| scaffold_cursor += accepted_token_num | |
| past_key_values = _crop_cache(past_key_values, input_ids.shape[1] - 1) | |
| step += 1 | |
| if input_ids.shape[1] - original_input_length > max_tokens: | |
| break | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _t_phase["after_phase1"] = _time.perf_counter() | |
| # ── Phase 2: Fork KV cache N times (one per trajectory rollout) ── | |
| prefix_input_ids = input_ids.clone() | |
| prefix_len = prefix_input_ids.shape[1] | |
| def _clone_cache(kv): | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _t0 = _time.perf_counter() | |
| cloned = [] | |
| for layer_num in range(len(kv)): | |
| cloned.append(tuple(t.clone() for t in kv[layer_num])) | |
| ret = DynamicCache(cloned) | |
| if _profile: | |
| torch.cuda.synchronize() | |
| nonlocal _phase_clone_total | |
| _phase_clone_total += _time.perf_counter() - _t0 | |
| return ret | |
| # ── Phase 3: N SS rollouts on trajectory section, each with vt > 0 ── | |
| # All rollouts start from the same prefix; randomness comes from | |
| # the multinomial calls in draft / verify (RNG is process-global). | |
| N = max(1, int(num_traj_rollouts)) | |
| def _run_one_traj_rollout(start_kv, start_input_ids): | |
| """Continue Scaffold Spec from start_kv / start_input_ids over the | |
| trajectory section, applying traj_*_temperature. Returns the | |
| final ss_input_ids (with trajectory tokens appended) and the | |
| extracted trajectory value tokens.""" | |
| local_kv = start_kv | |
| local_input = start_input_ids | |
| local_cursor = scaffold_cursor | |
| while local_cursor < scaffold_len: | |
| n_draft = min(block_size - 1, scaffold_len - local_cursor) | |
| sc_end = local_cursor + n_draft | |
| is_fixed = scaffold_is_fixed[local_cursor:sc_end] | |
| draft_tensor = torch.where( | |
| is_fixed, scaffold_tok_t[local_cursor:sc_end], mask_id, | |
| ).unsqueeze(0) | |
| x_t = torch.cat([local_input[:, -1:], draft_tensor], dim=1) | |
| mask_idx = (x_t == mask_id) | |
| # Draft (block-bidirectional, optionally temp-sampled) | |
| draft_logits = self.forward( | |
| input_ids=x_t, use_cache=True, past_key_values=local_kv, | |
| update_kv_cache=False, eval_bd_size=block_size, | |
| ).logits | |
| draft_logits = torch.cat( | |
| [draft_logits[:, :1, :], draft_logits[:, :-1, :]], dim=1, | |
| ) | |
| if traj_draft_temperature > 0: | |
| scaled = draft_logits / traj_draft_temperature | |
| draft_probs = torch.softmax(scaled, dim=-1) | |
| x_1 = torch.multinomial( | |
| draft_probs.view(-1, draft_probs.shape[-1]), | |
| num_samples=1, | |
| ).view(draft_probs.shape[:-1]) | |
| else: | |
| x_1 = draft_logits.argmax(dim=-1) | |
| probs = torch.softmax(draft_logits, dim=-1) | |
| x1_p = torch.gather( | |
| probs, dim=-1, index=x_1.unsqueeze(-1), | |
| ).squeeze(-1) | |
| x1_p = torch.where(mask_idx, x1_p, -torch.inf) | |
| unmask_idx = (x1_p > 0) | |
| if unmask_idx.sum() > 0: | |
| x_t[unmask_idx] = x_1[unmask_idx] | |
| else: | |
| mask_only_p = x1_p.clone() | |
| mask_only_p[~mask_idx] = -torch.inf | |
| if mask_only_p.max() > -torch.inf: | |
| x_t.view(-1)[mask_only_p.argmax()] = \ | |
| x_1.view(-1)[mask_only_p.argmax()] | |
| # Verify (causal, optionally temp-sampled) | |
| v_out = self.forward( | |
| input_ids=x_t, use_cache=True, past_key_values=local_kv, | |
| update_kv_cache=True, eval_bd_size=block_size, | |
| ) | |
| local_kv = v_out.past_key_values | |
| if traj_verify_temperature > 0: | |
| v_logits = v_out.logits / traj_verify_temperature | |
| v_probs = torch.softmax(v_logits, dim=-1) | |
| ar_block_token = torch.multinomial( | |
| v_probs.view(-1, v_probs.shape[-1]), | |
| num_samples=1, | |
| ).view(v_probs.shape[:-1]) | |
| else: | |
| ar_block_token = v_out.logits.argmax(dim=-1) | |
| ar_matches = (ar_block_token[0, :n_draft] == x_t[0, 1:n_draft + 1]) | |
| accepted_token_num = 0 | |
| for i in range(n_draft): | |
| if is_fixed[i] or ar_matches[i]: | |
| accepted_token_num += 1 | |
| else: | |
| break | |
| accepted_token_num += 1 | |
| accepted_ids = ar_block_token[:, :accepted_token_num].clone() | |
| acc_end = min(local_cursor + accepted_token_num, scaffold_len) | |
| acc_fixed = scaffold_is_fixed[local_cursor:acc_end] | |
| accepted_ids[0, :len(acc_fixed)][acc_fixed] = \ | |
| scaffold_tok_t[local_cursor:acc_end][acc_fixed] | |
| local_input = torch.cat([local_input, accepted_ids], dim=1) | |
| local_cursor += accepted_token_num | |
| local_kv = _crop_cache(local_kv, local_input.shape[1] - 1) | |
| if local_input.shape[1] - original_input_length > max_tokens: | |
| break | |
| if stop_token in local_input[:, prefix_len:]: | |
| st_idx = (local_input[:, prefix_len:] == stop_token).nonzero() | |
| if st_idx.numel() > 0: | |
| cand_st = st_idx[0][1].item() | |
| if (local_input[:, prefix_len:prefix_len + cand_st] == mask_id).sum() == 0: | |
| break | |
| traj_values = [ | |
| t for i, t in enumerate(local_input[0, original_input_length:].tolist()) | |
| if i >= traj_start_in_scaffold and i < scaffold_len | |
| and not scaffold_mask_list[i] and t != null_id and t != mask_id | |
| ] | |
| return local_input, traj_values | |
| # Sequential N rollouts (Option A; batch_parallel=False). | |
| rollout_inputs = [] | |
| rollout_traj_values = [] | |
| for _i in range(N): | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _t_r0 = _time.perf_counter() | |
| cand_kv = _clone_cache(past_key_values) | |
| cand_input = prefix_input_ids.clone() | |
| cand_input, traj_vals = _run_one_traj_rollout(cand_kv, cand_input) | |
| rollout_inputs.append(cand_input) | |
| rollout_traj_values.append(traj_vals) | |
| step += 1 | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _phase_rollout_each.append(_time.perf_counter() - _t_r0) | |
| if _profile: | |
| torch.cuda.synchronize() | |
| _t_phase["after_rollouts"] = _time.perf_counter() | |
| _t_total = _t_phase["after_rollouts"] - _t_phase["start"] | |
| _t_pre = _t_phase["after_prefill"] - _t_phase["start"] | |
| _t_p1 = _t_phase["after_phase1"] - _t_phase["after_prefill"] | |
| _t_rolls = _t_phase["after_rollouts"] - _t_phase["after_phase1"] | |
| print( | |
| f"[ss_mt profile] total={_t_total*1000:.0f}ms " | |
| f"prefill(P0)={_t_pre*1000:.0f}ms " | |
| f"prefix-decode(P1)={_t_p1*1000:.0f}ms " | |
| f"rollouts(P2+P3)={_t_rolls*1000:.0f}ms " | |
| f"of which kv-clone={_phase_clone_total*1000:.0f}ms " | |
| f"per-rollout={[f'{r*1000:.0f}' for r in _phase_rollout_each]}ms", | |
| flush=True, | |
| ) | |
| # ── Phase 4: Parse all rollouts, weighted-merge waypoints ── | |
| def _decode_trajectory(traj_tokens): | |
| text = tokenizer.decode(traj_tokens, skip_special_tokens=False) | |
| text = text.replace("<|NULL|>", "").strip() | |
| coords = re.findall(r"[+-]?\d+\.?\d*", text) | |
| wps = [] | |
| for i in range(0, len(coords) - 1, 2): | |
| wps.append([float(coords[i]), float(coords[i + 1])]) | |
| return wps | |
| rollout_waypoints = [_decode_trajectory(v) for v in rollout_traj_values] | |
| if merge_weights is None or len(merge_weights) != N: | |
| ws = [1.0 / N] * N | |
| else: | |
| total = sum(merge_weights) | |
| ws = [w / total for w in merge_weights] | |
| if rollout_waypoints and all(len(w) > 0 for w in rollout_waypoints): | |
| n_wp = min(len(w) for w in rollout_waypoints) | |
| merged_waypoints = [] | |
| for i in range(n_wp): | |
| mx = sum(ws[c] * rollout_waypoints[c][i][0] for c in range(N)) | |
| my = sum(ws[c] * rollout_waypoints[c][i][1] for c in range(N)) | |
| merged_waypoints.append([mx, my]) | |
| else: | |
| merged_waypoints = next( | |
| (w for w in rollout_waypoints if w), [], | |
| ) | |
| # Output text: take rollout 0's full text but replace its trajectory | |
| # with the merged waypoints. | |
| base_input = rollout_inputs[0] | |
| if stop_token in base_input[:, original_input_length:]: | |
| st_idx = (base_input[:, original_input_length:] == stop_token).nonzero()[0][1] | |
| base_input = base_input[:, :st_idx + original_input_length + 1] | |
| base_raw_tokens = base_input[0, original_input_length:].tolist() | |
| base_cleaned = [t for t in base_raw_tokens if t != null_id and t != mask_id] | |
| base_null_stripped = len(base_raw_tokens) - len(base_cleaned) | |
| base_text = tokenizer.decode(base_cleaned, skip_special_tokens=False) | |
| traj_parts = [ | |
| f"[{x:+07.2f},{y:+06.2f}]" for x, y in merged_waypoints | |
| ] | |
| merged_traj_str = "[" + ", ".join(traj_parts) + "]" | |
| replaced_text = re.sub( | |
| r'("trajectory"\s*:\s*")(\[\[.*?\]\])', | |
| r"\g<1>" + merged_traj_str, base_text, | |
| ) | |
| merged_tokens = tokenizer.encode(replaced_text, add_special_tokens=False) | |
| output_ids = torch.cat([ | |
| input_ids[:, :original_input_length], | |
| torch.tensor([merged_tokens], device=self.device, dtype=torch.long), | |
| ], dim=1) | |
| gen_length = output_ids.shape[1] - original_input_length | |
| if return_stats: | |
| stats = { | |
| "tokens_per_step": tokens_per_step, | |
| "total_steps": step, | |
| "gen_length": gen_length, | |
| "null_tokens_stripped": base_null_stripped, | |
| "block_size": block_size, | |
| "method": "scaffold_spec_with_ss_multi_traj", | |
| "num_traj_rollouts": N, | |
| "traj_verify_temperature": traj_verify_temperature, | |
| "rollout_waypoints": rollout_waypoints, | |
| "merged_waypoints": merged_waypoints, | |
| "merge_weights": ws, | |
| } | |
| return output_ids, stats | |
| return output_ids | |
| # --------------------------------------------------------------------------- | |
| # Bind decoding methods onto the model class. | |
| # | |
| # ``modeling.py`` imports this module at the bottom of the file, after the | |
| # ``Fast_dDriveForConditionalGeneration`` class has been defined. We | |
| # attach the three decoding paths as ordinary methods so callers can invoke | |
| # them as ``model.mdm_sample_deep_scaffold(...)`` etc. without any extra | |
| # registration step. | |
| # --------------------------------------------------------------------------- | |
| def attach_generation_methods(cls): | |
| """Attach the three release decoding paths as methods of ``cls``.""" | |
| cls.mdm_sample_deep_scaffold = mdm_sample_deep_scaffold | |
| cls.scaffold_speculative_sample = scaffold_speculative_sample | |
| cls.scaffold_spec_with_ss_multi_traj = scaffold_spec_with_ss_multi_traj | |
| return cls | |
| __all__ = [ | |
| "mdm_sample_deep_scaffold", | |
| "scaffold_speculative_sample", | |
| "scaffold_spec_with_ss_multi_traj", | |
| "attach_generation_methods", | |
| ] | |