| |
| """ |
| Oculus Car Part Detection Demo |
| |
| Demonstrates detection on car images using the extended training model. |
| """ |
|
|
| import sys |
| import requests |
| from io import BytesIO |
| from PIL import Image, ImageDraw, ImageFont |
| import torch |
| import numpy as np |
|
|
| |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from oculus_unified_model import OculusForConditionalGeneration |
|
|
| def visualize_results(image, output, filename="output_car_parts.png"): |
| """Draw bounding boxes and labels on image.""" |
| draw = ImageDraw.Draw(image) |
| |
| |
| try: |
| font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16) |
| except: |
| font = ImageFont.load_default() |
| |
| width, height = image.size |
| |
| |
| COCO_CLASSES = [ |
| 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', |
| 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', |
| 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', |
| 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', |
| 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', |
| 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', |
| 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', |
| 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', |
| 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', |
| 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', |
| 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', |
| 'toothbrush' |
| ] |
|
|
| |
| for box, label, conf in zip(output.boxes, output.labels, output.confidences): |
| |
| x1, y1, x2, y2 = box |
| |
| |
| x1 = max(0.0, min(1.0, x1)) |
| y1 = max(0.0, min(1.0, y1)) |
| x2 = max(0.0, min(1.0, x2)) |
| y2 = max(0.0, min(1.0, y2)) |
| |
| |
| if x2 <= x1 or y2 <= y1: |
| continue |
| |
| x1 *= width |
| y1 *= height |
| x2 *= width |
| y2 *= height |
| |
| |
| color = "red" if conf < 0.5 else "green" |
| |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=3) |
| |
| |
| try: |
| class_name = COCO_CLASSES[int(label)] |
| except: |
| class_name = str(label) |
| |
| label_text = f"{class_name} ({conf:.2f})" |
| |
| |
| text_bbox = draw.textbbox((x1, y1), label_text, font=font) |
| draw.rectangle(text_bbox, fill=color) |
| draw.text((x1, y1), label_text, fill="white", font=font) |
| |
| image.save(filename) |
| print(f"Saved visualization to {filename}") |
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo") |
| parser.add_argument("--image", type=str, help="Path to image file to test") |
| parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model") |
| parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode") |
| parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold") |
| parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename") |
| args = parser.parse_args() |
| |
| |
| |
| checkpoint_dir = Path("checkpoints/oculus_detection_v2") |
| model_path = None |
| |
| if checkpoint_dir.exists(): |
| |
| steps = [] |
| for d in checkpoint_dir.iterdir(): |
| if d.is_dir() and d.name.startswith("step_"): |
| try: |
| step = int(d.name.split("_")[1]) |
| steps.append((step, d)) |
| except: |
| pass |
| |
| |
| if steps: |
| steps.sort(key=lambda x: x[0], reverse=True) |
| model_path = str(steps[0][1]) |
| print(f"✨ Found latest checkpoint: {model_path}") |
| |
| if model_path is None: |
| model_path = str(checkpoint_dir / "final") |
| |
| |
| if not Path(model_path).exists(): |
| model_path = "checkpoints/oculus_detection/final" |
| print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}") |
| |
| print(f"Loading model from {model_path}...") |
| try: |
| model = OculusForConditionalGeneration.from_pretrained(model_path) |
| |
| |
| heads_path = Path(model_path) / "heads.pth" |
| if heads_path.exists(): |
| heads = torch.load(heads_path, map_location="cpu") |
| model.detection_head.load_state_dict(heads['detection']) |
| print("✓ Loaded detection heads") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return |
|
|
| |
| if args.image: |
| image_path = args.image |
| print(f"\nProcessing Custom Image: {image_path}...") |
| else: |
| |
| |
| image_path = "data/coco/images/000000071345.jpg" |
| print(f"\nProcessing Default Image: {image_path}...") |
| |
| try: |
| if Path(image_path).exists(): |
| image = Image.open(image_path).convert('RGB') |
| else: |
| |
| |
| url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg" |
| print(f"Image not found, downloading sample {url}...") |
| response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}) |
| image = Image.open(BytesIO(response.content)).convert('RGB') |
| |
| |
| if args.mode == "box": |
| print(f"Running detection with prompt: '{args.prompt}'...") |
| output = model.generate( |
| image, |
| mode="box", |
| prompt=args.prompt, |
| threshold=args.threshold |
| ) |
| print(f"Found {len(output.boxes)} objects") |
| visualize_results(image, output, args.output) |
| |
| elif args.mode == "caption": |
| print("Generating caption...") |
| output = model.generate(image, mode="text", prompt="A photo of") |
| print(f"\n📝 Caption: {output.text}\n") |
| |
| elif args.mode == "vqa": |
| question = args.prompt if args.prompt != "Detect objects" else "What is in this image?" |
| print(f"Thinking about question: '{question}'...") |
| output = model.generate(image, mode="text", prompt=question) |
| print(f"\n🤔 Answer: {output.text}\n") |
| |
| except Exception as e: |
| print(f"Error processing image: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|