Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import numpy as np | |
| import gradio as gr | |
| # Optional offline fallback embeddings | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Try to import sentence-transformers, but we’ll fall back if it can’t download | |
| try: | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| _HAS_SBERT = True | |
| except Exception: | |
| _HAS_SBERT = False | |
| from datasets import load_dataset # datasets worked for you per logs | |
| # ======================== | |
| # Config | |
| # ======================== | |
| DATASET_ID = "motimmom/cocktails_clean_nobrand" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| FLAVOR_BOOST = 0.20 | |
| # Hardcoded background image | |
| BACKGROUND_IMAGE_URL = "https://huggingface.co/spaces/OGOGOG/Bartender-AI/resolve/main/bar.jpg" | |
| # If dataset is private, add Space secret HF_TOKEN (read scope) | |
| HF_READ_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| load_kwargs = {} | |
| if HF_READ_TOKEN: | |
| load_kwargs["token"] = HF_READ_TOKEN | |
| load_kwargs["use_auth_token"] = HF_READ_TOKEN | |
| # ======================== | |
| # Base & Flavor tagging rules | |
| # ======================== | |
| BASE_SPIRITS = { | |
| "vodka": [r"\bvodka\b"], | |
| "gin": [r"\bgin\b"], | |
| "rum": [r"\brum\b", r"\bwhite rum\b", r"\bdark rum\b"], | |
| "tequila": [r"\btequila\b"], | |
| "whiskey": [r"\bwhisk(?:e|)y\b", r"\bbourbon\b", r"\bscotch\b", r"\brye\b"], | |
| "mezcal": [r"\bmezcal\b"], | |
| "brandy": [r"\bbrandy\b", r"\bcognac\b"], | |
| "vermouth": [r"\bvermouth\b"], | |
| "other": [r"\btriple sec\b", r"\bliqueur\b", r"\bcointreau\b", r"\baperol\b", r"\bcampari\b"], | |
| } | |
| FLAVORS = { | |
| "citrus": [r"lime", r"lemon", r"grapefruit", r"orange", r"\bcitrus\b"], | |
| "sweet": [r"simple syrup", r"\bsugar\b", r"\bhoney\b", r"\bagave\b", r"\bmaple\b", r"\bgrenadine\b", r"\bvanilla\b", r"\bsweet\b"], | |
| "sour": [r"\bsour\b", r"lemon juice", r"lime juice", r"\bacid\b"], | |
| "bitter": [r"\bbitter", r"\bamaro\b", r"\bcampari\b", r"\baperol\b"], | |
| "smoky": [r"\bsmoky\b", r"\bsmoked\b", r"\bmezcal\b", r"\bpeated\b"], | |
| "spicy": [r"\bspicy\b", r"\bchili\b", r"\bginger\b", r"\bjalapeño\b", r"\bcayenne\b"], | |
| "herbal": [r"\bmint\b", r"\bbasil\b", r"\brosemary\b", r"\bthyme\b", r"\bherb", r"\bchartreuse\b"], | |
| "fruity": [r"pineapple", r"cranberr", r"strawberr", r"mango", r"passion", r"peach", r"\bfruit"], | |
| "creamy": [r"\bcream\b", r"coconut cream", r"\begg white\b", r"\bcreamy\b"], | |
| "floral": [r"\brose\b", r"\bviolet\b", r"\belderflower\b", r"\blavender\b", r"\bfloral\b"], | |
| "refreshing": [r"soda water", r"\btonic\b", r"\bhighball\b", r"\bcollins\b", r"\bfizz\b", r"\brefreshing\b"], | |
| "boozy": [r"\bstirred\b", r"\bmartini\b", r"old fashioned", r"\bboozy\b", r"\bstrong\b"], | |
| } | |
| FLAVOR_OPTIONS = list(FLAVORS.keys()) | |
| # ======================== | |
| # Robust extraction helpers (with measures) | |
| # ======================== | |
| def _clean(s): | |
| return s.strip() if isinstance(s, str) else "" | |
| def _norm_measure(s: str) -> str: | |
| if not isinstance(s, str): | |
| return "" | |
| s = re.sub(r"\s+", " ", s.strip()) | |
| s = re.sub(r"\bml\b", "ml", s, flags=re.I) | |
| s = re.sub(r"\boz\b", "oz", s, flags=re.I) | |
| s = re.sub(r"\btsp\b", "tsp", s, flags=re.I) | |
| s = re.sub(r"\btbsp\b", "tbsp", s, flags=re.I) | |
| return s | |
| def _join_measure_name(measure, name): | |
| m = _norm_measure(measure) | |
| n = name.strip() if isinstance(name, str) else "" | |
| if m and n: return f"{m} {n}" | |
| return n or m | |
| def _split_ingredient_blob(s): | |
| if not isinstance(s, str): return [] | |
| parts = re.split(r"[,\n;•\-–]+", s) | |
| return [p.strip() for p in parts if p.strip()] | |
| def _ingredients_from_any(val): | |
| if isinstance(val, str): | |
| lines = _split_ingredient_blob(val) | |
| tokens = [] | |
| for line in lines: | |
| parts = re.split(r"\s+", line) | |
| idx = 0 | |
| for i, p in enumerate(parts): | |
| if re.search(r"[A-Za-z]", p): | |
| idx = i; break | |
| tokens.append(" ".join(parts[idx:]).lower()) | |
| return lines, tokens | |
| return [], [] | |
| def _get_title(row, cols): | |
| for k in ["title","name","cocktail_name","drink","Drink","strDrink"]: | |
| if k in cols and _clean(row.get(k)): | |
| return _clean(row[k]) | |
| return "Untitled" | |
| def _get_ingredients_with_measures(row, cols): | |
| for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients", | |
| "ingredient_list","ingredients_list"]: | |
| if key in cols and row.get(key) not in (None, "", [], {}): | |
| return _ingredients_from_any(row[key]) | |
| return [], [] | |
| def tag_base(text): | |
| t = text.lower() | |
| for base, pats in BASE_SPIRITS.items(): | |
| if any(re.search(p, t) for p in pats): | |
| return base | |
| return "other" | |
| def tag_flavors(text): | |
| t = text.lower() | |
| return [flv for flv, pats in FLAVORS.items() if any(re.search(p, t) for p in pats)] | |
| # ======================== | |
| # Load dataset & build docs | |
| # ======================== | |
| ds = load_dataset(DATASET_ID, split="train", **load_kwargs) | |
| cols = ds.column_names | |
| DOCS = [] | |
| for r in ds: | |
| title = _get_title(r, cols) | |
| ing_disp, ing_tokens = _get_ingredients_with_measures(r, cols) | |
| ing_disp = [x for x in ing_disp if x] | |
| ing_tokens = [x for x in ing_tokens if x] | |
| fused = f"{title}\nIngredients: {', '.join(ing_tokens)}" | |
| DOCS.append({ | |
| "title": title, | |
| "ingredients_display": ing_disp, | |
| "ingredients_tokens": ing_tokens, | |
| "text": fused, | |
| "base": tag_base(fused), | |
| "flavors": tag_flavors(fused), | |
| }) | |
| # ======================== | |
| # Embedding backends (SBERT -> TF-IDF fallback) | |
| # ======================== | |
| class Embedder: | |
| def __init__(self): | |
| self.mode = "tfidf" | |
| self.encoder = None | |
| self.vectorizer = None | |
| self.doc_matrix = None | |
| # Try SBERT if available and downloadable | |
| if _HAS_SBERT: | |
| try: | |
| self.encoder = SentenceTransformer(EMBED_MODEL) | |
| self.mode = "sbert" | |
| except Exception as e: | |
| print(f"[WARN] SBERT model load failed, falling back to TF-IDF. Reason: {e}") | |
| if self.mode == "tfidf": | |
| self.vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1) | |
| print(f"[INFO] Embedding mode: {self.mode}") | |
| def fit_docs(self, docs): | |
| if self.mode == "sbert": | |
| embs = self.encoder.encode(docs, normalize_embeddings=True, convert_to_numpy=True).astype("float32") | |
| self.doc_matrix = embs | |
| else: | |
| self.doc_matrix = self.vectorizer.fit_transform(docs) | |
| def embed_query(self, q): | |
| if self.mode == "sbert": | |
| v = self.encoder.encode([q], normalize_embeddings=True, convert_to_numpy=True).astype("float32") | |
| return v | |
| else: | |
| return self.vectorizer.transform([q]) | |
| def scores(self, idxs, q_vec): | |
| if self.mode == "sbert": | |
| # cosine since normalized | |
| return self.doc_matrix[idxs].dot(q_vec[0]) | |
| else: | |
| sims = cosine_similarity(self.doc_matrix[idxs], q_vec) | |
| return sims[:,0] | |
| embedder = Embedder() | |
| DOC_TEXTS = [d["text"] for d in DOCS] | |
| embedder.fit_docs(DOC_TEXTS) | |
| # ======================== | |
| # Pretty ingredient formatting | |
| # ======================== | |
| _MEASURE_RE = re.compile( | |
| r"^\s*(?P<meas>(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*", | |
| flags=re.I | |
| ) | |
| def _split_measure_name_line(line: str): | |
| if not isinstance(line, str): return "", line | |
| m = _MEASURE_RE.match(line.strip()) | |
| if m: | |
| meas = _norm_measure(m.group("meas")) | |
| name = line[m.end():].strip() | |
| return meas, name or "" | |
| return "", line.strip() | |
| def _format_ingredients_markdown(lines): | |
| if not lines: | |
| return "—" | |
| formatted = [] | |
| for ln in lines: | |
| meas, name = _split_measure_name_line(ln) | |
| if meas and name: | |
| formatted.append(f"- **{meas}** — {name}") | |
| elif name: | |
| formatted.append(f"- {name}") | |
| else: | |
| formatted.append(f"- {ln}") | |
| return "\n".join(formatted) | |
| # ======================== | |
| # Recommendation | |
| # ======================== | |
| def recommend(base_alcohol_text, flavor, top_k=3): | |
| inferred_base = tag_base(base_alcohol_text or "") | |
| if flavor not in FLAVOR_OPTIONS: | |
| return "Please choose a flavor." | |
| idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] | |
| if inferred_base == "other" or not idxs: | |
| idxs = list(range(len(DOCS))) | |
| q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe." | |
| q_vec = embedder.embed_query(q_text) | |
| sims = embedder.scores(idxs, q_vec) | |
| scored = [] | |
| for pos, i in enumerate(idxs): | |
| base_score = float(sims[pos]) | |
| score = base_score + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0) | |
| scored.append((score, i)) | |
| scored.sort(reverse=True) | |
| k = max(1, int(top_k)) | |
| picks = scored[:k] | |
| if not picks: | |
| return "No matches found." | |
| blocks = [] | |
| for sc, i in picks: | |
| d = DOCS[i] | |
| ing_lines = d["ingredients_display"] or d["ingredients_tokens"] | |
| ing_md = _format_ingredients_markdown(ing_lines) | |
| meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}" | |
| blocks.append( | |
| f"### {d['title']}\n" | |
| f"{meta}\n\n" | |
| f"**Ingredients:**\n{ing_md}" | |
| ) | |
| return "\n\n---\n\n".join(blocks) | |
| # ======================== | |
| # CSS | |
| # ======================== | |
| CUSTOM_CSS = f""" | |
| html, body, #root, .gradio-container {{ | |
| background: transparent !important; | |
| }} | |
| #app-bg {{ | |
| position: fixed; | |
| inset: 0; | |
| z-index: -1; | |
| background-image: url('{BACKGROUND_IMAGE_URL}'); | |
| background-size: cover; | |
| background-position: center; | |
| filter: brightness(0.45); | |
| }} | |
| .glass-card {{ | |
| background: rgba(255, 255, 255, 0.08); | |
| backdrop-filter: blur(6px); | |
| border-radius: 14px; | |
| padding: 18px; | |
| border: 1px solid rgba(255, 255, 255, 0.12); | |
| }} | |
| #title_md, #result_md, #title_md *, #result_md * {{ | |
| color: #ffffff !important; | |
| }} | |
| """ | |
| # ======================== | |
| # Gradio UI | |
| # ======================== | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| gr.HTML("<div id='app-bg'></div>") | |
| with gr.Column(elem_classes=["glass-card"]): | |
| gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor", elem_id="title_md") | |
| with gr.Row(): | |
| base_text = gr.Textbox(value="gin", label="Base alcohol") | |
| flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor") | |
| topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations") | |
| with gr.Row(): | |
| ex1 = gr.Button("Example: Gin + Citrus") | |
| ex2 = gr.Button("Example: Rum + Fruity") | |
| ex3 = gr.Button("Example: Mezcal + Smoky") | |
| # Recommend button UNDER the example buttons | |
| out = gr.Markdown(elem_id="result_md") | |
| gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out) | |
| # Quick-fill examples | |
| ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk]) | |
| ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk]) | |
| ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk]) | |
| if __name__ == "__main__": | |
| demo.launch() |