bdck commited on
Commit
09d788a
·
verified ·
1 Parent(s): ac25220

Upload learn_region_grow/io.py

Browse files
Files changed (1) hide show
  1. learn_region_grow/io.py +279 -0
learn_region_grow/io.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """I/O utilities for PLY, PCD, and NumPy point clouds."""
2
+
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from typing import Tuple, Optional
6
+
7
+
8
+ def load_ply(ply_path: str) -> Tuple[np.ndarray, np.ndarray]:
9
+ """
10
+ Load an ASCII PLY file and return points (N, 6: xyzrgb) and optionally normals.
11
+
12
+ Only supports 'vertex' elements with x,y,z and optionally r,g,b,nx,ny,nz.
13
+ Returns a (N, 3) xyz array and a (N, 6+) feature array (xyz + rgb + ...).
14
+ """
15
+ path = Path(ply_path)
16
+ if not path.is_file():
17
+ raise FileNotFoundError(f"PLY file not found: {ply_path}")
18
+
19
+ with open(path, 'r') as f:
20
+ header = []
21
+ while True:
22
+ line = f.readline().strip()
23
+ header.append(line)
24
+ if line.startswith("end_header"):
25
+ break
26
+ if not line:
27
+ raise ValueError("Invalid PLY header: missing end_header")
28
+
29
+ # Parse header to find vertex count and property names
30
+ n_vertices = 0
31
+ in_vertex = False
32
+ prop_names = []
33
+ for line in header:
34
+ if line.startswith("element vertex"):
35
+ parts = line.split()
36
+ n_vertices = int(parts[2])
37
+ in_vertex = True
38
+ prop_names = []
39
+ elif line.startswith("element"):
40
+ in_vertex = False
41
+ elif line.startswith("property") and in_vertex:
42
+ parts = line.split()
43
+ prop_names.append(parts[-1])
44
+
45
+ # Read body
46
+ data = np.loadtxt(path, skiprows=len(header), max_rows=n_vertices)
47
+ if data.ndim == 1:
48
+ data = data.reshape(1, -1)
49
+
50
+ # Map known properties
51
+ xyz = _get_props(data, prop_names, ["x", "y", "z"])
52
+ rgb = _get_props(data, prop_names, ["r", "g", "b"], dtype=np.uint8)
53
+ normals = _get_props(data, prop_names, ["nx", "ny", "nz"], optional=True)
54
+ return xyz, rgb, normals
55
+
56
+
57
+ def _get_props(data, prop_names, keys, optional=False, dtype=np.float32):
58
+ indices = []
59
+ for k in keys:
60
+ if k in prop_names:
61
+ indices.append(prop_names.index(k))
62
+ elif k.capitalize() in prop_names:
63
+ indices.append(prop_names.index(k.capitalize()))
64
+ elif optional:
65
+ return None
66
+ else:
67
+ raise ValueError(f"Property {k} not found in PLY file. Available: {prop_names}")
68
+ if indices:
69
+ return data[:, indices].astype(dtype)
70
+ return None
71
+
72
+
73
+ def load_pcd(pcd_path: str) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
74
+ """
75
+ Load a PCD (Point Cloud Data) file.
76
+
77
+ Supports ASCII PCD with fields: x y z rgb / rgba / r g b / normal_x normal_y normal_z.
78
+ Returns (xyz, rgb, normals) where rgb is uint8 (N,3) and normals is float32 (N,3) or None.
79
+ """
80
+ path = Path(pcd_path)
81
+ if not path.is_file():
82
+ raise FileNotFoundError(f"PCD file not found: {pcd_path}")
83
+
84
+ with open(path, 'r') as f:
85
+ lines = f.readlines()
86
+
87
+ # Find header end
88
+ header_end = 0
89
+ data_mode = "ascii"
90
+ fields = []
91
+ for i, line in enumerate(lines):
92
+ l = line.strip().split()
93
+ if not l:
94
+ continue
95
+ if l[0] == "DATA":
96
+ data_mode = l[1] if len(l) > 1 else "ascii"
97
+ header_end = i + 1
98
+ break
99
+ if l[0] == "FIELDS":
100
+ fields = l[1:]
101
+
102
+ if data_mode != "ascii":
103
+ raise NotImplementedError("Only ASCII PCD files are supported.")
104
+
105
+ data = np.loadtxt(lines[header_end:])
106
+ if data.ndim == 1:
107
+ data = data.reshape(1, -1)
108
+
109
+ xyz = _get_props(data, fields, ["x", "y", "z"])
110
+ rgb = None
111
+ normals = None
112
+
113
+ # Try RGB channels
114
+ if all(c in fields for c in ["r", "g", "b"]):
115
+ rgb = _get_props(data, fields, ["r", "g", "b"], dtype=np.uint8)
116
+ elif "rgb" in fields:
117
+ rgb_idx = fields.index("rgb")
118
+ rgb_packed = data[:, rgb_idx].astype(np.float32)
119
+ rgb = np.zeros((len(rgb_packed), 3), dtype=np.uint8)
120
+ # PCL packs RGB as a single float: int bits = 0x00RRGGBB
121
+ rgb_floats = rgb_packed.astype(np.float32)
122
+ rgb[:, 0] = ((rgb_floats.view(np.int32) >> 16) & 0xFF).astype(np.uint8)
123
+ rgb[:, 1] = ((rgb_floats.view(np.int32) >> 8) & 0xFF).astype(np.uint8)
124
+ rgb[:, 2] = ((rgb_floats.view(np.int32)) & 0xFF).astype(np.uint8)
125
+ elif "rgba" in fields:
126
+ rgba_idx = fields.index("rgba")
127
+ rgb_floats = data[:, rgba_idx].astype(np.float32)
128
+ rgb = np.zeros((len(rgb_floats), 3), dtype=np.uint8)
129
+ rgb[:, 0] = ((rgb_floats.view(np.int32) >> 16) & 0xFF).astype(np.uint8)
130
+ rgb[:, 1] = ((rgb_floats.view(np.int32) >> 8) & 0xFF).astype(np.uint8)
131
+ rgb[:, 2] = ((rgb_floats.view(np.int32)) & 0xFF).astype(np.uint8)
132
+
133
+ # Try normals
134
+ normals = _get_props(data, fields, ["normal_x", "normal_y", "normal_z"], optional=True)
135
+ if normals is None:
136
+ normals = _get_props(data, fields, ["nx", "ny", "nz"], optional=True)
137
+
138
+ return xyz, rgb, normals
139
+
140
+
141
+ def save_ply(ply_path: str, xyz: np.ndarray, rgb: np.ndarray = None,
142
+ labels: np.ndarray = None, normals: np.ndarray = None):
143
+ """
144
+ Save a point cloud as ASCII PLY.
145
+
146
+ Parameters
147
+ ----------
148
+ xyz : np.ndarray, shape (N, 3)
149
+ XYZ coordinates.
150
+ rgb : np.ndarray, shape (N, 3), uint8, optional
151
+ Per-point colors. If labels is given instead, rgb is ignored and
152
+ colors are looked up from label_to_color.
153
+ labels : np.ndarray, shape (N,), optional
154
+ Integer instance labels. Used to colorize output.
155
+ normals : np.ndarray, shape (N, 3), optional
156
+ Surface normals.
157
+ """
158
+ n = len(xyz)
159
+ if labels is not None:
160
+ rgb = label_to_color(labels)
161
+ if rgb is None:
162
+ rgb = np.full((n, 3), 128, dtype=np.uint8)
163
+
164
+ props = ["property float x", "property float y", "property float z"]
165
+ data_cols = [xyz[:, 0:1], xyz[:, 1:2], xyz[:, 2:3]]
166
+
167
+ if normals is not None:
168
+ props += ["property float nx", "property float ny", "property float nz"]
169
+ data_cols += [normals[:, 0:1], normals[:, 1:2], normals[:, 2:3]]
170
+
171
+ props += ["property uchar red", "property uchar green", "property uchar blue"]
172
+ data_cols += [rgb[:, 0:1], rgb[:, 1:2], rgb[:, 2:3]]
173
+
174
+ if labels is not None:
175
+ props += ["property int label"]
176
+ data_cols += [labels.reshape(-1, 1)]
177
+
178
+ header = f"""ply
179
+ format ascii 1.0
180
+ element vertex {n}
181
+ """ + "\n".join(props) + """
182
+ end_header
183
+ """
184
+ data = np.hstack(data_cols)
185
+ if labels is not None:
186
+ # last column is int, rest float/uint8 — save manually
187
+ lines = [header]
188
+ for i in range(n):
189
+ parts = [f"{data[i,j]:.6f}" for j in range(data.shape[1]-1)]
190
+ parts.append(f"{int(data[i,-1])}")
191
+ lines.append(" ".join(parts) + "\n")
192
+ with open(ply_path, 'w') as f:
193
+ f.writelines(lines)
194
+ else:
195
+ np.savetxt(ply_path, data, header=header.strip(), comments='', fmt='%.6f')
196
+
197
+
198
+ def save_pcd(pcd_path: str, xyz: np.ndarray, rgb: np.ndarray = None,
199
+ labels: np.ndarray = None, normals: np.ndarray = None):
200
+ """
201
+ Save a point cloud as ASCII PCD.
202
+ """
203
+ n = len(xyz)
204
+ if labels is not None:
205
+ rgb = label_to_color(labels)
206
+ if rgb is None:
207
+ rgb = np.full((n, 3), 128, dtype=np.uint8)
208
+
209
+ fields = ["x", "y", "z"]
210
+ types = ["F", "F", "F"]
211
+ sizes = [4, 4, 4]
212
+ counts = [1, 1, 1]
213
+
214
+ if normals is not None:
215
+ fields += ["normal_x", "normal_y", "normal_z"]
216
+ types += ["F", "F", "F"]
217
+ sizes += [4, 4, 4]
218
+ counts += [1, 1, 1]
219
+
220
+ fields += ["r", "g", "b"]
221
+ types += ["U", "U", "U"]
222
+ sizes += [1, 1, 1]
223
+ counts += [1, 1, 1]
224
+
225
+ data = np.hstack([xyz, rgb.astype(np.float32)])
226
+ if normals is not None:
227
+ data = np.hstack([xyz, normals, rgb.astype(np.float32)])
228
+
229
+ header = f"""# .PCD v0.7 - Point Cloud Data file format
230
+ VERSION 0.7
231
+ FIELDS {' '.join(fields)}
232
+ SIZE {' '.join(str(s) for s in sizes)}
233
+ TYPE {' '.join(types)}
234
+ COUNT {' '.join(str(c) for c in counts)}
235
+ WIDTH {n}
236
+ HEIGHT 1
237
+ VIEWPOINT 0 0 0 1 0 0 0
238
+ POINTS {n}
239
+ DATA ascii
240
+ """
241
+ np.savetxt(pcd_path, data, header=header.strip(), comments='', fmt='%.6f')
242
+
243
+
244
+ def load_point_cloud(path: str) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
245
+ """
246
+ Load a point cloud from either PLY or PCD format.
247
+
248
+ Returns
249
+ -------
250
+ xyz : np.ndarray, shape (N, 3), float32
251
+ rgb : np.ndarray, shape (N, 3), uint8
252
+ normals : np.ndarray, shape (N, 3), float32 or None
253
+ """
254
+ p = Path(path)
255
+ suffix = p.suffix.lower()
256
+ if suffix == ".ply":
257
+ return load_ply(path)
258
+ elif suffix == ".pcd":
259
+ return load_pcd(path)
260
+ else:
261
+ raise ValueError(f"Unsupported point cloud format: {suffix}. Use .ply or .pcd")
262
+
263
+
264
+ # ------------------------------------------------------------------
265
+ # Color map for instance labels (cyclic, similar to S3DIS / ScanNet)
266
+ # ------------------------------------------------------------------
267
+ def label_to_color(labels: np.ndarray) -> np.ndarray:
268
+ """Map integer instance labels to RGB colors."""
269
+ colors = np.array([
270
+ [0.8, 0.2, 0.2], [0.2, 0.8, 0.2], [0.2, 0.2, 0.8],
271
+ [0.8, 0.8, 0.2], [0.8, 0.2, 0.8], [0.2, 0.8, 0.8],
272
+ [0.5, 0.3, 0.1], [0.1, 0.5, 0.3], [0.3, 0.1, 0.5],
273
+ [0.9, 0.6, 0.1], [0.1, 0.9, 0.6], [0.6, 0.1, 0.9],
274
+ [0.4, 0.4, 0.4], [0.7, 0.7, 0.7], [0.3, 0.7, 0.5],
275
+ [0.7, 0.3, 0.5], [0.5, 0.7, 0.3], [0.3, 0.5, 0.7],
276
+ [0.9, 0.3, 0.1], [0.1, 0.3, 0.9], [0.3, 0.9, 0.1],
277
+ ], dtype=np.float32)
278
+ rgb = colors[labels % len(colors)] * 255
279
+ return rgb.astype(np.uint8)