molecular-shadows-h2-v10 / observable_regressor.py
aniketdesh's picture
upload molecular-shadows-h2-v10 (v10)
d8e765d verified
"""
Direct observable regression: f_θ(R, t) → {⟨Γ_μ⟩(R, t)}.
Bypasses the shadow/Q-conditioning pipeline entirely. Predicts signal
matrix entries directly using learnable Fourier features for time encoding.
Two variants:
- Shared frequencies (v2/v3): ω_k are global learnable parameters
- Geometry-conditioned (v4+): ω_k(R) = ω_k^{(0)} + g_φ(R)_k where g_φ
is a small MLP. Energy gaps depend on R, so the optimal Fourier basis
should too.
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
import numpy as np
import torch
import torch.nn as nn
@dataclass
class ObservableRegressorConfig:
n_observables: int = 120
d_hidden: int = 256
n_layers: int = 3
n_fourier: int = 64
fourier_scale: float = 10.0
conditioned_frequencies: bool = False
freq_net_hidden: int = 64
freq_net_layers: int = 2
n_orb_features: int = 0 # 0 = use scalar R; >0 = use HF orbital energies
adaptive_bandwidth: bool = False # ω_k(R) = ω_op(R) · sigmoid(freq_net(ε(R)))_k
def to_dict(self):
return asdict(self)
class ObservableRegressor(nn.Module):
"""Direct regression: (R, t) → K observable expectations.
Architecture:
1. Fourier features for t: sin(ω_k(R) * t), cos(ω_k(R) * t)
- Shared mode: ω_k are global learnable parameters
- Conditioned mode: ω_k(R) = ω_base_k + freq_net(R)_k
2. Input = [R, Fourier features] ∈ R^{1 + 2*n_fourier}
3. MLP with GELU activations → K outputs
"""
def __init__(self, config: ObservableRegressorConfig):
super().__init__()
self.config = config
# Base frequencies — log-uniform initialization for broad coverage
# Covers from ~0.05 to ~fourier_scale Eₕ
log_omega = torch.linspace(
np.log(0.05), np.log(config.fourier_scale), config.n_fourier
)
self.omega_base = nn.Parameter(log_omega.exp())
# Geometry-conditioned frequency shift (also used as the σ_k head in
# adaptive_bandwidth mode).
if config.adaptive_bandwidth and not config.conditioned_frequencies:
raise ValueError("adaptive_bandwidth requires conditioned_frequencies=True")
if config.conditioned_frequencies:
freq_in = config.n_orb_features if config.n_orb_features > 0 else 1
fn_layers = [nn.Linear(freq_in, config.freq_net_hidden), nn.GELU()]
for _ in range(config.freq_net_layers - 2):
fn_layers.extend([
nn.Linear(config.freq_net_hidden, config.freq_net_hidden),
nn.GELU(),
])
fn_layers.append(nn.Linear(config.freq_net_hidden, config.n_fourier))
self.freq_net = nn.Sequential(*fn_layers)
else:
self.freq_net = None
input_dim = 1 + 2 * config.n_fourier
layers = [nn.Linear(input_dim, config.d_hidden), nn.GELU()]
for _ in range(config.n_layers - 1):
layers.extend([nn.Linear(config.d_hidden, config.d_hidden), nn.GELU()])
layers.append(nn.Linear(config.d_hidden, config.n_observables))
self.net = nn.Sequential(*layers)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if self.freq_net is not None:
nn.init.zeros_(self.freq_net[-1].weight)
if config.adaptive_bandwidth:
# Spread initial σ_k across (0, 1) so initial ω_k tile [0, ω_op]
# rather than collapsing at ω_op/2.
init_sigma = torch.linspace(0.05, 0.95, config.n_fourier)
init_logits = torch.log(init_sigma / (1.0 - init_sigma))
self.freq_net[-1].bias.data.copy_(init_logits)
else:
# Original behavior: initial ω ≈ ω_base.
nn.init.zeros_(self.freq_net[-1].bias)
def forward(
self,
rt: torch.Tensor,
orb_energies: torch.Tensor = None,
omega_op: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
rt: (B, 2) tensor with [R, t] per sample
orb_energies: (B, n_orb) tensor of HF orbital energies, or None
omega_op: (B,) tensor of operational frequency ceilings ω_op(R),
required when config.adaptive_bandwidth is True.
Returns:
(B, K) predicted observable expectations
"""
R = rt[:, 0:1] # (B, 1)
t = rt[:, 1:2] # (B, 1)
if self.config.adaptive_bandwidth:
if omega_op is None:
raise ValueError("adaptive_bandwidth requires omega_op input")
x = orb_energies if (self.config.n_orb_features > 0 and orb_energies is not None) else R
sigma = torch.sigmoid(self.freq_net(x)) # (B, n_fourier) in (0, 1)
omega = omega_op[:, None] * sigma # (B, n_fourier) in (0, ω_op(R))
elif self.freq_net is not None:
if self.config.n_orb_features > 0 and orb_energies is not None:
omega = self.omega_base[None, :] + self.freq_net(orb_energies)
else:
omega = self.omega_base[None, :] + self.freq_net(R)
else:
omega = self.omega_base[None, :] # (1, n_fourier)
fourier = torch.cat(
[torch.sin(omega * t), torch.cos(omega * t)],
dim=-1,
) # (B, 2*n_fourier)
x = torch.cat([R, fourier], dim=-1)
return self.net(x)
def init_observable_regressor(
n_observables: int = 120,
d_hidden: int = 256,
n_layers: int = 3,
n_fourier: int = 64,
fourier_scale: float = 10.0,
conditioned_frequencies: bool = False,
freq_net_hidden: int = 64,
freq_net_layers: int = 2,
n_orb_features: int = 0,
adaptive_bandwidth: bool = False,
):
return ObservableRegressor(
ObservableRegressorConfig(
n_observables=n_observables,
d_hidden=d_hidden,
n_layers=n_layers,
n_fourier=n_fourier,
fourier_scale=fourier_scale,
conditioned_frequencies=conditioned_frequencies,
freq_net_hidden=freq_net_hidden,
freq_net_layers=freq_net_layers,
n_orb_features=n_orb_features,
adaptive_bandwidth=adaptive_bandwidth,
)
)