dongjoaquin commited on
Commit
0522555
·
verified ·
1 Parent(s): 3720817

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +217 -0
inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RadioUNet V3 推理脚本
3
+ 使用训练好的模型对SoundMapDiff数据集进行推理
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import torch
10
+ import numpy as np
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
+ from pathlib import Path
14
+ from skimage.metrics import structural_similarity as ssim
15
+
16
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'lib'))
17
+
18
+ from lib.modules import RadioWNet
19
+ from lib.soundmap_loader import SoundMapDataset
20
+ from torch.utils.data import DataLoader
21
+
22
+
23
+ def calculate_metrics(pred, target):
24
+ """计算评估指标"""
25
+ pred_np = pred.cpu().numpy().squeeze()
26
+ target_np = target.cpu().numpy().squeeze()
27
+
28
+ # MSE
29
+ mse = np.mean((pred_np - target_np) ** 2)
30
+
31
+ # MAE
32
+ mae = np.mean(np.abs(pred_np - target_np))
33
+
34
+ # RMSE
35
+ rmse = np.sqrt(mse)
36
+
37
+ # SSIM
38
+ ssim_val = ssim(pred_np, target_np, data_range=1.0)
39
+
40
+ # PSNR
41
+ if mse > 0:
42
+ psnr = 10 * np.log10(1.0 / mse)
43
+ else:
44
+ psnr = float('inf')
45
+
46
+ return {
47
+ 'mse': mse,
48
+ 'mae': mae,
49
+ 'rmse': rmse,
50
+ 'ssim': ssim_val,
51
+ 'psnr': psnr
52
+ }
53
+
54
+
55
+ def visualize_prediction(inputs, target, pred, metrics, save_path):
56
+ """可视化预测结果"""
57
+ fig, axes = plt.subplots(2, 2, figsize=(12, 12))
58
+
59
+ # 建筑物布局
60
+ axes[0, 0].imshow(inputs[0].cpu().numpy(), cmap='gray')
61
+ axes[0, 0].set_title('Building Layout', fontsize=14)
62
+ axes[0, 0].axis('off')
63
+
64
+ # 声源位置
65
+ axes[0, 1].imshow(inputs[1].cpu().numpy(), cmap='hot')
66
+ axes[0, 1].set_title('Sound Source', fontsize=14)
67
+ axes[0, 1].axis('off')
68
+
69
+ # 真实热力图 - 使用viridis颜色方案(紫→蓝→绿→黄)
70
+ im1 = axes[1, 0].imshow(target.cpu().numpy().squeeze(), cmap='viridis', vmin=0, vmax=1)
71
+ axes[1, 0].set_title('Ground Truth', fontsize=14)
72
+ axes[1, 0].axis('off')
73
+ plt.colorbar(im1, ax=axes[1, 0], fraction=0.046, pad=0.04)
74
+
75
+ # 预测热力图 - 使用viridis颜色方案(紫→蓝→绿→黄)
76
+ im2 = axes[1, 1].imshow(pred.cpu().numpy().squeeze(), cmap='viridis', vmin=0, vmax=1)
77
+ axes[1, 1].set_title(f"Prediction (SSIM: {metrics['ssim']:.4f})", fontsize=14)
78
+ axes[1, 1].axis('off')
79
+ plt.colorbar(im2, ax=axes[1, 1], fraction=0.046, pad=0.04)
80
+
81
+ # 添加指标信息
82
+ metrics_text = f"MSE: {metrics['mse']:.6f} | MAE: {metrics['mae']:.4f} | SSIM: {metrics['ssim']:.4f} | PSNR: {metrics['psnr']:.2f} dB"
83
+ fig.suptitle(metrics_text, fontsize=12, y=0.02)
84
+
85
+ plt.tight_layout()
86
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
87
+ plt.close()
88
+
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description='RadioUNet V3 推理脚本')
92
+ parser.add_argument('--checkpoint', type=str,
93
+ default='outputs/radiounet_v3/checkpoints/best_model.pth',
94
+ help='模型检查点路径')
95
+ parser.add_argument('--dataset_dir', type=str,
96
+ default='/home/djk/generate/dataset/SoundMapDiff',
97
+ help='数据集目录')
98
+ parser.add_argument('--output_dir', type=str,
99
+ default='outputs/radiounet_v3/inference',
100
+ help='输出目录')
101
+ parser.add_argument('--num_samples', type=int, default=20,
102
+ help='推理样本数量')
103
+ parser.add_argument('--img_size', type=int, default=256,
104
+ help='图像尺寸')
105
+
106
+ args = parser.parse_args()
107
+
108
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
109
+ print(f"使用设备: {device}")
110
+
111
+ # 创建输出目录
112
+ output_dir = Path(args.output_dir)
113
+ output_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ # 加载模型
116
+ print(f"加载模型: {args.checkpoint}")
117
+ model = RadioWNet(inputs=2, phase="firstU").to(device)
118
+
119
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False)
120
+ if 'model_state_dict' in checkpoint:
121
+ model.load_state_dict(checkpoint['model_state_dict'])
122
+ print(f"加载Epoch {checkpoint.get('epoch', 'unknown')}的模型")
123
+ else:
124
+ model.load_state_dict(checkpoint)
125
+
126
+ model.eval()
127
+
128
+ # 加载测试数据集
129
+ print(f"加载数据集: {args.dataset_dir}")
130
+ test_dataset = SoundMapDataset(
131
+ dataset_dir=args.dataset_dir,
132
+ phase="test",
133
+ img_size=args.img_size
134
+ )
135
+
136
+ # 均匀采样
137
+ total_samples = len(test_dataset)
138
+ indices = np.linspace(0, total_samples - 1, args.num_samples, dtype=int)
139
+
140
+ print(f"测试样本数: {total_samples}, 采样数: {args.num_samples}")
141
+ print(f"\n{'='*60}")
142
+ print("开始推理...")
143
+ print(f"{'='*60}\n")
144
+
145
+ all_metrics = []
146
+
147
+ with torch.no_grad():
148
+ for i, idx in enumerate(indices):
149
+ inputs, target = test_dataset[idx]
150
+ inputs = inputs.unsqueeze(0).to(device)
151
+ target = target.unsqueeze(0).to(device)
152
+
153
+ # 推理
154
+ outputs = model(inputs)
155
+ if isinstance(outputs, list):
156
+ outputs = outputs[0]
157
+
158
+ # 计算指标
159
+ metrics = calculate_metrics(outputs.squeeze(0), target.squeeze(0))
160
+ all_metrics.append(metrics)
161
+
162
+ # 可视化
163
+ save_path = output_dir / f'prediction_{i+1}_idx{idx}.png'
164
+ visualize_prediction(inputs.squeeze(0), target.squeeze(0),
165
+ outputs.squeeze(0), metrics, save_path)
166
+
167
+ print(f"样本 {i+1}/{args.num_samples} (idx={idx}): "
168
+ f"SSIM={metrics['ssim']:.4f}, MSE={metrics['mse']:.6f}, PSNR={metrics['psnr']:.2f}dB")
169
+
170
+ # 计算平均指标
171
+ avg_metrics = {
172
+ 'mse': np.mean([m['mse'] for m in all_metrics]),
173
+ 'mae': np.mean([m['mae'] for m in all_metrics]),
174
+ 'rmse': np.mean([m['rmse'] for m in all_metrics]),
175
+ 'ssim': np.mean([m['ssim'] for m in all_metrics]),
176
+ 'psnr': np.mean([m['psnr'] for m in all_metrics])
177
+ }
178
+
179
+ print(f"\n{'='*60}")
180
+ print("平均评估指标")
181
+ print(f"{'='*60}")
182
+ print(f" 平均 MSE: {avg_metrics['mse']:.6f}")
183
+ print(f" 平均 MAE: {avg_metrics['mae']:.4f}")
184
+ print(f" 平均 RMSE: {avg_metrics['rmse']:.4f}")
185
+ print(f" 平均 SSIM: {avg_metrics['ssim']:.4f}")
186
+ print(f" 平均 PSNR: {avg_metrics['psnr']:.2f} dB")
187
+ print(f"{'='*60}")
188
+
189
+ # 保存报告
190
+ report_path = output_dir / 'evaluation_report.txt'
191
+ with open(report_path, 'w', encoding='utf-8') as f:
192
+ f.write("RadioUNet V3 评估报告\n")
193
+ f.write("=" * 60 + "\n\n")
194
+ f.write(f"模型: {args.checkpoint}\n")
195
+ f.write(f"测试样本数: {args.num_samples}\n\n")
196
+
197
+ for i, (idx, m) in enumerate(zip(indices, all_metrics)):
198
+ f.write(f"样本 {i+1} (索引 {idx}):\n")
199
+ f.write(f" MSE: {m['mse']:.6f}\n")
200
+ f.write(f" MAE: {m['mae']:.4f}\n")
201
+ f.write(f" SSIM: {m['ssim']:.4f}\n")
202
+ f.write(f" PSNR: {m['psnr']:.2f} dB\n\n")
203
+
204
+ f.write("=" * 60 + "\n")
205
+ f.write("平均指标:\n")
206
+ f.write("=" * 60 + "\n")
207
+ f.write(f" 平均 MSE: {avg_metrics['mse']:.6f}\n")
208
+ f.write(f" 平均 MAE: {avg_metrics['mae']:.4f}\n")
209
+ f.write(f" 平均 RMSE: {avg_metrics['rmse']:.4f}\n")
210
+ f.write(f" 平均 SSIM: {avg_metrics['ssim']:.4f}\n")
211
+ f.write(f" 平均 PSNR: {avg_metrics['psnr']:.2f} dB\n")
212
+
213
+ print(f"\n✅ 推理完成!结果保存在: {output_dir}")
214
+
215
+
216
+ if __name__ == '__main__':
217
+ main()