singhanshuman commited on
Commit
8c40084
Β·
verified Β·
1 Parent(s): ab692ae

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Spaces β€” Gradio demo for lsr-lang.
3
+
4
+ Models are downloaded from singhanshuman/lsr-lang-models on first run.
5
+ """
6
+
7
+ import io
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+ from PIL import Image
19
+
20
+ # ── device ──────────────────────────────────────────────────────────────────
21
+ DEVICE = torch.device("cpu")
22
+ Z_DIM = 4
23
+ ACTION_DIM = 4
24
+ MODEL_REPO = "singhanshuman/lsr-lang-models"
25
+
26
+ # ── download weights ─────────────────────────────────────────────────────────
27
+ def _get(filename: str) -> str:
28
+ return hf_hub_download(repo_id=MODEL_REPO, filename=filename)
29
+
30
+ # ── imports after sys.path is stable ─────────────────────────────────────────
31
+ from models.apm import APM
32
+ from models.clip_vae import ClipVAE
33
+ from models.lsr import LSR
34
+ from models.vae import VAE
35
+
36
+
37
+ def _load(ckpt: str, model: torch.nn.Module) -> torch.nn.Module:
38
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
39
+ model.eval()
40
+ return model
41
+
42
+
43
+ print("Loading models from HF Hub...")
44
+ vae = _load(_get("vae_best.pt"), VAE(z_dim=Z_DIM))
45
+ clip_vae = _load(_get("clip_vae_best.pt"), ClipVAE(z_dim=Z_DIM))
46
+ apm = _load(_get("apm_best.pt"), APM(z_dim=Z_DIM, action_dim=ACTION_DIM))
47
+ text_emb = torch.tensor(np.load(_get("text_emb.npy")), dtype=torch.float32)
48
+ lsr = LSR.load(_get("lsr_graph.pkl"))
49
+ print("Ready.")
50
+
51
+
52
+ # ── helpers ───────────────────────────────────────────────────────────────────
53
+ def _pil_to_tensor(img_pil: Image.Image) -> torch.Tensor:
54
+ img = img_pil.convert("RGB").resize((64, 64))
55
+ return torch.tensor(np.array(img), dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
56
+
57
+
58
+ def _tensor_to_pil(t: torch.Tensor) -> Image.Image:
59
+ arr = (t.squeeze().permute(1, 2, 0).clamp(0, 1).numpy() * 255).astype(np.uint8)
60
+ return Image.fromarray(arr)
61
+
62
+
63
+ def _encode_img(img_pil: Image.Image, use_clip: bool, lang: str):
64
+ x = _pil_to_tensor(img_pil)
65
+ with torch.no_grad():
66
+ if use_clip:
67
+ t_emb = text_emb
68
+ if lang.strip():
69
+ from models.clip_vae import encode_text
70
+ t_emb = encode_text([lang], DEVICE)
71
+ t_emb = t_emb.expand(1, -1)
72
+ _, mu, _ = clip_vae.encode(x, t_emb)
73
+ else:
74
+ _, mu, _ = vae.encode(x)
75
+ return mu.squeeze().numpy()
76
+
77
+
78
+ def _decode(z_np: np.ndarray, use_clip: bool) -> Image.Image:
79
+ z = torch.tensor(z_np, dtype=torch.float32).unsqueeze(0)
80
+ model = clip_vae if use_clip else vae
81
+ with torch.no_grad():
82
+ return _tensor_to_pil(model.decode(z))
83
+
84
+
85
+ def _latent_plot(latent_path: np.ndarray) -> Image.Image:
86
+ if lsr.latents is None:
87
+ return None
88
+ all_z = lsr.latents
89
+ if all_z.shape[1] > 2:
90
+ from sklearn.decomposition import PCA
91
+ pca = PCA(n_components=2).fit(all_z)
92
+ coords = pca.transform(all_z)
93
+ path_coords = pca.transform(latent_path)
94
+ else:
95
+ coords = all_z
96
+ path_coords = latent_path
97
+
98
+ fig, ax = plt.subplots(figsize=(5, 5))
99
+ ax.scatter(coords[:, 0], coords[:, 1], s=4, alpha=0.35, c="lightgray")
100
+ ax.plot(path_coords[:, 0], path_coords[:, 1], "r-o", ms=7, lw=2, label="plan")
101
+ ax.scatter(*path_coords[0], c="green", s=90, zorder=5, label="start")
102
+ ax.scatter(*path_coords[-1], c="blue", s=90, zorder=5, label="goal")
103
+ ax.legend(fontsize=9); ax.grid(True, alpha=0.25)
104
+ ax.set_title("Latent space + planned path")
105
+ plt.tight_layout()
106
+ buf = io.BytesIO()
107
+ fig.savefig(buf, format="png", dpi=120)
108
+ plt.close(fig)
109
+ buf.seek(0)
110
+ return Image.open(buf).copy()
111
+
112
+
113
+ # ── main inference ────────────────────────────────────────────────────────────
114
+ def run_plan(start_img, goal_img, language_goal, model_choice):
115
+ use_clip = model_choice == "CLIP-VAE"
116
+ if start_img is None or goal_img is None:
117
+ return [], None, "Provide both a start and a goal image."
118
+
119
+ z_start = _encode_img(Image.fromarray(start_img), use_clip, language_goal)
120
+ z_goal = _encode_img(Image.fromarray(goal_img), use_clip, language_goal)
121
+
122
+ result = lsr.plan(z_start, z_goal)
123
+ if result is None:
124
+ return [], None, "No path found in the latent roadmap."
125
+
126
+ path, latent_path = result
127
+
128
+ plan_imgs = [Image.fromarray(start_img).resize((64, 64))]
129
+ for z_np in latent_path:
130
+ plan_imgs.append(_decode(z_np, use_clip))
131
+ plan_imgs.append(Image.fromarray(goal_img).resize((64, 64)))
132
+
133
+ action_lines = []
134
+ with torch.no_grad():
135
+ for i in range(len(latent_path) - 1):
136
+ z_i = torch.tensor(latent_path[i], dtype=torch.float32).unsqueeze(0)
137
+ z_j = torch.tensor(latent_path[i + 1], dtype=torch.float32).unsqueeze(0)
138
+ a = apm(z_i, z_j).squeeze().tolist()
139
+ vals = a if isinstance(a, list) else [a]
140
+ action_lines.append(f"Step {i+1}: [{', '.join(f'{v:.3f}' for v in vals)}]")
141
+
142
+ info = (
143
+ f"Model : {'CLIP-VAE' if use_clip else 'Baseline VAE'}\n"
144
+ f"Path : {len(path)} nodes\n\n"
145
+ "Predicted actions:\n" + "\n".join(action_lines)
146
+ )
147
+ return plan_imgs, _latent_plot(latent_path), info
148
+
149
+
150
+ # ── UI ────────────────────────────────────────────────────────────────────────
151
+ with gr.Blocks(title="lsr-lang β€” Visual Action Planning") as demo:
152
+ gr.Markdown(
153
+ "## lsr-lang β€” Latent Space Roadmap with Language Conditioning\n"
154
+ "Upload **start** and **goal** images of a box-stacking scene. "
155
+ "The model plans a visual path through latent space and predicts robot actions."
156
+ )
157
+ with gr.Row():
158
+ with gr.Column(scale=1):
159
+ start_img = gr.Image(label="Start image", type="numpy")
160
+ goal_img = gr.Image(label="Goal image", type="numpy")
161
+ lang_goal = gr.Textbox(
162
+ label="Language goal (optional, CLIP-VAE only)",
163
+ placeholder="stack the red box on the blue box",
164
+ )
165
+ model_radio = gr.Radio(
166
+ ["Baseline VAE", "CLIP-VAE"], value="Baseline VAE", label="Model"
167
+ )
168
+ run_btn = gr.Button("Plan β†’", variant="primary")
169
+ with gr.Column(scale=2):
170
+ gallery = gr.Gallery(label="Visual plan", columns=10, height="auto")
171
+ latent_plot = gr.Image(label="Latent space + path")
172
+ info_box = gr.Textbox(label="Actions", lines=8)
173
+
174
+ run_btn.click(
175
+ fn=run_plan,
176
+ inputs=[start_img, goal_img, lang_goal, model_radio],
177
+ outputs=[gallery, latent_plot, info_box],
178
+ )
179
+
180
+ gr.Markdown(
181
+ "**Code:** [github.com/anshuman-dev/lsr-lang](https://github.com/anshuman-dev/lsr-lang) Β· "
182
+ "**Paper:** [LSR-v2 (IEEE T-RO 2023)](https://arxiv.org/abs/2103.02554)"
183
+ )
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch()