Spaces:
Sleeping
Sleeping
| """Prepare MLX-LM-compatible train/valid files from existing SFT data.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Dict, List | |
| from transformers import AutoTokenizer | |
| ROOT = Path(__file__).resolve().parents[1] | |
| TRUNCATION_MARKER = "\n...[truncated observation]...\n" | |
| def load_jsonl(path: Path) -> List[Dict[str, object]]: | |
| with path.open("r", encoding="utf-8") as handle: | |
| return [json.loads(line) for line in handle if line.strip()] | |
| def dump_jsonl(path: Path, rows: List[Dict[str, object]]) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with path.open("w", encoding="utf-8") as handle: | |
| for row in rows: | |
| handle.write(json.dumps(row, sort_keys=True) + "\n") | |
| def trim_prompt_to_budget(prompt: str, tokenizer, budget: int) -> str: | |
| prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| if len(prompt_ids) <= budget: | |
| return prompt | |
| marker_ids = tokenizer.encode(TRUNCATION_MARKER, add_special_tokens=False) | |
| marker_len = len(marker_ids) | |
| if budget <= marker_len + 8: | |
| return tokenizer.decode(prompt_ids[-budget:]) | |
| remaining = budget - marker_len | |
| head_len = max(1, int(remaining * 0.55)) | |
| tail_len = max(1, remaining - head_len) | |
| trimmed_ids = prompt_ids[:head_len] + marker_ids + prompt_ids[-tail_len:] | |
| if len(trimmed_ids) > budget: | |
| trimmed_ids = trimmed_ids[:budget] | |
| return tokenizer.decode(trimmed_ids, skip_special_tokens=False) | |
| def rendered_length(prompt: str, completion: str, tokenizer) -> int: | |
| messages = [ | |
| {"role": "user", "content": prompt}, | |
| {"role": "assistant", "content": completion}, | |
| ] | |
| return len(tokenizer.apply_chat_template(messages, return_dict=False)) | |
| def normalize_record(record: Dict[str, object], tokenizer, max_seq_length: int) -> tuple[Dict[str, object] | None, Dict[str, int]]: | |
| prompt = str(record["prompt"]) | |
| completion = str(record["completion"]) | |
| stats = {"trimmed": 0, "dropped": 0} | |
| completion_ids = tokenizer.encode(completion, add_special_tokens=False) | |
| prompt_budget = max_seq_length - len(completion_ids) - 32 | |
| if prompt_budget <= 0: | |
| stats["dropped"] = 1 | |
| return None, stats | |
| normalized_prompt = trim_prompt_to_budget(prompt, tokenizer, prompt_budget) | |
| while rendered_length(normalized_prompt, completion, tokenizer) > max_seq_length and prompt_budget > 64: | |
| prompt_budget = max(64, int(prompt_budget * 0.9)) | |
| normalized_prompt = trim_prompt_to_budget(prompt, tokenizer, prompt_budget) | |
| if rendered_length(normalized_prompt, completion, tokenizer) > max_seq_length: | |
| stats["dropped"] = 1 | |
| return None, stats | |
| if normalized_prompt != prompt: | |
| stats["trimmed"] = 1 | |
| text = f"{normalized_prompt}\n{completion}" | |
| normalized = dict(record) | |
| normalized["prompt"] = normalized_prompt | |
| normalized["text"] = text | |
| return normalized, stats | |
| def transform_split(src: Path, dst: Path, tokenizer, max_seq_length: int) -> Dict[str, int]: | |
| rows = load_jsonl(src) | |
| normalized_rows: List[Dict[str, object]] = [] | |
| stats = {"input_examples": len(rows), "written_examples": 0, "trimmed_examples": 0, "dropped_examples": 0} | |
| for row in rows: | |
| normalized, row_stats = normalize_record(row, tokenizer, max_seq_length) | |
| stats["trimmed_examples"] += row_stats["trimmed"] | |
| stats["dropped_examples"] += row_stats["dropped"] | |
| if normalized is not None: | |
| normalized_rows.append(normalized) | |
| stats["written_examples"] = len(normalized_rows) | |
| dump_jsonl(dst, normalized_rows) | |
| return stats | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--source-root", default="artifacts/lora_qwen3_4b/data") | |
| parser.add_argument("--output-root", default="artifacts/mlx_qwen3_4b/data") | |
| parser.add_argument("--model", default="Qwen/Qwen3.5-4B") | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--include-valid", action="store_true") | |
| parser.add_argument("--force", action="store_true") | |
| args = parser.parse_args() | |
| source_root = (ROOT / args.source_root).resolve() | |
| output_root = (ROOT / args.output_root).resolve() | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| mapping = {source_root / "train.jsonl": output_root / "train.jsonl"} | |
| if args.include_valid: | |
| mapping[source_root / "eval.jsonl"] = output_root / "valid.jsonl" | |
| summary: Dict[str, object] = { | |
| "model": args.model, | |
| "max_seq_length": args.max_seq_length, | |
| "splits": {}, | |
| } | |
| for src, dst in mapping.items(): | |
| if not src.exists(): | |
| raise FileNotFoundError(f"Missing source file: {src}") | |
| if dst.exists() and not args.force: | |
| continue | |
| summary["splits"][dst.stem] = transform_split(src, dst, tokenizer, args.max_seq_length) | |
| valid_path = output_root / "valid.jsonl" | |
| if not args.include_valid and valid_path.exists(): | |
| valid_path.unlink() | |
| summary_path = output_root.parent / "prepare_stats.json" | |
| summary_path.write_text(json.dumps(summary, indent=2, sort_keys=True) + "\n", encoding="utf-8") | |
| print(output_root) | |
| print(json.dumps(summary, indent=2, sort_keys=True)) | |
| if __name__ == "__main__": | |
| main() | |