| """ |
| Author : Janarddan Sarkar |
| file_name : mistral_ocr_st.py |
| date : 10-03-2025 |
| description : |
| """ |
| import os |
| import json |
| import base64 |
| import streamlit as st |
| from mistralai import Mistral |
| from dotenv import find_dotenv, load_dotenv |
| from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk |
| from mistralai.models import OCRResponse |
| from enum import Enum |
| from pydantic import BaseModel |
| import pycountry |
|
|
| |
| load_dotenv(find_dotenv()) |
| api_key = os.environ.get("MISTRAL_API_KEY") |
| client = Mistral(api_key=api_key) |
|
|
| |
| languages = {lang.alpha_2: lang.name for lang in pycountry.languages if hasattr(lang, 'alpha_2')} |
|
|
|
|
| class LanguageMeta(Enum.__class__): |
| def __new__(metacls, cls, bases, classdict): |
| for code, name in languages.items(): |
| classdict[name.upper().replace(' ', '_')] = name |
| return super().__new__(metacls, cls, bases, classdict) |
|
|
|
|
| class Language(Enum, metaclass=LanguageMeta): |
| pass |
|
|
|
|
| class StructuredOCR(BaseModel): |
| file_name: str |
| topics: list[str] |
| languages: list[Language] |
| ocr_contents: dict |
|
|
| def replace_images_in_markdown(markdown_str: str, images_dict: dict) -> str: |
| for img_name, base64_str in images_dict.items(): |
| markdown_str = markdown_str.replace(f"", f"") |
| return markdown_str |
|
|
| def get_combined_markdown(ocr_response: OCRResponse) -> str: |
| markdowns: list[str] = [] |
| for page in ocr_response.pages: |
| image_data = {img.id: img.image_base64 for img in page.images} |
| markdowns.append(replace_images_in_markdown(page.markdown, image_data)) |
| return "\n\n".join(markdowns) |
|
|
| def process_pdf(pdf_bytes, file_name): |
| """Process a PDF using OCR.""" |
| uploaded_file = client.files.upload( |
| file={"file_name": file_name, "content": pdf_bytes}, |
| purpose = "ocr", |
| ) |
| signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1) |
| pdf_response = client.ocr.process( |
| document=DocumentURLChunk(document_url=signed_url.url), |
| model="mistral-ocr-latest", |
| include_image_base64=True, |
| ) |
|
|
| |
| if isinstance(pdf_response, dict): |
| pdf_response = OCRResponse(**pdf_response) |
|
|
| return pdf_response |
|
|
|
|
| def process_image(image_bytes, file_name): |
| """Process an image using OCR.""" |
| encoded_image = base64.b64encode(image_bytes).decode() |
| base64_data_url = f"data:image/jpeg;base64,{encoded_image}" |
| image_response = client.ocr.process( |
| document=ImageURLChunk(image_url=base64_data_url), model="mistral-ocr-latest" |
| ) |
| image_ocr_markdown = image_response.pages[0].markdown |
|
|
| chat_response = client.chat.parse( |
| model="pixtral-12b-latest", |
| messages=[ |
| { |
| "role": "user", |
| "content": [ |
| ImageURLChunk(image_url=base64_data_url), |
| TextChunk( |
| text=( |
| "This is the image's OCR in markdown:\n" |
| f"<BEGIN_IMAGE_OCR>\n{image_ocr_markdown}\n<END_IMAGE_OCR>.\n" |
| "Convert this into a structured JSON response with the OCR contents in a dictionary." |
| ) |
| ), |
| ], |
| }, |
| ], |
| response_format=StructuredOCR, |
| temperature=0, |
| ) |
| return json.loads(chat_response.choices[0].message.parsed.model_dump_json()) |
|
|
|
|
| |
| st.title("OLMOCR") |
|
|
| uploaded_file = st.file_uploader("Upload a PDF or Image", type=["pdf", "png", "jpg", "jpeg"]) |
|
|
| if uploaded_file: |
| file_type = uploaded_file.type |
| file_bytes = uploaded_file.read() |
| file_name = uploaded_file.name |
|
|
| if st.button("Submit"): |
| st.write(f"**Processing file:** {file_name}") |
|
|
| if "pdf" in file_type: |
| pdf_response = process_pdf(file_bytes, file_name) |
| st.markdown(get_combined_markdown(pdf_response)) |
| else: |
| result = process_image(file_bytes, file_name) |
| st.json(result) |
|
|