vikashmakeit commited on
Commit
8e1bee3
·
verified ·
1 Parent(s): e9ba07a

Fix VLM models in refinement_loop.py to match app.py (Llama-4-Scout, Kimi-K2.6, Qwen3.5-9B)

Browse files
Files changed (1) hide show
  1. refinement_loop.py +11 -14
refinement_loop.py CHANGED
@@ -12,7 +12,6 @@ from typing import Dict, List, Tuple, Optional
12
 
13
 
14
  def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
15
- """Render a Plotly 3D figure to a PIL image using matplotlib."""
16
  fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
17
  ax = fig.add_subplot(111, projection='3d')
18
  for trace in plotly_fig.data:
@@ -21,10 +20,10 @@ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
21
  x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
22
  ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0', edgecolor='none', shade=False)
23
  elif hasattr(trace, 'i') and trace.i is not None:
24
- verts_x, verts_y, verts_z = np.array(trace.x, dtype=float), np.array(trace.y, dtype=float), np.array(trace.z, dtype=float)
25
- faces_i, faces_j, faces_k = np.array(trace.i, dtype=int), np.array(trace.j, dtype=int), np.array(trace.k, dtype=int)
26
- verts = list(zip(verts_x, verts_y, verts_z))
27
- faces = [[verts[i], verts[j], verts[k]] for i, j, k in zip(faces_i, faces_j, faces_k)]
28
  color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
29
  ax.add_collection3d(Poly3DCollection(faces, alpha=0.75, facecolor=color, edgecolor='none'))
30
  elif hasattr(trace, 'x') and trace.x is not None:
@@ -43,7 +42,6 @@ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
43
 
44
 
45
  def compute_similarity(img1, img2, size=(256, 256)):
46
- """Compute CPU-based similarity metrics."""
47
  from skimage.metrics import structural_similarity as ssim_fn
48
  from skimage import filters
49
  arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
@@ -69,7 +67,6 @@ def _image_to_b64(img, max_dim=512):
69
 
70
  def vlm_compare_and_adjust(original_img, projection_img, current_params,
71
  iteration, metrics, hf_token):
72
- """Use VLM to compare images and suggest parameter adjustments."""
73
  import requests
74
  orig_b64 = _image_to_b64(original_img)
75
  proj_b64 = _image_to_b64(projection_img)
@@ -97,11 +94,11 @@ Only adjust params that exist in current params. Set converged=true if sufficien
97
  {"type": "text", "text": prompt}
98
  ]}]
99
 
100
- # Use actual VLMs with correct providers
101
  models = [
102
- ("Qwen/Qwen2.5-VL-72B-Instruct", "together"),
103
- ("google/gemma-4-31B-it", "novita"),
104
- ("moonshotai/Kimi-K2.5", "fireworks-ai"),
105
  ]
106
 
107
  for model_id, provider in models:
@@ -111,8 +108,9 @@ Only adjust params that exist in current params. Set converged=true if sufficien
111
  payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
112
  resp = requests.post(url, headers=headers, json=payload, timeout=120)
113
  if resp.status_code == 200:
114
- text = resp.json()['choices'][0]['message'].get('content', '')
115
- if not text: text = resp.json()['choices'][0]['message'].get('reasoning', '')
 
116
  json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
117
  json_str = json_match.group(1) if json_match else None
118
  if not json_str:
@@ -152,7 +150,6 @@ def apply_adjustments(analysis, adjustments, lr=0.7):
152
  def refinement_loop(original_image, initial_analysis, generate_fn,
153
  max_iterations=8, target_composite=0.82,
154
  plateau_threshold=0.005, plateau_patience=3, lr=0.7):
155
- """Run the agentic refinement loop."""
156
  hf_token = os.environ.get("HF_TOKEN", "")
157
  current_analysis = copy.deepcopy(initial_analysis)
158
  best_analysis = copy.deepcopy(initial_analysis)
 
12
 
13
 
14
  def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
 
15
  fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
16
  ax = fig.add_subplot(111, projection='3d')
17
  for trace in plotly_fig.data:
 
20
  x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
21
  ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0', edgecolor='none', shade=False)
22
  elif hasattr(trace, 'i') and trace.i is not None:
23
+ vx, vy, vz = np.array(trace.x, dtype=float), np.array(trace.y, dtype=float), np.array(trace.z, dtype=float)
24
+ fi, fj, fk = np.array(trace.i, dtype=int), np.array(trace.j, dtype=int), np.array(trace.k, dtype=int)
25
+ verts = list(zip(vx, vy, vz))
26
+ faces = [[verts[i], verts[j], verts[k]] for i, j, k in zip(fi, fj, fk)]
27
  color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
28
  ax.add_collection3d(Poly3DCollection(faces, alpha=0.75, facecolor=color, edgecolor='none'))
29
  elif hasattr(trace, 'x') and trace.x is not None:
 
42
 
43
 
44
  def compute_similarity(img1, img2, size=(256, 256)):
 
45
  from skimage.metrics import structural_similarity as ssim_fn
46
  from skimage import filters
47
  arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
 
67
 
68
  def vlm_compare_and_adjust(original_img, projection_img, current_params,
69
  iteration, metrics, hf_token):
 
70
  import requests
71
  orig_b64 = _image_to_b64(original_img)
72
  proj_b64 = _image_to_b64(projection_img)
 
94
  {"type": "text", "text": prompt}
95
  ]}]
96
 
97
+ # Verified working VLMs (tested 2026-04-25)
98
  models = [
99
+ ("meta-llama/Llama-4-Scout-17B-16E-Instruct", "nscale"),
100
+ ("moonshotai/Kimi-K2.6", "together"),
101
+ ("Qwen/Qwen3.5-9B", "together"),
102
  ]
103
 
104
  for model_id, provider in models:
 
108
  payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
109
  resp = requests.post(url, headers=headers, json=payload, timeout=120)
110
  if resp.status_code == 200:
111
+ msg = resp.json()['choices'][0]['message']
112
+ text = (msg.get('content', '') or '').strip() or (msg.get('reasoning', '') or '').strip()
113
+ if not text: continue
114
  json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
115
  json_str = json_match.group(1) if json_match else None
116
  if not json_str:
 
150
  def refinement_loop(original_image, initial_analysis, generate_fn,
151
  max_iterations=8, target_composite=0.82,
152
  plateau_threshold=0.005, plateau_patience=3, lr=0.7):
 
153
  hf_token = os.environ.get("HF_TOKEN", "")
154
  current_analysis = copy.deepcopy(initial_analysis)
155
  best_analysis = copy.deepcopy(initial_analysis)