kaveh commited on
Commit
bae9663
·
1 Parent(s): f537675

fix def loss

Browse files
Files changed (2) hide show
  1. training/s2f_trainer.py +8 -3
  2. utils/metrics.py +30 -0
training/s2f_trainer.py CHANGED
@@ -16,12 +16,17 @@ if S2F_ROOT not in sys.path:
16
 
17
  from models.s2f_model import create_settings_channels
18
  from utils.substrate_settings import compute_settings_normalization
19
- from utils.metrics import calculate_psnr, calculate_ssim_tensor, calculate_pearson_correlation
 
 
 
 
 
20
  from scipy.stats import pearsonr
21
 
22
 
23
  class S2FLoss(nn.Module):
24
- """S2F loss: reconstruction (L1) + GAN + optional force consistency."""
25
  def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
26
  gan_mode='vanilla', custom_loss=None, use_force_consistency=False,
27
  force_consistency_target='mean'):
@@ -32,7 +37,7 @@ class S2FLoss(nn.Module):
32
  self.gan_mode = gan_mode
33
  self.use_force_consistency = use_force_consistency
34
  self.force_consistency_target = force_consistency_target
35
- self.reconstruction_loss = custom_loss if custom_loss is not None else nn.L1Loss()
36
  self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None
37
  self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss()
38
 
 
16
 
17
  from models.s2f_model import create_settings_channels
18
  from utils.substrate_settings import compute_settings_normalization
19
+ from utils.metrics import (
20
+ WFMRMELoss,
21
+ calculate_psnr,
22
+ calculate_ssim_tensor,
23
+ calculate_pearson_correlation,
24
+ )
25
  from scipy.stats import pearsonr
26
 
27
 
28
  class S2FLoss(nn.Module):
29
+ """S2F loss: reconstruction (WFM-RME by default) + GAN + optional force consistency."""
30
  def __init__(self, lambda_L1=100.0, lambda_gan=1.0, lambda_force=1.0,
31
  gan_mode='vanilla', custom_loss=None, use_force_consistency=False,
32
  force_consistency_target='mean'):
 
37
  self.gan_mode = gan_mode
38
  self.use_force_consistency = use_force_consistency
39
  self.force_consistency_target = force_consistency_target
40
+ self.reconstruction_loss = custom_loss if custom_loss is not None else WFMRMELoss()
41
  self.force_consistency_loss = nn.MSELoss() if use_force_consistency else None
42
  self.gan_loss = nn.BCEWithLogitsLoss() if gan_mode == 'vanilla' else nn.MSELoss()
43
 
utils/metrics.py CHANGED
@@ -106,6 +106,36 @@ def _force_mag_wfm(f):
106
  return np.sqrt(fx**2 + fy**2)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def wfm_correlation(y_true, y_pred, mode="magnitude"):
110
  """Pearson correlation between prediction and ground truth (magnitude mode for heatmaps)."""
111
  t = _ensure_shape_wfm(_to_numpy_wfm(y_true))
 
106
  return np.sqrt(fx**2 + fy**2)
107
 
108
 
109
+ def _force_magnitude_tensor(x: torch.Tensor) -> torch.Tensor:
110
+ """Per-pixel force magnitude: (B,1,H,W) uses channel 0 as magnitude; (B,2+,H,W) uses sqrt(fx^2+fy^2)."""
111
+ if x.dim() != 4:
112
+ raise ValueError(f"Expected 4D tensor, got shape {tuple(x.shape)}")
113
+ c = x.size(1)
114
+ if c == 1:
115
+ return x[:, 0]
116
+ if c >= 2:
117
+ return torch.sqrt(x[:, 0].pow(2) + x[:, 1].pow(2))
118
+ raise ValueError(f"Expected at least 1 channel, got {c}")
119
+
120
+
121
+ class WFMRMELoss(nn.Module):
122
+ """Weighted force-magnitude relative error; matches ``wfm_relative_magnitude_error`` (differentiable)."""
123
+
124
+ def __init__(self, eps: float = 1e-8):
125
+ super().__init__()
126
+ self.eps = eps
127
+
128
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
129
+ mag_t = _force_magnitude_tensor(target)
130
+ mag_p = _force_magnitude_tensor(pred)
131
+ if mag_t.shape != mag_p.shape:
132
+ raise ValueError(f"Shape mismatch after magnitude: {mag_t.shape} vs {mag_p.shape}")
133
+ fbar = mag_t.mean().clamp_min(self.eps)
134
+ w = mag_t / fbar
135
+ rel = (mag_p - mag_t).abs() / (mag_t + self.eps)
136
+ return (w * rel).mean()
137
+
138
+
139
  def wfm_correlation(y_true, y_pred, mode="magnitude"):
140
  """Pearson correlation between prediction and ground truth (magnitude mode for heatmaps)."""
141
  t = _ensure_shape_wfm(_to_numpy_wfm(y_true))