ZZandro commited on
Commit
64ad746
·
verified ·
1 Parent(s): 0c66dd1

Upload weight_predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. weight_predictor.py +44 -0
weight_predictor.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Weight predictor classes for food delivery platforms.
3
+ """
4
+ import numpy as np
5
+
6
+
7
+ class WeightPredictor:
8
+ """Per-item-type weight predictor using TF-IDF + Ridge regression."""
9
+ def __init__(self, tfidf, model):
10
+ self.tfidf = tfidf
11
+ self.model = model
12
+
13
+ def predict(self, texts):
14
+ """Predict weights for a list of texts."""
15
+ X = self.tfidf.transform(texts)
16
+ return np.expm1(self.model.predict(X))
17
+
18
+ def predict_single(self, text):
19
+ """Predict weight for a single text."""
20
+ return self.predict([text])[0]
21
+
22
+
23
+ class UnifiedWeightPredictor:
24
+ """Unified predictor that routes to per-type models."""
25
+ def __init__(self, predictors, default_type="grocery"):
26
+ self.predictors = predictors
27
+ self.default_type = default_type
28
+
29
+ def predict(self, text, item_type=None):
30
+ """
31
+ Predict weight from text description.
32
+ item_type should be one of: menu_item, grocery, non_food
33
+ """
34
+ if item_type is None:
35
+ if text.startswith("[MENU_ITEM]"):
36
+ item_type = "menu_item"
37
+ elif text.startswith("[GROCERY]"):
38
+ item_type = "grocery"
39
+ elif text.startswith("[NON_FOOD]"):
40
+ item_type = "non_food"
41
+ else:
42
+ item_type = self.default_type
43
+ predictor = self.predictors.get(item_type, self.predictors.get(self.default_type))
44
+ return predictor.predict([text])[0]