Spaces:
Running
Running
File size: 3,131 Bytes
236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d 7e36dd7 236bf6d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import torch
import numpy as np
from torchvision import models, transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from waste_logic import get_explanation
# -------------------------
# Device
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -------------------------
# Load ResNet50 (object detection)
# -------------------------
resnet = models.resnet50(weights=None)
resnet.load_state_dict(torch.load("resnet50.pth", map_location=device))
resnet.to(device)
resnet.eval()
# ImageNet labels
with open("imagenet_classes.txt") as f:
imagenet_labels = [line.strip() for line in f.readlines()]
# -------------------------
# Load CLIP (waste classification)
# -------------------------
clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32"
).to(device)
clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32"
)
WASTE_LABELS = [
"wet waste",
"dry waste",
"electronic waste",
"hazardous waste"
]
# -------------------------
# Image preprocessing
# -------------------------
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# -------------------------
# ResNet object detection
# -------------------------
def detect_object(image):
tensor = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = resnet(tensor)
probs = torch.softmax(outputs, dim=1)
confidence, idx = torch.max(probs, 1)
label = imagenet_labels[idx.item()]
return label, round(confidence.item(), 3)
# -------------------------
# CLIP waste classification
# -------------------------
def classify_waste_clip(image):
inputs = clip_processor(
text=WASTE_LABELS,
images=image,
return_tensors="pt",
padding=True
).to(device)
with torch.no_grad():
outputs = clip_model(**inputs)
logits = outputs.logits_per_image
probs = logits.softmax(dim=1).cpu().numpy()[0]
idx = np.argmax(probs)
confidence = probs[idx]
label_map = {
"wet waste": "Wet Waste",
"dry waste": "Dry Waste",
"electronic waste": "E-Waste",
"hazardous waste": "Hazardous Waste"
}
return label_map[WASTE_LABELS[idx]], round(float(confidence), 3)
# -------------------------
# Final hybrid analysis
# -------------------------
def analyze_image(image_path):
image = Image.open(image_path).convert("RGB")
object_label, object_conf = detect_object(image)
waste_category, waste_conf = classify_waste_clip(image)
explanation = get_explanation(waste_category)
return {
"detected_object": object_label,
"object_confidence": object_conf,
"waste_category": waste_category,
"waste_confidence": waste_conf,
"reason": explanation["reason"],
"disposal": explanation["disposal"],
"environmental_impact": explanation["impact"]
}
|