Spaces:
Running
Running
Fix VLM models in refinement_loop.py to match app.py (Llama-4-Scout, Kimi-K2.6, Qwen3.5-9B)
Browse files- refinement_loop.py +11 -14
refinement_loop.py
CHANGED
|
@@ -12,7 +12,6 @@ from typing import Dict, List, Tuple, Optional
|
|
| 12 |
|
| 13 |
|
| 14 |
def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
|
| 15 |
-
"""Render a Plotly 3D figure to a PIL image using matplotlib."""
|
| 16 |
fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
|
| 17 |
ax = fig.add_subplot(111, projection='3d')
|
| 18 |
for trace in plotly_fig.data:
|
|
@@ -21,10 +20,10 @@ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
|
|
| 21 |
x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
|
| 22 |
ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0', edgecolor='none', shade=False)
|
| 23 |
elif hasattr(trace, 'i') and trace.i is not None:
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
verts = list(zip(
|
| 27 |
-
faces = [[verts[i], verts[j], verts[k]] for i, j, k in zip(
|
| 28 |
color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
|
| 29 |
ax.add_collection3d(Poly3DCollection(faces, alpha=0.75, facecolor=color, edgecolor='none'))
|
| 30 |
elif hasattr(trace, 'x') and trace.x is not None:
|
|
@@ -43,7 +42,6 @@ def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
|
|
| 43 |
|
| 44 |
|
| 45 |
def compute_similarity(img1, img2, size=(256, 256)):
|
| 46 |
-
"""Compute CPU-based similarity metrics."""
|
| 47 |
from skimage.metrics import structural_similarity as ssim_fn
|
| 48 |
from skimage import filters
|
| 49 |
arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
|
|
@@ -69,7 +67,6 @@ def _image_to_b64(img, max_dim=512):
|
|
| 69 |
|
| 70 |
def vlm_compare_and_adjust(original_img, projection_img, current_params,
|
| 71 |
iteration, metrics, hf_token):
|
| 72 |
-
"""Use VLM to compare images and suggest parameter adjustments."""
|
| 73 |
import requests
|
| 74 |
orig_b64 = _image_to_b64(original_img)
|
| 75 |
proj_b64 = _image_to_b64(projection_img)
|
|
@@ -97,11 +94,11 @@ Only adjust params that exist in current params. Set converged=true if sufficien
|
|
| 97 |
{"type": "text", "text": prompt}
|
| 98 |
]}]
|
| 99 |
|
| 100 |
-
#
|
| 101 |
models = [
|
| 102 |
-
("
|
| 103 |
-
("
|
| 104 |
-
("
|
| 105 |
]
|
| 106 |
|
| 107 |
for model_id, provider in models:
|
|
@@ -111,8 +108,9 @@ Only adjust params that exist in current params. Set converged=true if sufficien
|
|
| 111 |
payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
|
| 112 |
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
| 113 |
if resp.status_code == 200:
|
| 114 |
-
|
| 115 |
-
|
|
|
|
| 116 |
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
|
| 117 |
json_str = json_match.group(1) if json_match else None
|
| 118 |
if not json_str:
|
|
@@ -152,7 +150,6 @@ def apply_adjustments(analysis, adjustments, lr=0.7):
|
|
| 152 |
def refinement_loop(original_image, initial_analysis, generate_fn,
|
| 153 |
max_iterations=8, target_composite=0.82,
|
| 154 |
plateau_threshold=0.005, plateau_patience=3, lr=0.7):
|
| 155 |
-
"""Run the agentic refinement loop."""
|
| 156 |
hf_token = os.environ.get("HF_TOKEN", "")
|
| 157 |
current_analysis = copy.deepcopy(initial_analysis)
|
| 158 |
best_analysis = copy.deepcopy(initial_analysis)
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def render_3d_to_image(plotly_fig, elev=15, azim=45, width=512, height=512):
|
|
|
|
| 15 |
fig = plt.figure(figsize=(width / 100, height / 100), dpi=100)
|
| 16 |
ax = fig.add_subplot(111, projection='3d')
|
| 17 |
for trace in plotly_fig.data:
|
|
|
|
| 20 |
x, y, z = np.array(trace.x), np.array(trace.y), np.array(trace.z)
|
| 21 |
ax.plot_surface(x, y, z, alpha=0.08, color='#E8D0B0', edgecolor='none', shade=False)
|
| 22 |
elif hasattr(trace, 'i') and trace.i is not None:
|
| 23 |
+
vx, vy, vz = np.array(trace.x, dtype=float), np.array(trace.y, dtype=float), np.array(trace.z, dtype=float)
|
| 24 |
+
fi, fj, fk = np.array(trace.i, dtype=int), np.array(trace.j, dtype=int), np.array(trace.k, dtype=int)
|
| 25 |
+
verts = list(zip(vx, vy, vz))
|
| 26 |
+
faces = [[verts[i], verts[j], verts[k]] for i, j, k in zip(fi, fj, fk)]
|
| 27 |
color = trace.color if hasattr(trace, 'color') and trace.color else '#4A90D9'
|
| 28 |
ax.add_collection3d(Poly3DCollection(faces, alpha=0.75, facecolor=color, edgecolor='none'))
|
| 29 |
elif hasattr(trace, 'x') and trace.x is not None:
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
def compute_similarity(img1, img2, size=(256, 256)):
|
|
|
|
| 45 |
from skimage.metrics import structural_similarity as ssim_fn
|
| 46 |
from skimage import filters
|
| 47 |
arr1 = np.array(img1.resize(size).convert('RGB'), dtype=float)
|
|
|
|
| 67 |
|
| 68 |
def vlm_compare_and_adjust(original_img, projection_img, current_params,
|
| 69 |
iteration, metrics, hf_token):
|
|
|
|
| 70 |
import requests
|
| 71 |
orig_b64 = _image_to_b64(original_img)
|
| 72 |
proj_b64 = _image_to_b64(projection_img)
|
|
|
|
| 94 |
{"type": "text", "text": prompt}
|
| 95 |
]}]
|
| 96 |
|
| 97 |
+
# Verified working VLMs (tested 2026-04-25)
|
| 98 |
models = [
|
| 99 |
+
("meta-llama/Llama-4-Scout-17B-16E-Instruct", "nscale"),
|
| 100 |
+
("moonshotai/Kimi-K2.6", "together"),
|
| 101 |
+
("Qwen/Qwen3.5-9B", "together"),
|
| 102 |
]
|
| 103 |
|
| 104 |
for model_id, provider in models:
|
|
|
|
| 108 |
payload = {"model": model_id, "messages": messages, "max_tokens": 1500, "temperature": 0.1}
|
| 109 |
resp = requests.post(url, headers=headers, json=payload, timeout=120)
|
| 110 |
if resp.status_code == 200:
|
| 111 |
+
msg = resp.json()['choices'][0]['message']
|
| 112 |
+
text = (msg.get('content', '') or '').strip() or (msg.get('reasoning', '') or '').strip()
|
| 113 |
+
if not text: continue
|
| 114 |
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
|
| 115 |
json_str = json_match.group(1) if json_match else None
|
| 116 |
if not json_str:
|
|
|
|
| 150 |
def refinement_loop(original_image, initial_analysis, generate_fn,
|
| 151 |
max_iterations=8, target_composite=0.82,
|
| 152 |
plateau_threshold=0.005, plateau_patience=3, lr=0.7):
|
|
|
|
| 153 |
hf_token = os.environ.get("HF_TOKEN", "")
|
| 154 |
current_analysis = copy.deepcopy(initial_analysis)
|
| 155 |
best_analysis = copy.deepcopy(initial_analysis)
|