ZZandro commited on
Commit
27cc340
·
verified ·
1 Parent(s): a1803e0

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +66 -37
inference.py CHANGED
@@ -1,48 +1,77 @@
1
  """
2
- Simple inference script for the weight predictor model.
3
- Usage:
4
- python inference.py "cheeseburger with fries" menu_item
5
- python inference.py "chocolate bar 100g" grocery
6
- python inference.py "wireless mouse" non_food
 
 
 
 
7
  """
8
- import sys
9
- import joblib
10
  from huggingface_hub import hf_hub_download
 
 
 
11
 
12
- # Import predictor classes so pickling works
13
- import weight_predictor
14
 
 
 
15
 
16
- def load_model():
17
- print("Loading model from HF Hub...")
18
- model_path = hf_hub_download(
19
- repo_id="ZZandro/weight-predictor",
20
- filename="unified_predictor.pkl",
21
- repo_type="model"
22
- )
23
- return joblib.load(model_path)
24
 
 
 
 
25
 
26
- def predict(text, item_type="grocery"):
27
- predictor = load_model()
28
- weight = predictor.predict(text, item_type=item_type)
29
- return weight
 
30
 
31
 
32
  if __name__ == "__main__":
33
- if len(sys.argv) < 2:
34
- print("Usage: python inference.py '<item description>' <item_type>")
35
- print(" item_type: menu_item | grocery | non_food")
36
- sys.exit(1)
37
-
38
- description = sys.argv[1]
39
- item_type = sys.argv[2] if len(sys.argv) > 2 else "grocery"
40
-
41
- # Auto-prepend tag if not present
42
- if not description.startswith("["):
43
- description = f"[{item_type.upper()}] {description}"
44
-
45
- weight = predict(description, item_type)
46
- print(f"\nItem: {sys.argv[1]}")
47
- print(f"Type: {item_type}")
48
- print(f"Predicted weight: {weight:.1f} grams")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Usage example for the hybrid weight predictor.
3
+
4
+ Requirements:
5
+ pip install huggingface_hub joblib scikit-learn
6
+
7
+ This uses a hybrid approach:
8
+ 1. Explicit weight extraction from text (e.g., "500g", "2 liter")
9
+ 2. Knowledge base lookup for known FMCG brands and fast food items
10
+ 3. ML model fallback for unknown items
11
  """
 
 
12
  from huggingface_hub import hf_hub_download
13
+ import joblib
14
+ import sys
15
+ import os
16
 
17
+ # Make sure the weight_predictor module is importable
18
+ sys.path.insert(0, os.path.dirname(__file__))
19
 
20
+ # Load the hybrid predictor (pure Python, no ML model needed)
21
+ from hybrid_weight_predictor import HybridWeightPredictor, build_hybrid_predictor
22
 
23
+ def predict_weight(text, item_type=None):
24
+ """
25
+ Predict item weight in grams.
 
 
 
 
 
26
 
27
+ Args:
28
+ text: Item description. Should start with [MENU_ITEM], [GROCERY], or [NON_FOOD]
29
+ item_type: "menu_item", "grocery", or "non_food" (auto-detected from text if None)
30
 
31
+ Returns:
32
+ Predicted weight in grams (float)
33
+ """
34
+ predictor = build_hybrid_predictor()
35
+ return predictor.predict(text, item_type)
36
 
37
 
38
  if __name__ == "__main__":
39
+ # Example usage
40
+ examples = [
41
+ ("[GROCERY] coca cola can 330ml", "grocery"),
42
+ ("[GROCERY] coca cola", "grocery"), # uses KB default
43
+ ("[GROCERY] pepsi 1 liter bottle", "grocery"),
44
+ ("[GROCERY] kelloggs corn flakes 500g", "grocery"),
45
+ ("[GROCERY] oreo cookies 154g", "grocery"),
46
+ ("[GROCERY] heinz ketchup 570ml", "grocery"),
47
+ ("[GROCERY] mars bar 51g", "grocery"),
48
+ ("[GROCERY] snickers 2 pack 96g", "grocery"),
49
+ ("[GROCERY] red bull 4 pack", "grocery"),
50
+ ("[GROCERY] tide laundry detergent 1.5kg", "grocery"),
51
+ ("[MENU_ITEM] large pizza", "menu_item"),
52
+ ("[MENU_ITEM] cheeseburger", "menu_item"),
53
+ ("[MENU_ITEM] double cheeseburger", "menu_item"),
54
+ ("[MENU_ITEM] big mac", "menu_item"),
55
+ ("[MENU_ITEM] french fries", "menu_item"),
56
+ ("[MENU_ITEM] large fries", "menu_item"),
57
+ ("[MENU_ITEM] chicken nuggets", "menu_item"),
58
+ ("[MENU_ITEM] burrito", "menu_item"),
59
+ ("[MENU_ITEM] caesar salad", "menu_item"),
60
+ ("[MENU_ITEM] caesar salad large", "menu_item"),
61
+ ("[MENU_ITEM] pho", "menu_item"),
62
+ ("[MENU_ITEM] ramen", "menu_item"),
63
+ ("[MENU_ITEM] sushi platter", "menu_item"),
64
+ ("[MENU_ITEM] medium pizza", "menu_item"),
65
+ ("[MENU_ITEM] combo meal", "menu_item"),
66
+ ("[MENU_ITEM] milkshake", "menu_item"),
67
+ ("[MENU_ITEM] iced coffee", "menu_item"),
68
+ ("[MENU_ITEM] family meal", "menu_item"),
69
+ ("[MENU_ITEM] sliders", "menu_item"),
70
+ ("[NON_FOOD] laptop computer", "non_food"),
71
+ ("[NON_FOOD] water bottle", "non_food"),
72
+ ]
73
+
74
+ print("=== Weight Predictions ===\n")
75
+ for text, item_type in examples:
76
+ weight = predict_weight(text, item_type)
77
+ print(f" {text:55s} -> {weight:8.1f}g")