| """ |
| TD3B-specific MCTS modifications. |
| Extends the base MCTS to support directional rewards and confidence weighting. |
| """ |
|
|
| import numpy as np |
| import torch |
| from peptide_mcts import MCTS as BaseMCTS |
| from .td3b_scoring import TD3BRewardFunction, TD3BConfidenceWeighting |
|
|
|
|
| class TD3B_MCTS(BaseMCTS): |
| """ |
| TD3B version of MCTS that: |
| 1. Uses gated directional rewards instead of multi-objective scalarization |
| 2. Stores directional labels and confidence scores in the buffer |
| 3. Applies confidence-weighted importance sampling |
| """ |
|
|
| def __init__( |
| self, |
| args, |
| diffusion_model, |
| td3b_reward_function: TD3BRewardFunction, |
| confidence_weighting: TD3BConfidenceWeighting, |
| mask_index: int, |
| buffer_size: int = 100, |
| noise=None, |
| tokenizer=None |
| ): |
| """ |
| Args: |
| args: Configuration arguments |
| diffusion_model: MDLM model for sampling |
| td3b_reward_function: TD3BRewardFunction instance |
| confidence_weighting: TD3BConfidenceWeighting instance |
| mask_index: Token ID for masked positions |
| buffer_size: Maximum buffer size |
| noise: Noise schedule |
| tokenizer: Peptide tokenizer |
| """ |
| |
| |
| |
| class MinimalConfig: |
| def __init__(self): |
| self.noise = type('obj', (object,), { |
| 'type': 'loglinear', |
| 'sigma_min': 1e-4, |
| 'sigma_max': 20 |
| })() |
| config = MinimalConfig() |
|
|
| super().__init__( |
| args=args, |
| config=config, |
| policy_model=diffusion_model, |
| pretrained=diffusion_model, |
| score_func_names=['affinity', 'gated_reward', 'placeholder1', 'placeholder2', 'placeholder3'] |
| ) |
|
|
| |
| self.td3b_reward_func = td3b_reward_function |
| self.confidence_weighting = confidence_weighting |
| self.mask_index = mask_index |
| self.buffer_size = buffer_size |
| self.noise = noise |
| self.tokenizer = tokenizer if tokenizer is not None else diffusion_model.tokenizer |
|
|
| |
| self.num_obj = 5 |
|
|
| |
| self.rewardFunc = self._td3b_reward_wrapper |
|
|
| def _td3b_reward_wrapper(self, input_seqs): |
| """ |
| Wrapper to make TD3BRewardFunction compatible with existing MCTS interface. |
| Returns (N, 5) array to match base MCTS expectations. |
| The 5 columns are: [affinity, gated_reward, 0, 0, 0] (padding last 3) |
| """ |
| import numpy as np |
| total_rewards, info = self.td3b_reward_func(input_seqs) |
| |
|
|
| |
| self._last_confidences = info['confidences'] |
|
|
| |
| |
| |
| score_vectors = info['score_vectors'] |
| padded = np.zeros((score_vectors.shape[0], 5)) |
| padded[:, :2] = score_vectors |
|
|
| return padded |
|
|
| def updateBuffer(self, x_final, log_rnd, score_vectors, childSequences): |
| """ |
| TD3B version: stores directional labels and confidence scores. |
| |
| Args: |
| x_final: (B, L) final sequence tokens |
| log_rnd: (B,) log importance weights (trajectory-level) |
| score_vectors: (B, K) score arrays |
| childSequences: List of B SMILES strings |
| Returns: |
| traj_log_rnds: (B,) updated log importance weights |
| scalar_rewards: (B,) scalar rewards |
| """ |
| B = x_final.shape[0] |
| traj_log_rnds, scalar_rewards = [], [] |
|
|
| |
| confidences = getattr(self, '_last_confidences', np.ones(B)) |
|
|
| for i in range(B): |
| sv = np.asarray(score_vectors[i], dtype=float) |
| confidence = confidences[i] |
|
|
| |
| scalar_reward = float(sv[1]) |
|
|
| |
| |
| |
| log_confidence = np.log(np.maximum(confidence, self.confidence_weighting.min_confidence)) |
| traj_log_rnd = log_rnd[i] + (scalar_reward / self.args.alpha) + log_confidence |
|
|
| |
| |
| |
| directional_label = np.sign(scalar_reward) if scalar_reward != 0 else 0.0 |
|
|
| item = { |
| "x_final": x_final[i].clone(), |
| "log_rnd": traj_log_rnd.clone() if isinstance(traj_log_rnd, torch.Tensor) else torch.tensor(traj_log_rnd), |
| "final_reward": scalar_reward, |
| "score_vector": sv.copy(), |
| "seq": childSequences[i], |
| |
| "directional_label": directional_label, |
| "confidence": confidence, |
| } |
|
|
| |
| from peptide_mcts import dominated_by, dominates |
|
|
| if any(dominated_by(sv, bi["score_vector"]) for bi in self.buffer): |
| self._debug_buffer_decision(sv, "rejected_dominated") |
| continue |
|
|
| |
| keep = [] |
| for bi in self.buffer: |
| if not dominates(sv, bi["score_vector"]): |
| keep.append(bi) |
| self.buffer = keep |
|
|
| |
| if len(self.buffer) < self.buffer_size: |
| self.buffer.append(item) |
| else: |
| |
| worst_i = int(np.argmin([np.sum(bi["score_vector"]) for bi in self.buffer])) |
| self.buffer[worst_i] = item |
|
|
| self._debug_buffer_decision(sv, "inserted", {"new_len": len(self.buffer)}) |
|
|
| traj_log_rnds.append(traj_log_rnd) |
| scalar_rewards.append(scalar_reward) |
|
|
| traj_log_rnds = torch.stack([torch.tensor(x) if not isinstance(x, torch.Tensor) else x for x in traj_log_rnds], dim=0) if traj_log_rnds else torch.empty(0) |
| scalar_rewards = np.asarray(scalar_rewards, dtype=float) |
| return traj_log_rnds, scalar_rewards |
|
|
| def forward(self, resetTree=False): |
| """ |
| TD3B version of forward that returns 7 values. |
| |
| Returns: |
| x_final: (N, L) sequence tokens |
| log_rnd: (N,) log importance weights |
| final_rewards: (N,) scalar rewards |
| score_vectors: (N, K) score arrays |
| sequences: List of N SMILES strings |
| directional_labels: (N,) directional labels |
| confidences: (N,) confidence scores |
| """ |
| self.reset(resetTree) |
|
|
| while (self.iter_num < self.num_iter): |
| self.iter_num += 1 |
|
|
| |
| with self.timer.section("select"): |
| leafNode, _ = self.select(self.rootNode) |
|
|
| |
| with self.timer.section("expand"): |
| self.expand(leafNode) |
|
|
| final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = self.consolidateBuffer() |
|
|
| rows = self.timer.summary() |
| print("\n=== Timing summary (by total time) ===") |
| for name, cnt, total, mean, p50, p95 in rows: |
| print(f"{name:30s} n={cnt:5d} total={total:8.3f}s mean={mean*1e3:7.2f}ms " |
| f"p50={p50*1e3:7.2f}ms p95={p95*1e3:7.2f}ms") |
|
|
| return final_x, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences |
|
|
| def consolidateBuffer(self): |
| """ |
| TD3B version: includes directional labels and confidences. |
| |
| Returns: |
| x_final: (N, L) sequence tokens |
| log_rnd: (N,) log importance weights |
| final_rewards: (N,) scalar rewards |
| score_vectors: (N, K) score arrays |
| sequences: List of N SMILES strings |
| directional_labels: (N,) directional labels |
| confidences: (N,) confidence scores |
| """ |
| |
| if len(self.buffer) == 0: |
| import logging |
| logger = logging.getLogger(__name__) |
| logger.warning("MCTS buffer is empty - no valid sequences found. Returning empty results.") |
|
|
| |
| |
| device = self.policy_model.device if hasattr(self.policy_model, 'device') else 'cpu' |
| return ( |
| torch.empty(0, 0, dtype=torch.long, device=device), |
| torch.empty(0, dtype=torch.float32, device=device), |
| np.empty(0, dtype=np.float32), |
| np.empty((0, 0), dtype=np.float32), |
| [], |
| np.empty(0, dtype=np.float32), |
| np.empty(0, dtype=np.float32) |
| ) |
|
|
| x_final = [] |
| log_rnd = [] |
| final_rewards = [] |
| score_vectors = [] |
| sequences = [] |
| directional_labels = [] |
| confidences = [] |
|
|
| for item in self.buffer: |
| x_final.append(item["x_final"]) |
| log_rnd.append(item["log_rnd"]) |
| final_rewards.append(item["final_reward"]) |
| score_vectors.append(item["score_vector"]) |
| sequences.append(item["seq"]) |
| directional_labels.append(item.get("directional_label", 0.0)) |
| confidences.append(item.get("confidence", 1.0)) |
|
|
| x_final = torch.stack(x_final, dim=0) |
| log_rnd = torch.stack(log_rnd, dim=0).to(dtype=torch.float32) |
| final_rewards = np.stack(final_rewards, axis=0).astype(np.float32) |
| score_vectors = np.stack(score_vectors, axis=0).astype(np.float32) |
| directional_labels = np.array(directional_labels, dtype=np.float32) |
| confidences = np.array(confidences, dtype=np.float32) |
|
|
| return x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences |
|
|
|
|
| def create_td3b_mcts( |
| args, |
| diffusion_model, |
| td3b_reward_function: TD3BRewardFunction, |
| alpha: float = 0.1, |
| **kwargs |
| ) -> TD3B_MCTS: |
| """ |
| Factory function to create TD3B MCTS instance. |
| |
| Args: |
| args: Configuration arguments |
| diffusion_model: MDLM model |
| td3b_reward_function: TD3BRewardFunction instance |
| alpha: Temperature for importance weighting |
| **kwargs: Additional MCTS arguments |
| |
| Returns: |
| mcts: TD3B_MCTS instance |
| """ |
| |
| confidence_weighting = TD3BConfidenceWeighting( |
| alpha=alpha, |
| min_confidence=0.1 |
| ) |
|
|
| |
| mcts = TD3B_MCTS( |
| args=args, |
| diffusion_model=diffusion_model, |
| td3b_reward_function=td3b_reward_function, |
| confidence_weighting=confidence_weighting, |
| **kwargs |
| ) |
|
|
| return mcts |
|
|