Karan6933 commited on
Commit
fc9000d
·
verified ·
1 Parent(s): 07f244f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -73
main.py CHANGED
@@ -18,7 +18,7 @@ import httpx
18
  from duckduckgo_search import DDGS
19
  from PIL import Image
20
 
21
- # --- HuggingFace Official Client (Fix for 410 Error) ---
22
  from huggingface_hub import InferenceClient
23
 
24
  # --- LangChain / AI Core ---
@@ -40,55 +40,47 @@ BASE_URL = "http://localhost:11434"
40
 
41
  HF_TOKEN_GLOBAL = os.getenv("HF_TOKEN", "")
42
 
43
- # --- NEW MODEL: SDXL Base 1.0 (Reliable on Free Tier) ---
44
- # Instruct-Pix2Pix band ho gaya, isliye hum SDXL Image-to-Image use karenge
45
- EDIT_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
46
 
47
  http_client = httpx.AsyncClient(timeout=120.0, follow_redirects=True)
48
 
49
  @asynccontextmanager
50
  async def lifespan(app: FastAPI):
51
- try:
52
- os.makedirs("static/images", exist_ok=True)
53
- os.makedirs("static/uploads", exist_ok=True)
54
- except PermissionError:
55
- logger.error("Permission denied. Check Dockerfile.")
56
  yield
57
  await http_client.aclose()
58
 
59
- app = FastAPI(title="GenAI Fixed Agent", lifespan=lifespan)
60
  app.mount("/static", StaticFiles(directory="static"), name="static")
61
 
62
  # --------------------------------------------------------------------------------------
63
- # 2. Tools
64
  # --------------------------------------------------------------------------------------
65
 
66
  @tool
67
  async def web_search(query: str) -> str:
68
  """Search the web for information."""
69
- def run_sync_search(q):
70
- try:
71
- with DDGS() as ddgs:
72
- return list(ddgs.text(q, max_results=4))
73
- except Exception as e:
74
- return str(e)
75
  try:
 
 
76
  results = await asyncio.to_thread(run_sync_search, query)
77
- if isinstance(results, str) or not results: return "No results."
78
- return "\n".join([f"Link: {r.get('href')}\nSnippet: {r.get('body')}" for r in results])
79
  except Exception as e:
80
  return f"Error: {str(e)}"
81
 
82
  @tool
83
  async def generate_image(prompt: str) -> str:
84
- """Create a NEW image from scratch (Pollinations AI)."""
85
  try:
86
  seed = random.randint(0, 99999)
87
  safe_prompt = prompt.replace(" ", "%20")
88
  url = f"https://image.pollinations.ai/prompt/{safe_prompt}?seed={seed}&nologo=true&width=1024&height=1024&model=flux"
89
  resp = await http_client.get(url)
90
  if resp.status_code != 200: return "Failed."
91
-
92
  filename = f"static/images/gen_{int(time.time())}.png"
93
  img = Image.open(io.BytesIO(resp.content))
94
  await asyncio.to_thread(img.save, filename)
@@ -96,76 +88,70 @@ async def generate_image(prompt: str) -> str:
96
  except Exception as e:
97
  return f"Error: {str(e)}"
98
 
99
- # --- FIXED EDIT TOOL (Uses huggingface_hub client) ---
100
  @tool
101
  async def edit_image(instruction: str, image_path: str) -> str:
102
  """
103
- Edits an uploaded image using Stable Diffusion XL (Image-to-Image).
104
- Best for: Changing background, style, or adding elements while keeping structure.
105
  """
106
  logger.info(f"🎨 Editing {image_path} | Instruction: {instruction}")
107
 
108
- if not os.path.exists(image_path):
109
- return "Error: Image file not found."
110
-
111
- if not HF_TOKEN_GLOBAL:
112
- return "Error: HuggingFace Token is missing. Please enter it in the UI."
113
 
114
- # Helper function to run HF Client in thread (Sync library)
115
  def run_hf_edit():
116
  try:
117
- # Initialize Client
118
  client = InferenceClient(model=EDIT_MODEL_ID, token=HF_TOKEN_GLOBAL)
119
-
120
- # Load User Image
121
  image = Image.open(image_path).convert("RGB")
122
 
123
- # SDXL Image-to-Image Call
124
- # strength: 0.0 = exact copy, 1.0 = completely new image
125
- # 0.75 is a good balance for "Make it snowy/cinematic" without losing the person completely
 
 
 
 
 
 
 
 
 
126
  output_image = client.image_to_image(
127
  image=image,
128
- prompt=instruction,
129
- negative_prompt="bad quality, distorted face, ugly, blurry, low resolution, cartoon",
130
- strength=0.75,
131
- guidance_scale=8.5
132
  )
133
  return output_image
134
  except Exception as e:
135
  return str(e)
136
 
137
  try:
138
- # Run heavy task in thread
139
  result = await asyncio.to_thread(run_hf_edit)
 
140
 
141
- if isinstance(result, str): # Error returned as string
142
- return f"Edit Failed: {result}"
143
-
144
- # Save Result
145
  filename = f"static/images/edited_{int(time.time())}_{random.randint(0,999)}.png"
146
  await asyncio.to_thread(result.save, filename)
147
  return f"Image Edited Successfully: {filename}"
148
-
149
  except Exception as e:
150
  return f"System Error: {str(e)}"
151
 
152
  tools = [web_search, generate_image, edit_image]
153
 
154
  # --------------------------------------------------------------------------------------
155
- # 3. Agent Logic
156
  # --------------------------------------------------------------------------------------
157
 
158
  class AgentState(TypedDict):
159
  messages: Annotated[List[BaseMessage], "add_messages"]
160
 
161
- llm = ChatOllama(model=MODEL_NAME, base_url=BASE_URL, temperature=0.2).bind_tools(tools)
162
-
163
- SYSTEM_PROMPT = """You are an AI visual expert.
164
 
165
- 1. **New Image:** Use `generate_image` if user asks to create/draw from scratch.
166
- 2. **Edit Image:** Use `edit_image` ONLY if user provides an image or asks to modify "this" image.
167
- - Input: User's instruction + Exact path of the uploaded image.
168
- - Note: The underlying model is SDXL.
169
  """
170
 
171
  async def agent_node(state: AgentState):
@@ -176,13 +162,19 @@ async def agent_node(state: AgentState):
176
  workflow = StateGraph(AgentState)
177
  workflow.add_node("agent", agent_node)
178
  workflow.add_node("tools", ToolNode(tools))
 
179
  workflow.add_edge(START, "agent")
 
 
 
 
180
  workflow.add_conditional_edges("agent", lambda s: "tools" if s["messages"][-1].tool_calls else END)
181
- workflow.add_edge("tools", "agent")
 
182
  app_graph = workflow.compile(checkpointer=MemorySaver())
183
 
184
  # --------------------------------------------------------------------------------------
185
- # 4. API Endpoints
186
  # --------------------------------------------------------------------------------------
187
 
188
  class ChatRequest(BaseModel):
@@ -197,19 +189,15 @@ async def chat_endpoint(req: ChatRequest):
197
  if req.hf_token: HF_TOKEN_GLOBAL = req.hf_token
198
 
199
  initial_msg = req.query
200
-
201
  if req.image_base64:
202
  try:
203
- if "," in req.image_base64: image_base64_data = req.image_base64.split(",")[1]
204
- else: image_base64_data = req.image_base64
205
-
206
- img_data = base64.b64decode(image_base64_data)
207
- filename = f"static/uploads/user_upload_{req.thread_id}_{int(time.time())}.png"
208
- with open(filename, "wb") as f: f.write(img_data)
209
 
210
- initial_msg = f"User uploaded an image at path: '{filename}'. Request: {req.query}"
211
- except Exception:
212
- pass
 
213
 
214
  config = {"configurable": {"thread_id": req.thread_id}}
215
  inputs = {"messages": [HumanMessage(content=initial_msg)]}
@@ -218,19 +206,16 @@ async def chat_endpoint(req: ChatRequest):
218
  try:
219
  async for event in app_graph.astream_events(inputs, config=config, version="v1"):
220
  event_type = event["event"]
221
- if event_type == "on_chat_model_stream":
222
- chunk = event["data"]["chunk"].content
223
- if chunk: yield chunk
224
- elif event_type == "on_tool_start":
225
  yield f"\n\n⚙️ **Processing:** {event['name']}...\n\n"
226
  elif event_type == "on_tool_end":
227
  out = str(event['data'].get('output'))
228
  if "static/" in out:
229
- match = re.search(r'(static/.*\.png)', out)
230
- path = match.group(1) if match else out
231
  yield f"\n\n![Result]({path})\n\n"
232
  else:
233
- yield f"\nInfo: {out[:100]}\n"
234
  except Exception as e:
235
  yield f"Error: {str(e)}"
236
 
 
18
  from duckduckgo_search import DDGS
19
  from PIL import Image
20
 
21
+ # --- HuggingFace Client ---
22
  from huggingface_hub import InferenceClient
23
 
24
  # --- LangChain / AI Core ---
 
40
 
41
  HF_TOKEN_GLOBAL = os.getenv("HF_TOKEN", "")
42
 
43
+ # --- BETTER MODEL FOR REALISM ---
44
+ # SDXL Base ki jagah RealVisXL use kar rahe hain (Better photorealism & face consistency)
45
+ EDIT_MODEL_ID = "SG161222/RealVisXL_V4.0"
46
 
47
  http_client = httpx.AsyncClient(timeout=120.0, follow_redirects=True)
48
 
49
  @asynccontextmanager
50
  async def lifespan(app: FastAPI):
51
+ os.makedirs("static/images", exist_ok=True)
52
+ os.makedirs("static/uploads", exist_ok=True)
 
 
 
53
  yield
54
  await http_client.aclose()
55
 
56
+ app = FastAPI(title="GenAI Stable Agent", lifespan=lifespan)
57
  app.mount("/static", StaticFiles(directory="static"), name="static")
58
 
59
  # --------------------------------------------------------------------------------------
60
+ # 2. Tools (Tuned for Consistency)
61
  # --------------------------------------------------------------------------------------
62
 
63
  @tool
64
  async def web_search(query: str) -> str:
65
  """Search the web for information."""
 
 
 
 
 
 
66
  try:
67
+ def run_sync_search(q):
68
+ with DDGS() as ddgs: return list(ddgs.text(q, max_results=3))
69
  results = await asyncio.to_thread(run_sync_search, query)
70
+ if not results: return "No results."
71
+ return "\n".join([f"Snippet: {r.get('body')}" for r in results])
72
  except Exception as e:
73
  return f"Error: {str(e)}"
74
 
75
  @tool
76
  async def generate_image(prompt: str) -> str:
77
+ """Create a NEW image from scratch (No input image)."""
78
  try:
79
  seed = random.randint(0, 99999)
80
  safe_prompt = prompt.replace(" ", "%20")
81
  url = f"https://image.pollinations.ai/prompt/{safe_prompt}?seed={seed}&nologo=true&width=1024&height=1024&model=flux"
82
  resp = await http_client.get(url)
83
  if resp.status_code != 200: return "Failed."
 
84
  filename = f"static/images/gen_{int(time.time())}.png"
85
  img = Image.open(io.BytesIO(resp.content))
86
  await asyncio.to_thread(img.save, filename)
 
88
  except Exception as e:
89
  return f"Error: {str(e)}"
90
 
 
91
  @tool
92
  async def edit_image(instruction: str, image_path: str) -> str:
93
  """
94
+ Edits the uploaded image.
95
+ IMPORTANT: Provide the EXACT image path.
96
  """
97
  logger.info(f"🎨 Editing {image_path} | Instruction: {instruction}")
98
 
99
+ if not os.path.exists(image_path): return "Error: Image file not found."
100
+ if not HF_TOKEN_GLOBAL: return "Error: HuggingFace Token is missing."
 
 
 
101
 
 
102
  def run_hf_edit():
103
  try:
 
104
  client = InferenceClient(model=EDIT_MODEL_ID, token=HF_TOKEN_GLOBAL)
 
 
105
  image = Image.open(image_path).convert("RGB")
106
 
107
+ # --- CONSISTENCY HACKS ---
108
+ # 1. Prompt Booster: Force identity terms
109
+ full_prompt = f"photorealistic, {instruction}, same person, consistent face, high detail, 8k, sharp focus"
110
+
111
+ # 2. Strong Negatives: Prevent face changing
112
+ neg_prompt = "cartoon, painting, illustration, distorted face, changed face, different person, ugly, blur, low quality, morphing"
113
+
114
+ # 3. Strength Tuning (Crucial):
115
+ # 0.5 - 0.6 = Best for keeping face (Face won't change, but background change will be subtle)
116
+ # 0.7 - 0.8 = Face changes
117
+ # Hum 0.6 use karenge (Balance)
118
+
119
  output_image = client.image_to_image(
120
  image=image,
121
+ prompt=full_prompt,
122
+ negative_prompt=neg_prompt,
123
+ strength=0.6, # <--- FIXED STRENGTH (Isse loop nahi hoga, consistency maintain rahegi)
124
+ guidance_scale=7.5
125
  )
126
  return output_image
127
  except Exception as e:
128
  return str(e)
129
 
130
  try:
 
131
  result = await asyncio.to_thread(run_hf_edit)
132
+ if isinstance(result, str): return f"Edit Failed: {result}"
133
 
 
 
 
 
134
  filename = f"static/images/edited_{int(time.time())}_{random.randint(0,999)}.png"
135
  await asyncio.to_thread(result.save, filename)
136
  return f"Image Edited Successfully: {filename}"
 
137
  except Exception as e:
138
  return f"System Error: {str(e)}"
139
 
140
  tools = [web_search, generate_image, edit_image]
141
 
142
  # --------------------------------------------------------------------------------------
143
+ # 3. Agent Logic (LOOP FIX HERE)
144
  # --------------------------------------------------------------------------------------
145
 
146
  class AgentState(TypedDict):
147
  messages: Annotated[List[BaseMessage], "add_messages"]
148
 
149
+ llm = ChatOllama(model=MODEL_NAME, base_url=BASE_URL, temperature=0).bind_tools(tools)
 
 
150
 
151
+ SYSTEM_PROMPT = """You are an AI visual assistant.
152
+ 1. Use `edit_image` ONLY if user provides an image path.
153
+ 2. Use `generate_image` for new creations.
154
+ 3. Once you call a tool, your job is DONE.
155
  """
156
 
157
  async def agent_node(state: AgentState):
 
162
  workflow = StateGraph(AgentState)
163
  workflow.add_node("agent", agent_node)
164
  workflow.add_node("tools", ToolNode(tools))
165
+
166
  workflow.add_edge(START, "agent")
167
+
168
+ # --- THE LOOP FIX ---
169
+ # Logic: Agent -> Tools -> END
170
+ # Tool chalne ke baad wapis Agent ke paas mat jao. Seedha khatam karo.
171
  workflow.add_conditional_edges("agent", lambda s: "tools" if s["messages"][-1].tool_calls else END)
172
+ workflow.add_edge("tools", END) # <--- STOP LOOP HERE
173
+
174
  app_graph = workflow.compile(checkpointer=MemorySaver())
175
 
176
  # --------------------------------------------------------------------------------------
177
+ # 4. API (Same logic)
178
  # --------------------------------------------------------------------------------------
179
 
180
  class ChatRequest(BaseModel):
 
189
  if req.hf_token: HF_TOKEN_GLOBAL = req.hf_token
190
 
191
  initial_msg = req.query
 
192
  if req.image_base64:
193
  try:
194
+ if "," in req.image_base64: d = req.image_base64.split(",")[1]
195
+ else: d = req.image_base64
 
 
 
 
196
 
197
+ fname = f"static/uploads/user_upload_{req.thread_id}_{int(time.time())}.png"
198
+ with open(fname, "wb") as f: f.write(base64.b64decode(d))
199
+ initial_msg = f"User uploaded an image at path: '{fname}'. Request: {req.query}"
200
+ except: pass
201
 
202
  config = {"configurable": {"thread_id": req.thread_id}}
203
  inputs = {"messages": [HumanMessage(content=initial_msg)]}
 
206
  try:
207
  async for event in app_graph.astream_events(inputs, config=config, version="v1"):
208
  event_type = event["event"]
209
+ # Sirf Tool Output aur Final result stream karo
210
+ if event_type == "on_tool_start":
 
 
211
  yield f"\n\n⚙️ **Processing:** {event['name']}...\n\n"
212
  elif event_type == "on_tool_end":
213
  out = str(event['data'].get('output'))
214
  if "static/" in out:
215
+ path = re.search(r'(static/.*\.png)', out).group(1)
 
216
  yield f"\n\n![Result]({path})\n\n"
217
  else:
218
+ yield f"Info: {out}\n"
219
  except Exception as e:
220
  yield f"Error: {str(e)}"
221