| """ |
| 数据集格式验证脚本 |
| 用于验证 train_loader 加载的 input 和 target 格式 |
| 特别是验证 target[0] 是否为 [image_idx, class_id, x_center, y_center, width, height] |
| """ |
| import os |
| import sys |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| sys.path.append(BASE_DIR) |
|
|
| import torch |
| import torchvision.transforms as transforms |
| from lib.config import cfg |
| import lib.dataset as dataset |
| from lib.utils import DataLoaderX |
|
|
| def check_dataset_format(): |
| """验证数据集加载格式""" |
| |
| print("="*80) |
| print("开始验证数据集加载格式...") |
| print("="*80) |
| |
| |
| normalize = transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225] |
| ) |
| |
| |
| print("\n1. 创建数据集...") |
| train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
| cfg=cfg, |
| is_train=True, |
| inputsize=cfg.MODEL.IMAGE_SIZE, |
| transform=transforms.Compose([ |
| transforms.ToTensor(), |
| normalize, |
| ]) |
| ) |
| print(f" 数据集类型: {cfg.DATASET.DATASET}") |
| print(f" 数据集大小: {len(train_dataset)}") |
| |
| |
| if hasattr(train_dataset, 'names'): |
| print(f" 数据集类别: {train_dataset.names}") |
| print(f" 类别数量: {len(train_dataset.names)}") |
| else: |
| print(" 数据集没有 names 属性") |
| |
| if hasattr(train_dataset, "names"): |
| print(f" 数据集类别数量: {len(train_dataset.names)}") |
| else: |
| print(" 数据集不包含 names 属性,无法统计类别数量。") |
| |
| |
| print("\n2. 创建 DataLoader...") |
| train_loader = DataLoaderX( |
| train_dataset, |
| batch_size=4, |
| shuffle=False, |
| num_workers=0, |
| pin_memory=False, |
| collate_fn=dataset.AutoDriveDataset.collate_fn |
| ) |
| print(f" Batch size: ") |
| print(f" Total batches: {len(train_loader)}") |
| |
| |
| print("\n3. 加载第一个 batch...") |
| for i, (input, target, paths, shapes) in enumerate(train_loader): |
| print("\n" + "="*80) |
| print(f"Batch {i} 数据格式分析:") |
| print("="*80) |
| |
| |
| print("\n[INPUT - 图像数据]") |
| print(f" 类型: {type(input)}") |
| print(f" 形状: {input.shape}") |
| print(f" dtype: {input.dtype}") |
| print(f" 值范围: [{input.min():.3f}, {input.max():.3f}]") |
| |
| |
| print("\n[TARGET - 标注数据]") |
| print(f" 类型: {type(target)}") |
| print(f" 长度: {len(target)} (包含 3 个元素: det, da_seg, ll_seg)") |
| |
| |
| print(f"\n target[0] - 检测标签 (Detection Labels):") |
| print(f" 类型: {type(target[0])}") |
| print(f" 形状: {target[0].shape}") |
| print(f" dtype: {target[0].dtype}") |
| print(f" 说明: [N, 6] 其中 N 是所有图片的目标总数,6 维度为:") |
| print(f" [image_idx, class_id, x_center, y_center, width, height]") |
| |
| |
| if target[0].shape[0] > 0: |
| print(f"\n 前 5 个目标样本:") |
| print(f" {'索引':<6} {'img_idx':<10} {'class_id':<10} {'x_center':<12} {'y_center':<12} {'width':<12} {'height':<12}") |
| print(f" {'-'*76}") |
| for idx in range(min(5, target[0].shape[0])): |
| obj = target[0][idx] |
| print(f" {idx:<6} {obj[0].item():<10.0f} {obj[1].item():<10.0f} {obj[2].item():<12.6f} {obj[3].item():<12.6f} {obj[4].item():<12.6f} {obj[5].item():<12.6f}") |
| |
| |
| print(f"\n 验证坐标是否归一化到 [0, 1]:") |
| xywh_data = target[0][:, 2:] |
| print(f" x_center 范围: [{xywh_data[:, 0].min():.6f}, {xywh_data[:, 0].max():.6f}]") |
| print(f" y_center 范围: [{xywh_data[:, 1].min():.6f}, {xywh_data[:, 1].max():.6f}]") |
| print(f" width 范围: [{xywh_data[:, 2].min():.6f}, {xywh_data[:, 2].max():.6f}]") |
| print(f" height 范围: [{xywh_data[:, 3].min():.6f}, {xywh_data[:, 3].max():.6f}]") |
| |
| |
| is_normalized = (xywh_data >= 0).all() and (xywh_data <= 1).all() |
| if is_normalized: |
| print(f" ✓ 坐标已归一化到 [0, 1]") |
| else: |
| print(f" ✗ 警告: 坐标未完全归一化!") |
| |
| |
| print(f"\n 每张图片的目标数量:") |
| for img_idx in range(input.shape[0]): |
| count = (target[0][:, 0] == img_idx).sum().item() |
| print(f" 图片 {img_idx}: {count} 个目标") |
| else: |
| print(f" (该 batch 没有检测目标)") |
| |
| |
| print(f"\n target[1] - 驾驶区域分割标签 (Drivable Area Segmentation):") |
| print(f" 类型: {type(target[1])}") |
| print(f" 形状: {target[1].shape}") |
| print(f" dtype: {target[1].dtype}") |
| print(f" 值范围: [{target[1].min():.3f}, {target[1].max():.3f}]") |
| print(f" 说明: [batch_size, num_classes, H, W]") |
| |
| |
| print(f"\n target[2] - 车道线分割标签 (Lane Line Segmentation):") |
| print(f" 类型: {type(target[2])}") |
| print(f" 形状: {target[2].shape}") |
| print(f" dtype: {target[2].dtype}") |
| print(f" 值范围: [{target[2].min():.3f}, {target[2].max():.3f}]") |
| print(f" 说明: [batch_size, num_classes, H, W]") |
| |
| |
| print(f"\n[PATHS - 图像路径]") |
| print(f" 类型: {type(paths)}") |
| print(f" 长度: {len(paths)}") |
| if len(paths) > 0: |
| print(f" 示例路径:") |
| for idx, path in enumerate(paths): |
| print(f" [{idx}] {path}") |
| |
| |
| print(f"\n[SHAPES - 图像尺寸信息]") |
| print(f" 类型: {type(shapes)}") |
| print(f" 长度: {len(shapes)}") |
| if len(shapes) > 0: |
| print(f" 示例 (原始尺寸, ((缩放比例), (padding))):") |
| for idx, shape in enumerate(shapes[:2]): |
| print(f" [{idx}] {shape}") |
| |
| print("\n" + "="*80) |
| print("验证结论:") |
| print("="*80) |
| print("✓ target[0] 格式为: [image_idx, class_id, x_center, y_center, width, height]") |
| print("✓ xywh 坐标已归一化到 [0, 1]") |
| print("✓ image_idx 用于区分 batch 中不同图片的目标") |
| print("✓ class_id 表示目标类别") |
| print("="*80) |
| |
| |
| break |
| |
| print("\n验证完成!") |
|
|
|
|
| if __name__ == '__main__': |
| check_dataset_format() |
|
|