from __future__ import annotations from model_utils import evaluate_generation, generate_completion, load_model_and_tokenizer, perplexity_from_loss from dataset import ReACCGeneratorDataset, collate_batch, load_jsonl import argparse import json import math import os import random import sys from pathlib import Path import torch from torch.utils.data import DataLoader from transformers import get_linear_schedule_with_warmup from tqdm import tqdm from torch.cuda.amp import autocast, GradScaler CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) if CURRENT_DIR not in sys.path: sys.path.insert(0, CURRENT_DIR) # Speed-ups on CUDA-capable environments. torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True def set_seed(seed: int = 42): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def get_safe_device(cpu_only: bool = False): """Use a supported GPU if available; otherwise fall back to CPU.""" if cpu_only or not torch.cuda.is_available(): return torch.device('cpu') for i in range(torch.cuda.device_count()): major, minor = torch.cuda.get_device_capability(i) name = torch.cuda.get_device_name(i) print(f'Detected GPU {i}: {name} (sm_{major}{minor})') # current Kaggle torch builds generally support >= sm70 cleanly if major >= 7: torch.cuda.set_device(i) print(f'[INFO] Using supported GPU {i}: {name}') return torch.device(f'cuda:{i}') print('[WARN] No supported GPU found for this PyTorch build. Falling back to CPU.') return torch.device('cpu') def build_dataloader(dataset, tokenizer, batch_size: int, shuffle: bool, num_workers: int = 0, pin_memory: bool = False): return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, collate_fn=lambda batch: collate_batch(batch, tokenizer.pad_token_id), ) def train_one_epoch(model, loader, optimizer, scheduler, device, grad_accum: int, max_grad_norm: float, scaler: GradScaler): model.train() total_loss = 0.0 used_steps = 0 optimizer.zero_grad(set_to_none=True) pbar = tqdm(loader, desc='Training', leave=False) amp_enabled = scaler.is_enabled() for step, batch in enumerate(pbar): batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in ( 'input_ids', 'attention_mask', 'labels')} with autocast(enabled=amp_enabled): outputs = model(**batch) loss = outputs.loss if not torch.isfinite(loss): print( f'[WARN] Skipping non-finite training loss at step {step}: {loss.item()}') optimizer.zero_grad(set_to_none=True) continue loss = loss / grad_accum scaler.scale(loss).backward() total_loss += float(loss.item()) * grad_accum used_steps += 1 if (step + 1) % grad_accum == 0 or (step + 1) == len(loader): scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update() scheduler.step() optimizer.zero_grad(set_to_none=True) if used_steps > 0: pbar.set_postfix(loss=f'{total_loss / used_steps:.4f}') if used_steps == 0: return float('nan') return total_loss / used_steps @torch.no_grad() def evaluate_loss(model, loader, device, use_amp: bool = True): model.eval() total_loss = 0.0 count = 0 amp_enabled = (device.type == 'cuda') and use_amp for batch in tqdm(loader, desc='Eval(loss)', leave=False): batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k in ( 'input_ids', 'attention_mask', 'labels')} with autocast(enabled=amp_enabled): outputs = model(**batch) loss_value = float(outputs.loss.item()) if not math.isfinite(loss_value): print('[WARN] Skipping non-finite eval batch loss') continue total_loss += loss_value count += 1 if count == 0: return float('nan'), float('nan') avg_loss = total_loss / count return avg_loss, perplexity_from_loss(avg_loss) @torch.no_grad() def evaluate_generation_on_file(model, tokenizer, path: str, device, max_length: int, max_new_tokens: int, do_sample: bool = False, use_amp: bool = True): rows = load_jsonl(path) preds, golds = [], [] amp_enabled = (device.type == 'cuda') and use_amp for ex in tqdm(rows, desc='Eval(gen)', leave=False): with autocast(enabled=amp_enabled): pred = generate_completion( model=model, tokenizer=tokenizer, retrieved=ex.get('retrieved', ''), context=ex.get('context', ''), device=device, max_length=max_length, max_new_tokens=max_new_tokens, do_sample=do_sample, stop_strings=[''], ) preds.append(pred) golds.append(ex.get('target', '')) return evaluate_generation(preds, golds), preds def ensure_dir(path: str): os.makedirs(path, exist_ok=True) def save_metrics(path: str, metrics: dict): with open(path, 'w', encoding='utf-8') as f: json.dump(metrics, f, ensure_ascii=False, indent=2) def save_predictions(path: str, rows, preds): with open(path, 'w', encoding='utf-8') as f: for ex, pred in zip(rows, preds): obj = dict(ex) obj['prediction'] = pred f.write(json.dumps(obj, ensure_ascii=False) + '\n') def resolve_model_path(model_name_or_path: str) -> str: """If checkpoint-best is missing, fallback to checkpoint-last for local checkpoints. HF repo ids like 'microsoft/CodeGPT-small-py' are returned untouched. """ p = Path(model_name_or_path) if p.exists(): return str(p) if p.name == 'checkpoint-best': alt = p.parent / 'checkpoint-last' if alt.exists(): print(f'[WARN] {p} not found. Falling back to {alt}') return str(alt) if '/' in model_name_or_path and not model_name_or_path.startswith('/'): return model_name_or_path raise FileNotFoundError(f'Model path not found: {model_name_or_path}') def train_main(args): set_seed(args.seed) if args.num_threads and args.num_threads > 0: torch.set_num_threads(args.num_threads) device = get_safe_device(args.cpu_only) print('Using device:', device) if device.type == 'cuda': print('GPU:', torch.cuda.get_device_name( device.index if device.index is not None else 0)) use_amp = (device.type == 'cuda') and (not args.no_fp16) scaler = GradScaler(enabled=use_amp) print('AMP enabled:', scaler.is_enabled()) print('Loading base model from:', args.model_name_or_path) tokenizer, model = load_model_and_tokenizer(args.model_name_or_path) model.to(device) train_rows = load_jsonl(args.train_file) valid_rows = load_jsonl(args.valid_file) if args.valid_file else [] train_dataset = ReACCGeneratorDataset( train_rows, tokenizer, max_length=args.block_size, max_target_length=args.max_target_length, ) valid_loader = None if len(train_dataset) == 0: raise ValueError( 'No valid training examples after filtering. Check your JSONL targets.') pin = device.type == 'cuda' num_workers = min(2, args.num_threads if args.num_threads > 0 else 0) train_loader = build_dataloader( train_dataset, tokenizer, args.per_device_train_batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin, ) if valid_rows: valid_dataset = ReACCGeneratorDataset( valid_rows, tokenizer, max_length=args.block_size, max_target_length=args.max_target_length, ) if len(valid_dataset) > 0: valid_loader = build_dataloader( valid_dataset, tokenizer, args.per_device_eval_batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin, ) else: print( '[WARN] No valid validation examples after filtering. Validation will be skipped.') optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) total_updates = math.ceil(len( train_loader) / max(1, args.gradient_accumulation_steps)) * args.num_train_epochs warmup_steps = int(total_updates * args.warmup_ratio) scheduler = get_linear_schedule_with_warmup( optimizer, warmup_steps, total_updates) ensure_dir(args.output_dir) best_val = float('inf') for epoch in range(args.num_train_epochs): train_loss = train_one_epoch( model=model, loader=train_loader, optimizer=optimizer, scheduler=scheduler, device=device, grad_accum=args.gradient_accumulation_steps, max_grad_norm=args.max_grad_norm, scaler=scaler, ) metrics = {'epoch': epoch + 1, 'train_loss': train_loss} if valid_loader is not None: val_loss, val_ppl = evaluate_loss( model, valid_loader, device, use_amp=use_amp) metrics.update({'valid_loss': val_loss, 'valid_ppl': val_ppl}) if math.isfinite(val_loss) and ((epoch == 0) or (val_loss < best_val)): best_val = val_loss save_dir = os.path.join(args.output_dir, 'checkpoint-best') ensure_dir(save_dir) model.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir) print( f'[INFO] Saved checkpoint-best at epoch {epoch + 1} with valid_loss={val_loss:.6f}') elif not math.isfinite(val_loss): print( '[WARN] Validation loss is non-finite; checkpoint-best will NOT be updated.') save_metrics(os.path.join(args.output_dir, f'metrics_epoch_{epoch + 1}.json'), metrics) last_dir = os.path.join(args.output_dir, 'checkpoint-last') ensure_dir(last_dir) model.save_pretrained(last_dir) tokenizer.save_pretrained(last_dir) print(f'[INFO] Saved checkpoint-last to {last_dir}') def eval_main(args): if args.num_threads and args.num_threads > 0: torch.set_num_threads(args.num_threads) device = get_safe_device(args.cpu_only) print('Using device:', device) use_amp = (device.type == 'cuda') and (not args.no_fp16) print('AMP enabled:', use_amp) model_path = resolve_model_path(args.model_name_or_path) print('Loading model from:', model_path) tokenizer, model = load_model_and_tokenizer(model_path) model.to(device) rows = load_jsonl(args.valid_file) dataset = ReACCGeneratorDataset( rows, tokenizer, max_length=args.block_size, max_target_length=args.max_target_length) if len(dataset) == 0: raise ValueError( 'No valid evaluation examples after filtering. Check your JSONL targets.') loader = build_dataloader( dataset, tokenizer, args.per_device_eval_batch_size, shuffle=False, num_workers=min(2, args.num_threads if args.num_threads > 0 else 0), pin_memory=(device.type == 'cuda') ) loss, ppl = evaluate_loss(model, loader, device, use_amp=use_amp) metrics = {'loss': loss, 'perplexity': ppl} if args.eval_generation: gen_metrics, preds = evaluate_generation_on_file( model, tokenizer, args.valid_file, device, args.block_size, args.max_new_tokens, do_sample=args.do_sample, use_amp=use_amp ) metrics.update(gen_metrics) if args.save_predictions_path: save_predictions(args.save_predictions_path, rows, preds) print(json.dumps(metrics, indent=2, ensure_ascii=False)) if args.output_dir: ensure_dir(args.output_dir) save_metrics(os.path.join(args.output_dir, 'eval_metrics.json'), metrics) def predict_main(args): if args.num_threads and args.num_threads > 0: torch.set_num_threads(args.num_threads) device = get_safe_device(args.cpu_only) print('Using device:', device) use_amp = (device.type == 'cuda') and (not args.no_fp16) print('AMP enabled:', use_amp) model_path = resolve_model_path(args.model_name_or_path) print('Loading model from:', model_path) tokenizer, model = load_model_and_tokenizer(model_path) model.to(device) rows = load_jsonl(args.test_file) preds = [] for ex in tqdm(rows, desc='Predict'): with autocast(enabled=use_amp): pred = generate_completion( model=model, tokenizer=tokenizer, retrieved=ex.get('retrieved', ''), context=ex.get('context', ''), device=device, max_length=args.block_size, max_new_tokens=args.max_new_tokens, do_sample=args.do_sample, stop_strings=[''], ) preds.append(pred) if args.save_predictions_path: save_predictions(args.save_predictions_path, rows, preds) print(f'[INFO] Saved predictions to {args.save_predictions_path}') else: print('\nSample predictions:') for i, pred in enumerate(preds[:10]): print(f'[{i}] {pred}') def build_arg_parser(): parser = argparse.ArgumentParser( description='ReACC-style generator runner') parser.add_argument('--task', type=str, default='train', choices=['train', 'eval', 'predict']) parser.add_argument('--model_name_or_path', type=str, default='microsoft/CodeGPT-small-py') parser.add_argument('--output_dir', type=str, default='./save/reacc_gen') parser.add_argument('--train_file', type=str, default=None) parser.add_argument('--valid_file', type=str, default=None) parser.add_argument('--test_file', type=str, default=None) parser.add_argument('--save_predictions_path', type=str, default=None) parser.add_argument('--block_size', type=int, default=384) parser.add_argument('--max_target_length', type=int, default=96) parser.add_argument('--max_new_tokens', type=int, default=64) parser.add_argument('--num_train_epochs', type=int, default=1) parser.add_argument('--per_device_train_batch_size', type=int, default=1) parser.add_argument('--per_device_eval_batch_size', type=int, default=1) parser.add_argument('--gradient_accumulation_steps', type=int, default=8) parser.add_argument('--learning_rate', type=float, default=2e-5) parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--warmup_ratio', type=float, default=0.06) parser.add_argument('--max_grad_norm', type=float, default=1.0) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--num_threads', type=int, default=2) parser.add_argument('--cpu_only', action='store_true') parser.add_argument('--do_sample', action='store_true') parser.add_argument('--eval_generation', action='store_true') parser.add_argument('--no_fp16', action='store_true', help='Disable automatic mixed precision on CUDA') return parser if __name__ == '__main__': parser = build_arg_parser() args = parser.parse_args() if args.task == 'train': if not args.train_file: raise ValueError('--train_file is required for train task') train_main(args) elif args.task == 'eval': if not args.valid_file: raise ValueError('--valid_file is required for eval task') eval_main(args) else: if not args.test_file: raise ValueError('--test_file is required for predict task') predict_main(args)