|
|
| import torch |
| import requests |
| from PIL import Image |
| from io import BytesIO |
| from pathlib import Path |
| from typing import Union, List, Dict, Any |
| import sys |
|
|
| |
| OCULUS_ROOT = Path(__file__).parent |
| sys.path.insert(0, str(OCULUS_ROOT)) |
|
|
| try: |
| from oculus_unified_model import OculusForConditionalGeneration |
| except ImportError: |
| |
| from Oculus.oculus_unified_model import OculusForConditionalGeneration |
|
|
| class OculusPredictor: |
| """ |
| Easy-to-use interface for the Oculus Unified Model. |
| Supports Object Detection, VQA, and Captioning. |
| """ |
| |
| def __init__(self, model_path: str = None, device: str = "cpu"): |
| self.device = device |
| |
| |
| if model_path is None: |
| base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2" |
| if (base_dir / "final").exists(): |
| model_path = str(base_dir / "final") |
| else: |
| |
| model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final") |
| |
| print(f"Loading Oculus model from: {model_path}") |
| self.model = OculusForConditionalGeneration.from_pretrained(model_path) |
| |
| |
| heads_path = Path(model_path) / "heads.pth" |
| if heads_path.exists(): |
| heads = torch.load(heads_path, map_location=device) |
| self.model.detection_head.load_state_dict(heads['detection']) |
| print("✓ Detection heads loaded") |
| |
| |
| instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model" |
| if instruct_path.exists(): |
| from transformers import BlipForQuestionAnswering |
| self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path) |
| print("✓ Instruction-tuned VQA model loaded") |
| |
| print("✓ Model loaded successfully") |
|
|
| def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image: |
| """Load image from path, URL, or PIL object.""" |
| if isinstance(image_source, Image.Image): |
| return image_source.convert("RGB") |
| |
| if image_source.startswith("http"): |
| response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'}) |
| return Image.open(BytesIO(response.content)).convert("RGB") |
| |
| return Image.open(image_source).convert("RGB") |
|
|
| def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]: |
| """ |
| Run object detection. |
| Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]} |
| """ |
| image = self.load_image(image_source) |
| output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold) |
| |
| |
| return { |
| 'boxes': output.boxes, |
| 'labels': output.labels, |
| 'confidences': output.confidences, |
| 'image_size': image.size |
| } |
|
|
| def ask(self, image_source: Union[str, Image.Image], question: str) -> str: |
| """Ask a question about the image (VQA).""" |
| image = self.load_image(image_source) |
| output = self.model.generate(image, mode="text", prompt=question) |
| return output.text |
|
|
| def caption(self, image_source: Union[str, Image.Image]) -> str: |
| """Generate a caption for the image.""" |
| return self.ask(image_source, "A photo of") |
|
|