| """ |
| TD3B Finetuning Loop |
| Extends TR2-D2 training with contrastive loss and directional rewards. |
| """ |
|
|
| import numpy as np |
| import torch |
| import wandb |
| import os |
| from finetune_utils import loss_wdce |
| from .td3b_losses import TD3BTotalLoss, extract_embeddings_from_mdlm |
| from tqdm import tqdm |
| import pandas as pd |
| from plotting import plot_data_with_distribution_seaborn, plot_data |
|
|
|
|
| def td3b_finetune( |
| args, |
| cfg, |
| policy_model, |
| reward_model, |
| mcts=None, |
| pretrained=None, |
| filename=None, |
| prot_name=None, |
| eps=1e-5, |
| |
| contrastive_weight=0.1, |
| contrastive_margin=1.0, |
| contrastive_type='margin', |
| embedding_pool_method='mean', |
| kl_beta=0.1 |
| ): |
| """ |
| TD3B finetuning with combined WDCE + contrastive loss + KL regularization. |
| |
| Args: |
| args: Configuration arguments |
| cfg: Hydra config |
| policy_model: Policy model (MDLM) |
| reward_model: Reward scoring functions (TD3BRewardFunction) |
| mcts: TD3B_MCTS instance |
| pretrained: Pretrained model (for no-MCTS mode) |
| filename: Output filename |
| prot_name: Target protein name |
| eps: Small epsilon |
| contrastive_weight: λ for contrastive loss |
| contrastive_margin: Margin for margin-based contrastive loss |
| contrastive_type: 'margin' or 'infonce' |
| embedding_pool_method: 'mean', 'max', or 'cls' |
| kl_beta: β coefficient for KL divergence regularization |
| Returns: |
| batch_losses: List of training losses |
| """ |
| base_path = args.base_path |
| dt = (1 - eps) / args.total_num_steps |
|
|
| if args.no_mcts: |
| assert pretrained is not None, "pretrained model is required for no mcts" |
| else: |
| assert mcts is not None, "mcts is required for mcts" |
|
|
| |
| |
| |
| print("[TD3B] Creating reference model for KL regularization...") |
|
|
| |
| from diffusion import Diffusion |
|
|
| |
| reference_model = Diffusion( |
| config=policy_model.config, |
| tokenizer=policy_model.tokenizer, |
| mode="eval", |
| device=policy_model.device if hasattr(policy_model, 'device') else args.device |
| ) |
|
|
| |
| device = policy_model.device if hasattr(policy_model, 'device') else args.device |
| if device is None: |
| device = next(policy_model.parameters()).device |
|
|
| |
| |
| state_dict_copy = { |
| key: value.clone().detach() |
| for key, value in policy_model.state_dict().items() |
| } |
| reference_model.load_state_dict(state_dict_copy) |
|
|
| |
| reference_model = reference_model.to(device) |
|
|
| |
| reference_model.eval() |
| for param in reference_model.parameters(): |
| param.requires_grad = False |
|
|
| print(f"[TD3B] Reference model frozen with {sum(p.numel() for p in reference_model.parameters())} parameters") |
| print(f"[TD3B] Reference model on device: {device}") |
|
|
| |
| policy_params = {id(p) for p in policy_model.parameters()} |
| ref_params = {id(p) for p in reference_model.parameters()} |
| assert len(policy_params.intersection(ref_params)) == 0, \ |
| "ERROR: Reference model shares parameters with policy model!" |
| print("[TD3B] ✓ Verified: No parameter sharing between policy and reference model") |
|
|
| |
| td3b_loss_fn = TD3BTotalLoss( |
| contrastive_weight=contrastive_weight, |
| contrastive_margin=contrastive_margin, |
| contrastive_type=contrastive_type, |
| kl_beta=kl_beta, |
| reference_model=reference_model |
| ) |
|
|
| |
| policy_model.train() |
| torch.set_grad_enabled(True) |
| optim = torch.optim.AdamW(policy_model.parameters(), lr=args.learning_rate) |
|
|
| |
| batch_losses = [] |
| batch_wdce_losses = [] |
| batch_contrastive_losses = [] |
| batch_kl_losses = [] |
|
|
| |
| x_saved, log_rnd_saved, final_rewards_saved = None, None, None |
| directional_labels_saved, confidences_saved = None, None |
|
|
| |
| valid_fraction_log = [] |
| affinity_log = [] |
| gated_reward_log = [] |
| confidence_log = [] |
| direction_prediction_log = [] |
| consistency_reward_log = [] |
|
|
| |
| pbar = tqdm(range(args.num_epochs)) |
|
|
| for epoch in pbar: |
| rewards = [] |
| losses = [] |
|
|
| policy_model.train() |
|
|
| with torch.no_grad(): |
| if x_saved is None or epoch % args.resample_every_n_step == 0: |
| |
| if args.no_mcts: |
| |
| x_final, log_rnd, final_rewards = policy_model.sample_finetuned_with_rnd( |
| args, reward_model, pretrained |
| ) |
| directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32) |
| confidences = torch.ones(x_final.size(0), dtype=torch.float32) |
| else: |
| |
| |
| if hasattr(args, 'target_direction') and args.target_direction == 'both': |
| print(f"[Dual-direction] Epoch {epoch}: Sampling BOTH agonist and antagonist binders") |
|
|
| |
| reward_model.target_direction = 1.0 |
| if epoch % args.reset_every_n_step == 0: |
| results_agonist = mcts.forward(resetTree=True) |
| else: |
| results_agonist = mcts.forward(resetTree=False) |
|
|
| |
| reward_model.target_direction = -1.0 |
| |
| results_antagonist = mcts.forward(resetTree=False) |
|
|
| |
| if len(results_agonist) == 7 and len(results_antagonist) == 7: |
| x_agonist, log_rnd_agonist, rewards_agonist, _, _, labels_agonist, conf_agonist = results_agonist |
| x_antagonist, log_rnd_antagonist, rewards_antagonist, _, _, labels_antagonist, conf_antagonist = results_antagonist |
|
|
| |
| labels_agonist = torch.ones(x_agonist.size(0), dtype=torch.float32) * 1.0 |
| labels_antagonist = torch.ones(x_antagonist.size(0), dtype=torch.float32) * -1.0 |
|
|
| |
| x_final = torch.cat([x_agonist, x_antagonist], dim=0) |
| log_rnd = torch.cat([log_rnd_agonist, log_rnd_antagonist], dim=0) |
| final_rewards = np.concatenate([rewards_agonist, rewards_antagonist], axis=0) |
| directional_labels = torch.cat([labels_agonist, labels_antagonist], dim=0) |
| confidences = torch.cat([ |
| conf_agonist if isinstance(conf_agonist, torch.Tensor) else torch.tensor(conf_agonist), |
| conf_antagonist if isinstance(conf_antagonist, torch.Tensor) else torch.tensor(conf_antagonist) |
| ], dim=0) |
|
|
| print(f" → Combined batch: {x_agonist.size(0)} agonists + {x_antagonist.size(0)} antagonists = {x_final.size(0)} total") |
| print(f" → Directional labels: {torch.unique(directional_labels).tolist()} (DIVERSITY CONFIRMED!)") |
| else: |
| raise ValueError("Dual-direction mode requires 7-value return from MCTS") |
| else: |
| |
| if epoch % args.reset_every_n_step == 0: |
| results = mcts.forward(resetTree=True) |
| else: |
| results = mcts.forward(resetTree=False) |
|
|
| |
| if len(results) == 7: |
| x_final, log_rnd, final_rewards, score_vectors, sequences, directional_labels, confidences = results |
| |
| if not isinstance(directional_labels, torch.Tensor): |
| directional_labels = torch.tensor(directional_labels, dtype=torch.float32) |
| if not isinstance(confidences, torch.Tensor): |
| confidences = torch.tensor(confidences, dtype=torch.float32) |
| else: |
| |
| x_final, log_rnd, final_rewards, score_vectors, sequences = results |
| directional_labels = torch.zeros(x_final.size(0), dtype=torch.float32) |
| confidences = torch.ones(x_final.size(0), dtype=torch.float32) |
|
|
| |
| x_saved = x_final |
| log_rnd_saved = log_rnd |
| final_rewards_saved = final_rewards |
| directional_labels_saved = directional_labels |
| confidences_saved = confidences |
| else: |
| |
| x_final = x_saved |
| log_rnd = log_rnd_saved |
| final_rewards = final_rewards_saved |
| directional_labels = directional_labels_saved |
| confidences = confidences_saved |
|
|
| |
| wdce_loss = loss_wdce( |
| policy_model, |
| log_rnd, |
| x_final, |
| num_replicates=args.wdce_num_replicates, |
| centering=args.centering |
| ) |
|
|
| |
| |
| mask_index = policy_model.mask_index |
| device = x_final.device |
|
|
| |
| lamda = torch.rand(x_final.shape[0], device=device) |
| sigma_kl = -torch.log1p(-(1 - eps) * lamda) |
|
|
| |
| masked_index = torch.rand(*x_final.shape, device=device) < lamda[..., None] |
| perturbed_batch = torch.where(masked_index, mask_index, x_final) |
| attn_mask_kl = torch.ones_like(perturbed_batch).to(device) |
|
|
| |
| kl_loss = td3b_loss_fn.compute_kl_loss( |
| policy_model, |
| perturbed_batch, |
| attn_mask_kl, |
| sigma_kl |
| ) |
|
|
| |
| |
| if directional_labels is not None and len(torch.unique(directional_labels)) > 1: |
| |
| device = policy_model.backbone.device if hasattr(policy_model.backbone, 'device') else x_final.device |
|
|
| embeddings = extract_embeddings_from_mdlm( |
| policy_model, |
| x_final.to(device), |
| pool_method=embedding_pool_method |
| ) |
|
|
| |
| directional_labels = directional_labels.to(embeddings.device) |
|
|
| |
| debug_mode = (epoch < 3) or (epoch > 0 and batch_contrastive_losses and batch_contrastive_losses[-1] < 1e-6) |
|
|
| |
| total_loss, loss_dict = td3b_loss_fn.compute_loss( |
| wdce_loss, |
| embeddings, |
| directional_labels, |
| kl_loss=kl_loss, |
| debug=debug_mode |
| ) |
| else: |
| |
| print(f"[WARNING] Epoch {epoch}: No directional diversity! Skipping contrastive loss.") |
| print(f" Labels: {directional_labels.cpu().tolist() if directional_labels is not None else 'None'}") |
| total_loss = wdce_loss + td3b_loss_fn.kl_beta * kl_loss |
| loss_dict = { |
| 'total_loss': total_loss.item(), |
| 'wdce_loss': wdce_loss.item(), |
| 'contrastive_loss': 0.0, |
| 'kl_loss': kl_loss.item() |
| } |
|
|
| |
| total_loss.backward() |
|
|
| |
| if args.grad_clip: |
| torch.nn.utils.clip_grad_norm_(policy_model.parameters(), args.gradnorm_clip) |
|
|
| optim.step() |
| optim.zero_grad() |
|
|
| pbar.set_postfix( |
| total_loss=loss_dict['total_loss'], |
| wdce=loss_dict['wdce_loss'], |
| ctr=loss_dict['contrastive_loss'] |
| ) |
|
|
| |
| x_eval, eval_metrics = policy_model.sample_finetuned_td3b( |
| args, |
| reward_model, |
| batch_size=50, |
| dataframe=False |
| ) |
|
|
| |
| affinity = eval_metrics.get('affinity', [0]) |
| gated_reward = eval_metrics.get('gated_reward', [0]) |
| confidence = eval_metrics.get('confidence', [1]) |
| valid_fraction = eval_metrics.get('valid_fraction', 0) |
|
|
| |
| direction_predictions = eval_metrics.get('direction_predictions', [0.5]) |
|
|
| |
| |
| d_star = reward_model.target_direction |
| consistency_rewards = [d_star * (f_phi - 0.5) for f_phi in direction_predictions] |
|
|
| |
| affinity_log.append(affinity) |
| gated_reward_log.append(gated_reward) |
| confidence_log.append(confidence) |
| valid_fraction_log.append(valid_fraction) |
| direction_prediction_log.append(direction_predictions) |
| consistency_reward_log.append(consistency_rewards) |
|
|
| batch_losses.append(loss_dict['total_loss']) |
| batch_wdce_losses.append(loss_dict['wdce_loss']) |
| batch_contrastive_losses.append(loss_dict['contrastive_loss']) |
| batch_kl_losses.append(loss_dict.get('kl_loss', 0.0)) |
|
|
| |
| if args.no_mcts: |
| mean_reward_search = final_rewards.mean().item() |
| min_reward_search = final_rewards.min().item() |
| max_reward_search = final_rewards.max().item() |
| median_reward_search = final_rewards.median().item() |
| else: |
| mean_reward_search = np.mean(final_rewards) |
| min_reward_search = np.min(final_rewards) |
| max_reward_search = np.max(final_rewards) |
| median_reward_search = np.median(final_rewards) |
|
|
| |
| mean_direction = np.mean(direction_predictions) if len(direction_predictions) > 0 else 0.5 |
| std_direction = np.std(direction_predictions) if len(direction_predictions) > 0 else 0.0 |
| mean_consistency = np.mean(consistency_rewards) if len(consistency_rewards) > 0 else 0.0 |
| std_consistency = np.std(consistency_rewards) if len(consistency_rewards) > 0 else 0.0 |
|
|
| print( |
| f"epoch {epoch} | " |
| f"affinity {np.mean(affinity):.4f} | " |
| f"gated_reward {np.mean(gated_reward):.4f} | " |
| f"confidence {np.mean(confidence):.4f} | " |
| f"valid_frac {valid_fraction:.4f} | " |
| f"direction_oracle {mean_direction:.4f}±{std_direction:.4f} | " |
| f"consistency_reward {mean_consistency:.4f}±{std_consistency:.4f} | " |
| f"total_loss {loss_dict['total_loss']:.4f} | " |
| f"wdce_loss {loss_dict['wdce_loss']:.4f} | " |
| f"contrastive_loss {loss_dict['contrastive_loss']:.4f} | " |
| f"kl_loss {loss_dict.get('kl_loss', 0.0):.4f}" |
| ) |
|
|
| |
| wandb.log({ |
| "epoch": epoch, |
| "affinity": np.mean(affinity), |
| "gated_reward": np.mean(gated_reward), |
| "confidence": np.mean(confidence), |
| "valid_fraction": valid_fraction, |
| "direction_oracle/mean": mean_direction, |
| "direction_oracle/std": std_direction, |
| "consistency_reward/mean": mean_consistency, |
| "consistency_reward/std": std_consistency, |
| "total_loss": loss_dict['total_loss'], |
| "wdce_loss": loss_dict['wdce_loss'], |
| "contrastive_loss": loss_dict['contrastive_loss'], |
| "kl_loss": loss_dict.get('kl_loss', 0.0), |
| "mean_reward_search": mean_reward_search, |
| "min_reward_search": min_reward_search, |
| "max_reward_search": max_reward_search, |
| "median_reward_search": median_reward_search |
| }) |
|
|
| |
| if (epoch + 1) % args.save_every_n_epochs == 0: |
| model_path = os.path.join(args.save_path, f'model_{epoch}.ckpt') |
| torch.save(policy_model.state_dict(), model_path) |
| print(f"model saved at epoch {epoch}") |
|
|
| |
|
|
| wandb.finish() |
|
|
| |
| plot_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}' |
| os.makedirs(plot_path, exist_ok=True) |
| output_log_path = f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/log_{filename}.csv' |
| save_td3b_logs_to_file( |
| valid_fraction_log, |
| affinity_log, |
| gated_reward_log, |
| confidence_log, |
| direction_prediction_log, |
| consistency_reward_log, |
| output_log_path |
| ) |
|
|
| plot_data(valid_fraction_log, |
| save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/valid_{filename}.png') |
|
|
| plot_data_with_distribution_seaborn( |
| log1=affinity_log, |
| save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/affinity_{filename}.png', |
| label1=f"Average Affinity to {prot_name}", |
| title=f"Average Affinity to {prot_name} Over Iterations" |
| ) |
|
|
| plot_data_with_distribution_seaborn( |
| log1=gated_reward_log, |
| save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/gated_reward_{filename}.png', |
| label1="Average Gated Reward", |
| title="Average Gated Reward Over Iterations" |
| ) |
|
|
| plot_data_with_distribution_seaborn( |
| log1=confidence_log, |
| save_path=f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/confidence_{filename}.png', |
| label1="Average Confidence", |
| title="Average Confidence Over Iterations" |
| ) |
|
|
| |
| x_eval, eval_metrics, df = policy_model.sample_finetuned_td3b( |
| args, |
| reward_model, |
| batch_size=200, |
| dataframe=True |
| ) |
| df.to_csv(f'{base_path}/TR2-D2/tr2d2-pep/results/{args.run_name}/{prot_name}_generation_results.csv', index=False) |
|
|
| return batch_losses |
|
|
|
|
| def save_td3b_logs_to_file(valid_fraction_log, affinity_log, gated_reward_log, confidence_log, |
| direction_prediction_log, consistency_reward_log, output_path): |
| """ |
| Saves TD3B-specific logs to a CSV file. |
| |
| Parameters: |
| valid_fraction_log (list): Log of valid fractions over iterations. |
| affinity_log (list): Log of binding affinity over iterations. |
| gated_reward_log (list): Log of gated rewards over iterations. |
| confidence_log (list): Log of confidence scores over iterations. |
| direction_prediction_log (list): Log of direction oracle predictions over iterations. |
| consistency_reward_log (list): Log of consistency rewards over iterations. |
| output_path (str): Path to save the log CSV file. |
| """ |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
| |
| log_data = { |
| "Iteration": list(range(1, len(valid_fraction_log) + 1)), |
| "Valid Fraction": valid_fraction_log, |
| "Binding Affinity": affinity_log, |
| "Gated Reward": gated_reward_log, |
| "Confidence": confidence_log, |
| "Direction Oracle": direction_prediction_log, |
| "Consistency Reward": consistency_reward_log |
| } |
|
|
| df = pd.DataFrame(log_data) |
|
|
| |
| df.to_csv(output_path, index=False) |
| print(f"Logs saved to {output_path}") |
|
|
|
|
| |
| def add_td3b_sampling_to_model(model): |
| """ |
| Adds TD3B-specific sampling method to the model. |
| This is a helper function to extend the existing model. |
| """ |
| def sample_finetuned_td3b(self, args, reward_model, batch_size=50, dataframe=False): |
| """ |
| TD3B-specific sampling that returns directional metrics. |
| """ |
| self.backbone.eval() |
| self.noise.eval() |
|
|
| if batch_size is None: |
| batch_size = args.batch_size |
|
|
| eps = getattr(args, "sampling_eps", 1e-5) |
| num_steps = args.total_num_steps |
| x_rollout = self.sample_prior( |
| batch_size, |
| args.seq_length).to(self.device, dtype=torch.long) |
|
|
| timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) |
| dt = torch.tensor((1 - eps) / num_steps, device=self.device) |
|
|
| for i in range(num_steps): |
| t = timesteps[i] * torch.ones(x_rollout.shape[0], 1, device=self.device) |
| log_p, x_next = self.single_reverse_step(x_rollout, t=t, dt=dt) |
| x_rollout = x_next.to(self.device) |
|
|
| mask_positions = (x_rollout == self.mask_index) |
| if mask_positions.any().item(): |
| log_p, x_next = self.single_noise_removal(x_rollout, t=t, dt=dt) |
| x_rollout = x_next.to(self.device) |
|
|
| |
| from utils.app import PeptideAnalyzer |
| analyzer = PeptideAnalyzer() |
| sequences = self.tokenizer.batch_decode(x_rollout) |
| valid_mask = torch.tensor([analyzer.is_peptide(seq) for seq in sequences], device=self.device) |
| valid_sequences = [seq for seq, keep in zip(sequences, valid_mask.tolist()) if keep] |
| valid_x_final = x_rollout[valid_mask] if valid_mask.any().item() else torch.empty(0, device=self.device) |
| valid_fraction = len(valid_sequences) / batch_size |
|
|
| if len(valid_sequences) > 0: |
| result = reward_model(valid_sequences) |
| if isinstance(result, tuple): |
| total_rewards, info = result |
| affinity = np.asarray(info.get('affinities', total_rewards)) |
| confidence = np.asarray(info.get('confidences', np.ones_like(affinity))) |
| direction_predictions = np.asarray(info.get('directions', np.zeros_like(affinity))) |
| else: |
| total_rewards = np.asarray(result) |
| if total_rewards.ndim > 1: |
| affinity = total_rewards[:, 0] |
| else: |
| affinity = total_rewards |
| confidence = np.ones_like(affinity) |
| direction_predictions = np.zeros_like(affinity) |
|
|
| rewards_t = torch.as_tensor(total_rewards, dtype=torch.float32, device=self.device) |
| alpha = max(float(getattr(args, "alpha", 0.1)), 1e-6) |
| weights = torch.softmax(rewards_t / alpha, dim=0) |
| idx = torch.multinomial(weights, num_samples=batch_size, replacement=True) |
|
|
| idx_np = idx.detach().cpu().numpy() |
| x_resampled = valid_x_final[idx] |
| sequences = [valid_sequences[i] for i in idx_np] |
| total_rewards = total_rewards[idx_np] |
| affinity = affinity[idx_np] |
| confidence = confidence[idx_np] |
| direction_predictions = direction_predictions[idx_np] |
| else: |
| x_resampled = x_rollout |
| total_rewards = np.array([]) |
| affinity = np.array([]) |
| confidence = np.array([]) |
| direction_predictions = np.array([]) |
|
|
| eval_metrics = { |
| 'affinity': affinity, |
| 'gated_reward': total_rewards, |
| 'confidence': confidence, |
| 'direction_predictions': direction_predictions, |
| 'valid_fraction': valid_fraction |
| } |
|
|
| if dataframe: |
| df = pd.DataFrame({ |
| 'sequence': sequences if len(total_rewards) else [], |
| 'affinity': affinity, |
| 'gated_reward': total_rewards, |
| 'confidence': confidence |
| }) |
| return x_resampled, eval_metrics, df |
| else: |
| return x_resampled, eval_metrics |
|
|
| |
| model.sample_finetuned_td3b = sample_finetuned_td3b.__get__(model, type(model)) |
| return model |
|
|