TomLii commited on
Commit
dd903d9
·
1 Parent(s): f656413

Add model fallback plus lightweight memory and cache

Browse files
Files changed (1) hide show
  1. app.py +98 -20
app.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  import os
3
  import re
4
  from dataclasses import dataclass, field
5
- from typing import Any, Dict, List, Optional, Tuple
6
 
7
  import gradio as gr
8
  import requests
@@ -11,12 +11,16 @@ from duckduckgo_search import DDGS
11
  from huggingface_hub import InferenceClient
12
 
13
 
14
- DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "Qwen/Qwen2.5-7B-Instruct")
15
  DEFAULT_FREE_MODELS = [
 
 
 
 
 
16
  "Qwen/Qwen2.5-7B-Instruct",
17
  "meta-llama/Llama-3.1-8B-Instruct",
18
- "mistralai/Mistral-7B-Instruct-v0.3",
19
  ]
 
20
 
21
  SYSTEM_PROMPT = """You are a Deep Research assistant.
22
  You can think step by step, use tools, and then return a final answer.
@@ -47,6 +51,9 @@ TOOL_RESPONSE_TEMPLATE = """<tool_response>
47
  {payload}
48
  </tool_response>"""
49
 
 
 
 
50
  CUSTOM_CSS = """
51
  @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700&display=swap');
52
 
@@ -100,6 +107,9 @@ CUSTOM_CSS = """
100
  class AgentState:
101
  searched_queries: List[str] = field(default_factory=list)
102
  visited_urls: List[str] = field(default_factory=list)
 
 
 
103
  trace: List[Dict[str, Any]] = field(default_factory=list)
104
 
105
 
@@ -128,6 +138,10 @@ def parse_tool_call(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]],
128
  def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
129
  if not query.strip():
130
  return {"ok": False, "error": "Search query cannot be empty."}
 
 
 
 
131
  rows: List[Dict[str, str]] = []
132
  with DDGS() as ddgs:
133
  for item in ddgs.text(query, max_results=max_results):
@@ -138,7 +152,9 @@ def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
138
  "body": item.get("body", ""),
139
  }
140
  )
141
- return {"ok": True, "query": query, "results": rows}
 
 
142
 
143
 
144
  def _clean_html_to_text(html: str, max_chars: int) -> str:
@@ -153,6 +169,9 @@ def _clean_html_to_text(html: str, max_chars: int) -> str:
153
  def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
154
  if not url.strip():
155
  return {"ok": False, "error": "URL cannot be empty."}
 
 
 
156
  try:
157
  resp = requests.get(
158
  url,
@@ -165,7 +184,9 @@ def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
165
  text = _clean_html_to_text(resp.text, max_chars=max_chars)
166
  else:
167
  text = resp.text[:max_chars]
168
- return {"ok": True, "url": url, "content": text}
 
 
169
  except Exception as exc:
170
  return {"ok": False, "url": url, "error": str(exc)}
171
 
@@ -173,17 +194,30 @@ def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
173
  def call_model(
174
  client: InferenceClient,
175
  messages: List[Dict[str, str]],
176
- model: str,
 
177
  temperature: float,
178
  max_new_tokens: int,
179
- ) -> str:
180
- completion = client.chat_completion(
181
- model=model,
182
- messages=messages,
183
- temperature=temperature,
184
- max_tokens=max_new_tokens,
185
- )
186
- return completion.choices[0].message.content or ""
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
 
189
  def build_research_agent(
@@ -196,6 +230,8 @@ def build_research_agent(
196
  token = os.getenv("HF_TOKEN")
197
  client = InferenceClient(token=token)
198
  state = AgentState()
 
 
199
 
200
  messages: List[Dict[str, str]] = [
201
  {"role": "system", "content": SYSTEM_PROMPT},
@@ -205,10 +241,20 @@ def build_research_agent(
205
  final_answer: Optional[str] = None
206
 
207
  for turn in range(1, max_turns + 1):
208
- model_output = call_model(
 
 
 
 
 
 
 
 
 
209
  client=client,
210
  messages=messages,
211
- model=model,
 
212
  temperature=temperature,
213
  max_new_tokens=1400,
214
  )
@@ -237,16 +283,45 @@ def build_research_agent(
237
  query = str(tool_args.get("query", "")).strip()
238
  max_results = int(tool_args.get("max_results", max_search_results))
239
  max_results = max(1, min(max_results, 10))
240
- if query:
 
 
 
 
 
 
 
 
241
  state.searched_queries.append(query)
242
- tool_response = run_search(query=query, max_results=max_results)
 
 
 
 
 
 
 
243
  elif tool_name == "visit":
244
  url = str(tool_args.get("url", "")).strip()
245
  max_chars = int(tool_args.get("max_chars", 6000))
246
  max_chars = max(500, min(max_chars, 20000))
247
- if url:
 
 
 
 
 
 
 
248
  state.visited_urls.append(url)
249
- tool_response = run_visit(url=url, max_chars=max_chars)
 
 
 
 
 
 
 
250
  else:
251
  tool_response = {"ok": False, "error": f"Unknown tool: {tool_name}"}
252
 
@@ -267,13 +342,16 @@ def build_research_agent(
267
  )
268
 
269
  citations = "\n".join(f"- {url}" for url in sorted(set(state.visited_urls)))
 
270
  if citations:
271
  final_answer = f"{final_answer}\n\n### Visited Sources\n{citations}"
272
 
273
  trace_text = json.dumps(
274
  {
 
275
  "searched_queries": state.searched_queries,
276
  "visited_urls": state.visited_urls,
 
277
  "trace": state.trace,
278
  },
279
  ensure_ascii=False,
 
2
  import os
3
  import re
4
  from dataclasses import dataclass, field
5
+ from typing import Any, Dict, List, Optional, Set, Tuple
6
 
7
  import gradio as gr
8
  import requests
 
11
  from huggingface_hub import InferenceClient
12
 
13
 
 
14
  DEFAULT_FREE_MODELS = [
15
+ # Newer free-friendly candidates (availability depends on HF Inference quota/region)
16
+ "Qwen/Qwen3-8B",
17
+ "google/gemma-3-12b-it",
18
+ "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
19
+ # Fallback older but usually reliable
20
  "Qwen/Qwen2.5-7B-Instruct",
21
  "meta-llama/Llama-3.1-8B-Instruct",
 
22
  ]
23
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", DEFAULT_FREE_MODELS[0])
24
 
25
  SYSTEM_PROMPT = """You are a Deep Research assistant.
26
  You can think step by step, use tools, and then return a final answer.
 
51
  {payload}
52
  </tool_response>"""
53
 
54
+ SEARCH_CACHE: Dict[str, Dict[str, Any]] = {}
55
+ VISIT_CACHE: Dict[str, Dict[str, Any]] = {}
56
+
57
  CUSTOM_CSS = """
58
  @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700&display=swap');
59
 
 
107
  class AgentState:
108
  searched_queries: List[str] = field(default_factory=list)
109
  visited_urls: List[str] = field(default_factory=list)
110
+ searched_query_set: Set[str] = field(default_factory=set)
111
+ visited_url_set: Set[str] = field(default_factory=set)
112
+ trusted_notes: List[str] = field(default_factory=list)
113
  trace: List[Dict[str, Any]] = field(default_factory=list)
114
 
115
 
 
138
  def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
139
  if not query.strip():
140
  return {"ok": False, "error": "Search query cannot be empty."}
141
+ cache_key = f"{query.strip().lower()}::{max_results}"
142
+ if cache_key in SEARCH_CACHE:
143
+ return {**SEARCH_CACHE[cache_key], "cached": True}
144
+
145
  rows: List[Dict[str, str]] = []
146
  with DDGS() as ddgs:
147
  for item in ddgs.text(query, max_results=max_results):
 
152
  "body": item.get("body", ""),
153
  }
154
  )
155
+ payload = {"ok": True, "query": query, "results": rows, "cached": False}
156
+ SEARCH_CACHE[cache_key] = payload
157
+ return payload
158
 
159
 
160
  def _clean_html_to_text(html: str, max_chars: int) -> str:
 
169
  def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
170
  if not url.strip():
171
  return {"ok": False, "error": "URL cannot be empty."}
172
+ cache_key = f"{url.strip()}::{max_chars}"
173
+ if cache_key in VISIT_CACHE:
174
+ return {**VISIT_CACHE[cache_key], "cached": True}
175
  try:
176
  resp = requests.get(
177
  url,
 
184
  text = _clean_html_to_text(resp.text, max_chars=max_chars)
185
  else:
186
  text = resp.text[:max_chars]
187
+ payload = {"ok": True, "url": url, "content": text, "cached": False}
188
+ VISIT_CACHE[cache_key] = payload
189
+ return payload
190
  except Exception as exc:
191
  return {"ok": False, "url": url, "error": str(exc)}
192
 
 
194
  def call_model(
195
  client: InferenceClient,
196
  messages: List[Dict[str, str]],
197
+ preferred_model: str,
198
+ candidate_models: List[str],
199
  temperature: float,
200
  max_new_tokens: int,
201
+ ) -> Tuple[str, str]:
202
+ model_order: List[str] = []
203
+ for m in [preferred_model] + candidate_models:
204
+ if m and m not in model_order:
205
+ model_order.append(m)
206
+
207
+ last_error = None
208
+ for model_name in model_order:
209
+ try:
210
+ completion = client.chat_completion(
211
+ model=model_name,
212
+ messages=messages,
213
+ temperature=temperature,
214
+ max_tokens=max_new_tokens,
215
+ )
216
+ return completion.choices[0].message.content or "", model_name
217
+ except Exception as exc:
218
+ last_error = exc
219
+ continue
220
+ raise RuntimeError(f"All model candidates failed. Last error: {last_error}")
221
 
222
 
223
  def build_research_agent(
 
230
  token = os.getenv("HF_TOKEN")
231
  client = InferenceClient(token=token)
232
  state = AgentState()
233
+ used_model = model
234
+ recent_model_candidates = [m for m in DEFAULT_FREE_MODELS if m != model]
235
 
236
  messages: List[Dict[str, str]] = [
237
  {"role": "system", "content": SYSTEM_PROMPT},
 
241
  final_answer: Optional[str] = None
242
 
243
  for turn in range(1, max_turns + 1):
244
+ if state.trusted_notes and turn > 1 and turn % 3 == 0:
245
+ summary_lines = "\n".join(f"- {n}" for n in state.trusted_notes[-6:])
246
+ messages.append(
247
+ {
248
+ "role": "user",
249
+ "content": f"RESEARCH STATE SUMMARY\n{summary_lines}\nUse this summary to avoid repeating work.",
250
+ }
251
+ )
252
+
253
+ model_output, used_model = call_model(
254
  client=client,
255
  messages=messages,
256
+ preferred_model=model,
257
+ candidate_models=recent_model_candidates,
258
  temperature=temperature,
259
  max_new_tokens=1400,
260
  )
 
283
  query = str(tool_args.get("query", "")).strip()
284
  max_results = int(tool_args.get("max_results", max_search_results))
285
  max_results = max(1, min(max_results, 10))
286
+ if query in state.searched_query_set:
287
+ tool_response = {
288
+ "ok": True,
289
+ "query": query,
290
+ "cached": True,
291
+ "note": "This query was already searched. Reusing cached result to avoid duplicate work.",
292
+ "results": [],
293
+ }
294
+ else:
295
  state.searched_queries.append(query)
296
+ state.searched_query_set.add(query)
297
+ tool_response = run_search(query=query, max_results=max_results)
298
+ if tool_response.get("ok"):
299
+ first_titles = [r.get("title", "") for r in tool_response.get("results", [])[:2]]
300
+ if first_titles:
301
+ state.trusted_notes.append(
302
+ f"Searched '{query}' and found leads: {', '.join(t for t in first_titles if t)}"
303
+ )
304
  elif tool_name == "visit":
305
  url = str(tool_args.get("url", "")).strip()
306
  max_chars = int(tool_args.get("max_chars", 6000))
307
  max_chars = max(500, min(max_chars, 20000))
308
+ if url in state.visited_url_set:
309
+ tool_response = {
310
+ "ok": True,
311
+ "url": url,
312
+ "cached": True,
313
+ "note": "This URL was already visited. Reusing cached result to avoid duplicate work.",
314
+ }
315
+ else:
316
  state.visited_urls.append(url)
317
+ state.visited_url_set.add(url)
318
+ tool_response = run_visit(url=url, max_chars=max_chars)
319
+ if tool_response.get("ok"):
320
+ snippet = str(tool_response.get("content", ""))[:180]
321
+ if snippet:
322
+ state.trusted_notes.append(
323
+ f"Visited {url} and extracted key context: {snippet}"
324
+ )
325
  else:
326
  tool_response = {"ok": False, "error": f"Unknown tool: {tool_name}"}
327
 
 
342
  )
343
 
344
  citations = "\n".join(f"- {url}" for url in sorted(set(state.visited_urls)))
345
+ final_answer = f"**Model used:** `{used_model}`\n\n{final_answer}"
346
  if citations:
347
  final_answer = f"{final_answer}\n\n### Visited Sources\n{citations}"
348
 
349
  trace_text = json.dumps(
350
  {
351
+ "used_model": used_model,
352
  "searched_queries": state.searched_queries,
353
  "visited_urls": state.visited_urls,
354
+ "trusted_notes": state.trusted_notes[-10:],
355
  "trace": state.trace,
356
  },
357
  ensure_ascii=False,