ZZandro commited on
Commit
9fb21de
·
verified ·
1 Parent(s): fcbf632

Upload hybrid_weight_predictor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hybrid_weight_predictor.py +427 -0
hybrid_weight_predictor.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid weight predictor for food delivery platforms.
3
+ Combines explicit weight extraction, rule-based knowledge base, and ML fallback.
4
+ """
5
+ import re
6
+ import numpy as np
7
+ import joblib
8
+
9
+ UNIT_MAP = {
10
+ 'ml': 1.0, 'milliliter': 1.0, 'milliliters': 1.0, 'millilitre': 1.0, 'millilitres': 1.0,
11
+ 'l': 1000.0, 'liter': 1000.0, 'liters': 1000.0, 'litre': 1000.0, 'litres': 1000.0,
12
+ 'cl': 10.0, 'centiliter': 10.0, 'centiliters': 10.0,
13
+ 'dl': 100.0, 'deciliter': 100.0, 'deciliters': 100.0,
14
+ 'g': 1.0, 'gram': 1.0, 'grams': 1.0, 'gr': 1.0,
15
+ 'kg': 1000.0, 'kilogram': 1000.0, 'kilograms': 1000.0, 'kilo': 1000.0,
16
+ 'mg': 0.001, 'milligram': 0.001, 'milligrams': 0.001,
17
+ 'oz': 28.3495, 'ounce': 28.3495, 'ounces': 28.3495,
18
+ 'lb': 453.592, 'lbs': 453.592, 'pound': 453.592, 'pounds': 453.592,
19
+ 'fl oz': 29.5735, 'fluid ounce': 29.5735, 'fluid ounces': 29.5735,
20
+ 'pt': 473.176, 'pint': 473.176, 'pints': 473.176,
21
+ 'qt': 946.353, 'quart': 946.353, 'quarts': 946.353,
22
+ 'gal': 3785.41, 'gallon': 3785.41, 'gallons': 3785.41,
23
+ 'cup': 240.0, 'cups': 240.0,
24
+ }
25
+
26
+ # Common single-unit weights for known grocery products (fallback when no explicit weight in text)
27
+ GROCERY_KB = {
28
+ # Beverages (per single unit if size not specified)
29
+ 'coca cola': 330, 'coke': 330, 'pepsi': 330, 'sprite': 330, 'fanta': 330,
30
+ '7up': 330, 'mountain dew': 500, 'dr pepper': 355,
31
+ 'red bull': 250, 'monster': 500, 'gatorade': 500, 'powerade': 500,
32
+ 'aquafina': 500, 'dasani': 500, 'evian': 500, 'smartwater': 700,
33
+ 'volvic': 500, 'perrier': 330, 'san pellegrino': 500,
34
+ 'minute maid': 250, 'tropicana': 250, 'honest tea': 500, 'fuze tea': 500,
35
+ 'lipton iced tea': 500, 'snapple': 473, 'vitaminwater': 500,
36
+ 'kombucha': 450, 'kefir': 250,
37
+ 'heineken': 330, 'budweiser': 355, 'corona': 355, 'stella artois': 330,
38
+ 'carlsberg': 330, 'guinness': 440,
39
+ # Cereals (per typical box if size not specified)
40
+ 'kelloggs corn flakes': 500, 'kelloggs special k': 500,
41
+ 'kelloggs rice krispies': 510, 'kelloggs coco pops': 350,
42
+ 'kelloggs frosted flakes': 425, 'cheerios': 340, 'lucky charms': 326,
43
+ 'cinnamon toast crunch': 340, 'honey nut cheerios': 306,
44
+ 'trix': 285, 'wheaties': 450,
45
+ # Snacks
46
+ 'oreo': 154, 'oreos': 154, 'pringles': 165, 'doritos': 175,
47
+ 'lays': 175, 'cheetos': 200, 'ruffles': 200,
48
+ 'kit kat': 42, 'snickers': 48, 'mars bar': 51, 'twix': 50,
49
+ 'm and m': 42, 'm and ms': 42, 'maltesers': 37, 'skittles': 45,
50
+ 'toblerone': 100, 'milka': 100, 'cadbury': 110,
51
+ 'haribo': 100, 'gummy bears': 100,
52
+ # Personal care (if no size specified)
53
+ 'toothpaste': 100, 'shampoo': 400, 'conditioner': 400,
54
+ 'body wash': 500, 'deodorant': 50, 'soap': 100,
55
+ # Cleaning
56
+ 'laundry detergent': 1500, 'dish soap': 500, 'bleach': 750,
57
+ # Baby
58
+ 'baby formula': 800, 'baby food': 120, 'diapers': 800,
59
+ }
60
+
61
+ # Fast-food portion knowledge base (item → typical single-serving weight in grams)
62
+ PORTION_KB = {
63
+ # Pizza
64
+ 'small pizza': 500, 'personal pizza': 400, 'medium pizza': 800,
65
+ 'large pizza': 1200, 'extra large pizza': 1700, 'xl pizza': 1700,
66
+ 'pizza slice': 150, 'pizza': 800,
67
+ # Burgers
68
+ 'hamburger': 150, 'cheeseburger': 220, 'double cheeseburger': 350,
69
+ 'big mac': 219, 'whopper': 291, 'quarter pounder': 220,
70
+ 'double quarter pounder': 350, 'mushroom swiss burger': 320,
71
+ 'bbq bacon burger': 350, 'blue cheese burger': 320,
72
+ 'veggie burger': 250, 'beyond burger': 250, 'impossible burger': 250,
73
+ 'turkey burger': 220, 'slider': 100, 'sliders': 100,
74
+ 'burger': 220,
75
+ # Chicken
76
+ 'chicken sandwich': 280, 'crispy chicken sandwich': 300,
77
+ 'grilled chicken sandwich': 280, 'spicy chicken sandwich': 290,
78
+ 'chicken burger': 250, 'fried chicken': 300,
79
+ 'chicken nuggets': 180, 'chicken tenders': 220, 'popcorn chicken': 180,
80
+ 'chicken wings': 300, 'buffalo wings': 300, 'boneless wings': 250,
81
+ 'chicken strips': 220, 'rotisserie chicken': 600,
82
+ # Sandwiches
83
+ 'sub sandwich': 450, 'club sandwich': 350, 'blt sandwich': 300,
84
+ 'tuna sandwich': 280, 'turkey sandwich': 300, 'ham sandwich': 280,
85
+ 'roast beef sandwich': 300, 'grilled cheese sandwich': 220,
86
+ 'reuben sandwich': 450, 'pastrami sandwich': 400,
87
+ 'meatball sub': 450, 'philly cheesesteak': 450,
88
+ 'italian sub': 450, 'sandwich': 300,
89
+ # Wraps & Mexican
90
+ 'wrap': 280, 'burrito': 500, 'taco': 150, 'soft taco': 180,
91
+ 'hard shell taco': 130, 'crunchy taco': 120, 'taco supreme': 180,
92
+ 'quesadilla': 350, 'nachos': 300, 'nachos supreme': 400,
93
+ 'nachos bellgrande': 450, 'enchilada': 320, 'fajita': 400,
94
+ 'chimichanga': 450, 'tostada': 200, 'churros': 100,
95
+ 'tamales': 200, 'bowl': 450,
96
+ # Sides
97
+ 'french fries': 150, 'small fries': 100, 'medium fries': 180,
98
+ 'large fries': 300, 'sweet potato fries': 200, 'curly fries': 150,
99
+ 'waffle fries': 180, 'steak fries': 200, 'onion rings': 180,
100
+ 'mozzarella sticks': 180, 'jalapeno poppers': 130,
101
+ 'loaded fries': 400, 'chili cheese fries': 450, 'tater tots': 150,
102
+ 'hash browns': 150, 'potato wedges': 200,
103
+ # Salads
104
+ 'side salad': 120, 'caesar salad': 250, 'garden salad': 150,
105
+ 'greek salad': 280, 'cobb salad': 350, 'chef salad': 350,
106
+ 'taco salad': 400, 'chicken salad': 280, 'pasta salad': 280,
107
+ 'potato salad': 220, 'coleslaw': 120, 'fruit salad': 180,
108
+ 'salad': 200,
109
+ # Breakfast
110
+ 'breakfast burrito': 450, 'breakfast sandwich': 250,
111
+ 'breakfast platter': 550, 'pancakes': 300, 'waffles': 300,
112
+ 'french toast': 300, 'omelette': 280, 'scrambled eggs': 200,
113
+ 'fried eggs': 180, 'bacon': 80, 'sausage': 100,
114
+ 'sausage patty': 90, 'sausage links': 100, 'hash browns': 150,
115
+ 'home fries': 180, 'biscuits and gravy': 350,
116
+ 'english muffin': 100, 'bagel': 110, 'croissant': 90,
117
+ 'cinnamon roll': 150, 'donut': 80, 'muffin': 110,
118
+ 'breakfast bowl': 400, 'eggs': 150,
119
+ # Pasta & Italian
120
+ 'pasta': 350, 'spaghetti': 400, 'penne pasta': 380, 'fettuccine': 350,
121
+ 'mac and cheese': 350, 'lasagna': 450, 'chicken parmigiana': 450,
122
+ 'chicken parmesan': 450, 'chicken alfredo': 450,
123
+ 'chicken marsala': 380, 'chicken piccata': 380,
124
+ # Asian
125
+ 'ramen': 550, 'pho': 550, 'pad thai': 450, 'lo mein': 450,
126
+ 'chow mein': 450, 'fried rice': 350, 'stir fry': 450,
127
+ 'curry': 450, 'beef stew': 500, 'dumplings': 250,
128
+ 'pot stickers': 250, 'egg rolls': 120, 'spring rolls': 120,
129
+ 'bao bun': 130, 'char siu': 250, 'kung pao chicken': 400,
130
+ 'general tso chicken': 450, 'orange chicken': 400,
131
+ 'sweet and sour chicken': 400, 'sesame chicken': 400,
132
+ 'beef and broccoli': 400, 'mongolian beef': 400,
133
+ 'kung pao shrimp': 350, 'mapo tofu': 350, 'hot pot': 600,
134
+ 'bibimbap': 450, 'bulgogi': 350, 'kimchi': 150,
135
+ 'bento box': 550, 'teriyaki chicken': 350, 'tonkatsu': 350,
136
+ 'udon': 500, 'soba': 400, 'sushi roll': 180, 'sashimi': 130,
137
+ 'tempura': 250, 'sushi platter': 700, 'sushi': 180,
138
+ # Soups
139
+ 'soup': 300, 'cup of soup': 250, 'bowl of soup': 350,
140
+ 'tomato soup': 280, 'chicken noodle soup': 300,
141
+ 'clam chowder': 350, 'lobster bisque': 350,
142
+ 'french onion soup': 320, 'minestrone': 320,
143
+ 'lentil soup': 320, 'vegetable soup': 280,
144
+ 'miso soup': 180, 'wonton soup': 300,
145
+ 'hot and sour soup': 280, 'egg drop soup': 250,
146
+ # Meat
147
+ 'beef steak': 300, 'sirloin steak': 300, 'ribeye steak': 380,
148
+ 'new york strip': 350, 'filet mignon': 250, 'pork chop': 280,
149
+ 'pork ribs': 400, 'bbq ribs': 450, 'meatloaf': 380,
150
+ 'spaghetti and meatballs': 500, 'meatballs': 280,
151
+ 'eggplant parmesan': 400, 'fish and chips': 450,
152
+ 'fish taco': 250, 'grilled salmon': 220, 'salmon fillet': 200,
153
+ 'shrimp': 200, 'fried shrimp': 250, 'calamari': 250,
154
+ 'crab cakes': 250, 'shrimp cocktail': 200,
155
+ # Appetizers
156
+ 'wings': 300, 'mozzarella sticks': 180, 'jalapeno poppers': 130,
157
+ 'onion rings': 180, 'breadsticks': 130, 'garlic bread': 150,
158
+ 'cheese bread': 160, 'spinach dip': 250, 'artichoke dip': 250,
159
+ 'queso dip': 200, 'guacamole': 200, 'salsa': 150,
160
+ 'bruschetta': 130, 'caprese': 200, 'antipasto': 250,
161
+ 'olives': 120, 'deviled eggs': 120, 'stuffed mushrooms': 200,
162
+ 'calamari': 250, 'crab rangoon': 200, 'edamame': 200,
163
+ 'gyoza': 200, 'hummus': 200, 'falafel': 200, 'samosa': 120,
164
+ 'loaded potato skins': 250, 'stuffed peppers': 250,
165
+ # Desserts
166
+ 'ice cream': 150, 'ice cream sundae': 300,
167
+ 'milkshake': 450, 'thick shake': 500, 'float': 400,
168
+ 'cookie': 50, 'chocolate chip cookie': 50,
169
+ 'brownie': 110, 'blondie': 100, 'cake': 150,
170
+ 'cake slice': 150, 'cheesecake': 150, 'pie slice': 150,
171
+ 'apple pie': 150, 'pudding': 200, 'flan': 200,
172
+ 'tiramisu': 180, 'creme brulee': 150, 'mousse': 130,
173
+ 'parfait': 250, 'fruit cup': 150, 'yogurt parfait': 250,
174
+ 'smoothie': 350, 'acai bowl': 350, 'frozen yogurt': 180,
175
+ 'sorbet': 150, 'gelato': 180, 'affogato': 180,
176
+ 'crepe': 200, 'waffle': 200, 'pancake': 130,
177
+ 'funnel cake': 250, 'churro': 90, 'beignet': 100,
178
+ 'baklava': 100, 'cannoli': 80, 'macaron': 15,
179
+ 'cupcake': 90, 'scone': 100, 'danish': 100,
180
+ 'eclair': 80, 'donut holes': 100, 'cinnamon roll': 150,
181
+ 'sticky bun': 130, 'apple fritter': 130, 'bear claw': 100,
182
+ 'dessert': 150,
183
+ # Platters
184
+ 'combo meal': 900, 'value meal': 800,
185
+ 'burger combo': 900, 'chicken combo': 900,
186
+ 'pizza combo': 1000, 'family meal': 2000,
187
+ 'party platter': 1500, 'feast': 2000,
188
+ 'dinner for two': 1500, 'dinner for four': 3000,
189
+ 'appetizer sampler': 600, 'sampler platter': 600,
190
+ 'wing platter': 600, 'sampler': 500, 'shareable': 500,
191
+ 'platter': 600,
192
+ # Beverages as menu items
193
+ 'soft drink': 400, 'soda': 400, 'cola': 400,
194
+ 'diet soda': 400, 'root beer': 400, 'ginger ale': 400,
195
+ 'cream soda': 400, 'lemon lime soda': 400,
196
+ 'iced tea': 500, 'sweet tea': 500, 'lemonade': 400,
197
+ 'fruit punch': 400, 'orange juice': 300,
198
+ 'apple juice': 300, 'cranberry juice': 300,
199
+ 'grapefruit juice': 300, 'tomato juice': 300,
200
+ 'milk': 300, 'chocolate milk': 350,
201
+ 'hot chocolate': 350, 'coffee': 350,
202
+ 'hot coffee': 350, 'iced coffee': 450,
203
+ 'latte': 350, 'cappuccino': 250, 'espresso': 40,
204
+ 'americano': 300, 'mocha': 350, 'macchiato': 250,
205
+ 'frappuccino': 450, 'cold brew': 450,
206
+ 'matcha latte': 400, 'chai latte': 400,
207
+ 'bubble tea': 500, 'milk tea': 400,
208
+ 'smoothie': 350, 'protein shake': 400,
209
+ 'meal replacement shake': 500,
210
+ 'energy drink': 350, 'sports drink': 500,
211
+ 'water bottle': 500, 'sparkling water': 400,
212
+ 'flavored water': 500, 'kombucha': 450,
213
+ 'kefir': 300,
214
+ # Generic
215
+ 'meal': 500, 'dish': 400, 'portion': 300,
216
+ 'appetizer': 200, 'entree': 450, 'main course': 500,
217
+ 'side dish': 150, 'side': 150,
218
+ }
219
+
220
+ # Size modifiers - only apply to portion size words, not item name words
221
+ SIZE_MODIFIERS = {
222
+ 'small': 0.6, 'mini': 0.4, 'junior': 0.5, 'kids': 0.5, 'child': 0.5,
223
+ 'medium': 1.0, 'regular': 1.0, 'standard': 1.0, 'normal': 1.0,
224
+ 'large': 1.5, 'big': 1.4, 'jumbo': 1.8, 'extra large': 1.8, 'xl': 1.8, 'xxl': 2.2,
225
+ 'double': 2.0, 'triple': 3.0, 'family': 2.5, 'party': 3.0,
226
+ 'supreme': 1.3, 'deluxe': 1.3, 'premium': 1.2, 'loaded': 1.3,
227
+ 'half': 0.5, 'full': 1.0, 'whole': 1.0, 'quarter': 0.25,
228
+ }
229
+
230
+
231
+ def extract_explicit_weight(text):
232
+ """Extract weight from explicit mentions like '500g', '2 liter', '12 oz'."""
233
+ text_lower = text.lower()
234
+ weights_found = []
235
+
236
+ # Pattern: number + unit (g, ml, kg, oz, lb, etc.)
237
+ patterns = [
238
+ r'(\d+(?:\.\d+)?)\s*(ml|milliliter|milliliters|millilitre|millilitres|cl|centiliter|centiliters|dl|deciliter|deciliters)',
239
+ r'(\d+(?:\.\d+)?)\s*(l|liter|liters|litre|litres)',
240
+ r'(\d+(?:\.\d+)?)\s*(g|gram|grams|gr)\b',
241
+ r'(\d+(?:\.\d+)?)\s*(kg|kilogram|kilograms|kilo)\b',
242
+ r'(\d+(?:\.\d+)?)\s*(mg|milligram|milligrams)',
243
+ r'(\d+(?:\.\d+)?)\s*(oz|ounce|ounces)',
244
+ r'(\d+(?:\.\d+)?)\s*(lb|lbs|pound|pounds)',
245
+ r'(\d+(?:\.\d+)?)\s*(fl\s*oz|fluid\s*ounce|fluid\s*ounces)',
246
+ r'(\d+(?:\.\d+)?)\s*(pt|pint|pints)',
247
+ r'(\d+(?:\.\d+)?)\s*(qt|quart|quarts)',
248
+ r'(\d+(?:\.\d+)?)\s*(gal|gallon|gallons)',
249
+ r'(\d+(?:\.\d+)?)\s*(cup|cups)',
250
+ ]
251
+
252
+ for pattern in patterns:
253
+ for match in re.finditer(pattern, text_lower):
254
+ val = float(match.group(1))
255
+ unit_str = match.group(2).strip()
256
+ for unit_key, conversion in UNIT_MAP.items():
257
+ if unit_str.startswith(unit_key):
258
+ weights_found.append(val * conversion)
259
+ break
260
+
261
+ # Extract pack size for multiplier
262
+ pack_match = re.search(r'(\d+)\s*pack(?:age|et)?s?\b', text_lower)
263
+ pack_size = int(pack_match.group(1)) if pack_match else 1
264
+
265
+ if weights_found:
266
+ # Use the largest weight found (usually the package weight)
267
+ return max(weights_found) * pack_size
268
+
269
+ return None
270
+
271
+
272
+ def get_knowledge_base_weight(text, item_type):
273
+ """Get weight from knowledge base for known food/grocery items."""
274
+ text_lower = text.lower()
275
+
276
+ if item_type == 'grocery':
277
+ # First try explicit weight
278
+ explicit = extract_explicit_weight(text)
279
+ if explicit is not None:
280
+ return explicit
281
+
282
+ # Then try grocery knowledge base for known brands
283
+ best_match = None
284
+ best_weight = None
285
+ best_len = 0
286
+ for item_name, weight in GROCERY_KB.items():
287
+ if item_name in text_lower:
288
+ if len(item_name) > best_len:
289
+ best_match = item_name
290
+ best_weight = weight
291
+ best_len = len(item_name)
292
+
293
+ # Apply pack multiplier
294
+ pack_match = re.search(r'(\d+)\s*pack', text_lower)
295
+ if pack_match and best_weight:
296
+ pack_size = int(pack_match.group(1))
297
+ # Estimate: pack_size * single_unit_weight * 0.9 (packaging savings)
298
+ return pack_size * best_weight * 0.95
299
+
300
+ if best_weight:
301
+ return best_weight
302
+
303
+ elif item_type == 'menu_item':
304
+ # First try explicit weight
305
+ explicit = extract_explicit_weight(text)
306
+ if explicit is not None:
307
+ return explicit
308
+
309
+ # Find best matching item from portion KB
310
+ best_match = None
311
+ best_weight = None
312
+ best_len = 0
313
+ for item_name, weight in PORTION_KB.items():
314
+ if item_name in text_lower:
315
+ if len(item_name) > best_len:
316
+ best_match = item_name
317
+ best_weight = weight
318
+ best_len = len(item_name)
319
+
320
+ if best_weight and best_match:
321
+ # Don't apply modifiers that are already part of the matched item name words
322
+ best_words = set(best_match.split())
323
+ multiplier = 1.0
324
+ for mod, mult in SIZE_MODIFIERS.items():
325
+ if mod in text_lower and mod not in best_words:
326
+ multiplier = max(multiplier, mult)
327
+ return best_weight * multiplier
328
+
329
+ return None
330
+
331
+
332
+ class HybridWeightPredictor:
333
+ """Hybrid predictor: explicit extraction → KB lookup → ML fallback."""
334
+ def __init__(self, ml_predictor=None):
335
+ self.ml_predictor = ml_predictor
336
+
337
+ def predict(self, text, item_type=None):
338
+ """Predict weight using hybrid approach."""
339
+ # Auto-detect item type
340
+ if item_type is None:
341
+ if text.startswith("[MENU_ITEM]"):
342
+ item_type = "menu_item"
343
+ elif text.startswith("[GROCERY]"):
344
+ item_type = "grocery"
345
+ elif text.startswith("[NON_FOOD]"):
346
+ item_type = "non_food"
347
+ else:
348
+ item_type = "grocery"
349
+
350
+ # Step 1: Explicit weight extraction (works for all types)
351
+ explicit_weight = extract_explicit_weight(text)
352
+ if explicit_weight is not None:
353
+ return explicit_weight
354
+
355
+ # Step 2: Knowledge base lookup
356
+ kb_weight = get_knowledge_base_weight(text, item_type)
357
+ if kb_weight is not None:
358
+ return kb_weight
359
+
360
+ # Step 3: ML fallback
361
+ if self.ml_predictor is not None:
362
+ return self.ml_predictor.predict(text, item_type)
363
+
364
+ # Step 4: Default
365
+ return {'menu_item': 300, 'grocery': 400, 'non_food': 500}.get(item_type, 300)
366
+
367
+ def predict_single(self, text, item_type=None):
368
+ return self.predict(text, item_type)
369
+
370
+
371
+ def build_hybrid_predictor(ml_model_path="/app/weight_predictor_v5/unified_predictor.pkl"):
372
+ try:
373
+ ml_predictor = joblib.load(ml_model_path)
374
+ except Exception as e:
375
+ print(f"Warning: Could not load ML model: {e}")
376
+ ml_predictor = None
377
+ return HybridWeightPredictor(ml_predictor)
378
+
379
+
380
+ if __name__ == "__main__":
381
+ predictor = build_hybrid_predictor()
382
+
383
+ test_cases = [
384
+ # FMCG with explicit sizes
385
+ ("[GROCERY] coca cola can 330ml", "grocery"),
386
+ ("[GROCERY] coca cola bottle 2 liter", "grocery"),
387
+ ("[GROCERY] pepsi 1 liter bottle", "grocery"),
388
+ ("[GROCERY] kelloggs corn flakes 500g", "grocery"),
389
+ ("[GROCERY] oreo cookies 154g", "grocery"),
390
+ ("[GROCERY] heinz ketchup 570ml", "grocery"),
391
+ ("[GROCERY] mars bar 51g", "grocery"),
392
+ ("[GROCERY] snickers 2 pack 96g", "grocery"),
393
+ ("[GROCERY] red bull 4 pack", "grocery"),
394
+ ("[GROCERY] tide laundry detergent 1.5kg", "grocery"),
395
+ ("[GROCERY] coca cola", "grocery"), # no size - should default to 330g
396
+ ("[GROCERY] pepsi", "grocery"), # no size - should default to 330g
397
+ ("[GROCERY] oreo", "grocery"), # no size - should default to 154g
398
+ # Menu items
399
+ ("[MENU_ITEM] large pizza", "menu_item"),
400
+ ("[MENU_ITEM] cheeseburger", "menu_item"),
401
+ ("[MENU_ITEM] double cheeseburger", "menu_item"),
402
+ ("[MENU_ITEM] big mac", "menu_item"),
403
+ ("[MENU_ITEM] french fries", "menu_item"),
404
+ ("[MENU_ITEM] large fries", "menu_item"),
405
+ ("[MENU_ITEM] chicken nuggets", "menu_item"),
406
+ ("[MENU_ITEM] burrito", "menu_item"),
407
+ ("[MENU_ITEM] caesar salad", "menu_item"),
408
+ ("[MENU_ITEM] caesar salad large", "menu_item"),
409
+ ("[MENU_ITEM] pho", "menu_item"),
410
+ ("[MENU_ITEM] ramen", "menu_item"),
411
+ ("[MENU_ITEM] sushi platter", "menu_item"),
412
+ ("[MENU_ITEM] medium pizza", "menu_item"),
413
+ ("[MENU_ITEM] personal pizza", "menu_item"),
414
+ ("[MENU_ITEM] combo meal", "menu_item"),
415
+ ("[MENU_ITEM] milkshake", "menu_item"),
416
+ ("[MENU_ITEM] iced coffee", "menu_item"),
417
+ ("[MENU_ITEM] family meal", "menu_item"),
418
+ ("[MENU_ITEM] sliders", "menu_item"),
419
+ # Non-food
420
+ ("[NON_FOOD] laptop computer", "non_food"),
421
+ ("[NON_FOOD] water bottle", "non_food"),
422
+ ]
423
+
424
+ print("=== Hybrid Weight Predictor Tests ===\n")
425
+ for text, item_type in test_cases:
426
+ weight = predictor.predict(text, item_type)
427
+ print(f" {text:55s} -> {weight:8.1f}g")