sandbox-5ca717e4 / network_admin_llm_train.py
Justin-lee's picture
Add network admin LLM training script
bbf6ae6 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
Network Admin LLM - QLoRA Fine-tuning Script
=============================================
Base Model: microsoft/Phi-4-mini-instruct
Method: QLoRA SFT (4-bit quantization + LoRA)
Datasets: NetEval + Telecom Intent Config
Run locally with GPU:
pip install transformers trl peft bitsandbytes accelerate datasets trackio
python network_admin_llm_train.py
Or on Google Colab:
!pip install transformers trl peft bitsandbytes accelerate datasets trackio
%cd /content
!python network_admin_llm_train.py
Author: Network Admin LLM Project
'''
import os
import sys
import torch
from datetime import datetime
# ============== CONFIGURATION ==============
# ่ซ‹ไฟฎๆ”นไปฅไธ‹่จญๅฎš
MODEL_NAME = 'microsoft/Phi-4-mini-instruct'
HF_USERNAME = 'YOUR_HF_USERNAME' # ๆ”นๆˆไฝ ็š„ HuggingFace ็”จๆˆถๅ
HF_TOKEN = os.environ.get('HF_TOKEN', 'YOUR_HF_TOKEN') # HF token for upload
# ่จ“็ทด่ถ…ๅƒๆ•ธ
TRAINING_CONFIG = {
'learning_rate': 2e-4, # LoRA ้œ€่ฆ่ผƒ้ซ˜ๅญธ็ฟ’็އ
'num_epochs': 3,
'batch_size': 4,
'gradient_accumulation': 4, # effective batch = 16
'max_seq_length': 2048,
'lora_r': 16,
'lora_alpha': 32,
'lora_dropout': 0.05,
'warmup_ratio': 0.1,
}
OUTPUT_DIR = f'{HF_USERNAME}/network-admin-phi4-mini'
# ===========================================
def print_section(title):
print(f'\n{"="*60}')
print(f' {title}')
print('='*60)
def install_dependencies():
'''ๆชขๆŸฅไธฆๅฎ‰่ฃไพ่ณด'''
print_section('CHECKING DEPENDENCIES')
required = ['transformers', 'trl', 'peft', 'bitsandbytes', 'accelerate', 'datasets', 'trackio']
missing = []
for pkg in required:
try:
__import__(pkg.replace('-', '_'))
print(f'โœ… {pkg}')
except ImportError:
missing.append(pkg)
print(f'โŒ {pkg} - ้œ€่ฆๅฎ‰่ฃ')
if missing:
print(f'\n่ซ‹้‹่กŒ: pip install {" ".join(missing)}')
return False
return True
def load_and_prepare_datasets():
'''่ผ‰ๅ…ฅไธฆ่ฝ‰ๆ›ๆ•ธๆ“š้›†'''
from datasets import load_dataset, concatenate_datasets
print_section('LOADING DATASETS')
# 1. ่ผ‰ๅ…ฅ NetEval ่€ƒ่ฉฆ้กŒๅบซ
print('๐Ÿ“š ่ผ‰ๅ…ฅ NetEval ่€ƒ่ฉฆ้กŒๅบซ...')
neteval_dataset = load_dataset('NASP/neteval-exam', split='train')
print(f' NetEval: {len(neteval_dataset)} ้กŒ')
def convert_neteval(example):
'''ๅฐ‡ Q&A ๆ ผๅผ่ฝ‰ๆ›็‚บๅฐ่ฉฑๆ ผๅผ'''
question = example['Question']
options = f'\nA. {example.get("A", "")}\nB. {example.get("B", "")}\nC. {example.get("C", "")}\nD. {example.get("D", "")}'
answer = f'ๆญฃ็ขบ็ญ”ๆกˆๆ˜ฏ: {example["Answer"]}'
if example.get('Explanation'):
answer += f'\n\n๐Ÿ“– ่งฃ่ชช: {example["Explanation"]}'
return {
'messages': [
{'role': 'system', 'content': 'ไฝ ๆ˜ฏไธ€ไฝ็ถฒ่ทฏ็ฎก็†ๅฐˆๅฎถใ€‚่ซ‹ๅ›ž็ญ”้—œๆ–ผ็ถฒ่ทฏใ€ๅฎ‰ๅ…จใ€่ทฏ็”ฑใ€ไบคๆ›ๆฉŸใ€VLANใ€้˜ฒ็ซ็‰†็ญ‰ITๅŸบ็คŽ่จญๆ–ฝ็š„ๅ•้กŒใ€‚'},
{'role': 'user', 'content': f'{question}{options}'},
{'role': 'assistant', 'content': answer}
]
}
neteval_converted = neteval_dataset.map(
convert_neteval,
remove_columns=neteval_dataset.column_names,
desc='่ฝ‰ๆ› NetEval ๆ ผๅผ'
)
# 2. ่ผ‰ๅ…ฅ้›ปไฟกๆ„ๅœ–้…็ฝฎๆ•ธๆ“š้›†
print('๐Ÿ“š ่ผ‰ๅ…ฅ้›ปไฟกๆ„ๅœ–้…็ฝฎๆ•ธๆ“š้›†...')
telecom_dataset = load_dataset('nraptisss/telecom-intent-config-sft-10k', split='train')
print(f' Telecom: {len(telecom_dataset)} ๆข')
telecom_messages = telecom_dataset.map(
lambda x: {'messages': x['messages']},
remove_columns=[c for c in telecom_dataset.column_names if c != 'messages']
)
# 3. ๅˆไฝตๆ•ธๆ“š้›†
print('๐Ÿ”„ ๅˆไฝตๆ•ธๆ“š้›†...')
combined = concatenate_datasets([neteval_converted, telecom_messages])
split_data = combined.train_test_split(test_size=0.1, seed=42)
train_ds = split_data['train']
eval_ds = split_data['test']
print(f'\n๐Ÿ“Š ๆ•ธๆ“š้›†็ตฑ่จˆ:')
print(f' ่จ“็ทด้›†: {len(train_ds)} ๆข')
print(f' ้ฉ—่ญ‰้›†: {len(eval_ds)} ๆข')
print(f' ็ธฝ่จˆ: {len(combined)} ๆข')
return train_ds, eval_ds
def setup_model_and_tokenizer():
'''่จญ็ฝฎๆจกๅž‹ๅ’Œ tokenizer'''
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
print_section('LOADING MODEL')
print(f'๐Ÿค– ๆจกๅž‹: {MODEL_NAME}')
# Tokenizer
print('\n๐Ÿ“ ่ผ‰ๅ…ฅ Tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'
print(f' Vocab size: {len(tokenizer):,}')
# QLoRA ้…็ฝฎ (4-bit)
print('\nโšก ้…็ฝฎ QLoRA (4-bit)...')
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4', # Normalized Float4
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True, # ๅตŒๅฅ—้‡ๅŒ–
)
# ่ผ‰ๅ…ฅๆจกๅž‹
print('๐Ÿ“ฅ ่ผ‰ๅ…ฅๆจกๅž‹ (4-bit)...')
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map='auto',
trust_remote_code=True,
)
# ๆบ–ๅ‚™ kbit ่จ“็ทด
model = prepare_model_for_kbit_training(model)
print('โœ… ๆจกๅž‹ๆบ–ๅ‚™ๅฎŒๆˆ')
# LoRA ้…็ฝฎ
print('\n๐Ÿ”ง ้…็ฝฎ LoRA...')
lora_config = LoraConfig(
r=TRAINING_CONFIG['lora_r'],
lora_alpha=TRAINING_CONFIG['lora_alpha'],
lora_dropout=TRAINING_CONFIG['lora_dropout'],
bias='none',
task_type='CAUSAL_LM',
target_modules=[
'q_proj', 'k_proj', 'v_proj', 'o_proj', # Attention
'gate_proj', 'up_proj', 'down_proj', # MLP
],
modules_to_save=['lm_head', 'embed_tokens'],
)
# ๆ‡‰็”จ LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
return model, tokenizer, lora_config
def setup_trainer(model, tokenizer, train_ds, eval_ds, lora_config):
'''่จญ็ฝฎ่จ“็ทดๅ™จ'''
from trl import SFTTrainer, SFTConfig
print_section('CONFIGURING TRAINER')
# ็”Ÿๆˆ้‹่กŒๅ็จฑ
run_name = f'phi4-netadmin-{datetime.now().strftime("%m%d-%H%M")}'
# ๅ˜—่ฉฆๅˆๅง‹ๅŒ– trackio
try:
import trackio
trackio.init(project='network-admin-llm', experiment='qlora-sft', run_name=run_name)
print('โœ… Trackio ๅˆๅง‹ๅŒ–ๆˆๅŠŸ')
report_to = ['trackio']
except Exception as e:
print(f'โš ๏ธ Trackio ๅˆๅง‹ๅŒ–ๅคฑๆ•—: {e}')
report_to = ['none']
# SFT ้…็ฝฎ
training_args = SFTConfig(
# ๅญธ็ฟ’็އ
learning_rate=TRAINING_CONFIG['learning_rate'],
lr_scheduler_type='cosine',
warmup_ratio=TRAINING_CONFIG['warmup_ratio'],
# ่จ“็ทด
num_train_epochs=TRAINING_CONFIG['num_epochs'],
per_device_train_batch_size=TRAINING_CONFIG['batch_size'],
gradient_accumulation_steps=TRAINING_CONFIG['gradient_accumulation'],
max_seq_length=TRAINING_CONFIG['max_seq_length'],
# ่จ˜ๆ†ถ้ซ”ๅ„ชๅŒ–
gradient_checkpointing=True,
bf16=True,
fp16=False,
# ่ผธๅ‡บ
output_dir='./output',
logging_steps=10,
save_steps=500,
save_total_limit=2,
evaluation_strategy='steps',
eval_steps=500,
# Hub ไธŠๅ‚ณ
push_to_hub=True,
hub_model_id=OUTPUT_DIR,
hub_strategy='checkpoint',
# ็›ฃๆŽง
report_to=report_to,
logging_strategy='steps',
logging_first_step=True,
# ้›œ้ …
remove_unused_columns=False,
dataloader_num_workers=4,
seed=42,
)
# ๅ‰ตๅปบ trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
peft_config=lora_config,
)
return trainer, run_name
def train_model(trainer):
'''ๅŸท่กŒ่จ“็ทด'''
print_section('STARTING TRAINING')
print('๐Ÿš€ ้–‹ๅง‹่จ“็ทด...')
print(' (ๆŒ‰ Ctrl+C ๅฏ้šจๆ™‚ไธญๆ–ท)')
print()
try:
trainer.train()
print('\nโœ… ่จ“็ทดๅฎŒๆˆ!')
return True
except KeyboardInterrupt:
print('\nโš ๏ธ ่จ“็ทด่ขซ็”จๆˆถไธญๆ–ท')
return False
except Exception as e:
print(f'\nโŒ ่จ“็ทดๅคฑๆ•—: {e}')
raise
def save_and_upload(trainer):
'''ไฟๅญ˜ไธฆไธŠๅ‚ณๆจกๅž‹'''
print_section('SAVING & UPLOADING')
try:
print('๐Ÿ“ค ไธŠๅ‚ณๆจกๅž‹ๅˆฐ HuggingFace Hub...')
trainer.push_to_hub()
print(f'\nโœ… ๆจกๅž‹ๅทฒไธŠๅ‚ณ!')
print(f'๐Ÿ”— ้€ฃ็ต: https://huggingface.co/{OUTPUT_DIR}')
except Exception as e:
print(f'\nโš ๏ธ ไธŠๅ‚ณๅคฑๆ•—: {e}')
print('ๆจกๅž‹ๅทฒไฟๅญ˜ๅœจ ./output ็›ฎ้Œ„')
def main():
'''ไธปๅ‡ฝๆ•ธ'''
print('''
โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—
โ•‘ Network Admin LLM - QLoRA Fine-tuning โ•‘
โ•‘ Base: microsoft/Phi-4-mini-instruct โ•‘
โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
''')
# ๆชขๆŸฅ GPU
print(f'๐Ÿ–ฅ๏ธ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "็„ก GPU"}')
if torch.cuda.is_available():
print(f' Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')
# ๅฎ‰่ฃไพ่ณด
if not install_dependencies():
sys.exit(1)
# ่ผ‰ๅ…ฅๆ•ธๆ“š
train_ds, eval_ds = load_and_prepare_datasets()
# ่จญ็ฝฎๆจกๅž‹
model, tokenizer, lora_config = setup_model_and_tokenizer()
# ่จญ็ฝฎ trainer
trainer, run_name = setup_trainer(model, tokenizer, train_ds, eval_ds, lora_config)
# ่จ“็ทด
success = train_model(trainer)
# ไฟๅญ˜
if success:
save_and_upload(trainer)
print_section('DONE')
print(f'Run name: {run_name}')
if __name__ == '__main__':
main()