OGrohit commited on
Commit
236bf6d
·
verified ·
1 Parent(s): 5852458

Update classifier.py

Browse files
Files changed (1) hide show
  1. classifier.py +99 -99
classifier.py CHANGED
@@ -1,99 +1,99 @@
1
- import torch
2
- from torchvision import models, transforms
3
- from PIL import Image
4
-
5
- from waste_logic import map_to_waste, get_explanation
6
-
7
- # -------------------------
8
- # Device setup
9
- # -------------------------
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
- # -------------------------
13
- # Lazy-loaded model (important for backend)
14
- # -------------------------
15
- _model = None
16
-
17
- def get_model():
18
- global _model
19
- if _model is None:
20
- model = models.resnet50(weights=None)
21
- state_dict = torch.load("resnet50.pth", map_location="cpu")
22
- model.load_state_dict(state_dict)
23
- model.to(device)
24
- model.eval()
25
- _model = model
26
- return _model
27
-
28
- # -------------------------
29
- # Image preprocessing
30
- # -------------------------
31
- preprocess = transforms.Compose([
32
- transforms.Resize(256),
33
- transforms.CenterCrop(224),
34
- transforms.ToTensor(),
35
- transforms.Normalize(
36
- mean=[0.485, 0.456, 0.406],
37
- std=[0.229, 0.224, 0.225]
38
- )
39
- ])
40
-
41
- # -------------------------
42
- # Load ImageNet labels
43
- # -------------------------
44
- with open("imagenet_classes.txt", "r") as f:
45
- labels = [line.strip() for line in f.readlines()]
46
-
47
- # -------------------------
48
- # Image classification (Top-K)
49
- # -------------------------
50
- def classify_image(image_path, top_k=3):
51
- model = get_model()
52
-
53
- try:
54
- image = Image.open(image_path).convert("RGB")
55
- except Exception as e:
56
- raise ValueError(f"Invalid image file: {e}")
57
-
58
- tensor = preprocess(image).unsqueeze(0).to(device)
59
-
60
- with torch.no_grad():
61
- outputs = model(tensor)
62
- probs = torch.softmax(outputs, dim=1)
63
-
64
- top_probs, top_idxs = torch.topk(probs, top_k)
65
-
66
- results = []
67
- for prob, idx in zip(top_probs[0], top_idxs[0]):
68
- results.append((labels[idx.item()], prob.item()))
69
-
70
- return results
71
-
72
- # -------------------------
73
- # Public API function
74
- # -------------------------
75
- def analyze_image(image_path):
76
- predictions = classify_image(image_path)
77
-
78
- chosen_label = predictions[0][0]
79
- chosen_conf = predictions[0][1]
80
- waste_type = "Unknown"
81
-
82
- for label, conf in predictions:
83
- wt = map_to_waste(label)
84
- if wt != "Unknown":
85
- chosen_label = label
86
- chosen_conf = conf
87
- waste_type = wt
88
- break
89
-
90
- explanation = get_explanation(waste_type)
91
-
92
- return {
93
- "detected_object": chosen_label,
94
- "confidence": round(chosen_conf, 3),
95
- "waste_category": waste_type,
96
- "reason": explanation["reason"],
97
- "disposal": explanation["disposal"],
98
- "environmental_impact": explanation["impact"]
99
- }
 
1
+ import torch
2
+ from torchvision import models, transforms
3
+ from PIL import Image
4
+
5
+ from waste_logic import map_to_waste, get_explanation
6
+
7
+ # -------------------------
8
+ # Device setup
9
+ # -------------------------
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ # -------------------------
13
+ # Lazy-loaded model (important for backend)
14
+ # -------------------------
15
+ _model = None
16
+
17
+ def get_model():
18
+ global _model
19
+ if _model is None:
20
+ model = models.resnet50(weights=None)
21
+ state_dict = torch.load("resnet50.pth", map_location=device)
22
+ model.load_state_dict(state_dict)
23
+ model.to(device)
24
+ model.eval()
25
+ _model = model
26
+ return _model
27
+
28
+ # -------------------------
29
+ # Image preprocessing
30
+ # -------------------------
31
+ preprocess = transforms.Compose([
32
+ transforms.Resize(256),
33
+ transforms.CenterCrop(224),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(
36
+ mean=[0.485, 0.456, 0.406],
37
+ std=[0.229, 0.224, 0.225]
38
+ )
39
+ ])
40
+
41
+ # -------------------------
42
+ # Load ImageNet labels
43
+ # -------------------------
44
+ with open("imagenet_classes.txt", "r") as f:
45
+ labels = [line.strip() for line in f.readlines()]
46
+
47
+ # -------------------------
48
+ # Image classification (Top-K)
49
+ # -------------------------
50
+ def classify_image(image_path, top_k=3):
51
+ model = get_model()
52
+
53
+ try:
54
+ image = Image.open(image_path).convert("RGB")
55
+ except Exception as e:
56
+ raise ValueError(f"Invalid image file: {e}")
57
+
58
+ tensor = preprocess(image).unsqueeze(0).to(device)
59
+
60
+ with torch.no_grad():
61
+ outputs = model(tensor)
62
+ probs = torch.softmax(outputs, dim=1)
63
+
64
+ top_probs, top_idxs = torch.topk(probs, top_k)
65
+
66
+ results = []
67
+ for prob, idx in zip(top_probs[0], top_idxs[0]):
68
+ results.append((labels[idx.item()], prob.item()))
69
+
70
+ return results
71
+
72
+ # -------------------------
73
+ # Public API function
74
+ # -------------------------
75
+ def analyze_image(image_path):
76
+ predictions = classify_image(image_path)
77
+
78
+ chosen_label = predictions[0][0]
79
+ chosen_conf = predictions[0][1]
80
+ waste_type = "Unknown"
81
+
82
+ for label, conf in predictions:
83
+ wt = map_to_waste(label)
84
+ if wt != "Unknown":
85
+ chosen_label = label
86
+ chosen_conf = conf
87
+ waste_type = wt
88
+ break
89
+
90
+ explanation = get_explanation(waste_type)
91
+
92
+ return {
93
+ "detected_object": chosen_label,
94
+ "confidence": round(chosen_conf, 3),
95
+ "waste_category": waste_type,
96
+ "reason": explanation["reason"],
97
+ "disposal": explanation["disposal"],
98
+ "environmental_impact": explanation["impact"]
99
+ }