Add model fallback plus lightweight memory and cache
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import json
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
-
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import requests
|
|
@@ -11,12 +11,16 @@ from duckduckgo_search import DDGS
|
|
| 11 |
from huggingface_hub import InferenceClient
|
| 12 |
|
| 13 |
|
| 14 |
-
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "Qwen/Qwen2.5-7B-Instruct")
|
| 15 |
DEFAULT_FREE_MODELS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"Qwen/Qwen2.5-7B-Instruct",
|
| 17 |
"meta-llama/Llama-3.1-8B-Instruct",
|
| 18 |
-
"mistralai/Mistral-7B-Instruct-v0.3",
|
| 19 |
]
|
|
|
|
| 20 |
|
| 21 |
SYSTEM_PROMPT = """You are a Deep Research assistant.
|
| 22 |
You can think step by step, use tools, and then return a final answer.
|
|
@@ -47,6 +51,9 @@ TOOL_RESPONSE_TEMPLATE = """<tool_response>
|
|
| 47 |
{payload}
|
| 48 |
</tool_response>"""
|
| 49 |
|
|
|
|
|
|
|
|
|
|
| 50 |
CUSTOM_CSS = """
|
| 51 |
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700&display=swap');
|
| 52 |
|
|
@@ -100,6 +107,9 @@ CUSTOM_CSS = """
|
|
| 100 |
class AgentState:
|
| 101 |
searched_queries: List[str] = field(default_factory=list)
|
| 102 |
visited_urls: List[str] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
|
| 103 |
trace: List[Dict[str, Any]] = field(default_factory=list)
|
| 104 |
|
| 105 |
|
|
@@ -128,6 +138,10 @@ def parse_tool_call(text: str) -> Tuple[Optional[str], Optional[Dict[str, Any]],
|
|
| 128 |
def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
|
| 129 |
if not query.strip():
|
| 130 |
return {"ok": False, "error": "Search query cannot be empty."}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
rows: List[Dict[str, str]] = []
|
| 132 |
with DDGS() as ddgs:
|
| 133 |
for item in ddgs.text(query, max_results=max_results):
|
|
@@ -138,7 +152,9 @@ def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
|
|
| 138 |
"body": item.get("body", ""),
|
| 139 |
}
|
| 140 |
)
|
| 141 |
-
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
def _clean_html_to_text(html: str, max_chars: int) -> str:
|
|
@@ -153,6 +169,9 @@ def _clean_html_to_text(html: str, max_chars: int) -> str:
|
|
| 153 |
def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
|
| 154 |
if not url.strip():
|
| 155 |
return {"ok": False, "error": "URL cannot be empty."}
|
|
|
|
|
|
|
|
|
|
| 156 |
try:
|
| 157 |
resp = requests.get(
|
| 158 |
url,
|
|
@@ -165,7 +184,9 @@ def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
|
|
| 165 |
text = _clean_html_to_text(resp.text, max_chars=max_chars)
|
| 166 |
else:
|
| 167 |
text = resp.text[:max_chars]
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
except Exception as exc:
|
| 170 |
return {"ok": False, "url": url, "error": str(exc)}
|
| 171 |
|
|
@@ -173,17 +194,30 @@ def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
|
|
| 173 |
def call_model(
|
| 174 |
client: InferenceClient,
|
| 175 |
messages: List[Dict[str, str]],
|
| 176 |
-
|
|
|
|
| 177 |
temperature: float,
|
| 178 |
max_new_tokens: int,
|
| 179 |
-
) -> str:
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
|
| 189 |
def build_research_agent(
|
|
@@ -196,6 +230,8 @@ def build_research_agent(
|
|
| 196 |
token = os.getenv("HF_TOKEN")
|
| 197 |
client = InferenceClient(token=token)
|
| 198 |
state = AgentState()
|
|
|
|
|
|
|
| 199 |
|
| 200 |
messages: List[Dict[str, str]] = [
|
| 201 |
{"role": "system", "content": SYSTEM_PROMPT},
|
|
@@ -205,10 +241,20 @@ def build_research_agent(
|
|
| 205 |
final_answer: Optional[str] = None
|
| 206 |
|
| 207 |
for turn in range(1, max_turns + 1):
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
client=client,
|
| 210 |
messages=messages,
|
| 211 |
-
|
|
|
|
| 212 |
temperature=temperature,
|
| 213 |
max_new_tokens=1400,
|
| 214 |
)
|
|
@@ -237,16 +283,45 @@ def build_research_agent(
|
|
| 237 |
query = str(tool_args.get("query", "")).strip()
|
| 238 |
max_results = int(tool_args.get("max_results", max_search_results))
|
| 239 |
max_results = max(1, min(max_results, 10))
|
| 240 |
-
if query:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
state.searched_queries.append(query)
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
elif tool_name == "visit":
|
| 244 |
url = str(tool_args.get("url", "")).strip()
|
| 245 |
max_chars = int(tool_args.get("max_chars", 6000))
|
| 246 |
max_chars = max(500, min(max_chars, 20000))
|
| 247 |
-
if url:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
state.visited_urls.append(url)
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
else:
|
| 251 |
tool_response = {"ok": False, "error": f"Unknown tool: {tool_name}"}
|
| 252 |
|
|
@@ -267,13 +342,16 @@ def build_research_agent(
|
|
| 267 |
)
|
| 268 |
|
| 269 |
citations = "\n".join(f"- {url}" for url in sorted(set(state.visited_urls)))
|
|
|
|
| 270 |
if citations:
|
| 271 |
final_answer = f"{final_answer}\n\n### Visited Sources\n{citations}"
|
| 272 |
|
| 273 |
trace_text = json.dumps(
|
| 274 |
{
|
|
|
|
| 275 |
"searched_queries": state.searched_queries,
|
| 276 |
"visited_urls": state.visited_urls,
|
|
|
|
| 277 |
"trace": state.trace,
|
| 278 |
},
|
| 279 |
ensure_ascii=False,
|
|
|
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import requests
|
|
|
|
| 11 |
from huggingface_hub import InferenceClient
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
DEFAULT_FREE_MODELS = [
|
| 15 |
+
# Newer free-friendly candidates (availability depends on HF Inference quota/region)
|
| 16 |
+
"Qwen/Qwen3-8B",
|
| 17 |
+
"google/gemma-3-12b-it",
|
| 18 |
+
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
| 19 |
+
# Fallback older but usually reliable
|
| 20 |
"Qwen/Qwen2.5-7B-Instruct",
|
| 21 |
"meta-llama/Llama-3.1-8B-Instruct",
|
|
|
|
| 22 |
]
|
| 23 |
+
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", DEFAULT_FREE_MODELS[0])
|
| 24 |
|
| 25 |
SYSTEM_PROMPT = """You are a Deep Research assistant.
|
| 26 |
You can think step by step, use tools, and then return a final answer.
|
|
|
|
| 51 |
{payload}
|
| 52 |
</tool_response>"""
|
| 53 |
|
| 54 |
+
SEARCH_CACHE: Dict[str, Dict[str, Any]] = {}
|
| 55 |
+
VISIT_CACHE: Dict[str, Dict[str, Any]] = {}
|
| 56 |
+
|
| 57 |
CUSTOM_CSS = """
|
| 58 |
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700&display=swap');
|
| 59 |
|
|
|
|
| 107 |
class AgentState:
|
| 108 |
searched_queries: List[str] = field(default_factory=list)
|
| 109 |
visited_urls: List[str] = field(default_factory=list)
|
| 110 |
+
searched_query_set: Set[str] = field(default_factory=set)
|
| 111 |
+
visited_url_set: Set[str] = field(default_factory=set)
|
| 112 |
+
trusted_notes: List[str] = field(default_factory=list)
|
| 113 |
trace: List[Dict[str, Any]] = field(default_factory=list)
|
| 114 |
|
| 115 |
|
|
|
|
| 138 |
def run_search(query: str, max_results: int = 5) -> Dict[str, Any]:
|
| 139 |
if not query.strip():
|
| 140 |
return {"ok": False, "error": "Search query cannot be empty."}
|
| 141 |
+
cache_key = f"{query.strip().lower()}::{max_results}"
|
| 142 |
+
if cache_key in SEARCH_CACHE:
|
| 143 |
+
return {**SEARCH_CACHE[cache_key], "cached": True}
|
| 144 |
+
|
| 145 |
rows: List[Dict[str, str]] = []
|
| 146 |
with DDGS() as ddgs:
|
| 147 |
for item in ddgs.text(query, max_results=max_results):
|
|
|
|
| 152 |
"body": item.get("body", ""),
|
| 153 |
}
|
| 154 |
)
|
| 155 |
+
payload = {"ok": True, "query": query, "results": rows, "cached": False}
|
| 156 |
+
SEARCH_CACHE[cache_key] = payload
|
| 157 |
+
return payload
|
| 158 |
|
| 159 |
|
| 160 |
def _clean_html_to_text(html: str, max_chars: int) -> str:
|
|
|
|
| 169 |
def run_visit(url: str, max_chars: int = 6000) -> Dict[str, Any]:
|
| 170 |
if not url.strip():
|
| 171 |
return {"ok": False, "error": "URL cannot be empty."}
|
| 172 |
+
cache_key = f"{url.strip()}::{max_chars}"
|
| 173 |
+
if cache_key in VISIT_CACHE:
|
| 174 |
+
return {**VISIT_CACHE[cache_key], "cached": True}
|
| 175 |
try:
|
| 176 |
resp = requests.get(
|
| 177 |
url,
|
|
|
|
| 184 |
text = _clean_html_to_text(resp.text, max_chars=max_chars)
|
| 185 |
else:
|
| 186 |
text = resp.text[:max_chars]
|
| 187 |
+
payload = {"ok": True, "url": url, "content": text, "cached": False}
|
| 188 |
+
VISIT_CACHE[cache_key] = payload
|
| 189 |
+
return payload
|
| 190 |
except Exception as exc:
|
| 191 |
return {"ok": False, "url": url, "error": str(exc)}
|
| 192 |
|
|
|
|
| 194 |
def call_model(
|
| 195 |
client: InferenceClient,
|
| 196 |
messages: List[Dict[str, str]],
|
| 197 |
+
preferred_model: str,
|
| 198 |
+
candidate_models: List[str],
|
| 199 |
temperature: float,
|
| 200 |
max_new_tokens: int,
|
| 201 |
+
) -> Tuple[str, str]:
|
| 202 |
+
model_order: List[str] = []
|
| 203 |
+
for m in [preferred_model] + candidate_models:
|
| 204 |
+
if m and m not in model_order:
|
| 205 |
+
model_order.append(m)
|
| 206 |
+
|
| 207 |
+
last_error = None
|
| 208 |
+
for model_name in model_order:
|
| 209 |
+
try:
|
| 210 |
+
completion = client.chat_completion(
|
| 211 |
+
model=model_name,
|
| 212 |
+
messages=messages,
|
| 213 |
+
temperature=temperature,
|
| 214 |
+
max_tokens=max_new_tokens,
|
| 215 |
+
)
|
| 216 |
+
return completion.choices[0].message.content or "", model_name
|
| 217 |
+
except Exception as exc:
|
| 218 |
+
last_error = exc
|
| 219 |
+
continue
|
| 220 |
+
raise RuntimeError(f"All model candidates failed. Last error: {last_error}")
|
| 221 |
|
| 222 |
|
| 223 |
def build_research_agent(
|
|
|
|
| 230 |
token = os.getenv("HF_TOKEN")
|
| 231 |
client = InferenceClient(token=token)
|
| 232 |
state = AgentState()
|
| 233 |
+
used_model = model
|
| 234 |
+
recent_model_candidates = [m for m in DEFAULT_FREE_MODELS if m != model]
|
| 235 |
|
| 236 |
messages: List[Dict[str, str]] = [
|
| 237 |
{"role": "system", "content": SYSTEM_PROMPT},
|
|
|
|
| 241 |
final_answer: Optional[str] = None
|
| 242 |
|
| 243 |
for turn in range(1, max_turns + 1):
|
| 244 |
+
if state.trusted_notes and turn > 1 and turn % 3 == 0:
|
| 245 |
+
summary_lines = "\n".join(f"- {n}" for n in state.trusted_notes[-6:])
|
| 246 |
+
messages.append(
|
| 247 |
+
{
|
| 248 |
+
"role": "user",
|
| 249 |
+
"content": f"RESEARCH STATE SUMMARY\n{summary_lines}\nUse this summary to avoid repeating work.",
|
| 250 |
+
}
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
model_output, used_model = call_model(
|
| 254 |
client=client,
|
| 255 |
messages=messages,
|
| 256 |
+
preferred_model=model,
|
| 257 |
+
candidate_models=recent_model_candidates,
|
| 258 |
temperature=temperature,
|
| 259 |
max_new_tokens=1400,
|
| 260 |
)
|
|
|
|
| 283 |
query = str(tool_args.get("query", "")).strip()
|
| 284 |
max_results = int(tool_args.get("max_results", max_search_results))
|
| 285 |
max_results = max(1, min(max_results, 10))
|
| 286 |
+
if query in state.searched_query_set:
|
| 287 |
+
tool_response = {
|
| 288 |
+
"ok": True,
|
| 289 |
+
"query": query,
|
| 290 |
+
"cached": True,
|
| 291 |
+
"note": "This query was already searched. Reusing cached result to avoid duplicate work.",
|
| 292 |
+
"results": [],
|
| 293 |
+
}
|
| 294 |
+
else:
|
| 295 |
state.searched_queries.append(query)
|
| 296 |
+
state.searched_query_set.add(query)
|
| 297 |
+
tool_response = run_search(query=query, max_results=max_results)
|
| 298 |
+
if tool_response.get("ok"):
|
| 299 |
+
first_titles = [r.get("title", "") for r in tool_response.get("results", [])[:2]]
|
| 300 |
+
if first_titles:
|
| 301 |
+
state.trusted_notes.append(
|
| 302 |
+
f"Searched '{query}' and found leads: {', '.join(t for t in first_titles if t)}"
|
| 303 |
+
)
|
| 304 |
elif tool_name == "visit":
|
| 305 |
url = str(tool_args.get("url", "")).strip()
|
| 306 |
max_chars = int(tool_args.get("max_chars", 6000))
|
| 307 |
max_chars = max(500, min(max_chars, 20000))
|
| 308 |
+
if url in state.visited_url_set:
|
| 309 |
+
tool_response = {
|
| 310 |
+
"ok": True,
|
| 311 |
+
"url": url,
|
| 312 |
+
"cached": True,
|
| 313 |
+
"note": "This URL was already visited. Reusing cached result to avoid duplicate work.",
|
| 314 |
+
}
|
| 315 |
+
else:
|
| 316 |
state.visited_urls.append(url)
|
| 317 |
+
state.visited_url_set.add(url)
|
| 318 |
+
tool_response = run_visit(url=url, max_chars=max_chars)
|
| 319 |
+
if tool_response.get("ok"):
|
| 320 |
+
snippet = str(tool_response.get("content", ""))[:180]
|
| 321 |
+
if snippet:
|
| 322 |
+
state.trusted_notes.append(
|
| 323 |
+
f"Visited {url} and extracted key context: {snippet}"
|
| 324 |
+
)
|
| 325 |
else:
|
| 326 |
tool_response = {"ok": False, "error": f"Unknown tool: {tool_name}"}
|
| 327 |
|
|
|
|
| 342 |
)
|
| 343 |
|
| 344 |
citations = "\n".join(f"- {url}" for url in sorted(set(state.visited_urls)))
|
| 345 |
+
final_answer = f"**Model used:** `{used_model}`\n\n{final_answer}"
|
| 346 |
if citations:
|
| 347 |
final_answer = f"{final_answer}\n\n### Visited Sources\n{citations}"
|
| 348 |
|
| 349 |
trace_text = json.dumps(
|
| 350 |
{
|
| 351 |
+
"used_model": used_model,
|
| 352 |
"searched_queries": state.searched_queries,
|
| 353 |
"visited_urls": state.visited_urls,
|
| 354 |
+
"trusted_notes": state.trusted_notes[-10:],
|
| 355 |
"trace": state.trace,
|
| 356 |
},
|
| 357 |
ensure_ascii=False,
|