| import logging |
| import math |
| from dataclasses import dataclass |
| from types import SimpleNamespace |
| from typing import Callable, Dict, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| DEFAULT_EPS = 1e-5 |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _sample_categorical(categorical_probs: torch.Tensor) -> torch.Tensor: |
| gumbel = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log() |
| return (categorical_probs / gumbel).argmax(dim=-1).to(dtype=torch.long) |
|
|
|
|
| def _normalize_probs(probs: torch.Tensor, dim: int = -1) -> torch.Tensor: |
| return probs / probs.sum(dim=dim, keepdim=True).clamp_min(1e-12) |
|
|
|
|
| def _safe_resample_weights(weights: torch.Tensor) -> torch.Tensor: |
| if weights.numel() == 0: |
| return weights |
| weights = torch.where(torch.isfinite(weights), weights, torch.zeros_like(weights)) |
| total = weights.sum() |
| if not torch.isfinite(total) or total <= 0: |
| return torch.full_like(weights, 1.0 / weights.numel()) |
| return weights / total |
|
|
|
|
| def _sequence_logprob( |
| probs: torch.Tensor, |
| x_next: torch.Tensor, |
| x_current: torch.Tensor, |
| mask_idx: int, |
| ) -> torch.Tensor: |
| gather = probs.gather(-1, x_next.unsqueeze(-1)).squeeze(-1).clamp_min(1e-12) |
| mask = (x_current == mask_idx).to(gather.dtype) |
| return (gather.log() * mask).sum(dim=-1) |
|
|
|
|
| def _transition_probs_from_logits( |
| log_probs: torch.Tensor, |
| t: torch.Tensor, |
| dt: torch.Tensor, |
| mask_idx: int, |
| ) -> torch.Tensor: |
| change_prob_t = t[:, None, None] |
| change_prob_s = (t - dt)[:, None, None] |
| q_xs = log_probs.exp() * (change_prob_t - change_prob_s) |
| q_xs[:, :, mask_idx] = change_prob_s[:, :, 0] |
| return q_xs |
|
|
|
|
| def _sample_from_q( |
| q_probs: torch.Tensor, |
| x_current: torch.Tensor, |
| mask_idx: int, |
| ) -> torch.Tensor: |
| x_changed = _sample_categorical(q_probs) |
| copy_flag = (x_current != mask_idx) |
| return torch.where(copy_flag, x_current, x_changed) |
|
|
|
|
| def _protein_tokens_to_device(tokens: torch.Tensor, device: torch.device) -> torch.Tensor: |
| if tokens.device != device: |
| return tokens.to(device) |
| return tokens |
|
|
|
|
| def _tokens_to_one_hot(tokens: torch.Tensor, vocab_size: int) -> torch.Tensor: |
| return F.one_hot(tokens, num_classes=vocab_size).float() |
|
|
|
|
| def _decode_sequences(tokenizer, token_ids: torch.Tensor) -> list: |
| return tokenizer.batch_decode(token_ids) |
|
|
|
|
| def _affinity_from_scoring( |
| scoring_fn: Callable, |
| sequences: list, |
| device: torch.device, |
| protein_seq: Optional[str] = None, |
| ) -> torch.Tensor: |
| if protein_seq is not None: |
| try: |
| scores = scoring_fn(sequences, protein_seq) |
| except TypeError: |
| try: |
| scores = scoring_fn(sequences, prot_seq=protein_seq) |
| except TypeError: |
| scores = scoring_fn(sequences) |
| else: |
| scores = scoring_fn(sequences) |
| if isinstance(scores, tuple): |
| scores = scores[0] |
| scores = np.asarray(scores) |
| if scores.ndim == 1: |
| affinity = scores |
| else: |
| affinity = scores[:, 0] |
| return torch.as_tensor(affinity, device=device, dtype=torch.float32) |
|
|
|
|
| def _roformer_hidden_from_inputs( |
| base_model, |
| input_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| outputs = base_model.backbone.model( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attn_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| return outputs.hidden_states[-1] |
|
|
|
|
| def _logits_from_inputs( |
| base_model, |
| input_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| outputs = base_model.backbone.model( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attn_mask, |
| output_hidden_states=False, |
| return_dict=True, |
| ) |
| return outputs.logits |
|
|
|
|
| @dataclass |
| class RewardInputs: |
| protein_tokens: torch.Tensor |
| d_star: float |
| protein_seq: str |
|
|
|
|
| class RewardWrapper: |
| def __init__( |
| self, |
| scoring_fn: Callable, |
| direction_oracle: torch.nn.Module, |
| base_model, |
| tokenizer, |
| reward_inputs: RewardInputs, |
| device: torch.device, |
| fast_direction: bool = False, |
| reward_alpha: float = 0.1, |
| ): |
| self.scoring_fn = scoring_fn |
| self.direction_oracle = direction_oracle |
| self.base_model = base_model |
| self.tokenizer = tokenizer |
| self.reward_inputs = reward_inputs |
| self.device = device |
| self.fast_direction = fast_direction |
| self.reward_alpha = reward_alpha |
| self._supports_hidden_direction = all( |
| hasattr(direction_oracle, attr) |
| for attr in ("protein_embedder", "fusion", "classifier") |
| ) |
| self._supports_predict = hasattr(direction_oracle, "predict_with_confidence") |
| if self.fast_direction and not self._supports_hidden_direction: |
| logger.warning("fast_direction requested but oracle lacks hidden-direction modules; disabling fast_direction.") |
| self.fast_direction = False |
| self._protein_emb_cache = None |
| if self.reward_inputs.protein_seq is None: |
| raise ValueError("RewardInputs.protein_seq is required for conditioned sampling.") |
|
|
| def _protein_emb(self, batch_size: int) -> torch.Tensor: |
| if not self._supports_hidden_direction: |
| raise RuntimeError("direction_oracle does not support hidden-direction inference.") |
| if self._protein_emb_cache is None: |
| prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device) |
| prot_emb = self.direction_oracle.protein_embedder(prot_tokens) |
| self._protein_emb_cache = prot_emb |
| return self._protein_emb_cache.expand(batch_size, -1) |
|
|
| def _direction_from_hidden( |
| self, |
| hidden: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| if not self._supports_hidden_direction: |
| raise RuntimeError("direction_oracle does not support hidden-direction inference.") |
| mask = attn_mask.to(hidden.dtype).unsqueeze(-1) |
| pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) |
| protein_emb = self._protein_emb(pooled.size(0)) |
| fused = self.direction_oracle.fusion(pooled, protein_emb) |
| return self.direction_oracle.classifier(fused).squeeze(-1) |
|
|
| def _direction_from_probs( |
| self, |
| y_probs: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| if hasattr(self.direction_oracle, "predict_from_probs"): |
| prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device) |
| return self.direction_oracle.predict_from_probs(y_probs, prot_tokens, attn_mask) |
| if not self._supports_hidden_direction: |
| token_ids = y_probs.argmax(dim=-1) |
| return self._direction_from_tokens(token_ids) |
| if self.fast_direction: |
| emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight |
| inputs_embeds = y_probs @ emb_weight |
| hidden = inputs_embeds |
| else: |
| emb_weight = self.base_model.backbone.model.roformer.embeddings.word_embeddings.weight |
| inputs_embeds = y_probs @ emb_weight |
| hidden = _roformer_hidden_from_inputs( |
| self.base_model, |
| inputs_embeds=inputs_embeds, |
| attn_mask=attn_mask, |
| ) |
| return self._direction_from_hidden(hidden, attn_mask) |
|
|
| def _direction_from_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: |
| prot_tokens = _protein_tokens_to_device(self.reward_inputs.protein_tokens, self.device) |
| if prot_tokens.dim() == 2 and prot_tokens.size(0) == 1: |
| prot_tokens = prot_tokens.expand(token_ids.size(0), -1) |
| if self._supports_predict: |
| direction, _ = self.direction_oracle.predict_with_confidence(token_ids, prot_tokens) |
| return direction |
| return self.direction_oracle(token_ids, prot_tokens) |
|
|
| def _gated_reward(self, affinity: torch.Tensor, direction: torch.Tensor) -> torch.Tensor: |
| d_star = torch.as_tensor(self.reward_inputs.d_star, device=self.device, dtype=direction.dtype) |
| directional_score = (direction - 0.5) * d_star |
| gate = torch.sigmoid(directional_score / self.reward_alpha) |
| return affinity * gate |
|
|
| def evaluate_tokens(self, token_ids: torch.Tensor, attn_mask: torch.Tensor) -> Dict[str, torch.Tensor]: |
| sequences = _decode_sequences(self.tokenizer, token_ids) |
| affinity = _affinity_from_scoring( |
| self.scoring_fn, |
| sequences, |
| self.device, |
| protein_seq=self.reward_inputs.protein_seq, |
| ) |
| with torch.no_grad(): |
| direction = self._direction_from_tokens(token_ids) |
| gated_reward = self._gated_reward(affinity, direction) |
| return { |
| "sequences": sequences, |
| "affinity": affinity, |
| "direction": direction, |
| "gated_reward": gated_reward, |
| } |
|
|
| def reward_from_tokens( |
| self, |
| token_ids: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| sequences = _decode_sequences(self.tokenizer, token_ids) |
| affinity = _affinity_from_scoring( |
| self.scoring_fn, |
| sequences, |
| self.device, |
| protein_seq=self.reward_inputs.protein_seq, |
| ) |
| with torch.no_grad(): |
| direction = self._direction_from_tokens(token_ids) |
| return self._gated_reward(affinity, direction) |
|
|
| def reward_from_probs( |
| self, |
| y_probs: torch.Tensor, |
| token_ids_for_affinity: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| affinity = None |
| if hasattr(self.scoring_fn, "forward_from_probs"): |
| try: |
| affinity = self.scoring_fn.forward_from_probs( |
| y_probs, |
| attn_mask, |
| prot_seq=self.reward_inputs.protein_seq, |
| ) |
| except Exception as exc: |
| logger.warning("Differentiable affinity failed; falling back to argmax. Error: %s", exc) |
| affinity = None |
| if affinity is None: |
| sequences = _decode_sequences(self.tokenizer, token_ids_for_affinity) |
| affinity = _affinity_from_scoring( |
| self.scoring_fn, |
| sequences, |
| self.device, |
| protein_seq=self.reward_inputs.protein_seq, |
| ) |
| direction = self._direction_from_probs(y_probs, attn_mask) |
| return self._gated_reward(affinity, direction) |
|
|
|
|
| class PepTuneSampler: |
| def __init__( |
| self, |
| base_model, |
| reward_fn: RewardWrapper, |
| seq_length: int, |
| num_steps: int, |
| mcts_iterations: int, |
| num_children: int, |
| sample_prob_weight: float, |
| invalid_penalty: float, |
| pareto_max_size: Optional[int], |
| eps: float, |
| ): |
| from peptide_mcts import Node, updateParetoFront |
| from utils.app import PeptideAnalyzer |
|
|
| self.base_model = base_model |
| self.reward_fn = reward_fn |
| self.seq_length = seq_length |
| self.num_steps = num_steps |
| self.mcts_iterations = mcts_iterations |
| self.num_children = num_children |
| self.sample_prob_weight = sample_prob_weight |
| self.invalid_penalty = invalid_penalty |
| self.pareto_max_size = pareto_max_size |
| self.eps = eps |
|
|
| self.device = base_model.device |
| self.mask_idx = base_model.mask_index |
| self.tokenizer = base_model.tokenizer |
| self.analyzer = PeptideAnalyzer() |
| self.Node = Node |
| self.updateParetoFront = updateParetoFront |
|
|
| self.timesteps = torch.linspace(1, eps, num_steps + 1, device=self.device) |
| self.dt = torch.as_tensor((1 - eps) / num_steps, device=self.device) |
| self.args = SimpleNamespace( |
| num_obj=1, |
| total_num_steps=num_steps, |
| seq_length=seq_length, |
| num_children=num_children, |
| ) |
|
|
| def _init_root(self): |
| masked_seq = torch.full((self.seq_length,), self.mask_idx, device=self.device, dtype=torch.long) |
| attn_mask = torch.ones_like(masked_seq, device=self.device) |
| tokens = {"seqs": masked_seq, "attention_mask": attn_mask} |
| return self.Node( |
| args=self.args, |
| tokens=tokens, |
| log_rnd=torch.zeros((), device=self.device), |
| log_policy_step=torch.zeros((), device=self.device), |
| log_pretrained_step=torch.zeros((), device=self.device), |
| totalReward=np.zeros(self.args.num_obj), |
| timestep=0, |
| ) |
|
|
| def _select(self, root): |
| node = root |
| while True: |
| node, status = node.selectNode() |
| if status != 3: |
| return node, status |
|
|
| def _update_pareto(self, pareto_front, pareto_tokens, seq, token_ids, score_vector): |
| pareto_front = self.updateParetoFront( |
| pareto_front, |
| seq, |
| score_vector, |
| totalSize=self.pareto_max_size, |
| ) |
| pareto_tokens = {k: pareto_tokens[k] for k in pareto_front if k in pareto_tokens} |
| if seq in pareto_front: |
| pareto_tokens[seq] = token_ids.detach().clone() |
| return pareto_front, pareto_tokens |
|
|
| def _expand(self, parent, pareto_front, pareto_tokens): |
| parent_tokens = parent.tokens["seqs"].to(self.device) |
| attn_mask = parent.tokens["attention_mask"].to(self.device) |
| t = self.timesteps[parent.timestep] * torch.ones(1, 1, device=self.device) |
|
|
| with torch.no_grad(): |
| _, x_children, log_policy_step, log_pretrained_step = self.base_model.batch_mcts_reverse_step( |
| token_array=parent_tokens, |
| t=t, |
| dt=self.dt, |
| batch_size=self.num_children, |
| pretrained=self.base_model, |
| ) |
|
|
| child_log_rnd = parent.log_rnd + (log_pretrained_step - log_policy_step) |
| log_policy_step = log_policy_step * self.sample_prob_weight |
|
|
| x_rollout = x_children |
| t_step = self.timesteps[parent.timestep] * torch.ones(self.num_children, 1, device=self.device) |
| for i in range(1, self.num_steps - parent.timestep): |
| t_step = self.timesteps[parent.timestep + i] * torch.ones(self.num_children, 1, device=self.device) |
| with torch.no_grad(): |
| _, x_next, _, _ = self.base_model.mcts_reverse_step( |
| x_rollout, |
| t=t_step, |
| dt=self.dt, |
| pretrained=self.base_model, |
| ) |
| x_rollout = x_next |
|
|
| if (x_rollout == self.mask_idx).any().item(): |
| with torch.no_grad(): |
| _, x_next, _, _ = self.base_model.mcts_noise_removal( |
| x_rollout, |
| t=t_step, |
| dt=self.dt, |
| pretrained=self.base_model, |
| ) |
| x_rollout = x_next |
|
|
| sequences = self.tokenizer.batch_decode(x_rollout) |
| valid_mask = [self.analyzer.is_peptide(seq) for seq in sequences] |
|
|
| reward_values = np.full(self.num_children, -float(self.invalid_penalty), dtype=np.float32) |
| if any(valid_mask): |
| valid_tokens = x_rollout[valid_mask] |
| valid_sequences = [seq for seq, keep in zip(sequences, valid_mask) if keep] |
| affinity = _affinity_from_scoring( |
| self.reward_fn.scoring_fn, |
| valid_sequences, |
| self.device, |
| protein_seq=self.reward_fn.reward_inputs.protein_seq, |
| ) |
| with torch.no_grad(): |
| direction = self.reward_fn._direction_from_tokens(valid_tokens) |
| gated_reward = self.reward_fn._gated_reward(affinity, direction) |
| d_star = self.reward_fn.reward_inputs.d_star |
| dir_score = (direction - 0.5) * d_star |
|
|
| for idx, seq in enumerate(valid_sequences): |
| score_vector = np.array( |
| [float(affinity[idx].item()), float(dir_score[idx].item())], |
| dtype=np.float32, |
| ) |
| pareto_front, pareto_tokens = self._update_pareto( |
| pareto_front, |
| pareto_tokens, |
| seq, |
| valid_tokens[idx], |
| score_vector, |
| ) |
|
|
| reward_values[np.array(valid_mask)] = gated_reward.detach().cpu().numpy() |
|
|
| reward_vectors = [] |
| for i in range(self.num_children): |
| child_tokens = {"seqs": x_children[i].to(dtype=torch.long), "attention_mask": attn_mask} |
| reward_vec = np.array([float(reward_values[i])], dtype=np.float32) |
| parent.addChildNode( |
| tokens=child_tokens, |
| log_rnd=child_log_rnd[i], |
| log_policy_step=log_policy_step[i], |
| log_pretrained_step=log_pretrained_step[i], |
| totalReward=reward_vec, |
| ) |
| reward_vectors.append(reward_vec) |
|
|
| avg_reward = np.mean(np.stack(reward_vectors, axis=0), axis=0) |
| node = parent |
| while node: |
| node.updateNode(avg_reward) |
| node = node.parentNode |
|
|
| return pareto_front, pareto_tokens |
|
|
| def _select_from_pareto(self, pareto_front, pareto_tokens, batch_size): |
| if not pareto_front: |
| return self.base_model.sample_prior(batch_size, self.seq_length).to(self.device) |
|
|
| seqs = list(pareto_front.keys()) |
| scores = np.stack([pareto_front[seq] for seq in seqs], axis=0) |
| affinity = scores[:, 0] |
| dir_score = scores[:, 1] |
| gate = 1.0 / (1.0 + np.exp(-dir_score / max(self.reward_fn.reward_alpha, 1e-6))) |
| gated = affinity * gate |
| order = np.argsort(-gated) |
|
|
| if len(order) >= batch_size: |
| selected = [seqs[i] for i in order[:batch_size]] |
| else: |
| repeats = np.random.choice(order, size=batch_size, replace=True) |
| selected = [seqs[i] for i in repeats] |
|
|
| tokens = [pareto_tokens[seq] for seq in selected] |
| return torch.stack(tokens, dim=0).to(self.device) |
|
|
| def sample(self, batch_size): |
| self.base_model.eval() |
| root = self._init_root() |
| pareto_front = {} |
| pareto_tokens = {} |
|
|
| for _ in range(self.mcts_iterations): |
| leaf, status = self._select(root) |
| if status == 1: |
| continue |
| pareto_front, pareto_tokens = self._expand(leaf, pareto_front, pareto_tokens) |
|
|
| return self._select_from_pareto(pareto_front, pareto_tokens, batch_size) |
|
|
|
|
| def _logits_and_probs_from_tokens( |
| base_model, |
| token_ids: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| logits = _logits_from_inputs(base_model, input_ids=token_ids, attn_mask=attn_mask) |
| log_probs = base_model.subs_parameterization(logits, token_ids) |
| return log_probs |
|
|
|
|
| def _logits_and_probs_from_one_hot( |
| base_model, |
| y_one_hot: torch.Tensor, |
| token_ids: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| emb_weight = base_model.backbone.model.roformer.embeddings.word_embeddings.weight |
| inputs_embeds = y_one_hot @ emb_weight |
| logits = _logits_from_inputs(base_model, inputs_embeds=inputs_embeds, attn_mask=attn_mask) |
| log_probs = base_model.subs_parameterization(logits, token_ids) |
| return log_probs |
|
|
|
|
| def classifier_guidance( |
| base_model, |
| reward_fn: RewardWrapper, |
| batch_size: int, |
| seq_length: int, |
| num_steps: int, |
| guidance_scale: float, |
| eps: float = DEFAULT_EPS, |
| guidance_steps: Optional[int] = None, |
| ) -> Dict[str, torch.Tensor]: |
| device = base_model.device |
| mask_idx = base_model.mask_index |
| vocab_size = base_model.vocab_size |
| x = base_model.sample_prior(batch_size, seq_length).to(device) |
| attn_mask = torch.ones_like(x, device=device) |
| timesteps = torch.linspace(1, eps, num_steps + 1, device=device) |
| dt = torch.as_tensor((1 - eps) / num_steps, device=device) |
|
|
| guidance_enabled = True |
| for step in range(num_steps): |
| t = timesteps[step].repeat(batch_size) |
| use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps) |
| if not use_guidance: |
| log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask) |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| x = _sample_from_q(q_base, x, mask_idx) |
| continue |
|
|
| y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device) |
| y_one_hot.requires_grad_(True) |
| token_ids = x.detach() |
| log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask) |
| y_probs = log_probs.exp() |
| token_ids_for_affinity = y_probs.argmax(dim=-1).detach() |
| reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask) |
| if not reward.requires_grad: |
| if guidance_enabled: |
| logger.warning( |
| "Reward does not require grad; disabling gradient guidance for classifier_guidance." |
| ) |
| guidance_enabled = False |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| x = _sample_from_q(q_base, x, mask_idx) |
| continue |
| reward.sum().backward() |
| grad = y_one_hot.grad |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1)) |
| guidance = guidance.clamp(min=-50.0, max=50.0) |
| q_guided = q_base * torch.exp(guidance) |
| q_guided = _normalize_probs(q_guided) |
| x = _sample_from_q(q_guided, x, mask_idx) |
|
|
| return {"tokens": x} |
|
|
|
|
| def unguided_sampling( |
| base_model, |
| batch_size: int, |
| seq_length: int, |
| num_steps: int, |
| eps: float = DEFAULT_EPS, |
| ) -> Dict[str, torch.Tensor]: |
| device = base_model.device |
| mask_idx = base_model.mask_index |
| x = base_model.sample_prior(batch_size, seq_length).to(device) |
| attn_mask = torch.ones_like(x, device=device) |
| timesteps = torch.linspace(1, eps, num_steps + 1, device=device) |
| dt = torch.as_tensor((1 - eps) / num_steps, device=device) |
|
|
| for step in range(num_steps): |
| t = timesteps[step].repeat(batch_size) |
| log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask) |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| x = _sample_from_q(q_base, x, mask_idx) |
|
|
| return {"tokens": x} |
|
|
|
|
| def sequential_monte_carlo( |
| base_model, |
| reward_fn: RewardWrapper, |
| batch_size: int, |
| seq_length: int, |
| num_steps: int, |
| alpha: float, |
| eps: float = DEFAULT_EPS, |
| ) -> Dict[str, torch.Tensor]: |
| device = base_model.device |
| mask_idx = base_model.mask_index |
| x = base_model.sample_prior(batch_size, seq_length).to(device) |
| attn_mask = torch.ones_like(x, device=device) |
| timesteps = torch.linspace(1, eps, num_steps + 1, device=device) |
| dt = torch.as_tensor((1 - eps) / num_steps, device=device) |
|
|
| with torch.no_grad(): |
| r_current = reward_fn.reward_from_tokens(x, attn_mask).detach() |
| for step in range(num_steps): |
| t = timesteps[step].repeat(batch_size) |
| log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask) |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| x_next = _sample_from_q(q_base, x, mask_idx) |
|
|
| with torch.no_grad(): |
| r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach() |
| weights = torch.exp((r_next - r_current) / alpha).clamp_max(1e6) |
| weights = _safe_resample_weights(weights) |
| indices = torch.multinomial(weights, num_samples=batch_size, replacement=True) |
| x = x_next[indices] |
| r_current = r_next[indices] |
|
|
| return {"tokens": x} |
|
|
|
|
| def twisted_diffusion_sampler( |
| base_model, |
| reward_fn: RewardWrapper, |
| batch_size: int, |
| seq_length: int, |
| num_steps: int, |
| guidance_scale: float, |
| alpha: float, |
| eps: float = DEFAULT_EPS, |
| guidance_steps: Optional[int] = None, |
| ) -> Dict[str, torch.Tensor]: |
| device = base_model.device |
| mask_idx = base_model.mask_index |
| vocab_size = base_model.vocab_size |
| x = base_model.sample_prior(batch_size, seq_length).to(device) |
| attn_mask = torch.ones_like(x, device=device) |
| timesteps = torch.linspace(1, eps, num_steps + 1, device=device) |
| dt = torch.as_tensor((1 - eps) / num_steps, device=device) |
|
|
| with torch.no_grad(): |
| r_current = reward_fn.reward_from_tokens(x, attn_mask).detach() |
| guidance_enabled = True |
| for step in range(num_steps): |
| t = timesteps[step].repeat(batch_size) |
| use_guidance = guidance_enabled and (guidance_steps is None or step >= num_steps - guidance_steps) |
|
|
| if use_guidance: |
| y_one_hot = _tokens_to_one_hot(x, vocab_size).to(device) |
| y_one_hot.requires_grad_(True) |
| token_ids = x.detach() |
| log_probs = _logits_and_probs_from_one_hot(base_model, y_one_hot, token_ids, attn_mask) |
| y_probs = log_probs.exp() |
| token_ids_for_affinity = y_probs.argmax(dim=-1).detach() |
| reward = reward_fn.reward_from_probs(y_probs, token_ids_for_affinity, attn_mask) |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| if not reward.requires_grad: |
| if guidance_enabled: |
| logger.warning( |
| "Reward does not require grad; disabling gradient guidance for twisted_diffusion_sampler." |
| ) |
| guidance_enabled = False |
| q_guided = q_base |
| else: |
| reward.sum().backward() |
| grad = y_one_hot.grad |
| guidance = guidance_scale * (grad - grad[:, :, mask_idx].unsqueeze(-1)) |
| guidance = guidance.clamp(min=-50.0, max=50.0) |
| q_guided = q_base * torch.exp(guidance) |
| q_guided = _normalize_probs(q_guided) |
| else: |
| log_probs = _logits_and_probs_from_tokens(base_model, x, attn_mask) |
| q_base = _transition_probs_from_logits(log_probs, t, dt, mask_idx) |
| q_guided = q_base |
|
|
| x_next = _sample_from_q(q_guided, x, mask_idx) |
| with torch.no_grad(): |
| r_next = reward_fn.reward_from_tokens(x_next, attn_mask).detach() |
|
|
| logp_guided = _sequence_logprob(q_guided, x_next, x, mask_idx) |
| logp_base = _sequence_logprob(q_base, x_next, x, mask_idx) |
| weights = torch.exp((r_next - r_current) / alpha + (logp_base - logp_guided)).clamp_max(1e6) |
| weights = _safe_resample_weights(weights) |
| indices = torch.multinomial(weights, num_samples=batch_size, replacement=True) |
| x = x_next[indices] |
| r_current = r_next[indices] |
|
|
| return {"tokens": x} |
|
|
|
|
| def peptune_mctg_sampling( |
| base_model, |
| reward_fn: RewardWrapper, |
| batch_size: int, |
| seq_length: int, |
| num_steps: int, |
| mcts_iterations: int, |
| num_children: int, |
| alpha: float, |
| sample_prob_weight: float, |
| invalid_penalty: float = 1.0, |
| pareto_max_size: Optional[int] = None, |
| eps: float = DEFAULT_EPS, |
| ) -> Dict[str, torch.Tensor]: |
| sampler = PepTuneSampler( |
| base_model=base_model, |
| reward_fn=reward_fn, |
| seq_length=seq_length, |
| num_steps=num_steps, |
| mcts_iterations=mcts_iterations, |
| num_children=num_children, |
| sample_prob_weight=sample_prob_weight, |
| invalid_penalty=invalid_penalty, |
| pareto_max_size=pareto_max_size, |
| eps=eps, |
| ) |
| tokens = sampler.sample(batch_size=batch_size) |
| return {"tokens": tokens} |
|
|