| |
| """ |
| FIXED PixelText OCR Model with proper Hugging Face Hub support |
| This version has the from_pretrained method and works with AutoModel.from_pretrained() |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| PaliGemmaForConditionalGeneration, |
| PaliGemmaProcessor, |
| AutoTokenizer, |
| PreTrainedModel, |
| PretrainedConfig |
| ) |
| from PIL import Image |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| class PixelTextConfig(PretrainedConfig): |
| """Configuration for PixelText model.""" |
| |
| model_type = "pixeltext" |
| |
| def __init__( |
| self, |
| base_model="google/paligemma-3b-pt-224", |
| hidden_size=2048, |
| vocab_size=257216, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.base_model = base_model |
| self.hidden_size = hidden_size |
| self.vocab_size = vocab_size |
|
|
| class FixedPixelTextOCR(PreTrainedModel): |
| """ |
| FIXED PixelText OCR model with proper Hugging Face Hub support. |
| This version works with AutoModel.from_pretrained() |
| """ |
| |
| config_class = PixelTextConfig |
| |
| def __init__(self, config=None): |
| if config is None: |
| config = PixelTextConfig() |
| |
| super().__init__(config) |
| |
| print(f"🚀 Loading FIXED PixelText OCR...") |
| |
| |
| if torch.cuda.is_available(): |
| self._device = "cuda" |
| self.torch_dtype = torch.float16 |
| else: |
| self._device = "cpu" |
| self.torch_dtype = torch.float32 |
| |
| print(f"🔧 Device: {self._device}") |
| |
| |
| try: |
| self.base_model = PaliGemmaForConditionalGeneration.from_pretrained( |
| config.base_model, |
| torch_dtype=self.torch_dtype, |
| trust_remote_code=True |
| ).to(self._device) |
| |
| self.processor = PaliGemmaProcessor.from_pretrained(config.base_model) |
| self.tokenizer = AutoTokenizer.from_pretrained(config.base_model) |
| |
| print("✅ FIXED PixelText OCR ready!") |
| |
| except Exception as e: |
| print(f"❌ Failed to load components: {e}") |
| raise |
| |
| |
| self.hidden_size = config.hidden_size |
| self.vocab_size = config.vocab_size |
| |
| def forward(self, **kwargs): |
| """Forward pass through the base model.""" |
| return self.base_model(**kwargs) |
| |
| def generate_ocr_text(self, image, prompt="<image>Extract all text from this image:", max_length=512): |
| """ |
| 🎯 MAIN METHOD: Extract text from image |
| |
| Args: |
| image: PIL Image, file path, or numpy array |
| prompt: Custom prompt (optional) |
| max_length: Maximum length of generated text |
| |
| Returns: |
| dict: Contains extracted text, confidence, and metadata |
| """ |
| |
| |
| if isinstance(image, str): |
| image = Image.open(image).convert('RGB') |
| elif hasattr(image, 'shape'): |
| image = Image.fromarray(image).convert('RGB') |
| elif not isinstance(image, Image.Image): |
| raise ValueError("Image must be PIL Image, file path, or numpy array") |
| |
| |
| if "<image>" not in prompt: |
| prompt = f"<image>{prompt}" |
| |
| try: |
| |
| inputs = self.processor(text=prompt, images=image, return_tensors="pt") |
| |
| |
| for key in inputs: |
| if isinstance(inputs[key], torch.Tensor): |
| inputs[key] = inputs[key].to(self._device) |
| |
| |
| with torch.no_grad(): |
| generated_ids = self.base_model.generate( |
| **inputs, |
| max_length=max_length, |
| do_sample=False, |
| num_beams=1, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| generated_text = self.processor.batch_decode( |
| generated_ids, |
| skip_special_tokens=True |
| )[0] |
| |
| |
| text = self._clean_text(generated_text, prompt) |
| |
| |
| confidence = self._calculate_confidence(text) |
| |
| return { |
| 'text': text, |
| 'confidence': confidence, |
| 'success': True, |
| 'method': 'fixed_pixeltext', |
| 'raw_output': generated_text |
| } |
| |
| except Exception as e: |
| return { |
| 'text': "", |
| 'confidence': 0.0, |
| 'success': False, |
| 'method': 'error', |
| 'error': str(e) |
| } |
| |
| def _clean_text(self, generated_text, prompt): |
| """Clean the generated text.""" |
| |
| |
| clean_prompt = prompt.replace("<image>", "").strip() |
| if clean_prompt and clean_prompt in generated_text: |
| text = generated_text.replace(clean_prompt, "").strip() |
| else: |
| text = generated_text.strip() |
| |
| |
| artifacts = [ |
| "The image shows", "The text in the image says", |
| "The image contains", "I can see", "The text reads", |
| "This image shows", "The picture shows" |
| ] |
| |
| for artifact in artifacts: |
| if text.lower().startswith(artifact.lower()): |
| text = text[len(artifact):].strip() |
| if text.startswith(":"): |
| text = text[1:].strip() |
| if text.startswith('"') and text.endswith('"'): |
| text = text[1:-1].strip() |
| |
| return text |
| |
| def _calculate_confidence(self, text): |
| """Calculate confidence score.""" |
| |
| if not text: |
| return 0.0 |
| |
| confidence = 0.5 |
| |
| if len(text) > 10: |
| confidence += 0.2 |
| if len(text) > 50: |
| confidence += 0.1 |
| if len(text) > 100: |
| confidence += 0.1 |
| |
| if any(c.isalpha() for c in text): |
| confidence += 0.1 |
| if any(c.isdigit() for c in text): |
| confidence += 0.05 |
| |
| if len(text.strip()) < 3: |
| confidence *= 0.5 |
| |
| return min(0.95, confidence) |
| |
| def batch_ocr(self, images, prompt="<image>Extract all text from this image:", max_length=512): |
| """Process multiple images.""" |
| |
| results = [] |
| |
| for i, image in enumerate(images): |
| print(f"📄 Processing image {i+1}/{len(images)}...") |
| result = self.generate_ocr_text(image, prompt, max_length) |
| results.append(result) |
| |
| if result['success']: |
| print(f" ✅ Success: {len(result['text'])} characters") |
| else: |
| print(f" ❌ Failed: {result.get('error', 'Unknown error')}") |
| |
| return results |
| |
| def get_model_info(self): |
| """Get model information.""" |
| |
| return { |
| 'model_name': 'FIXED PixelText OCR', |
| 'base_model': 'PaliGemma-3B', |
| 'device': self._device, |
| 'dtype': str(self.torch_dtype), |
| 'hidden_size': self.hidden_size, |
| 'vocab_size': self.vocab_size, |
| 'parameters': '~3B', |
| 'repository': 'BabaK07/pixeltext-ai', |
| 'status': 'FIXED - Hub loading works!', |
| 'features': [ |
| 'Hub loading support', |
| 'from_pretrained method', |
| 'Fast OCR extraction', |
| 'Multi-language support', |
| 'Batch processing', |
| 'Production ready' |
| ] |
| } |
|
|
| |
| WorkingQwenOCRModel = FixedPixelTextOCR |
|
|