import os import logging import asyncio import io import random import time import base64 import re from typing import Annotated, List, TypedDict, Optional from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.responses import StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import httpx from duckduckgo_search import DDGS from PIL import Image # --- HuggingFace Client --- from huggingface_hub import InferenceClient # --- LangChain / AI Core --- from langchain_ollama import ChatOllama from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage from langchain_core.tools import tool from langgraph.graph import StateGraph, END, START from langgraph.prebuilt import ToolNode from langgraph.checkpoint.memory import MemorySaver # -------------------------------------------------------------------------------------- # 1. Configuration # -------------------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("GenAI-Agent") MODEL_NAME = "qwen2.5:3b" BASE_URL = "http://localhost:11434" HF_TOKEN_GLOBAL = os.getenv("HF_TOKEN", "") # --- BETTER MODEL FOR REALISM --- # SDXL Base ki jagah RealVisXL use kar rahe hain (Better photorealism & face consistency) EDIT_MODEL_ID = "SG161222/RealVisXL_V4.0" http_client = httpx.AsyncClient(timeout=120.0, follow_redirects=True) @asynccontextmanager async def lifespan(app: FastAPI): os.makedirs("static/images", exist_ok=True) os.makedirs("static/uploads", exist_ok=True) yield await http_client.aclose() app = FastAPI(title="GenAI Stable Agent", lifespan=lifespan) app.mount("/static", StaticFiles(directory="static"), name="static") # -------------------------------------------------------------------------------------- # 2. Tools (Tuned for Consistency) # -------------------------------------------------------------------------------------- @tool async def web_search(query: str) -> str: """Search the web for information.""" try: def run_sync_search(q): with DDGS() as ddgs: return list(ddgs.text(q, max_results=3)) results = await asyncio.to_thread(run_sync_search, query) if not results: return "No results." return "\n".join([f"Snippet: {r.get('body')}" for r in results]) except Exception as e: return f"Error: {str(e)}" @tool async def generate_image(prompt: str) -> str: """Create a NEW image from scratch (No input image).""" try: seed = random.randint(0, 99999) safe_prompt = prompt.replace(" ", "%20") url = f"https://image.pollinations.ai/prompt/{safe_prompt}?seed={seed}&nologo=true&width=1024&height=1024&model=flux" resp = await http_client.get(url) if resp.status_code != 200: return "Failed." filename = f"static/images/gen_{int(time.time())}.png" img = Image.open(io.BytesIO(resp.content)) await asyncio.to_thread(img.save, filename) return f"Image Created: {filename}" except Exception as e: return f"Error: {str(e)}" @tool async def edit_image(instruction: str, image_path: str) -> str: """ Edits the uploaded image. IMPORTANT: Provide the EXACT image path. """ logger.info(f"🎨 Editing {image_path} | Instruction: {instruction}") if not os.path.exists(image_path): return "Error: Image file not found." if not HF_TOKEN_GLOBAL: return "Error: HuggingFace Token is missing." def run_hf_edit(): try: client = InferenceClient(model=EDIT_MODEL_ID, token=HF_TOKEN_GLOBAL) image = Image.open(image_path).convert("RGB") # --- CONSISTENCY HACKS --- # 1. Prompt Booster: Force identity terms full_prompt = f"photorealistic, {instruction}, same person, consistent face, high detail, 8k, sharp focus" # 2. Strong Negatives: Prevent face changing neg_prompt = "cartoon, painting, illustration, distorted face, changed face, different person, ugly, blur, low quality, morphing" # 3. Strength Tuning (Crucial): # 0.5 - 0.6 = Best for keeping face (Face won't change, but background change will be subtle) # 0.7 - 0.8 = Face changes # Hum 0.6 use karenge (Balance) output_image = client.image_to_image( image=image, prompt=full_prompt, negative_prompt=neg_prompt, strength=0.6, # <--- FIXED STRENGTH (Isse loop nahi hoga, consistency maintain rahegi) guidance_scale=7.5 ) return output_image except Exception as e: return str(e) try: result = await asyncio.to_thread(run_hf_edit) if isinstance(result, str): return f"Edit Failed: {result}" filename = f"static/images/edited_{int(time.time())}_{random.randint(0,999)}.png" await asyncio.to_thread(result.save, filename) return f"Image Edited Successfully: {filename}" except Exception as e: return f"System Error: {str(e)}" tools = [web_search, generate_image, edit_image] # -------------------------------------------------------------------------------------- # 3. Agent Logic (LOOP FIX HERE) # -------------------------------------------------------------------------------------- class AgentState(TypedDict): messages: Annotated[List[BaseMessage], "add_messages"] llm = ChatOllama(model=MODEL_NAME, base_url=BASE_URL, temperature=0).bind_tools(tools) SYSTEM_PROMPT = """You are an AI visual assistant. 1. Use `edit_image` ONLY if user provides an image path. 2. Use `generate_image` for new creations. 3. Once you call a tool, your job is DONE. """ async def agent_node(state: AgentState): messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"] response = await llm.ainvoke(messages) return {"messages": [response]} workflow = StateGraph(AgentState) workflow.add_node("agent", agent_node) workflow.add_node("tools", ToolNode(tools)) workflow.add_edge(START, "agent") # --- THE LOOP FIX --- # Logic: Agent -> Tools -> END # Tool chalne ke baad wapis Agent ke paas mat jao. Seedha khatam karo. workflow.add_conditional_edges("agent", lambda s: "tools" if s["messages"][-1].tool_calls else END) workflow.add_edge("tools", END) # <--- STOP LOOP HERE app_graph = workflow.compile(checkpointer=MemorySaver()) # -------------------------------------------------------------------------------------- # 4. API (Same logic) # -------------------------------------------------------------------------------------- class ChatRequest(BaseModel): query: str thread_id: str image_base64: Optional[str] = None hf_token: Optional[str] = None @app.post("/chat") async def chat_endpoint(req: ChatRequest): global HF_TOKEN_GLOBAL if req.hf_token: HF_TOKEN_GLOBAL = req.hf_token initial_msg = req.query if req.image_base64: try: if "," in req.image_base64: d = req.image_base64.split(",")[1] else: d = req.image_base64 fname = f"static/uploads/user_upload_{req.thread_id}_{int(time.time())}.png" with open(fname, "wb") as f: f.write(base64.b64decode(d)) initial_msg = f"User uploaded an image at path: '{fname}'. Request: {req.query}" except: pass config = {"configurable": {"thread_id": req.thread_id}} inputs = {"messages": [HumanMessage(content=initial_msg)]} async def generator(): try: async for event in app_graph.astream_events(inputs, config=config, version="v1"): event_type = event["event"] # Sirf Tool Output aur Final result stream karo if event_type == "on_tool_start": yield f"\n\n⚙️ **Processing:** {event['name']}...\n\n" elif event_type == "on_tool_end": out = str(event['data'].get('output')) if "static/" in out: path = re.search(r'(static/.*\.png)', out).group(1) yield f"\n\n![Result]({path})\n\n" else: yield f"Info: {out}\n" except Exception as e: yield f"Error: {str(e)}" return StreamingResponse(generator(), media_type="text/plain") if __name__ == "__main__": import uvicorn os.makedirs("static/images", exist_ok=True) os.makedirs("static/uploads", exist_ok=True) uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)