| import sys |
| import torch |
| import numpy as np |
| from tqdm import tqdm |
|
|
| import math |
| from .utils.utils import kabsch |
| from .bias import BiasForceTransformer |
|
|
|
|
| class EntangledSBM: |
| def __init__(self, args, mds): |
| self.bias_net = BiasForceTransformer(mds, args) |
| |
| self.target_measure = PathObjective(args, mds) |
| |
| if args.training: |
| self.replay = ReplayBuffer(args, mds) |
| |
| self.rollout_idx = 0 |
| |
| def increment_rollout(self): |
| self.rollout_idx += 1 |
| |
|
|
| def sample(self, args, mds, temperature): |
| |
| positions = torch.zeros( |
| (args.num_samples, args.num_steps + 1, mds.num_particles, 3), |
| device=args.device, |
| ) |
| |
| forces = torch.zeros( |
| (args.num_samples, args.num_steps + 1, mds.num_particles, 3), |
| device=args.device, |
| ) |
| |
| position, force = mds.report() |
| positions[:, 0] = position.detach().clone() |
| forces[:, 0] = force.detach().clone() |
| mds.reset() |
| |
| mds.set_temperature(temperature) |
| prev_position = position.detach().clone() |
| |
| for step in tqdm(range(1, args.num_steps + 1), desc="Sampling"): |
| if step == 1: |
| velocity = torch.zeros_like(position) |
| else: |
| velocity = (position - prev_position) / args.timestep |
| |
| bias_force = self.bias_net(position.detach().clone(), |
| velocity.detach().clone(), |
| mds.target_position).detach() |
|
|
| |
| mds.step(bias_force) |
| |
| position, force = mds.report() |
| |
| if not _is_finite(position, force): |
| print("MD produced non-finite: pos nan/inf", torch.isnan(position).sum().item(), torch.isinf(position).sum().item(), |
| "force nan/inf", torch.isnan(force).sum().item(), torch.isinf(force).sum().item()) |
|
|
| positions[:, step] = prev_position |
| forces[:, step] = force |
| break |
| |
| prev_position = position.detach().clone() |
| |
| positions[:, step] = position |
| forces[:, step] = force - 1e-6 * bias_force |
| |
| mds.reset() |
| log_tpm, final_idx, log_ri = self.target_measure(positions, forces) |
| |
| if args.training: |
| self.replay.add_ranked((positions, |
| forces, |
| log_tpm), score=log_ri) |
| |
| for i in range(args.num_samples): |
| np.save( |
| f"{args.save_dir}/positions/{i}.npy", |
| positions[i][: final_idx[i] + 1].cpu().numpy(), |
| ) |
|
|
| def train(self, args, mds): |
| |
| exclude = {id(self.bias_net.log_z)} |
| params_except = [p for p in self.bias_net.parameters() if id(p) not in exclude] |
| optimizer = torch.optim.Adam( |
| [ |
| {"params": [self.bias_net.log_z], "lr": args.log_z_lr}, |
| {"params": params_except, "lr": args.policy_lr}, |
| ] |
| ) |
| loss_sum = 0 |
| |
| for _ in tqdm(range(args.trains_per_rollout), desc="Training"): |
| |
| positions, forces, log_tpm, log_ri = self.replay.sample() |
| |
| assert positions.shape == forces.shape, f"{positions.shape=} != {forces.shape=}" |
| velocities = (positions[:, 1:] - positions[:, :-1]) / args.timestep |
| |
| |
| biases = 1e-6 * self.bias_net( |
| positions[:, :-1].reshape(-1, positions.size(-2), positions.size(-1)), |
| velocities.view(-1, velocities.size(-2), velocities.size(-1)), |
| mds.target_position, |
| ) |
|
|
| biases = biases.view(*velocities.shape) |
| |
| means = ( |
| 1 - args.friction * args.timestep |
| ) * velocities + args.timestep / mds.m * (forces[:, :-1] + biases) |
|
|
| resid = _sanitize(velocities[:, 1:] - means[:, :-1]) |
| log_bpm = mds.log_prob(resid).mean((1, 2, 3)) |
| |
| if args.control_variate == "global": |
| log_z = self.bias_net.log_z |
| elif args.control_variate == "local": |
| log_z = (log_tpm - log_bpm).mean().detach() |
| elif args.control_variate == "zero": |
| log_z = 0 |
| |
| |
| if args.objective == "ce": |
| |
| log_rnd = (log_tpm - log_bpm).detach() |
| |
| weights = torch.softmax(log_rnd, dim=0) |
| |
| if args.control_cost: |
| control_cost = 0.5 * args.timestep * (biases[:, :-1].square().sum((-1, -2, -3))).mean() |
| loss = -(weights * log_bpm).sum() + control_cost |
| else: |
| loss = -(weights * log_bpm).sum() |
| |
| elif args.objective == "lv": |
| loss = (log_z + log_bpm - log_tpm).square().mean() |
|
|
| loss.backward() |
| |
| for group in optimizer.param_groups: |
| torch.nn.utils.clip_grad_norm_(group["params"], args.max_grad_norm) |
| |
| optimizer.step() |
| optimizer.zero_grad() |
| loss_sum += loss.item() |
| |
| loss = loss_sum / args.trains_per_rollout |
| return loss, positions |
|
|
|
|
| class ReplayBuffer: |
| def __init__(self, args, mds): |
| |
| self.positions = torch.zeros( |
| (args.buffer_size, args.num_steps + 1, mds.num_particles, 3), |
| device=args.device, |
| ) |
| self.forces = torch.zeros( |
| (args.buffer_size, args.num_steps + 1, mds.num_particles, 3), |
| device=args.device, |
| ) |
| self.log_tpm = torch.zeros(args.buffer_size, device=args.device) |
| self.idx = 0 |
| |
| self.device = args.device |
| self.batch_size = args.batch_size |
| self.num_samples = args.num_samples |
| self.buffer_size = args.buffer_size |
| self.args = args |
| |
| |
| self.scores = torch.zeros(args.buffer_size, device=args.device) |
| self.count = 0 |
|
|
| def add(self, data): |
| pos_batch, force_batch, tpm_batch = data |
|
|
| newN = pos_batch.size(0) |
|
|
| indices = (torch.arange(self.idx, self.idx + newN, device=self.device) % self.buffer_size) |
|
|
| self.idx = (self.idx + newN) % self.buffer_size |
|
|
| self.positions[indices] = pos_batch.detach().to(self.device).clone() |
| self.forces[indices] = force_batch.detach().to(self.device).clone() |
| self.log_tpm[indices] = tpm_batch.detach().to(self.device).clone() |
|
|
| self.count = min(self.count + newN, self.buffer_size) |
| |
| @torch.no_grad() |
| def add_ranked(self, data, score=None): |
| positions, forces, log_tpm = data |
| if score is None: |
| score = log_tpm |
| |
| |
| positions, forces, log_tpm, score = ( |
| positions.clone().detach(), |
| forces.clone().detach(), |
| log_tpm.clone().detach(), |
| score.clone().detach() |
| ) |
| |
| valid = torch.isfinite(positions).all((1,2,3)) & torch.isfinite(forces).all((1,2,3)) & torch.isfinite(log_tpm) |
|
|
| if valid.any(): |
| positions = positions[valid] |
| forces = forces[valid] |
| log_tpm = log_tpm[valid] |
| score = score[valid] |
|
|
| curr = self.count |
| newN = positions.size(0) |
| keepN = min(self.buffer_size, curr + newN) |
|
|
| if curr > 0: |
| pos_cat = torch.cat([self.positions[:curr], positions], dim=0) |
| force_cat = torch.cat([self.forces[:curr], forces], dim=0) |
| tpm_cat = torch.cat([self.log_tpm[:curr], log_tpm], dim=0) |
| sco_cat = torch.cat([self.scores[:curr], score], dim=0) |
| else: |
| pos_cat, force_cat, tpm_cat, sco_cat = positions, forces, log_tpm, score |
|
|
| top_vals, top_idx = torch.topk(sco_cat, k=keepN, largest=True, sorted=False) |
|
|
| self.positions[:keepN] = pos_cat.index_select(0, top_idx) |
| self.forces[:keepN] = force_cat.index_select(0, top_idx) |
| self.log_tpm[:keepN] = tpm_cat.index_select(0, top_idx) |
| self.scores[:keepN] = top_vals |
| self.count = keepN |
|
|
| def sample(self): |
| assert self.count > 0, "buffer is empty" |
| if self.args.importance_sample: |
| idx = torch.multinomial(torch.softmax(self.scores[:self.count], 0), |
| num_samples=self.batch_size, replacement=True) |
| else: |
| idx = torch.randint(0, self.count, (self.batch_size,), device=self.device) |
| |
| |
| return ( |
| self.positions[idx].clone().detach(), |
| self.forces[idx].clone().detach(), |
| self.log_tpm[idx].clone().detach(), |
| self.scores[idx].clone().detach(), |
| ) |
|
|
|
|
| class PathObjective: |
| def __init__(self, args, mds): |
| self.sigma = args.sigma |
| self.timestep = args.timestep |
| self.friction = args.friction |
| self.heavy_atoms = mds.heavy_atoms |
| self.target_position = mds.target_position |
| self.m = mds.m |
| self.log_prob = mds.log_prob |
|
|
| def __call__(self, positions, forces): |
| log_upm = self.unbiased_path_measure(positions, forces) |
| log_ri, final_idx = self.relaxed_indicator(positions, self.target_position) |
| log_tpm = log_upm + log_ri |
| return log_tpm, final_idx, log_ri |
|
|
| def unbiased_path_measure(self, positions, forces): |
| velocities = (positions[:, 1:] - positions[:, :-1]) / self.timestep |
| |
| means = ( |
| 1 - self.friction * self.timestep |
| ) * velocities + self.timestep / self.m * forces[:, :-1] |
| |
| resid = _sanitize(velocities[:, 1:] - means[:, :-1]) |
| |
| lp = self.log_prob(resid) |
| |
| log_upm = lp.mean((1, 2, 3)) |
| return log_upm |
|
|
| def relaxed_indicator(self, positions, target_position): |
| positions = positions[:, :, self.heavy_atoms] |
| target_position = target_position[:, self.heavy_atoms] |
| |
| log_ri = torch.zeros(positions.size(0), device=positions.device) |
| final_idx = torch.zeros( |
| positions.size(0), device=positions.device, dtype=torch.long |
| ) |
| for i in range(positions.size(0)): |
| log_ri[i], final_idx[i] = self.rbf( |
| positions[i], |
| target_position, |
| ).max(0) |
| return log_ri, final_idx |
|
|
| def rbf(self, positions, target_position): |
| |
| R, t = kabsch(positions, target_position) |
| positions = torch.matmul(positions, R.transpose(-2, -1)) + t |
| log_ri = ( |
| -0.5 / self.sigma**2 * (positions - target_position).square().mean((-2, -1)) |
| ) |
| return log_ri |
| |
| def _is_finite(*tensors): |
| return all(torch.isfinite(t).all().item() for t in tensors) |
|
|
| def _sanitize(t, max_abs=1e6): |
| t = torch.nan_to_num(t, nan=0.0, posinf=max_abs, neginf=-max_abs) |
| return torch.clamp(t, min=-max_abs, max=max_abs) |