Georg commited on
Commit
e219ce4
·
1 Parent(s): 1fd398f

Add joblib dependency and implement CAD-based & model-free init tabs

Browse files
Dockerfile.base CHANGED
@@ -70,6 +70,7 @@ RUN pip install --no-cache-dir \
70
  timm==0.9.16 \
71
  transformations==2024.6.1 \
72
  pyyaml==6.0.1 \
 
73
  && pip cache purge
74
 
75
  # Note: nvdiffrast will be built in final Dockerfile on HuggingFace (needs GPU)
 
70
  timm==0.9.16 \
71
  transformations==2024.6.1 \
72
  pyyaml==6.0.1 \
73
+ joblib==1.4.0 \
74
  && pip cache purge
75
 
76
  # Note: nvdiffrast will be built in final Dockerfile on HuggingFace (needs GPU)
app.py CHANGED
@@ -167,8 +167,51 @@ pose_estimator = FoundationPoseInference()
167
 
168
 
169
  # Gradio wrapper functions
170
- def gradio_initialize(object_id: str, reference_files: List, fx: float, fy: float, cx: float, cy: float):
171
- """Gradio wrapper for object initialization."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  try:
173
  if not reference_files:
174
  return "Error: No reference images provided"
@@ -185,6 +228,9 @@ def gradio_initialize(object_id: str, reference_files: List, fx: float, fy: floa
185
  if not reference_images:
186
  return "Error: Could not load any reference images"
187
 
 
 
 
188
  # Prepare camera intrinsics
189
  camera_intrinsics = {
190
  "fx": fx,
@@ -193,20 +239,21 @@ def gradio_initialize(object_id: str, reference_files: List, fx: float, fy: floa
193
  "cy": cy
194
  }
195
 
196
- # Register object
197
  success = pose_estimator.register_object(
198
  object_id=object_id,
199
  reference_images=reference_images,
200
- camera_intrinsics=camera_intrinsics
 
201
  )
202
 
203
  if success:
204
- return f"✓ Object '{object_id}' initialized with {len(reference_images)} reference images"
205
  else:
206
  return f"✗ Failed to initialize object '{object_id}'"
207
 
208
  except Exception as e:
209
- logger.error(f"Gradio initialization error: {e}", exc_info=True)
210
  return f"Error: {str(e)}"
211
 
212
 
@@ -290,47 +337,108 @@ with gr.Blocks(title="FoundationPose Inference", theme=gr.themes.Soft()) as demo
290
  # Tab 1: Initialize Object
291
  with gr.Tab("Initialize Object"):
292
  gr.Markdown("""
293
- Upload reference images of your object from different angles (8-20 images recommended).
294
- The model will learn the object's appearance for pose estimation.
295
  """)
296
 
297
- with gr.Row():
298
- with gr.Column():
299
- init_object_id = gr.Textbox(
300
- label="Object ID",
301
- placeholder="e.g., target_cube",
302
- value="target_cube"
303
- )
 
 
304
 
305
- init_ref_files = gr.File(
306
- label="Reference Images",
307
- file_count="multiple",
308
- file_types=["image"]
309
- )
310
-
311
- gr.Markdown("### Camera Intrinsics")
312
  with gr.Row():
313
- init_fx = gr.Number(label="fx (focal length x)", value=500.0)
314
- init_fy = gr.Number(label="fy (focal length y)", value=500.0)
315
- with gr.Row():
316
- init_cx = gr.Number(label="cx (principal point x)", value=320.0)
317
- init_cy = gr.Number(label="cy (principal point y)", value=240.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- init_button = gr.Button("Initialize Object", variant="primary")
 
 
 
 
 
 
 
320
 
321
- with gr.Column():
322
- init_output = gr.Textbox(
323
- label="Initialization Result",
324
- lines=5,
325
- interactive=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  )
327
 
328
- init_button.click(
329
- fn=gradio_initialize,
330
- inputs=[init_object_id, init_ref_files, init_fx, init_fy, init_cx, init_cy],
331
- outputs=init_output
332
- )
333
-
334
  # Tab 2: Estimate Pose
335
  with gr.Tab("Estimate Pose"):
336
  gr.Markdown("""
 
167
 
168
 
169
  # Gradio wrapper functions
170
+ def gradio_initialize_cad(object_id: str, mesh_file, reference_files: List, fx: float, fy: float, cx: float, cy: float):
171
+ """Gradio wrapper for CAD-based object initialization."""
172
+ try:
173
+ if not mesh_file:
174
+ return "Error: No mesh file provided"
175
+
176
+ # Load reference images (optional for CAD mode)
177
+ reference_images = []
178
+ if reference_files:
179
+ for file in reference_files:
180
+ img = cv2.imread(file.name)
181
+ if img is None:
182
+ continue
183
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
184
+ reference_images.append(img)
185
+
186
+ # Prepare camera intrinsics
187
+ camera_intrinsics = {
188
+ "fx": fx,
189
+ "fy": fy,
190
+ "cx": cx,
191
+ "cy": cy
192
+ }
193
+
194
+ # Register object with mesh
195
+ success = pose_estimator.register_object(
196
+ object_id=object_id,
197
+ reference_images=reference_images if reference_images else [],
198
+ camera_intrinsics=camera_intrinsics,
199
+ mesh_path=mesh_file.name
200
+ )
201
+
202
+ if success:
203
+ ref_info = f" and {len(reference_images)} reference images" if reference_images else ""
204
+ return f"✓ Object '{object_id}' initialized with CAD model{ref_info}"
205
+ else:
206
+ return f"✗ Failed to initialize object '{object_id}'"
207
+
208
+ except Exception as e:
209
+ logger.error(f"CAD initialization error: {e}", exc_info=True)
210
+ return f"Error: {str(e)}"
211
+
212
+
213
+ def gradio_initialize_model_free(object_id: str, reference_files: List, fx: float, fy: float, cx: float, cy: float):
214
+ """Gradio wrapper for model-free object initialization."""
215
  try:
216
  if not reference_files:
217
  return "Error: No reference images provided"
 
228
  if not reference_images:
229
  return "Error: Could not load any reference images"
230
 
231
+ if len(reference_images) < 8:
232
+ return f"Warning: Only {len(reference_images)} images provided. 16-24 recommended for best results."
233
+
234
  # Prepare camera intrinsics
235
  camera_intrinsics = {
236
  "fx": fx,
 
239
  "cy": cy
240
  }
241
 
242
+ # Register object without mesh (model-free)
243
  success = pose_estimator.register_object(
244
  object_id=object_id,
245
  reference_images=reference_images,
246
+ camera_intrinsics=camera_intrinsics,
247
+ mesh_path=None
248
  )
249
 
250
  if success:
251
+ return f"✓ Object '{object_id}' initialized with {len(reference_images)} reference images (model-free mode)"
252
  else:
253
  return f"✗ Failed to initialize object '{object_id}'"
254
 
255
  except Exception as e:
256
+ logger.error(f"Model-free initialization error: {e}", exc_info=True)
257
  return f"Error: {str(e)}"
258
 
259
 
 
337
  # Tab 1: Initialize Object
338
  with gr.Tab("Initialize Object"):
339
  gr.Markdown("""
340
+ Choose the initialization mode based on whether you have a 3D CAD model of your object.
 
341
  """)
342
 
343
+ with gr.Tabs():
344
+ # Sub-tab 1.1: CAD-Based Init
345
+ with gr.Tab("CAD-Based (Model-Based)"):
346
+ gr.Markdown("""
347
+ **Model-Based Mode**: Use this if you have a 3D mesh/CAD model (.obj, .stl, .ply).
348
+ - Upload your 3D mesh file
349
+ - Optionally upload reference images for better initialization
350
+ - More accurate and robust
351
+ """)
352
 
 
 
 
 
 
 
 
353
  with gr.Row():
354
+ with gr.Column():
355
+ cad_object_id = gr.Textbox(
356
+ label="Object ID",
357
+ placeholder="e.g., target_cube",
358
+ value="target_cube"
359
+ )
360
+
361
+ cad_mesh_file = gr.File(
362
+ label="3D Mesh File (.obj, .stl, .ply)",
363
+ file_count="single",
364
+ file_types=[".obj", ".stl", ".ply", ".mesh"]
365
+ )
366
+
367
+ cad_ref_files = gr.File(
368
+ label="Reference Images (Optional)",
369
+ file_count="multiple",
370
+ file_types=["image"]
371
+ )
372
+
373
+ gr.Markdown("### Camera Intrinsics")
374
+ with gr.Row():
375
+ cad_fx = gr.Number(label="fx", value=500.0)
376
+ cad_fy = gr.Number(label="fy", value=500.0)
377
+ with gr.Row():
378
+ cad_cx = gr.Number(label="cx", value=320.0)
379
+ cad_cy = gr.Number(label="cy", value=240.0)
380
+
381
+ cad_init_button = gr.Button("Initialize with CAD", variant="primary")
382
+
383
+ with gr.Column():
384
+ cad_init_output = gr.Textbox(
385
+ label="Initialization Result",
386
+ lines=5,
387
+ interactive=False
388
+ )
389
+
390
+ cad_init_button.click(
391
+ fn=gradio_initialize_cad,
392
+ inputs=[cad_object_id, cad_mesh_file, cad_ref_files, cad_fx, cad_fy, cad_cx, cad_cy],
393
+ outputs=cad_init_output
394
+ )
395
 
396
+ # Sub-tab 1.2: Model-Free Init
397
+ with gr.Tab("Model-Free (Reference-Based)"):
398
+ gr.Markdown("""
399
+ **Model-Free Mode**: Use this if you don't have a 3D model.
400
+ - Upload 16-24 reference images from different viewpoints
401
+ - Works without a 3D model
402
+ - Less accurate than CAD-based but more flexible
403
+ """)
404
 
405
+ with gr.Row():
406
+ with gr.Column():
407
+ free_object_id = gr.Textbox(
408
+ label="Object ID",
409
+ placeholder="e.g., target_cube",
410
+ value="target_cube"
411
+ )
412
+
413
+ free_ref_files = gr.File(
414
+ label="Reference Images (16-24 recommended)",
415
+ file_count="multiple",
416
+ file_types=["image"]
417
+ )
418
+
419
+ gr.Markdown("### Camera Intrinsics")
420
+ with gr.Row():
421
+ free_fx = gr.Number(label="fx", value=500.0)
422
+ free_fy = gr.Number(label="fy", value=500.0)
423
+ with gr.Row():
424
+ free_cx = gr.Number(label="cx", value=320.0)
425
+ free_cy = gr.Number(label="cy", value=240.0)
426
+
427
+ free_init_button = gr.Button("Initialize Model-Free", variant="primary")
428
+
429
+ with gr.Column():
430
+ free_init_output = gr.Textbox(
431
+ label="Initialization Result",
432
+ lines=5,
433
+ interactive=False
434
+ )
435
+
436
+ free_init_button.click(
437
+ fn=gradio_initialize_model_free,
438
+ inputs=[free_object_id, free_ref_files, free_fx, free_fy, free_cx, free_cy],
439
+ outputs=free_init_output
440
  )
441
 
 
 
 
 
 
 
442
  # Tab 2: Estimate Pose
443
  with gr.Tab("Estimate Pose"):
444
  gr.Markdown("""
client.py CHANGED
@@ -5,22 +5,21 @@ This client can be used from the robot-ml training pipeline to call the
5
  FoundationPose inference API hosted on Hugging Face Spaces.
6
  """
7
 
8
- import base64
9
  import json
10
  import logging
11
- from io import BytesIO
12
  from pathlib import Path
13
  from typing import Dict, List, Optional
14
 
15
  import cv2
16
  import numpy as np
17
- import requests
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
  class FoundationPoseClient:
23
- """Client for FoundationPose API."""
24
 
25
  def __init__(self, api_url: str = "https://gpue-foundationpose.hf.space"):
26
  """Initialize client.
@@ -29,27 +28,26 @@ class FoundationPoseClient:
29
  api_url: Base URL of the FoundationPose Space
30
  """
31
  self.api_url = api_url.rstrip("/")
32
- self.session = requests.Session()
33
- self.session.headers.update({"Content-Type": "application/json"})
 
34
 
35
- def _encode_image(self, image: np.ndarray) -> str:
36
- """Encode image as base64 JPEG.
37
 
38
  Args:
39
  image: RGB image as numpy array
40
 
41
  Returns:
42
- Base64-encoded JPEG string
43
  """
44
  # Convert RGB to BGR for OpenCV
45
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
46
 
47
- # Encode as JPEG
48
- _, buffer = cv2.imencode(".jpg", image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
49
-
50
- # Convert to base64
51
- image_b64 = base64.b64encode(buffer).decode("utf-8")
52
- return image_b64
53
 
54
  def initialize(
55
  self,
@@ -62,7 +60,7 @@ class FoundationPoseClient:
62
  Args:
63
  object_id: Unique ID for the object
64
  reference_images: List of RGB images (numpy arrays)
65
- camera_intrinsics: Optional camera parameters
66
 
67
  Returns:
68
  True if successful
@@ -72,39 +70,60 @@ class FoundationPoseClient:
72
  """
73
  logger.info(f"Initializing object '{object_id}' with {len(reference_images)} reference images")
74
 
75
- # Encode images
76
- images_b64 = [self._encode_image(img) for img in reference_images]
77
-
78
- # Prepare request
79
- payload = {
80
- "object_id": object_id,
81
- "reference_images_b64": images_b64,
82
- }
83
-
84
- if camera_intrinsics:
85
- payload["camera_intrinsics"] = json.dumps(camera_intrinsics)
86
-
87
- # Send request
88
  try:
89
- response = self.session.post(
90
- f"{self.api_url}/api/initialize",
91
- json=payload,
92
- timeout=120 # Long timeout for model loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  )
94
- response.raise_for_status()
95
-
96
- result = response.json()
97
-
98
- if not result.get("success"):
99
- error = result.get("error", "Unknown error")
100
- raise RuntimeError(f"Initialization failed: {error}")
101
 
102
- logger.info(f"Object '{object_id}' initialized successfully")
103
- return True
104
-
105
- except requests.exceptions.RequestException as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  logger.error(f"API request failed: {e}")
107
  raise RuntimeError(f"Failed to initialize object: {e}")
 
 
 
 
 
 
 
108
 
109
  def estimate_pose(
110
  self,
@@ -117,7 +136,7 @@ class FoundationPoseClient:
117
  Args:
118
  object_id: ID of object to detect
119
  query_image: RGB query image as numpy array
120
- camera_intrinsics: Optional camera parameters
121
 
122
  Returns:
123
  List of detected poses:
@@ -134,38 +153,74 @@ class FoundationPoseClient:
134
  Raises:
135
  RuntimeError: If estimation fails
136
  """
137
- # Encode image
138
- image_b64 = self._encode_image(query_image)
139
-
140
- # Prepare request
141
- payload = {
142
- "object_id": object_id,
143
- "query_image_b64": image_b64,
144
- }
145
-
146
- if camera_intrinsics:
147
- payload["camera_intrinsics"] = json.dumps(camera_intrinsics)
148
 
149
- # Send request
150
  try:
151
- response = self.session.post(
152
- f"{self.api_url}/api/estimate",
153
- json=payload,
154
- timeout=30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
- response.raise_for_status()
157
-
158
- result = response.json()
159
-
160
- if not result.get("success"):
161
- error = result.get("error", "Unknown error")
162
- raise RuntimeError(f"Pose estimation failed: {error}")
163
-
164
- return result.get("poses", [])
165
 
166
- except requests.exceptions.RequestException as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  logger.error(f"API request failed: {e}")
168
  raise RuntimeError(f"Failed to estimate pose: {e}")
 
 
 
 
 
 
169
 
170
 
171
  def load_reference_images(directory: Path) -> List[np.ndarray]:
 
5
  FoundationPose inference API hosted on Hugging Face Spaces.
6
  """
7
 
 
8
  import json
9
  import logging
10
+ import tempfile
11
  from pathlib import Path
12
  from typing import Dict, List, Optional
13
 
14
  import cv2
15
  import numpy as np
16
+ from gradio_client import Client, handle_file
17
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
  class FoundationPoseClient:
22
+ """Client for FoundationPose Gradio API."""
23
 
24
  def __init__(self, api_url: str = "https://gpue-foundationpose.hf.space"):
25
  """Initialize client.
 
28
  api_url: Base URL of the FoundationPose Space
29
  """
30
  self.api_url = api_url.rstrip("/")
31
+ logger.info(f"Initializing Gradio client for {self.api_url}")
32
+ self.client = Client(self.api_url)
33
+ logger.info("Gradio client initialized")
34
 
35
+ def _save_image_temp(self, image: np.ndarray) -> str:
36
+ """Save image to temporary file.
37
 
38
  Args:
39
  image: RGB image as numpy array
40
 
41
  Returns:
42
+ Path to temporary file
43
  """
44
  # Convert RGB to BGR for OpenCV
45
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
46
 
47
+ # Save to temp file
48
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jpg")
49
+ cv2.imwrite(temp_file.name, image_bgr, [cv2.IMWRITE_JPEG_QUALITY, 95])
50
+ return temp_file.name
 
 
51
 
52
  def initialize(
53
  self,
 
60
  Args:
61
  object_id: Unique ID for the object
62
  reference_images: List of RGB images (numpy arrays)
63
+ camera_intrinsics: Optional camera parameters (dict with fx, fy, cx, cy)
64
 
65
  Returns:
66
  True if successful
 
70
  """
71
  logger.info(f"Initializing object '{object_id}' with {len(reference_images)} reference images")
72
 
73
+ # Save images to temporary files
74
+ temp_files = []
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
+ for img in reference_images:
77
+ temp_path = self._save_image_temp(img)
78
+ temp_files.append(temp_path)
79
+
80
+ # Extract camera intrinsics or use defaults
81
+ if camera_intrinsics:
82
+ fx = camera_intrinsics.get("fx", 600.0)
83
+ fy = camera_intrinsics.get("fy", 600.0)
84
+ cx = camera_intrinsics.get("cx", 320.0)
85
+ cy = camera_intrinsics.get("cy", 240.0)
86
+ else:
87
+ fx, fy, cx, cy = 600.0, 600.0, 320.0, 240.0
88
+
89
+ # Call Gradio API
90
+ result = self.client.predict(
91
+ object_id=object_id,
92
+ reference_files=[handle_file(f) for f in temp_files],
93
+ fx=fx,
94
+ fy=fy,
95
+ cx=cx,
96
+ cy=cy,
97
+ api_name="/gradio_initialize"
98
  )
 
 
 
 
 
 
 
99
 
100
+ # Parse result - Gradio returns plain text
101
+ logger.info(f"API result: {result}")
102
+ if isinstance(result, str):
103
+ # Check if result indicates success (contains ✓ or "initialized")
104
+ if "✓" in result or "initialized" in result.lower():
105
+ logger.info("Initialization successful")
106
+ return True
107
+ elif "Error" in result or "error" in result:
108
+ raise RuntimeError(f"Initialization failed: {result}")
109
+ else:
110
+ # Assume success if no error indication
111
+ return True
112
+ else:
113
+ raise RuntimeError(f"Unexpected result type: {type(result)}")
114
+
115
+ except RuntimeError:
116
+ raise
117
+ except Exception as e:
118
  logger.error(f"API request failed: {e}")
119
  raise RuntimeError(f"Failed to initialize object: {e}")
120
+ finally:
121
+ # Clean up temp files
122
+ for temp_file in temp_files:
123
+ try:
124
+ Path(temp_file).unlink()
125
+ except Exception:
126
+ pass
127
 
128
  def estimate_pose(
129
  self,
 
136
  Args:
137
  object_id: ID of object to detect
138
  query_image: RGB query image as numpy array
139
+ camera_intrinsics: Optional camera parameters (dict with fx, fy, cx, cy)
140
 
141
  Returns:
142
  List of detected poses:
 
153
  Raises:
154
  RuntimeError: If estimation fails
155
  """
156
+ # Save query image to temp file
157
+ temp_file = self._save_image_temp(query_image)
 
 
 
 
 
 
 
 
 
158
 
 
159
  try:
160
+ # Extract camera intrinsics or use defaults
161
+ if camera_intrinsics:
162
+ fx = camera_intrinsics.get("fx", 600.0)
163
+ fy = camera_intrinsics.get("fy", 600.0)
164
+ cx = camera_intrinsics.get("cx", 320.0)
165
+ cy = camera_intrinsics.get("cy", 240.0)
166
+ else:
167
+ fx, fy, cx, cy = 600.0, 600.0, 320.0, 240.0
168
+
169
+ # Call Gradio API
170
+ result = self.client.predict(
171
+ object_id=object_id,
172
+ query_image=handle_file(temp_file),
173
+ fx=fx,
174
+ fy=fy,
175
+ cx=cx,
176
+ cy=cy,
177
+ api_name="/gradio_estimate"
178
  )
 
 
 
 
 
 
 
 
 
179
 
180
+ # Parse result - Gradio may return tuple (text, image) or just text
181
+ logger.info(f"API result type: {type(result)}")
182
+
183
+ # If tuple, take first element (text output)
184
+ if isinstance(result, tuple):
185
+ result = result[0]
186
+
187
+ if isinstance(result, str):
188
+ logger.info(f"API result: {result}")
189
+
190
+ # Check for errors
191
+ if "Error" in result or "not initialized" in result:
192
+ raise RuntimeError(f"Pose estimation failed: {result}")
193
+
194
+ # Try to parse as JSON (in case app.py returns JSON string)
195
+ try:
196
+ result_dict = json.loads(result)
197
+ if isinstance(result_dict, dict) and "poses" in result_dict:
198
+ return result_dict["poses"]
199
+ except (json.JSONDecodeError, ValueError):
200
+ pass
201
+
202
+ # Check if the result indicates no poses detected
203
+ if "No poses detected" in result or "⚠" in result:
204
+ logger.info("No poses detected in query image")
205
+ return []
206
+
207
+ # For now, return empty list with a warning
208
+ logger.warning(f"Could not parse pose from result: {result}")
209
+ return []
210
+ else:
211
+ raise RuntimeError(f"Unexpected result type: {type(result)}")
212
+
213
+ except RuntimeError:
214
+ raise
215
+ except Exception as e:
216
  logger.error(f"API request failed: {e}")
217
  raise RuntimeError(f"Failed to estimate pose: {e}")
218
+ finally:
219
+ # Clean up temp file
220
+ try:
221
+ Path(temp_file).unlink()
222
+ except Exception:
223
+ pass
224
 
225
 
226
  def load_reference_images(directory: Path) -> List[np.ndarray]:
tests/README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FoundationPose Tests
2
+
3
+ This directory contains test scripts for the FoundationPose estimator.
4
+
5
+ ## Test Data
6
+
7
+ Reference images for test objects are stored in `reference/target_cube/`.
8
+
9
+ ## Running Tests
10
+
11
+ ### Test Estimator Locally
12
+
13
+ ```bash
14
+ cd /path/to/foundationpose
15
+ python tests/test_estimator.py
16
+ ```
17
+
18
+ ### Test Against HuggingFace Space
19
+
20
+ Use the client script to test the deployed API:
21
+
22
+ ```bash
23
+ python client.py
24
+ ```
25
+
26
+ ## Test Coverage
27
+
28
+ **test_estimator.py** tests:
29
+ 1. Estimator initialization
30
+ 2. Object registration with reference images
31
+ 3. Pose estimation on query images
32
+
33
+ The test uses images from `reference/target_cube/` to register an object, then randomly selects one image to test pose estimation.
tests/reference/target_cube/image_001.jpg ADDED
tests/reference/target_cube/image_002.jpg ADDED
tests/reference/target_cube/image_003.jpg ADDED
tests/reference/target_cube/image_004.jpg ADDED
tests/reference/target_cube/image_005.jpg ADDED
tests/reference/target_cube/image_006.jpg ADDED
tests/reference/target_cube/image_007.jpg ADDED
tests/reference/target_cube/image_008.jpg ADDED
tests/reference/target_cube/image_009.jpg ADDED
tests/reference/target_cube/image_010.jpg ADDED
tests/reference/target_cube/image_011.jpg ADDED
tests/reference/target_cube/image_012.jpg ADDED
tests/reference/target_cube/image_013.jpg ADDED
tests/reference/target_cube/image_014.jpg ADDED
tests/reference/target_cube/image_015.jpg ADDED
tests/test_estimator.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script for FoundationPose HuggingFace API.
3
+
4
+ This test verifies that the API can:
5
+ 1. Load reference images
6
+ 2. Initialize an object with reference images
7
+ 3. Estimate pose from a query image
8
+ """
9
+
10
+ import sys
11
+ from pathlib import Path
12
+ import random
13
+ import cv2
14
+
15
+ # Add parent directory to path to import client
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ from client import FoundationPoseClient
19
+
20
+
21
+ def load_reference_images(reference_dir: Path):
22
+ """Load all reference images from directory."""
23
+ image_files = sorted(reference_dir.glob("*.jpg"))
24
+ images = []
25
+
26
+ for img_path in image_files:
27
+ # Use cv2 to load images (same as client.py)
28
+ img = cv2.imread(str(img_path))
29
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
30
+ images.append(img)
31
+
32
+ return images, image_files
33
+
34
+
35
+ def test_client_initialization():
36
+ """Test that API client initializes without errors."""
37
+ print("=" * 60)
38
+ print("Test 1: API Client Initialization")
39
+ print("=" * 60)
40
+
41
+ try:
42
+ client = FoundationPoseClient(api_url="https://gpue-foundationpose.hf.space")
43
+ print("✓ API client initialized successfully")
44
+ return client
45
+ except Exception as e:
46
+ print(f"✗ API client initialization failed: {e}")
47
+ return None
48
+
49
+
50
+ def test_object_initialization(client, reference_images):
51
+ """Test object initialization with reference images via API."""
52
+ print("\n" + "=" * 60)
53
+ print("Test 2: Object Initialization via API")
54
+ print("=" * 60)
55
+
56
+ # Define camera intrinsics (typical values for RGB camera)
57
+ camera_intrinsics = {
58
+ "fx": 600.0,
59
+ "fy": 600.0,
60
+ "cx": 320.0,
61
+ "cy": 240.0
62
+ }
63
+
64
+ try:
65
+ success = client.initialize(
66
+ object_id="target_cube",
67
+ reference_images=reference_images,
68
+ camera_intrinsics=camera_intrinsics
69
+ )
70
+
71
+ if success:
72
+ print(f"✓ Object initialized successfully with {len(reference_images)} reference images")
73
+ return True
74
+ else:
75
+ print("✗ Object initialization failed")
76
+ return False
77
+ except Exception as e:
78
+ print(f"✗ Object initialization failed with exception: {e}")
79
+ import traceback
80
+ traceback.print_exc()
81
+ return False
82
+
83
+
84
+ def test_pose_estimation(client, query_image, query_name):
85
+ """Test pose estimation on a query image via API."""
86
+ print("\n" + "=" * 60)
87
+ print("Test 3: Pose Estimation via API")
88
+ print("=" * 60)
89
+ print(f"Query image: {query_name}")
90
+
91
+ # Define camera intrinsics (same as initialization)
92
+ camera_intrinsics = {
93
+ "fx": 600.0,
94
+ "fy": 600.0,
95
+ "cx": 320.0,
96
+ "cy": 240.0
97
+ }
98
+
99
+ try:
100
+ poses = client.estimate_pose(
101
+ object_id="target_cube",
102
+ query_image=query_image,
103
+ camera_intrinsics=camera_intrinsics
104
+ )
105
+
106
+ if poses and len(poses) > 0:
107
+ print(f"✓ Pose estimation completed successfully (detected {len(poses)} object(s))")
108
+
109
+ for i, pose in enumerate(poses):
110
+ print(f"\nDetected Object {i+1}:")
111
+ print(f" Position: x={pose['position']['x']:.3f}, "
112
+ f"y={pose['position']['y']:.3f}, "
113
+ f"z={pose['position']['z']:.3f}")
114
+ print(f" Orientation (quaternion): w={pose['orientation']['w']:.3f}, "
115
+ f"x={pose['orientation']['x']:.3f}, "
116
+ f"y={pose['orientation']['y']:.3f}, "
117
+ f"z={pose['orientation']['z']:.3f}")
118
+ print(f" Confidence: {pose['confidence']:.3f}")
119
+
120
+ return True
121
+ else:
122
+ print("✗ Pose estimation returned no detections")
123
+ return False
124
+ except Exception as e:
125
+ print(f"✗ Pose estimation failed with exception: {e}")
126
+ import traceback
127
+ traceback.print_exc()
128
+ return False
129
+
130
+
131
+ def main():
132
+ """Run all tests."""
133
+ print("\n" + "=" * 60)
134
+ print("FoundationPose HuggingFace API Test Suite")
135
+ print("=" * 60)
136
+
137
+ # Setup paths
138
+ test_dir = Path(__file__).parent
139
+ reference_dir = test_dir / "reference" / "target_cube"
140
+
141
+ if not reference_dir.exists():
142
+ print(f"✗ Reference directory not found: {reference_dir}")
143
+ return
144
+
145
+ # Load reference images
146
+ print(f"\nLoading reference images from: {reference_dir}")
147
+ reference_images, image_files = load_reference_images(reference_dir)
148
+ print(f"✓ Loaded {len(reference_images)} reference images")
149
+
150
+ # Test 1: Initialize API client
151
+ client = test_client_initialization()
152
+ if client is None:
153
+ print("\n" + "=" * 60)
154
+ print("TESTS ABORTED: API client initialization failed")
155
+ print("=" * 60)
156
+ return
157
+
158
+ # Test 2: Initialize object via API
159
+ success = test_object_initialization(client, reference_images)
160
+ if not success:
161
+ print("\n" + "=" * 60)
162
+ print("TESTS ABORTED: Object initialization failed")
163
+ print("=" * 60)
164
+ return
165
+
166
+ # Test 3: Estimate pose on a random reference image
167
+ random_idx = random.randint(0, len(reference_images) - 1)
168
+ query_image = reference_images[random_idx]
169
+ query_name = image_files[random_idx].name
170
+
171
+ success = test_pose_estimation(client, query_image, query_name)
172
+
173
+ # Print final results
174
+ print("\n" + "=" * 60)
175
+ if success:
176
+ print("ALL TESTS PASSED ✓")
177
+ else:
178
+ print("SOME TESTS FAILED ✗")
179
+ print("=" * 60)
180
+
181
+
182
+ if __name__ == "__main__":
183
+ main()