dreamlessx commited on
Commit
d847b3c
·
verified ·
1 Parent(s): e316420

Upload landmarkdiff/api_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/api_client.py +242 -0
landmarkdiff/api_client.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python client for the LandmarkDiff REST API.
2
+
3
+ Provides a clean interface for interacting with the FastAPI server,
4
+ handling image encoding/decoding, error handling, and session management.
5
+
6
+ Usage:
7
+ from landmarkdiff.api_client import LandmarkDiffClient
8
+
9
+ client = LandmarkDiffClient("http://localhost:8000")
10
+
11
+ # Single prediction
12
+ result = client.predict("patient.png", procedure="rhinoplasty", intensity=65)
13
+ result.save("output.png")
14
+
15
+ # Face analysis
16
+ analysis = client.analyze("patient.png")
17
+ print(f"Fitzpatrick type: {analysis['fitzpatrick_type']}")
18
+
19
+ # Batch processing
20
+ results = client.batch_predict(
21
+ ["patient1.png", "patient2.png"],
22
+ procedure="blepharoplasty",
23
+ )
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import base64
29
+ import io
30
+ from dataclasses import dataclass, field
31
+ from pathlib import Path
32
+ from typing import Any
33
+
34
+ import cv2
35
+ import numpy as np
36
+
37
+
38
+ @dataclass
39
+ class PredictionResult:
40
+ """Result from a single prediction."""
41
+
42
+ output_image: np.ndarray
43
+ procedure: str
44
+ intensity: float
45
+ confidence: float = 0.0
46
+ landmarks_before: list | None = None
47
+ landmarks_after: list | None = None
48
+ metrics: dict[str, float] = field(default_factory=dict)
49
+ metadata: dict[str, Any] = field(default_factory=dict)
50
+
51
+ def save(self, path: str | Path, fmt: str = ".png") -> None:
52
+ """Save the output image to a file."""
53
+ cv2.imwrite(str(path), self.output_image)
54
+
55
+ def show(self) -> None:
56
+ """Display the output image (requires GUI)."""
57
+ cv2.imshow("LandmarkDiff Prediction", self.output_image)
58
+ cv2.waitKey(0)
59
+ cv2.destroyAllWindows()
60
+
61
+
62
+ class LandmarkDiffClient:
63
+ """Client for the LandmarkDiff REST API.
64
+
65
+ Args:
66
+ base_url: Server URL (e.g. "http://localhost:8000").
67
+ timeout: Request timeout in seconds.
68
+ """
69
+
70
+ def __init__(self, base_url: str = "http://localhost:8000", timeout: float = 60.0) -> None:
71
+ self.base_url = base_url.rstrip("/")
72
+ self.timeout = timeout
73
+ self._session = None
74
+
75
+ def _get_session(self):
76
+ """Lazy-initialize requests session."""
77
+ if self._session is None:
78
+ try:
79
+ import requests
80
+ except ImportError:
81
+ raise ImportError("requests required. Install with: pip install requests")
82
+ self._session = requests.Session()
83
+ self._session.timeout = self.timeout
84
+ return self._session
85
+
86
+ def _read_image(self, image_path: str | Path) -> bytes:
87
+ """Read image file as bytes."""
88
+ path = Path(image_path)
89
+ if not path.exists():
90
+ raise FileNotFoundError(f"Image not found: {path}")
91
+ return path.read_bytes()
92
+
93
+ def _decode_base64_image(self, b64_string: str) -> np.ndarray:
94
+ """Decode a base64-encoded image to numpy array."""
95
+ img_bytes = base64.b64decode(b64_string)
96
+ arr = np.frombuffer(img_bytes, np.uint8)
97
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
98
+ if img is None:
99
+ raise ValueError("Failed to decode base64 image")
100
+ return img
101
+
102
+ # ------------------------------------------------------------------
103
+ # API methods
104
+ # ------------------------------------------------------------------
105
+
106
+ def health(self) -> dict[str, Any]:
107
+ """Check server health.
108
+
109
+ Returns:
110
+ Dict with status and version info.
111
+ """
112
+ session = self._get_session()
113
+ resp = session.get(f"{self.base_url}/health")
114
+ resp.raise_for_status()
115
+ return resp.json()
116
+
117
+ def procedures(self) -> list[str]:
118
+ """List available surgical procedures.
119
+
120
+ Returns:
121
+ List of procedure names.
122
+ """
123
+ session = self._get_session()
124
+ resp = session.get(f"{self.base_url}/procedures")
125
+ resp.raise_for_status()
126
+ return resp.json().get("procedures", [])
127
+
128
+ def predict(
129
+ self,
130
+ image_path: str | Path,
131
+ procedure: str = "rhinoplasty",
132
+ intensity: float = 65.0,
133
+ seed: int = 42,
134
+ ) -> PredictionResult:
135
+ """Run surgical outcome prediction.
136
+
137
+ Args:
138
+ image_path: Path to input face image.
139
+ procedure: Surgical procedure type.
140
+ intensity: Intensity of the modification (0-100).
141
+ seed: Random seed for reproducibility.
142
+
143
+ Returns:
144
+ PredictionResult with output image and metadata.
145
+ """
146
+ session = self._get_session()
147
+ image_bytes = self._read_image(image_path)
148
+
149
+ files = {"image": ("image.png", image_bytes, "image/png")}
150
+ data = {
151
+ "procedure": procedure,
152
+ "intensity": str(intensity),
153
+ "seed": str(seed),
154
+ }
155
+
156
+ resp = session.post(f"{self.base_url}/predict", files=files, data=data)
157
+ resp.raise_for_status()
158
+ result = resp.json()
159
+
160
+ # Decode output image
161
+ output_img = self._decode_base64_image(result["output_image"])
162
+
163
+ return PredictionResult(
164
+ output_image=output_img,
165
+ procedure=procedure,
166
+ intensity=intensity,
167
+ confidence=result.get("confidence", 0.0),
168
+ metrics=result.get("metrics", {}),
169
+ metadata=result.get("metadata", {}),
170
+ )
171
+
172
+ def analyze(self, image_path: str | Path) -> dict[str, Any]:
173
+ """Analyze a face image without generating a prediction.
174
+
175
+ Returns face landmarks, Fitzpatrick type, pose estimation, etc.
176
+
177
+ Args:
178
+ image_path: Path to input face image.
179
+
180
+ Returns:
181
+ Dict with analysis results.
182
+ """
183
+ session = self._get_session()
184
+ image_bytes = self._read_image(image_path)
185
+
186
+ files = {"image": ("image.png", image_bytes, "image/png")}
187
+ resp = session.post(f"{self.base_url}/analyze", files=files)
188
+ resp.raise_for_status()
189
+ return resp.json()
190
+
191
+ def batch_predict(
192
+ self,
193
+ image_paths: list[str | Path],
194
+ procedure: str = "rhinoplasty",
195
+ intensity: float = 65.0,
196
+ seed: int = 42,
197
+ ) -> list[PredictionResult]:
198
+ """Run batch prediction on multiple images.
199
+
200
+ Args:
201
+ image_paths: List of image file paths.
202
+ procedure: Procedure to apply to all images.
203
+ intensity: Intensity for all images.
204
+ seed: Base random seed.
205
+
206
+ Returns:
207
+ List of PredictionResult objects.
208
+ """
209
+ results = []
210
+ for i, path in enumerate(image_paths):
211
+ try:
212
+ result = self.predict(
213
+ path,
214
+ procedure=procedure,
215
+ intensity=intensity,
216
+ seed=seed + i,
217
+ )
218
+ results.append(result)
219
+ except Exception as e:
220
+ # Create a failed result
221
+ results.append(PredictionResult(
222
+ output_image=np.zeros((512, 512, 3), dtype=np.uint8),
223
+ procedure=procedure,
224
+ intensity=intensity,
225
+ metadata={"error": str(e), "path": str(path)},
226
+ ))
227
+ return results
228
+
229
+ def close(self) -> None:
230
+ """Close the HTTP session."""
231
+ if self._session is not None:
232
+ self._session.close()
233
+ self._session = None
234
+
235
+ def __enter__(self):
236
+ return self
237
+
238
+ def __exit__(self, *args):
239
+ self.close()
240
+
241
+ def __repr__(self) -> str:
242
+ return f"LandmarkDiffClient(base_url='{self.base_url}')"