""" Kurdish Handwritten Line Recognition - Inference Script Usage: # Single image python inference.py --image sample.tif --model_path best_model.pth --vocab_path vocab.json # Directory of images python inference.py --image_dir ./test_images --model_path best_model.pth --vocab_path vocab.json # With safetensors format python inference.py --image sample.tif --model_path model.safetensors --vocab_path vocab.json """ import os import glob import json import math import argparse from PIL import Image import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models # =============================== # Argument Parser # =============================== def parse_args(): parser = argparse.ArgumentParser(description="Kurdish Handwritten Line Recognition - Inference") parser.add_argument("--image", type=str, default=None, help="Path to a single image") parser.add_argument("--image_dir", type=str, default=None, help="Directory of images to process") parser.add_argument("--model_path", type=str, required=True, help="Path to model (.pth or .safetensors)") parser.add_argument("--vocab_path", type=str, required=True, help="Path to vocab.json") parser.add_argument("--img_height", type=int, default=96) parser.add_argument("--img_width", type=int, default=1235) parser.add_argument("--hidden_size", type=int, default=256) parser.add_argument("--encoder_layers", type=int, default=3) parser.add_argument("--decoder_layers", type=int, default=3) parser.add_argument("--num_heads", type=int, default=8) parser.add_argument("--ff_dim", type=int, default=1024) parser.add_argument("--max_seq_len", type=int, default=150) parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu, auto-detected if not set)") return parser.parse_args() # =============================== # Vocabulary # =============================== def load_vocabulary(vocab_path): with open(vocab_path, "r", encoding="utf-8") as f: vocab_data = json.load(f) if "vocab_list" in vocab_data: char_list = vocab_data["vocab_list"] elif "char_to_idx" in vocab_data: char_to_idx = vocab_data["char_to_idx"] char_list = [None] * len(char_to_idx) for char, idx in char_to_idx.items(): char_list[idx] = char else: raise ValueError("Vocabulary JSON must contain 'vocab_list' or 'char_to_idx'") idx_to_char = {idx: char for idx, char in enumerate(char_list)} return char_list, idx_to_char # =============================== # Model Architecture # =============================== class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe) def forward(self, x): return x + self.pe[:x.size(0), :] class DenseNetFeatureExtractor(nn.Module): def __init__(self, output_dim=256): super().__init__() densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT) self.features = nn.Sequential(*list(densenet.children())[:-1]) self.adapt = nn.Conv2d(1024, output_dim, kernel_size=1) def forward(self, x): x = self.features(x) x = self.adapt(x) x = nn.functional.adaptive_avg_pool2d(x, (1, None)) x = x.squeeze(2) return x.permute(0, 2, 1) class TransformerOCRModel(nn.Module): def __init__(self, vocab_size, hidden_size=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=1024, dropout=0.0, max_seq_len=150): super().__init__() self.feature_extractor = DenseNetFeatureExtractor(output_dim=hidden_size) self.pos_encoder = PositionalEncoding(hidden_size) self.transformer = nn.Transformer( d_model=hidden_size, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True) self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.output_projection = nn.Linear(hidden_size, vocab_size) self.max_seq_len = max_seq_len self.SOS_token = 1 self.EOS_token = 2 self.PAD_token = 0 def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) def _apply_pos_encoding(self, x): x = x.permute(1, 0, 2) x = self.pos_encoder(x) return x.permute(1, 0, 2) def generate(self, img): self.eval() with torch.no_grad(): if img.dim() == 3: img = img.unsqueeze(0) memory = self._apply_pos_encoding(self.feature_extractor(img)) ys = torch.ones(1, 1).fill_(self.SOS_token).long().to(img.device) for _ in range(self.max_seq_len - 1): tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys)) tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(img.device) out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask) out = self.output_projection(out) next_word = out[0, -1].argmax().item() ys = torch.cat([ys, torch.ones(1, 1).long().fill_(next_word).to(img.device)], dim=1) if next_word == self.EOS_token: break return ys[0] # =============================== # Image Preprocessing # =============================== def preprocess_image(image_path, img_height, img_width): image = Image.open(image_path).convert("RGB") orig_w, orig_h = image.size new_h = img_height new_w = min(int(new_h * (orig_w / orig_h)), img_width) image = image.resize((new_w, new_h), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (img_width, img_height), color=(255, 255, 255)) canvas.paste(image, (0, 0)) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) return transform(canvas) # =============================== # Decode Output # =============================== def decode_output(tensor, idx_to_char): if isinstance(tensor, torch.Tensor): tensor = tensor.cpu().tolist() text = "" for idx in tensor: if idx == 0 or idx == 1: # PAD or SOS continue if idx == 2: # EOS break if idx in idx_to_char: text += idx_to_char[idx] return text # =============================== # Main # =============================== def main(): args = parse_args() if args.image is None and args.image_dir is None: print("Error: Provide --image or --image_dir") return # Device if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Vocabulary char_list, idx_to_char = load_vocabulary(args.vocab_path) vocab_size = len(char_list) print(f"Vocabulary: {vocab_size} tokens") # Model model = TransformerOCRModel( vocab_size=vocab_size, hidden_size=args.hidden_size, nhead=args.num_heads, num_encoder_layers=args.encoder_layers, num_decoder_layers=args.decoder_layers, dim_feedforward=args.ff_dim, max_seq_len=args.max_seq_len ).to(device) # Load weights if args.model_path.endswith(".safetensors"): from safetensors.torch import load_file state_dict = load_file(args.model_path) model.load_state_dict(state_dict, strict=True) else: checkpoint = torch.load(args.model_path, map_location=device) if "model_state_dict" in checkpoint: model.load_state_dict(checkpoint["model_state_dict"], strict=True) else: model.load_state_dict(checkpoint, strict=True) model.eval() print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters\n") # Collect images image_paths = [] if args.image: image_paths = [args.image] elif args.image_dir: for ext in ("*.tif", "*.tiff", "*.png", "*.jpg", "*.jpeg", "*.bmp"): image_paths.extend(glob.glob(os.path.join(args.image_dir, ext))) image_paths.sort() if not image_paths: print("No images found.") return print(f"Processing {len(image_paths)} image(s)...\n") print(f"{'File':<40} {'Predicted Text'}") print("-" * 80) for img_path in image_paths: tensor = preprocess_image(img_path, args.img_height, args.img_width).to(device) output = model.generate(tensor) text = decode_output(output, idx_to_char) filename = os.path.basename(img_path) print(f"{filename:<40} {text}") print(f"\nDone. {len(image_paths)} image(s) processed.") if __name__ == "__main__": main()