fix def loss
Browse files- training/s2f_trainer.py +8 -3
- utils/metrics.py +30 -0
training/s2f_trainer.py
CHANGED
|
@@ -16,12 +16,17 @@ if S2F_ROOT not in sys.path:
|
|
| 16 |
|
| 17 |
from models.s2f_model import create_settings_channels
|
| 18 |
from utils.substrate_settings import compute_settings_normalization
|
| 19 |
-
from utils.metrics import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from scipy.stats import pearsonr
|
| 21 |
|
| 22 |
|
| 23 |
class S2FLoss(nn.Module):
|
| 24 |
-
"""S2F loss: reconstruction (
|
| 25 |
def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
|
| 26 |
gan_mode='vanilla', custom_loss=None, use_force_consistency=False,
|
| 27 |
force_consistency_target='mean'):
|
|
@@ -32,7 +37,7 @@ class S2FLoss(nn.Module):
|
|
| 32 |
self.gan_mode = gan_mode
|
| 33 |
self.use_force_consistency = use_force_consistency
|
| 34 |
self.force_consistency_target = force_consistency_target
|
| 35 |
-
self.reconstruction_loss = custom_loss if custom_loss is not None else
|
| 36 |
self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None
|
| 37 |
self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss()
|
| 38 |
|
|
|
|
| 16 |
|
| 17 |
from models.s2f_model import create_settings_channels
|
| 18 |
from utils.substrate_settings import compute_settings_normalization
|
| 19 |
+
from utils.metrics import (
|
| 20 |
+
WFMRMELoss,
|
| 21 |
+
calculate_psnr,
|
| 22 |
+
calculate_ssim_tensor,
|
| 23 |
+
calculate_pearson_correlation,
|
| 24 |
+
)
|
| 25 |
from scipy.stats import pearsonr
|
| 26 |
|
| 27 |
|
| 28 |
class S2FLoss(nn.Module):
|
| 29 |
+
"""S2F loss: reconstruction (WFM-RME by default) + GAN + optional force consistency."""
|
| 30 |
def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
|
| 31 |
gan_mode='vanilla', custom_loss=None, use_force_consistency=False,
|
| 32 |
force_consistency_target='mean'):
|
|
|
|
| 37 |
self.gan_mode = gan_mode
|
| 38 |
self.use_force_consistency = use_force_consistency
|
| 39 |
self.force_consistency_target = force_consistency_target
|
| 40 |
+
self.reconstruction_loss = custom_loss if custom_loss is not None else WFMRMELoss()
|
| 41 |
self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None
|
| 42 |
self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss()
|
| 43 |
|
utils/metrics.py
CHANGED
|
@@ -106,6 +106,36 @@ def _force_mag_wfm(f):
|
|
| 106 |
return np.sqrt(fx**2 + fy**2)
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def wfm_correlation(y_true, y_pred, mode="magnitude"):
|
| 110 |
"""Pearson correlation between prediction and ground truth (magnitude mode for heatmaps)."""
|
| 111 |
t = _ensure_shape_wfm(_to_numpy_wfm(y_true))
|
|
|
|
| 106 |
return np.sqrt(fx**2 + fy**2)
|
| 107 |
|
| 108 |
|
| 109 |
+
def _force_magnitude_tensor(x: torch.Tensor) -> torch.Tensor:
|
| 110 |
+
"""Per-pixel force magnitude: (B,1,H,W) uses channel 0 as magnitude; (B,2+,H,W) uses sqrt(fx^2+fy^2)."""
|
| 111 |
+
if x.dim() != 4:
|
| 112 |
+
raise ValueError(f"Expected 4D tensor, got shape {tuple(x.shape)}")
|
| 113 |
+
c = x.size(1)
|
| 114 |
+
if c == 1:
|
| 115 |
+
return x[:, 0]
|
| 116 |
+
if c >= 2:
|
| 117 |
+
return torch.sqrt(x[:, 0].pow(2) + x[:, 1].pow(2))
|
| 118 |
+
raise ValueError(f"Expected at least 1 channel, got {c}")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class WFMRMELoss(nn.Module):
|
| 122 |
+
"""Weighted force-magnitude relative error; matches ``wfm_relative_magnitude_error`` (differentiable)."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, eps: float = 1e-8):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.eps = eps
|
| 127 |
+
|
| 128 |
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 129 |
+
mag_t = _force_magnitude_tensor(target)
|
| 130 |
+
mag_p = _force_magnitude_tensor(pred)
|
| 131 |
+
if mag_t.shape != mag_p.shape:
|
| 132 |
+
raise ValueError(f"Shape mismatch after magnitude: {mag_t.shape} vs {mag_p.shape}")
|
| 133 |
+
fbar = mag_t.mean().clamp_min(self.eps)
|
| 134 |
+
w = mag_t / fbar
|
| 135 |
+
rel = (mag_p - mag_t).abs() / (mag_t + self.eps)
|
| 136 |
+
return (w * rel).mean()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
def wfm_correlation(y_true, y_pred, mode="magnitude"):
|
| 140 |
"""Pearson correlation between prediction and ground truth (magnitude mode for heatmaps)."""
|
| 141 |
t = _ensure_shape_wfm(_to_numpy_wfm(y_true))
|