ZZandro commited on
Commit
e301b1c
·
verified ·
1 Parent(s): 64ad746

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +48 -0
inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple inference script for the weight predictor model.
3
+ Usage:
4
+ python inference.py "cheeseburger with fries" menu_item
5
+ python inference.py "chocolate bar 100g" grocery
6
+ python inference.py "wireless mouse" non_food
7
+ """
8
+ import sys
9
+ import joblib
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ # Import predictor classes so pickling works
13
+ import weight_predictor
14
+
15
+
16
+ def load_model():
17
+ print("Loading model from HF Hub...")
18
+ model_path = hf_hub_download(
19
+ repo_id="ZZandro/weight-predictor",
20
+ filename="unified_predictor.pkl",
21
+ repo_type="model"
22
+ )
23
+ return joblib.load(model_path)
24
+
25
+
26
+ def predict(text, item_type="grocery"):
27
+ predictor = load_model()
28
+ weight = predictor.predict(text, item_type=item_type)
29
+ return weight
30
+
31
+
32
+ if __name__ == "__main__":
33
+ if len(sys.argv) < 2:
34
+ print("Usage: python inference.py '<item description>' <item_type>")
35
+ print(" item_type: menu_item | grocery | non_food")
36
+ sys.exit(1)
37
+
38
+ description = sys.argv[1]
39
+ item_type = sys.argv[2] if len(sys.argv) > 2 else "grocery"
40
+
41
+ # Auto-prepend tag if not present
42
+ if not description.startswith("["):
43
+ description = f"[{item_type.upper()}] {description}"
44
+
45
+ weight = predict(description, item_type)
46
+ print(f"\nItem: {sys.argv[1]}")
47
+ print(f"Type: {item_type}")
48
+ print(f"Predicted weight: {weight:.1f} grams")