Akseltinfat's picture
Update app.py
8684ac7 verified
import os
import json
import torch
import tempfile
import gradio as gr
import csv
from PIL import Image
from pdf2image import convert_from_path
from transformers import NougatProcessor, VisionEncoderDecoderModel
import nltk
# Initialize NLTK
nltk.download('punkt', quiet=True)
# Configuration
MODEL_REPO = "Tamazight/Tifinagh-Nougat-Small"
BASE_REPO = "facebook/nougat-small"
print("Loading processor and model...")
processor = NougatProcessor.from_pretrained(BASE_REPO)
# Fix image processor attributes
for attr in ["do_crop_margin", "do_thumbnail", "do_align_long_axis", "do_rescale", "do_normalize", "do_resize"]:
if hasattr(processor.image_processor, attr):
setattr(processor.image_processor, attr, False if attr == "do_crop_margin" else (True if attr in ["do_rescale", "do_normalize", "do_resize"] else False))
processor.image_processor.rescale_factor = 1/255.0
processor.image_processor.image_mean = [0.485, 0.456, 0.406]
processor.image_processor.image_std = [0.229, 0.224, 0.225]
processor.image_processor.size = {"height": 896, "width": 672}
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionEncoderDecoderModel.from_pretrained(MODEL_REPO).to(device)
model.eval()
# ==========================================
# Post-Processing: Dictionary Corrections
# ==========================================
DICTIONARY_FILE = "Tokenizer.csv"
corrections_dict = {}
try:
if os.path.exists(DICTIONARY_FILE):
with open(DICTIONARY_FILE, mode='r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
if len(row) == 2:
wrong_word = row[0].strip()
correct_word = row[1].strip()
if wrong_word and correct_word:
corrections_dict[wrong_word] = correct_word
except Exception as e:
print(f"Error loading dictionary: {e}")
def apply_dictionary_corrections(text):
if not corrections_dict:
return text
corrected_text = text
for wrong, correct in corrections_dict.items():
corrected_text = corrected_text.replace(wrong, correct)
return corrected_text
# ==========================================
# Image Optimization
# ==========================================
def optimize_image(img):
optimized_img = img.copy()
optimized_img.thumbnail((1024, 1024), Image.Resampling.LANCZOS)
return optimized_img
def get_compressed_text(text):
return " ".join(text.split())
def ocr_inference(image):
optimized_img = optimize_image(image)
pixel_values = processor(
images=optimized_img,
return_tensors="pt",
do_resize=True,
size={"height": 896, "width": 672},
resample=Image.BILINEAR,
do_rescale=True,
rescale_factor=1/255.0,
do_normalize=True,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
do_crop_margin=False,
do_thumbnail=False,
do_align_long_axis=False
).pixel_values.to(device)
with torch.no_grad():
outputs = model.generate(
pixel_values,
min_length=1,
max_new_tokens=1500, # Safe limit to cover full pages without freezing the server
bad_words_ids=[[processor.tokenizer.unk_token_id]]
)
sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
try:
raw_output = processor.post_process_generation(sequence, fix_markdown=False).strip()
except:
raw_output = sequence.strip()
corrected_output = apply_dictionary_corrections(raw_output)
return corrected_output
def handle_request_with_loading(input_file, file_type):
if input_file is None:
yield gr.update(), gr.update(), gr.update(), gr.update(visible=False), gr.update(visible=False)
return
yield (
gr.update(value="⏳ Processing... Please wait.", visible=True),
gr.update(value=""),
gr.update(value=""),
gr.update(visible=True),
gr.update(visible=False)
)
full_raw_text = ""
if file_type == "image":
full_raw_text = ocr_inference(input_file.convert("RGB"))
else:
pages = convert_from_path(input_file.name)
extracted = []
for i, page_img in enumerate(pages):
text = ocr_inference(page_img.convert("RGB"))
extracted.append(f"### Page {i+1}\n\n{text}")
full_raw_text = "\n\n".join(extracted)
compressed = get_compressed_text(full_raw_text)
yield (
gr.update(value=compressed),
gr.update(value=full_raw_text),
gr.update(value=full_raw_text),
gr.update(visible=True),
gr.update(visible=True)
)
def build_download_files(raw_text):
if not raw_text or raw_text.startswith("⏳"):
return gr.update(visible=False)
output_dir = tempfile.mkdtemp()
compressed = get_compressed_text(raw_text)
export_data = {
"output.mmd": raw_text,
"output.md": raw_text,
"output.txt": compressed,
"output.jsonl": json.dumps({"markdown": raw_text, "text": compressed}, ensure_ascii=False) + "\n"
}
final_paths = []
for filename, content in export_data.items():
file_path = os.path.join(output_dir, filename)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
final_paths.append(file_path)
return gr.update(value=final_paths, visible=True)
with gr.Blocks() as demo:
gr.HTML("<h1 style='text-align: center; margin-bottom: 0;'>Tifinagh OCR</h1>")
gr.HTML("<h3 style='text-align: center; margin-top: 0; font-weight: normal;'>Image/PDF to Text</h3>")
with gr.Row():
with gr.Column(scale=1):
with gr.Tabs():
with gr.Tab("Image Input"):
image_in = gr.Image(type="pil", label="Select Image")
img_btn = gr.Button("Extract Text", variant="primary")
with gr.Tab("PDF Input"):
gr.Markdown("⚠️ **Note:** Extraction takes time based on GPU and page count.")
pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"])
pdf_btn = gr.Button("Extract PDF", variant="primary")
with gr.Column(scale=1, visible=False) as output_section:
with gr.Tabs():
with gr.Tab("Result"):
display_result = gr.Textbox(label="Extracted Text", lines=20, max_lines=5000)
with gr.Tab("Markdown View", visible=False) as tab_markdown:
display_markdown = gr.Markdown()
raw_storage = gr.State("")
gr.Markdown("---")
with gr.Row():
with gr.Column():
gen_files_btn = gr.Button("📥 Download result", variant="secondary")
output_files = gr.File(label="Choose the format", visible=False)
gr.Markdown("---")
with gr.Row():
with gr.Column():
gr.Examples(
examples=[
["img_01.webp"], ["img_02.webp"], ["img_03.webp"],
["img_04.webp"], ["img_05.webp"], ["img_06.webp"],
["img_07.webp"], ["img_08.webp"], ["img_09.webp"]
],
inputs=image_in
)
img_btn.click(
fn=handle_request_with_loading,
inputs=[image_in, gr.State("image")],
outputs=[display_result, display_markdown, raw_storage, output_section, tab_markdown],
scroll_to_output=True
)
pdf_btn.click(
fn=handle_request_with_loading,
inputs=[pdf_in, gr.State("pdf")],
outputs=[display_result, display_markdown, raw_storage, output_section, tab_markdown],
scroll_to_output=True
)
gen_files_btn.click(
fn=build_download_files,
inputs=raw_storage,
outputs=output_files
)
# CRITICAL FIX: Enabling queue to prevent freezing on Hugging Face Spaces
demo.queue().launch(theme=gr.themes.Soft())