weight-predictor / hybrid_weight_predictor.py
ZZandro's picture
Upload hybrid_weight_predictor.py with huggingface_hub
9fb21de verified
"""
Hybrid weight predictor for food delivery platforms.
Combines explicit weight extraction, rule-based knowledge base, and ML fallback.
"""
import re
import numpy as np
import joblib
UNIT_MAP = {
'ml': 1.0, 'milliliter': 1.0, 'milliliters': 1.0, 'millilitre': 1.0, 'millilitres': 1.0,
'l': 1000.0, 'liter': 1000.0, 'liters': 1000.0, 'litre': 1000.0, 'litres': 1000.0,
'cl': 10.0, 'centiliter': 10.0, 'centiliters': 10.0,
'dl': 100.0, 'deciliter': 100.0, 'deciliters': 100.0,
'g': 1.0, 'gram': 1.0, 'grams': 1.0, 'gr': 1.0,
'kg': 1000.0, 'kilogram': 1000.0, 'kilograms': 1000.0, 'kilo': 1000.0,
'mg': 0.001, 'milligram': 0.001, 'milligrams': 0.001,
'oz': 28.3495, 'ounce': 28.3495, 'ounces': 28.3495,
'lb': 453.592, 'lbs': 453.592, 'pound': 453.592, 'pounds': 453.592,
'fl oz': 29.5735, 'fluid ounce': 29.5735, 'fluid ounces': 29.5735,
'pt': 473.176, 'pint': 473.176, 'pints': 473.176,
'qt': 946.353, 'quart': 946.353, 'quarts': 946.353,
'gal': 3785.41, 'gallon': 3785.41, 'gallons': 3785.41,
'cup': 240.0, 'cups': 240.0,
}
# Common single-unit weights for known grocery products (fallback when no explicit weight in text)
GROCERY_KB = {
# Beverages (per single unit if size not specified)
'coca cola': 330, 'coke': 330, 'pepsi': 330, 'sprite': 330, 'fanta': 330,
'7up': 330, 'mountain dew': 500, 'dr pepper': 355,
'red bull': 250, 'monster': 500, 'gatorade': 500, 'powerade': 500,
'aquafina': 500, 'dasani': 500, 'evian': 500, 'smartwater': 700,
'volvic': 500, 'perrier': 330, 'san pellegrino': 500,
'minute maid': 250, 'tropicana': 250, 'honest tea': 500, 'fuze tea': 500,
'lipton iced tea': 500, 'snapple': 473, 'vitaminwater': 500,
'kombucha': 450, 'kefir': 250,
'heineken': 330, 'budweiser': 355, 'corona': 355, 'stella artois': 330,
'carlsberg': 330, 'guinness': 440,
# Cereals (per typical box if size not specified)
'kelloggs corn flakes': 500, 'kelloggs special k': 500,
'kelloggs rice krispies': 510, 'kelloggs coco pops': 350,
'kelloggs frosted flakes': 425, 'cheerios': 340, 'lucky charms': 326,
'cinnamon toast crunch': 340, 'honey nut cheerios': 306,
'trix': 285, 'wheaties': 450,
# Snacks
'oreo': 154, 'oreos': 154, 'pringles': 165, 'doritos': 175,
'lays': 175, 'cheetos': 200, 'ruffles': 200,
'kit kat': 42, 'snickers': 48, 'mars bar': 51, 'twix': 50,
'm and m': 42, 'm and ms': 42, 'maltesers': 37, 'skittles': 45,
'toblerone': 100, 'milka': 100, 'cadbury': 110,
'haribo': 100, 'gummy bears': 100,
# Personal care (if no size specified)
'toothpaste': 100, 'shampoo': 400, 'conditioner': 400,
'body wash': 500, 'deodorant': 50, 'soap': 100,
# Cleaning
'laundry detergent': 1500, 'dish soap': 500, 'bleach': 750,
# Baby
'baby formula': 800, 'baby food': 120, 'diapers': 800,
}
# Fast-food portion knowledge base (item → typical single-serving weight in grams)
PORTION_KB = {
# Pizza
'small pizza': 500, 'personal pizza': 400, 'medium pizza': 800,
'large pizza': 1200, 'extra large pizza': 1700, 'xl pizza': 1700,
'pizza slice': 150, 'pizza': 800,
# Burgers
'hamburger': 150, 'cheeseburger': 220, 'double cheeseburger': 350,
'big mac': 219, 'whopper': 291, 'quarter pounder': 220,
'double quarter pounder': 350, 'mushroom swiss burger': 320,
'bbq bacon burger': 350, 'blue cheese burger': 320,
'veggie burger': 250, 'beyond burger': 250, 'impossible burger': 250,
'turkey burger': 220, 'slider': 100, 'sliders': 100,
'burger': 220,
# Chicken
'chicken sandwich': 280, 'crispy chicken sandwich': 300,
'grilled chicken sandwich': 280, 'spicy chicken sandwich': 290,
'chicken burger': 250, 'fried chicken': 300,
'chicken nuggets': 180, 'chicken tenders': 220, 'popcorn chicken': 180,
'chicken wings': 300, 'buffalo wings': 300, 'boneless wings': 250,
'chicken strips': 220, 'rotisserie chicken': 600,
# Sandwiches
'sub sandwich': 450, 'club sandwich': 350, 'blt sandwich': 300,
'tuna sandwich': 280, 'turkey sandwich': 300, 'ham sandwich': 280,
'roast beef sandwich': 300, 'grilled cheese sandwich': 220,
'reuben sandwich': 450, 'pastrami sandwich': 400,
'meatball sub': 450, 'philly cheesesteak': 450,
'italian sub': 450, 'sandwich': 300,
# Wraps & Mexican
'wrap': 280, 'burrito': 500, 'taco': 150, 'soft taco': 180,
'hard shell taco': 130, 'crunchy taco': 120, 'taco supreme': 180,
'quesadilla': 350, 'nachos': 300, 'nachos supreme': 400,
'nachos bellgrande': 450, 'enchilada': 320, 'fajita': 400,
'chimichanga': 450, 'tostada': 200, 'churros': 100,
'tamales': 200, 'bowl': 450,
# Sides
'french fries': 150, 'small fries': 100, 'medium fries': 180,
'large fries': 300, 'sweet potato fries': 200, 'curly fries': 150,
'waffle fries': 180, 'steak fries': 200, 'onion rings': 180,
'mozzarella sticks': 180, 'jalapeno poppers': 130,
'loaded fries': 400, 'chili cheese fries': 450, 'tater tots': 150,
'hash browns': 150, 'potato wedges': 200,
# Salads
'side salad': 120, 'caesar salad': 250, 'garden salad': 150,
'greek salad': 280, 'cobb salad': 350, 'chef salad': 350,
'taco salad': 400, 'chicken salad': 280, 'pasta salad': 280,
'potato salad': 220, 'coleslaw': 120, 'fruit salad': 180,
'salad': 200,
# Breakfast
'breakfast burrito': 450, 'breakfast sandwich': 250,
'breakfast platter': 550, 'pancakes': 300, 'waffles': 300,
'french toast': 300, 'omelette': 280, 'scrambled eggs': 200,
'fried eggs': 180, 'bacon': 80, 'sausage': 100,
'sausage patty': 90, 'sausage links': 100, 'hash browns': 150,
'home fries': 180, 'biscuits and gravy': 350,
'english muffin': 100, 'bagel': 110, 'croissant': 90,
'cinnamon roll': 150, 'donut': 80, 'muffin': 110,
'breakfast bowl': 400, 'eggs': 150,
# Pasta & Italian
'pasta': 350, 'spaghetti': 400, 'penne pasta': 380, 'fettuccine': 350,
'mac and cheese': 350, 'lasagna': 450, 'chicken parmigiana': 450,
'chicken parmesan': 450, 'chicken alfredo': 450,
'chicken marsala': 380, 'chicken piccata': 380,
# Asian
'ramen': 550, 'pho': 550, 'pad thai': 450, 'lo mein': 450,
'chow mein': 450, 'fried rice': 350, 'stir fry': 450,
'curry': 450, 'beef stew': 500, 'dumplings': 250,
'pot stickers': 250, 'egg rolls': 120, 'spring rolls': 120,
'bao bun': 130, 'char siu': 250, 'kung pao chicken': 400,
'general tso chicken': 450, 'orange chicken': 400,
'sweet and sour chicken': 400, 'sesame chicken': 400,
'beef and broccoli': 400, 'mongolian beef': 400,
'kung pao shrimp': 350, 'mapo tofu': 350, 'hot pot': 600,
'bibimbap': 450, 'bulgogi': 350, 'kimchi': 150,
'bento box': 550, 'teriyaki chicken': 350, 'tonkatsu': 350,
'udon': 500, 'soba': 400, 'sushi roll': 180, 'sashimi': 130,
'tempura': 250, 'sushi platter': 700, 'sushi': 180,
# Soups
'soup': 300, 'cup of soup': 250, 'bowl of soup': 350,
'tomato soup': 280, 'chicken noodle soup': 300,
'clam chowder': 350, 'lobster bisque': 350,
'french onion soup': 320, 'minestrone': 320,
'lentil soup': 320, 'vegetable soup': 280,
'miso soup': 180, 'wonton soup': 300,
'hot and sour soup': 280, 'egg drop soup': 250,
# Meat
'beef steak': 300, 'sirloin steak': 300, 'ribeye steak': 380,
'new york strip': 350, 'filet mignon': 250, 'pork chop': 280,
'pork ribs': 400, 'bbq ribs': 450, 'meatloaf': 380,
'spaghetti and meatballs': 500, 'meatballs': 280,
'eggplant parmesan': 400, 'fish and chips': 450,
'fish taco': 250, 'grilled salmon': 220, 'salmon fillet': 200,
'shrimp': 200, 'fried shrimp': 250, 'calamari': 250,
'crab cakes': 250, 'shrimp cocktail': 200,
# Appetizers
'wings': 300, 'mozzarella sticks': 180, 'jalapeno poppers': 130,
'onion rings': 180, 'breadsticks': 130, 'garlic bread': 150,
'cheese bread': 160, 'spinach dip': 250, 'artichoke dip': 250,
'queso dip': 200, 'guacamole': 200, 'salsa': 150,
'bruschetta': 130, 'caprese': 200, 'antipasto': 250,
'olives': 120, 'deviled eggs': 120, 'stuffed mushrooms': 200,
'calamari': 250, 'crab rangoon': 200, 'edamame': 200,
'gyoza': 200, 'hummus': 200, 'falafel': 200, 'samosa': 120,
'loaded potato skins': 250, 'stuffed peppers': 250,
# Desserts
'ice cream': 150, 'ice cream sundae': 300,
'milkshake': 450, 'thick shake': 500, 'float': 400,
'cookie': 50, 'chocolate chip cookie': 50,
'brownie': 110, 'blondie': 100, 'cake': 150,
'cake slice': 150, 'cheesecake': 150, 'pie slice': 150,
'apple pie': 150, 'pudding': 200, 'flan': 200,
'tiramisu': 180, 'creme brulee': 150, 'mousse': 130,
'parfait': 250, 'fruit cup': 150, 'yogurt parfait': 250,
'smoothie': 350, 'acai bowl': 350, 'frozen yogurt': 180,
'sorbet': 150, 'gelato': 180, 'affogato': 180,
'crepe': 200, 'waffle': 200, 'pancake': 130,
'funnel cake': 250, 'churro': 90, 'beignet': 100,
'baklava': 100, 'cannoli': 80, 'macaron': 15,
'cupcake': 90, 'scone': 100, 'danish': 100,
'eclair': 80, 'donut holes': 100, 'cinnamon roll': 150,
'sticky bun': 130, 'apple fritter': 130, 'bear claw': 100,
'dessert': 150,
# Platters
'combo meal': 900, 'value meal': 800,
'burger combo': 900, 'chicken combo': 900,
'pizza combo': 1000, 'family meal': 2000,
'party platter': 1500, 'feast': 2000,
'dinner for two': 1500, 'dinner for four': 3000,
'appetizer sampler': 600, 'sampler platter': 600,
'wing platter': 600, 'sampler': 500, 'shareable': 500,
'platter': 600,
# Beverages as menu items
'soft drink': 400, 'soda': 400, 'cola': 400,
'diet soda': 400, 'root beer': 400, 'ginger ale': 400,
'cream soda': 400, 'lemon lime soda': 400,
'iced tea': 500, 'sweet tea': 500, 'lemonade': 400,
'fruit punch': 400, 'orange juice': 300,
'apple juice': 300, 'cranberry juice': 300,
'grapefruit juice': 300, 'tomato juice': 300,
'milk': 300, 'chocolate milk': 350,
'hot chocolate': 350, 'coffee': 350,
'hot coffee': 350, 'iced coffee': 450,
'latte': 350, 'cappuccino': 250, 'espresso': 40,
'americano': 300, 'mocha': 350, 'macchiato': 250,
'frappuccino': 450, 'cold brew': 450,
'matcha latte': 400, 'chai latte': 400,
'bubble tea': 500, 'milk tea': 400,
'smoothie': 350, 'protein shake': 400,
'meal replacement shake': 500,
'energy drink': 350, 'sports drink': 500,
'water bottle': 500, 'sparkling water': 400,
'flavored water': 500, 'kombucha': 450,
'kefir': 300,
# Generic
'meal': 500, 'dish': 400, 'portion': 300,
'appetizer': 200, 'entree': 450, 'main course': 500,
'side dish': 150, 'side': 150,
}
# Size modifiers - only apply to portion size words, not item name words
SIZE_MODIFIERS = {
'small': 0.6, 'mini': 0.4, 'junior': 0.5, 'kids': 0.5, 'child': 0.5,
'medium': 1.0, 'regular': 1.0, 'standard': 1.0, 'normal': 1.0,
'large': 1.5, 'big': 1.4, 'jumbo': 1.8, 'extra large': 1.8, 'xl': 1.8, 'xxl': 2.2,
'double': 2.0, 'triple': 3.0, 'family': 2.5, 'party': 3.0,
'supreme': 1.3, 'deluxe': 1.3, 'premium': 1.2, 'loaded': 1.3,
'half': 0.5, 'full': 1.0, 'whole': 1.0, 'quarter': 0.25,
}
def extract_explicit_weight(text):
"""Extract weight from explicit mentions like '500g', '2 liter', '12 oz'."""
text_lower = text.lower()
weights_found = []
# Pattern: number + unit (g, ml, kg, oz, lb, etc.)
patterns = [
r'(\d+(?:\.\d+)?)\s*(ml|milliliter|milliliters|millilitre|millilitres|cl|centiliter|centiliters|dl|deciliter|deciliters)',
r'(\d+(?:\.\d+)?)\s*(l|liter|liters|litre|litres)',
r'(\d+(?:\.\d+)?)\s*(g|gram|grams|gr)\b',
r'(\d+(?:\.\d+)?)\s*(kg|kilogram|kilograms|kilo)\b',
r'(\d+(?:\.\d+)?)\s*(mg|milligram|milligrams)',
r'(\d+(?:\.\d+)?)\s*(oz|ounce|ounces)',
r'(\d+(?:\.\d+)?)\s*(lb|lbs|pound|pounds)',
r'(\d+(?:\.\d+)?)\s*(fl\s*oz|fluid\s*ounce|fluid\s*ounces)',
r'(\d+(?:\.\d+)?)\s*(pt|pint|pints)',
r'(\d+(?:\.\d+)?)\s*(qt|quart|quarts)',
r'(\d+(?:\.\d+)?)\s*(gal|gallon|gallons)',
r'(\d+(?:\.\d+)?)\s*(cup|cups)',
]
for pattern in patterns:
for match in re.finditer(pattern, text_lower):
val = float(match.group(1))
unit_str = match.group(2).strip()
for unit_key, conversion in UNIT_MAP.items():
if unit_str.startswith(unit_key):
weights_found.append(val * conversion)
break
# Extract pack size for multiplier
pack_match = re.search(r'(\d+)\s*pack(?:age|et)?s?\b', text_lower)
pack_size = int(pack_match.group(1)) if pack_match else 1
if weights_found:
# Use the largest weight found (usually the package weight)
return max(weights_found) * pack_size
return None
def get_knowledge_base_weight(text, item_type):
"""Get weight from knowledge base for known food/grocery items."""
text_lower = text.lower()
if item_type == 'grocery':
# First try explicit weight
explicit = extract_explicit_weight(text)
if explicit is not None:
return explicit
# Then try grocery knowledge base for known brands
best_match = None
best_weight = None
best_len = 0
for item_name, weight in GROCERY_KB.items():
if item_name in text_lower:
if len(item_name) > best_len:
best_match = item_name
best_weight = weight
best_len = len(item_name)
# Apply pack multiplier
pack_match = re.search(r'(\d+)\s*pack', text_lower)
if pack_match and best_weight:
pack_size = int(pack_match.group(1))
# Estimate: pack_size * single_unit_weight * 0.9 (packaging savings)
return pack_size * best_weight * 0.95
if best_weight:
return best_weight
elif item_type == 'menu_item':
# First try explicit weight
explicit = extract_explicit_weight(text)
if explicit is not None:
return explicit
# Find best matching item from portion KB
best_match = None
best_weight = None
best_len = 0
for item_name, weight in PORTION_KB.items():
if item_name in text_lower:
if len(item_name) > best_len:
best_match = item_name
best_weight = weight
best_len = len(item_name)
if best_weight and best_match:
# Don't apply modifiers that are already part of the matched item name words
best_words = set(best_match.split())
multiplier = 1.0
for mod, mult in SIZE_MODIFIERS.items():
if mod in text_lower and mod not in best_words:
multiplier = max(multiplier, mult)
return best_weight * multiplier
return None
class HybridWeightPredictor:
"""Hybrid predictor: explicit extraction → KB lookup → ML fallback."""
def __init__(self, ml_predictor=None):
self.ml_predictor = ml_predictor
def predict(self, text, item_type=None):
"""Predict weight using hybrid approach."""
# Auto-detect item type
if item_type is None:
if text.startswith("[MENU_ITEM]"):
item_type = "menu_item"
elif text.startswith("[GROCERY]"):
item_type = "grocery"
elif text.startswith("[NON_FOOD]"):
item_type = "non_food"
else:
item_type = "grocery"
# Step 1: Explicit weight extraction (works for all types)
explicit_weight = extract_explicit_weight(text)
if explicit_weight is not None:
return explicit_weight
# Step 2: Knowledge base lookup
kb_weight = get_knowledge_base_weight(text, item_type)
if kb_weight is not None:
return kb_weight
# Step 3: ML fallback
if self.ml_predictor is not None:
return self.ml_predictor.predict(text, item_type)
# Step 4: Default
return {'menu_item': 300, 'grocery': 400, 'non_food': 500}.get(item_type, 300)
def predict_single(self, text, item_type=None):
return self.predict(text, item_type)
def build_hybrid_predictor(ml_model_path="/app/weight_predictor_v5/unified_predictor.pkl"):
try:
ml_predictor = joblib.load(ml_model_path)
except Exception as e:
print(f"Warning: Could not load ML model: {e}")
ml_predictor = None
return HybridWeightPredictor(ml_predictor)
if __name__ == "__main__":
predictor = build_hybrid_predictor()
test_cases = [
# FMCG with explicit sizes
("[GROCERY] coca cola can 330ml", "grocery"),
("[GROCERY] coca cola bottle 2 liter", "grocery"),
("[GROCERY] pepsi 1 liter bottle", "grocery"),
("[GROCERY] kelloggs corn flakes 500g", "grocery"),
("[GROCERY] oreo cookies 154g", "grocery"),
("[GROCERY] heinz ketchup 570ml", "grocery"),
("[GROCERY] mars bar 51g", "grocery"),
("[GROCERY] snickers 2 pack 96g", "grocery"),
("[GROCERY] red bull 4 pack", "grocery"),
("[GROCERY] tide laundry detergent 1.5kg", "grocery"),
("[GROCERY] coca cola", "grocery"), # no size - should default to 330g
("[GROCERY] pepsi", "grocery"), # no size - should default to 330g
("[GROCERY] oreo", "grocery"), # no size - should default to 154g
# Menu items
("[MENU_ITEM] large pizza", "menu_item"),
("[MENU_ITEM] cheeseburger", "menu_item"),
("[MENU_ITEM] double cheeseburger", "menu_item"),
("[MENU_ITEM] big mac", "menu_item"),
("[MENU_ITEM] french fries", "menu_item"),
("[MENU_ITEM] large fries", "menu_item"),
("[MENU_ITEM] chicken nuggets", "menu_item"),
("[MENU_ITEM] burrito", "menu_item"),
("[MENU_ITEM] caesar salad", "menu_item"),
("[MENU_ITEM] caesar salad large", "menu_item"),
("[MENU_ITEM] pho", "menu_item"),
("[MENU_ITEM] ramen", "menu_item"),
("[MENU_ITEM] sushi platter", "menu_item"),
("[MENU_ITEM] medium pizza", "menu_item"),
("[MENU_ITEM] personal pizza", "menu_item"),
("[MENU_ITEM] combo meal", "menu_item"),
("[MENU_ITEM] milkshake", "menu_item"),
("[MENU_ITEM] iced coffee", "menu_item"),
("[MENU_ITEM] family meal", "menu_item"),
("[MENU_ITEM] sliders", "menu_item"),
# Non-food
("[NON_FOOD] laptop computer", "non_food"),
("[NON_FOOD] water bottle", "non_food"),
]
print("=== Hybrid Weight Predictor Tests ===\n")
for text, item_type in test_cases:
weight = predictor.predict(text, item_type)
print(f" {text:55s} -> {weight:8.1f}g")