Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from model_utils import evaluate_generation, generate_completion, load_model_and_tokenizer, perplexity_from_loss | |
| from dataset import ReACCGeneratorDataset, collate_batch, load_jsonl | |
| import argparse | |
| import json | |
| import math | |
| import os | |
| import random | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from transformers import get_linear_schedule_with_warmup | |
| from tqdm import tqdm | |
| from torch.cuda.amp import autocast, GradScaler | |
| CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if CURRENT_DIR not in sys.path: | |
| sys.path.insert(0, CURRENT_DIR) | |
| # Speed-ups on CUDA-capable environments. | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| def set_seed(seed: int = 42): | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def get_safe_device(cpu_only: bool = False): | |
| """Use a supported GPU if available; otherwise fall back to CPU.""" | |
| if cpu_only or not torch.cuda.is_available(): | |
| return torch.device('cpu') | |
| for i in range(torch.cuda.device_count()): | |
| major, minor = torch.cuda.get_device_capability(i) | |
| name = torch.cuda.get_device_name(i) | |
| print(f'Detected GPU {i}: {name} (sm_{major}{minor})') | |
| # current Kaggle torch builds generally support >= sm70 cleanly | |
| if major >= 7: | |
| torch.cuda.set_device(i) | |
| print(f'[INFO] Using supported GPU {i}: {name}') | |
| return torch.device(f'cuda:{i}') | |
| print('[WARN] No supported GPU found for this PyTorch build. Falling back to CPU.') | |
| return torch.device('cpu') | |
| def build_dataloader(dataset, tokenizer, batch_size: int, shuffle: bool, num_workers: int = 0, pin_memory: bool = False): | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=num_workers, | |
| pin_memory=pin_memory, | |
| collate_fn=lambda batch: collate_batch(batch, tokenizer.pad_token_id), | |
| ) | |
| def train_one_epoch(model, loader, optimizer, scheduler, device, grad_accum: int, max_grad_norm: float, scaler: GradScaler): | |
| model.train() | |
| total_loss = 0.0 | |
| used_steps = 0 | |
| optimizer.zero_grad(set_to_none=True) | |
| pbar = tqdm(loader, desc='Training', leave=False) | |
| amp_enabled = scaler.is_enabled() | |
| for step, batch in enumerate(pbar): | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in ( | |
| 'input_ids', 'attention_mask', 'labels')} | |
| with autocast(enabled=amp_enabled): | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| if not torch.isfinite(loss): | |
| print( | |
| f'[WARN] Skipping non-finite training loss at step {step}: {loss.item()}') | |
| optimizer.zero_grad(set_to_none=True) | |
| continue | |
| loss = loss / grad_accum | |
| scaler.scale(loss).backward() | |
| total_loss += float(loss.item()) * grad_accum | |
| used_steps += 1 | |
| if (step + 1) % grad_accum == 0 or (step + 1) == len(loader): | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| if used_steps > 0: | |
| pbar.set_postfix(loss=f'{total_loss / used_steps:.4f}') | |
| if used_steps == 0: | |
| return float('nan') | |
| return total_loss / used_steps | |
| def evaluate_loss(model, loader, device, use_amp: bool = True): | |
| model.eval() | |
| total_loss = 0.0 | |
| count = 0 | |
| amp_enabled = (device.type == 'cuda') and use_amp | |
| for batch in tqdm(loader, desc='Eval(loss)', leave=False): | |
| batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in ( | |
| 'input_ids', 'attention_mask', 'labels')} | |
| with autocast(enabled=amp_enabled): | |
| outputs = model(**batch) | |
| loss_value = float(outputs.loss.item()) | |
| if not math.isfinite(loss_value): | |
| print('[WARN] Skipping non-finite eval batch loss') | |
| continue | |
| total_loss += loss_value | |
| count += 1 | |
| if count == 0: | |
| return float('nan'), float('nan') | |
| avg_loss = total_loss / count | |
| return avg_loss, perplexity_from_loss(avg_loss) | |
| def evaluate_generation_on_file(model, tokenizer, path: str, device, max_length: int, max_new_tokens: int, do_sample: bool = False, use_amp: bool = True): | |
| rows = load_jsonl(path) | |
| preds, golds = [], [] | |
| amp_enabled = (device.type == 'cuda') and use_amp | |
| for ex in tqdm(rows, desc='Eval(gen)', leave=False): | |
| with autocast(enabled=amp_enabled): | |
| pred = generate_completion( | |
| model=model, | |
| tokenizer=tokenizer, | |
| retrieved=ex.get('retrieved', ''), | |
| context=ex.get('context', ''), | |
| device=device, | |
| max_length=max_length, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| stop_strings=['<EOL>'], | |
| ) | |
| preds.append(pred) | |
| golds.append(ex.get('target', '')) | |
| return evaluate_generation(preds, golds), preds | |
| def ensure_dir(path: str): | |
| os.makedirs(path, exist_ok=True) | |
| def save_metrics(path: str, metrics: dict): | |
| with open(path, 'w', encoding='utf-8') as f: | |
| json.dump(metrics, f, ensure_ascii=False, indent=2) | |
| def save_predictions(path: str, rows, preds): | |
| with open(path, 'w', encoding='utf-8') as f: | |
| for ex, pred in zip(rows, preds): | |
| obj = dict(ex) | |
| obj['prediction'] = pred | |
| f.write(json.dumps(obj, ensure_ascii=False) + '\n') | |
| def resolve_model_path(model_name_or_path: str) -> str: | |
| """If checkpoint-best is missing, fallback to checkpoint-last for local checkpoints. | |
| HF repo ids like 'microsoft/CodeGPT-small-py' are returned untouched. | |
| """ | |
| p = Path(model_name_or_path) | |
| if p.exists(): | |
| return str(p) | |
| if p.name == 'checkpoint-best': | |
| alt = p.parent / 'checkpoint-last' | |
| if alt.exists(): | |
| print(f'[WARN] {p} not found. Falling back to {alt}') | |
| return str(alt) | |
| if '/' in model_name_or_path and not model_name_or_path.startswith('/'): | |
| return model_name_or_path | |
| raise FileNotFoundError(f'Model path not found: {model_name_or_path}') | |
| def train_main(args): | |
| set_seed(args.seed) | |
| if args.num_threads and args.num_threads > 0: | |
| torch.set_num_threads(args.num_threads) | |
| device = get_safe_device(args.cpu_only) | |
| print('Using device:', device) | |
| if device.type == 'cuda': | |
| print('GPU:', torch.cuda.get_device_name( | |
| device.index if device.index is not None else 0)) | |
| use_amp = (device.type == 'cuda') and (not args.no_fp16) | |
| scaler = GradScaler(enabled=use_amp) | |
| print('AMP enabled:', scaler.is_enabled()) | |
| print('Loading base model from:', args.model_name_or_path) | |
| tokenizer, model = load_model_and_tokenizer(args.model_name_or_path) | |
| model.to(device) | |
| train_rows = load_jsonl(args.train_file) | |
| valid_rows = load_jsonl(args.valid_file) if args.valid_file else [] | |
| train_dataset = ReACCGeneratorDataset( | |
| train_rows, | |
| tokenizer, | |
| max_length=args.block_size, | |
| max_target_length=args.max_target_length, | |
| ) | |
| valid_loader = None | |
| if len(train_dataset) == 0: | |
| raise ValueError( | |
| 'No valid training examples after filtering. Check your JSONL targets.') | |
| pin = device.type == 'cuda' | |
| num_workers = min(2, args.num_threads if args.num_threads > 0 else 0) | |
| train_loader = build_dataloader( | |
| train_dataset, | |
| tokenizer, | |
| args.per_device_train_batch_size, | |
| shuffle=True, | |
| num_workers=num_workers, | |
| pin_memory=pin, | |
| ) | |
| if valid_rows: | |
| valid_dataset = ReACCGeneratorDataset( | |
| valid_rows, | |
| tokenizer, | |
| max_length=args.block_size, | |
| max_target_length=args.max_target_length, | |
| ) | |
| if len(valid_dataset) > 0: | |
| valid_loader = build_dataloader( | |
| valid_dataset, | |
| tokenizer, | |
| args.per_device_eval_batch_size, | |
| shuffle=False, | |
| num_workers=num_workers, | |
| pin_memory=pin, | |
| ) | |
| else: | |
| print( | |
| '[WARN] No valid validation examples after filtering. Validation will be skipped.') | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) | |
| total_updates = math.ceil(len( | |
| train_loader) / max(1, args.gradient_accumulation_steps)) * args.num_train_epochs | |
| warmup_steps = int(total_updates * args.warmup_ratio) | |
| scheduler = get_linear_schedule_with_warmup( | |
| optimizer, warmup_steps, total_updates) | |
| ensure_dir(args.output_dir) | |
| best_val = float('inf') | |
| for epoch in range(args.num_train_epochs): | |
| train_loss = train_one_epoch( | |
| model=model, | |
| loader=train_loader, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| device=device, | |
| grad_accum=args.gradient_accumulation_steps, | |
| max_grad_norm=args.max_grad_norm, | |
| scaler=scaler, | |
| ) | |
| metrics = {'epoch': epoch + 1, 'train_loss': train_loss} | |
| if valid_loader is not None: | |
| val_loss, val_ppl = evaluate_loss( | |
| model, valid_loader, device, use_amp=use_amp) | |
| metrics.update({'valid_loss': val_loss, 'valid_ppl': val_ppl}) | |
| if math.isfinite(val_loss) and ((epoch == 0) or (val_loss < best_val)): | |
| best_val = val_loss | |
| save_dir = os.path.join(args.output_dir, 'checkpoint-best') | |
| ensure_dir(save_dir) | |
| model.save_pretrained(save_dir) | |
| tokenizer.save_pretrained(save_dir) | |
| print( | |
| f'[INFO] Saved checkpoint-best at epoch {epoch + 1} with valid_loss={val_loss:.6f}') | |
| elif not math.isfinite(val_loss): | |
| print( | |
| '[WARN] Validation loss is non-finite; checkpoint-best will NOT be updated.') | |
| save_metrics(os.path.join(args.output_dir, | |
| f'metrics_epoch_{epoch + 1}.json'), metrics) | |
| last_dir = os.path.join(args.output_dir, 'checkpoint-last') | |
| ensure_dir(last_dir) | |
| model.save_pretrained(last_dir) | |
| tokenizer.save_pretrained(last_dir) | |
| print(f'[INFO] Saved checkpoint-last to {last_dir}') | |
| def eval_main(args): | |
| if args.num_threads and args.num_threads > 0: | |
| torch.set_num_threads(args.num_threads) | |
| device = get_safe_device(args.cpu_only) | |
| print('Using device:', device) | |
| use_amp = (device.type == 'cuda') and (not args.no_fp16) | |
| print('AMP enabled:', use_amp) | |
| model_path = resolve_model_path(args.model_name_or_path) | |
| print('Loading model from:', model_path) | |
| tokenizer, model = load_model_and_tokenizer(model_path) | |
| model.to(device) | |
| rows = load_jsonl(args.valid_file) | |
| dataset = ReACCGeneratorDataset( | |
| rows, tokenizer, max_length=args.block_size, max_target_length=args.max_target_length) | |
| if len(dataset) == 0: | |
| raise ValueError( | |
| 'No valid evaluation examples after filtering. Check your JSONL targets.') | |
| loader = build_dataloader( | |
| dataset, | |
| tokenizer, | |
| args.per_device_eval_batch_size, | |
| shuffle=False, | |
| num_workers=min(2, args.num_threads if args.num_threads > 0 else 0), | |
| pin_memory=(device.type == 'cuda') | |
| ) | |
| loss, ppl = evaluate_loss(model, loader, device, use_amp=use_amp) | |
| metrics = {'loss': loss, 'perplexity': ppl} | |
| if args.eval_generation: | |
| gen_metrics, preds = evaluate_generation_on_file( | |
| model, tokenizer, args.valid_file, device, args.block_size, args.max_new_tokens, | |
| do_sample=args.do_sample, use_amp=use_amp | |
| ) | |
| metrics.update(gen_metrics) | |
| if args.save_predictions_path: | |
| save_predictions(args.save_predictions_path, rows, preds) | |
| print(json.dumps(metrics, indent=2, ensure_ascii=False)) | |
| if args.output_dir: | |
| ensure_dir(args.output_dir) | |
| save_metrics(os.path.join(args.output_dir, | |
| 'eval_metrics.json'), metrics) | |
| def predict_main(args): | |
| if args.num_threads and args.num_threads > 0: | |
| torch.set_num_threads(args.num_threads) | |
| device = get_safe_device(args.cpu_only) | |
| print('Using device:', device) | |
| use_amp = (device.type == 'cuda') and (not args.no_fp16) | |
| print('AMP enabled:', use_amp) | |
| model_path = resolve_model_path(args.model_name_or_path) | |
| print('Loading model from:', model_path) | |
| tokenizer, model = load_model_and_tokenizer(model_path) | |
| model.to(device) | |
| rows = load_jsonl(args.test_file) | |
| preds = [] | |
| for ex in tqdm(rows, desc='Predict'): | |
| with autocast(enabled=use_amp): | |
| pred = generate_completion( | |
| model=model, | |
| tokenizer=tokenizer, | |
| retrieved=ex.get('retrieved', ''), | |
| context=ex.get('context', ''), | |
| device=device, | |
| max_length=args.block_size, | |
| max_new_tokens=args.max_new_tokens, | |
| do_sample=args.do_sample, | |
| stop_strings=['<EOL>'], | |
| ) | |
| preds.append(pred) | |
| if args.save_predictions_path: | |
| save_predictions(args.save_predictions_path, rows, preds) | |
| print(f'[INFO] Saved predictions to {args.save_predictions_path}') | |
| else: | |
| print('\nSample predictions:') | |
| for i, pred in enumerate(preds[:10]): | |
| print(f'[{i}] {pred}') | |
| def build_arg_parser(): | |
| parser = argparse.ArgumentParser( | |
| description='ReACC-style generator runner') | |
| parser.add_argument('--task', type=str, default='train', | |
| choices=['train', 'eval', 'predict']) | |
| parser.add_argument('--model_name_or_path', type=str, | |
| default='microsoft/CodeGPT-small-py') | |
| parser.add_argument('--output_dir', type=str, default='./save/reacc_gen') | |
| parser.add_argument('--train_file', type=str, default=None) | |
| parser.add_argument('--valid_file', type=str, default=None) | |
| parser.add_argument('--test_file', type=str, default=None) | |
| parser.add_argument('--save_predictions_path', type=str, default=None) | |
| parser.add_argument('--block_size', type=int, default=384) | |
| parser.add_argument('--max_target_length', type=int, default=96) | |
| parser.add_argument('--max_new_tokens', type=int, default=64) | |
| parser.add_argument('--num_train_epochs', type=int, default=1) | |
| parser.add_argument('--per_device_train_batch_size', type=int, default=1) | |
| parser.add_argument('--per_device_eval_batch_size', type=int, default=1) | |
| parser.add_argument('--gradient_accumulation_steps', type=int, default=8) | |
| parser.add_argument('--learning_rate', type=float, default=2e-5) | |
| parser.add_argument('--weight_decay', type=float, default=0.01) | |
| parser.add_argument('--warmup_ratio', type=float, default=0.06) | |
| parser.add_argument('--max_grad_norm', type=float, default=1.0) | |
| parser.add_argument('--seed', type=int, default=42) | |
| parser.add_argument('--num_threads', type=int, default=2) | |
| parser.add_argument('--cpu_only', action='store_true') | |
| parser.add_argument('--do_sample', action='store_true') | |
| parser.add_argument('--eval_generation', action='store_true') | |
| parser.add_argument('--no_fp16', action='store_true', | |
| help='Disable automatic mixed precision on CUDA') | |
| return parser | |
| if __name__ == '__main__': | |
| parser = build_arg_parser() | |
| args = parser.parse_args() | |
| if args.task == 'train': | |
| if not args.train_file: | |
| raise ValueError('--train_file is required for train task') | |
| train_main(args) | |
| elif args.task == 'eval': | |
| if not args.valid_file: | |
| raise ValueError('--valid_file is required for eval task') | |
| eval_main(args) | |
| else: | |
| if not args.test_file: | |
| raise ValueError('--test_file is required for predict task') | |
| predict_main(args) | |