hodfa840 commited on
Commit
a8820b1
Β·
1 Parent(s): d624b44

perf: replace local CPU inference with HF Inference API

Browse files

Running Qwen2.5-0.5B locally on CPU took 15-20s per query.
Switch to InferenceClient (serverless GPU) for ~1-2s responses.
- Use Qwen2.5-72B-Instruct via HF serverless endpoint
- Add structured fallback template if API is unavailable
- Drop torch/transformers from requirements (sentence-transformers
still brings them as transitive deps for embeddings)

Files changed (2) hide show
  1. modules/llm.py +49 -53
  2. requirements.txt +1 -2
modules/llm.py CHANGED
@@ -1,66 +1,65 @@
1
  """
2
- Local LLM inference engine for RetailMind.
3
 
4
- Uses Qwen2.5-0.5B-Instruct running entirely on CPU β€” no API keys, no GPU,
5
- no external dependencies. Prompt engineering is tuned to minimize
6
- hallucination by grounding all answers in the provided product context.
7
  """
8
 
9
  from __future__ import annotations
10
 
11
  import logging
12
- import time
13
  from typing import Any
14
 
15
- import torch
16
- from transformers import pipeline
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
- _generator = None
 
21
 
22
 
23
- def _get_pipeline():
24
- """Lazy-load the text-generation pipeline (singleton)."""
25
- global _generator
26
- if _generator is None:
27
- logger.info("Loading Qwen2.5-0.5B-Instruct on CPU (first call only)…")
28
- t0 = time.time()
29
- _generator = pipeline(
30
- "text-generation",
31
- model="Qwen/Qwen2.5-0.5B-Instruct",
32
- device="cpu",
33
- torch_dtype=torch.float32,
34
- )
35
- logger.info("Model loaded in %.1fs", time.time() - t0)
36
- return _generator
37
-
38
 
39
- def generate_response(
40
- system_prompt: str,
41
- user_query: str,
42
- retrieved_items: list[dict[str, Any]],
43
- ) -> str:
44
- """
45
- Generate a grounded product recommendation.
46
 
47
- The retrieved items are injected directly into the system prompt so
48
- the model can only reference real products.
49
- """
50
- # Build structured context from retrieved products
51
- context_lines = []
52
  for i, r in enumerate(retrieved_items, 1):
53
  p = r["product"]
54
  stars = "β˜…" * int(p.get("rating", 4)) + "β˜†" * (5 - int(p.get("rating", 4)))
55
- context_lines.append(
56
  f"{i}. {p['title']} β€” ${p['price']:.2f}\n"
57
  f" Category: {p['category']} | Rating: {stars} ({p.get('reviews', 0)} reviews)\n"
58
  f" Materials: {p.get('materials', 'N/A')}\n"
59
  f" Description: {p['desc']}"
60
  )
 
61
 
62
- context = "\n\n".join(context_lines)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  messages = [
65
  {
66
  "role": "system",
@@ -68,27 +67,24 @@ def generate_response(
68
  f"{system_prompt}\n\n"
69
  f"══════ Available Inventory ══════\n\n"
70
  f"{context}\n\n"
71
- f"══════════════════════════════════\n"
72
- f"IMPORTANT: You are an elite AI shopping assistant. "
73
- f"Only recommend from the products listed above. "
74
- f"Cite exact names and prices. Do not hallucinate items that are not in the provided inventory."
75
  ),
76
  },
77
  {"role": "user", "content": user_query},
78
  ]
79
 
80
  try:
81
- gen = _get_pipeline()
82
- result = gen(
83
- messages,
84
- max_new_tokens=80,
85
- do_sample=False,
86
- return_full_text=False,
87
  )
88
- generated = result[0]["generated_text"]
89
- if isinstance(generated, list):
90
- return generated[-1]["content"]
91
- return generated
92
  except Exception as e:
93
- logger.exception("LLM inference failed")
94
- return f"[RetailMind] I encountered an issue generating a response. Error: {e}"
 
1
  """
2
+ LLM inference engine for RetailMind.
3
 
4
+ Uses the HuggingFace Inference API (serverless, GPU-backed) so responses
5
+ arrive in ~1–2 s instead of 15–20 s on CPU. Falls back to a structured
6
+ template if the API is unavailable.
7
  """
8
 
9
  from __future__ import annotations
10
 
11
  import logging
12
+ import os
13
  from typing import Any
14
 
15
+ from huggingface_hub import InferenceClient
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
+ _client: InferenceClient | None = None
20
+ MODEL = "Qwen/Qwen2.5-72B-Instruct" # strong model, free on HF serverless
21
 
22
 
23
+ def _get_client() -> InferenceClient:
24
+ global _client
25
+ if _client is None:
26
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
27
+ _client = InferenceClient(token=token)
28
+ logger.info("InferenceClient ready (model=%s)", MODEL)
29
+ return _client
 
 
 
 
 
 
 
 
30
 
 
 
 
 
 
 
 
31
 
32
+ def _build_context(retrieved_items: list[dict[str, Any]]) -> str:
33
+ lines = []
 
 
 
34
  for i, r in enumerate(retrieved_items, 1):
35
  p = r["product"]
36
  stars = "β˜…" * int(p.get("rating", 4)) + "β˜†" * (5 - int(p.get("rating", 4)))
37
+ lines.append(
38
  f"{i}. {p['title']} β€” ${p['price']:.2f}\n"
39
  f" Category: {p['category']} | Rating: {stars} ({p.get('reviews', 0)} reviews)\n"
40
  f" Materials: {p.get('materials', 'N/A')}\n"
41
  f" Description: {p['desc']}"
42
  )
43
+ return "\n\n".join(lines)
44
 
 
45
 
46
+ def _fallback_response(retrieved_items: list[dict[str, Any]]) -> str:
47
+ """Structured template used when the API is unavailable."""
48
+ if not retrieved_items:
49
+ return "I couldn't find matching products for your query. Try different keywords."
50
+ lines = ["Here are my top picks for you:\n"]
51
+ for r in retrieved_items:
52
+ p = r["product"]
53
+ lines.append(f"β€’ **{p['title']}** β€” ${p['price']:.2f}\n {p['desc'][:120]}…")
54
+ return "\n".join(lines)
55
+
56
+
57
+ def generate_response(
58
+ system_prompt: str,
59
+ user_query: str,
60
+ retrieved_items: list[dict[str, Any]],
61
+ ) -> str:
62
+ context = _build_context(retrieved_items)
63
  messages = [
64
  {
65
  "role": "system",
 
67
  f"{system_prompt}\n\n"
68
  f"══════ Available Inventory ══════\n\n"
69
  f"{context}\n\n"
70
+ f"════════════════════════════════\n"
71
+ f"You are a helpful AI shopping assistant. "
72
+ f"Only recommend products listed above. "
73
+ f"Cite exact names and prices. Be concise (2–4 sentences)."
74
  ),
75
  },
76
  {"role": "user", "content": user_query},
77
  ]
78
 
79
  try:
80
+ client = _get_client()
81
+ result = client.chat.completions.create(
82
+ model=MODEL,
83
+ messages=messages,
84
+ max_tokens=150,
85
+ temperature=0.3,
86
  )
87
+ return result.choices[0].message.content.strip()
 
 
 
88
  except Exception as e:
89
+ logger.warning("Inference API failed (%s), using fallback template.", e)
90
+ return _fallback_response(retrieved_items)
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
- transformers
2
- torch
3
  sentence-transformers
 
4
  python-dotenv
5
  hf_transfer
6
  plotly
 
 
 
1
  sentence-transformers
2
+ huggingface_hub
3
  python-dotenv
4
  hf_transfer
5
  plotly