Upload inference.py with huggingface_hub
Browse files- inference.py +48 -0
inference.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Simple inference script for the weight predictor model.
|
| 3 |
+
Usage:
|
| 4 |
+
python inference.py "cheeseburger with fries" menu_item
|
| 5 |
+
python inference.py "chocolate bar 100g" grocery
|
| 6 |
+
python inference.py "wireless mouse" non_food
|
| 7 |
+
"""
|
| 8 |
+
import sys
|
| 9 |
+
import joblib
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
# Import predictor classes so pickling works
|
| 13 |
+
import weight_predictor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_model():
|
| 17 |
+
print("Loading model from HF Hub...")
|
| 18 |
+
model_path = hf_hub_download(
|
| 19 |
+
repo_id="ZZandro/weight-predictor",
|
| 20 |
+
filename="unified_predictor.pkl",
|
| 21 |
+
repo_type="model"
|
| 22 |
+
)
|
| 23 |
+
return joblib.load(model_path)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def predict(text, item_type="grocery"):
|
| 27 |
+
predictor = load_model()
|
| 28 |
+
weight = predictor.predict(text, item_type=item_type)
|
| 29 |
+
return weight
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
if len(sys.argv) < 2:
|
| 34 |
+
print("Usage: python inference.py '<item description>' <item_type>")
|
| 35 |
+
print(" item_type: menu_item | grocery | non_food")
|
| 36 |
+
sys.exit(1)
|
| 37 |
+
|
| 38 |
+
description = sys.argv[1]
|
| 39 |
+
item_type = sys.argv[2] if len(sys.argv) > 2 else "grocery"
|
| 40 |
+
|
| 41 |
+
# Auto-prepend tag if not present
|
| 42 |
+
if not description.startswith("["):
|
| 43 |
+
description = f"[{item_type.upper()}] {description}"
|
| 44 |
+
|
| 45 |
+
weight = predict(description, item_type)
|
| 46 |
+
print(f"\nItem: {sys.argv[1]}")
|
| 47 |
+
print(f"Type: {item_type}")
|
| 48 |
+
print(f"Predicted weight: {weight:.1f} grams")
|