""" Direct observable regression: f_θ(R, t) → {⟨Γ_μ⟩(R, t)}. Bypasses the shadow/Q-conditioning pipeline entirely. Predicts signal matrix entries directly using learnable Fourier features for time encoding. Two variants: - Shared frequencies (v2/v3): ω_k are global learnable parameters - Geometry-conditioned (v4+): ω_k(R) = ω_k^{(0)} + g_φ(R)_k where g_φ is a small MLP. Energy gaps depend on R, so the optimal Fourier basis should too. """ from __future__ import annotations from dataclasses import asdict, dataclass import numpy as np import torch import torch.nn as nn @dataclass class ObservableRegressorConfig: n_observables: int = 120 d_hidden: int = 256 n_layers: int = 3 n_fourier: int = 64 fourier_scale: float = 10.0 conditioned_frequencies: bool = False freq_net_hidden: int = 64 freq_net_layers: int = 2 n_orb_features: int = 0 # 0 = use scalar R; >0 = use HF orbital energies adaptive_bandwidth: bool = False # ω_k(R) = ω_op(R) · sigmoid(freq_net(ε(R)))_k def to_dict(self): return asdict(self) class ObservableRegressor(nn.Module): """Direct regression: (R, t) → K observable expectations. Architecture: 1. Fourier features for t: sin(ω_k(R) * t), cos(ω_k(R) * t) - Shared mode: ω_k are global learnable parameters - Conditioned mode: ω_k(R) = ω_base_k + freq_net(R)_k 2. Input = [R, Fourier features] ∈ R^{1 + 2*n_fourier} 3. MLP with GELU activations → K outputs """ def __init__(self, config: ObservableRegressorConfig): super().__init__() self.config = config # Base frequencies — log-uniform initialization for broad coverage # Covers from ~0.05 to ~fourier_scale Eₕ log_omega = torch.linspace( np.log(0.05), np.log(config.fourier_scale), config.n_fourier ) self.omega_base = nn.Parameter(log_omega.exp()) # Geometry-conditioned frequency shift (also used as the σ_k head in # adaptive_bandwidth mode). if config.adaptive_bandwidth and not config.conditioned_frequencies: raise ValueError("adaptive_bandwidth requires conditioned_frequencies=True") if config.conditioned_frequencies: freq_in = config.n_orb_features if config.n_orb_features > 0 else 1 fn_layers = [nn.Linear(freq_in, config.freq_net_hidden), nn.GELU()] for _ in range(config.freq_net_layers - 2): fn_layers.extend([ nn.Linear(config.freq_net_hidden, config.freq_net_hidden), nn.GELU(), ]) fn_layers.append(nn.Linear(config.freq_net_hidden, config.n_fourier)) self.freq_net = nn.Sequential(*fn_layers) else: self.freq_net = None input_dim = 1 + 2 * config.n_fourier layers = [nn.Linear(input_dim, config.d_hidden), nn.GELU()] for _ in range(config.n_layers - 1): layers.extend([nn.Linear(config.d_hidden, config.d_hidden), nn.GELU()]) layers.append(nn.Linear(config.d_hidden, config.n_observables)) self.net = nn.Sequential(*layers) for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) if self.freq_net is not None: nn.init.zeros_(self.freq_net[-1].weight) if config.adaptive_bandwidth: # Spread initial σ_k across (0, 1) so initial ω_k tile [0, ω_op] # rather than collapsing at ω_op/2. init_sigma = torch.linspace(0.05, 0.95, config.n_fourier) init_logits = torch.log(init_sigma / (1.0 - init_sigma)) self.freq_net[-1].bias.data.copy_(init_logits) else: # Original behavior: initial ω ≈ ω_base. nn.init.zeros_(self.freq_net[-1].bias) def forward( self, rt: torch.Tensor, orb_energies: torch.Tensor = None, omega_op: torch.Tensor = None, ) -> torch.Tensor: """ Args: rt: (B, 2) tensor with [R, t] per sample orb_energies: (B, n_orb) tensor of HF orbital energies, or None omega_op: (B,) tensor of operational frequency ceilings ω_op(R), required when config.adaptive_bandwidth is True. Returns: (B, K) predicted observable expectations """ R = rt[:, 0:1] # (B, 1) t = rt[:, 1:2] # (B, 1) if self.config.adaptive_bandwidth: if omega_op is None: raise ValueError("adaptive_bandwidth requires omega_op input") x = orb_energies if (self.config.n_orb_features > 0 and orb_energies is not None) else R sigma = torch.sigmoid(self.freq_net(x)) # (B, n_fourier) in (0, 1) omega = omega_op[:, None] * sigma # (B, n_fourier) in (0, ω_op(R)) elif self.freq_net is not None: if self.config.n_orb_features > 0 and orb_energies is not None: omega = self.omega_base[None, :] + self.freq_net(orb_energies) else: omega = self.omega_base[None, :] + self.freq_net(R) else: omega = self.omega_base[None, :] # (1, n_fourier) fourier = torch.cat( [torch.sin(omega * t), torch.cos(omega * t)], dim=-1, ) # (B, 2*n_fourier) x = torch.cat([R, fourier], dim=-1) return self.net(x) def init_observable_regressor( n_observables: int = 120, d_hidden: int = 256, n_layers: int = 3, n_fourier: int = 64, fourier_scale: float = 10.0, conditioned_frequencies: bool = False, freq_net_hidden: int = 64, freq_net_layers: int = 2, n_orb_features: int = 0, adaptive_bandwidth: bool = False, ): return ObservableRegressor( ObservableRegressorConfig( n_observables=n_observables, d_hidden=d_hidden, n_layers=n_layers, n_fourier=n_fourier, fourier_scale=fourier_scale, conditioned_frequencies=conditioned_frequencies, freq_net_hidden=freq_net_hidden, freq_net_layers=freq_net_layers, n_orb_features=n_orb_features, adaptive_bandwidth=adaptive_bandwidth, ) )