Karan6933 commited on
Commit
5f8879c
·
verified ·
1 Parent(s): 6306cb4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +170 -265
main.py CHANGED
@@ -1,139 +1,96 @@
1
  """
2
  GenAI Advanced Agent - Production Ready
3
- Architecture: FastAPI + LangGraph + Ollama + DuckDuckGo
4
- Features: Streaming, Memory, Tools, Structured Output, Error Handling
5
  """
6
 
7
  import os
8
  import logging
9
  import asyncio
10
- import nest_asyncio # ADD THIS
11
  from typing import Annotated, TypedDict, List, Dict, Any, Optional, AsyncGenerator
12
  from contextlib import asynccontextmanager
13
  from datetime import datetime
14
  from enum import Enum
15
 
16
- from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
17
- from fastapi.responses import StreamingResponse, JSONResponse
18
  from fastapi.middleware.cors import CORSMiddleware
19
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
20
  from pydantic import BaseModel, Field, validator
21
 
22
- # Async & Network
23
  import httpx
24
  from duckduckgo_search import DDGS
25
  from bs4 import BeautifulSoup
26
 
27
- # LangChain / AI Core
28
  from langchain_ollama import ChatOllama
29
- from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage, ToolMessage
30
  from langchain_core.tools import tool, BaseTool
31
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
32
- from langchain_core.output_parsers import StrOutputParser
33
  from langchain_core.runnables import RunnableConfig
34
 
35
- # LangGraph
36
  from langgraph.graph import StateGraph, END, START
37
  from langgraph.prebuilt import ToolNode, tools_condition
38
  from langgraph.checkpoint.memory import MemorySaver
39
  from langgraph.checkpoint.base import BaseCheckpointSaver
40
 
41
- # FIX: Apply nest_asyncio for Jupyter/IPython environments
42
  nest_asyncio.apply()
43
 
44
  # --------------------------------------------------------------------------------------
45
- # CONFIGURATION & LOGGING
46
  # --------------------------------------------------------------------------------------
47
 
48
  class Settings(BaseModel):
49
- """Application configuration"""
50
  MODEL_NAME: str = "qwen2.5:3b"
51
  BASE_URL: str = "http://localhost:11434"
52
  TEMPERATURE: float = 0.3
53
- MAX_TOKENS: int = 4096
54
- TIMEOUT: float = 30.0
55
  MAX_SEARCH_RESULTS: int = 5
56
  MAX_CONTENT_LENGTH: int = 4000
57
- LOG_LEVEL: str = "INFO"
58
 
59
  class Config:
60
  env_file = ".env"
61
 
62
  settings = Settings()
63
 
64
- # Structured Logging
65
  logging.basicConfig(
66
- level=getattr(logging, settings.LOG_LEVEL),
67
- format='%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
68
  )
69
  logger = logging.getLogger("GenAI-Agent")
70
 
71
  # --------------------------------------------------------------------------------------
72
- # MODELS & SCHEMAS
73
  # --------------------------------------------------------------------------------------
74
 
75
- class MessageType(str, Enum):
76
- HUMAN = "human"
77
- AI = "ai"
78
- SYSTEM = "system"
79
- TOOL = "tool"
80
-
81
- class ChatMessage(BaseModel):
82
- role: MessageType
83
- content: str
84
- timestamp: Optional[datetime] = Field(default_factory=datetime.now)
85
- metadata: Optional[Dict[str, Any]] = None
86
-
87
  class ChatRequest(BaseModel):
88
- query: str = Field(..., min_length=1, max_length=10000, description="User query")
89
- thread_id: str = Field(..., min_length=1, description="Conversation thread ID")
90
- stream: bool = Field(default=True, description="Enable streaming response")
91
- context: Optional[List[ChatMessage]] = Field(default=None, description="Previous messages")
92
 
93
  @validator('thread_id')
94
  def validate_thread_id(cls, v):
95
- if not v.strip():
96
- raise ValueError("thread_id cannot be empty")
97
  return v.strip()
98
 
99
- class ChatResponse(BaseModel):
100
- response: str
101
- thread_id: str
102
- tools_used: List[str]
103
- tokens_used: Optional[int] = None
104
- processing_time: float
105
-
106
- class HealthStatus(BaseModel):
107
- status: str
108
- model: str
109
- version: str
110
- timestamp: datetime
111
-
112
  # --------------------------------------------------------------------------------------
113
- # STATE MANAGEMENT
114
  # --------------------------------------------------------------------------------------
115
 
116
  class AgentState(TypedDict):
117
- """LangGraph state definition"""
118
  messages: Annotated[List[BaseMessage], "add_messages"]
119
  thread_id: str
120
  tools_used: Annotated[List[str], "append"]
121
- metadata: Dict[str, Any]
 
122
 
123
  # --------------------------------------------------------------------------------------
124
- # TOOLS IMPLEMENTATION
125
  # --------------------------------------------------------------------------------------
126
 
127
  class ToolRegistry:
128
- """Centralized tool management with caching and metrics"""
129
-
130
  def __init__(self):
131
  self._tools: Dict[str, BaseTool] = {}
132
- self._metrics: Dict[str, Dict] = {}
133
 
134
  def register(self, tool_instance: BaseTool):
135
  self._tools[tool_instance.name] = tool_instance
136
- self._metrics[tool_instance.name] = {"calls": 0, "errors": 0, "avg_time": 0}
137
  return tool_instance
138
 
139
  def get(self, name: str) -> Optional[BaseTool]:
@@ -141,159 +98,116 @@ class ToolRegistry:
141
 
142
  def all_tools(self) -> List[BaseTool]:
143
  return list(self._tools.values())
144
-
145
- def record_usage(self, name: str, duration: float, error: bool = False):
146
- if name in self._metrics:
147
- self._metrics[name]["calls"] += 1
148
- if error:
149
- self._metrics[name]["errors"] += 1
150
- # Update running average
151
- prev_avg = self._metrics[name]["avg_time"]
152
- n = self._metrics[name]["calls"]
153
- self._metrics[name]["avg_time"] = (prev_avg * (n-1) + duration) / n
154
 
155
  tool_registry = ToolRegistry()
156
 
157
  @tool
158
  async def web_search(query: str, max_results: int = 5) -> str:
159
  """
160
- Advanced web search with result ranking and filtering.
161
- Use for: current events, technical documentation, news, facts verification.
162
  """
163
- start_time = asyncio.get_event_loop().time()
 
164
 
165
  def _sync_search(q: str):
166
  try:
167
  with DDGS() as ddgs:
168
- results = ddgs.text(q, max_results=max_results)
169
- return list(results)
170
  except Exception as e:
171
- logger.error(f"Search failed: {e}")
172
- raise
173
 
174
  try:
175
- logger.info(f"🔍 Web search: {query}")
176
  results = await asyncio.to_thread(_sync_search, query)
177
 
178
- if not results:
179
- return "No relevant results found."
 
 
 
 
180
 
181
  formatted = []
182
  for idx, r in enumerate(results, 1):
183
- formatted.append(
184
- f"[{idx}] {r.get('title', 'Untitled')}\n"
185
- f"URL: {r.get('href', 'N/A')}\n"
186
- f"Summary: {r.get('body', 'No description')}\n"
187
- )
188
-
189
- duration = asyncio.get_event_loop().time() - start_time
190
- tool_registry.record_usage("web_search", duration)
191
 
192
  return "\n".join(formatted)
193
 
194
  except Exception as e:
195
- tool_registry.record_usage("web_search", 0, error=True)
196
- return f"Search error: {str(e)}"
197
 
198
  @tool
199
- async def read_webpage(url: str, extract_code: bool = False) -> str:
200
  """
201
- Intelligent webpage reader with content extraction and cleaning.
202
- Use for: deep technical details, documentation, code examples.
203
  """
204
- start_time = asyncio.get_event_loop().time()
 
205
 
206
  try:
207
  logger.info(f"📖 Reading: {url}")
208
 
209
  headers = {
210
- "User-Agent": "Mozilla/5.0 (AppleWebKit/537.36) GenAI-Agent/2.0",
211
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
212
- "Accept-Language": "en-US,en;q=0.5",
213
- "Accept-Encoding": "gzip, deflate",
214
- "DNT": "1",
215
- "Connection": "keep-alive",
216
  }
217
 
218
- async with httpx.AsyncClient(timeout=20.0, follow_redirects=True) as client:
219
  response = await client.get(url, headers=headers)
220
  response.raise_for_status()
221
 
222
  soup = BeautifulSoup(response.text, 'lxml')
223
 
224
  # Remove noise
225
- for element in soup(["script", "style", "nav", "footer", "header",
226
- "aside", "advertisement", "svg", "iframe"]):
227
  element.decompose()
228
 
229
- # Extract main content (prefer article/main tags)
230
  main_content = soup.find('article') or soup.find('main') or soup.find('body')
231
-
232
- if extract_code:
233
- # Extract code blocks specifically
234
- code_blocks = main_content.find_all(['pre', 'code'])
235
- code_content = '\n\n'.join(
236
- block.get_text() for block in code_blocks if block.get_text().strip()
237
- )
238
- if code_content:
239
- return f"Code extracted:\n{code_content[:settings.MAX_CONTENT_LENGTH]}"
240
 
241
  # Clean text
242
- text = main_content.get_text(separator='\n') if main_content else soup.get_text()
243
  lines = (line.strip() for line in text.splitlines())
244
  chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
245
  clean_text = '\n'.join(chunk for chunk in chunks if chunk)
246
 
247
- # Smart truncation with context preservation
248
- if len(clean_text) > settings.MAX_CONTENT_LENGTH:
249
- truncated = clean_text[:settings.MAX_CONTENT_LENGTH]
250
- # Try to end at a sentence boundary
251
- last_period = truncated.rfind('.')
252
- if last_period > len(truncated) * 0.8:
253
- truncated = truncated[:last_period + 1]
254
- clean_text = truncated + "\n\n[Content truncated...]"
255
-
256
- duration = asyncio.get_event_loop().time() - start_time
257
- tool_registry.record_usage("read_webpage", duration)
258
-
259
- return clean_text
260
 
261
  except httpx.HTTPStatusError as e:
262
- tool_registry.record_usage("read_webpage", 0, error=True)
263
- return f"HTTP Error {e.response.status_code}: Unable to access {url}"
264
  except Exception as e:
265
- tool_registry.record_usage("read_webpage", 0, error=True)
266
- return f"Scraping error: {str(e)}"
267
 
268
  @tool
269
  async def calculate(expression: str) -> str:
270
  """
271
- Safe mathematical expression evaluator.
272
- Use for: calculations, data processing, unit conversions.
273
  """
274
  try:
275
- # Safe eval with limited scope
276
- allowed_names = {
277
- "abs": abs, "round": round, "max": max, "min": min,
278
- "sum": sum, "pow": pow, "len": len
279
- }
280
- result = eval(expression, {"__builtins__": {}}, allowed_names)
281
  return f"Result: {result}"
282
  except Exception as e:
283
- return f"Calculation error: {str(e)}"
284
 
285
- # Register all tools
286
  tool_registry.register(web_search)
287
  tool_registry.register(read_webpage)
288
  tool_registry.register(calculate)
289
 
290
  # --------------------------------------------------------------------------------------
291
- # LANGGRAPH AGENT ARCHITECTURE
292
  # --------------------------------------------------------------------------------------
293
 
294
  class AgentBuilder:
295
- """Factory for building configurable LangGraph agents"""
296
-
297
  def __init__(self, model_name: str, base_url: str, temperature: float = 0.3):
298
  self.model_name = model_name
299
  self.base_url = base_url
@@ -306,7 +220,6 @@ class AgentBuilder:
306
  return self
307
 
308
  def build(self) -> StateGraph:
309
- # Initialize LLM with tools
310
  llm = ChatOllama(
311
  model=self.model_name,
312
  base_url=self.base_url,
@@ -315,59 +228,122 @@ class AgentBuilder:
315
  num_ctx=8192
316
  ).bind_tools(self.tools)
317
 
318
- # System prompt with dynamic tool descriptions
319
  tool_descriptions = "\n".join([
320
  f"- {t.name}: {t.description}" for t in self.tools
321
  ])
322
 
323
- system_prompt = f"""You are an advanced GenAI technical assistant with access to real-time tools.
 
324
 
325
- AVAILABLE TOOLS:
326
  {tool_descriptions}
327
 
328
- CORE INSTRUCTIONS:
329
- 1. **Always use tools** for current information, calculations, or external data
330
- 2. **Chain tools intelligently**: Search Read Analyze
331
- 3. **Format responses**:
332
- - Start with a brief executive summary
333
- - Use markdown headers (##) for sections
334
- - For code, use XML tags: <code lang="python">your code</code>
335
- - Cite sources when using web data
336
- 4. **Be concise** but thorough. Avoid hallucinations.
 
 
337
 
338
  Current date: {datetime.now().strftime("%Y-%m-%d")}
339
  """
340
 
341
- # Agent node
342
  async def agent_node(state: AgentState):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  messages = [SystemMessage(content=system_prompt)] + state["messages"]
344
  response = await llm.ainvoke(messages)
345
 
346
  # Track tool usage
 
347
  if response.tool_calls:
348
- state["tools_used"].extend([
349
- tc["name"] for tc in response.tool_calls
350
- ])
351
 
352
- return {"messages": [response], "tools_used": state["tools_used"]}
 
 
 
 
 
353
 
354
- # Tool node with error handling
355
- tool_node = ToolNode(self.tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- # Build graph
358
  workflow = StateGraph(AgentState)
359
  workflow.add_node("agent", agent_node)
360
- workflow.add_node("tools", tool_node)
361
 
362
  workflow.add_edge(START, "agent")
363
  workflow.add_conditional_edges(
364
  "agent",
365
- tools_condition,
366
  {"tools": "tools", END: END}
367
  )
368
  workflow.add_edge("tools", "agent")
369
 
370
- # Compile with memory
371
  if self.checkpointer:
372
  return workflow.compile(checkpointer=self.checkpointer)
373
  return workflow.compile()
@@ -376,7 +352,6 @@ Current date: {datetime.now().strftime("%Y-%m-%d")}
376
  # FASTAPI APPLICATION
377
  # --------------------------------------------------------------------------------------
378
 
379
- # Global state
380
  class AppState:
381
  def __init__(self):
382
  self.http_client: Optional[httpx.AsyncClient] = None
@@ -387,8 +362,6 @@ app_state = AppState()
387
 
388
  @asynccontextmanager
389
  async def lifespan(app: FastAPI):
390
- """Application lifecycle management"""
391
- # Startup
392
  logger.info("🚀 Starting GenAI Agent...")
393
 
394
  app_state.http_client = httpx.AsyncClient(
@@ -396,7 +369,6 @@ async def lifespan(app: FastAPI):
396
  limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
397
  )
398
 
399
- # Initialize memory and agent
400
  app_state.memory = MemorySaver()
401
  builder = AgentBuilder(
402
  model_name=settings.MODEL_NAME,
@@ -405,25 +377,19 @@ async def lifespan(app: FastAPI):
405
  )
406
  app_state.agent = builder.with_memory(app_state.memory).build()
407
 
408
- logger.info(f"✅ Agent ready with model: {settings.MODEL_NAME}")
409
  yield
410
 
411
- # Shutdown
412
  logger.info("🛑 Shutting down...")
413
  if app_state.http_client:
414
  await app_state.http_client.aclose()
415
 
416
- # Create FastAPI app
417
  app = FastAPI(
418
- title="GenAI Advanced Agent API",
419
- description="Production-ready AI agent with web search, RAG, and memory",
420
- version="3.0.0",
421
- lifespan=lifespan,
422
- docs_url="/docs",
423
- redoc_url="/redoc"
424
  )
425
 
426
- # CORS
427
  app.add_middleware(
428
  CORSMiddleware,
429
  allow_origins=["*"],
@@ -432,45 +398,28 @@ app.add_middleware(
432
  allow_headers=["*"],
433
  )
434
 
435
- # Security
436
- security = HTTPBearer(auto_error=False)
437
-
438
  # --------------------------------------------------------------------------------------
439
  # API ENDPOINTS
440
  # --------------------------------------------------------------------------------------
441
 
442
- @app.get("/health", response_model=HealthStatus)
443
  async def health_check():
444
- """Health check endpoint"""
445
- return HealthStatus(
446
- status="healthy",
447
- model=settings.MODEL_NAME,
448
- version="3.0.0",
449
- timestamp=datetime.now()
450
- )
451
-
452
- @app.get("/tools")
453
- async def list_tools():
454
- """List available tools and their metrics"""
455
  return {
456
- "tools": [
457
- {
458
- "name": name,
459
- "description": tool_registry.get(name).description,
460
- "metrics": tool_registry._metrics.get(name, {})
461
- }
462
- for name in tool_registry._tools.keys()
463
- ]
464
  }
465
 
466
  async def stream_response(query: str, thread_id: str) -> AsyncGenerator[str, None]:
467
- """Generate streaming response with real-time updates"""
468
  config = RunnableConfig(configurable={"thread_id": thread_id})
469
  inputs = {
470
  "messages": [HumanMessage(content=query)],
471
  "thread_id": thread_id,
472
  "tools_used": [],
473
- "metadata": {}
 
474
  }
475
 
476
  yield f"event: start\ndata: {thread_id}\n\n"
@@ -489,8 +438,10 @@ async def stream_response(query: str, thread_id: str) -> AsyncGenerator[str, Non
489
  yield f"event: tool_start\ndata: {tool_name}\n\n"
490
 
491
  elif event_type == "on_tool_end":
492
- output = str(event["data"].get("output", ""))[:200]
493
- yield f"event: tool_end\ndata: {output}...\n\n"
 
 
494
 
495
  yield "event: complete\ndata: done\n\n"
496
 
@@ -500,9 +451,6 @@ async def stream_response(query: str, thread_id: str) -> AsyncGenerator[str, Non
500
 
501
  @app.post("/chat")
502
  async def chat_endpoint(request: ChatRequest):
503
- """
504
- Main chat endpoint with streaming support
505
- """
506
  try:
507
  if request.stream:
508
  return StreamingResponse(
@@ -511,78 +459,35 @@ async def chat_endpoint(request: ChatRequest):
511
  headers={
512
  "Cache-Control": "no-cache",
513
  "Connection": "keep-alive",
514
- "X-Thread-ID": request.thread_id
515
  }
516
  )
517
  else:
518
- # Non-streaming response
519
  config = RunnableConfig(configurable={"thread_id": request.thread_id})
520
  inputs = {
521
  "messages": [HumanMessage(content=request.query)],
522
  "thread_id": request.thread_id,
523
  "tools_used": [],
524
- "metadata": {}
 
525
  }
526
 
527
  result = await app_state.agent.ainvoke(inputs, config=config)
528
  final_message = result["messages"][-1]
529
 
530
- return ChatResponse(
531
- response=final_message.content,
532
- thread_id=request.thread_id,
533
- tools_used=result.get("tools_used", []),
534
- processing_time=0.0 # Calculate if needed
535
- )
536
 
537
  except Exception as e:
538
  logger.error(f"Chat error: {e}")
539
  raise HTTPException(status_code=500, detail=str(e))
540
 
541
- @app.post("/chat/sync", response_model=ChatResponse)
542
- async def chat_sync(request: ChatRequest):
543
- """Synchronous chat endpoint for simple requests"""
544
- return await chat_endpoint(request)
545
-
546
- @app.delete("/memory/{thread_id}")
547
- async def clear_memory(thread_id: str):
548
- """Clear conversation memory for a thread"""
549
- try:
550
- # MemorySaver specific implementation
551
- if hasattr(app_state.memory, 'delete'):
552
- await app_state.memory.delete(thread_id)
553
- return {"status": "success", "message": f"Memory cleared for {thread_id}"}
554
- except Exception as e:
555
- raise HTTPException(status_code=500, detail=str(e))
556
-
557
- @app.get("/memory/{thread_id}")
558
- async def get_conversation(thread_id: str):
559
- """Retrieve conversation history"""
560
- try:
561
- config = RunnableConfig(configurable={"thread_id": thread_id})
562
- # This depends on your checkpointer implementation
563
- return {"thread_id": thread_id, "history": []}
564
- except Exception as e:
565
- raise HTTPException(status_code=500, detail=str(e))
566
-
567
- # --------------------------------------------------------------------------------------
568
- # MAIN ENTRY - FIXED FOR JUPYTER
569
- # --------------------------------------------------------------------------------------
570
-
571
  def run_server():
572
- """Run server with proper async handling for Jupyter"""
573
  import uvicorn
574
-
575
- # Use this instead of asyncio.run()
576
- uvicorn.run(
577
- app,
578
- host="0.0.0.0",
579
- port=8000,
580
- log_level="info"
581
- )
582
-
583
- # For Jupyter/IPython - run directly
584
- # run_server()
585
 
586
- # For normal Python execution
587
  if __name__ == "__main__":
588
  run_server()
 
1
  """
2
  GenAI Advanced Agent - Production Ready
3
+ Fixed: Infinite loop prevention, better tool error handling
 
4
  """
5
 
6
  import os
7
  import logging
8
  import asyncio
9
+ import nest_asyncio
10
  from typing import Annotated, TypedDict, List, Dict, Any, Optional, AsyncGenerator
11
  from contextlib import asynccontextmanager
12
  from datetime import datetime
13
  from enum import Enum
14
 
15
+ from fastapi import FastAPI, HTTPException
16
+ from fastapi.responses import StreamingResponse
17
  from fastapi.middleware.cors import CORSMiddleware
 
18
  from pydantic import BaseModel, Field, validator
19
 
 
20
  import httpx
21
  from duckduckgo_search import DDGS
22
  from bs4 import BeautifulSoup
23
 
 
24
  from langchain_ollama import ChatOllama
25
+ from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage, ToolMessage, AIMessage
26
  from langchain_core.tools import tool, BaseTool
 
 
27
  from langchain_core.runnables import RunnableConfig
28
 
 
29
  from langgraph.graph import StateGraph, END, START
30
  from langgraph.prebuilt import ToolNode, tools_condition
31
  from langgraph.checkpoint.memory import MemorySaver
32
  from langgraph.checkpoint.base import BaseCheckpointSaver
33
 
 
34
  nest_asyncio.apply()
35
 
36
  # --------------------------------------------------------------------------------------
37
+ # CONFIGURATION
38
  # --------------------------------------------------------------------------------------
39
 
40
  class Settings(BaseModel):
 
41
  MODEL_NAME: str = "qwen2.5:3b"
42
  BASE_URL: str = "http://localhost:11434"
43
  TEMPERATURE: float = 0.3
44
+ MAX_ITERATIONS: int = 3 # Prevent infinite loops
 
45
  MAX_SEARCH_RESULTS: int = 5
46
  MAX_CONTENT_LENGTH: int = 4000
47
+ TIMEOUT: float = 30.0
48
 
49
  class Config:
50
  env_file = ".env"
51
 
52
  settings = Settings()
53
 
 
54
  logging.basicConfig(
55
+ level=logging.INFO,
56
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
57
  )
58
  logger = logging.getLogger("GenAI-Agent")
59
 
60
  # --------------------------------------------------------------------------------------
61
+ # MODELS
62
  # --------------------------------------------------------------------------------------
63
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class ChatRequest(BaseModel):
65
+ query: str = Field(..., min_length=1, max_length=10000)
66
+ thread_id: str = Field(..., min_length=1)
67
+ stream: bool = Field(default=True)
 
68
 
69
  @validator('thread_id')
70
  def validate_thread_id(cls, v):
 
 
71
  return v.strip()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # --------------------------------------------------------------------------------------
74
+ # STATE MANAGEMENT - FIXED: Added iteration counter
75
  # --------------------------------------------------------------------------------------
76
 
77
  class AgentState(TypedDict):
 
78
  messages: Annotated[List[BaseMessage], "add_messages"]
79
  thread_id: str
80
  tools_used: Annotated[List[str], "append"]
81
+ iteration_count: int # NEW: Track iterations to prevent loops
82
+ last_tool_result: Optional[str] # NEW: Track last tool result
83
 
84
  # --------------------------------------------------------------------------------------
85
+ # TOOLS - FIXED: Better error messages and validation
86
  # --------------------------------------------------------------------------------------
87
 
88
  class ToolRegistry:
 
 
89
  def __init__(self):
90
  self._tools: Dict[str, BaseTool] = {}
 
91
 
92
  def register(self, tool_instance: BaseTool):
93
  self._tools[tool_instance.name] = tool_instance
 
94
  return tool_instance
95
 
96
  def get(self, name: str) -> Optional[BaseTool]:
 
98
 
99
  def all_tools(self) -> List[BaseTool]:
100
  return list(self._tools.values())
 
 
 
 
 
 
 
 
 
 
101
 
102
  tool_registry = ToolRegistry()
103
 
104
  @tool
105
  async def web_search(query: str, max_results: int = 5) -> str:
106
  """
107
+ Search the web for current information. Returns formatted search results.
108
+ Use this for: current events, documentation, news, facts.
109
  """
110
+ if not query or len(query.strip()) < 2:
111
+ return "ERROR: Query too short or empty"
112
 
113
  def _sync_search(q: str):
114
  try:
115
  with DDGS() as ddgs:
116
+ results = list(ddgs.text(q, max_results=max_results))
117
+ return results
118
  except Exception as e:
119
+ logger.error(f"DDGS Error: {e}")
120
+ return f"ERROR: Search failed - {str(e)}"
121
 
122
  try:
123
+ logger.info(f"🔍 Searching: {query}")
124
  results = await asyncio.to_thread(_sync_search, query)
125
 
126
+ # Handle error string return
127
+ if isinstance(results, str) and results.startswith("ERROR"):
128
+ return results
129
+
130
+ if not results or len(results) == 0:
131
+ return "ERROR: No results found for this query. Try a different search term."
132
 
133
  formatted = []
134
  for idx, r in enumerate(results, 1):
135
+ title = r.get('title', 'Untitled')
136
+ link = r.get('href', 'N/A')
137
+ body = r.get('body', 'No description')
138
+ formatted.append(f"[{idx}] {title}\nURL: {link}\nSummary: {body}\n")
 
 
 
 
139
 
140
  return "\n".join(formatted)
141
 
142
  except Exception as e:
143
+ logger.error(f"Search error: {e}")
144
+ return f"ERROR: {str(e)}"
145
 
146
  @tool
147
+ async def read_webpage(url: str) -> str:
148
  """
149
+ Read content from a specific URL. Use for detailed documentation.
 
150
  """
151
+ if not url.startswith(('http://', 'https://')):
152
+ return "ERROR: Invalid URL format"
153
 
154
  try:
155
  logger.info(f"📖 Reading: {url}")
156
 
157
  headers = {
158
+ "User-Agent": "Mozilla/5.0 GenAI-Agent/2.0",
159
+ "Accept": "text/html,application/xhtml+xml",
 
 
 
 
160
  }
161
 
162
+ async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
163
  response = await client.get(url, headers=headers)
164
  response.raise_for_status()
165
 
166
  soup = BeautifulSoup(response.text, 'lxml')
167
 
168
  # Remove noise
169
+ for element in soup(["script", "style", "nav", "footer", "header"]):
 
170
  element.decompose()
171
 
 
172
  main_content = soup.find('article') or soup.find('main') or soup.find('body')
173
+ text = main_content.get_text(separator='\n') if main_content else soup.get_text()
 
 
 
 
 
 
 
 
174
 
175
  # Clean text
 
176
  lines = (line.strip() for line in text.splitlines())
177
  chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
178
  clean_text = '\n'.join(chunk for chunk in chunks if chunk)
179
 
180
+ if len(clean_text) < 100:
181
+ return "ERROR: Content too short or page blocked"
182
+
183
+ return clean_text[:settings.MAX_CONTENT_LENGTH] + "\n[Content truncated...]" if len(clean_text) > settings.MAX_CONTENT_LENGTH else clean_text
 
 
 
 
 
 
 
 
 
184
 
185
  except httpx.HTTPStatusError as e:
186
+ return f"ERROR: HTTP {e.response.status_code} - Unable to access page"
 
187
  except Exception as e:
188
+ return f"ERROR: {str(e)}"
 
189
 
190
  @tool
191
  async def calculate(expression: str) -> str:
192
  """
193
+ Calculate mathematical expressions safely.
 
194
  """
195
  try:
196
+ allowed = {"abs": abs, "round": round, "max": max, "min": min, "sum": sum, "pow": pow}
197
+ result = eval(expression, {"__builtins__": {}}, allowed)
 
 
 
 
198
  return f"Result: {result}"
199
  except Exception as e:
200
+ return f"ERROR: Invalid expression - {str(e)}"
201
 
 
202
  tool_registry.register(web_search)
203
  tool_registry.register(read_webpage)
204
  tool_registry.register(calculate)
205
 
206
  # --------------------------------------------------------------------------------------
207
+ # LANGGRAPH - FIXED: Added iteration limit and better routing
208
  # --------------------------------------------------------------------------------------
209
 
210
  class AgentBuilder:
 
 
211
  def __init__(self, model_name: str, base_url: str, temperature: float = 0.3):
212
  self.model_name = model_name
213
  self.base_url = base_url
 
220
  return self
221
 
222
  def build(self) -> StateGraph:
 
223
  llm = ChatOllama(
224
  model=self.model_name,
225
  base_url=self.base_url,
 
228
  num_ctx=8192
229
  ).bind_tools(self.tools)
230
 
 
231
  tool_descriptions = "\n".join([
232
  f"- {t.name}: {t.description}" for t in self.tools
233
  ])
234
 
235
+ # FIXED: Stronger instructions to prevent loops
236
+ system_prompt = f"""You are an advanced AI assistant with tools.
237
 
238
+ TOOLS AVAILABLE:
239
  {tool_descriptions}
240
 
241
+ CRITICAL RULES:
242
+ 1. **MAXIMUM 2 tool calls per conversation** - After that, answer with available info
243
+ 2. **NEVER call the same tool twice** with similar queries
244
+ 3. **If a tool returns ERROR**, do NOT retry - explain the limitation to user
245
+ 4. **If web_search returns no results**, tell user you couldn't find info online
246
+ 5. **DO NOT LOOP** - If you've searched once, don't search again
247
+
248
+ Response Format:
249
+ - Start with brief summary
250
+ - Use ## for headers
251
+ - Use XML for code: <code lang="python">code</code>
252
 
253
  Current date: {datetime.now().strftime("%Y-%m-%d")}
254
  """
255
 
 
256
  async def agent_node(state: AgentState):
257
+ # FIXED: Check iteration limit
258
+ if state.get("iteration_count", 0) >= settings.MAX_ITERATIONS:
259
+ logger.warning("Max iterations reached, forcing end")
260
+ # Force final response
261
+ messages = state["messages"] + [
262
+ AIMessage(content="I've reached the maximum number of tool calls. Let me provide the best answer based on the information gathered so far.")
263
+ ]
264
+ return {
265
+ "messages": messages,
266
+ "iteration_count": state["iteration_count"],
267
+ "tools_used": state.get("tools_used", []),
268
+ "last_tool_result": state.get("last_tool_result")
269
+ }
270
+
271
  messages = [SystemMessage(content=system_prompt)] + state["messages"]
272
  response = await llm.ainvoke(messages)
273
 
274
  # Track tool usage
275
+ tools_used = state.get("tools_used", []).copy()
276
  if response.tool_calls:
277
+ tools_used.extend([tc["name"] for tc in response.tool_calls])
 
 
278
 
279
+ return {
280
+ "messages": [response],
281
+ "iteration_count": state.get("iteration_count", 0) + 1,
282
+ "tools_used": tools_used,
283
+ "last_tool_result": state.get("last_tool_result")
284
+ }
285
 
286
+ # FIXED: Custom tool node with error tracking
287
+ async def tool_node_with_tracking(state: AgentState):
288
+ tool_node = ToolNode(self.tools)
289
+ result = await tool_node.ainvoke(state)
290
+
291
+ # Check if tool returned error
292
+ last_msg = result["messages"][-1] if result["messages"] else None
293
+ if last_msg and hasattr(last_msg, 'content'):
294
+ content = str(last_msg.content)
295
+ if content.startswith("ERROR") or "No results found" in content:
296
+ logger.warning(f"Tool error detected: {content[:100]}")
297
+ # Add error context to state
298
+ result["last_tool_result"] = "error"
299
+
300
+ # Increment iteration count
301
+ result["iteration_count"] = state.get("iteration_count", 0) + 1
302
+ result["tools_used"] = state.get("tools_used", [])
303
+
304
+ return result
305
+
306
+ # FIXED: Better conditional routing
307
+ def should_continue(state: AgentState) -> str:
308
+ last_message = state["messages"][-1] if state["messages"] else None
309
+
310
+ # Check iteration limit
311
+ if state.get("iteration_count", 0) >= settings.MAX_ITERATIONS:
312
+ logger.info("Max iterations reached, ending")
313
+ return END
314
+
315
+ # Check if last tool had error
316
+ if state.get("last_tool_result") == "error":
317
+ logger.info("Previous tool had error, ending to prevent loop")
318
+ return END
319
+
320
+ # Check if there are tool calls
321
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
322
+ # Check if same tool being called repeatedly
323
+ current_tools = [tc["name"] for tc in last_message.tool_calls]
324
+ previous_tools = state.get("tools_used", [])
325
+
326
+ # If web_search called more than once, prevent loop
327
+ if current_tools.count("web_search") > 0 and previous_tools.count("web_search") >= 1:
328
+ logger.warning("Preventing web_search loop")
329
+ return END
330
+
331
+ return "tools"
332
+
333
+ return END
334
 
 
335
  workflow = StateGraph(AgentState)
336
  workflow.add_node("agent", agent_node)
337
+ workflow.add_node("tools", tool_node_with_tracking)
338
 
339
  workflow.add_edge(START, "agent")
340
  workflow.add_conditional_edges(
341
  "agent",
342
+ should_continue,
343
  {"tools": "tools", END: END}
344
  )
345
  workflow.add_edge("tools", "agent")
346
 
 
347
  if self.checkpointer:
348
  return workflow.compile(checkpointer=self.checkpointer)
349
  return workflow.compile()
 
352
  # FASTAPI APPLICATION
353
  # --------------------------------------------------------------------------------------
354
 
 
355
  class AppState:
356
  def __init__(self):
357
  self.http_client: Optional[httpx.AsyncClient] = None
 
362
 
363
  @asynccontextmanager
364
  async def lifespan(app: FastAPI):
 
 
365
  logger.info("🚀 Starting GenAI Agent...")
366
 
367
  app_state.http_client = httpx.AsyncClient(
 
369
  limits=httpx.Limits(max_keepalive_connections=20, max_connections=100)
370
  )
371
 
 
372
  app_state.memory = MemorySaver()
373
  builder = AgentBuilder(
374
  model_name=settings.MODEL_NAME,
 
377
  )
378
  app_state.agent = builder.with_memory(app_state.memory).build()
379
 
380
+ logger.info(f"✅ Agent ready: {settings.MODEL_NAME}")
381
  yield
382
 
 
383
  logger.info("🛑 Shutting down...")
384
  if app_state.http_client:
385
  await app_state.http_client.aclose()
386
 
 
387
  app = FastAPI(
388
+ title="GenAI Agent API",
389
+ version="3.1.0",
390
+ lifespan=lifespan
 
 
 
391
  )
392
 
 
393
  app.add_middleware(
394
  CORSMiddleware,
395
  allow_origins=["*"],
 
398
  allow_headers=["*"],
399
  )
400
 
 
 
 
401
  # --------------------------------------------------------------------------------------
402
  # API ENDPOINTS
403
  # --------------------------------------------------------------------------------------
404
 
405
+ @app.get("/health")
406
  async def health_check():
 
 
 
 
 
 
 
 
 
 
 
407
  return {
408
+ "status": "healthy",
409
+ "model": settings.MODEL_NAME,
410
+ "version": "3.1.0",
411
+ "max_iterations": settings.MAX_ITERATIONS,
412
+ "timestamp": datetime.now()
 
 
 
413
  }
414
 
415
  async def stream_response(query: str, thread_id: str) -> AsyncGenerator[str, None]:
 
416
  config = RunnableConfig(configurable={"thread_id": thread_id})
417
  inputs = {
418
  "messages": [HumanMessage(content=query)],
419
  "thread_id": thread_id,
420
  "tools_used": [],
421
+ "iteration_count": 0, # Initialize counter
422
+ "last_tool_result": None
423
  }
424
 
425
  yield f"event: start\ndata: {thread_id}\n\n"
 
438
  yield f"event: tool_start\ndata: {tool_name}\n\n"
439
 
440
  elif event_type == "on_tool_end":
441
+ output = str(event["data"].get("output", ""))
442
+ # Truncate long outputs
443
+ preview = output[:200] + "..." if len(output) > 200 else output
444
+ yield f"event: tool_end\ndata: {preview}\n\n"
445
 
446
  yield "event: complete\ndata: done\n\n"
447
 
 
451
 
452
  @app.post("/chat")
453
  async def chat_endpoint(request: ChatRequest):
 
 
 
454
  try:
455
  if request.stream:
456
  return StreamingResponse(
 
459
  headers={
460
  "Cache-Control": "no-cache",
461
  "Connection": "keep-alive",
 
462
  }
463
  )
464
  else:
 
465
  config = RunnableConfig(configurable={"thread_id": request.thread_id})
466
  inputs = {
467
  "messages": [HumanMessage(content=request.query)],
468
  "thread_id": request.thread_id,
469
  "tools_used": [],
470
+ "iteration_count": 0,
471
+ "last_tool_result": None
472
  }
473
 
474
  result = await app_state.agent.ainvoke(inputs, config=config)
475
  final_message = result["messages"][-1]
476
 
477
+ return {
478
+ "response": final_message.content,
479
+ "thread_id": request.thread_id,
480
+ "tools_used": result.get("tools_used", []),
481
+ "iterations": result.get("iteration_count", 0)
482
+ }
483
 
484
  except Exception as e:
485
  logger.error(f"Chat error: {e}")
486
  raise HTTPException(status_code=500, detail=str(e))
487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  def run_server():
 
489
  import uvicorn
490
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
 
 
 
 
 
 
 
 
 
 
491
 
 
492
  if __name__ == "__main__":
493
  run_server()