Spaces:
Paused
Paused
| import os | |
| import logging | |
| import asyncio | |
| import io | |
| import random | |
| import time | |
| import base64 | |
| import re | |
| from typing import Annotated, List, TypedDict, Optional | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| import httpx | |
| from duckduckgo_search import DDGS | |
| from PIL import Image | |
| # --- HuggingFace Client --- | |
| from huggingface_hub import InferenceClient | |
| # --- LangChain / AI Core --- | |
| from langchain_ollama import ChatOllama | |
| from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, END, START | |
| from langgraph.prebuilt import ToolNode | |
| from langgraph.checkpoint.memory import MemorySaver | |
| # -------------------------------------------------------------------------------------- | |
| # 1. Configuration | |
| # -------------------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("GenAI-Agent") | |
| MODEL_NAME = "qwen2.5:3b" | |
| BASE_URL = "http://localhost:11434" | |
| HF_TOKEN_GLOBAL = os.getenv("HF_TOKEN", "") | |
| # --- BETTER MODEL FOR REALISM --- | |
| # SDXL Base ki jagah RealVisXL use kar rahe hain (Better photorealism & face consistency) | |
| EDIT_MODEL_ID = "SG161222/RealVisXL_V4.0" | |
| http_client = httpx.AsyncClient(timeout=120.0, follow_redirects=True) | |
| async def lifespan(app: FastAPI): | |
| os.makedirs("static/images", exist_ok=True) | |
| os.makedirs("static/uploads", exist_ok=True) | |
| yield | |
| await http_client.aclose() | |
| app = FastAPI(title="GenAI Stable Agent", lifespan=lifespan) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # -------------------------------------------------------------------------------------- | |
| # 2. Tools (Tuned for Consistency) | |
| # -------------------------------------------------------------------------------------- | |
| async def web_search(query: str) -> str: | |
| """Search the web for information.""" | |
| try: | |
| def run_sync_search(q): | |
| with DDGS() as ddgs: return list(ddgs.text(q, max_results=3)) | |
| results = await asyncio.to_thread(run_sync_search, query) | |
| if not results: return "No results." | |
| return "\n".join([f"Snippet: {r.get('body')}" for r in results]) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| async def generate_image(prompt: str) -> str: | |
| """Create a NEW image from scratch (No input image).""" | |
| try: | |
| seed = random.randint(0, 99999) | |
| safe_prompt = prompt.replace(" ", "%20") | |
| url = f"https://image.pollinations.ai/prompt/{safe_prompt}?seed={seed}&nologo=true&width=1024&height=1024&model=flux" | |
| resp = await http_client.get(url) | |
| if resp.status_code != 200: return "Failed." | |
| filename = f"static/images/gen_{int(time.time())}.png" | |
| img = Image.open(io.BytesIO(resp.content)) | |
| await asyncio.to_thread(img.save, filename) | |
| return f"Image Created: {filename}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| async def edit_image(instruction: str, image_path: str) -> str: | |
| """ | |
| Edits the uploaded image. | |
| IMPORTANT: Provide the EXACT image path. | |
| """ | |
| logger.info(f"🎨 Editing {image_path} | Instruction: {instruction}") | |
| if not os.path.exists(image_path): return "Error: Image file not found." | |
| if not HF_TOKEN_GLOBAL: return "Error: HuggingFace Token is missing." | |
| def run_hf_edit(): | |
| try: | |
| client = InferenceClient(model=EDIT_MODEL_ID, token=HF_TOKEN_GLOBAL) | |
| image = Image.open(image_path).convert("RGB") | |
| # --- CONSISTENCY HACKS --- | |
| # 1. Prompt Booster: Force identity terms | |
| full_prompt = f"photorealistic, {instruction}, same person, consistent face, high detail, 8k, sharp focus" | |
| # 2. Strong Negatives: Prevent face changing | |
| neg_prompt = "cartoon, painting, illustration, distorted face, changed face, different person, ugly, blur, low quality, morphing" | |
| # 3. Strength Tuning (Crucial): | |
| # 0.5 - 0.6 = Best for keeping face (Face won't change, but background change will be subtle) | |
| # 0.7 - 0.8 = Face changes | |
| # Hum 0.6 use karenge (Balance) | |
| output_image = client.image_to_image( | |
| image=image, | |
| prompt=full_prompt, | |
| negative_prompt=neg_prompt, | |
| strength=0.6, # <--- FIXED STRENGTH (Isse loop nahi hoga, consistency maintain rahegi) | |
| guidance_scale=7.5 | |
| ) | |
| return output_image | |
| except Exception as e: | |
| return str(e) | |
| try: | |
| result = await asyncio.to_thread(run_hf_edit) | |
| if isinstance(result, str): return f"Edit Failed: {result}" | |
| filename = f"static/images/edited_{int(time.time())}_{random.randint(0,999)}.png" | |
| await asyncio.to_thread(result.save, filename) | |
| return f"Image Edited Successfully: {filename}" | |
| except Exception as e: | |
| return f"System Error: {str(e)}" | |
| tools = [web_search, generate_image, edit_image] | |
| # -------------------------------------------------------------------------------------- | |
| # 3. Agent Logic (LOOP FIX HERE) | |
| # -------------------------------------------------------------------------------------- | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[BaseMessage], "add_messages"] | |
| llm = ChatOllama(model=MODEL_NAME, base_url=BASE_URL, temperature=0).bind_tools(tools) | |
| SYSTEM_PROMPT = """You are an AI visual assistant. | |
| 1. Use `edit_image` ONLY if user provides an image path. | |
| 2. Use `generate_image` for new creations. | |
| 3. Once you call a tool, your job is DONE. | |
| """ | |
| async def agent_node(state: AgentState): | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"] | |
| response = await llm.ainvoke(messages) | |
| return {"messages": [response]} | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", agent_node) | |
| workflow.add_node("tools", ToolNode(tools)) | |
| workflow.add_edge(START, "agent") | |
| # --- THE LOOP FIX --- | |
| # Logic: Agent -> Tools -> END | |
| # Tool chalne ke baad wapis Agent ke paas mat jao. Seedha khatam karo. | |
| workflow.add_conditional_edges("agent", lambda s: "tools" if s["messages"][-1].tool_calls else END) | |
| workflow.add_edge("tools", END) # <--- STOP LOOP HERE | |
| app_graph = workflow.compile(checkpointer=MemorySaver()) | |
| # -------------------------------------------------------------------------------------- | |
| # 4. API (Same logic) | |
| # -------------------------------------------------------------------------------------- | |
| class ChatRequest(BaseModel): | |
| query: str | |
| thread_id: str | |
| image_base64: Optional[str] = None | |
| hf_token: Optional[str] = None | |
| async def chat_endpoint(req: ChatRequest): | |
| global HF_TOKEN_GLOBAL | |
| if req.hf_token: HF_TOKEN_GLOBAL = req.hf_token | |
| initial_msg = req.query | |
| if req.image_base64: | |
| try: | |
| if "," in req.image_base64: d = req.image_base64.split(",")[1] | |
| else: d = req.image_base64 | |
| fname = f"static/uploads/user_upload_{req.thread_id}_{int(time.time())}.png" | |
| with open(fname, "wb") as f: f.write(base64.b64decode(d)) | |
| initial_msg = f"User uploaded an image at path: '{fname}'. Request: {req.query}" | |
| except: pass | |
| config = {"configurable": {"thread_id": req.thread_id}} | |
| inputs = {"messages": [HumanMessage(content=initial_msg)]} | |
| async def generator(): | |
| try: | |
| async for event in app_graph.astream_events(inputs, config=config, version="v1"): | |
| event_type = event["event"] | |
| # Sirf Tool Output aur Final result stream karo | |
| if event_type == "on_tool_start": | |
| yield f"\n\n⚙️ **Processing:** {event['name']}...\n\n" | |
| elif event_type == "on_tool_end": | |
| out = str(event['data'].get('output')) | |
| if "static/" in out: | |
| path = re.search(r'(static/.*\.png)', out).group(1) | |
| yield f"\n\n\n\n" | |
| else: | |
| yield f"Info: {out}\n" | |
| except Exception as e: | |
| yield f"Error: {str(e)}" | |
| return StreamingResponse(generator(), media_type="text/plain") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| os.makedirs("static/images", exist_ok=True) | |
| os.makedirs("static/uploads", exist_ok=True) | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True) |