OGOGOG commited on
Commit
233677c
·
verified ·
1 Parent(s): e9395e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -27
app.py CHANGED
@@ -2,8 +2,19 @@ import os
2
  import re
3
  import numpy as np
4
  import gradio as gr
5
- from datasets import load_dataset
6
- from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # ========================
9
  # Config
@@ -53,7 +64,7 @@ FLAVORS = {
53
  FLAVOR_OPTIONS = list(FLAVORS.keys())
54
 
55
  # ========================
56
- # Helpers
57
  # ========================
58
  def _clean(s):
59
  return s.strip() if isinstance(s, str) else ""
@@ -77,12 +88,7 @@ def _join_measure_name(measure, name):
77
  def _split_ingredient_blob(s):
78
  if not isinstance(s, str): return []
79
  parts = re.split(r"[,\n;•\-–]+", s)
80
- out = []
81
- for p in parts:
82
- p = p.strip()
83
- if p:
84
- out.append(p)
85
- return out
86
 
87
  def _ingredients_from_any(val):
88
  if isinstance(val, str):
@@ -105,8 +111,9 @@ def _get_title(row, cols):
105
  return "Untitled"
106
 
107
  def _get_ingredients_with_measures(row, cols):
108
- for key in ["ingredients","ingredients_raw","raw_ingredients"]:
109
- if key in cols and row.get(key):
 
110
  return _ingredients_from_any(row[key])
111
  return [], []
112
 
@@ -122,7 +129,7 @@ def tag_flavors(text):
122
  return [flv for flv, pats in FLAVORS.items() if any(re.search(p, t) for p in pats)]
123
 
124
  # ========================
125
- # Load dataset
126
  # ========================
127
  ds = load_dataset(DATASET_ID, split="train", **load_kwargs)
128
  cols = ds.column_names
@@ -144,14 +151,81 @@ for r in ds:
144
  })
145
 
146
  # ========================
147
- # Embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # ========================
149
- encoder = SentenceTransformer(EMBED_MODEL)
150
- doc_embs = encoder.encode(
151
- [d["text"] for d in DOCS],
152
- normalize_embeddings=True,
153
- convert_to_numpy=True
154
- ).astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # ========================
157
  # Recommendation
@@ -160,18 +234,38 @@ def recommend(base_alcohol_text, flavor, top_k=3):
160
  inferred_base = tag_base(base_alcohol_text or "")
161
  if flavor not in FLAVOR_OPTIONS:
162
  return "Please choose a flavor."
163
- idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base] or list(range(len(DOCS)))
 
 
 
 
164
  q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe."
165
- q_emb = encoder.encode([q_text], normalize_embeddings=True, convert_to_numpy=True).astype("float32")[0]
166
- sims = doc_embs[idxs].dot(q_emb)
167
- scored = [(s + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0), i) for s, i in zip(sims, idxs)]
 
 
 
 
 
168
  scored.sort(reverse=True)
169
- picks = scored[:max(1, int(top_k))]
 
 
 
 
 
170
  blocks = []
171
  for sc, i in picks:
172
  d = DOCS[i]
173
- meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}"
174
- blocks.append(f"### {d['title']}\n{meta}\n\n**Ingredients:**\n" + "\n".join(f"- {x}" for x in d["ingredients_display"]))
 
 
 
 
 
 
175
  return "\n\n---\n\n".join(blocks)
176
 
177
  # ========================
@@ -209,16 +303,22 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
209
  gr.HTML("<div id='app-bg'></div>")
210
  with gr.Column(elem_classes=["glass-card"]):
211
  gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor", elem_id="title_md")
 
212
  with gr.Row():
213
  base_text = gr.Textbox(value="gin", label="Base alcohol")
214
  flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
215
  topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
 
216
  with gr.Row():
217
  ex1 = gr.Button("Example: Gin + Citrus")
218
  ex2 = gr.Button("Example: Rum + Fruity")
219
  ex3 = gr.Button("Example: Mezcal + Smoky")
220
- gr.Button("Recommend").click(recommend, [base_text, flavor, topk], gr.Markdown(elem_id="result_md"))
 
221
  out = gr.Markdown(elem_id="result_md")
 
 
 
222
  ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk])
223
  ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk])
224
  ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk])
 
2
  import re
3
  import numpy as np
4
  import gradio as gr
5
+
6
+ # Optional offline fallback embeddings
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+ # Try to import sentence-transformers, but we’ll fall back if it can’t download
11
+ try:
12
+ from datasets import load_dataset
13
+ from sentence_transformers import SentenceTransformer
14
+ _HAS_SBERT = True
15
+ except Exception:
16
+ _HAS_SBERT = False
17
+ from datasets import load_dataset # datasets worked for you per logs
18
 
19
  # ========================
20
  # Config
 
64
  FLAVOR_OPTIONS = list(FLAVORS.keys())
65
 
66
  # ========================
67
+ # Robust extraction helpers (with measures)
68
  # ========================
69
  def _clean(s):
70
  return s.strip() if isinstance(s, str) else ""
 
88
  def _split_ingredient_blob(s):
89
  if not isinstance(s, str): return []
90
  parts = re.split(r"[,\n;•\-–]+", s)
91
+ return [p.strip() for p in parts if p.strip()]
 
 
 
 
 
92
 
93
  def _ingredients_from_any(val):
94
  if isinstance(val, str):
 
111
  return "Untitled"
112
 
113
  def _get_ingredients_with_measures(row, cols):
114
+ for key in ["ingredients","ingredients_raw","raw_ingredients","Raw_Ingredients","Raw Ingredients",
115
+ "ingredient_list","ingredients_list"]:
116
+ if key in cols and row.get(key) not in (None, "", [], {}):
117
  return _ingredients_from_any(row[key])
118
  return [], []
119
 
 
129
  return [flv for flv, pats in FLAVORS.items() if any(re.search(p, t) for p in pats)]
130
 
131
  # ========================
132
+ # Load dataset & build docs
133
  # ========================
134
  ds = load_dataset(DATASET_ID, split="train", **load_kwargs)
135
  cols = ds.column_names
 
151
  })
152
 
153
  # ========================
154
+ # Embedding backends (SBERT -> TF-IDF fallback)
155
+ # ========================
156
+ class Embedder:
157
+ def __init__(self):
158
+ self.mode = "tfidf"
159
+ self.encoder = None
160
+ self.vectorizer = None
161
+ self.doc_matrix = None
162
+ # Try SBERT if available and downloadable
163
+ if _HAS_SBERT:
164
+ try:
165
+ self.encoder = SentenceTransformer(EMBED_MODEL)
166
+ self.mode = "sbert"
167
+ except Exception as e:
168
+ print(f"[WARN] SBERT model load failed, falling back to TF-IDF. Reason: {e}")
169
+ if self.mode == "tfidf":
170
+ self.vectorizer = TfidfVectorizer(ngram_range=(1,2), min_df=1)
171
+ print(f"[INFO] Embedding mode: {self.mode}")
172
+
173
+ def fit_docs(self, docs):
174
+ if self.mode == "sbert":
175
+ embs = self.encoder.encode(docs, normalize_embeddings=True, convert_to_numpy=True).astype("float32")
176
+ self.doc_matrix = embs
177
+ else:
178
+ self.doc_matrix = self.vectorizer.fit_transform(docs)
179
+
180
+ def embed_query(self, q):
181
+ if self.mode == "sbert":
182
+ v = self.encoder.encode([q], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
183
+ return v
184
+ else:
185
+ return self.vectorizer.transform([q])
186
+
187
+ def scores(self, idxs, q_vec):
188
+ if self.mode == "sbert":
189
+ # cosine since normalized
190
+ return self.doc_matrix[idxs].dot(q_vec[0])
191
+ else:
192
+ sims = cosine_similarity(self.doc_matrix[idxs], q_vec)
193
+ return sims[:,0]
194
+
195
+ embedder = Embedder()
196
+ DOC_TEXTS = [d["text"] for d in DOCS]
197
+ embedder.fit_docs(DOC_TEXTS)
198
+
199
+ # ========================
200
+ # Pretty ingredient formatting
201
  # ========================
202
+ _MEASURE_RE = re.compile(
203
+ r"^\s*(?P<meas>(?:\d+(\.\d+)?|\d+\s*/\s*\d+|\d+\s*\d*/\d+)\s*(?:ml|oz|tsp|tbsp)?|\d+\s*(?:ml|oz|tsp|tbsp)|(?:dash|dashes|drop|drops|barspoon)s?)\b[\s\-–:]*",
204
+ flags=re.I
205
+ )
206
+
207
+ def _split_measure_name_line(line: str):
208
+ if not isinstance(line, str): return "", line
209
+ m = _MEASURE_RE.match(line.strip())
210
+ if m:
211
+ meas = _norm_measure(m.group("meas"))
212
+ name = line[m.end():].strip()
213
+ return meas, name or ""
214
+ return "", line.strip()
215
+
216
+ def _format_ingredients_markdown(lines):
217
+ if not lines:
218
+ return "—"
219
+ formatted = []
220
+ for ln in lines:
221
+ meas, name = _split_measure_name_line(ln)
222
+ if meas and name:
223
+ formatted.append(f"- **{meas}** — {name}")
224
+ elif name:
225
+ formatted.append(f"- {name}")
226
+ else:
227
+ formatted.append(f"- {ln}")
228
+ return "\n".join(formatted)
229
 
230
  # ========================
231
  # Recommendation
 
234
  inferred_base = tag_base(base_alcohol_text or "")
235
  if flavor not in FLAVOR_OPTIONS:
236
  return "Please choose a flavor."
237
+
238
+ idxs = [i for i, d in enumerate(DOCS) if d["base"] == inferred_base]
239
+ if inferred_base == "other" or not idxs:
240
+ idxs = list(range(len(DOCS)))
241
+
242
  q_text = f"Base spirit: {base_alcohol_text}. Flavor: {flavor}. Cocktail recipe."
243
+ q_vec = embedder.embed_query(q_text)
244
+ sims = embedder.scores(idxs, q_vec)
245
+
246
+ scored = []
247
+ for pos, i in enumerate(idxs):
248
+ base_score = float(sims[pos])
249
+ score = base_score + (FLAVOR_BOOST if flavor in DOCS[i]['flavors'] else 0.0)
250
+ scored.append((score, i))
251
  scored.sort(reverse=True)
252
+
253
+ k = max(1, int(top_k))
254
+ picks = scored[:k]
255
+ if not picks:
256
+ return "No matches found."
257
+
258
  blocks = []
259
  for sc, i in picks:
260
  d = DOCS[i]
261
+ ing_lines = d["ingredients_display"] or d["ingredients_tokens"]
262
+ ing_md = _format_ingredients_markdown(ing_lines)
263
+ meta = f"**Base:** {d['base']} | **Flavor tags:** {', '.join(d['flavors']) or '—'} | **Score:** {sc:.3f}"
264
+ blocks.append(
265
+ f"### {d['title']}\n"
266
+ f"{meta}\n\n"
267
+ f"**Ingredients:**\n{ing_md}"
268
+ )
269
  return "\n\n---\n\n".join(blocks)
270
 
271
  # ========================
 
303
  gr.HTML("<div id='app-bg'></div>")
304
  with gr.Column(elem_classes=["glass-card"]):
305
  gr.Markdown("# 🍹 AI Bartender — Type a Base + Flavor", elem_id="title_md")
306
+
307
  with gr.Row():
308
  base_text = gr.Textbox(value="gin", label="Base alcohol")
309
  flavor = gr.Dropdown(choices=FLAVOR_OPTIONS, value="citrus", label="Flavor")
310
  topk = gr.Slider(1, 10, value=3, step=1, label="Number of recommendations")
311
+
312
  with gr.Row():
313
  ex1 = gr.Button("Example: Gin + Citrus")
314
  ex2 = gr.Button("Example: Rum + Fruity")
315
  ex3 = gr.Button("Example: Mezcal + Smoky")
316
+
317
+ # Recommend button UNDER the example buttons
318
  out = gr.Markdown(elem_id="result_md")
319
+ gr.Button("Recommend").click(recommend, [base_text, flavor, topk], out)
320
+
321
+ # Quick-fill examples
322
  ex1.click(lambda: ("gin", "citrus", 3), outputs=[base_text, flavor, topk])
323
  ex2.click(lambda: ("white rum", "fruity", 3), outputs=[base_text, flavor, topk])
324
  ex3.click(lambda: ("mezcal", "smoky", 3), outputs=[base_text, flavor, topk])