OmarAbualrob commited on
Commit
523b7a6
·
verified ·
1 Parent(s): 499d6a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -77
app.py CHANGED
@@ -1,110 +1,109 @@
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import JSONResponse
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
- import torch
6
- import io
7
- import logging
8
 
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
 
12
 
13
- # --- 1. Initialize FastAPI App ---
14
- app = FastAPI(title="Mixed-Content OCR API", description="An API to extract text from images containing both printed and handwritten text.")
 
 
15
 
16
- # --- 2. Load the Model and Processor (at startup) ---
17
- # This is a critical step. We load the model only once when the app starts.
18
- # This prevents reloading the model on every API call, which would be very slow.
 
19
  try:
20
- logger.info("Loading model and processor...")
21
- # Use the large model for better accuracy
22
- model_id = "microsoft/Florence-2-large"
23
- # NOTE: We need to trust remote code for Florence-2
24
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
25
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
- logger.info("Model and processor loaded successfully.")
27
  except Exception as e:
28
- logger.error(f"Error loading model: {e}")
29
- # If the model fails to load, the API is not usable. We can't proceed.
30
  model = None
31
  processor = None
32
 
33
- # --- 3. Define the OCR Task Function ---
34
- def run_ocr(image: Image.Image) -> str:
 
 
 
 
 
 
 
 
35
  """
36
- Performs OCR on a given PIL Image using the Florence-2 model.
 
 
 
 
 
 
37
  """
38
- if model is None or processor is None:
39
- raise RuntimeError("Model is not available. Check logs for loading errors.")
 
 
 
40
 
41
  # Ensure image is in RGB format
42
  if image.mode != "RGB":
43
  image = image.convert("RGB")
44
 
45
- # Define the task prompt
46
- prompt = "<OCR>"
47
-
48
  # Preprocess the image and prompt
49
- inputs = processor(text=prompt, images=image, return_tensors="pt")
50
-
 
 
 
 
51
  # Generate text from the image
52
- # Note: max_new_tokens can be adjusted based on expected text length
53
- os.environ["DISABLE_FLASH_ATTN"] = "1"
54
  generated_ids = model.generate(
55
  input_ids=inputs["input_ids"],
56
  pixel_values=inputs["pixel_values"],
57
- max_new_tokens=4096, # Increased token limit for long documents
58
- do_sample=False, # Use greedy decoding for deterministic output
59
- num_beams=3
60
  )
61
-
62
- # Decode the generated IDs to a string
63
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
 
 
 
64
 
65
- # Post-process the output to get the clean text
66
- # The model's output for OCR is typically in the format: <OCR>extracted_text</s>
67
- parsed_text = processor.post_process_generation(generated_text, task="<OCR>", image_size=(image.width, image.height))
68
-
69
- return parsed_text.get("<OCR>", "Error: Could not parse OCR output.")
70
 
71
 
72
- # --- 4. Create the API Endpoint ---
73
- @app.post("/ocr", summary="Extract Text from Image")
74
- async def perform_ocr(file: UploadFile = File(..., description="Image file to perform OCR on.")):
 
 
 
 
75
  """
76
- Takes an image file, extracts both printed and handwritten text,
77
- and returns it as a JSON object.
78
  """
79
- if model is None:
80
- raise HTTPException(status_code=503, detail="Model is not loaded or unavailable.")
81
-
82
- # Validate file type
83
- if not file.content_type.startswith("image/"):
84
- raise HTTPException(status_code=400, detail="Invalid file type. Please upload an image.")
85
-
86
  try:
87
- # Read the image content from the uploaded file
88
  contents = await file.read()
89
  image = Image.open(io.BytesIO(contents))
90
-
91
- # Run the OCR task
92
- logger.info("Running OCR on the uploaded image...")
93
- extracted_text = run_ocr(image)
94
- logger.info("OCR completed successfully.")
95
-
96
- # Return the result
97
- return JSONResponse(
98
- content={"filename": file.filename, "text": extracted_text}
99
- )
100
-
101
- except Exception as e:
102
- logger.error(f"An error occurred during OCR processing: {e}")
103
- raise HTTPException(status_code=500, detail=f"An internal error occurred: {str(e)}")
104
 
105
- @app.get("/", summary="Health Check")
106
- def read_root():
107
- """
108
- A simple health check endpoint to confirm the API is running.
109
- """
110
- return {"status": "ok", "model_loaded": model is not None}
 
 
1
+ import io
2
+ import torch
3
  from fastapi import FastAPI, File, UploadFile, HTTPException
 
 
4
  from PIL import Image
5
+ from transformers import AutoModelForCausalLM, AutoProcessor
 
 
6
 
7
+ # --- 1. SCRIPT SETUP ---
8
+ # Set up device (use GPU if available, otherwise CPU)
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ print(f"--- Running on {DEVICE} ---")
11
 
12
+ # Define model and processor IDs from Hugging Face Hub
13
+ MODEL_ID = "microsoft/Florence-2-large"
14
+ # For better performance, you can use the float16 version if your hardware supports it
15
+ # MODEL_ID = "microsoft/Florence-2-large-ft"
16
 
17
+ # --- 2. LOAD MODEL AND PROCESSOR ---
18
+ # Load the model and processor from Hugging Face
19
+ # trust_remote_code=True is required for Florence-2
20
+ # torch_dtype=torch.float16 is used for faster inference and lower memory on GPUs
21
  try:
22
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to(DEVICE)
23
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
24
+ print("--- Model and processor loaded successfully ---")
 
 
 
 
25
  except Exception as e:
26
+ print(f"--- Error loading model: {e} ---")
 
27
  model = None
28
  processor = None
29
 
30
+ # --- 3. FASTAPI APP INITIALIZATION ---
31
+ app = FastAPI(
32
+ title="Florence-2 OCR API",
33
+ description="An API for extracting text from images using Microsoft's Florence-2-large model. "
34
+ "Handles both printed and handwritten text.",
35
+ version="1.0.0"
36
+ )
37
+
38
+ # --- 4. HELPER FUNCTION ---
39
+ def run_florence2_ocr(image: Image.Image):
40
  """
41
+ Runs the Florence-2 model to perform OCR on a given image.
42
+
43
+ Args:
44
+ image (Image.Image): The input image in PIL format.
45
+
46
+ Returns:
47
+ str: The extracted text.
48
  """
49
+ if not model or not processor:
50
+ raise HTTPException(status_code=503, detail="Model is not available. Please check server logs.")
51
+
52
+ # The task prompt for OCR
53
+ task_prompt = "<OCR>"
54
 
55
  # Ensure image is in RGB format
56
  if image.mode != "RGB":
57
  image = image.convert("RGB")
58
 
 
 
 
59
  # Preprocess the image and prompt
60
+ inputs = processor(text=task_prompt, images=image, return_tensors="pt").to(DEVICE)
61
+ # Move inputs to float16 if the model is in float16
62
+ if model.dtype == torch.float16:
63
+ inputs = inputs.to(torch.float16)
64
+
65
+
66
  # Generate text from the image
 
 
67
  generated_ids = model.generate(
68
  input_ids=inputs["input_ids"],
69
  pixel_values=inputs["pixel_values"],
70
+ max_new_tokens=2048, # Increased token limit for dense text
71
+ num_beams=3,
72
+ do_sample=False # Use greedy decoding for more deterministic results
73
  )
74
+
75
+ # Decode the generated IDs to text
76
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
77
+
78
+ # Parse the output to get only the OCR result
79
+ # The model's output format is typically "<OCR>extracted_text</s>"
80
+ # We remove the prompt and the end-of-sequence token
81
+ parsed_text = processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
82
 
83
+ return parsed_text.get('<OCR>', "Error: Could not parse OCR output.")
 
 
 
 
84
 
85
 
86
+ # --- 5. API ENDPOINTS ---
87
+ @app.get("/", summary="Root Endpoint", description="Returns a welcome message.")
88
+ def read_root():
89
+ return {"message": "Welcome to the Florence-2 OCR API. Go to /docs for usage."}
90
+
91
+ @app.post("/ocr", summary="Extract Text from Image", description="Upload an image file to extract text. Supports both computer and handwritten text.")
92
+ async def extract_text_from_image(file: UploadFile = File(..., description="Image file to process.")):
93
  """
94
+ Endpoint to perform OCR on an uploaded image.
 
95
  """
96
+ # Read image content from the uploaded file
 
 
 
 
 
 
97
  try:
 
98
  contents = await file.read()
99
  image = Image.open(io.BytesIO(contents))
100
+ except Exception:
101
+ raise HTTPException(status_code=400, detail="Invalid image file. Could not open image.")
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ # Run the OCR model
104
+ try:
105
+ extracted_text = run_florence2_ocr(image)
106
+ return {"filename": file.filename, "extracted_text": extracted_text}
107
+ except Exception as e:
108
+ print(f"Error during model inference: {e}")
109
+ raise HTTPException(status_code=500, detail=f"An error occurred during processing: {str(e)}")