| import torch
|
| from torch.utils.data import DataLoader
|
| from typing import Dict, List
|
| from tqdm import tqdm
|
| from torch.amp import autocast, GradScaler
|
|
|
| class ModelTrainer:
|
| def __init__(self, model, optimizer, criterion, device, scaler: GradScaler = None, scheduler=None):
|
| self.model = model
|
| self.optimizer = optimizer
|
| self.criterion = criterion
|
| self.device = device
|
| self.scaler = scaler or GradScaler('cuda')
|
| self.use_amp = device.type == 'cuda'
|
| self.scheduler = scheduler
|
|
|
| def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
|
| self.model.train()
|
| total_loss = 0
|
|
|
| for batch in tqdm(dataloader, desc="Training"):
|
| input_ids = batch['input_ids'].to(self.device)
|
| attention_mask = batch['attention_mask'].to(self.device)
|
| labels = batch['labels'].to(self.device)
|
|
|
| self.optimizer.zero_grad()
|
|
|
| if self.use_amp:
|
| with autocast('cuda'):
|
| outputs = self.model(input_ids, attention_mask)
|
| loss = self.criterion(outputs, labels)
|
|
|
| self.scaler.scale(loss).backward()
|
|
|
|
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| else:
|
| outputs = self.model(input_ids, attention_mask)
|
| loss = self.criterion(outputs, labels)
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
| self.optimizer.step()
|
|
|
| if self.scheduler is not None:
|
| self.scheduler.step()
|
|
|
| total_loss += loss.item()
|
|
|
| return {'loss': total_loss / len(dataloader)}
|
|
|
| def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
|
| self.model.eval()
|
| total_loss = 0
|
| predictions = []
|
| true_labels = []
|
|
|
| with torch.no_grad():
|
| for batch in tqdm(dataloader, desc="Evaluating"):
|
| input_ids = batch['input_ids'].to(self.device)
|
| attention_mask = batch['attention_mask'].to(self.device)
|
| labels = batch['labels'].to(self.device)
|
|
|
| if self.use_amp:
|
| with autocast('cuda'):
|
| outputs = self.model(input_ids, attention_mask)
|
| loss = self.criterion(outputs, labels)
|
| else:
|
| outputs = self.model(input_ids, attention_mask)
|
| loss = self.criterion(outputs, labels)
|
|
|
|
|
| probs = torch.sigmoid(outputs)
|
|
|
| total_loss += loss.item()
|
| predictions.extend(probs.cpu().numpy())
|
| true_labels.extend(labels.cpu().numpy())
|
|
|
| return {
|
| 'loss': total_loss / len(dataloader),
|
| 'predictions': predictions,
|
| 'true_labels': true_labels
|
| } |