| """ |
| IO 工具模块 |
| |
| 图像和 JSON 文件的读写。 |
| """ |
|
|
| import json |
| import numpy as np |
| from pathlib import Path |
| from typing import Dict, Any, Union |
| import cv2 |
|
|
|
|
| def load_image(path: Union[str, Path]) -> np.ndarray: |
| """ |
| 加载图像 |
| |
| Args: |
| path: 图像路径 |
| |
| Returns: |
| img: (H, W, 3) RGB uint8 图像 |
| """ |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Image not found: {path}") |
| |
| |
| img = cv2.imread(str(path), cv2.IMREAD_COLOR) |
| if img is None: |
| raise ValueError(f"Failed to load image: {path}") |
| |
| |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| return img |
|
|
|
|
| def save_image(img: np.ndarray, path: Union[str, Path]) -> None: |
| """ |
| 保存图像 |
| |
| Args: |
| img: (H, W, 3) RGB uint8 图像 |
| path: 输出路径 |
| """ |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| |
| |
| if len(img.shape) == 3 and img.shape[2] == 3: |
| img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| else: |
| img_bgr = img |
| |
| cv2.imwrite(str(path), img_bgr) |
|
|
|
|
| def load_json(path: Union[str, Path]) -> Dict[str, Any]: |
| """ |
| 加载 JSON 文件 |
| |
| Args: |
| path: JSON 文件路径 |
| |
| Returns: |
| data: 解析后的字典 |
| """ |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"JSON not found: {path}") |
| |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def save_json(data: Dict[str, Any], path: Union[str, Path], indent: int = 2) -> None: |
| """ |
| 保存 JSON 文件 |
| |
| Args: |
| data: 要保存的数据 |
| path: 输出路径 |
| indent: 缩进 |
| """ |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(data, f, indent=indent, ensure_ascii=False) |
|
|
|
|
| def load_depth(path: Union[str, Path]) -> np.ndarray: |
| """ |
| 加载深度图(.npy 格式) |
| |
| Args: |
| path: 深度图路径 |
| |
| Returns: |
| depth: (H, W) float32 深度图 |
| """ |
| path = Path(path) |
| if not path.exists(): |
| raise FileNotFoundError(f"Depth not found: {path}") |
| |
| depth = np.load(str(path)) |
| return depth.astype(np.float32) |
|
|
|
|
| def save_depth(depth: np.ndarray, path: Union[str, Path]) -> None: |
| """ |
| 保存深度图(.npy 格式) |
| |
| Args: |
| depth: (H, W) 深度图 |
| path: 输出路径 |
| """ |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| |
| np.save(str(path), depth.astype(np.float32)) |
|
|
|
|
| def list_files( |
| directory: Union[str, Path], |
| pattern: str = "*", |
| sort: bool = True, |
| ) -> list: |
| """ |
| 列出目录中的文件 |
| |
| Args: |
| directory: 目录路径 |
| pattern: glob 模式 |
| sort: 是否排序 |
| |
| Returns: |
| files: 文件路径列表 |
| """ |
| directory = Path(directory) |
| if not directory.exists(): |
| return [] |
| |
| files = list(directory.glob(pattern)) |
| if sort: |
| files = sorted(files) |
| |
| return files |
|
|