""" Kurdish Handwritten Line Recognition - Training Script DenseNet121-Transformer Architecture with Constrained Synthetic Line Generation Usage: python train.py --data_dir ./data/DASTNUS --vocab_path ./vocab.json python train.py --data_dir ./data/DASTNUS --vocab_path ./vocab.json --use_synthetic --use_writer_mixing """ import os import glob import time import argparse import json import math import random import re import numpy as np from PIL import Image, ImageFilter from datetime import datetime import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torch.utils.data import ConcatDataset import torchvision.transforms as transforms import torchvision.models as models from torchvision.transforms import InterpolationMode from tqdm import tqdm # =============================== # Argument Parser # =============================== def parse_args(): parser = argparse.ArgumentParser(description="Kurdish Handwritten Line Recognition Training") # Data paths parser.add_argument("--data_dir", type=str, required=True, help="Root directory of DASTNUS dataset") parser.add_argument("--vocab_path", type=str, required=True, help="Path to vocabulary JSON file (vocab.json)") parser.add_argument("--synthetic_dir", type=str, default=None, help="Directory containing synthetic handwritten lines") parser.add_argument("--fixed_lines_dir", type=str, default=None, help="Directory containing fixed-content handwritten lines") # Data options parser.add_argument("--use_synthetic", action="store_true", help="Include synthetic handwritten lines in training") parser.add_argument("--use_writer_mixing", action="store_true", help="Include fixed-content lines from random writers") parser.add_argument("--num_writers", type=int, default=50, help="Number of writers to randomly select for mixing") # Image dimensions parser.add_argument("--img_height", type=int, default=96) parser.add_argument("--img_width", type=int, default=1235) # Training hyperparameters parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--num_epochs", type=int, default=80) parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--grad_clip", type=float, default=5.0) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument("--seed", type=int, default=42) # Model parameters parser.add_argument("--hidden_size", type=int, default=256) parser.add_argument("--encoder_layers", type=int, default=3) parser.add_argument("--decoder_layers", type=int, default=3) parser.add_argument("--num_heads", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.4) parser.add_argument("--ff_dim", type=int, default=1024) # Early stopping parser.add_argument("--patience", type=int, default=10) # Augmentation parser.add_argument("--no_aug", action="store_true", help="Disable adaptive augmentation") # Output parser.add_argument("--output_dir", type=str, default="./output", help="Directory to save models and logs") return parser.parse_args() # =============================== # Vocabulary Loader # =============================== def load_vocabulary(vocab_path): """Load vocabulary from JSON file""" with open(vocab_path, "r", encoding="utf-8") as f: vocab_data = json.load(f) if "vocab_list" in vocab_data: char_list = vocab_data["vocab_list"] elif "char_to_idx" in vocab_data: char_to_idx = vocab_data["char_to_idx"] char_list = [None] * len(char_to_idx) for char, idx in char_to_idx.items(): char_list[idx] = char else: raise ValueError("Vocabulary JSON must contain 'vocab_list' or 'char_to_idx'") char_to_idx = {char: idx for idx, char in enumerate(char_list)} idx_to_char = {idx: char for idx, char in enumerate(char_list)} PAD_token = 0 SOS_token = 1 EOS_token = 2 return char_list, char_to_idx, idx_to_char, PAD_token, SOS_token, EOS_token # =============================== # Helper Functions # =============================== def tensor_to_text(tensor, idx_to_char, PAD_token, SOS_token, EOS_token): """Convert a tensor of character indices to readable text""" if isinstance(tensor, torch.Tensor): tensor = tensor.cpu().tolist() text = "" for idx in tensor: if idx == PAD_token or idx == SOS_token: continue if idx == EOS_token: break if idx in idx_to_char: text += idx_to_char[idx] return text def extract_writer_id(filename): """Extract writer ID from filename (e.g., DNDK00002_2_1.tif -> 2)""" basename = os.path.basename(filename) match = re.match(r"DNDK(\d+)", basename) if match: return int(match.group(1)) return None def get_unique_writers(directory): """Get all unique writer IDs from a directory""" image_files = glob.glob(os.path.join(directory, "*.tif")) writer_ids = set() for f in image_files: wid = extract_writer_id(f) if wid is not None: writer_ids.add(wid) return sorted(list(writer_ids)) def filter_files_by_writers(directory, selected_writers): """Filter image files to only include those from selected writers""" all_files = glob.glob(os.path.join(directory, "*.tif")) return [f for f in all_files if extract_writer_id(f) in selected_writers] # =============================== # Dataset Class # =============================== # Global variables for adaptive augmentation current_epoch = 0 num_epochs_global = 80 overfitting_detected = False validation_loss_history = [] training_loss_history = [] class KurdishLineDataset(data.Dataset): def __init__(self, root_dir=None, transform=None, max_samples=None, dataset_name="", image_files=None, img_height=96, img_width=1235, char_to_idx=None, SOS_token=1, EOS_token=2): self.transform = transform self.dataset_name = dataset_name self.img_height = img_height self.img_width = img_width self.char_to_idx = char_to_idx self.SOS_token = SOS_token self.EOS_token = EOS_token if image_files is not None: self.image_files = image_files else: self.image_files = glob.glob(os.path.join(root_dir, "*.tif")) if max_samples and max_samples < len(self.image_files): self.image_files = self.image_files[:max_samples] print(f"Loaded {len(self.image_files)} images for {dataset_name}") def __len__(self): return len(self.image_files) def __getitem__(self, idx): img_path = self.image_files[idx] label_path = os.path.splitext(img_path)[0] + ".txt" image = Image.open(img_path).convert("RGB") orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height new_height = self.img_height new_width = min(int(new_height * aspect_ratio), self.img_width) image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) target_img = Image.new("RGB", (self.img_width, self.img_height), color=(255, 255, 255)) target_img.paste(image, (0, 0)) if self.transform: target_img = self.transform(target_img) try: with open(label_path, "r", encoding="utf-8") as f: text = f.readline().strip() except UnicodeDecodeError: with open(label_path, "r", encoding="utf-8-sig") as f: text = f.readline().strip() indices = ([self.SOS_token] + [self.char_to_idx.get(char, self.char_to_idx.get(" ", 0)) for char in text] + [self.EOS_token]) target = torch.LongTensor(indices) target_length = len(indices) return target_img, target, target_length, text def collate_fn(batch): """Custom collate function for padding sequences""" batch.sort(key=lambda x: x[2], reverse=True) images, targets, lengths, original_texts = zip(*batch) images = torch.stack(images, 0) max_length = max(lengths) padded_targets = torch.zeros(len(targets), max_length).long() for i, target in enumerate(targets): padded_targets[i, :lengths[i]] = target[:lengths[i]] lengths = torch.LongTensor(lengths) return images, padded_targets, lengths, original_texts # =============================== # Adaptive Augmentation # =============================== class AdaptiveStrokeWidthJitter: def __init__(self, base_p=0.2, max_p=0.6, base_kernel=3, max_kernel=5): self.base_p, self.max_p = base_p, max_p self.base_kernel, self.max_kernel = base_kernel, max_kernel def __call__(self, img): progress = min(current_epoch / num_epochs_global, 1.0) factor = 1.5 if overfitting_detected else 1.0 p = min(self.base_p + (self.max_p - self.base_p) * progress * factor, self.max_p) kernel = self.base_kernel + int(2 * progress) if kernel % 2 == 0: kernel += 1 kernel = min(kernel, self.max_kernel) if random.random() < p: if random.random() < 0.5: return img.filter(ImageFilter.MinFilter(kernel)) return img.filter(ImageFilter.MaxFilter(kernel)) return img class AdaptiveGaussianNoise: def __init__(self, base_std=(0.0, 0.01), max_std=(0.0, 0.03), base_p=0.3, max_p=0.7): self.base_std, self.max_std = base_std, max_std self.base_p, self.max_p = base_p, max_p def __call__(self, tensor): progress = min(current_epoch / num_epochs_global, 1.0) factor = 1.5 if overfitting_detected else 1.0 p = min(self.base_p + (self.max_p - self.base_p) * progress * factor, self.max_p) std_high = min(self.base_std[1] + (self.max_std[1] - self.base_std[1]) * progress * factor, self.max_std[1]) if torch.rand(1).item() < p: noise = torch.randn_like(tensor) * random.uniform(self.base_std[0], std_high) tensor = torch.clamp(tensor + noise, 0.0, 1.0) return tensor def build_adaptive_train_transform(): class AdaptiveTransform: def __call__(self, img): progress = min(current_epoch / num_epochs_global, 1.0) factor = 1.3 if overfitting_detected else 1.0 if random.random() < min(0.6 + 0.35 * progress * factor, 0.95): b = min(0.1 + 0.2 * progress * factor, 0.3) img = transforms.ColorJitter(brightness=b, contrast=b)(img) if random.random() < min(0.7 + 0.25 * progress * factor, 0.95): deg = min(1 + 4 * progress * factor, 5) shear = min(3 + 7 * progress * factor, 10) img = transforms.RandomAffine( degrees=deg, translate=(min(0.01 + 0.02 * progress, 0.03), min(0.03 + 0.05 * progress, 0.08)), scale=(max(1 - 0.02 - 0.08 * progress, 0.90), min(1 + 0.02 + 0.08 * progress, 1.10)), shear=(-shear, shear), interpolation=InterpolationMode.BILINEAR, fill=255)(img) if random.random() < min(0.1 + 0.4 * progress * factor, 0.5): dist = min(0.02 + 0.06 * progress * factor, 0.08) img = transforms.RandomPerspective( distortion_scale=dist, p=1.0, interpolation=InterpolationMode.BILINEAR, fill=255)(img) if random.random() < min(0.15 + 0.2 * progress, 0.35): img = transforms.GaussianBlur( kernel_size=3, sigma=(0.1, min(0.5 + 0.5 * progress, 1.0)))(img) img = AdaptiveStrokeWidthJitter()(img) img = transforms.ToTensor()(img) img = AdaptiveGaussianNoise()(img) if random.random() < min(0.1 + 0.3 * progress * factor, 0.4): img = transforms.RandomErasing( p=1.0, scale=(0.01, min(0.01 + 0.04 * progress, 0.05)), ratio=(0.3, 3.3), value="random")(img) img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img) return img return AdaptiveTransform() def build_eval_transform(): return transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # =============================== # Model Architecture # =============================== class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe) def forward(self, x): return x + self.pe[:x.size(0), :] class DenseNetFeatureExtractor(nn.Module): def __init__(self, output_dim=256): super().__init__() densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT) self.features = nn.Sequential(*list(densenet.children())[:-1]) self.adapt = nn.Conv2d(1024, output_dim, kernel_size=1) def forward(self, x): x = self.features(x) x = self.adapt(x) x = nn.functional.adaptive_avg_pool2d(x, (1, None)) x = x.squeeze(2) return x.permute(0, 2, 1) class TransformerOCRModel(nn.Module): def __init__(self, vocab_size, hidden_size=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=1024, dropout=0.4, PAD_token=0, SOS_token=1, EOS_token=2, max_seq_len=150): super().__init__() self.feature_extractor = DenseNetFeatureExtractor(output_dim=hidden_size) self.pos_encoder = PositionalEncoding(hidden_size) self.transformer = nn.Transformer( d_model=hidden_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True) self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.output_projection = nn.Linear(hidden_size, vocab_size) self.hidden_size = hidden_size self.vocab_size = vocab_size self.PAD_token = PAD_token self.SOS_token = SOS_token self.EOS_token = EOS_token self.max_seq_len = max_seq_len self._init_parameters() def _init_parameters(self): nn.init.xavier_uniform_(self.token_embedding.weight) nn.init.xavier_uniform_(self.output_projection.weight) def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) def _apply_pos_encoding(self, x): x = x.permute(1, 0, 2) x = self.pos_encoder(x) return x.permute(1, 0, 2) def forward(self, src, tgt): memory = self._apply_pos_encoding(self.feature_extractor(src)) tgt_input = tgt[:, :-1] tgt_embedded = self._apply_pos_encoding(self.token_embedding(tgt_input)) tgt_mask = self._generate_square_subsequent_mask(tgt_embedded.size(1)).to(src.device) output = self.transformer( src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask, src_is_causal=False, tgt_is_causal=True) return self.output_projection(output) def generate(self, img): """Generate text from a single image""" was_training = self.training self.eval() with torch.no_grad(): if img.dim() == 3: img = img.unsqueeze(0) memory = self._apply_pos_encoding(self.feature_extractor(img)) ys = torch.ones(1, 1).fill_(self.SOS_token).long().to(img.device) for _ in range(self.max_seq_len - 1): tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys)) tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(img.device) out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask) out = self.output_projection(out) next_word = out[0, -1].argmax().item() ys = torch.cat([ys, torch.ones(1, 1).long().fill_(next_word).to(img.device)], dim=1) if next_word == self.EOS_token: break if was_training: self.train(True) return ys[0] def generate_batch(self, imgs): """Generate text from a batch of images""" self.eval() batch_size = imgs.size(0) with torch.no_grad(): memory = self._apply_pos_encoding(self.feature_extractor(imgs)) ys = torch.ones(batch_size, 1).fill_(self.SOS_token).long().to(imgs.device) finished = torch.zeros(batch_size, dtype=torch.bool, device=imgs.device) for _ in range(self.max_seq_len - 1): tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys)) tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(imgs.device) out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask) out = self.output_projection(out) next_tokens = out[:, -1].argmax(dim=-1) next_tokens[finished] = self.PAD_token ys = torch.cat([ys, next_tokens.unsqueeze(1)], dim=1) finished = finished | (next_tokens == self.EOS_token) if finished.all(): break return ys # =============================== # Metrics # =============================== def levenshtein_distance(s1, s2): if len(s1) < len(s2): return levenshtein_distance(s2, s1) if len(s2) == 0: return len(s1) prev = range(len(s2) + 1) for c1 in s1: curr = [prev[0] + 1] for j, c2 in enumerate(s2): curr.append(min(prev[j + 1] + 1, curr[j] + 1, prev[j] + (c1 != c2))) prev = curr return prev[-1] def calculate_cer(preds, targets): total_dist, total_chars = 0, 0 for p, t in zip(preds, targets): total_dist += levenshtein_distance(p, t) total_chars += len(t) return total_dist / max(1, total_chars) def calculate_wer(preds, targets): total_dist, total_words = 0, 0 for p, t in zip(preds, targets): total_dist += levenshtein_distance(p.split(), t.split()) total_words += len(t.split()) return total_dist / max(1, total_words) def evaluate_cer(model, dataloader, device, idx_to_char, PAD_token, SOS_token, EOS_token): model.eval() all_preds, all_targets = [], [] with torch.no_grad(): for images, _, _, texts in tqdm(dataloader, desc="Evaluating"): images = images.to(device) batch_output = model.generate_batch(images) for seq in batch_output: all_preds.append(tensor_to_text(seq, idx_to_char, PAD_token, SOS_token, EOS_token)) all_targets.extend(texts) cer = calculate_cer(all_preds, all_targets) return cer, all_preds, all_targets # =============================== # Early Stopping # =============================== class EarlyStopping: def __init__(self, patience=10): self.patience = patience self.counter = 0 self.best_cer = float("inf") self.early_stop = False def __call__(self, val_cer, model, epoch, path): if val_cer < self.best_cer: self.best_cer = val_cer self.counter = 0 torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "val_cer": val_cer }, path) print(f"Model saved (Val CER: {val_cer:.4f})") else: self.counter += 1 print(f"Early stopping: {self.counter}/{self.patience}") if self.counter >= self.patience: self.early_stop = True print("Early stopping triggered.") # =============================== # Training Loop # =============================== def train_epoch(model, dataloader, optimizer, criterion, device, scheduler, PAD_token): model.train() epoch_loss = 0 for images, targets, _, _ in tqdm(dataloader, desc="Training"): images, targets = images.to(device), targets.to(device) optimizer.zero_grad() outputs = model(images, targets) outputs = outputs.reshape(-1, outputs.shape[-1]) targets_flat = targets[:, 1:].reshape(-1) loss = criterion(outputs, targets_flat) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() scheduler.step() epoch_loss += loss.item() return epoch_loss / len(dataloader) def evaluate_model(model, dataloader, criterion, device, PAD_token): model.eval() epoch_loss = 0 with torch.no_grad(): for images, targets, _, _ in dataloader: images, targets = images.to(device), targets.to(device) outputs = model(images, targets) outputs = outputs.reshape(-1, outputs.shape[-1]) targets_flat = targets[:, 1:].reshape(-1) loss = criterion(outputs, targets_flat) epoch_loss += loss.item() return epoch_loss / len(dataloader) # =============================== # Main # =============================== def main(): global current_epoch, num_epochs_global, overfitting_detected global validation_loss_history, training_loss_history, args args = parse_args() # Set seeds torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) num_epochs_global = args.num_epochs # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Output directory os.makedirs(args.output_dir, exist_ok=True) # Load vocabulary char_list, char_to_idx, idx_to_char, PAD_token, SOS_token, EOS_token = \ load_vocabulary(args.vocab_path) vocab_size = len(char_list) print(f"Vocabulary size: {vocab_size}") # Transforms train_transform = build_eval_transform() if args.no_aug else build_adaptive_train_transform() eval_transform = build_eval_transform() # Dataset common kwargs ds_kwargs = dict(img_height=args.img_height, img_width=args.img_width, char_to_idx=char_to_idx, SOS_token=SOS_token, EOS_token=EOS_token) # Build datasets real_train_dir = os.path.join(args.data_dir, "Training") real_val_dir = os.path.join(args.data_dir, "Validation") real_test_dir = os.path.join(args.data_dir, "Testing") train_datasets = [ KurdishLineDataset(real_train_dir, transform=train_transform, dataset_name="Real Training", **ds_kwargs) ] if args.use_synthetic and args.synthetic_dir: syn_dir = os.path.join(args.synthetic_dir, "Training") train_datasets.append( KurdishLineDataset(syn_dir, transform=train_transform, dataset_name="Synthetic Training", **ds_kwargs)) if args.use_writer_mixing and args.fixed_lines_dir: fix_dir = os.path.join(args.fixed_lines_dir, "Training") all_writers = get_unique_writers(fix_dir) selected = random.sample(all_writers, min(args.num_writers, len(all_writers))) selected_files = filter_files_by_writers(fix_dir, set(selected)) train_datasets.append( KurdishLineDataset(image_files=selected_files, transform=train_transform, dataset_name=f"Fixed {len(selected)} Writers", **ds_kwargs)) train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] val_dataset = KurdishLineDataset(real_val_dir, transform=eval_transform, dataset_name="Validation", **ds_kwargs) test_dataset = KurdishLineDataset(real_test_dir, transform=eval_transform, dataset_name="Testing", **ds_kwargs) print(f"\nTraining: {len(train_dataset)} | Validation: {len(val_dataset)} | Testing: {len(test_dataset)}") # Data loaders loader_kwargs = dict(num_workers=0, pin_memory=True, collate_fn=collate_fn) train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, **loader_kwargs) val_loader = data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, **loader_kwargs) test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, **loader_kwargs) # Model model = TransformerOCRModel( vocab_size=vocab_size, hidden_size=args.hidden_size, nhead=args.num_heads, num_encoder_layers=args.encoder_layers, num_decoder_layers=args.decoder_layers, dim_feedforward=args.ff_dim, dropout=args.dropout, PAD_token=PAD_token, SOS_token=SOS_token, EOS_token=EOS_token).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Model parameters: {total_params:,}") # Optimizer and schedulers optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) onecycle = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.learning_rate, steps_per_epoch=len(train_loader), epochs=args.num_epochs, pct_start=0.1) plateau = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6) criterion = nn.CrossEntropyLoss(ignore_index=PAD_token) early_stopping = EarlyStopping(patience=args.patience) best_model_path = os.path.join(args.output_dir, "best_model.pth") # Training print(f"\nStarting training for {args.num_epochs} epochs...") best_val_cer = float("inf") for epoch in range(args.num_epochs): current_epoch = epoch start = time.time() train_loss = train_epoch(model, train_loader, optimizer, criterion, device, onecycle, PAD_token) val_loss = evaluate_model(model, val_loader, criterion, device, PAD_token) val_cer, _, _ = evaluate_cer(model, val_loader, device, idx_to_char, PAD_token, SOS_token, EOS_token) # Overfitting detection training_loss_history.append(train_loss) validation_loss_history.append(val_loss) if len(training_loss_history) >= 3: overfitting_detected = (np.mean(validation_loss_history[-3:]) > np.mean(training_loss_history[-3:]) * 1.2) plateau.step(val_cer) mins, secs = divmod(time.time() - start, 60) print(f"\nEpoch {epoch + 1}/{args.num_epochs} ({mins:.0f}m {secs:.0f}s)") print(f" Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val CER: {val_cer:.4f}") if val_cer < best_val_cer: best_val_cer = val_cer early_stopping(val_cer, model, epoch, best_model_path) if early_stopping.early_stop: break # Final evaluation print(f"\nBest validation CER: {best_val_cer:.4f}") print("Loading best model for test evaluation...") checkpoint = torch.load(best_model_path) model.load_state_dict(checkpoint["model_state_dict"]) test_cer, test_preds, test_targets = evaluate_cer( model, test_loader, device, idx_to_char, PAD_token, SOS_token, EOS_token) test_wer = calculate_wer(test_preds, test_targets) print(f"\nTest CER: {test_cer:.4f}") print(f"Test WER: {test_wer:.4f}") print(f"Test CRR: {(1 - test_cer) * 100:.2f}%") for i in range(min(5, len(test_preds))): print(f"\nSample {i + 1}:") print(f" Predicted: {test_preds[i]}") print(f" Actual: {test_targets[i]}") if __name__ == "__main__": main()