vikashmakeit commited on
Commit
99a8fd3
·
verified ·
1 Parent(s): 70dedbc

Fix VLM models in refinement_loop.py: use actual vision models with correct providers

Browse files
Files changed (1) hide show
  1. refinement_loop.py +52 -170
refinement_loop.py CHANGED
@@ -1,12 +1,5 @@
1
  """
2
  Agentic Refinement Loop: Image → Pattern → 3D → Projection → Compare → Refine
3
-
4
- Iteratively refines garment pattern parameters until the 3D garment projection
5
- matches the original input image. Uses:
6
- - Matplotlib 3D rendering for projection (CPU, no Chrome)
7
- - SSIM + Edge-SSIM for fast similarity gating (CPU)
8
- - VLM (via HF Inference API) for visual comparison and parameter adjustment
9
- - Keep-best tracking to prevent oscillation
10
  """
11
  import json, os, copy, base64, io, re
12
  import numpy as np
@@ -22,82 +15,50 @@ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
22
  """Render a Plotly 3D figure to a PIL image using matplotlib."""
23
  fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
24
  ax = fig.add_subplot(111, projection='3d')
25
-
26
  for trace in plotly_fig.data:
27
  try:
28
  if trace.name == "Body":
29
  x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
30
- ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0',
31
- edgecolor='none', shade=False)
32
  elif hasattr(trace, 'i') and trace.i is not None:
33
- verts_x = np.array(trace.x, dtype=float)
34
- verts_y = np.array(trace.y, dtype=float)
35
- verts_z = np.array(trace.z, dtype=float)
36
- faces_i = np.array(trace.i, dtype=int)
37
- faces_j = np.array(trace.j, dtype=int)
38
- faces_k = np.array(trace.k, dtype=int)
39
  verts = list(zip(verts_x, verts_y, verts_z))
40
- faces = [[verts[i], verts[j], verts[k]]
41
- for i, j, k in zip(faces_i, faces_j, faces_k)]
42
  color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
43
- poly = Poly3DCollection(faces, alpha=0.75,
44
- facecolor=color, edgecolor='none')
45
- ax.add_collection3d(poly)
46
  elif hasattr(trace, 'x') and trace.x is not None:
47
- x = np.array(trace.x, dtype=float)
48
- y = np.array(trace.y, dtype=float)
49
- z = np.array(trace.z, dtype=float)
50
  if x.ndim == 2:
51
- ax.plot_surface(x, y, z, alpha=0.6, color='#4A90D9',
52
- edgecolor='none', shade=True)
53
  except Exception:
54
  continue
55
-
56
  ax.view_init(elev=elev, azim=azim)
57
- ax.set_xlim(-35, 35)
58
- ax.set_ylim(-35, 35)
59
- ax.set_zlim(0, 180)
60
- ax.axis('off')
61
- ax.set_facecolor('white')
62
- fig.patch.set_facecolor('white')
63
-
64
  buf = io.BytesIO()
65
- fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
66
- facecolor='white', pad_inches=0.1)
67
- plt.close(fig)
68
- buf.seek(0)
69
  return Image.open(buf).convert('RGB')
70
 
71
 
72
- def compute_similarity(img1: Image.Image, img2: Image.Image,
73
- size=(256, 256)) -> Dict:
74
- """Compute CPU-based similarity metrics between two images."""
75
  from skimage.metrics import structural_similarity as ssim_fn
76
  from skimage import filters
77
-
78
  arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
79
  arr2 = np.array(img2.resize(size).convert('RGB'), dtype=float)
80
-
81
  ssim_val = ssim_fn(arr1 / 255.0, arr2 / 255.0, channel_axis=2, data_range=1.0)
82
  mse_val = 1.0 - np.mean((arr1 - arr2) ** 2) / (255.0 ** 2)
83
-
84
- gray1 = arr1.mean(axis=2) / 255.0
85
- gray2 = arr2.mean(axis=2) / 255.0
86
- edges1 = filters.sobel(gray1)
87
- edges2 = filters.sobel(gray2)
88
  edge_ssim_val = ssim_fn(edges1, edges2, data_range=1.0)
89
-
90
  composite = 0.4 * ssim_val + 0.3 * mse_val + 0.3 * edge_ssim_val
91
-
92
- return {
93
- 'ssim': round(float(ssim_val), 4),
94
- 'mse': round(float(mse_val), 4),
95
- 'edge_ssim': round(float(edge_ssim_val), 4),
96
- 'composite': round(float(composite), 4),
97
- }
98
 
99
 
100
- def _image_to_b64(img: Image.Image, max_dim=512) -> str:
101
  if max(img.size) > max_dim:
102
  ratio = max_dim / max(img.size)
103
  img = img.resize((int(img.size[0] * ratio), int(img.size[1] * ratio)), Image.LANCZOS)
@@ -110,10 +71,8 @@ def vlm_compare_and_adjust(original_img, projection_img, current_params,
110
  iteration, metrics, hf_token):
111
  """Use VLM to compare images and suggest parameter adjustments."""
112
  import requests
113
-
114
  orig_b64 = _image_to_b64(original_img)
115
  proj_b64 = _image_to_b64(projection_img)
116
-
117
  display_params = {k: v for k, v in current_params.items() if k != '_model_used'}
118
 
119
  prompt = f"""You are a garment pattern expert doing iterative refinement.
@@ -138,10 +97,11 @@ Only adjust params that exist in current params. Set converged=true if sufficien
138
  {"type": "text", "text": prompt}
139
  ]}]
140
 
 
141
  models = [
142
- ("Qwen/Qwen3.5-9B", "together"),
143
- ("google/gemma-4-31B-it", "together"),
144
- ("moonshotai/Kimi-K2.5", "together"),
145
  ]
146
 
147
  for model_id, provider in models:
@@ -149,38 +109,30 @@ Only adjust params that exist in current params. Set converged=true if sufficien
149
  url = f"https://router.huggingface.co/{provider}/v1/chat/completions"
150
  headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
151
  payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
152
-
153
  resp = requests.post(url, headers=headers, json=payload, timeout=120)
154
  if resp.status_code == 200:
155
  text = resp.json()['choices'][0]['message'].get('content', '')
156
- if not text:
157
- text = resp.json()['choices'][0]['message'].get('reasoning', '')
158
-
159
  json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
160
- if json_match:
161
- json_str = json_match.group(1)
162
- else:
163
  json_match = re.search(r'\{[\s\S]*\}', text)
164
- if json_match:
165
- json_str = json_match.group()
166
- else:
167
- continue
168
-
169
  result = json.loads(json_str)
170
  result['_model'] = model_id.split('/')[-1]
171
  return result
 
 
172
  except Exception as e:
173
- print(f"[Refine] {model_id}: {e}")
174
- continue
175
  return None
176
 
177
 
178
  def apply_adjustments(analysis, adjustments, lr=0.7):
179
- """Apply parameter adjustments with damping factor."""
180
  updated = copy.deepcopy(analysis)
181
  measurements = updated.get('measurements', {})
182
  features = updated.get('features', {})
183
-
184
  for param, new_value in adjustments.items():
185
  if param in measurements:
186
  old_value = measurements[param]
@@ -192,7 +144,6 @@ def apply_adjustments(analysis, adjustments, lr=0.7):
192
  features[param] = new_value
193
  elif param == 'garment_type':
194
  updated['garment_type'] = new_value
195
-
196
  updated['measurements'] = measurements
197
  updated['features'] = features
198
  return updated
@@ -201,131 +152,62 @@ def apply_adjustments(analysis, adjustments, lr=0.7):
201
  def refinement_loop(original_image, initial_analysis, generate_fn,
202
  max_iterations=8, target_composite=0.82,
203
  plateau_threshold=0.005, plateau_patience=3, lr=0.7):
204
- """Run the agentic refinement loop.
205
-
206
- Args:
207
- original_image: PIL Image of the garment
208
- initial_analysis: dict with garment_type, measurements, features
209
- generate_fn: function(analysis) → (pattern_img, fig_3d, summary, json_str)
210
- max_iterations: max steps
211
- target_composite: similarity target
212
- lr: damping factor
213
-
214
- Returns:
215
- dict with best_analysis, history, scores, converged, etc.
216
- """
217
  hf_token = os.environ.get("HF_TOKEN", "")
218
-
219
  current_analysis = copy.deepcopy(initial_analysis)
220
  best_analysis = copy.deepcopy(initial_analysis)
221
  best_score = -1.0
222
- history = []
223
- scores = []
224
- plateau_count = 0
225
 
226
  for iteration in range(1, max_iterations + 1):
227
  step = {"iteration": iteration}
228
-
229
- # Generate pattern + 3D
230
  try:
231
  pattern_img, fig_3d, summary, json_str = generate_fn(current_analysis)
232
  except Exception as e:
233
- step["status"] = "error"
234
- step["reason"] = f"Generation failed: {e}"
235
- history.append(step)
236
- break
237
-
238
- # Render 3D → 2D
239
  try:
240
  projection = render_3d_to_image(fig_3d, elev=15, azim=0)
241
  except Exception as e:
242
- step["status"] = "error"
243
- step["reason"] = f"Rendering failed: {e}"
244
- history.append(step)
245
- break
246
 
247
- # Compute similarity
248
  metrics = compute_similarity(original_image, projection)
249
- step["metrics"] = metrics
250
- step["projection"] = projection
251
- step["pattern_image"] = pattern_img
252
- step["fig_3d"] = fig_3d
253
- step["params"] = copy.deepcopy(current_analysis)
254
  scores.append(metrics['composite'])
255
 
256
- # Keep-best
257
  if metrics['composite'] > best_score:
258
- best_score = metrics['composite']
259
- best_analysis = copy.deepcopy(current_analysis)
260
- step["new_best"] = True
261
  else:
262
  step["new_best"] = False
263
 
264
- # Convergence: target reached
265
  if metrics['composite'] >= target_composite:
266
- step["status"] = "converged"
267
- step["reason"] = f"Target {target_composite} reached: {metrics['composite']:.4f}"
268
- history.append(step)
269
- break
270
 
271
- # Convergence: plateau
272
  if len(scores) >= 2:
273
- if abs(scores[-1] - scores[-2]) < plateau_threshold:
274
- plateau_count += 1
275
- else:
276
- plateau_count = 0
277
  if plateau_count >= plateau_patience:
278
- step["status"] = "plateau"
279
- step["reason"] = f"Plateau for {plateau_patience} iterations"
280
- history.append(step)
281
- break
282
-
283
- # VLM feedback
284
- if hf_token:
285
- vlm_result = vlm_compare_and_adjust(
286
- original_image, projection, current_analysis,
287
- iteration, metrics, hf_token)
288
- else:
289
- vlm_result = None
290
 
291
  if vlm_result:
292
  step["vlm_differences"] = vlm_result.get('differences', [])
293
  step["vlm_confidence"] = vlm_result.get('confidence', 0)
294
-
295
  if vlm_result.get('converged', False):
296
- step["status"] = "vlm_converged"
297
- step["reason"] = "VLM declared convergence"
298
- history.append(step)
299
- break
300
-
301
  if vlm_result.get('confidence', 1.0) < 0.2:
302
- step["status"] = "low_confidence"
303
- step["reason"] = f"VLM confidence: {vlm_result['confidence']}"
304
- history.append(step)
305
- break
306
-
307
  adjustments = vlm_result.get('adjustments', {})
308
  if adjustments:
309
- current_analysis = apply_adjustments(current_analysis, adjustments, lr=lr)
310
- step["adjustments"] = adjustments
311
  else:
312
- step["status"] = "no_vlm"
313
- step["reason"] = "No VLM available (set HF_TOKEN)"
314
- history.append(step)
315
- break
316
 
317
- step["status"] = "continuing"
318
- history.append(step)
319
 
320
  if history and history[-1].get("status") == "continuing":
321
- history[-1]["status"] = "max_iterations"
322
- history[-1]["reason"] = f"Max {max_iterations} iterations reached"
323
-
324
- return {
325
- "best_analysis": best_analysis,
326
- "best_score": best_score,
327
- "history": history,
328
- "total_iterations": len(history),
329
- "converged": any(h.get("status") in ("converged", "vlm_converged") for h in history),
330
- "scores": scores,
331
- }
 
1
  """
2
  Agentic Refinement Loop: Image → Pattern → 3D → Projection → Compare → Refine
 
 
 
 
 
 
 
3
  """
4
  import json, os, copy, base64, io, re
5
  import numpy as np
 
15
  """Render a Plotly 3D figure to a PIL image using matplotlib."""
16
  fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
17
  ax = fig.add_subplot(111, projection='3d')
 
18
  for trace in plotly_fig.data:
19
  try:
20
  if trace.name == "Body":
21
  x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
22
+ ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0', edgecolor='none', shade=False)
 
23
  elif hasattr(trace, 'i') and trace.i is not None:
24
+ verts_x, verts_y, verts_z = np.array(trace.x, dtype=float), np.array(trace.y, dtype=float), np.array(trace.z, dtype=float)
25
+ faces_i, faces_j, faces_k = np.array(trace.i, dtype=int), np.array(trace.j, dtype=int), np.array(trace.k, dtype=int)
 
 
 
 
26
  verts = list(zip(verts_x, verts_y, verts_z))
27
+ faces = [[verts[i], verts[j], verts[k]] for i, j, k in zip(faces_i, faces_j, faces_k)]
 
28
  color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
29
+ ax.add_collection3d(Poly3DCollection(faces, alpha=0.75, facecolor=color, edgecolor='none'))
 
 
30
  elif hasattr(trace, 'x') and trace.x is not None:
31
+ x, y, z = np.array(trace.x, dtype=float), np.array(trace.y, dtype=float), np.array(trace.z, dtype=float)
 
 
32
  if x.ndim == 2:
33
+ ax.plot_surface(x, y, z, alpha=0.6, color='#4A90D9', edgecolor='none', shade=True)
 
34
  except Exception:
35
  continue
 
36
  ax.view_init(elev=elev, azim=azim)
37
+ ax.set_xlim(-35, 35); ax.set_ylim(-35, 35); ax.set_zlim(0, 180)
38
+ ax.axis('off'); ax.set_facecolor('white'); fig.patch.set_facecolor('white')
 
 
 
 
 
39
  buf = io.BytesIO()
40
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor='white', pad_inches=0.1)
41
+ plt.close(fig); buf.seek(0)
 
 
42
  return Image.open(buf).convert('RGB')
43
 
44
 
45
+ def compute_similarity(img1, img2, size=(256, 256)):
46
+ """Compute CPU-based similarity metrics."""
 
47
  from skimage.metrics import structural_similarity as ssim_fn
48
  from skimage import filters
 
49
  arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
50
  arr2 = np.array(img2.resize(size).convert('RGB'), dtype=float)
 
51
  ssim_val = ssim_fn(arr1 / 255.0, arr2 / 255.0, channel_axis=2, data_range=1.0)
52
  mse_val = 1.0 - np.mean((arr1 - arr2) ** 2) / (255.0 ** 2)
53
+ edges1 = filters.sobel(arr1.mean(axis=2) / 255.0)
54
+ edges2 = filters.sobel(arr2.mean(axis=2) / 255.0)
 
 
 
55
  edge_ssim_val = ssim_fn(edges1, edges2, data_range=1.0)
 
56
  composite = 0.4 * ssim_val + 0.3 * mse_val + 0.3 * edge_ssim_val
57
+ return {'ssim': round(float(ssim_val), 4), 'mse': round(float(mse_val), 4),
58
+ 'edge_ssim': round(float(edge_ssim_val), 4), 'composite': round(float(composite), 4)}
 
 
 
 
 
59
 
60
 
61
+ def _image_to_b64(img, max_dim=512):
62
  if max(img.size) > max_dim:
63
  ratio = max_dim / max(img.size)
64
  img = img.resize((int(img.size[0] * ratio), int(img.size[1] * ratio)), Image.LANCZOS)
 
71
  iteration, metrics, hf_token):
72
  """Use VLM to compare images and suggest parameter adjustments."""
73
  import requests
 
74
  orig_b64 = _image_to_b64(original_img)
75
  proj_b64 = _image_to_b64(projection_img)
 
76
  display_params = {k: v for k, v in current_params.items() if k != '_model_used'}
77
 
78
  prompt = f"""You are a garment pattern expert doing iterative refinement.
 
97
  {"type": "text", "text": prompt}
98
  ]}]
99
 
100
+ # Use actual VLMs with correct providers
101
  models = [
102
+ ("Qwen/Qwen2.5-VL-72B-Instruct", "together"),
103
+ ("google/gemma-4-31B-it", "novita"),
104
+ ("moonshotai/Kimi-K2.5", "fireworks-ai"),
105
  ]
106
 
107
  for model_id, provider in models:
 
109
  url = f"https://router.huggingface.co/{provider}/v1/chat/completions"
110
  headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
111
  payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
 
112
  resp = requests.post(url, headers=headers, json=payload, timeout=120)
113
  if resp.status_code == 200:
114
  text = resp.json()['choices'][0]['message'].get('content', '')
115
+ if not text: text = resp.json()['choices'][0]['message'].get('reasoning', '')
 
 
116
  json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
117
+ json_str = json_match.group(1) if json_match else None
118
+ if not json_str:
 
119
  json_match = re.search(r'\{[\s\S]*\}', text)
120
+ json_str = json_match.group() if json_match else None
121
+ if not json_str: continue
 
 
 
122
  result = json.loads(json_str)
123
  result['_model'] = model_id.split('/')[-1]
124
  return result
125
+ else:
126
+ print(f"[Refine] {model_id} via {provider}: HTTP {resp.status_code}")
127
  except Exception as e:
128
+ print(f"[Refine] {model_id}: {e}"); continue
 
129
  return None
130
 
131
 
132
  def apply_adjustments(analysis, adjustments, lr=0.7):
 
133
  updated = copy.deepcopy(analysis)
134
  measurements = updated.get('measurements', {})
135
  features = updated.get('features', {})
 
136
  for param, new_value in adjustments.items():
137
  if param in measurements:
138
  old_value = measurements[param]
 
144
  features[param] = new_value
145
  elif param == 'garment_type':
146
  updated['garment_type'] = new_value
 
147
  updated['measurements'] = measurements
148
  updated['features'] = features
149
  return updated
 
152
  def refinement_loop(original_image, initial_analysis, generate_fn,
153
  max_iterations=8, target_composite=0.82,
154
  plateau_threshold=0.005, plateau_patience=3, lr=0.7):
155
+ """Run the agentic refinement loop."""
 
 
 
 
 
 
 
 
 
 
 
 
156
  hf_token = os.environ.get("HF_TOKEN", "")
 
157
  current_analysis = copy.deepcopy(initial_analysis)
158
  best_analysis = copy.deepcopy(initial_analysis)
159
  best_score = -1.0
160
+ history, scores, plateau_count = [], [], 0
 
 
161
 
162
  for iteration in range(1, max_iterations + 1):
163
  step = {"iteration": iteration}
 
 
164
  try:
165
  pattern_img, fig_3d, summary, json_str = generate_fn(current_analysis)
166
  except Exception as e:
167
+ step["status"] = "error"; step["reason"] = f"Generation failed: {e}"; history.append(step); break
 
 
 
 
 
168
  try:
169
  projection = render_3d_to_image(fig_3d, elev=15, azim=0)
170
  except Exception as e:
171
+ step["status"] = "error"; step["reason"] = f"Rendering failed: {e}"; history.append(step); break
 
 
 
172
 
 
173
  metrics = compute_similarity(original_image, projection)
174
+ step.update({"metrics": metrics, "projection": projection, "pattern_image": pattern_img,
175
+ "fig_3d": fig_3d, "params": copy.deepcopy(current_analysis)})
 
 
 
176
  scores.append(metrics['composite'])
177
 
 
178
  if metrics['composite'] > best_score:
179
+ best_score = metrics['composite']; best_analysis = copy.deepcopy(current_analysis); step["new_best"] = True
 
 
180
  else:
181
  step["new_best"] = False
182
 
 
183
  if metrics['composite'] >= target_composite:
184
+ step["status"] = "converged"; step["reason"] = f"Target {target_composite} reached: {metrics['composite']:.4f}"; history.append(step); break
 
 
 
185
 
 
186
  if len(scores) >= 2:
187
+ if abs(scores[-1] - scores[-2]) < plateau_threshold: plateau_count += 1
188
+ else: plateau_count = 0
 
 
189
  if plateau_count >= plateau_patience:
190
+ step["status"] = "plateau"; step["reason"] = f"Plateau for {plateau_patience} iterations"; history.append(step); break
191
+
192
+ vlm_result = vlm_compare_and_adjust(original_image, projection, current_analysis, iteration, metrics, hf_token) if hf_token else None
 
 
 
 
 
 
 
 
 
193
 
194
  if vlm_result:
195
  step["vlm_differences"] = vlm_result.get('differences', [])
196
  step["vlm_confidence"] = vlm_result.get('confidence', 0)
 
197
  if vlm_result.get('converged', False):
198
+ step["status"] = "vlm_converged"; step["reason"] = "VLM declared convergence"; history.append(step); break
 
 
 
 
199
  if vlm_result.get('confidence', 1.0) < 0.2:
200
+ step["status"] = "low_confidence"; step["reason"] = f"VLM confidence: {vlm_result['confidence']}"; history.append(step); break
 
 
 
 
201
  adjustments = vlm_result.get('adjustments', {})
202
  if adjustments:
203
+ current_analysis = apply_adjustments(current_analysis, adjustments, lr=lr); step["adjustments"] = adjustments
 
204
  else:
205
+ step["status"] = "no_vlm"; step["reason"] = "No VLM available (set HF_TOKEN)"; history.append(step); break
 
 
 
206
 
207
+ step["status"] = "continuing"; history.append(step)
 
208
 
209
  if history and history[-1].get("status") == "continuing":
210
+ history[-1]["status"] = "max_iterations"; history[-1]["reason"] = f"Max {max_iterations} iterations reached"
211
+
212
+ return {"best_analysis": best_analysis, "best_score": best_score, "history": history,
213
+ "total_iterations": len(history), "converged": any(h.get("status") in ("converged", "vlm_converged") for h in history), "scores": scores}