Spaces:
Build error
Build error
Robotics commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,25 +2,29 @@ import gradio as gr
|
|
| 2 |
import numpy as np, imageio, tempfile, os
|
| 3 |
import pybullet as p
|
| 4 |
import openai
|
| 5 |
-
from openai import AuthenticationError
|
| 6 |
from contactvla.env import BoxPushEnv
|
| 7 |
from contactvla.mppi import MPPI
|
| 8 |
from contactvla.llm_feedback import LLMFeedback
|
| 9 |
|
|
|
|
| 10 |
def run_demo(api_key):
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
|
|
|
| 18 |
api_key = api_key.strip()
|
| 19 |
if not api_key:
|
| 20 |
return None, "⚠️ Please enter a valid OpenAI API key."
|
| 21 |
|
| 22 |
try:
|
| 23 |
openai.api_key = api_key
|
|
|
|
| 24 |
openai.chat.completions.create(
|
| 25 |
model="gpt-4o-mini",
|
| 26 |
messages=[{"role": "user", "content": "ping"}],
|
|
@@ -30,7 +34,8 @@ def run_demo(api_key):
|
|
| 30 |
return None, "⚠️ Invalid OpenAI API key. Please check and try again."
|
| 31 |
except Exception as e:
|
| 32 |
return None, f"⚠️ Could not reach OpenAI API: {e}"
|
| 33 |
-
|
|
|
|
| 34 |
cid = p.connect(p.DIRECT)
|
| 35 |
env = BoxPushEnv(cid)
|
| 36 |
mppi = MPPI(env)
|
|
@@ -44,29 +49,55 @@ def run_demo(api_key):
|
|
| 44 |
env.reset()
|
| 45 |
history = []
|
| 46 |
|
| 47 |
-
|
|
|
|
| 48 |
u = mppi.compute_control()
|
| 49 |
box_pos, ee_pos = env.step(u)
|
| 50 |
cost = env.state_cost()
|
| 51 |
-
history.append({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
-
#
|
| 54 |
_, _, rgb, _, _ = p.getCameraImage(cam_w, cam_h, view, proj, renderer=p.ER_TINY_RENDERER)
|
| 55 |
frame = np.array(rgb, dtype=np.uint8).reshape(cam_h, cam_w, 4)[:, :, :3]
|
| 56 |
frames.append(frame)
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
|
|
|
| 61 |
tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
|
| 62 |
imageio.mimsave(tmp_path, frames, fps=10)
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
|
|
|
| 66 |
gr.Interface(
|
| 67 |
fn=run_demo,
|
| 68 |
-
inputs=gr.Textbox(
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
).launch()
|
|
|
|
| 2 |
import numpy as np, imageio, tempfile, os
|
| 3 |
import pybullet as p
|
| 4 |
import openai
|
| 5 |
+
from openai import AuthenticationError
|
| 6 |
from contactvla.env import BoxPushEnv
|
| 7 |
from contactvla.mppi import MPPI
|
| 8 |
from contactvla.llm_feedback import LLMFeedback
|
| 9 |
|
| 10 |
+
|
| 11 |
def run_demo(api_key):
|
| 12 |
+
"""Main demo entrypoint for Hugging Face Space"""
|
| 13 |
|
| 14 |
+
# --- 1. PyBullet cleanup: disconnect any previous physics servers ---
|
| 15 |
+
try:
|
| 16 |
+
p.disconnect() # closes any active simulation session
|
| 17 |
+
except Exception:
|
| 18 |
+
pass
|
| 19 |
|
| 20 |
+
# --- 2. Validate OpenAI API key ---
|
| 21 |
api_key = api_key.strip()
|
| 22 |
if not api_key:
|
| 23 |
return None, "⚠️ Please enter a valid OpenAI API key."
|
| 24 |
|
| 25 |
try:
|
| 26 |
openai.api_key = api_key
|
| 27 |
+
# Quick ping test for key validation
|
| 28 |
openai.chat.completions.create(
|
| 29 |
model="gpt-4o-mini",
|
| 30 |
messages=[{"role": "user", "content": "ping"}],
|
|
|
|
| 34 |
return None, "⚠️ Invalid OpenAI API key. Please check and try again."
|
| 35 |
except Exception as e:
|
| 36 |
return None, f"⚠️ Could not reach OpenAI API: {e}"
|
| 37 |
+
|
| 38 |
+
# --- 3. Initialize PyBullet env ---
|
| 39 |
cid = p.connect(p.DIRECT)
|
| 40 |
env = BoxPushEnv(cid)
|
| 41 |
mppi = MPPI(env)
|
|
|
|
| 49 |
env.reset()
|
| 50 |
history = []
|
| 51 |
|
| 52 |
+
# --- 4. Run MPPI controller ---
|
| 53 |
+
for step in range(30):
|
| 54 |
u = mppi.compute_control()
|
| 55 |
box_pos, ee_pos = env.step(u)
|
| 56 |
cost = env.state_cost()
|
| 57 |
+
history.append({
|
| 58 |
+
"step": step,
|
| 59 |
+
"box_pos": box_pos.tolist(),
|
| 60 |
+
"ee_pos": ee_pos.tolist(),
|
| 61 |
+
"cost": cost
|
| 62 |
+
})
|
| 63 |
|
| 64 |
+
# Off-screen rendering (TinyRenderer)
|
| 65 |
_, _, rgb, _, _ = p.getCameraImage(cam_w, cam_h, view, proj, renderer=p.ER_TINY_RENDERER)
|
| 66 |
frame = np.array(rgb, dtype=np.uint8).reshape(cam_h, cam_w, 4)[:, :, :3]
|
| 67 |
frames.append(frame)
|
| 68 |
|
| 69 |
+
# --- 5. Query LLM for improved cost function ---
|
| 70 |
+
try:
|
| 71 |
+
code, expl = llm.ask_state_cost_fn("Flip the box using the wall", history, env)
|
| 72 |
+
env.update_state_cost(code)
|
| 73 |
+
except AuthenticationError:
|
| 74 |
+
return None, "⚠️ Your API key was rejected during LLM call. Please try again."
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return None, f"⚠️ Error while querying LLM: {e}"
|
| 77 |
|
| 78 |
+
# --- 6. Render final video ---
|
| 79 |
tmp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
|
| 80 |
imageio.mimsave(tmp_path, frames, fps=10)
|
| 81 |
|
| 82 |
+
# --- 7. Disconnect PyBullet to clean up ---
|
| 83 |
+
p.disconnect()
|
| 84 |
+
|
| 85 |
+
# --- 8. Return output ---
|
| 86 |
+
return tmp_path, f"✅ LLM updated cost:\n{code}\n\nExplanation:\n{expl}"
|
| 87 |
+
|
| 88 |
|
| 89 |
+
# --- Gradio UI ---
|
| 90 |
gr.Interface(
|
| 91 |
fn=run_demo,
|
| 92 |
+
inputs=gr.Textbox(
|
| 93 |
+
label="Enter your OpenAI API Key (without quotation marks)",
|
| 94 |
+
type="password",
|
| 95 |
+
placeholder="sk-..."
|
| 96 |
+
),
|
| 97 |
+
outputs=["video", "text"],
|
| 98 |
+
title="ContactVLA: LLM-guided Box Flipping",
|
| 99 |
+
description=(
|
| 100 |
+
"This demo uses a PyBullet simulation of a Panda robot pushing a box against a wall. "
|
| 101 |
+
"After several steps, an LLM rewrites the cost function guiding the control policy."
|
| 102 |
+
),
|
| 103 |
).launch()
|