Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import tempfile | |
| import gradio as gr | |
| import csv | |
| from PIL import Image | |
| from pdf2image import convert_from_path | |
| from transformers import NougatProcessor, VisionEncoderDecoderModel | |
| import nltk | |
| # Initialize NLTK | |
| nltk.download('punkt', quiet=True) | |
| # Configuration | |
| MODEL_REPO = "Tamazight/Tifinagh-Nougat-Small" | |
| BASE_REPO = "facebook/nougat-small" | |
| print("Loading processor and model...") | |
| processor = NougatProcessor.from_pretrained(BASE_REPO) | |
| # Fix image processor attributes | |
| for attr in ["do_crop_margin", "do_thumbnail", "do_align_long_axis", "do_rescale", "do_normalize", "do_resize"]: | |
| if hasattr(processor.image_processor, attr): | |
| setattr(processor.image_processor, attr, False if attr == "do_crop_margin" else (True if attr in ["do_rescale", "do_normalize", "do_resize"] else False)) | |
| processor.image_processor.rescale_factor = 1/255.0 | |
| processor.image_processor.image_mean = [0.485, 0.456, 0.406] | |
| processor.image_processor.image_std = [0.229, 0.224, 0.225] | |
| processor.image_processor.size = {"height": 896, "width": 672} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO).to(device) | |
| model.eval() | |
| # ========================================== | |
| # Post-Processing: Dictionary Corrections | |
| # ========================================== | |
| DICTIONARY_FILE = "Tokenizer.csv" | |
| corrections_dict = {} | |
| try: | |
| if os.path.exists(DICTIONARY_FILE): | |
| with open(DICTIONARY_FILE, mode='r', encoding='utf-8') as f: | |
| reader = csv.reader(f) | |
| for row in reader: | |
| if len(row) == 2: | |
| wrong_word = row[0].strip() | |
| correct_word = row[1].strip() | |
| if wrong_word and correct_word: | |
| corrections_dict[wrong_word] = correct_word | |
| except Exception as e: | |
| print(f"Error loading dictionary: {e}") | |
| def apply_dictionary_corrections(text): | |
| if not corrections_dict: | |
| return text | |
| corrected_text = text | |
| for wrong, correct in corrections_dict.items(): | |
| corrected_text = corrected_text.replace(wrong, correct) | |
| return corrected_text | |
| # ========================================== | |
| # Image Optimization | |
| # ========================================== | |
| def optimize_image(img): | |
| optimized_img = img.copy() | |
| optimized_img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) | |
| return optimized_img | |
| def get_compressed_text(text): | |
| return " ".join(text.split()) | |
| def ocr_inference(image): | |
| optimized_img = optimize_image(image) | |
| pixel_values = processor( | |
| images=optimized_img, | |
| return_tensors="pt", | |
| do_resize=True, | |
| size={"height": 896, "width": 672}, | |
| resample=Image.BILINEAR, | |
| do_rescale=True, | |
| rescale_factor=1/255.0, | |
| do_normalize=True, | |
| image_mean=[0.485, 0.456, 0.406], | |
| image_std=[0.229, 0.224, 0.225], | |
| do_crop_margin=False, | |
| do_thumbnail=False, | |
| do_align_long_axis=False | |
| ).pixel_values.to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| pixel_values, | |
| min_length=1, | |
| max_new_tokens=1500, # Safe limit to cover full pages without freezing the server | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]] | |
| ) | |
| sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| try: | |
| raw_output = processor.post_process_generation(sequence, fix_markdown=False).strip() | |
| except: | |
| raw_output = sequence.strip() | |
| corrected_output = apply_dictionary_corrections(raw_output) | |
| return corrected_output | |
| def handle_request_with_loading(input_file, file_type): | |
| if input_file is None: | |
| yield gr.update(), gr.update(), gr.update(), gr.update(visible=False), gr.update(visible=False) | |
| return | |
| yield ( | |
| gr.update(value="⏳ Processing... Please wait.", visible=True), | |
| gr.update(value=""), | |
| gr.update(value=""), | |
| gr.update(visible=True), | |
| gr.update(visible=False) | |
| ) | |
| full_raw_text = "" | |
| if file_type == "image": | |
| full_raw_text = ocr_inference(input_file.convert("RGB")) | |
| else: | |
| pages = convert_from_path(input_file.name) | |
| extracted = [] | |
| for i, page_img in enumerate(pages): | |
| text = ocr_inference(page_img.convert("RGB")) | |
| extracted.append(f"### Page {i+1}\n\n{text}") | |
| full_raw_text = "\n\n".join(extracted) | |
| compressed = get_compressed_text(full_raw_text) | |
| yield ( | |
| gr.update(value=compressed), | |
| gr.update(value=full_raw_text), | |
| gr.update(value=full_raw_text), | |
| gr.update(visible=True), | |
| gr.update(visible=True) | |
| ) | |
| def build_download_files(raw_text): | |
| if not raw_text or raw_text.startswith("⏳"): | |
| return gr.update(visible=False) | |
| output_dir = tempfile.mkdtemp() | |
| compressed = get_compressed_text(raw_text) | |
| export_data = { | |
| "output.mmd": raw_text, | |
| "output.md": raw_text, | |
| "output.txt": compressed, | |
| "output.jsonl": json.dumps({"markdown": raw_text, "text": compressed}, ensure_ascii=False) + "\n" | |
| } | |
| final_paths = [] | |
| for filename, content in export_data.items(): | |
| file_path = os.path.join(output_dir, filename) | |
| with open(file_path, "w", encoding="utf-8") as f: | |
| f.write(content) | |
| final_paths.append(file_path) | |
| return gr.update(value=final_paths, visible=True) | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h1 style='text-align: center; margin-bottom: 0;'>Tifinagh OCR</h1>") | |
| gr.HTML("<h3 style='text-align: center; margin-top: 0; font-weight: normal;'>Image/PDF to Text</h3>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Tabs(): | |
| with gr.Tab("Image Input"): | |
| image_in = gr.Image(type="pil", label="Select Image") | |
| img_btn = gr.Button("Extract Text", variant="primary") | |
| with gr.Tab("PDF Input"): | |
| gr.Markdown("⚠️ **Note:** Extraction takes time based on GPU and page count.") | |
| pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| pdf_btn = gr.Button("Extract PDF", variant="primary") | |
| with gr.Column(scale=1, visible=False) as output_section: | |
| with gr.Tabs(): | |
| with gr.Tab("Result"): | |
| display_result = gr.Textbox(label="Extracted Text", lines=20, max_lines=5000) | |
| with gr.Tab("Markdown View", visible=False) as tab_markdown: | |
| display_markdown = gr.Markdown() | |
| raw_storage = gr.State("") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gen_files_btn = gr.Button("📥 Download result", variant="secondary") | |
| output_files = gr.File(label="Choose the format", visible=False) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=[ | |
| ["img_01.webp"], ["img_02.webp"], ["img_03.webp"], | |
| ["img_04.webp"], ["img_05.webp"], ["img_06.webp"], | |
| ["img_07.webp"], ["img_08.webp"], ["img_09.webp"] | |
| ], | |
| inputs=image_in | |
| ) | |
| img_btn.click( | |
| fn=handle_request_with_loading, | |
| inputs=[image_in, gr.State("image")], | |
| outputs=[display_result, display_markdown, raw_storage, output_section, tab_markdown], | |
| scroll_to_output=True | |
| ) | |
| pdf_btn.click( | |
| fn=handle_request_with_loading, | |
| inputs=[pdf_in, gr.State("pdf")], | |
| outputs=[display_result, display_markdown, raw_storage, output_section, tab_markdown], | |
| scroll_to_output=True | |
| ) | |
| gen_files_btn.click( | |
| fn=build_download_files, | |
| inputs=raw_storage, | |
| outputs=output_files | |
| ) | |
| # CRITICAL FIX: Enabling queue to prevent freezing on Hugging Face Spaces | |
| demo.queue().launch(theme=gr.themes.Soft()) |