nkshirsa commited on
Commit
fde181e
Β·
verified Β·
1 Parent(s): f6c2b19

Add upgraded SFT training script with SciRIFF data + proper QLoRA config

Browse files
phd_research_os_v2/training/train_sft_v2.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PhD Research OS β€” Upgraded SFT Training Script
3
+ =================================================
4
+ Stage 1 of the 4-stage training pipeline.
5
+
6
+ Changes from original train.py:
7
+ - Integrates SciRIFF data (72Γ— more training examples)
8
+ - Proper QLoRA configuration based on TRL v1.2.0 docs
9
+ - Trackio monitoring for loss tracking
10
+ - push_to_hub enabled (model not lost when job ends)
11
+ - Proper eval strategy with paper-level awareness
12
+ - Logging configured for headless training (no tqdm)
13
+
14
+ Usage:
15
+ python -m phd_research_os_v2.training.train_sft_v2
16
+
17
+ Dependencies:
18
+ pip install trl peft transformers datasets bitsandbytes accelerate trackio torch
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import json
24
+ import logging
25
+ import torch
26
+ from datetime import datetime
27
+
28
+ # ── Logging setup ─────────────────────────────────────────────────────
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
32
+ handlers=[logging.StreamHandler(sys.stdout)],
33
+ )
34
+ logger = logging.getLogger("train_sft_v2")
35
+
36
+ # ── Configuration ─────────────────────────────────────────────────────
37
+
38
+ # Model
39
+ BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen2.5-3B-Instruct")
40
+
41
+ # Data
42
+ EXISTING_DATASET = "nkshirsa/phd-research-os-sft-data"
43
+ SCIRIFF_MAX = int(os.environ.get("SCIRIFF_MAX", "8000")) # SciRIFF examples to include
44
+
45
+ # Training
46
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "3"))
47
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "2"))
48
+ GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
49
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "2e-4"))
50
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "2048"))
51
+ LORA_R = int(os.environ.get("LORA_R", "64"))
52
+ LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "16"))
53
+
54
+ # Output
55
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "./research-os-sft-v2")
56
+ HUB_MODEL_ID = os.environ.get("HUB_MODEL_ID", "nkshirsa/phd-research-os-brain-v2")
57
+ PUSH_TO_HUB = os.environ.get("PUSH_TO_HUB", "true").lower() == "true"
58
+
59
+
60
+ def main():
61
+ logger.info("=" * 60)
62
+ logger.info("PhD Research OS β€” SFT Training v2")
63
+ logger.info("=" * 60)
64
+ logger.info(f"Base model: {BASE_MODEL}")
65
+ logger.info(f"SciRIFF max examples: {SCIRIFF_MAX}")
66
+ logger.info(f"Epochs: {NUM_EPOCHS}, Batch: {BATCH_SIZE}, Grad accum: {GRAD_ACCUM}")
67
+ logger.info(f"LR: {LEARNING_RATE}, Max seq: {MAX_SEQ_LENGTH}")
68
+ logger.info(f"LoRA r={LORA_R}, alpha={LORA_ALPHA}")
69
+ logger.info(f"Output: {OUTPUT_DIR}")
70
+ logger.info(f"Push to hub: {PUSH_TO_HUB} β†’ {HUB_MODEL_ID}")
71
+
72
+ # ── 1. Setup Trackio monitoring ──────────────────────────────────
73
+ try:
74
+ import trackio
75
+ trackio.init(name="phd-research-os-sft-v2")
76
+ logger.info("Trackio monitoring initialized")
77
+ except ImportError:
78
+ logger.warning("Trackio not available β€” training will proceed without monitoring")
79
+
80
+ # ── 2. Load and merge datasets ───────────────────────────────────
81
+ logger.info("Loading datasets...")
82
+ from datasets import load_dataset, concatenate_datasets
83
+
84
+ # Load existing data
85
+ existing_ds = load_dataset(EXISTING_DATASET, split="train", trust_remote_code=True)
86
+ existing_test = load_dataset(EXISTING_DATASET, split="test", trust_remote_code=True)
87
+ logger.info(f"Existing dataset: {len(existing_ds)} train, {len(existing_test)} test")
88
+
89
+ # Load and convert SciRIFF
90
+ logger.info(f"Loading SciRIFF (max {SCIRIFF_MAX} examples)...")
91
+ try:
92
+ from phd_research_os_v2.training.sciriff_integration import load_sciriff
93
+ sciriff_examples = load_sciriff(config="4096", max_examples=SCIRIFF_MAX)
94
+
95
+ from datasets import Dataset
96
+ sciriff_ds = Dataset.from_list(sciriff_examples)
97
+
98
+ # Merge
99
+ train_ds = concatenate_datasets([existing_ds, sciriff_ds])
100
+ train_ds = train_ds.shuffle(seed=42)
101
+ logger.info(f"Merged: {len(existing_ds)} + {len(sciriff_ds)} = {len(train_ds)} training examples")
102
+ except Exception as e:
103
+ logger.warning(f"SciRIFF loading failed: {e}. Using existing data only.")
104
+ train_ds = existing_ds
105
+
106
+ test_ds = existing_test
107
+ logger.info(f"Final: {len(train_ds)} train, {len(test_ds)} test")
108
+
109
+ # ── 3. Load model with QLoRA quantization ────────────────────────
110
+ logger.info(f"Loading {BASE_MODEL} with 4-bit quantization...")
111
+
112
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
113
+
114
+ bnb_config = BitsAndBytesConfig(
115
+ load_in_4bit=True,
116
+ bnb_4bit_quant_type="nf4",
117
+ bnb_4bit_use_double_quant=True,
118
+ bnb_4bit_compute_dtype=torch.bfloat16,
119
+ )
120
+
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ BASE_MODEL,
123
+ quantization_config=bnb_config,
124
+ device_map="auto",
125
+ trust_remote_code=True,
126
+ )
127
+
128
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+
132
+ logger.info(f"Model loaded: {model.num_parameters():,} parameters")
133
+
134
+ # ── 4. Configure LoRA ────────────────────────────────────────────
135
+ from peft import LoraConfig, prepare_model_for_kbit_training
136
+
137
+ model = prepare_model_for_kbit_training(model)
138
+
139
+ peft_config = LoraConfig(
140
+ r=LORA_R,
141
+ lora_alpha=LORA_ALPHA,
142
+ lora_dropout=0.05,
143
+ bias="none",
144
+ task_type="CAUSAL_LM",
145
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
146
+ "gate_proj", "up_proj", "down_proj"],
147
+ )
148
+
149
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
150
+ total_params = sum(p.numel() for p in model.parameters())
151
+ logger.info(f"LoRA: r={LORA_R}, alpha={LORA_ALPHA}")
152
+ logger.info(f"Trainable: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.1%})")
153
+
154
+ # ── 5. Configure training ────────────────────────────────────────
155
+ from trl import SFTConfig, SFTTrainer
156
+
157
+ training_args = SFTConfig(
158
+ output_dir=OUTPUT_DIR,
159
+ num_train_epochs=NUM_EPOCHS,
160
+ per_device_train_batch_size=BATCH_SIZE,
161
+ per_device_eval_batch_size=BATCH_SIZE,
162
+ gradient_accumulation_steps=GRAD_ACCUM,
163
+ learning_rate=LEARNING_RATE,
164
+ lr_scheduler_type="cosine",
165
+ warmup_ratio=0.05,
166
+ bf16=True,
167
+ gradient_checkpointing=True,
168
+
169
+ # Logging β€” critical for headless training
170
+ logging_strategy="steps",
171
+ logging_steps=10,
172
+ logging_first_step=True,
173
+ disable_tqdm=True,
174
+ report_to="none",
175
+
176
+ # Evaluation
177
+ eval_strategy="steps",
178
+ eval_steps=200,
179
+ save_strategy="steps",
180
+ save_steps=500,
181
+ save_total_limit=3,
182
+ load_best_model_at_end=True,
183
+ metric_for_best_model="eval_loss",
184
+
185
+ # SFT-specific
186
+ max_seq_length=MAX_SEQ_LENGTH,
187
+ dataset_text_field=None, # Auto-detect 'messages' column
188
+ packing=False,
189
+
190
+ # Hub
191
+ push_to_hub=PUSH_TO_HUB,
192
+ hub_model_id=HUB_MODEL_ID,
193
+ hub_strategy="end",
194
+ )
195
+
196
+ # ── 6. Create trainer ────────────────────────────────────────────
197
+ trainer = SFTTrainer(
198
+ model=model,
199
+ args=training_args,
200
+ train_dataset=train_ds,
201
+ eval_dataset=test_ds,
202
+ peft_config=peft_config,
203
+ processing_class=tokenizer,
204
+ )
205
+
206
+ # ── 7. Train ─────────────────────────────────────────────────────
207
+ logger.info("Starting training...")
208
+ logger.info(f"Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
209
+ logger.info(f"Total steps: ~{len(train_ds) // (BATCH_SIZE * GRAD_ACCUM) * NUM_EPOCHS}")
210
+
211
+ result = trainer.train()
212
+
213
+ logger.info("Training complete!")
214
+ logger.info(f"Final metrics: {result.metrics}")
215
+
216
+ # ── 8. Save and push ─────────────────────────────────────────────
217
+ logger.info("Saving model...")
218
+ trainer.save_model()
219
+
220
+ if PUSH_TO_HUB:
221
+ logger.info(f"Pushing to hub: {HUB_MODEL_ID}")
222
+ trainer.push_to_hub()
223
+ logger.info(f"Model pushed: https://huggingface.co/{HUB_MODEL_ID}")
224
+
225
+ logger.info("=" * 60)
226
+ logger.info("SFT Training v2 β€” COMPLETE")
227
+ logger.info("=" * 60)
228
+
229
+ return result
230
+
231
+
232
+ if __name__ == "__main__":
233
+ main()