imageGenerator / main.py
Karan6933's picture
Update main.py
fc9000d verified
import os
import logging
import asyncio
import io
import random
import time
import base64
import re
from typing import Annotated, List, TypedDict, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import httpx
from duckduckgo_search import DDGS
from PIL import Image
# --- HuggingFace Client ---
from huggingface_hub import InferenceClient
# --- LangChain / AI Core ---
from langchain_ollama import ChatOllama
from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
# --------------------------------------------------------------------------------------
# 1. Configuration
# --------------------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("GenAI-Agent")
MODEL_NAME = "qwen2.5:3b"
BASE_URL = "http://localhost:11434"
HF_TOKEN_GLOBAL = os.getenv("HF_TOKEN", "")
# --- BETTER MODEL FOR REALISM ---
# SDXL Base ki jagah RealVisXL use kar rahe hain (Better photorealism & face consistency)
EDIT_MODEL_ID = "SG161222/RealVisXL_V4.0"
http_client = httpx.AsyncClient(timeout=120.0, follow_redirects=True)
@asynccontextmanager
async def lifespan(app: FastAPI):
os.makedirs("static/images", exist_ok=True)
os.makedirs("static/uploads", exist_ok=True)
yield
await http_client.aclose()
app = FastAPI(title="GenAI Stable Agent", lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
# --------------------------------------------------------------------------------------
# 2. Tools (Tuned for Consistency)
# --------------------------------------------------------------------------------------
@tool
async def web_search(query: str) -> str:
"""Search the web for information."""
try:
def run_sync_search(q):
with DDGS() as ddgs: return list(ddgs.text(q, max_results=3))
results = await asyncio.to_thread(run_sync_search, query)
if not results: return "No results."
return "\n".join([f"Snippet: {r.get('body')}" for r in results])
except Exception as e:
return f"Error: {str(e)}"
@tool
async def generate_image(prompt: str) -> str:
"""Create a NEW image from scratch (No input image)."""
try:
seed = random.randint(0, 99999)
safe_prompt = prompt.replace(" ", "%20")
url = f"https://image.pollinations.ai/prompt/{safe_prompt}?seed={seed}&nologo=true&width=1024&height=1024&model=flux"
resp = await http_client.get(url)
if resp.status_code != 200: return "Failed."
filename = f"static/images/gen_{int(time.time())}.png"
img = Image.open(io.BytesIO(resp.content))
await asyncio.to_thread(img.save, filename)
return f"Image Created: {filename}"
except Exception as e:
return f"Error: {str(e)}"
@tool
async def edit_image(instruction: str, image_path: str) -> str:
"""
Edits the uploaded image.
IMPORTANT: Provide the EXACT image path.
"""
logger.info(f"🎨 Editing {image_path} | Instruction: {instruction}")
if not os.path.exists(image_path): return "Error: Image file not found."
if not HF_TOKEN_GLOBAL: return "Error: HuggingFace Token is missing."
def run_hf_edit():
try:
client = InferenceClient(model=EDIT_MODEL_ID, token=HF_TOKEN_GLOBAL)
image = Image.open(image_path).convert("RGB")
# --- CONSISTENCY HACKS ---
# 1. Prompt Booster: Force identity terms
full_prompt = f"photorealistic, {instruction}, same person, consistent face, high detail, 8k, sharp focus"
# 2. Strong Negatives: Prevent face changing
neg_prompt = "cartoon, painting, illustration, distorted face, changed face, different person, ugly, blur, low quality, morphing"
# 3. Strength Tuning (Crucial):
# 0.5 - 0.6 = Best for keeping face (Face won't change, but background change will be subtle)
# 0.7 - 0.8 = Face changes
# Hum 0.6 use karenge (Balance)
output_image = client.image_to_image(
image=image,
prompt=full_prompt,
negative_prompt=neg_prompt,
strength=0.6, # <--- FIXED STRENGTH (Isse loop nahi hoga, consistency maintain rahegi)
guidance_scale=7.5
)
return output_image
except Exception as e:
return str(e)
try:
result = await asyncio.to_thread(run_hf_edit)
if isinstance(result, str): return f"Edit Failed: {result}"
filename = f"static/images/edited_{int(time.time())}_{random.randint(0,999)}.png"
await asyncio.to_thread(result.save, filename)
return f"Image Edited Successfully: {filename}"
except Exception as e:
return f"System Error: {str(e)}"
tools = [web_search, generate_image, edit_image]
# --------------------------------------------------------------------------------------
# 3. Agent Logic (LOOP FIX HERE)
# --------------------------------------------------------------------------------------
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], "add_messages"]
llm = ChatOllama(model=MODEL_NAME, base_url=BASE_URL, temperature=0).bind_tools(tools)
SYSTEM_PROMPT = """You are an AI visual assistant.
1. Use `edit_image` ONLY if user provides an image path.
2. Use `generate_image` for new creations.
3. Once you call a tool, your job is DONE.
"""
async def agent_node(state: AgentState):
messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"]
response = await llm.ainvoke(messages)
return {"messages": [response]}
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", ToolNode(tools))
workflow.add_edge(START, "agent")
# --- THE LOOP FIX ---
# Logic: Agent -> Tools -> END
# Tool chalne ke baad wapis Agent ke paas mat jao. Seedha khatam karo.
workflow.add_conditional_edges("agent", lambda s: "tools" if s["messages"][-1].tool_calls else END)
workflow.add_edge("tools", END) # <--- STOP LOOP HERE
app_graph = workflow.compile(checkpointer=MemorySaver())
# --------------------------------------------------------------------------------------
# 4. API (Same logic)
# --------------------------------------------------------------------------------------
class ChatRequest(BaseModel):
query: str
thread_id: str
image_base64: Optional[str] = None
hf_token: Optional[str] = None
@app.post("/chat")
async def chat_endpoint(req: ChatRequest):
global HF_TOKEN_GLOBAL
if req.hf_token: HF_TOKEN_GLOBAL = req.hf_token
initial_msg = req.query
if req.image_base64:
try:
if "," in req.image_base64: d = req.image_base64.split(",")[1]
else: d = req.image_base64
fname = f"static/uploads/user_upload_{req.thread_id}_{int(time.time())}.png"
with open(fname, "wb") as f: f.write(base64.b64decode(d))
initial_msg = f"User uploaded an image at path: '{fname}'. Request: {req.query}"
except: pass
config = {"configurable": {"thread_id": req.thread_id}}
inputs = {"messages": [HumanMessage(content=initial_msg)]}
async def generator():
try:
async for event in app_graph.astream_events(inputs, config=config, version="v1"):
event_type = event["event"]
# Sirf Tool Output aur Final result stream karo
if event_type == "on_tool_start":
yield f"\n\n⚙️ **Processing:** {event['name']}...\n\n"
elif event_type == "on_tool_end":
out = str(event['data'].get('output'))
if "static/" in out:
path = re.search(r'(static/.*\.png)', out).group(1)
yield f"\n\n![Result]({path})\n\n"
else:
yield f"Info: {out}\n"
except Exception as e:
yield f"Error: {str(e)}"
return StreamingResponse(generator(), media_type="text/plain")
if __name__ == "__main__":
import uvicorn
os.makedirs("static/images", exist_ok=True)
os.makedirs("static/uploads", exist_ok=True)
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)