Fast-dDrive / section_utils.py
xiwenyoumu's picture
Initial Fast-dDrive 3B release
5e9a603 verified
"""
Section-aware block scheduling for JSON structured output.
Inspired by the S3 (Self-adaptive Schema Scaffolding) paper (arXiv:2507.04504),
this module provides utilities to:
1. Parse tokenized JSON output into sections (critical_objects, explanation, etc.)
2. Assign section-aware block indices for variable block sizes per section
3. Build JSON scaffolds for inference (pre-fill structural tokens)
The DVLM-AD output schema has 4 sections:
- critical_objects: ~88 tokens (12 yes/no fields, nearly constant)
- explanation: ~114 tokens (variable, 72-172)
- future_meta_behavior: ~40 tokens (nearly constant)
- trajectory: ~80 tokens (nearly constant)
"""
import torch
from typing import Dict, List, Optional, Tuple
import math
# Ordered list of section keys as they appear in the JSON output
SECTION_KEYS = [
"critical_objects",
"explanation",
"future_meta_behavior",
"trajectory",
]
# Default token budgets per section (based on training data analysis)
DEFAULT_TOKEN_BUDGETS = {
"critical_objects": 88,
"explanation": 128,
"future_meta_behavior": 40,
"trajectory": 80,
}
# Default steps per section
DEFAULT_SECTION_STEPS = {
"critical_objects": 1,
"explanation": 3,
"future_meta_behavior": 1,
"trajectory": 1,
}
def _v1_removed_parse_json_sections(*args, **kwargs):
raise NotImplementedError("DS v1 parse_json_sections has been removed. Use deep scaffold v2.")
def _v1_removed_compute_section_block_idx(*args, **kwargs):
raise NotImplementedError("DS v1 compute_section_block_idx has been removed. Use compute_section_block_idx_deep_static.")
def _v1_removed_build_json_scaffold(*args, **kwargs):
raise NotImplementedError("DS v1 build_json_scaffold has been removed. Use build_deep_json_scaffold.")
def _v1_removed_compute_section_block_sizes(*args, **kwargs):
raise NotImplementedError("DS v1 compute_section_block_sizes has been removed.")
def build_static_scaffold_sequences(tokenizer) -> Dict[str, List[int]]:
"""Pre-compute token sequences for top-level JSON boundary matching.
Used internally by :func:`build_deep_scaffold_sequences`.
"""
return {
"prefix": tokenizer.encode('{"critical_objects":', add_special_tokens=False),
"between_co_exp": tokenizer.encode(' "explanation":', add_special_tokens=False),
"between_exp_fmb": tokenizer.encode(' "future_meta_behavior":', add_special_tokens=False),
"between_fmb_traj": tokenizer.encode(' "trajectory":', add_special_tokens=False),
}
def _v1_removed_compute_section_block_idx_static(*args, **kwargs):
raise NotImplementedError("DS v1 compute_section_block_idx_static has been removed. Use compute_section_block_idx_deep_static.")
# Backward-compatible aliases so stale imports produce clear errors
parse_json_sections = _v1_removed_parse_json_sections
compute_section_block_idx = _v1_removed_compute_section_block_idx
build_json_scaffold = _v1_removed_build_json_scaffold
compute_section_block_sizes = _v1_removed_compute_section_block_sizes
compute_section_block_idx_static = _v1_removed_compute_section_block_idx_static
# ═══════════════════════════════════════════════════════════════
# Deep scaffold v2: constants and utilities
# ═══════════════════════════════════════════════════════════════
NULL_TOKEN_ID = 151666
# critical_objects: 12 sub-keys, each value is exactly 1 token (yes=9693 / no=2152)
CRITICAL_OBJECTS_SUBKEYS = [
"nearby_vehicle", "pedestrian", "cyclist", "construction",
"traffic_element", "weather_condition", "road_hazard",
"emergency_vehicle", "animal", "special_vehicle",
"conflicting_vehicle", "door_opening_vehicle",
]
# future_meta_behavior: each sub-key value is exactly 3 tokens
# (e.g., "keep speed" → [4867, 4732, 151667] or "go straight" → [2849, 7833, 151667])
FMB_VALUE_BUDGET = 3
def build_deep_json_scaffold(
tokenizer,
section_token_budgets: Optional[Dict[str, int]] = None,
mask_id: Optional[int] = None,
null_id: Optional[int] = None,
explanation_block_size: int = 32,
explanation_max_blocks: int = 6,
) -> Tuple[List[int], Dict[str, Tuple[int, int]], List[int]]:
"""Build a deep JSON scaffold for inference (v2).
Constructs a template response by building a Python dict and
processing it through the **exact same pipeline** as the training
dataloader (``multi_modal_dataset.py``):
1. Build a realistic dict with placeholder values.
2. Pad explanation with ``<|NULL|>`` to ``exp_budget`` tokens.
3. Pad FMB values with ``<|NULL|>`` to 3 tokens each.
4. Normalize trajectory to ``+XXX.XX`` format with spaces.
5. Serialize with ``json.dumps(obj, ensure_ascii=False)``.
6. Tokenize the whole string as one piece.
7. Run ``compute_section_block_idx_deep_static`` to get scaffold/value.
8. Replace value positions with MASK tokens.
This guarantees identical BPE tokenization as training data.
Returns
-------
scaffold_tokens : list[int]
Token IDs with MASK at value positions.
section_ranges : dict
Section name -> (start, end) within scaffold_tokens.
scaffold_mask : list[int]
0 = value (to denoise), 1 = scaffold (frozen).
"""
import torch as _torch
import json as _json
import re as _re
if mask_id is None:
mask_tok = tokenizer.encode("|<MASK>|", add_special_tokens=False)
mask_id = mask_tok[0] if len(mask_tok) == 1 else 151665
if null_id is None:
null_id = NULL_TOKEN_ID
exp_budget = explanation_block_size * explanation_max_blocks # default 192
# ── Step 1: Build a Python dict matching training data structure ──
# Placeholder explanation text (will be replaced with MASK anyway).
filler_explanation = (
"The ego vehicle is driving forward on the road. "
"There are nearby vehicles ahead that may affect the path. "
"No pedestrians or cyclists are detected in the immediate area. "
"The road conditions appear normal with no hazards present. "
"Speed adjustment may be needed based on the traffic ahead. "
"No lateral maneuvering is required at this time."
)
def _build_template(n_exp_nulls: int) -> str:
"""Build template via json.dumps — identical to dataloader output."""
null_pad = "<|NULL|>" * n_exp_nulls
data_obj = {
"critical_objects": {
"nearby_vehicle": "no", "pedestrian": "no", "cyclist": "no",
"construction": "no", "traffic_element": "no",
"weather_condition": "no", "road_hazard": "no",
"emergency_vehicle": "no", "animal": "no",
"special_vehicle": "no", "conflicting_vehicle": "no",
"door_opening_vehicle": "no",
},
"explanation": filler_explanation + null_pad,
"future_meta_behavior": {
"longitudinal": "come to stop",
"lateral": "go straight<|NULL|>",
},
# Raw trajectory — will be normalized below
"trajectory": "[[+14.70,-00.04], [+29.55,-00.21], [+44.51,-00.56], [+59.50,-01.06], [+74.39,-01.69]]",
}
# Apply exact same trajectory normalization as dataloader (lines 851-863)
traj = data_obj["trajectory"]
def _fmt_coord(m):
raw = m.group(0)
sign = raw[0]
num = float(raw[1:])
return f"{sign}{num:06.2f}"
traj = _re.sub(r'[+-]\d+\.\d+', _fmt_coord, traj)
traj = _re.sub(r',([+-])', r', \1', traj)
traj = _re.sub(r'\[([+-])', r'[ \1', traj)
data_obj["trajectory"] = traj
# Serialize with json.dumps — identical to dataloader line 865
return _json.dumps(data_obj, ensure_ascii=False)
# ── Step 2: Iteratively adjust NULL count for exp_budget ──
deep_seqs = build_deep_scaffold_sequences(tokenizer)
top_seqs = deep_seqs["top"]
def _count_exp_value_tokens(tok_list):
"""Count explanation VALUE tokens (between boundary patterns)."""
co_exp_pat = top_seqs["between_co_exp"]
exp_fmb_pat = top_seqs["between_exp_fmb"]
co_exp_pos = _find_subseq(tok_list, co_exp_pat, 0)
if co_exp_pos < 0:
return None
exp_start = co_exp_pos + len(co_exp_pat)
exp_fmb_pos = _find_subseq(tok_list, exp_fmb_pat, exp_start)
if exp_fmb_pos < 0:
return None
# exp_start..exp_fmb_pos includes opening/closing quotes (scaffold)
# value tokens = total - 2 (quotes)
return (exp_fmb_pos - exp_start) - 2
# Measure base explanation tokens (no NULLs)
toks_0 = tokenizer.encode(_build_template(0), add_special_tokens=False)
base_exp = _count_exp_value_tokens(toks_0)
if base_exp is not None:
needed_nulls = max(0, exp_budget - base_exp)
else:
needed_nulls = exp_budget // 2 # fallback
# Build and measure, adjust once
template = _build_template(needed_nulls)
template_tokens = tokenizer.encode(template, add_special_tokens=False)
actual_exp = _count_exp_value_tokens(template_tokens)
if actual_exp is not None and actual_exp != exp_budget:
needed_nulls = max(0, needed_nulls + (exp_budget - actual_exp))
template = _build_template(needed_nulls)
template_tokens = tokenizer.encode(template, add_special_tokens=False)
# ── Step 3: Run training scaffold detection ──
prompt_len = 10
all_tokens = [1] * prompt_len + template_tokens
labels_list = [-100] * prompt_len + template_tokens
labels = _torch.tensor([labels_list])
token_ids = _torch.tensor([all_tokens])
_, _, _, scaffold_mask_tensor, _ = compute_section_block_idx_deep_static(
labels, token_ids, deep_seqs, fallback_block_size=32,
)
# ── Step 4: Extract scaffold/value and replace value with MASK ──
scaffold_tokens = list(template_tokens)
scaffold_mask_list: List[int] = []
for i in range(len(template_tokens)):
abs_pos = prompt_len + i
is_scaffold = scaffold_mask_tensor[abs_pos].item()
scaffold_mask_list.append(1 if is_scaffold else 0)
for i in range(len(scaffold_tokens)):
if scaffold_mask_list[i] == 0:
scaffold_tokens[i] = mask_id
# ── Step 5: Compute section ranges ──
section_ranges: Dict[str, Tuple[int, int]] = {}
boundary_order = [
("prefix", "critical_objects"),
("between_co_exp", "explanation"),
("between_exp_fmb", "future_meta_behavior"),
("between_fmb_traj", "trajectory"),
]
search_from = 0
prev_section_name = None
prev_value_start = None
for boundary_key, section_name in boundary_order:
pattern = top_seqs.get(boundary_key)
if pattern is None:
continue
pos = _find_subseq(template_tokens, pattern, search_from)
if pos < 0:
continue
if prev_section_name is not None and prev_value_start is not None:
section_ranges[prev_section_name] = (prev_value_start, pos)
value_start = pos + len(pattern)
prev_section_name = section_name
prev_value_start = value_start
search_from = value_start
if prev_section_name is not None and prev_value_start is not None:
section_ranges[prev_section_name] = (prev_value_start, len(template_tokens))
return scaffold_tokens, section_ranges, scaffold_mask_list
def _find_subseq(seq: List[int], pattern: List[int], start: int = 0) -> int:
"""Find first occurrence of *pattern* in *seq* starting at *start*. Returns -1 if not found."""
n = len(pattern)
for i in range(start, len(seq) - n + 1):
if seq[i : i + n] == pattern:
return i
return -1
def build_deep_scaffold_sequences(tokenizer) -> Dict[str, object]:
"""
Pre-compute token sequences for deep scaffold matching.
Returns a dict with:
- Top-level boundary patterns (same as build_static_scaffold_sequences)
- Sub-key patterns for critical_objects, future_meta_behavior, trajectory
"""
seqs: Dict[str, object] = {}
# ── Top-level boundaries (reuse existing) ──
seqs["top"] = build_static_scaffold_sequences(tokenizer)
# ── critical_objects sub-key patterns ──
# In context, CO value starts with ' {"nearby_vehicle": "yes", ...'
# Token 5212 = ' {"' merges space+brace+quote in context
# First entry: ' {"key": "'
# Subsequent: '", "key": "' (token 497='","' merges quote+comma)
co_patterns = []
for i, key in enumerate(CRITICAL_OBJECTS_SUBKEYS):
if i == 0:
pattern = tokenizer.encode(' {"' + key + '": "', add_special_tokens=False)
else:
pattern = tokenizer.encode('", "' + key + '": "', add_special_tokens=False)
co_patterns.append({"key": key, "pattern": pattern, "index": i})
seqs["co_subkeys"] = co_patterns
seqs["co_closing"] = tokenizer.encode('"}', add_special_tokens=False)
# json.dumps produces "}," which may merge into a single token
seqs["co_closing_comma"] = tokenizer.encode('"},', add_special_tokens=False)
# ── future_meta_behavior sub-key patterns ──
# After dataloader processing (mdm markers removed, NULLs cleaned):
# ' {"longitudinal": "keep speed", "lateral": "go straight"}'
# Scaffold = everything except the value content between quotes.
seqs["fmb_prefix"] = tokenizer.encode(' {"longitudinal": "', add_special_tokens=False)
seqs["fmb_closing"] = tokenizer.encode('"}', add_special_tokens=False)
seqs["fmb_closing_comma"] = tokenizer.encode('"},', add_special_tokens=False)
# Between longitudinal value and lateral value: '", "lateral": "'
seqs["fmb_between"] = tokenizer.encode('", "lateral": "', add_special_tokens=False)
# ── trajectory structure patterns ──
# After dataloader processing (no mdm markers), traj is:
# ' "[[+14.70,-00.04], [+29.55,-00.21], ...]"'
seqs["traj_open"] = tokenizer.encode(' "[[', add_special_tokens=False)
# After dataloader inserts spaces (e.g. [+14.70,-00.04] → [ +14.70, -00.04]),
# tokens split cleanly: '],'(1125), ' ['(508), ','(11) are all independent.
seqs["traj_wp_sep"] = tokenizer.encode('],', add_special_tokens=False) # [1125]
seqs["traj_wp_open"] = tokenizer.encode(' [', add_special_tokens=False) # [508]
seqs["traj_coord_comma"] = tokenizer.encode(',', add_special_tokens=False) # [11]
seqs["traj_close"] = tokenizer.encode(']]"}', add_special_tokens=False)
seqs["traj_close_split"] = tokenizer.encode(']]"', add_special_tokens=False)
seqs["traj_close_split2"] = tokenizer.encode(']]', add_special_tokens=False)
# Trajectory-only output support, e.g. {"trajectory": "..."}.
seqs["traj_only_boundaries"] = [
tokenizer.encode('{"trajectory":', add_special_tokens=False),
tokenizer.encode(' {"trajectory":', add_special_tokens=False),
tokenizer.encode('"trajectory":', add_special_tokens=False),
tokenizer.encode(' "trajectory":', add_special_tokens=False),
]
return seqs
def _mark_scaffold_range(scaffold_positions: List[int], start: int, length: int):
"""Add positions [start, start+length) to scaffold_positions."""
for i in range(length):
scaffold_positions.append(start + i)
def compute_section_block_idx_deep_static(
labels: torch.Tensor,
token_ids: torch.Tensor,
deep_scaffold_sequences: Dict[str, object],
fallback_block_size: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]:
"""
Deep-scaffold v2 block index computation.
Freezes sub-keys within sections:
- critical_objects: only yes/no values are denoised
- future_meta_behavior: only value tokens are denoised
- trajectory: only coordinate digits are denoised
- explanation: all content is denoised
Block count per section is computed dynamically:
``n_blocks = ceil(num_value_tokens / fallback_block_size)``.
Args:
labels: [B, seq_len]
token_ids: [B, seq_len]
deep_scaffold_sequences: output of ``build_deep_scaffold_sequences``
fallback_block_size: block size (bd_size), default 32
Returns:
response_block_idx, turn_idx, n_blocks, scaffold_mask
"""
labels_single = labels[0]
token_list = token_ids[0].tolist()
seq_len = labels_single.shape[0]
device = labels.device
response_mask = (labels_single != -100)
response_block_idx = torch.full((seq_len,), -1, device=device, dtype=torch.int64)
turn_idx = torch.zeros((seq_len,), device=device, dtype=torch.int64)
scaffold_mask = torch.zeros((seq_len,), device=device, dtype=torch.bool)
response_positions = response_mask.nonzero(as_tuple=True)[0]
if len(response_positions) == 0:
return response_block_idx, turn_idx, 0, scaffold_mask
resp_start = response_positions[0].item()
resp_end = response_positions[-1].item() + 1
effective_resp_end = resp_end
resp_tokens = token_list[resp_start:resp_end]
top_seqs = deep_scaffold_sequences["top"]
# ── Step 1: Find top-level section boundaries (same as static version) ──
boundary_order = [
("prefix", "critical_objects"),
("between_co_exp", "explanation"),
("between_exp_fmb", "future_meta_behavior"),
("between_fmb_traj", "trajectory"),
]
sections: Dict[str, Tuple[int, int]] = {}
scaffold_positions: List[int] = []
# Top-level boundary scaffold tokens should belong to the *following*
# section's first block (e.g. `"explanation":` -> explanation block 0).
boundary_scaffold_to_section: Dict[str, List[int]] = {}
search_from = 0
prev_section_name: Optional[str] = None
prev_value_start: Optional[int] = None
for boundary_key, section_name in boundary_order:
pattern = top_seqs.get(boundary_key)
if pattern is None:
continue
pos = _find_subseq(resp_tokens, pattern, search_from)
if pos < 0:
continue
if prev_section_name is not None and prev_value_start is not None:
sections[prev_section_name] = (prev_value_start, pos)
_mark_scaffold_range(scaffold_positions, pos, len(pattern))
boundary_scaffold_to_section.setdefault(section_name, []).extend(
list(range(pos, pos + len(pattern)))
)
value_start = pos + len(pattern)
prev_section_name = section_name
prev_value_start = value_start
search_from = value_start
if prev_section_name is not None and prev_value_start is not None:
sections[prev_section_name] = (prev_value_start, len(resp_tokens))
# New dataset compatibility: response may contain only trajectory.
# If the 4-section boundaries are not found, try direct trajectory key match.
if "trajectory" not in sections:
traj_only_patterns = deep_scaffold_sequences.get("traj_only_boundaries", [])
# Reuse legacy boundary pattern as additional fallback (contains
# `"trajectory":` in old-format responses).
between_fmb_traj = top_seqs.get("between_fmb_traj")
if between_fmb_traj:
traj_only_patterns = list(traj_only_patterns) + [between_fmb_traj]
traj_pos = -1
traj_pat: Optional[List[int]] = None
for pat in traj_only_patterns:
if not pat:
continue
pos = _find_subseq(resp_tokens, pat, 0)
if pos >= 0:
traj_pos = pos
traj_pat = pat
break
if traj_pos >= 0 and traj_pat is not None:
_mark_scaffold_range(scaffold_positions, traj_pos, len(traj_pat))
boundary_scaffold_to_section.setdefault("trajectory", []).extend(
list(range(traj_pos, traj_pos + len(traj_pat)))
)
sections["trajectory"] = (traj_pos + len(traj_pat), len(resp_tokens))
# print(f"sections: {sections}")
# ── Step 2: Deep scaffold within critical_objects ──
if "critical_objects" in sections:
co_start, co_end = sections["critical_objects"]
co_tokens = resp_tokens[co_start:co_end]
co_search = 0
for entry in deep_scaffold_sequences["co_subkeys"]:
pattern = entry["pattern"]
pos = _find_subseq(co_tokens, pattern, co_search)
if pos < 0:
continue
_mark_scaffold_range(scaffold_positions, co_start + pos, len(pattern))
# The single value token is right after the pattern — skip it
co_search = pos + len(pattern) + 1
# Mark closing '"}' or "}," as scaffold
co_close = deep_scaffold_sequences["co_closing"]
close_pos = _find_subseq(co_tokens, co_close,
max(0, len(co_tokens) - len(co_close) - 2))
if close_pos >= 0:
_mark_scaffold_range(scaffold_positions, co_start + close_pos, len(co_close))
else:
# json.dumps may produce "}," as a single token
co_close_comma = deep_scaffold_sequences.get("co_closing_comma")
if co_close_comma:
close_pos = _find_subseq(co_tokens, co_close_comma,
max(0, len(co_tokens) - len(co_close_comma) - 2))
if close_pos >= 0:
_mark_scaffold_range(scaffold_positions, co_start + close_pos, len(co_close_comma))
# ── Step 2b: Explanation opening/closing quotes as scaffold ──
# Explanation content is all VALUE, but the surrounding quotes must be
# SCAFFOLD so that VALUE tokens are exactly block-aligned (multiple of bd_size).
if "explanation" in sections:
exp_start, exp_end = sections["explanation"]
if exp_start < exp_end:
# Opening quote: first token of explanation section (e.g. ' "')
scaffold_positions.append(exp_start)
# Closing quote+comma: last token (e.g. '",')
scaffold_positions.append(exp_start + (exp_end - exp_start) - 1)
# ── Step 3: Deep scaffold within future_meta_behavior ──
# After dataloader processing, FMB has no <|mdm_start|>/<|mdm_end|> markers.
# Format: ' {"longitudinal": "keep speed", "lateral": "go straight"}'
# Strategy: use fmb_prefix to find start, fmb_between to split long/lat values,
# and fmb_closing to find end. Everything except value content is scaffold.
if "future_meta_behavior" in sections:
fmb_start, fmb_end = sections["future_meta_behavior"]
fmb_tokens = resp_tokens[fmb_start:fmb_end]
fmb_scaffold_positions = set()
# 1. Mark fmb_prefix as scaffold: ' {"longitudinal": "'
fmb_prefix = deep_scaffold_sequences["fmb_prefix"]
prefix_pos = _find_subseq(fmb_tokens, fmb_prefix, 0)
if prefix_pos >= 0:
for i in range(prefix_pos, prefix_pos + len(fmb_prefix)):
fmb_scaffold_positions.add(i)
long_value_start = prefix_pos + len(fmb_prefix)
# 2. Mark fmb_between as scaffold: '", "lateral": "'
fmb_between = deep_scaffold_sequences.get("fmb_between")
if fmb_between:
between_pos = _find_subseq(fmb_tokens, fmb_between, long_value_start)
if between_pos >= 0:
for i in range(between_pos, between_pos + len(fmb_between)):
fmb_scaffold_positions.add(i)
lat_value_start = between_pos + len(fmb_between)
# 3. Mark closing '"}' or "}," as scaffold
fmb_close = deep_scaffold_sequences["fmb_closing"]
close_pos = _find_subseq(fmb_tokens, fmb_close,
max(0, len(fmb_tokens) - len(fmb_close) - 2))
if close_pos < 0:
fmb_close_comma = deep_scaffold_sequences.get("fmb_closing_comma")
if fmb_close_comma:
close_pos = _find_subseq(fmb_tokens, fmb_close_comma,
max(0, len(fmb_tokens) - len(fmb_close_comma) - 2))
if close_pos >= 0:
fmb_close = fmb_close_comma
if close_pos >= 0:
for i in range(close_pos, close_pos + len(fmb_close)):
fmb_scaffold_positions.add(i)
for i in fmb_scaffold_positions:
scaffold_positions.append(fmb_start + i)
# ── Step 4: Deep scaffold within trajectory ──
# After dataloader processing (no mdm markers), trajectory is:
# ' "[[+14.70,-00.04], [+29.55,-00.21], ...]"'
if "trajectory" in sections:
traj_start, traj_end = sections["trajectory"]
traj_tokens = resp_tokens[traj_start:traj_end]
# Opening "[[
traj_open = deep_scaffold_sequences["traj_open"]
open_pos = _find_subseq(traj_tokens, traj_open, 0)
if open_pos >= 0:
_mark_scaffold_range(scaffold_positions, traj_start + open_pos, len(traj_open))
# Waypoint separators ], (4 of them between 5 waypoints)
traj_wp_sep = deep_scaffold_sequences["traj_wp_sep"]
sep_search = 0
for _ in range(4):
sep_pos = _find_subseq(traj_tokens, traj_wp_sep, sep_search)
if sep_pos < 0:
break
_mark_scaffold_range(scaffold_positions, traj_start + sep_pos, len(traj_wp_sep))
sep_search = sep_pos + len(traj_wp_sep)
# Intermediate waypoint opening ' [' (4 of them, between 5 waypoints)
traj_wp_open = deep_scaffold_sequences.get("traj_wp_open")
if traj_wp_open:
wo_search = 0
for _ in range(4):
wo_pos = _find_subseq(traj_tokens, traj_wp_open, wo_search)
if wo_pos < 0:
break
_mark_scaffold_range(scaffold_positions, traj_start + wo_pos, len(traj_wp_open))
wo_search = wo_pos + len(traj_wp_open)
# Coordinate comma ',' between x and y within each waypoint (5 of them)
traj_coord_comma = deep_scaffold_sequences.get("traj_coord_comma")
if traj_coord_comma:
cc_search = 0
for _ in range(5):
cc_pos = _find_subseq(traj_tokens, traj_coord_comma, cc_search)
if cc_pos < 0:
break
_mark_scaffold_range(scaffold_positions, traj_start + cc_pos, len(traj_coord_comma))
cc_search = cc_pos + len(traj_coord_comma)
# Closing ]]" or just ]]
traj_close = deep_scaffold_sequences["traj_close"]
close_pos = _find_subseq(traj_tokens, traj_close,
max(0, len(traj_tokens) - len(traj_close) - 6))
if close_pos < 0:
for split_key in ["traj_close_split", "traj_close_split2"]:
tcs = deep_scaffold_sequences.get(split_key)
if tcs:
close_pos = _find_subseq(traj_tokens, tcs,
max(0, len(traj_tokens) - len(tcs) - 6))
if close_pos >= 0:
traj_close = tcs
break
if close_pos >= 0:
_mark_scaffold_range(scaffold_positions, traj_start + close_pos, len(traj_close))
# Align training with inference scaffold: exclude trailing tokens
# after the JSON closing of trajectory (e.g. "<|im_end|>\n") from
# section/block scheduling.
effective_resp_end = min(
effective_resp_end,
resp_start + traj_start + close_pos + len(traj_close),
)
# Opening quote " (first token of traj value)
if len(traj_tokens) > 0:
scaffold_positions.append(traj_start)
# ── Mark scaffold mask (absolute positions) ──
scaffold_positions_set = set(scaffold_positions)
for sp in scaffold_positions_set:
abs_pos = resp_start + sp
if abs_pos < seq_len:
scaffold_mask[abs_pos] = True
# ── Assign block indices per section ──
current_block = 0
assigned = set()
block_to_section = {} # block_idx -> section_name (for SASD compatibility)
section_first_block: Dict[str, int] = {}
for section_name in SECTION_KEYS:
if section_name not in sections:
continue
rel_start, rel_end = sections[section_name]
abs_start = resp_start + rel_start
abs_end = resp_start + rel_end
abs_start = max(abs_start, resp_start)
abs_end = min(abs_end, effective_resp_end)
num_tokens = abs_end - abs_start
if num_tokens <= 0:
continue
# Count only non-scaffold tokens for block sizing
value_positions = [p for p in range(abs_start, abs_end)
if response_mask[p] and (p - resp_start) not in scaffold_positions_set]
num_value_tokens = len(value_positions)
if num_value_tokens <= 0:
section_first_block[section_name] = current_block
block_to_section[current_block] = section_name
current_block += 1
continue
# Use fixed block size (bd_size) and compute number of blocks dynamically
tokens_per_step = fallback_block_size
n_steps = max(1, math.ceil(num_value_tokens / tokens_per_step))
for b in range(n_steps):
block_to_section[current_block + b] = section_name
section_first_block[section_name] = current_block
for vi, pos in enumerate(value_positions):
block_in_section = min(vi // tokens_per_step, n_steps - 1)
response_block_idx[pos] = current_block + block_in_section
assigned.add(pos)
current_block += n_steps
# Assign scaffold tokens within each section to the nearest value token
# in the SAME section. This keeps section-closing tokens such as `"},`
# with their section instead of drifting to the next section.
for section_name in SECTION_KEYS:
if section_name not in sections:
continue
rel_start, rel_end = sections[section_name]
abs_start = max(resp_start + rel_start, resp_start)
abs_end = min(resp_start + rel_end, resp_end)
if abs_end <= abs_start:
continue
for abs_pos in range(abs_start, abs_end):
rel_pos = abs_pos - resp_start
if (
abs_pos >= seq_len
or not response_mask[abs_pos]
or abs_pos in assigned
or rel_pos not in scaffold_positions_set
):
continue
best_block = -1
max_delta = max(1, abs_end - abs_start)
for delta in range(1, max_delta + 1):
# Prefer left first so closing punctuation tends to stay with
# the preceding content in the same section.
for cand in [abs_pos - delta, abs_pos + delta]:
if abs_start <= cand < abs_end and cand in assigned:
best_block = response_block_idx[cand].item()
break
if best_block >= 0:
break
if best_block < 0:
best_block = section_first_block.get(section_name, -1)
if best_block >= 0:
response_block_idx[abs_pos] = best_block
assigned.add(abs_pos)
# Top-level boundary tokens are explicitly attached to the following
# section's first block, instead of nearest-neighbor assignment.
for section_name, rel_positions in boundary_scaffold_to_section.items():
first_block = section_first_block.get(section_name)
if first_block is None:
continue
for rel_pos in rel_positions:
abs_pos = resp_start + rel_pos
if abs_pos >= seq_len or not response_mask[abs_pos]:
continue
response_block_idx[abs_pos] = first_block
assigned.add(abs_pos)
# Scaffold tokens → block index of nearest assigned neighbour
for sp in scaffold_positions_set:
abs_pos = resp_start + sp
if (
abs_pos >= seq_len
or abs_pos >= effective_resp_end
or not response_mask[abs_pos]
or abs_pos in assigned
):
continue
best_block = -1
for delta in range(1, seq_len):
for cand in [abs_pos + delta, abs_pos - delta]:
if 0 <= cand < seq_len and cand in assigned:
best_block = response_block_idx[cand].item()
break
if best_block >= 0:
break
if best_block >= 0:
response_block_idx[abs_pos] = best_block
assigned.add(abs_pos)
# Fallback for unassigned response tokens
for pos in range(resp_start, effective_resp_end):
if response_mask[pos] and pos not in assigned:
offset = pos - resp_start
response_block_idx[pos] = current_block + offset // fallback_block_size
assigned.add(pos)
fallback_positions = [p for p in range(resp_start, effective_resp_end)
if response_mask[p] and response_block_idx[p].item() >= current_block]
if fallback_positions:
current_block = max(response_block_idx[p].item() for p in fallback_positions) + 1
n_blocks = current_block
# Turn index
for i in range(1, seq_len):
if response_block_idx[i] != response_block_idx[i - 1]:
turn_idx[i] = turn_idx[i - 1] + 1
else:
turn_idx[i] = turn_idx[i - 1]
return response_block_idx, turn_idx, n_blocks, scaffold_mask, block_to_section