Spaces:
Paused
Paused
| import io | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| # --- 1. SCRIPT SETUP --- | |
| # Set up device (use GPU if available, otherwise CPU) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"--- Running on {DEVICE} ---") | |
| # Define model and processor IDs from Hugging Face Hub | |
| MODEL_ID = "microsoft/Florence-2-large" | |
| # For better performance, you can use the float16 version if your hardware supports it | |
| # MODEL_ID = "microsoft/Florence-2-large-ft" | |
| # --- 2. LOAD MODEL AND PROCESSOR --- | |
| # Load the model and processor from Hugging Face | |
| # trust_remote_code=True is required for Florence-2 | |
| # torch_dtype=torch.float16 is used for faster inference and lower memory on GPUs | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print("--- Model and processor loaded successfully ---") | |
| except Exception as e: | |
| print(f"--- Error loading model: {e} ---") | |
| model = None | |
| processor = None | |
| # --- 3. FASTAPI APP INITIALIZATION --- | |
| app = FastAPI( | |
| title="Florence-2 OCR API", | |
| description="An API for extracting text from images using Microsoft's Florence-2-large model. " | |
| "Handles both printed and handwritten text.", | |
| version="1.0.0" | |
| ) | |
| # --- 4. HELPER FUNCTION --- | |
| def run_florence2_ocr(image: Image.Image): | |
| """ | |
| Runs the Florence-2 model to perform OCR on a given image. | |
| Args: | |
| image (Image.Image): The input image in PIL format. | |
| Returns: | |
| str: The extracted text. | |
| """ | |
| if not model or not processor: | |
| raise HTTPException(status_code=503, detail="Model is not available. Please check server logs.") | |
| # The task prompt for OCR | |
| task_prompt = "<OCR>" | |
| # Ensure image is in RGB format | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Preprocess the image and prompt | |
| inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(DEVICE) | |
| # Move inputs to float16 if the model is in float16 | |
| if model.dtype == torch.float16: | |
| inputs = inputs.to(torch.float16) | |
| # Generate text from the image | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=2048, # Increased token limit for dense text | |
| num_beams=3, | |
| do_sample=False # Use greedy decoding for more deterministic results | |
| ) | |
| # Decode the generated IDs to text | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| # Parse the output to get only the OCR result | |
| # The model's output format is typically "<OCR>extracted_text</s>" | |
| # We remove the prompt and the end-of-sequence token | |
| parsed_text = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) | |
| return parsed_text.get('<OCR>', "Error: Could not parse OCR output.") | |
| # --- 5. API ENDPOINTS --- | |
| def read_root(): | |
| return {"message": "Welcome to the Florence-2 OCR API. Go to /docs for usage."} | |
| async def extract_text_from_image(file: UploadFile = File(..., description="Image file to process.")): | |
| """ | |
| Endpoint to perform OCR on an uploaded image. | |
| """ | |
| # Read image content from the uploaded file | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid image file. Could not open image.") | |
| # Run the OCR model | |
| try: | |
| extracted_text = run_florence2_ocr(image) | |
| return {"filename": file.filename, "extracted_text": extracted_text} | |
| except Exception as e: | |
| print(f"Error during model inference: {e}") | |
| raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}") |