KHLR / Scripts /train.py
Karez's picture
Upload Scripts/train.py with huggingface_hub
384d3b8 verified
"""
Kurdish Handwritten Line Recognition - Training Script
DenseNet121-Transformer Architecture with Constrained Synthetic Line Generation
Usage:
python train.py --data_dir ./data/DASTNUS --vocab_path ./vocab.json
python train.py --data_dir ./data/DASTNUS --vocab_path ./vocab.json --use_synthetic --use_writer_mixing
"""
import os
import glob
import time
import argparse
import json
import math
import random
import re
import numpy as np
from PIL import Image, ImageFilter
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import ConcatDataset
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
# ===============================
# Argument Parser
# ===============================
def parse_args():
parser = argparse.ArgumentParser(description="Kurdish Handwritten Line Recognition Training")
# Data paths
parser.add_argument("--data_dir", type=str, required=True,
help="Root directory of DASTNUS dataset")
parser.add_argument("--vocab_path", type=str, required=True,
help="Path to vocabulary JSON file (vocab.json)")
parser.add_argument("--synthetic_dir", type=str, default=None,
help="Directory containing synthetic handwritten lines")
parser.add_argument("--fixed_lines_dir", type=str, default=None,
help="Directory containing fixed-content handwritten lines")
# Data options
parser.add_argument("--use_synthetic", action="store_true",
help="Include synthetic handwritten lines in training")
parser.add_argument("--use_writer_mixing", action="store_true",
help="Include fixed-content lines from random writers")
parser.add_argument("--num_writers", type=int, default=50,
help="Number of writers to randomly select for mixing")
# Image dimensions
parser.add_argument("--img_height", type=int, default=96)
parser.add_argument("--img_width", type=int, default=1235)
# Training hyperparameters
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--num_epochs", type=int, default=80)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--grad_clip", type=float, default=5.0)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--seed", type=int, default=42)
# Model parameters
parser.add_argument("--hidden_size", type=int, default=256)
parser.add_argument("--encoder_layers", type=int, default=3)
parser.add_argument("--decoder_layers", type=int, default=3)
parser.add_argument("--num_heads", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.4)
parser.add_argument("--ff_dim", type=int, default=1024)
# Early stopping
parser.add_argument("--patience", type=int, default=10)
# Augmentation
parser.add_argument("--no_aug", action="store_true",
help="Disable adaptive augmentation")
# Output
parser.add_argument("--output_dir", type=str, default="./output",
help="Directory to save models and logs")
return parser.parse_args()
# ===============================
# Vocabulary Loader
# ===============================
def load_vocabulary(vocab_path):
"""Load vocabulary from JSON file"""
with open(vocab_path, "r", encoding="utf-8") as f:
vocab_data = json.load(f)
if "vocab_list" in vocab_data:
char_list = vocab_data["vocab_list"]
elif "char_to_idx" in vocab_data:
char_to_idx = vocab_data["char_to_idx"]
char_list = [None] * len(char_to_idx)
for char, idx in char_to_idx.items():
char_list[idx] = char
else:
raise ValueError("Vocabulary JSON must contain 'vocab_list' or 'char_to_idx'")
char_to_idx = {char: idx for idx, char in enumerate(char_list)}
idx_to_char = {idx: char for idx, char in enumerate(char_list)}
PAD_token = 0
SOS_token = 1
EOS_token = 2
return char_list, char_to_idx, idx_to_char, PAD_token, SOS_token, EOS_token
# ===============================
# Helper Functions
# ===============================
def tensor_to_text(tensor, idx_to_char, PAD_token, SOS_token, EOS_token):
"""Convert a tensor of character indices to readable text"""
if isinstance(tensor, torch.Tensor):
tensor = tensor.cpu().tolist()
text = ""
for idx in tensor:
if idx == PAD_token or idx == SOS_token:
continue
if idx == EOS_token:
break
if idx in idx_to_char:
text += idx_to_char[idx]
return text
def extract_writer_id(filename):
"""Extract writer ID from filename (e.g., DNDK00002_2_1.tif -> 2)"""
basename = os.path.basename(filename)
match = re.match(r"DNDK(\d+)", basename)
if match:
return int(match.group(1))
return None
def get_unique_writers(directory):
"""Get all unique writer IDs from a directory"""
image_files = glob.glob(os.path.join(directory, "*.tif"))
writer_ids = set()
for f in image_files:
wid = extract_writer_id(f)
if wid is not None:
writer_ids.add(wid)
return sorted(list(writer_ids))
def filter_files_by_writers(directory, selected_writers):
"""Filter image files to only include those from selected writers"""
all_files = glob.glob(os.path.join(directory, "*.tif"))
return [f for f in all_files if extract_writer_id(f) in selected_writers]
# ===============================
# Dataset Class
# ===============================
# Global variables for adaptive augmentation
current_epoch = 0
num_epochs_global = 80
overfitting_detected = False
validation_loss_history = []
training_loss_history = []
class KurdishLineDataset(data.Dataset):
def __init__(self, root_dir=None, transform=None, max_samples=None,
dataset_name="", image_files=None, img_height=96, img_width=1235,
char_to_idx=None, SOS_token=1, EOS_token=2):
self.transform = transform
self.dataset_name = dataset_name
self.img_height = img_height
self.img_width = img_width
self.char_to_idx = char_to_idx
self.SOS_token = SOS_token
self.EOS_token = EOS_token
if image_files is not None:
self.image_files = image_files
else:
self.image_files = glob.glob(os.path.join(root_dir, "*.tif"))
if max_samples and max_samples < len(self.image_files):
self.image_files = self.image_files[:max_samples]
print(f"Loaded {len(self.image_files)} images for {dataset_name}")
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = self.image_files[idx]
label_path = os.path.splitext(img_path)[0] + ".txt"
image = Image.open(img_path).convert("RGB")
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
new_height = self.img_height
new_width = min(int(new_height * aspect_ratio), self.img_width)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
target_img = Image.new("RGB", (self.img_width, self.img_height), color=(255, 255, 255))
target_img.paste(image, (0, 0))
if self.transform:
target_img = self.transform(target_img)
try:
with open(label_path, "r", encoding="utf-8") as f:
text = f.readline().strip()
except UnicodeDecodeError:
with open(label_path, "r", encoding="utf-8-sig") as f:
text = f.readline().strip()
indices = ([self.SOS_token] +
[self.char_to_idx.get(char, self.char_to_idx.get(" ", 0)) for char in text] +
[self.EOS_token])
target = torch.LongTensor(indices)
target_length = len(indices)
return target_img, target, target_length, text
def collate_fn(batch):
"""Custom collate function for padding sequences"""
batch.sort(key=lambda x: x[2], reverse=True)
images, targets, lengths, original_texts = zip(*batch)
images = torch.stack(images, 0)
max_length = max(lengths)
padded_targets = torch.zeros(len(targets), max_length).long()
for i, target in enumerate(targets):
padded_targets[i, :lengths[i]] = target[:lengths[i]]
lengths = torch.LongTensor(lengths)
return images, padded_targets, lengths, original_texts
# ===============================
# Adaptive Augmentation
# ===============================
class AdaptiveStrokeWidthJitter:
def __init__(self, base_p=0.2, max_p=0.6, base_kernel=3, max_kernel=5):
self.base_p, self.max_p = base_p, max_p
self.base_kernel, self.max_kernel = base_kernel, max_kernel
def __call__(self, img):
progress = min(current_epoch / num_epochs_global, 1.0)
factor = 1.5 if overfitting_detected else 1.0
p = min(self.base_p + (self.max_p - self.base_p) * progress * factor, self.max_p)
kernel = self.base_kernel + int(2 * progress)
if kernel % 2 == 0:
kernel += 1
kernel = min(kernel, self.max_kernel)
if random.random() < p:
if random.random() < 0.5:
return img.filter(ImageFilter.MinFilter(kernel))
return img.filter(ImageFilter.MaxFilter(kernel))
return img
class AdaptiveGaussianNoise:
def __init__(self, base_std=(0.0, 0.01), max_std=(0.0, 0.03), base_p=0.3, max_p=0.7):
self.base_std, self.max_std = base_std, max_std
self.base_p, self.max_p = base_p, max_p
def __call__(self, tensor):
progress = min(current_epoch / num_epochs_global, 1.0)
factor = 1.5 if overfitting_detected else 1.0
p = min(self.base_p + (self.max_p - self.base_p) * progress * factor, self.max_p)
std_high = min(self.base_std[1] + (self.max_std[1] - self.base_std[1]) * progress * factor,
self.max_std[1])
if torch.rand(1).item() < p:
noise = torch.randn_like(tensor) * random.uniform(self.base_std[0], std_high)
tensor = torch.clamp(tensor + noise, 0.0, 1.0)
return tensor
def build_adaptive_train_transform():
class AdaptiveTransform:
def __call__(self, img):
progress = min(current_epoch / num_epochs_global, 1.0)
factor = 1.3 if overfitting_detected else 1.0
if random.random() < min(0.6 + 0.35 * progress * factor, 0.95):
b = min(0.1 + 0.2 * progress * factor, 0.3)
img = transforms.ColorJitter(brightness=b, contrast=b)(img)
if random.random() < min(0.7 + 0.25 * progress * factor, 0.95):
deg = min(1 + 4 * progress * factor, 5)
shear = min(3 + 7 * progress * factor, 10)
img = transforms.RandomAffine(
degrees=deg,
translate=(min(0.01 + 0.02 * progress, 0.03),
min(0.03 + 0.05 * progress, 0.08)),
scale=(max(1 - 0.02 - 0.08 * progress, 0.90),
min(1 + 0.02 + 0.08 * progress, 1.10)),
shear=(-shear, shear),
interpolation=InterpolationMode.BILINEAR, fill=255)(img)
if random.random() < min(0.1 + 0.4 * progress * factor, 0.5):
dist = min(0.02 + 0.06 * progress * factor, 0.08)
img = transforms.RandomPerspective(
distortion_scale=dist, p=1.0,
interpolation=InterpolationMode.BILINEAR, fill=255)(img)
if random.random() < min(0.15 + 0.2 * progress, 0.35):
img = transforms.GaussianBlur(
kernel_size=3, sigma=(0.1, min(0.5 + 0.5 * progress, 1.0)))(img)
img = AdaptiveStrokeWidthJitter()(img)
img = transforms.ToTensor()(img)
img = AdaptiveGaussianNoise()(img)
if random.random() < min(0.1 + 0.3 * progress * factor, 0.4):
img = transforms.RandomErasing(
p=1.0, scale=(0.01, min(0.01 + 0.04 * progress, 0.05)),
ratio=(0.3, 3.3), value="random")(img)
img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)
return img
return AdaptiveTransform()
def build_eval_transform():
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# ===============================
# Model Architecture
# ===============================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class DenseNetFeatureExtractor(nn.Module):
def __init__(self, output_dim=256):
super().__init__()
densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
self.features = nn.Sequential(*list(densenet.children())[:-1])
self.adapt = nn.Conv2d(1024, output_dim, kernel_size=1)
def forward(self, x):
x = self.features(x)
x = self.adapt(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, None))
x = x.squeeze(2)
return x.permute(0, 2, 1)
class TransformerOCRModel(nn.Module):
def __init__(self, vocab_size, hidden_size=256, nhead=8,
num_encoder_layers=3, num_decoder_layers=3,
dim_feedforward=1024, dropout=0.4,
PAD_token=0, SOS_token=1, EOS_token=2, max_seq_len=150):
super().__init__()
self.feature_extractor = DenseNetFeatureExtractor(output_dim=hidden_size)
self.pos_encoder = PositionalEncoding(hidden_size)
self.transformer = nn.Transformer(
d_model=hidden_size, nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout, batch_first=True)
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
self.output_projection = nn.Linear(hidden_size, vocab_size)
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.PAD_token = PAD_token
self.SOS_token = SOS_token
self.EOS_token = EOS_token
self.max_seq_len = max_seq_len
self._init_parameters()
def _init_parameters(self):
nn.init.xavier_uniform_(self.token_embedding.weight)
nn.init.xavier_uniform_(self.output_projection.weight)
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
def _apply_pos_encoding(self, x):
x = x.permute(1, 0, 2)
x = self.pos_encoder(x)
return x.permute(1, 0, 2)
def forward(self, src, tgt):
memory = self._apply_pos_encoding(self.feature_extractor(src))
tgt_input = tgt[:, :-1]
tgt_embedded = self._apply_pos_encoding(self.token_embedding(tgt_input))
tgt_mask = self._generate_square_subsequent_mask(tgt_embedded.size(1)).to(src.device)
output = self.transformer(
src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask,
src_is_causal=False, tgt_is_causal=True)
return self.output_projection(output)
def generate(self, img):
"""Generate text from a single image"""
was_training = self.training
self.eval()
with torch.no_grad():
if img.dim() == 3:
img = img.unsqueeze(0)
memory = self._apply_pos_encoding(self.feature_extractor(img))
ys = torch.ones(1, 1).fill_(self.SOS_token).long().to(img.device)
for _ in range(self.max_seq_len - 1):
tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys))
tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(img.device)
out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask)
out = self.output_projection(out)
next_word = out[0, -1].argmax().item()
ys = torch.cat([ys, torch.ones(1, 1).long().fill_(next_word).to(img.device)], dim=1)
if next_word == self.EOS_token:
break
if was_training:
self.train(True)
return ys[0]
def generate_batch(self, imgs):
"""Generate text from a batch of images"""
self.eval()
batch_size = imgs.size(0)
with torch.no_grad():
memory = self._apply_pos_encoding(self.feature_extractor(imgs))
ys = torch.ones(batch_size, 1).fill_(self.SOS_token).long().to(imgs.device)
finished = torch.zeros(batch_size, dtype=torch.bool, device=imgs.device)
for _ in range(self.max_seq_len - 1):
tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys))
tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(imgs.device)
out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask)
out = self.output_projection(out)
next_tokens = out[:, -1].argmax(dim=-1)
next_tokens[finished] = self.PAD_token
ys = torch.cat([ys, next_tokens.unsqueeze(1)], dim=1)
finished = finished | (next_tokens == self.EOS_token)
if finished.all():
break
return ys
# ===============================
# Metrics
# ===============================
def levenshtein_distance(s1, s2):
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
prev = range(len(s2) + 1)
for c1 in s1:
curr = [prev[0] + 1]
for j, c2 in enumerate(s2):
curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (c1 != c2)))
prev = curr
return prev[-1]
def calculate_cer(preds, targets):
total_dist, total_chars = 0, 0
for p, t in zip(preds, targets):
total_dist += levenshtein_distance(p, t)
total_chars += len(t)
return total_dist / max(1, total_chars)
def calculate_wer(preds, targets):
total_dist, total_words = 0, 0
for p, t in zip(preds, targets):
total_dist += levenshtein_distance(p.split(), t.split())
total_words += len(t.split())
return total_dist / max(1, total_words)
def evaluate_cer(model, dataloader, device, idx_to_char, PAD_token, SOS_token, EOS_token):
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
for images, _, _, texts in tqdm(dataloader, desc="Evaluating"):
images = images.to(device)
batch_output = model.generate_batch(images)
for seq in batch_output:
all_preds.append(tensor_to_text(seq, idx_to_char, PAD_token, SOS_token, EOS_token))
all_targets.extend(texts)
cer = calculate_cer(all_preds, all_targets)
return cer, all_preds, all_targets
# ===============================
# Early Stopping
# ===============================
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
self.counter = 0
self.best_cer = float("inf")
self.early_stop = False
def __call__(self, val_cer, model, epoch, path):
if val_cer < self.best_cer:
self.best_cer = val_cer
self.counter = 0
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"val_cer": val_cer
}, path)
print(f"Model saved (Val CER: {val_cer:.4f})")
else:
self.counter += 1
print(f"Early stopping: {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.early_stop = True
print("Early stopping triggered.")
# ===============================
# Training Loop
# ===============================
def train_epoch(model, dataloader, optimizer, criterion, device, scheduler, PAD_token):
model.train()
epoch_loss = 0
for images, targets, _, _ in tqdm(dataloader, desc="Training"):
images, targets = images.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(images, targets)
outputs = outputs.reshape(-1, outputs.shape[-1])
targets_flat = targets[:, 1:].reshape(-1)
loss = criterion(outputs, targets_flat)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
def evaluate_model(model, dataloader, criterion, device, PAD_token):
model.eval()
epoch_loss = 0
with torch.no_grad():
for images, targets, _, _ in dataloader:
images, targets = images.to(device), targets.to(device)
outputs = model(images, targets)
outputs = outputs.reshape(-1, outputs.shape[-1])
targets_flat = targets[:, 1:].reshape(-1)
loss = criterion(outputs, targets_flat)
epoch_loss += loss.item()
return epoch_loss / len(dataloader)
# ===============================
# Main
# ===============================
def main():
global current_epoch, num_epochs_global, overfitting_detected
global validation_loss_history, training_loss_history, args
args = parse_args()
# Set seeds
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
num_epochs_global = args.num_epochs
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Output directory
os.makedirs(args.output_dir, exist_ok=True)
# Load vocabulary
char_list, char_to_idx, idx_to_char, PAD_token, SOS_token, EOS_token = \
load_vocabulary(args.vocab_path)
vocab_size = len(char_list)
print(f"Vocabulary size: {vocab_size}")
# Transforms
train_transform = build_eval_transform() if args.no_aug else build_adaptive_train_transform()
eval_transform = build_eval_transform()
# Dataset common kwargs
ds_kwargs = dict(img_height=args.img_height, img_width=args.img_width,
char_to_idx=char_to_idx, SOS_token=SOS_token, EOS_token=EOS_token)
# Build datasets
real_train_dir = os.path.join(args.data_dir, "Training")
real_val_dir = os.path.join(args.data_dir, "Validation")
real_test_dir = os.path.join(args.data_dir, "Testing")
train_datasets = [
KurdishLineDataset(real_train_dir, transform=train_transform,
dataset_name="Real Training", **ds_kwargs)
]
if args.use_synthetic and args.synthetic_dir:
syn_dir = os.path.join(args.synthetic_dir, "Training")
train_datasets.append(
KurdishLineDataset(syn_dir, transform=train_transform,
dataset_name="Synthetic Training", **ds_kwargs))
if args.use_writer_mixing and args.fixed_lines_dir:
fix_dir = os.path.join(args.fixed_lines_dir, "Training")
all_writers = get_unique_writers(fix_dir)
selected = random.sample(all_writers, min(args.num_writers, len(all_writers)))
selected_files = filter_files_by_writers(fix_dir, set(selected))
train_datasets.append(
KurdishLineDataset(image_files=selected_files, transform=train_transform,
dataset_name=f"Fixed {len(selected)} Writers", **ds_kwargs))
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
val_dataset = KurdishLineDataset(real_val_dir, transform=eval_transform,
dataset_name="Validation", **ds_kwargs)
test_dataset = KurdishLineDataset(real_test_dir, transform=eval_transform,
dataset_name="Testing", **ds_kwargs)
print(f"\nTraining: {len(train_dataset)} | Validation: {len(val_dataset)} | Testing: {len(test_dataset)}")
# Data loaders
loader_kwargs = dict(num_workers=0, pin_memory=True, collate_fn=collate_fn)
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **loader_kwargs)
val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, **loader_kwargs)
test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **loader_kwargs)
# Model
model = TransformerOCRModel(
vocab_size=vocab_size, hidden_size=args.hidden_size,
nhead=args.num_heads, num_encoder_layers=args.encoder_layers,
num_decoder_layers=args.decoder_layers, dim_feedforward=args.ff_dim,
dropout=args.dropout, PAD_token=PAD_token, SOS_token=SOS_token,
EOS_token=EOS_token).to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")
# Optimizer and schedulers
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
onecycle = optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=args.learning_rate,
steps_per_epoch=len(train_loader), epochs=args.num_epochs, pct_start=0.1)
plateau = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_token)
early_stopping = EarlyStopping(patience=args.patience)
best_model_path = os.path.join(args.output_dir, "best_model.pth")
# Training
print(f"\nStarting training for {args.num_epochs} epochs...")
best_val_cer = float("inf")
for epoch in range(args.num_epochs):
current_epoch = epoch
start = time.time()
train_loss = train_epoch(model, train_loader, optimizer, criterion, device, onecycle, PAD_token)
val_loss = evaluate_model(model, val_loader, criterion, device, PAD_token)
val_cer, _, _ = evaluate_cer(model, val_loader, device, idx_to_char,
PAD_token, SOS_token, EOS_token)
# Overfitting detection
training_loss_history.append(train_loss)
validation_loss_history.append(val_loss)
if len(training_loss_history) >= 3:
overfitting_detected = (np.mean(validation_loss_history[-3:]) >
np.mean(training_loss_history[-3:]) * 1.2)
plateau.step(val_cer)
mins, secs = divmod(time.time() - start, 60)
print(f"\nEpoch {epoch + 1}/{args.num_epochs} ({mins:.0f}m {secs:.0f}s)")
print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f}")
if val_cer < best_val_cer:
best_val_cer = val_cer
early_stopping(val_cer, model, epoch, best_model_path)
if early_stopping.early_stop:
break
# Final evaluation
print(f"\nBest validation CER: {best_val_cer:.4f}")
print("Loading best model for test evaluation...")
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint["model_state_dict"])
test_cer, test_preds, test_targets = evaluate_cer(
model, test_loader, device, idx_to_char, PAD_token, SOS_token, EOS_token)
test_wer = calculate_wer(test_preds, test_targets)
print(f"\nTest CER: {test_cer:.4f}")
print(f"Test WER: {test_wer:.4f}")
print(f"Test CRR: {(1 - test_cer) * 100:.2f}%")
for i in range(min(5, len(test_preds))):
print(f"\nSample {i + 1}:")
print(f" Predicted: {test_preds[i]}")
print(f" Actual: {test_targets[i]}")
if __name__ == "__main__":
main()