nraptisss commited on
Commit
dd91b1f
·
verified ·
1 Parent(s): 1dcff4f

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +226 -0
train.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ QLoRA Fine-Tuning Script for Telecom Intent-to-Config Translation
3
+ Optimized for Kaggle T4x2 (2x T4 GPUs, ~30h/week free)
4
+
5
+ Dataset: nraptisss/TMF921-intent-to-config-augmented (or any dataset with 'messages' column)
6
+ Model: Qwen/Qwen2.5-7B-Instruct (or meta-llama/Llama-3.1-8B-Instruct)
7
+ Output: LoRA adapters saved locally, then merge_and_push.py merges and pushes
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import torch
13
+ from datasets import load_dataset
14
+ from transformers import (
15
+ AutoModelForCausalLM,
16
+ AutoTokenizer,
17
+ BitsAndBytesConfig,
18
+ )
19
+ from peft import LoraConfig
20
+ from trl import SFTConfig, SFTTrainer
21
+
22
+ # ============================================================================
23
+ # CONFIGURATION — EDIT THESE
24
+ # ============================================================================
25
+
26
+ # Model
27
+ MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" # or "meta-llama/Llama-3.1-8B-Instruct"
28
+
29
+ # Dataset
30
+ DATASET_NAME = "nraptisss/TMF921-intent-to-config-augmented"
31
+ DATASET_CONFIG = "default"
32
+ TRAIN_SPLIT = "train"
33
+ TEST_SPLIT = "test"
34
+
35
+ # Output
36
+ OUTPUT_DIR = "./qwen2.5-7b-telecom-intent-lora"
37
+
38
+ # Training hyperparameters (optimized for T4 16GB)
39
+ NUM_EPOCHS = 3
40
+ BATCH_SIZE = 1
41
+ GRAD_ACCUMULATION = 4 # effective batch = 4
42
+ LEARNING_RATE = 2.0e-4
43
+ MAX_LENGTH = 512
44
+ LORA_R = 64
45
+ LORA_ALPHA = 16
46
+ LORA_DROPOUT = 0.05
47
+
48
+ # ============================================================================
49
+ # SETUP
50
+ # ============================================================================
51
+
52
+ def setup():
53
+ """Verify GPU and set environment."""
54
+ if not torch.cuda.is_available():
55
+ print("WARNING: No GPU detected. This will be extremely slow on CPU.")
56
+ sys.exit(1)
57
+
58
+ gpu_count = torch.cuda.device_count()
59
+ print(f"Detected {gpu_count} GPU(s):")
60
+ for i in range(gpu_count):
61
+ props = torch.cuda.get_device_properties(i)
62
+ print(f" GPU {i}: {props.name} ({props.total_memory / 1e9:.1f} GB)")
63
+
64
+ # T4-specific: use fp16, not bf16
65
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
66
+ return gpu_count
67
+
68
+
69
+ def load_model_and_tokenizer(model_name: str):
70
+ """Load 4-bit quantized model and tokenizer."""
71
+ print(f"\nLoading model: {model_name}")
72
+
73
+ bnb_config = BitsAndBytesConfig(
74
+ load_in_4bit=True,
75
+ bnb_4bit_quant_type="nf4",
76
+ bnb_4bit_use_double_quant=True,
77
+ bnb_4bit_compute_dtype=torch.float16, # T4: fp16, not bf16
78
+ )
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(
81
+ model_name,
82
+ trust_remote_code=True,
83
+ padding_side="right",
84
+ )
85
+ if tokenizer.pad_token is None:
86
+ tokenizer.pad_token = tokenizer.eos_token
87
+ tokenizer.pad_token_id = tokenizer.eos_token_id
88
+
89
+ model = AutoModelForCausalLM.from_pretrained(
90
+ model_name,
91
+ quantization_config=bnb_config,
92
+ device_map="auto",
93
+ trust_remote_code=True,
94
+ torch_dtype=torch.float16,
95
+ )
96
+
97
+ # Enable gradient checkpointing for memory savings
98
+ model.gradient_checkpointing_enable()
99
+ model.enable_input_require_grads()
100
+
101
+ print(f"Model loaded. VRAM used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
102
+ return model, tokenizer
103
+
104
+
105
+ def load_and_inspect_dataset(dataset_name: str, config_name: str, split: str):
106
+ """Load dataset and verify messages column."""
107
+ print(f"\nLoading dataset: {dataset_name} (config={config_name}, split={split})")
108
+ ds = load_dataset(dataset_name, config_name, split=split)
109
+ print(f"Dataset size: {len(ds)} examples")
110
+
111
+ # Verify format
112
+ sample = ds[0]
113
+ if "messages" not in sample:
114
+ raise ValueError(
115
+ f"Dataset must have 'messages' column. Got: {list(sample.keys())}"
116
+ )
117
+
118
+ msgs = sample["messages"]
119
+ print(f"Sample messages structure: {len(msgs)} messages")
120
+ for m in msgs:
121
+ print(f" role={m.get('role')}, content_len={len(m.get('content', ''))}")
122
+
123
+ # Print a sample intent text
124
+ for m in msgs:
125
+ if m.get("role") == "user":
126
+ print(f"\nSample user intent:\n{m['content'][:200]}...")
127
+ break
128
+
129
+ return ds
130
+
131
+
132
+ def get_lora_config():
133
+ """Return LoRA config optimized for intent-to-config task."""
134
+ return LoraConfig(
135
+ r=LORA_R,
136
+ lora_alpha=LORA_ALPHA,
137
+ target_modules="all-linear",
138
+ lora_dropout=LORA_DROPOUT,
139
+ bias="none",
140
+ task_type="CAUSAL_LM",
141
+ )
142
+
143
+
144
+ def get_training_args(output_dir: str, num_gpus: int):
145
+ """Return SFTConfig optimized for Kaggle T4x2."""
146
+ return SFTConfig(
147
+ output_dir=output_dir,
148
+ num_train_epochs=NUM_EPOCHS,
149
+ per_device_train_batch_size=BATCH_SIZE,
150
+ per_device_eval_batch_size=BATCH_SIZE,
151
+ gradient_accumulation_steps=GRAD_ACCUMULATION,
152
+ learning_rate=LEARNING_RATE,
153
+ lr_scheduler_type="cosine",
154
+ warmup_ratio=0.05,
155
+ logging_steps=10,
156
+ save_strategy="epoch",
157
+ eval_strategy="epoch" if TEST_SPLIT else "no",
158
+ fp16=True,
159
+ bf16=False,
160
+ max_length=MAX_LENGTH,
161
+ gradient_checkpointing=True,
162
+ use_liger_kernel=True,
163
+ report_to="none",
164
+ load_best_model_at_end=False,
165
+ dataloader_num_workers=2,
166
+ remove_unused_columns=False,
167
+ )
168
+
169
+
170
+ def train(model, tokenizer, train_ds, eval_ds=None):
171
+ """Run SFT training with QLoRA."""
172
+ print("\n" + "=" * 60)
173
+ print("STARTING TRAINING")
174
+ print("=" * 60)
175
+
176
+ peft_config = get_lora_config()
177
+ training_args = get_training_args(OUTPUT_DIR, torch.cuda.device_count())
178
+
179
+ trainer = SFTTrainer(
180
+ model=model,
181
+ args=training_args,
182
+ train_dataset=train_ds,
183
+ eval_dataset=eval_ds,
184
+ processing_class=tokenizer,
185
+ peft_config=peft_config,
186
+ )
187
+
188
+ trainer.train()
189
+
190
+ # Save adapters
191
+ print(f"\nSaving LoRA adapters to {OUTPUT_DIR}")
192
+ trainer.save_model(OUTPUT_DIR)
193
+ tokenizer.save_pretrained(OUTPUT_DIR)
194
+
195
+ print("Training complete!")
196
+ return trainer
197
+
198
+
199
+ def main():
200
+ num_gpus = setup()
201
+
202
+ # Load everything
203
+ model, tokenizer = load_model_and_tokenizer(MODEL_NAME)
204
+ train_ds = load_and_inspect_dataset(DATASET_NAME, DATASET_CONFIG, TRAIN_SPLIT)
205
+
206
+ eval_ds = None
207
+ if TEST_SPLIT:
208
+ try:
209
+ eval_ds = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT)
210
+ print(f"Eval dataset: {len(eval_ds)} examples")
211
+ except Exception as e:
212
+ print(f"No eval split available: {e}")
213
+
214
+ # Train
215
+ trainer = train(model, tokenizer, train_ds, eval_ds)
216
+
217
+ print("\n" + "=" * 60)
218
+ print("NEXT STEPS:")
219
+ print("=" * 60)
220
+ print("1. Run inference.py to test the model")
221
+ print("2. Run merge_and_push.py to merge adapters and push to hub")
222
+ print("3. Run benchmark.py to evaluate on the test set")
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()