UIT.CS2229.ReACC / run_lm.py
TranTruongMMCII's picture
Upload 3 files
7a911f3 verified
from __future__ import annotations
from model_utils import evaluate_generation, generate_completion, load_model_and_tokenizer, perplexity_from_loss
from dataset import ReACCGeneratorDataset, collate_batch, load_jsonl
import argparse
import json
import math
import os
import random
import sys
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
if CURRENT_DIR not in sys.path:
sys.path.insert(0, CURRENT_DIR)
# Speed-ups on CUDA-capable environments.
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def set_seed(seed: int = 42):
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def get_safe_device(cpu_only: bool = False):
"""Use a supported GPU if available; otherwise fall back to CPU."""
if cpu_only or not torch.cuda.is_available():
return torch.device('cpu')
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
name = torch.cuda.get_device_name(i)
print(f'Detected GPU {i}: {name} (sm_{major}{minor})')
# current Kaggle torch builds generally support >= sm70 cleanly
if major >= 7:
torch.cuda.set_device(i)
print(f'[INFO] Using supported GPU {i}: {name}')
return torch.device(f'cuda:{i}')
print('[WARN] No supported GPU found for this PyTorch build. Falling back to CPU.')
return torch.device('cpu')
def build_dataloader(dataset, tokenizer, batch_size: int, shuffle: bool, num_workers: int = 0, pin_memory: bool = False):
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory,
collate_fn=lambda batch: collate_batch(batch, tokenizer.pad_token_id),
)
def train_one_epoch(model, loader, optimizer, scheduler, device, grad_accum: int, max_grad_norm: float, scaler: GradScaler):
model.train()
total_loss = 0.0
used_steps = 0
optimizer.zero_grad(set_to_none=True)
pbar = tqdm(loader, desc='Training', leave=False)
amp_enabled = scaler.is_enabled()
for step, batch in enumerate(pbar):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in (
'input_ids', 'attention_mask', 'labels')}
with autocast(enabled=amp_enabled):
outputs = model(**batch)
loss = outputs.loss
if not torch.isfinite(loss):
print(
f'[WARN] Skipping non-finite training loss at step {step}: {loss.item()}')
optimizer.zero_grad(set_to_none=True)
continue
loss = loss / grad_accum
scaler.scale(loss).backward()
total_loss += float(loss.item()) * grad_accum
used_steps += 1
if (step + 1) % grad_accum == 0 or (step + 1) == len(loader):
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
if used_steps > 0:
pbar.set_postfix(loss=f'{total_loss / used_steps:.4f}')
if used_steps == 0:
return float('nan')
return total_loss / used_steps
@torch.no_grad()
def evaluate_loss(model, loader, device, use_amp: bool = True):
model.eval()
total_loss = 0.0
count = 0
amp_enabled = (device.type == 'cuda') and use_amp
for batch in tqdm(loader, desc='Eval(loss)', leave=False):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in (
'input_ids', 'attention_mask', 'labels')}
with autocast(enabled=amp_enabled):
outputs = model(**batch)
loss_value = float(outputs.loss.item())
if not math.isfinite(loss_value):
print('[WARN] Skipping non-finite eval batch loss')
continue
total_loss += loss_value
count += 1
if count == 0:
return float('nan'), float('nan')
avg_loss = total_loss / count
return avg_loss, perplexity_from_loss(avg_loss)
@torch.no_grad()
def evaluate_generation_on_file(model, tokenizer, path: str, device, max_length: int, max_new_tokens: int, do_sample: bool = False, use_amp: bool = True):
rows = load_jsonl(path)
preds, golds = [], []
amp_enabled = (device.type == 'cuda') and use_amp
for ex in tqdm(rows, desc='Eval(gen)', leave=False):
with autocast(enabled=amp_enabled):
pred = generate_completion(
model=model,
tokenizer=tokenizer,
retrieved=ex.get('retrieved', ''),
context=ex.get('context', ''),
device=device,
max_length=max_length,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
stop_strings=['<EOL>'],
)
preds.append(pred)
golds.append(ex.get('target', ''))
return evaluate_generation(preds, golds), preds
def ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def save_metrics(path: str, metrics: dict):
with open(path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
def save_predictions(path: str, rows, preds):
with open(path, 'w', encoding='utf-8') as f:
for ex, pred in zip(rows, preds):
obj = dict(ex)
obj['prediction'] = pred
f.write(json.dumps(obj, ensure_ascii=False) + '\n')
def resolve_model_path(model_name_or_path: str) -> str:
"""If checkpoint-best is missing, fallback to checkpoint-last for local checkpoints.
HF repo ids like 'microsoft/CodeGPT-small-py' are returned untouched.
"""
p = Path(model_name_or_path)
if p.exists():
return str(p)
if p.name == 'checkpoint-best':
alt = p.parent / 'checkpoint-last'
if alt.exists():
print(f'[WARN] {p} not found. Falling back to {alt}')
return str(alt)
if '/' in model_name_or_path and not model_name_or_path.startswith('/'):
return model_name_or_path
raise FileNotFoundError(f'Model path not found: {model_name_or_path}')
def train_main(args):
set_seed(args.seed)
if args.num_threads and args.num_threads > 0:
torch.set_num_threads(args.num_threads)
device = get_safe_device(args.cpu_only)
print('Using device:', device)
if device.type == 'cuda':
print('GPU:', torch.cuda.get_device_name(
device.index if device.index is not None else 0))
use_amp = (device.type == 'cuda') and (not args.no_fp16)
scaler = GradScaler(enabled=use_amp)
print('AMP enabled:', scaler.is_enabled())
print('Loading base model from:', args.model_name_or_path)
tokenizer, model = load_model_and_tokenizer(args.model_name_or_path)
model.to(device)
train_rows = load_jsonl(args.train_file)
valid_rows = load_jsonl(args.valid_file) if args.valid_file else []
train_dataset = ReACCGeneratorDataset(
train_rows,
tokenizer,
max_length=args.block_size,
max_target_length=args.max_target_length,
)
valid_loader = None
if len(train_dataset) == 0:
raise ValueError(
'No valid training examples after filtering. Check your JSONL targets.')
pin = device.type == 'cuda'
num_workers = min(2, args.num_threads if args.num_threads > 0 else 0)
train_loader = build_dataloader(
train_dataset,
tokenizer,
args.per_device_train_batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin,
)
if valid_rows:
valid_dataset = ReACCGeneratorDataset(
valid_rows,
tokenizer,
max_length=args.block_size,
max_target_length=args.max_target_length,
)
if len(valid_dataset) > 0:
valid_loader = build_dataloader(
valid_dataset,
tokenizer,
args.per_device_eval_batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin,
)
else:
print(
'[WARN] No valid validation examples after filtering. Validation will be skipped.')
optimizer = torch.optim.AdamW(
model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
total_updates = math.ceil(len(
train_loader) / max(1, args.gradient_accumulation_steps)) * args.num_train_epochs
warmup_steps = int(total_updates * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer, warmup_steps, total_updates)
ensure_dir(args.output_dir)
best_val = float('inf')
for epoch in range(args.num_train_epochs):
train_loss = train_one_epoch(
model=model,
loader=train_loader,
optimizer=optimizer,
scheduler=scheduler,
device=device,
grad_accum=args.gradient_accumulation_steps,
max_grad_norm=args.max_grad_norm,
scaler=scaler,
)
metrics = {'epoch': epoch + 1, 'train_loss': train_loss}
if valid_loader is not None:
val_loss, val_ppl = evaluate_loss(
model, valid_loader, device, use_amp=use_amp)
metrics.update({'valid_loss': val_loss, 'valid_ppl': val_ppl})
if math.isfinite(val_loss) and ((epoch == 0) or (val_loss < best_val)):
best_val = val_loss
save_dir = os.path.join(args.output_dir, 'checkpoint-best')
ensure_dir(save_dir)
model.save_pretrained(save_dir)
tokenizer.save_pretrained(save_dir)
print(
f'[INFO] Saved checkpoint-best at epoch {epoch + 1} with valid_loss={val_loss:.6f}')
elif not math.isfinite(val_loss):
print(
'[WARN] Validation loss is non-finite; checkpoint-best will NOT be updated.')
save_metrics(os.path.join(args.output_dir,
f'metrics_epoch_{epoch + 1}.json'), metrics)
last_dir = os.path.join(args.output_dir, 'checkpoint-last')
ensure_dir(last_dir)
model.save_pretrained(last_dir)
tokenizer.save_pretrained(last_dir)
print(f'[INFO] Saved checkpoint-last to {last_dir}')
def eval_main(args):
if args.num_threads and args.num_threads > 0:
torch.set_num_threads(args.num_threads)
device = get_safe_device(args.cpu_only)
print('Using device:', device)
use_amp = (device.type == 'cuda') and (not args.no_fp16)
print('AMP enabled:', use_amp)
model_path = resolve_model_path(args.model_name_or_path)
print('Loading model from:', model_path)
tokenizer, model = load_model_and_tokenizer(model_path)
model.to(device)
rows = load_jsonl(args.valid_file)
dataset = ReACCGeneratorDataset(
rows, tokenizer, max_length=args.block_size, max_target_length=args.max_target_length)
if len(dataset) == 0:
raise ValueError(
'No valid evaluation examples after filtering. Check your JSONL targets.')
loader = build_dataloader(
dataset,
tokenizer,
args.per_device_eval_batch_size,
shuffle=False,
num_workers=min(2, args.num_threads if args.num_threads > 0 else 0),
pin_memory=(device.type == 'cuda')
)
loss, ppl = evaluate_loss(model, loader, device, use_amp=use_amp)
metrics = {'loss': loss, 'perplexity': ppl}
if args.eval_generation:
gen_metrics, preds = evaluate_generation_on_file(
model, tokenizer, args.valid_file, device, args.block_size, args.max_new_tokens,
do_sample=args.do_sample, use_amp=use_amp
)
metrics.update(gen_metrics)
if args.save_predictions_path:
save_predictions(args.save_predictions_path, rows, preds)
print(json.dumps(metrics, indent=2, ensure_ascii=False))
if args.output_dir:
ensure_dir(args.output_dir)
save_metrics(os.path.join(args.output_dir,
'eval_metrics.json'), metrics)
def predict_main(args):
if args.num_threads and args.num_threads > 0:
torch.set_num_threads(args.num_threads)
device = get_safe_device(args.cpu_only)
print('Using device:', device)
use_amp = (device.type == 'cuda') and (not args.no_fp16)
print('AMP enabled:', use_amp)
model_path = resolve_model_path(args.model_name_or_path)
print('Loading model from:', model_path)
tokenizer, model = load_model_and_tokenizer(model_path)
model.to(device)
rows = load_jsonl(args.test_file)
preds = []
for ex in tqdm(rows, desc='Predict'):
with autocast(enabled=use_amp):
pred = generate_completion(
model=model,
tokenizer=tokenizer,
retrieved=ex.get('retrieved', ''),
context=ex.get('context', ''),
device=device,
max_length=args.block_size,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
stop_strings=['<EOL>'],
)
preds.append(pred)
if args.save_predictions_path:
save_predictions(args.save_predictions_path, rows, preds)
print(f'[INFO] Saved predictions to {args.save_predictions_path}')
else:
print('\nSample predictions:')
for i, pred in enumerate(preds[:10]):
print(f'[{i}] {pred}')
def build_arg_parser():
parser = argparse.ArgumentParser(
description='ReACC-style generator runner')
parser.add_argument('--task', type=str, default='train',
choices=['train', 'eval', 'predict'])
parser.add_argument('--model_name_or_path', type=str,
default='microsoft/CodeGPT-small-py')
parser.add_argument('--output_dir', type=str, default='./save/reacc_gen')
parser.add_argument('--train_file', type=str, default=None)
parser.add_argument('--valid_file', type=str, default=None)
parser.add_argument('--test_file', type=str, default=None)
parser.add_argument('--save_predictions_path', type=str, default=None)
parser.add_argument('--block_size', type=int, default=384)
parser.add_argument('--max_target_length', type=int, default=96)
parser.add_argument('--max_new_tokens', type=int, default=64)
parser.add_argument('--num_train_epochs', type=int, default=1)
parser.add_argument('--per_device_train_batch_size', type=int, default=1)
parser.add_argument('--per_device_eval_batch_size', type=int, default=1)
parser.add_argument('--gradient_accumulation_steps', type=int, default=8)
parser.add_argument('--learning_rate', type=float, default=2e-5)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--warmup_ratio', type=float, default=0.06)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_threads', type=int, default=2)
parser.add_argument('--cpu_only', action='store_true')
parser.add_argument('--do_sample', action='store_true')
parser.add_argument('--eval_generation', action='store_true')
parser.add_argument('--no_fp16', action='store_true',
help='Disable automatic mixed precision on CUDA')
return parser
if __name__ == '__main__':
parser = build_arg_parser()
args = parser.parse_args()
if args.task == 'train':
if not args.train_file:
raise ValueError('--train_file is required for train task')
train_main(args)
elif args.task == 'eval':
if not args.valid_file:
raise ValueError('--valid_file is required for eval task')
eval_main(args)
else:
if not args.test_file:
raise ValueError('--test_file is required for predict task')
predict_main(args)