""" Section-aware block scheduling for JSON structured output. Inspired by the S3 (Self-adaptive Schema Scaffolding) paper (arXiv:2507.04504), this module provides utilities to: 1. Parse tokenized JSON output into sections (critical_objects, explanation, etc.) 2. Assign section-aware block indices for variable block sizes per section 3. Build JSON scaffolds for inference (pre-fill structural tokens) The DVLM-AD output schema has 4 sections: - critical_objects: ~88 tokens (12 yes/no fields, nearly constant) - explanation: ~114 tokens (variable, 72-172) - future_meta_behavior: ~40 tokens (nearly constant) - trajectory: ~80 tokens (nearly constant) """ import torch from typing import Dict, List, Optional, Tuple import math # Ordered list of section keys as they appear in the JSON output SECTION_KEYS = [ "critical_objects", "explanation", "future_meta_behavior", "trajectory", ] # Default token budgets per section (based on training data analysis) DEFAULT_TOKEN_BUDGETS = { "critical_objects": 88, "explanation": 128, "future_meta_behavior": 40, "trajectory": 80, } # Default steps per section DEFAULT_SECTION_STEPS = { "critical_objects": 1, "explanation": 3, "future_meta_behavior": 1, "trajectory": 1, } def _v1_removed_parse_json_sections(*args, **kwargs): raise NotImplementedError("DS v1 parse_json_sections has been removed. Use deep scaffold v2.") def _v1_removed_compute_section_block_idx(*args, **kwargs): raise NotImplementedError("DS v1 compute_section_block_idx has been removed. Use compute_section_block_idx_deep_static.") def _v1_removed_build_json_scaffold(*args, **kwargs): raise NotImplementedError("DS v1 build_json_scaffold has been removed. Use build_deep_json_scaffold.") def _v1_removed_compute_section_block_sizes(*args, **kwargs): raise NotImplementedError("DS v1 compute_section_block_sizes has been removed.") def build_static_scaffold_sequences(tokenizer) -> Dict[str, List[int]]: """Pre-compute token sequences for top-level JSON boundary matching. Used internally by :func:`build_deep_scaffold_sequences`. """ return { "prefix": tokenizer.encode('{"critical_objects":', add_special_tokens=False), "between_co_exp": tokenizer.encode(' "explanation":', add_special_tokens=False), "between_exp_fmb": tokenizer.encode(' "future_meta_behavior":', add_special_tokens=False), "between_fmb_traj": tokenizer.encode(' "trajectory":', add_special_tokens=False), } def _v1_removed_compute_section_block_idx_static(*args, **kwargs): raise NotImplementedError("DS v1 compute_section_block_idx_static has been removed. Use compute_section_block_idx_deep_static.") # Backward-compatible aliases so stale imports produce clear errors parse_json_sections = _v1_removed_parse_json_sections compute_section_block_idx = _v1_removed_compute_section_block_idx build_json_scaffold = _v1_removed_build_json_scaffold compute_section_block_sizes = _v1_removed_compute_section_block_sizes compute_section_block_idx_static = _v1_removed_compute_section_block_idx_static # ═══════════════════════════════════════════════════════════════ # Deep scaffold v2: constants and utilities # ═══════════════════════════════════════════════════════════════ NULL_TOKEN_ID = 151666 # critical_objects: 12 sub-keys, each value is exactly 1 token (yes=9693 / no=2152) CRITICAL_OBJECTS_SUBKEYS = [ "nearby_vehicle", "pedestrian", "cyclist", "construction", "traffic_element", "weather_condition", "road_hazard", "emergency_vehicle", "animal", "special_vehicle", "conflicting_vehicle", "door_opening_vehicle", ] # future_meta_behavior: each sub-key value is exactly 3 tokens # (e.g., "keep speed" → [4867, 4732, 151667] or "go straight" → [2849, 7833, 151667]) FMB_VALUE_BUDGET = 3 def build_deep_json_scaffold( tokenizer, section_token_budgets: Optional[Dict[str, int]] = None, mask_id: Optional[int] = None, null_id: Optional[int] = None, explanation_block_size: int = 32, explanation_max_blocks: int = 6, ) -> Tuple[List[int], Dict[str, Tuple[int, int]], List[int]]: """Build a deep JSON scaffold for inference (v2). Constructs a template response by building a Python dict and processing it through the **exact same pipeline** as the training dataloader (``multi_modal_dataset.py``): 1. Build a realistic dict with placeholder values. 2. Pad explanation with ``<|NULL|>`` to ``exp_budget`` tokens. 3. Pad FMB values with ``<|NULL|>`` to 3 tokens each. 4. Normalize trajectory to ``+XXX.XX`` format with spaces. 5. Serialize with ``json.dumps(obj, ensure_ascii=False)``. 6. Tokenize the whole string as one piece. 7. Run ``compute_section_block_idx_deep_static`` to get scaffold/value. 8. Replace value positions with MASK tokens. This guarantees identical BPE tokenization as training data. Returns ------- scaffold_tokens : list[int] Token IDs with MASK at value positions. section_ranges : dict Section name -> (start, end) within scaffold_tokens. scaffold_mask : list[int] 0 = value (to denoise), 1 = scaffold (frozen). """ import torch as _torch import json as _json import re as _re if mask_id is None: mask_tok = tokenizer.encode("||", add_special_tokens=False) mask_id = mask_tok[0] if len(mask_tok) == 1 else 151665 if null_id is None: null_id = NULL_TOKEN_ID exp_budget = explanation_block_size * explanation_max_blocks # default 192 # ── Step 1: Build a Python dict matching training data structure ── # Placeholder explanation text (will be replaced with MASK anyway). filler_explanation = ( "The ego vehicle is driving forward on the road. " "There are nearby vehicles ahead that may affect the path. " "No pedestrians or cyclists are detected in the immediate area. " "The road conditions appear normal with no hazards present. " "Speed adjustment may be needed based on the traffic ahead. " "No lateral maneuvering is required at this time." ) def _build_template(n_exp_nulls: int) -> str: """Build template via json.dumps — identical to dataloader output.""" null_pad = "<|NULL|>" * n_exp_nulls data_obj = { "critical_objects": { "nearby_vehicle": "no", "pedestrian": "no", "cyclist": "no", "construction": "no", "traffic_element": "no", "weather_condition": "no", "road_hazard": "no", "emergency_vehicle": "no", "animal": "no", "special_vehicle": "no", "conflicting_vehicle": "no", "door_opening_vehicle": "no", }, "explanation": filler_explanation + null_pad, "future_meta_behavior": { "longitudinal": "come to stop", "lateral": "go straight<|NULL|>", }, # Raw trajectory — will be normalized below "trajectory": "[[+14.70,-00.04], [+29.55,-00.21], [+44.51,-00.56], [+59.50,-01.06], [+74.39,-01.69]]", } # Apply exact same trajectory normalization as dataloader (lines 851-863) traj = data_obj["trajectory"] def _fmt_coord(m): raw = m.group(0) sign = raw[0] num = float(raw[1:]) return f"{sign}{num:06.2f}" traj = _re.sub(r'[+-]\d+\.\d+', _fmt_coord, traj) traj = _re.sub(r',([+-])', r', \1', traj) traj = _re.sub(r'\[([+-])', r'[ \1', traj) data_obj["trajectory"] = traj # Serialize with json.dumps — identical to dataloader line 865 return _json.dumps(data_obj, ensure_ascii=False) # ── Step 2: Iteratively adjust NULL count for exp_budget ── deep_seqs = build_deep_scaffold_sequences(tokenizer) top_seqs = deep_seqs["top"] def _count_exp_value_tokens(tok_list): """Count explanation VALUE tokens (between boundary patterns).""" co_exp_pat = top_seqs["between_co_exp"] exp_fmb_pat = top_seqs["between_exp_fmb"] co_exp_pos = _find_subseq(tok_list, co_exp_pat, 0) if co_exp_pos < 0: return None exp_start = co_exp_pos + len(co_exp_pat) exp_fmb_pos = _find_subseq(tok_list, exp_fmb_pat, exp_start) if exp_fmb_pos < 0: return None # exp_start..exp_fmb_pos includes opening/closing quotes (scaffold) # value tokens = total - 2 (quotes) return (exp_fmb_pos - exp_start) - 2 # Measure base explanation tokens (no NULLs) toks_0 = tokenizer.encode(_build_template(0), add_special_tokens=False) base_exp = _count_exp_value_tokens(toks_0) if base_exp is not None: needed_nulls = max(0, exp_budget - base_exp) else: needed_nulls = exp_budget // 2 # fallback # Build and measure, adjust once template = _build_template(needed_nulls) template_tokens = tokenizer.encode(template, add_special_tokens=False) actual_exp = _count_exp_value_tokens(template_tokens) if actual_exp is not None and actual_exp != exp_budget: needed_nulls = max(0, needed_nulls + (exp_budget - actual_exp)) template = _build_template(needed_nulls) template_tokens = tokenizer.encode(template, add_special_tokens=False) # ── Step 3: Run training scaffold detection ── prompt_len = 10 all_tokens = [1] * prompt_len + template_tokens labels_list = [-100] * prompt_len + template_tokens labels = _torch.tensor([labels_list]) token_ids = _torch.tensor([all_tokens]) _, _, _, scaffold_mask_tensor, _ = compute_section_block_idx_deep_static( labels, token_ids, deep_seqs, fallback_block_size=32, ) # ── Step 4: Extract scaffold/value and replace value with MASK ── scaffold_tokens = list(template_tokens) scaffold_mask_list: List[int] = [] for i in range(len(template_tokens)): abs_pos = prompt_len + i is_scaffold = scaffold_mask_tensor[abs_pos].item() scaffold_mask_list.append(1 if is_scaffold else 0) for i in range(len(scaffold_tokens)): if scaffold_mask_list[i] == 0: scaffold_tokens[i] = mask_id # ── Step 5: Compute section ranges ── section_ranges: Dict[str, Tuple[int, int]] = {} boundary_order = [ ("prefix", "critical_objects"), ("between_co_exp", "explanation"), ("between_exp_fmb", "future_meta_behavior"), ("between_fmb_traj", "trajectory"), ] search_from = 0 prev_section_name = None prev_value_start = None for boundary_key, section_name in boundary_order: pattern = top_seqs.get(boundary_key) if pattern is None: continue pos = _find_subseq(template_tokens, pattern, search_from) if pos < 0: continue if prev_section_name is not None and prev_value_start is not None: section_ranges[prev_section_name] = (prev_value_start, pos) value_start = pos + len(pattern) prev_section_name = section_name prev_value_start = value_start search_from = value_start if prev_section_name is not None and prev_value_start is not None: section_ranges[prev_section_name] = (prev_value_start, len(template_tokens)) return scaffold_tokens, section_ranges, scaffold_mask_list def _find_subseq(seq: List[int], pattern: List[int], start: int = 0) -> int: """Find first occurrence of *pattern* in *seq* starting at *start*. Returns -1 if not found.""" n = len(pattern) for i in range(start, len(seq) - n + 1): if seq[i : i + n] == pattern: return i return -1 def build_deep_scaffold_sequences(tokenizer) -> Dict[str, object]: """ Pre-compute token sequences for deep scaffold matching. Returns a dict with: - Top-level boundary patterns (same as build_static_scaffold_sequences) - Sub-key patterns for critical_objects, future_meta_behavior, trajectory """ seqs: Dict[str, object] = {} # ── Top-level boundaries (reuse existing) ── seqs["top"] = build_static_scaffold_sequences(tokenizer) # ── critical_objects sub-key patterns ── # In context, CO value starts with ' {"nearby_vehicle": "yes", ...' # Token 5212 = ' {"' merges space+brace+quote in context # First entry: ' {"key": "' # Subsequent: '", "key": "' (token 497='","' merges quote+comma) co_patterns = [] for i, key in enumerate(CRITICAL_OBJECTS_SUBKEYS): if i == 0: pattern = tokenizer.encode(' {"' + key + '": "', add_special_tokens=False) else: pattern = tokenizer.encode('", "' + key + '": "', add_special_tokens=False) co_patterns.append({"key": key, "pattern": pattern, "index": i}) seqs["co_subkeys"] = co_patterns seqs["co_closing"] = tokenizer.encode('"}', add_special_tokens=False) # json.dumps produces "}," which may merge into a single token seqs["co_closing_comma"] = tokenizer.encode('"},', add_special_tokens=False) # ── future_meta_behavior sub-key patterns ── # After dataloader processing (mdm markers removed, NULLs cleaned): # ' {"longitudinal": "keep speed", "lateral": "go straight"}' # Scaffold = everything except the value content between quotes. seqs["fmb_prefix"] = tokenizer.encode(' {"longitudinal": "', add_special_tokens=False) seqs["fmb_closing"] = tokenizer.encode('"}', add_special_tokens=False) seqs["fmb_closing_comma"] = tokenizer.encode('"},', add_special_tokens=False) # Between longitudinal value and lateral value: '", "lateral": "' seqs["fmb_between"] = tokenizer.encode('", "lateral": "', add_special_tokens=False) # ── trajectory structure patterns ── # After dataloader processing (no mdm markers), traj is: # ' "[[+14.70,-00.04], [+29.55,-00.21], ...]"' seqs["traj_open"] = tokenizer.encode(' "[[', add_special_tokens=False) # After dataloader inserts spaces (e.g. [+14.70,-00.04] → [ +14.70, -00.04]), # tokens split cleanly: '],'(1125), ' ['(508), ','(11) are all independent. seqs["traj_wp_sep"] = tokenizer.encode('],', add_special_tokens=False) # [1125] seqs["traj_wp_open"] = tokenizer.encode(' [', add_special_tokens=False) # [508] seqs["traj_coord_comma"] = tokenizer.encode(',', add_special_tokens=False) # [11] seqs["traj_close"] = tokenizer.encode(']]"}', add_special_tokens=False) seqs["traj_close_split"] = tokenizer.encode(']]"', add_special_tokens=False) seqs["traj_close_split2"] = tokenizer.encode(']]', add_special_tokens=False) # Trajectory-only output support, e.g. {"trajectory": "..."}. seqs["traj_only_boundaries"] = [ tokenizer.encode('{"trajectory":', add_special_tokens=False), tokenizer.encode(' {"trajectory":', add_special_tokens=False), tokenizer.encode('"trajectory":', add_special_tokens=False), tokenizer.encode(' "trajectory":', add_special_tokens=False), ] return seqs def _mark_scaffold_range(scaffold_positions: List[int], start: int, length: int): """Add positions [start, start+length) to scaffold_positions.""" for i in range(length): scaffold_positions.append(start + i) def compute_section_block_idx_deep_static( labels: torch.Tensor, token_ids: torch.Tensor, deep_scaffold_sequences: Dict[str, object], fallback_block_size: int = 32, ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: """ Deep-scaffold v2 block index computation. Freezes sub-keys within sections: - critical_objects: only yes/no values are denoised - future_meta_behavior: only value tokens are denoised - trajectory: only coordinate digits are denoised - explanation: all content is denoised Block count per section is computed dynamically: ``n_blocks = ceil(num_value_tokens / fallback_block_size)``. Args: labels: [B, seq_len] token_ids: [B, seq_len] deep_scaffold_sequences: output of ``build_deep_scaffold_sequences`` fallback_block_size: block size (bd_size), default 32 Returns: response_block_idx, turn_idx, n_blocks, scaffold_mask """ labels_single = labels[0] token_list = token_ids[0].tolist() seq_len = labels_single.shape[0] device = labels.device response_mask = (labels_single != -100) response_block_idx = torch.full((seq_len,), -1, device=device, dtype=torch.int64) turn_idx = torch.zeros((seq_len,), device=device, dtype=torch.int64) scaffold_mask = torch.zeros((seq_len,), device=device, dtype=torch.bool) response_positions = response_mask.nonzero(as_tuple=True)[0] if len(response_positions) == 0: return response_block_idx, turn_idx, 0, scaffold_mask resp_start = response_positions[0].item() resp_end = response_positions[-1].item() + 1 effective_resp_end = resp_end resp_tokens = token_list[resp_start:resp_end] top_seqs = deep_scaffold_sequences["top"] # ── Step 1: Find top-level section boundaries (same as static version) ── boundary_order = [ ("prefix", "critical_objects"), ("between_co_exp", "explanation"), ("between_exp_fmb", "future_meta_behavior"), ("between_fmb_traj", "trajectory"), ] sections: Dict[str, Tuple[int, int]] = {} scaffold_positions: List[int] = [] # Top-level boundary scaffold tokens should belong to the *following* # section's first block (e.g. `"explanation":` -> explanation block 0). boundary_scaffold_to_section: Dict[str, List[int]] = {} search_from = 0 prev_section_name: Optional[str] = None prev_value_start: Optional[int] = None for boundary_key, section_name in boundary_order: pattern = top_seqs.get(boundary_key) if pattern is None: continue pos = _find_subseq(resp_tokens, pattern, search_from) if pos < 0: continue if prev_section_name is not None and prev_value_start is not None: sections[prev_section_name] = (prev_value_start, pos) _mark_scaffold_range(scaffold_positions, pos, len(pattern)) boundary_scaffold_to_section.setdefault(section_name, []).extend( list(range(pos, pos + len(pattern))) ) value_start = pos + len(pattern) prev_section_name = section_name prev_value_start = value_start search_from = value_start if prev_section_name is not None and prev_value_start is not None: sections[prev_section_name] = (prev_value_start, len(resp_tokens)) # New dataset compatibility: response may contain only trajectory. # If the 4-section boundaries are not found, try direct trajectory key match. if "trajectory" not in sections: traj_only_patterns = deep_scaffold_sequences.get("traj_only_boundaries", []) # Reuse legacy boundary pattern as additional fallback (contains # `"trajectory":` in old-format responses). between_fmb_traj = top_seqs.get("between_fmb_traj") if between_fmb_traj: traj_only_patterns = list(traj_only_patterns) + [between_fmb_traj] traj_pos = -1 traj_pat: Optional[List[int]] = None for pat in traj_only_patterns: if not pat: continue pos = _find_subseq(resp_tokens, pat, 0) if pos >= 0: traj_pos = pos traj_pat = pat break if traj_pos >= 0 and traj_pat is not None: _mark_scaffold_range(scaffold_positions, traj_pos, len(traj_pat)) boundary_scaffold_to_section.setdefault("trajectory", []).extend( list(range(traj_pos, traj_pos + len(traj_pat))) ) sections["trajectory"] = (traj_pos + len(traj_pat), len(resp_tokens)) # print(f"sections: {sections}") # ── Step 2: Deep scaffold within critical_objects ── if "critical_objects" in sections: co_start, co_end = sections["critical_objects"] co_tokens = resp_tokens[co_start:co_end] co_search = 0 for entry in deep_scaffold_sequences["co_subkeys"]: pattern = entry["pattern"] pos = _find_subseq(co_tokens, pattern, co_search) if pos < 0: continue _mark_scaffold_range(scaffold_positions, co_start + pos, len(pattern)) # The single value token is right after the pattern — skip it co_search = pos + len(pattern) + 1 # Mark closing '"}' or "}," as scaffold co_close = deep_scaffold_sequences["co_closing"] close_pos = _find_subseq(co_tokens, co_close, max(0, len(co_tokens) - len(co_close) - 2)) if close_pos >= 0: _mark_scaffold_range(scaffold_positions, co_start + close_pos, len(co_close)) else: # json.dumps may produce "}," as a single token co_close_comma = deep_scaffold_sequences.get("co_closing_comma") if co_close_comma: close_pos = _find_subseq(co_tokens, co_close_comma, max(0, len(co_tokens) - len(co_close_comma) - 2)) if close_pos >= 0: _mark_scaffold_range(scaffold_positions, co_start + close_pos, len(co_close_comma)) # ── Step 2b: Explanation opening/closing quotes as scaffold ── # Explanation content is all VALUE, but the surrounding quotes must be # SCAFFOLD so that VALUE tokens are exactly block-aligned (multiple of bd_size). if "explanation" in sections: exp_start, exp_end = sections["explanation"] if exp_start < exp_end: # Opening quote: first token of explanation section (e.g. ' "') scaffold_positions.append(exp_start) # Closing quote+comma: last token (e.g. '",') scaffold_positions.append(exp_start + (exp_end - exp_start) - 1) # ── Step 3: Deep scaffold within future_meta_behavior ── # After dataloader processing, FMB has no <|mdm_start|>/<|mdm_end|> markers. # Format: ' {"longitudinal": "keep speed", "lateral": "go straight"}' # Strategy: use fmb_prefix to find start, fmb_between to split long/lat values, # and fmb_closing to find end. Everything except value content is scaffold. if "future_meta_behavior" in sections: fmb_start, fmb_end = sections["future_meta_behavior"] fmb_tokens = resp_tokens[fmb_start:fmb_end] fmb_scaffold_positions = set() # 1. Mark fmb_prefix as scaffold: ' {"longitudinal": "' fmb_prefix = deep_scaffold_sequences["fmb_prefix"] prefix_pos = _find_subseq(fmb_tokens, fmb_prefix, 0) if prefix_pos >= 0: for i in range(prefix_pos, prefix_pos + len(fmb_prefix)): fmb_scaffold_positions.add(i) long_value_start = prefix_pos + len(fmb_prefix) # 2. Mark fmb_between as scaffold: '", "lateral": "' fmb_between = deep_scaffold_sequences.get("fmb_between") if fmb_between: between_pos = _find_subseq(fmb_tokens, fmb_between, long_value_start) if between_pos >= 0: for i in range(between_pos, between_pos + len(fmb_between)): fmb_scaffold_positions.add(i) lat_value_start = between_pos + len(fmb_between) # 3. Mark closing '"}' or "}," as scaffold fmb_close = deep_scaffold_sequences["fmb_closing"] close_pos = _find_subseq(fmb_tokens, fmb_close, max(0, len(fmb_tokens) - len(fmb_close) - 2)) if close_pos < 0: fmb_close_comma = deep_scaffold_sequences.get("fmb_closing_comma") if fmb_close_comma: close_pos = _find_subseq(fmb_tokens, fmb_close_comma, max(0, len(fmb_tokens) - len(fmb_close_comma) - 2)) if close_pos >= 0: fmb_close = fmb_close_comma if close_pos >= 0: for i in range(close_pos, close_pos + len(fmb_close)): fmb_scaffold_positions.add(i) for i in fmb_scaffold_positions: scaffold_positions.append(fmb_start + i) # ── Step 4: Deep scaffold within trajectory ── # After dataloader processing (no mdm markers), trajectory is: # ' "[[+14.70,-00.04], [+29.55,-00.21], ...]"' if "trajectory" in sections: traj_start, traj_end = sections["trajectory"] traj_tokens = resp_tokens[traj_start:traj_end] # Opening "[[ traj_open = deep_scaffold_sequences["traj_open"] open_pos = _find_subseq(traj_tokens, traj_open, 0) if open_pos >= 0: _mark_scaffold_range(scaffold_positions, traj_start + open_pos, len(traj_open)) # Waypoint separators ], (4 of them between 5 waypoints) traj_wp_sep = deep_scaffold_sequences["traj_wp_sep"] sep_search = 0 for _ in range(4): sep_pos = _find_subseq(traj_tokens, traj_wp_sep, sep_search) if sep_pos < 0: break _mark_scaffold_range(scaffold_positions, traj_start + sep_pos, len(traj_wp_sep)) sep_search = sep_pos + len(traj_wp_sep) # Intermediate waypoint opening ' [' (4 of them, between 5 waypoints) traj_wp_open = deep_scaffold_sequences.get("traj_wp_open") if traj_wp_open: wo_search = 0 for _ in range(4): wo_pos = _find_subseq(traj_tokens, traj_wp_open, wo_search) if wo_pos < 0: break _mark_scaffold_range(scaffold_positions, traj_start + wo_pos, len(traj_wp_open)) wo_search = wo_pos + len(traj_wp_open) # Coordinate comma ',' between x and y within each waypoint (5 of them) traj_coord_comma = deep_scaffold_sequences.get("traj_coord_comma") if traj_coord_comma: cc_search = 0 for _ in range(5): cc_pos = _find_subseq(traj_tokens, traj_coord_comma, cc_search) if cc_pos < 0: break _mark_scaffold_range(scaffold_positions, traj_start + cc_pos, len(traj_coord_comma)) cc_search = cc_pos + len(traj_coord_comma) # Closing ]]" or just ]] traj_close = deep_scaffold_sequences["traj_close"] close_pos = _find_subseq(traj_tokens, traj_close, max(0, len(traj_tokens) - len(traj_close) - 6)) if close_pos < 0: for split_key in ["traj_close_split", "traj_close_split2"]: tcs = deep_scaffold_sequences.get(split_key) if tcs: close_pos = _find_subseq(traj_tokens, tcs, max(0, len(traj_tokens) - len(tcs) - 6)) if close_pos >= 0: traj_close = tcs break if close_pos >= 0: _mark_scaffold_range(scaffold_positions, traj_start + close_pos, len(traj_close)) # Align training with inference scaffold: exclude trailing tokens # after the JSON closing of trajectory (e.g. "<|im_end|>\n") from # section/block scheduling. effective_resp_end = min( effective_resp_end, resp_start + traj_start + close_pos + len(traj_close), ) # Opening quote " (first token of traj value) if len(traj_tokens) > 0: scaffold_positions.append(traj_start) # ── Mark scaffold mask (absolute positions) ── scaffold_positions_set = set(scaffold_positions) for sp in scaffold_positions_set: abs_pos = resp_start + sp if abs_pos < seq_len: scaffold_mask[abs_pos] = True # ── Assign block indices per section ── current_block = 0 assigned = set() block_to_section = {} # block_idx -> section_name (for SASD compatibility) section_first_block: Dict[str, int] = {} for section_name in SECTION_KEYS: if section_name not in sections: continue rel_start, rel_end = sections[section_name] abs_start = resp_start + rel_start abs_end = resp_start + rel_end abs_start = max(abs_start, resp_start) abs_end = min(abs_end, effective_resp_end) num_tokens = abs_end - abs_start if num_tokens <= 0: continue # Count only non-scaffold tokens for block sizing value_positions = [p for p in range(abs_start, abs_end) if response_mask[p] and (p - resp_start) not in scaffold_positions_set] num_value_tokens = len(value_positions) if num_value_tokens <= 0: section_first_block[section_name] = current_block block_to_section[current_block] = section_name current_block += 1 continue # Use fixed block size (bd_size) and compute number of blocks dynamically tokens_per_step = fallback_block_size n_steps = max(1, math.ceil(num_value_tokens / tokens_per_step)) for b in range(n_steps): block_to_section[current_block + b] = section_name section_first_block[section_name] = current_block for vi, pos in enumerate(value_positions): block_in_section = min(vi // tokens_per_step, n_steps - 1) response_block_idx[pos] = current_block + block_in_section assigned.add(pos) current_block += n_steps # Assign scaffold tokens within each section to the nearest value token # in the SAME section. This keeps section-closing tokens such as `"},` # with their section instead of drifting to the next section. for section_name in SECTION_KEYS: if section_name not in sections: continue rel_start, rel_end = sections[section_name] abs_start = max(resp_start + rel_start, resp_start) abs_end = min(resp_start + rel_end, resp_end) if abs_end <= abs_start: continue for abs_pos in range(abs_start, abs_end): rel_pos = abs_pos - resp_start if ( abs_pos >= seq_len or not response_mask[abs_pos] or abs_pos in assigned or rel_pos not in scaffold_positions_set ): continue best_block = -1 max_delta = max(1, abs_end - abs_start) for delta in range(1, max_delta + 1): # Prefer left first so closing punctuation tends to stay with # the preceding content in the same section. for cand in [abs_pos - delta, abs_pos + delta]: if abs_start <= cand < abs_end and cand in assigned: best_block = response_block_idx[cand].item() break if best_block >= 0: break if best_block < 0: best_block = section_first_block.get(section_name, -1) if best_block >= 0: response_block_idx[abs_pos] = best_block assigned.add(abs_pos) # Top-level boundary tokens are explicitly attached to the following # section's first block, instead of nearest-neighbor assignment. for section_name, rel_positions in boundary_scaffold_to_section.items(): first_block = section_first_block.get(section_name) if first_block is None: continue for rel_pos in rel_positions: abs_pos = resp_start + rel_pos if abs_pos >= seq_len or not response_mask[abs_pos]: continue response_block_idx[abs_pos] = first_block assigned.add(abs_pos) # Scaffold tokens → block index of nearest assigned neighbour for sp in scaffold_positions_set: abs_pos = resp_start + sp if ( abs_pos >= seq_len or abs_pos >= effective_resp_end or not response_mask[abs_pos] or abs_pos in assigned ): continue best_block = -1 for delta in range(1, seq_len): for cand in [abs_pos + delta, abs_pos - delta]: if 0 <= cand < seq_len and cand in assigned: best_block = response_block_idx[cand].item() break if best_block >= 0: break if best_block >= 0: response_block_idx[abs_pos] = best_block assigned.add(abs_pos) # Fallback for unassigned response tokens for pos in range(resp_start, effective_resp_end): if response_mask[pos] and pos not in assigned: offset = pos - resp_start response_block_idx[pos] = current_block + offset // fallback_block_size assigned.add(pos) fallback_positions = [p for p in range(resp_start, effective_resp_end) if response_mask[p] and response_block_idx[p].item() >= current_block] if fallback_positions: current_block = max(response_block_idx[p].item() for p in fallback_positions) + 1 n_blocks = current_block # Turn index for i in range(1, seq_len): if response_block_idx[i] != response_block_idx[i - 1]: turn_idx[i] = turn_idx[i - 1] + 1 else: turn_idx[i] = turn_idx[i - 1] return response_block_idx, turn_idx, n_blocks, scaffold_mask, block_to_section