| |
| """ |
| Unsloth fine-tuning runner for Gemma-3n-E4B-it. |
| - Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF). |
| - Output: PEFT adapter that can later be merged/exported to GGUF separately if desired. |
| |
| This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess. |
| """ |
| import argparse |
| import os |
| import json |
| import time |
| from pathlib import Path |
| from typing import Any, Dict |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| def _import_training_libs() -> Dict[str, Any]: |
| """Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT. |
| |
| Returns a dict with keys: |
| mode: "unsloth" | "hf" |
| load_dataset, SFTTrainer, SFTConfig |
| If mode=="unsloth": FastLanguageModel, AutoTokenizer |
| If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch |
| """ |
| |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
| from peft import get_peft_model, LoraConfig |
| import torch |
| return { |
| "load_dataset": load_dataset, |
| "AutoTokenizer": AutoTokenizer, |
| "AutoModelForCausalLM": AutoModelForCausalLM, |
| "get_peft_model": get_peft_model, |
| "LoraConfig": LoraConfig, |
| "Trainer": Trainer, |
| "TrainingArguments": TrainingArguments, |
| "torch": torch, |
| } |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--job-id", required=True) |
| p.add_argument("--output-dir", required=True) |
| p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file") |
| p.add_argument("--text-field", dest="text_field", default=None) |
| p.add_argument("--prompt-field", dest="prompt_field", default=None) |
| p.add_argument("--response-field", dest="response_field", default=None) |
| p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it") |
| p.add_argument("--epochs", type=int, default=1) |
| p.add_argument("--max-steps", dest="max_steps", type=int, default=None) |
| p.add_argument("--lr", type=float, default=2e-4) |
| p.add_argument("--batch-size", dest="batch_size", type=int, default=1) |
| p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8) |
| p.add_argument("--lora-r", dest="lora_r", type=int, default=16) |
| p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32) |
| p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096) |
| p.add_argument("--use-bf16", dest="use_bf16", action="store_true") |
| p.add_argument("--use-fp16", dest="use_fp16", action="store_true") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)") |
| p.add_argument("--grpo", dest="use_grpo", action="store_true", help="Enable GRPO (if supported by Unsloth)") |
| p.add_argument("--cpt", dest="use_cpt", action="store_true", help="Enable CPT (if supported by Unsloth)") |
| p.add_argument("--export-gguf", dest="export_gguf", action="store_true", help="Export model to GGUF Q4_K_XL after training") |
| p.add_argument("--gguf-out", dest="gguf_out", default=None, help="Path to save GGUF file (if exporting)") |
| return p.parse_args() |
|
|
|
|
| def _is_local_path(s: str) -> bool: |
| return os.path.exists(s) |
|
|
|
|
| def _load_dataset(load_dataset: Any, path: str) -> Any: |
| if _is_local_path(path): |
| |
| if path.endswith(".jsonl") or path.endswith(".jsonl.gz"): |
| return load_dataset("json", data_files=path, split="train") |
| elif path.endswith(".json"): |
| return load_dataset("json", data_files=path, split="train") |
| else: |
| raise ValueError("Unsupported local dataset format. Use JSON or JSONL.") |
| else: |
| return load_dataset(path, split="train") |
|
|
|
|
| def main(): |
| args = parse_args() |
| start = time.time() |
| out_dir = Path(args.output_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| (out_dir / "meta.json").write_text(json.dumps({ |
| "job_id": args.job_id, |
| "model_id": args.model_id, |
| "dataset": args.dataset, |
| "created_at": int(start), |
| }, indent=2)) |
|
|
| if args.dry_run: |
| (out_dir / "DONE").write_text("dry_run") |
| print("[train] Dry run complete. DONE written.") |
| return |
|
|
| |
| libs: Dict[str, Any] = _import_training_libs() |
| load_dataset = libs["load_dataset"] |
| AutoTokenizer = libs["AutoTokenizer"] |
| AutoModelForCausalLM = libs["AutoModelForCausalLM"] |
| get_peft_model = libs["get_peft_model"] |
| LoraConfig = libs["LoraConfig"] |
| Trainer = libs["Trainer"] |
| TrainingArguments = libs["TrainingArguments"] |
| torch = libs["torch"] |
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| print(f"[train] Loading base model: {args.model_id}") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True) |
| use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
| if not use_mps: |
| if args.use_fp16: |
| dtype = torch.float16 |
| elif args.use_bf16: |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
| else: |
| dtype = torch.float32 |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_id, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| ) |
| if use_mps: |
| model.to("mps") |
| print("[train] Attaching LoRA adapter (PEFT)") |
| lora_config = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], |
| lora_dropout=0.0, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
|
|
| |
| print(f"[train] Loading dataset: {args.dataset}") |
| ds = _load_dataset(load_dataset, args.dataset) |
|
|
| |
| text_field = args.text_field |
| prompt_field = args.prompt_field |
| response_field = args.response_field |
|
|
| if text_field: |
| |
| def format_row(ex: Dict[str, Any]) -> str: |
| if text_field not in ex: |
| raise KeyError(f"Missing required text field '{text_field}' in example: {ex}") |
| return ex[text_field] |
| elif prompt_field and response_field: |
| |
| def format_row(ex: Dict[str, Any]) -> str: |
| missing = [f for f in (prompt_field, response_field) if f not in ex] |
| if missing: |
| raise KeyError(f"Missing required field(s) {missing} in example: {ex}") |
| return ( |
| f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n" |
| f"<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n" |
| ) |
| else: |
| raise ValueError("Provide either --text-field or both --prompt-field and --response-field") |
|
|
| def map_fn(ex: Dict[str, Any]) -> Dict[str, str]: |
| return {"text": format_row(ex)} |
|
|
| ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"]) |
|
|
| |
| def tokenize_fn(ex): |
| return tokenizer( |
| ex["text"], |
| truncation=True, |
| max_length=args.cutoff_len, |
| padding="max_length", |
| ) |
| tokenized_ds = ds.map(tokenize_fn, batched=True) |
|
|
| |
| training_args = TrainingArguments( |
| output_dir=str(out_dir / "hf"), |
| per_device_train_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation, |
| learning_rate=args.lr, |
| num_train_epochs=args.epochs, |
| max_steps=args.max_steps if args.max_steps else -1, |
| logging_steps=10, |
| save_steps=200, |
| save_total_limit=2, |
| bf16=args.use_bf16, |
| fp16=args.use_fp16, |
| seed=args.seed, |
| report_to=[], |
| ) |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=tokenized_ds, |
| tokenizer=tokenizer, |
| ) |
|
|
| print("[train] Starting training...") |
| trainer.train() |
| print("[train] Saving adapter...") |
| adapter_path = out_dir / "adapter" |
| adapter_path.mkdir(parents=True, exist_ok=True) |
| try: |
| model.save_pretrained(str(adapter_path)) |
| except Exception as e: |
| logger.error("Error during model saving: %s", e, exc_info=True) |
| tokenizer.save_pretrained(str(adapter_path)) |
|
|
| |
| if args.export_gguf: |
| print("[train] Export to GGUF is not supported in Hugging Face-only mode. Use llama.cpp's convert-hf-to-gguf.py after training.") |
| gguf_path = args.gguf_out or str(out_dir / "adapter-gguf-q4_k_xl") |
| print(f"python convert-hf-to-gguf.py --outtype q4_k_xl --outfile {gguf_path} {adapter_path}") |
|
|
| |
| (out_dir / "DONE").write_text("ok") |
| elapsed = time.time() - start |
| print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|