File size: 4,443 Bytes
ee6da62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation

TD3B is a sequence-based generative framework that designs peptide binders with specified agonist or antagonist behavior. It combines a Direction Oracle, a soft binding-affinity gate, and amortized fine-tuning of a pre-trained discrete diffusion model (MDLM).

## Installation

```bash
conda env create -f env.yml
conda activate td3b
pip install -e .
```

## Data and Checkpoints

Download the pretrained checkpoints and data from [Google Drive (TBA)](placeholder_link).

Place the files as follows:

```
TD3B/
β”œβ”€β”€ checkpoints/
β”‚   β”œβ”€β”€ pretrained.ckpt          # Pre-trained MDLM weights
β”‚   β”œβ”€β”€ td3b.ckpt                # Fine-tuned TD3B model
β”‚   └── direction_oracle.pt      # Direction Oracle weights
β”œβ”€β”€ data/
β”‚   β”œβ”€β”€ train.csv                # Training set (target-binder pairs)
β”‚   └── test.csv                 # Test set
β”œβ”€β”€ scoring/functions/classifiers/
β”‚   β”œβ”€β”€ binding-affinity.pt
β”‚   β”œβ”€β”€ hemolysis-xgboost.json
β”‚   β”œβ”€β”€ nonfouling-xgboost.json
β”‚   β”œβ”€β”€ permeability-xgboost.json
β”‚   └── solubility-xgboost.json
└── tokenizer/
    β”œβ”€β”€ new_vocab.txt
    └── new_splits.txt
```

## Code Structure

```
TD3B/
β”œβ”€β”€ inference.py                 # Generate binders (main inference entry point)
β”œβ”€β”€ finetune_multi_target.py     # Multi-target TD3B training
β”œβ”€β”€ finetune_utils.py            # Training utilities
β”œβ”€β”€ launch_multi_target.sh       # Training launcher script
β”œβ”€β”€ diffusion.py                 # MDLM backbone (TR2-D2)
β”œβ”€β”€ roformer.py                  # RoFormer wrapper
β”œβ”€β”€ noise_schedule.py            # Noise schedules
β”œβ”€β”€ peptide_mcts.py              # MCTS tree search
β”œβ”€β”€ td3b/
β”‚   β”œβ”€β”€ direction_oracle.py      # Direction Oracle (f_Ο†)
β”‚   β”œβ”€β”€ td3b_scoring.py          # Gated reward R = g_ψ Β· Οƒ(d*Β·(f_Ο†βˆ’0.5)/Ο„)
β”‚   β”œβ”€β”€ td3b_losses.py           # L_WDCE + λ·L_ctr + Ξ²Β·L_KL
β”‚   β”œβ”€β”€ td3b_mcts.py             # TD3B-extended MCTS
β”‚   β”œβ”€β”€ td3b_finetune.py         # Training loop
β”‚   └── data_utils.py            # Data loading utilities
β”œβ”€β”€ scoring/                     # Affinity predictor (g_ψ) and property classifiers
β”œβ”€β”€ baselines/                   # CG, SMC, TDS, PepTune, Unguided baselines
β”œβ”€β”€ tokenizer/                   # SMILES tokenizer (vocab + splits)
β”œβ”€β”€ configs/                     # Model and training configs
└── utils/                       # Misc utilities
```

## Inference

Generate agonist/antagonist binders for target proteins:

```bash
python inference.py \
    --ckpt_path checkpoints/td3b.ckpt \
    --val_csv data/test.csv \
    --save_path results/ \
    --seed 42 \
    --num_pool 32 \
    --val_samples_per_target 8 \
    --resample_alpha 0.1
```

This generates 32 candidates per (target, direction), scores them with the Direction Oracle and affinity predictor, applies Algorithm 2 weighted resampling, and saves only valid peptide samples.

Output: `results/td3b_results_seed42.csv` with columns: target, sequence, direction, affinity, gated_reward, direction_oracle, direction_accuracy.

## Training

### Multi-target TD3B

1. Edit `launch_multi_target.sh` β€” set paths to checkpoints, data, and oracle:

```bash
BASE_PATH="/path/to/TD3B"
PRETRAINED_CHECKPOINT="${BASE_PATH}/checkpoints/pretrained.ckpt"
TRAIN_CSV="${BASE_PATH}/data/train.csv"
ORACLE_CKPT="${BASE_PATH}/checkpoints/direction_oracle.pt"
```

2. Launch training:

```bash
bash launch_multi_target.sh
```

Key hyperparameters (in `launch_multi_target.sh`):
- `CONTRASTIVE_WEIGHT=0.1` β€” Ξ» for L_ctr
- `KL_BETA=0.1` β€” Ξ² for L_KL
- `SIGMOID_TEMPERATURE=0.1` β€” Ο„ for gated reward
- `NUM_ITER=20` β€” MCTS iterations per round
- `NUM_CHILDREN=16` β€” Children per MCTS expansion

### Baselines

Run baseline methods (CG, SMC, TDS, PepTune, Unguided):

```bash
cd baselines/
bash run.sh --baseline cg --device cuda:0
bash run.sh --baseline smc --device cuda:0
bash run.sh --baseline tds --device cuda:0
```

## Citation

```bibtex
@article{caotd3b,
  title={TD3B: Transition-Directed Discrete Diffusion for Allosteric Binder Generation},
  author={Cao, Hanqun and Pal, Aastha and Tang, Sophia and Zhang, Yinuo and Zhang, Jingjie and Heng, Pheng-Ann and Chatterjee, Pranam}
}
```