Justin-lee commited on
Commit
bbf6ae6
·
verified ·
1 Parent(s): fcb5d72

Add network admin LLM training script

Browse files
Files changed (1) hide show
  1. network_admin_llm_train.py +333 -0
network_admin_llm_train.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ Network Admin LLM - QLoRA Fine-tuning Script
5
+ =============================================
6
+ Base Model: microsoft/Phi-4-mini-instruct
7
+ Method: QLoRA SFT (4-bit quantization + LoRA)
8
+ Datasets: NetEval + Telecom Intent Config
9
+
10
+ Run locally with GPU:
11
+ pip install transformers trl peft bitsandbytes accelerate datasets trackio
12
+ python network_admin_llm_train.py
13
+
14
+ Or on Google Colab:
15
+ !pip install transformers trl peft bitsandbytes accelerate datasets trackio
16
+ %cd /content
17
+ !python network_admin_llm_train.py
18
+
19
+ Author: Network Admin LLM Project
20
+ '''
21
+
22
+ import os
23
+ import sys
24
+ import torch
25
+ from datetime import datetime
26
+
27
+ # ============== CONFIGURATION ==============
28
+ # 請修改以下設定
29
+ MODEL_NAME = 'microsoft/Phi-4-mini-instruct'
30
+ HF_USERNAME = 'YOUR_HF_USERNAME' # 改成你的 HuggingFace 用戶名
31
+ HF_TOKEN = os.environ.get('HF_TOKEN', 'YOUR_HF_TOKEN') # HF token for upload
32
+
33
+ # 訓練超參數
34
+ TRAINING_CONFIG = {
35
+ 'learning_rate': 2e-4, # LoRA 需要較高學習率
36
+ 'num_epochs': 3,
37
+ 'batch_size': 4,
38
+ 'gradient_accumulation': 4, # effective batch = 16
39
+ 'max_seq_length': 2048,
40
+ 'lora_r': 16,
41
+ 'lora_alpha': 32,
42
+ 'lora_dropout': 0.05,
43
+ 'warmup_ratio': 0.1,
44
+ }
45
+
46
+ OUTPUT_DIR = f'{HF_USERNAME}/network-admin-phi4-mini'
47
+ # ===========================================
48
+
49
+ def print_section(title):
50
+ print(f'\n{"="*60}')
51
+ print(f' {title}')
52
+ print('='*60)
53
+
54
+ def install_dependencies():
55
+ '''檢查並安裝依賴'''
56
+ print_section('CHECKING DEPENDENCIES')
57
+
58
+ required = ['transformers', 'trl', 'peft', 'bitsandbytes', 'accelerate', 'datasets', 'trackio']
59
+ missing = []
60
+
61
+ for pkg in required:
62
+ try:
63
+ __import__(pkg.replace('-', '_'))
64
+ print(f'✅ {pkg}')
65
+ except ImportError:
66
+ missing.append(pkg)
67
+ print(f'❌ {pkg} - 需要安裝')
68
+
69
+ if missing:
70
+ print(f'\n請運行: pip install {" ".join(missing)}')
71
+ return False
72
+ return True
73
+
74
+ def load_and_prepare_datasets():
75
+ '''載入並轉換數據集'''
76
+ from datasets import load_dataset, concatenate_datasets
77
+
78
+ print_section('LOADING DATASETS')
79
+
80
+ # 1. 載入 NetEval 考試題庫
81
+ print('📚 載入 NetEval 考試題庫...')
82
+ neteval_dataset = load_dataset('NASP/neteval-exam', split='train')
83
+ print(f' NetEval: {len(neteval_dataset)} 題')
84
+
85
+ def convert_neteval(example):
86
+ '''將 Q&A 格式轉換為對話格式'''
87
+ question = example['Question']
88
+ options = f'\nA. {example.get("A", "")}\nB. {example.get("B", "")}\nC. {example.get("C", "")}\nD. {example.get("D", "")}'
89
+
90
+ answer = f'正確答案是: {example["Answer"]}'
91
+ if example.get('Explanation'):
92
+ answer += f'\n\n📖 解說: {example["Explanation"]}'
93
+
94
+ return {
95
+ 'messages': [
96
+ {'role': 'system', 'content': '你是一位網路管理專家。請回答關於網路、安全、路由、交換機、VLAN、防火牆等IT基礎設施的問題。'},
97
+ {'role': 'user', 'content': f'{question}{options}'},
98
+ {'role': 'assistant', 'content': answer}
99
+ ]
100
+ }
101
+
102
+ neteval_converted = neteval_dataset.map(
103
+ convert_neteval,
104
+ remove_columns=neteval_dataset.column_names,
105
+ desc='轉換 NetEval 格式'
106
+ )
107
+
108
+ # 2. 載入電信意圖配置數據集
109
+ print('📚 載入電信意圖配置數據集...')
110
+ telecom_dataset = load_dataset('nraptisss/telecom-intent-config-sft-10k', split='train')
111
+ print(f' Telecom: {len(telecom_dataset)} 條')
112
+
113
+ telecom_messages = telecom_dataset.map(
114
+ lambda x: {'messages': x['messages']},
115
+ remove_columns=[c for c in telecom_dataset.column_names if c != 'messages']
116
+ )
117
+
118
+ # 3. 合併數據集
119
+ print('🔄 合併數據集...')
120
+ combined = concatenate_datasets([neteval_converted, telecom_messages])
121
+ split_data = combined.train_test_split(test_size=0.1, seed=42)
122
+
123
+ train_ds = split_data['train']
124
+ eval_ds = split_data['test']
125
+
126
+ print(f'\n📊 數據集統計:')
127
+ print(f' 訓練集: {len(train_ds)} 條')
128
+ print(f' 驗證集: {len(eval_ds)} 條')
129
+ print(f' 總計: {len(combined)} 條')
130
+
131
+ return train_ds, eval_ds
132
+
133
+ def setup_model_and_tokenizer():
134
+ '''設置模型和 tokenizer'''
135
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
136
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
137
+
138
+ print_section('LOADING MODEL')
139
+ print(f'🤖 模型: {MODEL_NAME}')
140
+
141
+ # Tokenizer
142
+ print('\n📝 載入 Tokenizer...')
143
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
144
+ tokenizer.pad_token = tokenizer.eos_token
145
+ tokenizer.padding_side = 'right'
146
+ print(f' Vocab size: {len(tokenizer):,}')
147
+
148
+ # QLoRA 配置 (4-bit)
149
+ print('\n��� 配置 QLoRA (4-bit)...')
150
+ bnb_config = BitsAndBytesConfig(
151
+ load_in_4bit=True,
152
+ bnb_4bit_quant_type='nf4', # Normalized Float4
153
+ bnb_4bit_compute_dtype=torch.bfloat16,
154
+ bnb_4bit_use_double_quant=True, # 嵌套量化
155
+ )
156
+
157
+ # 載入模型
158
+ print('📥 載入模型 (4-bit)...')
159
+ model = AutoModelForCausalLM.from_pretrained(
160
+ MODEL_NAME,
161
+ quantization_config=bnb_config,
162
+ device_map='auto',
163
+ trust_remote_code=True,
164
+ )
165
+
166
+ # 準備 kbit 訓練
167
+ model = prepare_model_for_kbit_training(model)
168
+ print('✅ 模型準備完成')
169
+
170
+ # LoRA 配置
171
+ print('\n🔧 配置 LoRA...')
172
+ lora_config = LoraConfig(
173
+ r=TRAINING_CONFIG['lora_r'],
174
+ lora_alpha=TRAINING_CONFIG['lora_alpha'],
175
+ lora_dropout=TRAINING_CONFIG['lora_dropout'],
176
+ bias='none',
177
+ task_type='CAUSAL_LM',
178
+ target_modules=[
179
+ 'q_proj', 'k_proj', 'v_proj', 'o_proj', # Attention
180
+ 'gate_proj', 'up_proj', 'down_proj', # MLP
181
+ ],
182
+ modules_to_save=['lm_head', 'embed_tokens'],
183
+ )
184
+
185
+ # 應用 LoRA
186
+ model = get_peft_model(model, lora_config)
187
+ model.print_trainable_parameters()
188
+
189
+ return model, tokenizer, lora_config
190
+
191
+ def setup_trainer(model, tokenizer, train_ds, eval_ds, lora_config):
192
+ '''設置訓練器'''
193
+ from trl import SFTTrainer, SFTConfig
194
+
195
+ print_section('CONFIGURING TRAINER')
196
+
197
+ # 生成運行名稱
198
+ run_name = f'phi4-netadmin-{datetime.now().strftime("%m%d-%H%M")}'
199
+
200
+ # 嘗試初始化 trackio
201
+ try:
202
+ import trackio
203
+ trackio.init(project='network-admin-llm', experiment='qlora-sft', run_name=run_name)
204
+ print('✅ Trackio 初始化成功')
205
+ report_to = ['trackio']
206
+ except Exception as e:
207
+ print(f'⚠️ Trackio 初始化失敗: {e}')
208
+ report_to = ['none']
209
+
210
+ # SFT 配置
211
+ training_args = SFTConfig(
212
+ # 學習率
213
+ learning_rate=TRAINING_CONFIG['learning_rate'],
214
+ lr_scheduler_type='cosine',
215
+ warmup_ratio=TRAINING_CONFIG['warmup_ratio'],
216
+
217
+ # 訓練
218
+ num_train_epochs=TRAINING_CONFIG['num_epochs'],
219
+ per_device_train_batch_size=TRAINING_CONFIG['batch_size'],
220
+ gradient_accumulation_steps=TRAINING_CONFIG['gradient_accumulation'],
221
+ max_seq_length=TRAINING_CONFIG['max_seq_length'],
222
+
223
+ # 記憶體優化
224
+ gradient_checkpointing=True,
225
+ bf16=True,
226
+ fp16=False,
227
+
228
+ # 輸出
229
+ output_dir='./output',
230
+ logging_steps=10,
231
+ save_steps=500,
232
+ save_total_limit=2,
233
+ evaluation_strategy='steps',
234
+ eval_steps=500,
235
+
236
+ # Hub 上傳
237
+ push_to_hub=True,
238
+ hub_model_id=OUTPUT_DIR,
239
+ hub_strategy='checkpoint',
240
+
241
+ # 監控
242
+ report_to=report_to,
243
+ logging_strategy='steps',
244
+ logging_first_step=True,
245
+
246
+ # 雜項
247
+ remove_unused_columns=False,
248
+ dataloader_num_workers=4,
249
+ seed=42,
250
+ )
251
+
252
+ # 創建 trainer
253
+ trainer = SFTTrainer(
254
+ model=model,
255
+ args=training_args,
256
+ train_dataset=train_ds,
257
+ eval_dataset=eval_ds,
258
+ processing_class=tokenizer,
259
+ peft_config=lora_config,
260
+ )
261
+
262
+ return trainer, run_name
263
+
264
+ def train_model(trainer):
265
+ '''執行訓練'''
266
+ print_section('STARTING TRAINING')
267
+ print('🚀 開始訓練...')
268
+ print(' (按 Ctrl+C 可隨時中斷)')
269
+ print()
270
+
271
+ try:
272
+ trainer.train()
273
+ print('\n✅ 訓練完成!')
274
+ return True
275
+ except KeyboardInterrupt:
276
+ print('\n⚠️ 訓練被用戶中斷')
277
+ return False
278
+ except Exception as e:
279
+ print(f'\n❌ 訓練失敗: {e}')
280
+ raise
281
+
282
+ def save_and_upload(trainer):
283
+ '''保存並上傳模型'''
284
+ print_section('SAVING & UPLOADING')
285
+
286
+ try:
287
+ print('📤 上傳模型到 HuggingFace Hub...')
288
+ trainer.push_to_hub()
289
+ print(f'\n✅ 模型已上傳!')
290
+ print(f'🔗 連結: https://huggingface.co/{OUTPUT_DIR}')
291
+ except Exception as e:
292
+ print(f'\n⚠️ 上傳失敗: {e}')
293
+ print('模型已保存在 ./output 目錄')
294
+
295
+ def main():
296
+ '''主函數'''
297
+ print('''
298
+ ╔═══════════════════════════════════════════════════════════╗
299
+ ║ Network Admin LLM - QLoRA Fine-tuning ║
300
+ ║ Base: microsoft/Phi-4-mini-instruct ║
301
+ ╚═══════════════════════════════════════════════════════════╝
302
+ ''')
303
+
304
+ # 檢查 GPU
305
+ print(f'🖥️ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "無 GPU"}')
306
+ if torch.cuda.is_available():
307
+ print(f' Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')
308
+
309
+ # 安裝依賴
310
+ if not install_dependencies():
311
+ sys.exit(1)
312
+
313
+ # 載入數據
314
+ train_ds, eval_ds = load_and_prepare_datasets()
315
+
316
+ # 設置模型
317
+ model, tokenizer, lora_config = setup_model_and_tokenizer()
318
+
319
+ # 設置 trainer
320
+ trainer, run_name = setup_trainer(model, tokenizer, train_ds, eval_ds, lora_config)
321
+
322
+ # 訓練
323
+ success = train_model(trainer)
324
+
325
+ # 保存
326
+ if success:
327
+ save_and_upload(trainer)
328
+
329
+ print_section('DONE')
330
+ print(f'Run name: {run_name}')
331
+
332
+ if __name__ == '__main__':
333
+ main()