amar6de2 commited on
Commit
4d37213
·
1 Parent(s): 03afbfb

Add working ViT food classifier Space

Browse files
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from model import create_effnetb2_model
7
+ from timeit import default_timer as timer
8
+ from typing import Tuple, Dict
9
+ class_names = ["apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad", "beignets", "bibimbap", "biryani", "bread_pudding", "breakfast_burrito", "bruschetta", "caesar_salad", "cannoli", "caprese_salad", "carrot_cake", "ceviche", "chai", "chapati", "cheese_plate", "cheesecake", "chicken_curry", "chicken_quesadilla", "chicken_wings", "chocolate_cake", "chocolate_mousse", "chole_bhature", "churros", "clam_chowder", "club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "dabeli", "dal", "deviled_eggs", "dhokla", "donuts", "dosa", "dumplings", "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", "french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt", "garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole", "gyoza", "hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "idli", "jalebi", "kathi_rolls", "kofta", "kulfi", "lasagna", "lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", "momos", "mussels", "naan", "nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pakoda", "pancakes", "pani_puri", "panna_cotta", "panner_butter_masala", "pav_bhaji", "peking_duck", "pho", "pizza", "pork_chop", "poutine", "prime_rib", "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", "risotto", "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese", "spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu", "tuna_tartare", "vadapav", "waffles"]
10
+ vit, vit_transforms = create_vit_model(
11
+ num_classes=121, # len(class_names) would also work
12
+ )
13
+ vit.load_state_dict(
14
+ torch.load(
15
+ f="vit_epoch_2.pth",
16
+ map_location=torch.device("cpu"), # load to CPU
17
+ )
18
+ )
19
+ # Create predict function
20
+ def predict(img) -> Tuple[Dict, float]:
21
+ """Transforms and performs a prediction on img and returns prediction and time taken.
22
+ """
23
+ # Start the timer
24
+ start_time = timer()
25
+
26
+ # Transform the target image and add a batch dimension
27
+ img = vit_transforms(img).unsqueeze(0)
28
+
29
+ # Put model into evaluation mode and turn on inference mode
30
+ vit.eval()
31
+ with torch.inference_mode():
32
+ # Pass the transformed image through the model and turn the prediction logits into prediction probabilities
33
+ pred_probs = torch.softmax(vit(img), dim=1)
34
+
35
+ # Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
36
+ pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
37
+
38
+ # Calculate the prediction time
39
+ pred_time = round(timer() - start_time, 5)
40
+
41
+ # Return the prediction dictionary and prediction time
42
+ return pred_labels_and_probs, pred_time
43
+
44
+ ### 4. Gradio app ###
45
+
46
+ # Create title, description and article strings
47
+ title = "VisionBite 🍕🥩🍣"
48
+ description = "A Vision Transformer (ViT-Base-16) model trained to classify images of food into 121 distinct categories. The model uses a transformer-based architecture to extract visual features and achieve accurate classification across diverse food items."
49
+ article = "Model Has been trained on Food121 dataset("https://huggingface.co/datasets/ItsNotRohit/Food121") and has an accuracy of 95% on top 5 predictions."
50
+
51
+ # Create examples list from "examples/" directory
52
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
53
+
54
+ # Create the Gradio demo
55
+ demo = gr.Interface(fn=predict, # mapping function from input to output
56
+ inputs=gr.Image(type="pil"), # what are the inputs?
57
+ outputs=[gr.Label(num_top_classes=121, label="Predictions"), # what are the outputs?
58
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
59
+ # Create examples list from "examples/" directory
60
+ examples=example_list,
61
+ title=title,
62
+ description=description,
63
+ article=article)
64
+
65
+ # Launch the demo!
66
+ demo.launch()
examples/apple_pie_0.jpg ADDED
examples/chai_3400.jpg ADDED
examples/chapati_3600.jpg ADDED
model.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from torch import nn
5
+ def create_vit_model(num_classes:int=121,
6
+ seed:int=42):
7
+ """Creates a ViT-B/16 feature extractor model and transforms.
8
+
9
+ Args:
10
+ num_classes (int, optional): number of target classes. Defaults to 3.
11
+ seed (int, optional): random seed value for output layer. Defaults to 42.
12
+
13
+ Returns:
14
+ model (torch.nn.Module): ViT-B/16 feature extractor model.
15
+ transforms (torchvision.transforms): ViT-B/16 image transforms.
16
+ """
17
+ # Create ViT_B_16 pretrained weights, transforms and model
18
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
19
+ transforms = weights.transforms()
20
+ model = torchvision.models.vit_b_16(weights=weights)
21
+
22
+ # Freeze all layers in model
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
+ # Change classifier head to suit our needs (this will be trainable)
27
+ torch.manual_seed(seed)
28
+ model.heads = nn.Sequential(
29
+ nn.LayerNorm(768),
30
+ nn.Dropout(0.2), # Try 0.1 or 0.2
31
+ nn.Linear(768, 121)
32
+ )
33
+
34
+ return model, transforms
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4
vit_epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e2feb708f66db4d26e017d955e6e8f8e64842e9a71e67e81cd3f4dc3f956eff
3
+ size 343634614