vikashmakeit commited on
Commit
ee597f9
·
verified ·
1 Parent(s): 775c082

Add agentic refinement loop module

Browse files
Files changed (1) hide show
  1. refinement_loop.py +331 -0
refinement_loop.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agentic Refinement Loop: Image → Pattern → 3D → Projection → Compare → Refine
3
+
4
+ Iteratively refines garment pattern parameters until the 3D garment projection
5
+ matches the original input image. Uses:
6
+ - Matplotlib 3D rendering for projection (CPU, no Chrome)
7
+ - SSIM + Edge-SSIM for fast similarity gating (CPU)
8
+ - VLM (via HF Inference API) for visual comparison and parameter adjustment
9
+ - Keep-best tracking to prevent oscillation
10
+ """
11
+ import json, os, copy, base64, io, re
12
+ import numpy as np
13
+ import matplotlib
14
+ matplotlib.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
17
+ from PIL import Image
18
+ from typing import Dict, List, Tuple, Optional
19
+
20
+
21
+ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
22
+ """Render a Plotly 3D figure to a PIL image using matplotlib."""
23
+ fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
24
+ ax = fig.add_subplot(111, projection='3d')
25
+
26
+ for trace in plotly_fig.data:
27
+ try:
28
+ if trace.name == "Body":
29
+ x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
30
+ ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0',
31
+ edgecolor='none', shade=False)
32
+ elif hasattr(trace, 'i') and trace.i is not None:
33
+ verts_x = np.array(trace.x, dtype=float)
34
+ verts_y = np.array(trace.y, dtype=float)
35
+ verts_z = np.array(trace.z, dtype=float)
36
+ faces_i = np.array(trace.i, dtype=int)
37
+ faces_j = np.array(trace.j, dtype=int)
38
+ faces_k = np.array(trace.k, dtype=int)
39
+ verts = list(zip(verts_x, verts_y, verts_z))
40
+ faces = [[verts[i], verts[j], verts[k]]
41
+ for i, j, k in zip(faces_i, faces_j, faces_k)]
42
+ color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
43
+ poly = Poly3DCollection(faces, alpha=0.75,
44
+ facecolor=color, edgecolor='none')
45
+ ax.add_collection3d(poly)
46
+ elif hasattr(trace, 'x') and trace.x is not None:
47
+ x = np.array(trace.x, dtype=float)
48
+ y = np.array(trace.y, dtype=float)
49
+ z = np.array(trace.z, dtype=float)
50
+ if x.ndim == 2:
51
+ ax.plot_surface(x, y, z, alpha=0.6, color='#4A90D9',
52
+ edgecolor='none', shade=True)
53
+ except Exception:
54
+ continue
55
+
56
+ ax.view_init(elev=elev, azim=azim)
57
+ ax.set_xlim(-35, 35)
58
+ ax.set_ylim(-35, 35)
59
+ ax.set_zlim(0, 180)
60
+ ax.axis('off')
61
+ ax.set_facecolor('white')
62
+ fig.patch.set_facecolor('white')
63
+
64
+ buf = io.BytesIO()
65
+ fig.savefig(buf, format='png', dpi=100, bbox_inches='tight',
66
+ facecolor='white', pad_inches=0.1)
67
+ plt.close(fig)
68
+ buf.seek(0)
69
+ return Image.open(buf).convert('RGB')
70
+
71
+
72
+ def compute_similarity(img1: Image.Image, img2: Image.Image,
73
+ size=(256, 256)) -> Dict:
74
+ """Compute CPU-based similarity metrics between two images."""
75
+ from skimage.metrics import structural_similarity as ssim_fn
76
+ from skimage import filters
77
+
78
+ arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
79
+ arr2 = np.array(img2.resize(size).convert('RGB'), dtype=float)
80
+
81
+ ssim_val = ssim_fn(arr1 / 255.0, arr2 / 255.0, channel_axis=2, data_range=1.0)
82
+ mse_val = 1.0 - np.mean((arr1 - arr2) ** 2) / (255.0 ** 2)
83
+
84
+ gray1 = arr1.mean(axis=2) / 255.0
85
+ gray2 = arr2.mean(axis=2) / 255.0
86
+ edges1 = filters.sobel(gray1)
87
+ edges2 = filters.sobel(gray2)
88
+ edge_ssim_val = ssim_fn(edges1, edges2, data_range=1.0)
89
+
90
+ composite = 0.4 * ssim_val + 0.3 * mse_val + 0.3 * edge_ssim_val
91
+
92
+ return {
93
+ 'ssim': round(float(ssim_val), 4),
94
+ 'mse': round(float(mse_val), 4),
95
+ 'edge_ssim': round(float(edge_ssim_val), 4),
96
+ 'composite': round(float(composite), 4),
97
+ }
98
+
99
+
100
+ def _image_to_b64(img: Image.Image, max_dim=512) -> str:
101
+ if max(img.size) > max_dim:
102
+ ratio = max_dim / max(img.size)
103
+ img = img.resize((int(img.size[0] * ratio), int(img.size[1] * ratio)), Image.LANCZOS)
104
+ buf = io.BytesIO()
105
+ img.convert('RGB').save(buf, format='JPEG', quality=85)
106
+ return base64.b64encode(buf.getvalue()).decode('utf-8')
107
+
108
+
109
+ def vlm_compare_and_adjust(original_img, projection_img, current_params,
110
+ iteration, metrics, hf_token):
111
+ """Use VLM to compare images and suggest parameter adjustments."""
112
+ import requests
113
+
114
+ orig_b64 = _image_to_b64(original_img)
115
+ proj_b64 = _image_to_b64(projection_img)
116
+
117
+ display_params = {k: v for k, v in current_params.items() if k != '_model_used'}
118
+
119
+ prompt = f"""You are a garment pattern expert doing iterative refinement.
120
+
121
+ Iteration {iteration}. Current similarity: SSIM={metrics['ssim']:.3f}, Edge={metrics['edge_ssim']:.3f}, Composite={metrics['composite']:.3f}
122
+
123
+ Current garment parameters:
124
+ {json.dumps(display_params, indent=2)}
125
+
126
+ Image 1 = ORIGINAL garment photo. Image 2 = 3D pattern projection.
127
+
128
+ Compare carefully. Identify differences in: silhouette, sleeve length/width, neckline/collar, hem length/flare, fit.
129
+
130
+ Return ONLY valid JSON (no markdown):
131
+ {{"differences": ["diff1", "diff2"], "adjustments": {{"param": value}}, "confidence": 0.0_to_1.0, "converged": true_or_false}}
132
+
133
+ Only adjust params that exist in current params. Set converged=true if sufficiently similar."""
134
+
135
+ messages = [{"role": "user", "content": [
136
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{orig_b64}"}},
137
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{proj_b64}"}},
138
+ {"type": "text", "text": prompt}
139
+ ]}]
140
+
141
+ models = [
142
+ ("Qwen/Qwen3.5-9B", "together"),
143
+ ("google/gemma-4-31B-it", "together"),
144
+ ("moonshotai/Kimi-K2.5", "together"),
145
+ ]
146
+
147
+ for model_id, provider in models:
148
+ try:
149
+ url = f"https://router.huggingface.co/{provider}/v1/chat/completions"
150
+ headers = {"Authorization": f"Bearer {hf_token}", "Content-Type": "application/json"}
151
+ payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
152
+
153
+ resp = requests.post(url, headers=headers, json=payload, timeout=120)
154
+ if resp.status_code == 200:
155
+ text = resp.json()['choices'][0]['message'].get('content', '')
156
+ if not text:
157
+ text = resp.json()['choices'][0]['message'].get('reasoning', '')
158
+
159
+ json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
160
+ if json_match:
161
+ json_str = json_match.group(1)
162
+ else:
163
+ json_match = re.search(r'\{[\s\S]*\}', text)
164
+ if json_match:
165
+ json_str = json_match.group()
166
+ else:
167
+ continue
168
+
169
+ result = json.loads(json_str)
170
+ result['_model'] = model_id.split('/')[-1]
171
+ return result
172
+ except Exception as e:
173
+ print(f"[Refine] {model_id}: {e}")
174
+ continue
175
+ return None
176
+
177
+
178
+ def apply_adjustments(analysis, adjustments, lr=0.7):
179
+ """Apply parameter adjustments with damping factor."""
180
+ updated = copy.deepcopy(analysis)
181
+ measurements = updated.get('measurements', {})
182
+ features = updated.get('features', {})
183
+
184
+ for param, new_value in adjustments.items():
185
+ if param in measurements:
186
+ old_value = measurements[param]
187
+ if isinstance(old_value, (int, float)) and isinstance(new_value, (int, float)):
188
+ measurements[param] = round(old_value + lr * (new_value - old_value), 1)
189
+ else:
190
+ measurements[param] = new_value
191
+ elif param in features:
192
+ features[param] = new_value
193
+ elif param == 'garment_type':
194
+ updated['garment_type'] = new_value
195
+
196
+ updated['measurements'] = measurements
197
+ updated['features'] = features
198
+ return updated
199
+
200
+
201
+ def refinement_loop(original_image, initial_analysis, generate_fn,
202
+ max_iterations=8, target_composite=0.82,
203
+ plateau_threshold=0.005, plateau_patience=3, lr=0.7):
204
+ """Run the agentic refinement loop.
205
+
206
+ Args:
207
+ original_image: PIL Image of the garment
208
+ initial_analysis: dict with garment_type, measurements, features
209
+ generate_fn: function(analysis) → (pattern_img, fig_3d, summary, json_str)
210
+ max_iterations: max steps
211
+ target_composite: similarity target
212
+ lr: damping factor
213
+
214
+ Returns:
215
+ dict with best_analysis, history, scores, converged, etc.
216
+ """
217
+ hf_token = os.environ.get("HF_TOKEN", "")
218
+
219
+ current_analysis = copy.deepcopy(initial_analysis)
220
+ best_analysis = copy.deepcopy(initial_analysis)
221
+ best_score = -1.0
222
+ history = []
223
+ scores = []
224
+ plateau_count = 0
225
+
226
+ for iteration in range(1, max_iterations + 1):
227
+ step = {"iteration": iteration}
228
+
229
+ # Generate pattern + 3D
230
+ try:
231
+ pattern_img, fig_3d, summary, json_str = generate_fn(current_analysis)
232
+ except Exception as e:
233
+ step["status"] = "error"
234
+ step["reason"] = f"Generation failed: {e}"
235
+ history.append(step)
236
+ break
237
+
238
+ # Render 3D → 2D
239
+ try:
240
+ projection = render_3d_to_image(fig_3d, elev=15, azim=0)
241
+ except Exception as e:
242
+ step["status"] = "error"
243
+ step["reason"] = f"Rendering failed: {e}"
244
+ history.append(step)
245
+ break
246
+
247
+ # Compute similarity
248
+ metrics = compute_similarity(original_image, projection)
249
+ step["metrics"] = metrics
250
+ step["projection"] = projection
251
+ step["pattern_image"] = pattern_img
252
+ step["fig_3d"] = fig_3d
253
+ step["params"] = copy.deepcopy(current_analysis)
254
+ scores.append(metrics['composite'])
255
+
256
+ # Keep-best
257
+ if metrics['composite'] > best_score:
258
+ best_score = metrics['composite']
259
+ best_analysis = copy.deepcopy(current_analysis)
260
+ step["new_best"] = True
261
+ else:
262
+ step["new_best"] = False
263
+
264
+ # Convergence: target reached
265
+ if metrics['composite'] >= target_composite:
266
+ step["status"] = "converged"
267
+ step["reason"] = f"Target {target_composite} reached: {metrics['composite']:.4f}"
268
+ history.append(step)
269
+ break
270
+
271
+ # Convergence: plateau
272
+ if len(scores) >= 2:
273
+ if abs(scores[-1] - scores[-2]) < plateau_threshold:
274
+ plateau_count += 1
275
+ else:
276
+ plateau_count = 0
277
+ if plateau_count >= plateau_patience:
278
+ step["status"] = "plateau"
279
+ step["reason"] = f"Plateau for {plateau_patience} iterations"
280
+ history.append(step)
281
+ break
282
+
283
+ # VLM feedback
284
+ if hf_token:
285
+ vlm_result = vlm_compare_and_adjust(
286
+ original_image, projection, current_analysis,
287
+ iteration, metrics, hf_token)
288
+ else:
289
+ vlm_result = None
290
+
291
+ if vlm_result:
292
+ step["vlm_differences"] = vlm_result.get('differences', [])
293
+ step["vlm_confidence"] = vlm_result.get('confidence', 0)
294
+
295
+ if vlm_result.get('converged', False):
296
+ step["status"] = "vlm_converged"
297
+ step["reason"] = "VLM declared convergence"
298
+ history.append(step)
299
+ break
300
+
301
+ if vlm_result.get('confidence', 1.0) < 0.2:
302
+ step["status"] = "low_confidence"
303
+ step["reason"] = f"VLM confidence: {vlm_result['confidence']}"
304
+ history.append(step)
305
+ break
306
+
307
+ adjustments = vlm_result.get('adjustments', {})
308
+ if adjustments:
309
+ current_analysis = apply_adjustments(current_analysis, adjustments, lr=lr)
310
+ step["adjustments"] = adjustments
311
+ else:
312
+ step["status"] = "no_vlm"
313
+ step["reason"] = "No VLM available (set HF_TOKEN)"
314
+ history.append(step)
315
+ break
316
+
317
+ step["status"] = "continuing"
318
+ history.append(step)
319
+
320
+ if history and history[-1].get("status") == "continuing":
321
+ history[-1]["status"] = "max_iterations"
322
+ history[-1]["reason"] = f"Max {max_iterations} iterations reached"
323
+
324
+ return {
325
+ "best_analysis": best_analysis,
326
+ "best_score": best_score,
327
+ "history": history,
328
+ "total_iterations": len(history),
329
+ "converged": any(h.get("status") in ("converged", "vlm_converged") for h in history),
330
+ "scores": scores,
331
+ }