import io import torch from fastapi import FastAPI, File, UploadFile, HTTPException from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor # --- 1. SCRIPT SETUP --- # Set up device (use GPU if available, otherwise CPU) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"--- Running on {DEVICE} ---") # Define model and processor IDs from Hugging Face Hub MODEL_ID = "microsoft/Florence-2-large" # For better performance, you can use the float16 version if your hardware supports it # MODEL_ID = "microsoft/Florence-2-large-ft" # --- 2. LOAD MODEL AND PROCESSOR --- # Load the model and processor from Hugging Face # trust_remote_code=True is required for Florence-2 # torch_dtype=torch.float16 is used for faster inference and lower memory on GPUs try: model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE) processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) print("--- Model and processor loaded successfully ---") except Exception as e: print(f"--- Error loading model: {e} ---") model = None processor = None # --- 3. FASTAPI APP INITIALIZATION --- app = FastAPI( title="Florence-2 OCR API", description="An API for extracting text from images using Microsoft's Florence-2-large model. " "Handles both printed and handwritten text.", version="1.0.0" ) # --- 4. HELPER FUNCTION --- def run_florence2_ocr(image: Image.Image): """ Runs the Florence-2 model to perform OCR on a given image. Args: image (Image.Image): The input image in PIL format. Returns: str: The extracted text. """ if not model or not processor: raise HTTPException(status_code=503, detail="Model is not available. Please check server logs.") # The task prompt for OCR task_prompt = "" # Ensure image is in RGB format if image.mode != "RGB": image = image.convert("RGB") # Preprocess the image and prompt inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(DEVICE) # Move inputs to float16 if the model is in float16 if model.dtype == torch.float16: inputs = inputs.to(torch.float16) # Generate text from the image generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=2048, # Increased token limit for dense text num_beams=3, do_sample=False # Use greedy decoding for more deterministic results ) # Decode the generated IDs to text generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] # Parse the output to get only the OCR result # The model's output format is typically "extracted_text" # We remove the prompt and the end-of-sequence token parsed_text = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height)) return parsed_text.get('', "Error: Could not parse OCR output.") # --- 5. API ENDPOINTS --- @app.get("/", summary="Root Endpoint", description="Returns a welcome message.") def read_root(): return {"message": "Welcome to the Florence-2 OCR API. Go to /docs for usage."} @app.post("/ocr", summary="Extract Text from Image", description="Upload an image file to extract text. Supports both computer and handwritten text.") async def extract_text_from_image(file: UploadFile = File(..., description="Image file to process.")): """ Endpoint to perform OCR on an uploaded image. """ # Read image content from the uploaded file try: contents = await file.read() image = Image.open(io.BytesIO(contents)) except Exception: raise HTTPException(status_code=400, detail="Invalid image file. Could not open image.") # Run the OCR model try: extracted_text = run_florence2_ocr(image) return {"filename": file.filename, "extracted_text": extracted_text} except Exception as e: print(f"Error during model inference: {e}") raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")