sakhi / scripts /train_unsloth.py
Tushar9802's picture
HF Space deploy β€” initial
745f62a
"""
Sakhi β€” Unsloth LoRA Fine-Tuning + Auto-Evaluation
Full pipeline:
1. Prepare training data (clean schema leakage, match production prompts)
2. LoRA fine-tune Gemma 4 E4B via Unsloth
3. Export to GGUF
4. Register in Ollama as "sakhi"
5. Run quality test suite (15 tests) against both base and fine-tuned
6. Print A/B comparison
Usage:
python scripts/train_unsloth.py # Full pipeline
python scripts/train_unsloth.py --dry-run # 10 steps, skip eval
python scripts/train_unsloth.py --eval-only # Skip training, just evaluate
python scripts/train_unsloth.py --config configs/training.yaml
"""
import argparse
import json
import os
import subprocess
import sys
import time
# Disable torch.compile β€” required for PyTorch nightly + Gemma 4 on SM 12.0
os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["PYTHONIOENCODING"] = "utf-8"
import yaml
def step_prepare_data():
"""Step 0: Re-prepare training data with cleaned pipeline."""
print("\n" + "=" * 60)
print("STEP 0: Preparing training data (cleaned)")
print("=" * 60)
result = subprocess.run(
[sys.executable, "scripts/prepare_training.py"],
capture_output=False, timeout=120,
)
if result.returncode != 0:
print("ABORT: Data preparation failed")
sys.exit(1)
# Verify no schema leakage
leaked = 0
with open("data/processed/train.jsonl", "r", encoding="utf-8") as f:
for line in f:
content = json.loads(line)["messages"][2]["content"]
if '"$schema"' in content:
leaked += 1
if leaked:
print(f"ABORT: {leaked} examples still have schema leakage after cleaning!")
sys.exit(1)
print("Data verification: CLEAN (no schema leakage)")
def step_train(config_path: str, dry_run: bool):
"""Step 1: LoRA fine-tuning via Unsloth."""
import torch
torch._dynamo.config.suppress_errors = True
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model_name = config["model"]["base_model"]
max_seq_length = config["model"]["max_seq_length"]
load_in_4bit = config["model"]["load_in_4bit"]
lora_config = config["lora"]
train_config = config["training"]
print("\n" + "=" * 60)
print("STEP 1: LoRA Fine-Tuning")
print(f"Model: {model_name}")
print(f"LR: {train_config['learning_rate']}, Epochs: {train_config['num_train_epochs']}")
print(f"LoRA r={lora_config['r']}, alpha={lora_config['lora_alpha']}, dropout={lora_config['lora_dropout']}")
print("=" * 60)
# Gate: training data exists
train_file = config["data"]["train_file"]
val_file = config["data"]["validation_file"]
for f_path in [train_file, val_file]:
if not os.path.exists(f_path):
print(f"ABORT: {f_path} not found")
sys.exit(1)
with open(train_file, "r", encoding="utf-8") as f:
train_count = sum(1 for _ in f)
with open(val_file, "r", encoding="utf-8") as f:
val_count = sum(1 for _ in f)
print(f"Training: {train_count} examples, Validation: {val_count} examples")
# Load model
from unsloth import FastLanguageModel
print(f"\nLoading {model_name}...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
dtype=None,
)
# Apply LoRA
print("Applying LoRA adapter...")
model = FastLanguageModel.get_peft_model(
model,
r=lora_config["r"],
lora_alpha=lora_config["lora_alpha"],
lora_dropout=lora_config["lora_dropout"],
target_modules=lora_config["target_modules"],
bias=lora_config["bias"],
use_gradient_checkpointing="unsloth",
random_state=config["experiment"]["seed"],
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({trainable/total*100:.2f}%)")
# Load dataset
from datasets import Dataset
def load_jsonl(path):
data = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
data.append(json.loads(line.strip()))
return data
train_data = load_jsonl(train_file)
val_data = load_jsonl(val_file)
def format_for_sft(examples):
formatted = []
for ex in examples:
text = tokenizer.apply_chat_template(
ex["messages"], tokenize=False, add_generation_prompt=False,
)
formatted.append({"text": text})
return formatted
train_dataset = Dataset.from_list(format_for_sft(train_data))
val_dataset = Dataset.from_list(format_for_sft(val_data))
print(f"Formatted {len(train_dataset)} train / {len(val_dataset)} val")
# Training
from trl import SFTTrainer
from transformers import TrainingArguments
max_steps = 10 if dry_run else -1
num_epochs = 1 if dry_run else train_config["num_train_epochs"]
training_args = TrainingArguments(
output_dir=train_config["output_dir"],
per_device_train_batch_size=train_config["per_device_train_batch_size"],
gradient_accumulation_steps=train_config["gradient_accumulation_steps"],
gradient_checkpointing=train_config["gradient_checkpointing"],
optim=train_config["optim"],
learning_rate=train_config["learning_rate"],
weight_decay=train_config["weight_decay"],
max_grad_norm=train_config["max_grad_norm"],
num_train_epochs=num_epochs,
max_steps=max_steps,
warmup_ratio=train_config["warmup_ratio"],
lr_scheduler_type=train_config["lr_scheduler_type"],
bf16=train_config["bf16"],
tf32=train_config["tf32"],
logging_steps=train_config["logging_steps"],
save_strategy=train_config["save_strategy"],
save_steps=train_config["save_steps"],
save_total_limit=train_config["save_total_limit"],
eval_strategy=train_config["evaluation_strategy"],
eval_steps=train_config["eval_steps"],
load_best_model_at_end=train_config["load_best_model_at_end"],
metric_for_best_model=train_config["metric_for_best_model"],
dataloader_num_workers=train_config["dataloader_num_workers"],
dataloader_pin_memory=train_config["dataloader_pin_memory"],
seed=config["experiment"]["seed"],
report_to="tensorboard",
logging_dir="./logs/tensorboard",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=val_dataset,
args=training_args,
dataset_text_field="text",
max_seq_length=max_seq_length,
packing=True,
)
tag = "(DRY RUN β€” 10 steps)" if dry_run else ""
print(f"\nStarting training {tag}...")
t0 = time.time()
trainer.train()
train_time = time.time() - t0
# Evaluate
print("\nRunning evaluation...")
eval_results = trainer.evaluate()
print(f"Eval loss: {eval_results.get('eval_loss', 'N/A'):.4f}")
print(f"Training time: {train_time/60:.1f} min")
# Save LoRA adapter
output_dir = os.path.join(train_config["output_dir"], "final")
print(f"\nSaving adapter to {output_dir}...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# Export to GGUF
if not dry_run:
export_config = config.get("export", {})
gguf_quant = export_config.get("gguf_quantization", "q4_k_m")
gguf_dir = export_config.get("output_dir", "./models/exported")
os.makedirs(gguf_dir, exist_ok=True)
print(f"\nExporting to GGUF ({gguf_quant})...")
model.save_pretrained_gguf(gguf_dir, tokenizer, quantization_method=gguf_quant)
print(f"GGUF exported to {gguf_dir}")
return not dry_run # return whether we should proceed to Ollama registration
def step_export_only(config_path: str):
"""Load saved adapter and export to GGUF (no training)."""
import torch
torch._dynamo.config.suppress_errors = True
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model_name = config["model"]["base_model"]
max_seq_length = config["model"]["max_seq_length"]
load_in_4bit = config["model"]["load_in_4bit"]
lora_config = config["lora"]
train_config = config["training"]
adapter_dir = os.path.join(train_config["output_dir"], "final")
print("\n" + "=" * 60)
print("EXPORT-ONLY: Loading adapter -> GGUF")
print(f"Adapter: {adapter_dir}")
print("=" * 60)
if not os.path.exists(os.path.join(adapter_dir, "adapter_model.safetensors")):
print(f"ABORT: No adapter at {adapter_dir}")
sys.exit(1)
from unsloth import FastLanguageModel
print(f"\nLoading base {model_name}...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
dtype=None,
)
print(f"Applying saved LoRA from {adapter_dir}...")
model = FastLanguageModel.get_peft_model(
model,
r=lora_config["r"],
lora_alpha=lora_config["lora_alpha"],
lora_dropout=lora_config["lora_dropout"],
target_modules=lora_config["target_modules"],
bias=lora_config["bias"],
use_gradient_checkpointing="unsloth",
random_state=config["experiment"]["seed"],
)
# Load adapter weights
from peft import PeftModel
model.load_adapter(adapter_dir, adapter_name="default")
export_config = config.get("export", {})
gguf_quant = export_config.get("gguf_quantization", "q4_k_m")
gguf_dir = export_config.get("output_dir", "./models/exported")
os.makedirs(gguf_dir, exist_ok=True)
print(f"\nExporting to GGUF ({gguf_quant})...")
model.save_pretrained_gguf(gguf_dir, tokenizer, quantization_method=gguf_quant)
print(f"GGUF exported to {gguf_dir}")
def step_register_ollama(config_path: str):
"""Step 2: Register model in Ollama."""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
model_name = config.get("export", {}).get("ollama_model_name", "sakhi")
gguf_dir = config.get("export", {}).get("output_dir", "./models/exported")
print("\n" + "=" * 60)
print(f"STEP 2: Registering '{model_name}' in Ollama")
print("=" * 60)
# Find GGUF file
gguf_files = [f for f in os.listdir(gguf_dir) if f.endswith(".gguf")]
if not gguf_files:
print(f"ABORT: No GGUF in {gguf_dir}")
return False
gguf_path = os.path.join(gguf_dir, gguf_files[0])
print(f"GGUF: {gguf_path} ({os.path.getsize(gguf_path) / 1e9:.1f} GB)")
# Write Modelfile pointing to GGUF
modelfile_content = f"""FROM {gguf_path}
TEMPLATE \"\"\"{{{{ if .System }}}}<start_of_turn>system
{{{{ .System }}}}<end_of_turn>
{{{{ end }}}}{{{{ if .Prompt }}}}<start_of_turn>user
{{{{ .Prompt }}}}<end_of_turn>
<start_of_turn>model
{{{{ end }}}}{{{{ .Response }}}}<end_of_turn>\"\"\"
PARAMETER temperature 0.1
PARAMETER num_ctx 4096
PARAMETER stop "<end_of_turn>"
PARAMETER stop "<eos>"
"""
modelfile_path = "configs/Modelfile"
with open(modelfile_path, "w") as f:
f.write(modelfile_content)
# Remove old model if exists
subprocess.run(["ollama", "rm", model_name], capture_output=True, timeout=30)
# Create new model
result = subprocess.run(
["ollama", "create", model_name, "-f", modelfile_path],
capture_output=True, text=True, timeout=300,
)
if result.returncode == 0:
print(f"Model '{model_name}' registered successfully")
return True
else:
print(f"Failed: {result.stderr}")
return False
def step_evaluate(config_path: str = "configs/training.yaml"):
"""Step 3: A/B evaluation β€” base vs fine-tuned on 15-test quality suite.
Saves results to RETRAIN_RESULTS.md in project root."""
print("\n" + "=" * 60)
print("STEP 3: A/B Evaluation (base vs fine-tuned)")
print("=" * 60)
# Run the quality test suite with both models
result = subprocess.run(
[sys.executable, "-u", "scripts/test_ollama_quality.py"],
capture_output=True, text=True, timeout=1200,
env={**os.environ, "PYTHONUNBUFFERED": "1", "PYTHONIOENCODING": "utf-8"},
)
full_output = result.stdout
print(full_output)
if result.stderr:
print(result.stderr)
# Parse per-model results
lines = full_output.strip().split("\n")
scores = {}
model_details = {} # model -> list of PASS/FAIL lines
current_model = None
for line in lines:
stripped = line.strip()
if "gemma4:e4b" in stripped and "=" * 5 not in stripped:
current_model = "gemma4:e4b-it-q4_K_M"
model_details[current_model] = []
elif "sakhi:" in stripped and "=" * 5 not in stripped:
current_model = "sakhi:latest"
model_details[current_model] = []
elif current_model and ("PASS" in stripped or "FAIL" in stripped) and "[" in stripped:
model_details[current_model].append(stripped)
if "%" in stripped and ("gemma4" in stripped or "sakhi" in stripped):
parts = stripped.split()
for i, p in enumerate(parts):
if "gemma4" in p or "sakhi" in p:
scores[p] = parts[0] if i > 0 else parts[i-1]
break
# Load training config for the report
with open(config_path, "r") as f:
config = yaml.safe_load(f)
lora = config.get("lora", {})
train = config.get("training", {})
# Determine winner and diagnose
base_n, sakhi_n = 0, 0
for model, score in scores.items():
try:
n = int(score.split("/")[0])
if "gemma4" in model:
base_n = n
elif "sakhi" in model:
sakhi_n = n
except ValueError:
pass
if sakhi_n > base_n:
verdict = "FINE-TUNED MODEL WINS β€” switch production to sakhi:latest"
action = "Set `OLLAMA_MODEL=sakhi:latest` in env or update app.py"
elif sakhi_n == base_n:
verdict = "TIE β€” fine-tuned matches base quality. May be faster (shorter outputs)."
action = "Check timing above. If sakhi is faster, consider switching."
else:
verdict = "BASE MODEL WINS β€” keep using gemma4:e4b-it-q4_K_M"
action = "Fine-tuning did not improve quality. Skip Unsloth track."
# Diagnose failures
sakhi_failures = [l for l in model_details.get("sakhi:latest", []) if "FAIL" in l]
base_failures = [l for l in model_details.get("gemma4:e4b-it-q4_K_M", []) if "FAIL" in l]
# Check for common failure patterns
diagnostics = []
sakhi_fail_text = "\n".join(sakhi_failures)
if "MISSING" in sakhi_fail_text:
diagnostics.append("Model is under-extracting (MISSING fields). Possible causes: LR too low (model didn't learn enough), or training data doesn't cover these patterns well.")
if "HALLUC" in sakhi_fail_text:
diagnostics.append("Model is hallucinating values. Possible causes: LR too high (overfitting to training data quirks), insufficient negative examples, or training data has noisy labels.")
if "WRONG" in sakhi_fail_text:
diagnostics.append("Model extracts wrong values. Could be: training data has mismatched transcript-extraction pairs, or model is confusing similar fields.")
if "FALSE_POS" in sakhi_fail_text:
diagnostics.append("Model over-flags danger signs. Need more negative (no-danger) training examples, or lower the danger sign oversampling ratio.")
if "FALSE_NEG" in sakhi_fail_text:
diagnostics.append("Model under-flags danger signs. Training data may not have enough diverse danger scenarios.")
if not diagnostics and sakhi_n < base_n:
diagnostics.append("No clear pattern in failures. The base model may simply be better at zero-shot extraction than a LoRA fine-tune on 981 examples can achieve.")
# Build markdown report
report = f"""# Retrain Results
**Date:** {time.strftime('%Y-%m-%d %H:%M')}
**Training config:** LR={train.get('learning_rate')}, epochs={train.get('num_train_epochs')}, LoRA r={lora.get('r')}, alpha={lora.get('lora_alpha')}, dropout={lora.get('lora_dropout')}
**Training data:** 981 examples (schema leakage fixed, trimmed danger schema)
## Scores
| Model | Score |
|-------|-------|
| gemma4:e4b-it-q4_K_M (base) | {scores.get('gemma4:e4b-it-q4_K_M', '?')} |
| sakhi:latest (fine-tuned) | {scores.get('sakhi:latest', '?')} |
## Verdict
**{verdict}**
{action}
## Base Model Details
```
{chr(10).join(model_details.get('gemma4:e4b-it-q4_K_M', ['No results']))}
```
## Fine-Tuned Model Details
```
{chr(10).join(model_details.get('sakhi:latest', ['No results']))}
```
## Diagnostics
"""
if diagnostics:
for d in diagnostics:
report += f"- {d}\n"
else:
report += "No issues detected.\n"
report += f"""
## What was fixed in this retrain (vs previous 9/15 attempt)
1. **Schema leakage removed** β€” 454/981 training examples had `$schema`, `title`, `description` in assistant output. Stripped.
2. **Trimmed danger schema** β€” training now uses the same trimmed schema as production (no checklists).
3. **System prompts match production** β€” exact same prompts in training and inference.
4. **LR reduced** β€” 2e-4 -> 5e-5 (4x lower to prevent overfitting).
5. **Epochs reduced** β€” 3 -> 1 (less overfitting on small dataset).
6. **LoRA alpha doubled** β€” 16 -> 32 (alpha=2*r is standard practice).
7. **Dropout added** β€” 0.0 -> 0.05 (regularization).
## If results are still bad, next steps to try
- Further lower LR to 2e-5
- Use only form_extraction examples (skip danger sign training, let base model handle it)
- Increase training data to 2000+ examples with better diversity
- Try r=8 instead of r=16 (smaller adapter, less capacity to overfit)
"""
# Write report
report_path = "RETRAIN_RESULTS.md"
with open(report_path, "w", encoding="utf-8") as f:
f.write(report)
print(f"\nResults saved to {report_path}")
print(f"\n>>> {verdict}")
return result.returncode
def main():
parser = argparse.ArgumentParser(description="Sakhi β€” Full Retrain Pipeline")
parser.add_argument("--config", type=str, default="configs/training.yaml")
parser.add_argument("--dry-run", action="store_true", help="10 steps only, skip export/eval")
parser.add_argument("--eval-only", action="store_true", help="Skip training, just evaluate")
parser.add_argument("--export-only", action="store_true", help="Load saved adapter β†’ export GGUF β†’ register β†’ evaluate")
args = parser.parse_args()
start = time.time()
if args.export_only:
step_export_only(args.config)
success = step_register_ollama(args.config)
if success:
_enable_sakhi_in_quality_test()
step_evaluate(args.config)
else:
print("\nSkipping evaluation β€” Ollama registration failed")
elif args.eval_only:
# Enable both models in quality test
_enable_sakhi_in_quality_test()
step_evaluate(args.config)
else:
# Full pipeline
step_prepare_data()
should_export = step_train(args.config, args.dry_run)
if should_export:
success = step_register_ollama(args.config)
if success:
_enable_sakhi_in_quality_test()
step_evaluate(args.config)
else:
print("\nSkipping evaluation β€” Ollama registration failed")
else:
print("\nDry run complete. Skipping export and evaluation.")
elapsed = time.time() - start
print(f"\nTotal pipeline time: {elapsed/60:.1f} min")
def _enable_sakhi_in_quality_test():
"""Temporarily enable sakhi in the quality test for A/B comparison."""
test_path = "scripts/test_ollama_quality.py"
with open(test_path, "r", encoding="utf-8") as f:
content = f.read()
# Uncomment sakhi if it's commented out
if '# "sakhi:latest"' in content:
content = content.replace(
'# "sakhi:latest"',
'"sakhi:latest"',
)
with open(test_path, "w", encoding="utf-8") as f:
f.write(content)
print("Enabled sakhi:latest in quality test for A/B comparison")
if __name__ == "__main__":
main()