| |
| """ |
| Fixed Custom OCR Model based on PaliGemma-3B |
| Handles device placement issues and provides better OCR performance |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| PaliGemmaForConditionalGeneration, |
| PaliGemmaProcessor, |
| AutoTokenizer |
| ) |
| from PIL import Image |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| class FixedPaliGemmaOCR(nn.Module): |
| """ |
| Fixed Custom OCR model based on PaliGemma-3B with proper device handling. |
| """ |
| |
| def __init__(self, model_name="google/paligemma-3b-pt-224"): |
| super().__init__() |
| |
| print(f"๐ Initializing Fixed PaliGemma OCR Model...") |
| print(f"๐ฆ Base model: {model_name}") |
| |
| |
| if torch.cuda.is_available(): |
| self.device = "cuda" |
| self.torch_dtype = torch.float16 |
| print("๐ง Using CUDA with float16") |
| else: |
| self.device = "cpu" |
| self.torch_dtype = torch.float32 |
| print("๐ง Using CPU with float32") |
| |
| |
| try: |
| print("๐ฅ Loading PaliGemma model...") |
| self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( |
| model_name, |
| torch_dtype=self.torch_dtype, |
| trust_remote_code=True |
| ) |
| |
| print("๐ฅ Loading processor...") |
| self.processor = PaliGemmaProcessor.from_pretrained(model_name) |
| |
| print("๐ฅ Loading tokenizer...") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
| |
| self.base_model = self.base_model.to(self.device) |
| |
| print("โ
All components loaded successfully") |
| |
| except Exception as e: |
| print(f"โ Failed to load PaliGemma model: {e}") |
| raise |
| |
| |
| self.hidden_size = self.base_model.config.text_config.hidden_size |
| self.vocab_size = self.base_model.config.text_config.vocab_size |
| |
| |
| print(f"๐ง Model ready:") |
| print(f" - Device: {self.device}") |
| print(f" - Hidden size: {self.hidden_size}") |
| print(f" - Vocab size: {self.vocab_size}") |
| print(f" - Parameters: ~3B") |
| |
| def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512): |
| """ |
| Generate OCR text from image with proper device handling. |
| |
| Args: |
| image: PIL Image or path to image |
| prompt: Text prompt for OCR task (must include <image> token) |
| max_length: Maximum length of generated text |
| |
| Returns: |
| dict: Contains extracted text, confidence, and metadata |
| """ |
| |
| if isinstance(image, str): |
| image = Image.open(image).convert('RGB') |
| elif not isinstance(image, Image.Image): |
| raise ValueError("Image must be PIL Image or path string") |
| |
| try: |
| |
| result = self._extract_with_paligemma(image, prompt, max_length) |
| result['method'] = 'paligemma_standard' |
| return result |
| |
| except Exception as e: |
| print(f"โ ๏ธ Standard method failed: {e}") |
| |
| try: |
| |
| result = self._extract_with_fallback(image, max_length) |
| result['method'] = 'paligemma_fallback' |
| return result |
| |
| except Exception as e2: |
| print(f"โ ๏ธ Fallback method failed: {e2}") |
| |
| |
| return { |
| 'text': "Error: Could not extract text from image", |
| 'confidence': 0.0, |
| 'quality': 'error', |
| 'method': 'error', |
| 'error': str(e2) |
| } |
| |
| def _extract_with_paligemma(self, image, prompt, max_length): |
| """Extract text using PaliGemma's standard approach.""" |
| |
| try: |
| |
| if "<image>" not in prompt: |
| prompt = f"<image>{prompt}" |
| |
| inputs = self.processor( |
| text=prompt, |
| images=image, |
| return_tensors="pt" |
| ) |
| |
| |
| for key in inputs: |
| if isinstance(inputs[key], torch.Tensor): |
| inputs[key] = inputs[key].to(self.device) |
| |
| |
| with torch.no_grad(): |
| generated_ids = self.base_model.generate( |
| **inputs, |
| max_length=max_length, |
| do_sample=False, |
| num_beams=1, |
| pad_token_id=self.tokenizer.eos_token_id, |
| eos_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| generated_text = self.processor.batch_decode( |
| generated_ids, |
| skip_special_tokens=True |
| )[0] |
| |
| |
| extracted_text = self._clean_generated_text(generated_text, prompt) |
| |
| |
| confidence = self._estimate_confidence(extracted_text) |
| |
| return { |
| 'text': extracted_text, |
| 'confidence': confidence, |
| 'quality': self._assess_quality(extracted_text), |
| 'raw_output': generated_text |
| } |
| |
| except Exception as e: |
| print(f"โ PaliGemma extraction failed: {e}") |
| raise |
| |
| def _extract_with_fallback(self, image, max_length): |
| """Fallback extraction with different prompts.""" |
| |
| fallback_prompts = [ |
| "<image>What text is visible in this image?", |
| "<image>Read all the text in this image.", |
| "<image>OCR this image.", |
| "<image>Transcribe the text.", |
| "<image>" |
| ] |
| |
| for prompt in fallback_prompts: |
| try: |
| inputs = self.processor( |
| text=prompt, |
| images=image, |
| return_tensors="pt" |
| ) |
| |
| |
| for key in inputs: |
| if isinstance(inputs[key], torch.Tensor): |
| inputs[key] = inputs[key].to(self.device) |
| |
| with torch.no_grad(): |
| generated_ids = self.base_model.generate( |
| **inputs, |
| max_length=max_length, |
| do_sample=True, |
| temperature=0.1, |
| top_p=0.9, |
| num_beams=1, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| generated_text = self.processor.batch_decode( |
| generated_ids, |
| skip_special_tokens=True |
| )[0] |
| |
| extracted_text = self._clean_generated_text(generated_text, prompt) |
| |
| if len(extracted_text.strip()) > 0: |
| return { |
| 'text': extracted_text, |
| 'confidence': 0.7, |
| 'quality': 'good', |
| 'raw_output': generated_text |
| } |
| |
| except Exception as e: |
| print(f"โ ๏ธ Fallback prompt '{prompt}' failed: {e}") |
| continue |
| |
| |
| return { |
| 'text': "", |
| 'confidence': 0.0, |
| 'quality': 'poor', |
| 'raw_output': "" |
| } |
| |
| def _clean_generated_text(self, generated_text, prompt): |
| """Clean up generated text by removing prompt and artifacts.""" |
| |
| |
| clean_prompt = prompt.replace("<image>", "").strip() |
| if clean_prompt and clean_prompt in generated_text: |
| extracted_text = generated_text.replace(clean_prompt, "").strip() |
| else: |
| extracted_text = generated_text.strip() |
| |
| |
| artifacts = [ |
| "The image shows", |
| "The text in the image says", |
| "The image contains the text", |
| "I can see the text", |
| "The text reads" |
| ] |
| |
| for artifact in artifacts: |
| if extracted_text.lower().startswith(artifact.lower()): |
| extracted_text = extracted_text[len(artifact):].strip() |
| if extracted_text.startswith(":"): |
| extracted_text = extracted_text[1:].strip() |
| if extracted_text.startswith('"') and extracted_text.endswith('"'): |
| extracted_text = extracted_text[1:-1].strip() |
| |
| return extracted_text |
| |
| def _estimate_confidence(self, text): |
| """Estimate confidence based on text characteristics.""" |
| |
| if not text or len(text.strip()) == 0: |
| return 0.0 |
| |
| |
| confidence = 0.5 |
| |
| |
| if len(text) > 10: |
| confidence += 0.2 |
| if len(text) > 50: |
| confidence += 0.1 |
| |
| |
| if any(c.isalpha() for c in text): |
| confidence += 0.1 |
| if any(c.isdigit() for c in text): |
| confidence += 0.05 |
| |
| |
| if len(text.strip()) < 3: |
| confidence *= 0.5 |
| |
| return min(0.95, confidence) |
| |
| def _assess_quality(self, text): |
| """Assess text quality.""" |
| |
| if not text or len(text.strip()) == 0: |
| return 'poor' |
| |
| if len(text.strip()) < 5: |
| return 'poor' |
| elif len(text.strip()) < 20: |
| return 'fair' |
| elif len(text.strip()) < 100: |
| return 'good' |
| else: |
| return 'excellent' |
| |
| def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512): |
| """Process multiple images efficiently.""" |
| |
| results = [] |
| |
| for i, image in enumerate(images): |
| print(f"๐ Processing image {i+1}/{len(images)}...") |
| |
| try: |
| result = self.generate_ocr_text(image, prompt, max_length) |
| results.append(result) |
| |
| print(f" โ
Success: {len(result['text'])} characters extracted") |
| |
| except Exception as e: |
| print(f" โ Error: {e}") |
| results.append({ |
| 'text': f"Error processing image {i+1}", |
| 'confidence': 0.0, |
| 'quality': 'error', |
| 'method': 'error', |
| 'error': str(e) |
| }) |
| |
| return results |
| |
| def get_model_info(self): |
| """Get comprehensive model information.""" |
| |
| return { |
| 'base_model': 'PaliGemma-3B', |
| 'device': self.device, |
| 'dtype': str(self.torch_dtype), |
| 'hidden_size': self.hidden_size, |
| 'vocab_size': self.vocab_size, |
| 'parameters': '~3B', |
| 'optimized_for': 'OCR and Document Understanding', |
| 'supported_languages': '100+', |
| 'features': [ |
| 'Multi-language OCR', |
| 'Document understanding', |
| 'Robust error handling', |
| 'Batch processing', |
| 'Confidence estimation' |
| ] |
| } |
|
|
|
|
| def main(): |
| """Test the Fixed PaliGemma OCR Model.""" |
| |
| print("๐ Testing Fixed PaliGemma OCR Model") |
| print("=" * 50) |
| |
| try: |
| |
| model = FixedPaliGemmaOCR() |
| |
| |
| info = model.get_model_info() |
| print(f"\n๐ Model Information:") |
| for key, value in info.items(): |
| if isinstance(value, list): |
| print(f" {key}:") |
| for item in value: |
| print(f" - {item}") |
| else: |
| print(f" {key}: {value}") |
| |
| |
| print(f"\n๐งช Creating test image...") |
| from PIL import Image, ImageDraw, ImageFont |
| |
| img = Image.new('RGB', (500, 300), color='white') |
| draw = ImageDraw.Draw(img) |
| |
| try: |
| font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 20) |
| title_font = ImageFont.truetype("/System/Library/Fonts/Arial.ttf", 28) |
| except: |
| font = ImageFont.load_default() |
| title_font = font |
| |
| |
| draw.text((20, 30), "INVOICE #12345", fill='black', font=title_font) |
| draw.text((20, 80), "Date: January 15, 2024", fill='black', font=font) |
| draw.text((20, 110), "Customer: John Smith", fill='blue', font=font) |
| draw.text((20, 140), "Amount: $1,234.56", fill='red', font=font) |
| draw.text((20, 170), "Description: Professional Services", fill='black', font=font) |
| draw.text((20, 200), "Tax (10%): $123.46", fill='black', font=font) |
| draw.text((20, 230), "Total: $1,358.02", fill='black', font=title_font) |
| |
| img.save("test_paligemma_ocr.png") |
| print("โ
Test image created: test_paligemma_ocr.png") |
| |
| |
| print(f"\n๐ Testing OCR extraction...") |
| result = model.generate_ocr_text(img) |
| |
| print(f"\n๐ OCR Results:") |
| print(f" Text: {result['text']}") |
| print(f" Confidence: {result['confidence']:.3f}") |
| print(f" Quality: {result['quality']}") |
| print(f" Method: {result['method']}") |
| |
| if len(result['text']) > 0: |
| print(f"\nโ
PaliGemma OCR Model is working perfectly!") |
| else: |
| print(f"\nโ ๏ธ OCR extracted no text - may need adjustment") |
| |
| return model |
| |
| except Exception as e: |
| print(f"โ Error testing model: {e}") |
| import traceback |
| traceback.print_exc() |
| return None |
|
|
|
|
| if __name__ == "__main__": |
| model = main() |