nkshirsa commited on
Commit
5db4c88
·
verified ·
1 Parent(s): 9f0a004

Add training script (QLoRA SFT on Qwen2.5-3B-Instruct)

Browse files
Files changed (1) hide show
  1. train.py +91 -0
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PhD Research OS — SFT Training Script
3
+ =======================================
4
+ Fine-tunes Qwen2.5-3B-Instruct using QLoRA on multi-task scientific research data.
5
+
6
+ Tasks trained:
7
+ 1. Scientific Claim Extraction (structured JSON output)
8
+ 2. Epistemic Classification (Fact/Interpretation/Hypothesis/Conflict_Hypothesis)
9
+ 3. Confidence Scoring (evidence_strength × study_quality × journal_tier × completeness)
10
+ 4. Contradiction Detection (claim pair → conflict analysis)
11
+ 5. Query Decomposition (broad question → sub-queries)
12
+ 6. Decision Object Generation (knowledge gaps → proposed research actions)
13
+
14
+ Base model: Qwen/Qwen2.5-3B-Instruct
15
+ Method: QLoRA (r=64, all-linear) following "LoRA Without Regret" recipe
16
+ Reference: arxiv:2212.05238 (LLM-NERRE), arxiv:2401.00579 (multi-task biomedical SFT)
17
+
18
+ Usage:
19
+ pip install torch transformers trl peft datasets bitsandbytes accelerate trackio
20
+ python train.py
21
+ """
22
+
23
+ import os
24
+ import torch
25
+ from datasets import load_dataset
26
+ from transformers import BitsAndBytesConfig
27
+ from peft import LoraConfig
28
+ from trl import SFTConfig, SFTTrainer
29
+ import trackio
30
+
31
+ MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
32
+ DATASET_NAME = "nkshirsa/phd-research-os-sft-data"
33
+ OUTPUT_DIR = "./phd-research-os-brain"
34
+ HUB_MODEL_ID = "nkshirsa/phd-research-os-brain"
35
+
36
+ trackio.init(project="phd-research-os-training", run="sft-qwen25-3b-qlora-v1")
37
+
38
+ print("Loading dataset...")
39
+ dataset = load_dataset(DATASET_NAME)
40
+ train_dataset = dataset["train"]
41
+ eval_dataset = dataset["test"]
42
+ print(f"Train: {len(train_dataset)} examples, Eval: {len(eval_dataset)} examples")
43
+
44
+ bnb_config = BitsAndBytesConfig(
45
+ load_in_4bit=True, bnb_4bit_use_double_quant=True,
46
+ bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
47
+
48
+ peft_config = LoraConfig(
49
+ r=64, lora_alpha=16, lora_dropout=0.05, bias="none",
50
+ task_type="CAUSAL_LM", target_modules="all-linear")
51
+
52
+ training_args = SFTConfig(
53
+ output_dir=OUTPUT_DIR,
54
+ num_train_epochs=3,
55
+ per_device_train_batch_size=2,
56
+ per_device_eval_batch_size=2,
57
+ gradient_accumulation_steps=8,
58
+ learning_rate=2e-4,
59
+ lr_scheduler_type="cosine",
60
+ warmup_ratio=0.05,
61
+ weight_decay=0.01,
62
+ max_grad_norm=1.0,
63
+ bf16=True,
64
+ gradient_checkpointing=True,
65
+ max_length=2048,
66
+ model_init_kwargs={"quantization_config": bnb_config, "torch_dtype": torch.bfloat16},
67
+ assistant_only_loss=True,
68
+ logging_steps=5, logging_first_step=True, disable_tqdm=True,
69
+ report_to=["tensorboard"], logging_dir=f"{OUTPUT_DIR}/logs",
70
+ eval_strategy="steps", eval_steps=50,
71
+ save_strategy="steps", save_steps=100, save_total_limit=3,
72
+ load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False,
73
+ push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy="every_save",
74
+ seed=42, data_seed=42)
75
+
76
+ trainer = SFTTrainer(
77
+ model=MODEL_NAME, args=training_args,
78
+ train_dataset=train_dataset, eval_dataset=eval_dataset,
79
+ peft_config=peft_config)
80
+
81
+ trainable = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
82
+ total = sum(p.numel() for p in trainer.model.parameters())
83
+ print(f"Model: {MODEL_NAME} | Total: {total:,} | Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
84
+
85
+ train_result = trainer.train()
86
+ trainer.save_model()
87
+ trainer.push_to_hub(commit_message="Final model: PhD Research OS Brain v1")
88
+
89
+ print(f"\nTraining complete! Model at: https://huggingface.co/{HUB_MODEL_ID}")
90
+ for k, v in train_result.metrics.items():
91
+ print(f" {k}: {v}")