KHLR / Scripts /inference.py
Karez's picture
Upload Scripts/inference.py with huggingface_hub
c70e7c0 verified
"""
Kurdish Handwritten Line Recognition - Inference Script
Usage:
# Single image
python inference.py --image sample.tif --model_path best_model.pth --vocab_path vocab.json
# Directory of images
python inference.py --image_dir ./test_images --model_path best_model.pth --vocab_path vocab.json
# With safetensors format
python inference.py --image sample.tif --model_path model.safetensors --vocab_path vocab.json
"""
import os
import glob
import json
import math
import argparse
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
# ===============================
# Argument Parser
# ===============================
def parse_args():
parser = argparse.ArgumentParser(description="Kurdish Handwritten Line Recognition - Inference")
parser.add_argument("--image", type=str, default=None, help="Path to a single image")
parser.add_argument("--image_dir", type=str, default=None, help="Directory of images to process")
parser.add_argument("--model_path", type=str, required=True, help="Path to model (.pth or .safetensors)")
parser.add_argument("--vocab_path", type=str, required=True, help="Path to vocab.json")
parser.add_argument("--img_height", type=int, default=96)
parser.add_argument("--img_width", type=int, default=1235)
parser.add_argument("--hidden_size", type=int, default=256)
parser.add_argument("--encoder_layers", type=int, default=3)
parser.add_argument("--decoder_layers", type=int, default=3)
parser.add_argument("--num_heads", type=int, default=8)
parser.add_argument("--ff_dim", type=int, default=1024)
parser.add_argument("--max_seq_len", type=int, default=150)
parser.add_argument("--device", type=str, default=None, help="Device (cuda/cpu, auto-detected if not set)")
return parser.parse_args()
# ===============================
# Vocabulary
# ===============================
def load_vocabulary(vocab_path):
with open(vocab_path, "r", encoding="utf-8") as f:
vocab_data = json.load(f)
if "vocab_list" in vocab_data:
char_list = vocab_data["vocab_list"]
elif "char_to_idx" in vocab_data:
char_to_idx = vocab_data["char_to_idx"]
char_list = [None] * len(char_to_idx)
for char, idx in char_to_idx.items():
char_list[idx] = char
else:
raise ValueError("Vocabulary JSON must contain 'vocab_list' or 'char_to_idx'")
idx_to_char = {idx: char for idx, char in enumerate(char_list)}
return char_list, idx_to_char
# ===============================
# Model Architecture
# ===============================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer("pe", pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class DenseNetFeatureExtractor(nn.Module):
def __init__(self, output_dim=256):
super().__init__()
densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
self.features = nn.Sequential(*list(densenet.children())[:-1])
self.adapt = nn.Conv2d(1024, output_dim, kernel_size=1)
def forward(self, x):
x = self.features(x)
x = self.adapt(x)
x = nn.functional.adaptive_avg_pool2d(x, (1, None))
x = x.squeeze(2)
return x.permute(0, 2, 1)
class TransformerOCRModel(nn.Module):
def __init__(self, vocab_size, hidden_size=256, nhead=8,
num_encoder_layers=3, num_decoder_layers=3,
dim_feedforward=1024, dropout=0.0, max_seq_len=150):
super().__init__()
self.feature_extractor = DenseNetFeatureExtractor(output_dim=hidden_size)
self.pos_encoder = PositionalEncoding(hidden_size)
self.transformer = nn.Transformer(
d_model=hidden_size, nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout, batch_first=True)
self.token_embedding = nn.Embedding(vocab_size, hidden_size)
self.output_projection = nn.Linear(hidden_size, vocab_size)
self.max_seq_len = max_seq_len
self.SOS_token = 1
self.EOS_token = 2
self.PAD_token = 0
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
return mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0)
def _apply_pos_encoding(self, x):
x = x.permute(1, 0, 2)
x = self.pos_encoder(x)
return x.permute(1, 0, 2)
def generate(self, img):
self.eval()
with torch.no_grad():
if img.dim() == 3:
img = img.unsqueeze(0)
memory = self._apply_pos_encoding(self.feature_extractor(img))
ys = torch.ones(1, 1).fill_(self.SOS_token).long().to(img.device)
for _ in range(self.max_seq_len - 1):
tgt_embedded = self._apply_pos_encoding(self.token_embedding(ys))
tgt_mask = self._generate_square_subsequent_mask(ys.size(1)).to(img.device)
out = self.transformer(src=memory, tgt=tgt_embedded, tgt_mask=tgt_mask)
out = self.output_projection(out)
next_word = out[0, -1].argmax().item()
ys = torch.cat([ys, torch.ones(1, 1).long().fill_(next_word).to(img.device)], dim=1)
if next_word == self.EOS_token:
break
return ys[0]
# ===============================
# Image Preprocessing
# ===============================
def preprocess_image(image_path, img_height, img_width):
image = Image.open(image_path).convert("RGB")
orig_w, orig_h = image.size
new_h = img_height
new_w = min(int(new_h * (orig_w / orig_h)), img_width)
image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
canvas = Image.new("RGB", (img_width, img_height), color=(255, 255, 255))
canvas.paste(image, (0, 0))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return transform(canvas)
# ===============================
# Decode Output
# ===============================
def decode_output(tensor, idx_to_char):
if isinstance(tensor, torch.Tensor):
tensor = tensor.cpu().tolist()
text = ""
for idx in tensor:
if idx == 0 or idx == 1: # PAD or SOS
continue
if idx == 2: # EOS
break
if idx in idx_to_char:
text += idx_to_char[idx]
return text
# ===============================
# Main
# ===============================
def main():
args = parse_args()
if args.image is None and args.image_dir is None:
print("Error: Provide --image or --image_dir")
return
# Device
if args.device:
device = torch.device(args.device)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Vocabulary
char_list, idx_to_char = load_vocabulary(args.vocab_path)
vocab_size = len(char_list)
print(f"Vocabulary: {vocab_size} tokens")
# Model
model = TransformerOCRModel(
vocab_size=vocab_size,
hidden_size=args.hidden_size,
nhead=args.num_heads,
num_encoder_layers=args.encoder_layers,
num_decoder_layers=args.decoder_layers,
dim_feedforward=args.ff_dim,
max_seq_len=args.max_seq_len
).to(device)
# Load weights
if args.model_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(args.model_path)
model.load_state_dict(state_dict, strict=True)
else:
checkpoint = torch.load(args.model_path, map_location=device)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
else:
model.load_state_dict(checkpoint, strict=True)
model.eval()
print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} parameters\n")
# Collect images
image_paths = []
if args.image:
image_paths = [args.image]
elif args.image_dir:
for ext in ("*.tif", "*.tiff", "*.png", "*.jpg", "*.jpeg", "*.bmp"):
image_paths.extend(glob.glob(os.path.join(args.image_dir, ext)))
image_paths.sort()
if not image_paths:
print("No images found.")
return
print(f"Processing {len(image_paths)} image(s)...\n")
print(f"{'File':<40} {'Predicted Text'}")
print("-" * 80)
for img_path in image_paths:
tensor = preprocess_image(img_path, args.img_height, args.img_width).to(device)
output = model.generate(tensor)
text = decode_output(output, idx_to_char)
filename = os.path.basename(img_path)
print(f"{filename:<40} {text}")
print(f"\nDone. {len(image_paths)} image(s) processed.")
if __name__ == "__main__":
main()