| |
| """ |
| Convert UltraChat-style records (prompt, messages[], prompt_id) into a single |
| "text" field using the chat formatting expected by training/train_gemma_unsloth.py: |
| |
| <start_of_turn>user\n...<end_of_turn>\n<start_of_turn>model\n...<end_of_turn>\n(repeat) |
| |
| Usage: |
| python scripts/convert_ultrachat_to_text.py \ |
| --in sample_data/train_sft.jsonl \ |
| --out sample_data/train_sft_text.jsonl |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from typing import List, TypedDict, Any, Dict, cast |
|
|
| ROLE_MAP = { |
| "user": "user", |
| "assistant": "model", |
| } |
|
|
|
|
| class Msg(TypedDict): |
| content: str |
| role: str |
|
|
|
|
| def to_chat_text(messages: List[Msg]) -> str: |
| parts: List[str] = [] |
| for m in messages: |
| role = ROLE_MAP.get(m.get("role", "user"), "user") |
| content = (m.get("content", "") or "").rstrip() |
| parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>") |
| return "\n".join(parts) + "\n" |
|
|
|
|
| def convert(in_path: str, out_path: str) -> int: |
| os.makedirs(os.path.dirname(out_path), exist_ok=True) |
| n_in = 0 |
| n_out = 0 |
| with open(in_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout: |
| for line in fin: |
| if not line.strip(): |
| continue |
| n_in += 1 |
| obj: Any = json.loads(line) |
| raw: Any = obj.get("messages") |
| if not isinstance(raw, list) or not raw: |
| continue |
| all_dicts = True |
| for x_any in cast(List[Any], raw): |
| if not isinstance(x_any, dict): |
| all_dicts = False |
| break |
| if not all_dicts: |
| continue |
| raw_list: List[Dict[str, Any]] = cast(List[Dict[str, Any]], raw) |
| messages: List[Msg] = [] |
| for item in raw_list: |
| content_any = item.get("content") |
| role_any = item.get("role") |
| if not isinstance(content_any, str) or not isinstance(role_any, str): |
| messages = [] |
| break |
| messages.append({"content": content_any, "role": role_any}) |
| if not messages: |
| continue |
| text = to_chat_text(messages) |
| fout.write(json.dumps({"text": text}, ensure_ascii=False) + "\n") |
| n_out += 1 |
| print(f"Converted {n_out}/{n_in} records -> {out_path}") |
| return 0 |
|
|
|
|
| def main() -> int: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--in", dest="in_path", required=True) |
| ap.add_argument("--out", dest="out_path", required=True) |
| args = ap.parse_args() |
| return convert(args.in_path, args.out_path) |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|