import os import json import torch import tempfile import gradio as gr import csv from PIL import Image from pdf2image import convert_from_path from transformers import NougatProcessor, VisionEncoderDecoderModel import nltk # Initialize NLTK nltk.download('punkt', quiet=True) # Configuration MODEL_REPO = "Tamazight/Tifinagh-Nougat-Small" BASE_REPO = "facebook/nougat-small" print("Loading processor and model...") processor = NougatProcessor.from_pretrained(BASE_REPO) # Fix image processor attributes for attr in ["do_crop_margin", "do_thumbnail", "do_align_long_axis", "do_rescale", "do_normalize", "do_resize"]: if hasattr(processor.image_processor, attr): setattr(processor.image_processor, attr, False if attr == "do_crop_margin" else (True if attr in ["do_rescale", "do_normalize", "do_resize"] else False)) processor.image_processor.rescale_factor = 1/255.0 processor.image_processor.image_mean = [0.485, 0.456, 0.406] processor.image_processor.image_std = [0.229, 0.224, 0.225] processor.image_processor.size = {"height": 896, "width": 672} device = "cuda" if torch.cuda.is_available() else "cpu" model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO).to(device) model.eval() # ========================================== # Post-Processing: Dictionary Corrections # ========================================== DICTIONARY_FILE = "Tokenizer.csv" corrections_dict = {} try: if os.path.exists(DICTIONARY_FILE): with open(DICTIONARY_FILE, mode='r', encoding='utf-8') as f: reader = csv.reader(f) for row in reader: if len(row) == 2: wrong_word = row[0].strip() correct_word = row[1].strip() if wrong_word and correct_word: corrections_dict[wrong_word] = correct_word except Exception as e: print(f"Error loading dictionary: {e}") def apply_dictionary_corrections(text): if not corrections_dict: return text corrected_text = text for wrong, correct in corrections_dict.items(): corrected_text = corrected_text.replace(wrong, correct) return corrected_text # ========================================== # Image Optimization # ========================================== def optimize_image(img): optimized_img = img.copy() optimized_img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) return optimized_img def get_compressed_text(text): return " ".join(text.split()) def ocr_inference(image): optimized_img = optimize_image(image) pixel_values = processor( images=optimized_img, return_tensors="pt", do_resize=True, size={"height": 896, "width": 672}, resample=Image.BILINEAR, do_rescale=True, rescale_factor=1/255.0, do_normalize=True, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], do_crop_margin=False, do_thumbnail=False, do_align_long_axis=False ).pixel_values.to(device) with torch.no_grad(): outputs = model.generate( pixel_values, min_length=1, max_new_tokens=1500, # Safe limit to cover full pages without freezing the server bad_words_ids=[[processor.tokenizer.unk_token_id]] ) sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] try: raw_output = processor.post_process_generation(sequence, fix_markdown=False).strip() except: raw_output = sequence.strip() corrected_output = apply_dictionary_corrections(raw_output) return corrected_output def handle_request_with_loading(input_file, file_type): if input_file is None: yield gr.update(), gr.update(), gr.update(), gr.update(visible=False), gr.update(visible=False) return yield ( gr.update(value="⏳ Processing... Please wait.", visible=True), gr.update(value=""), gr.update(value=""), gr.update(visible=True), gr.update(visible=False) ) full_raw_text = "" if file_type == "image": full_raw_text = ocr_inference(input_file.convert("RGB")) else: pages = convert_from_path(input_file.name) extracted = [] for i, page_img in enumerate(pages): text = ocr_inference(page_img.convert("RGB")) extracted.append(f"### Page {i+1}\n\n{text}") full_raw_text = "\n\n".join(extracted) compressed = get_compressed_text(full_raw_text) yield ( gr.update(value=compressed), gr.update(value=full_raw_text), gr.update(value=full_raw_text), gr.update(visible=True), gr.update(visible=True) ) def build_download_files(raw_text): if not raw_text or raw_text.startswith("⏳"): return gr.update(visible=False) output_dir = tempfile.mkdtemp() compressed = get_compressed_text(raw_text) export_data = { "output.mmd": raw_text, "output.md": raw_text, "output.txt": compressed, "output.jsonl": json.dumps({"markdown": raw_text, "text": compressed}, ensure_ascii=False) + "\n" } final_paths = [] for filename, content in export_data.items(): file_path = os.path.join(output_dir, filename) with open(file_path, "w", encoding="utf-8") as f: f.write(content) final_paths.append(file_path) return gr.update(value=final_paths, visible=True) with gr.Blocks() as demo: gr.HTML("