ynuozhang commited on
Commit Β·
0da30f9
1
Parent(s): ab57276
update readme
Browse files- README.md +272 -97
- download_light.py +2 -1
- fit_mapie_adaptive.py +333 -0
- inference.py +5 -0
README.md
CHANGED
|
@@ -8,7 +8,7 @@ license: apache-2.0
|
|
| 8 |
|
| 9 |
This is the repository for [PeptiVerse: A Unified Platform for Therapeutic Peptide Property Prediction](https://www.biorxiv.org/content/10.64898/2025.12.31.697180), a collection of machine learning predictors for canonical and non-canonical peptide property prediction using sequence and SMILES representations. 𧬠PeptiVerse π enables evaluation of key biophysical and therapeutic properties of peptides for property-optimized generation.
|
| 10 |
|
| 11 |
-
## Table of Contents
|
| 12 |
|
| 13 |
- [Quick start](#quick-start)
|
| 14 |
- [Installation](#installation)
|
|
@@ -20,13 +20,14 @@ This is the repository for [PeptiVerse: A Unified Platform for Therapeutic Pepti
|
|
| 20 |
- [Usage](#usage)
|
| 21 |
- [Local Application Hosting](#local-application-hosting)
|
| 22 |
- [Dataset integration](#dataset-integration)
|
|
|
|
| 23 |
- [Quick inference by property per model](#Quick-inference-by-property-per-model)
|
| 24 |
- [Property Interpretations](#property-interpretations)
|
| 25 |
- [Model Architecture](#model-architecture)
|
| 26 |
- [Troubleshooting](#troubleshooting)
|
| 27 |
- [Citation](#citation)
|
| 28 |
|
| 29 |
-
## Quick Start
|
| 30 |
- Light-weighted start (basic models, no cuML, read below for details)
|
| 31 |
```bash
|
| 32 |
# Ignore all LFS files, you will see an empty folder first
|
|
@@ -69,7 +70,7 @@ pip install -r requirements.txt
|
|
| 69 |
# Run inference
|
| 70 |
python inference.py
|
| 71 |
```
|
| 72 |
-
## Installation
|
| 73 |
### Minimal Setup
|
| 74 |
- Easy start-up environment (using transformers, xgboost models)
|
| 75 |
```bash
|
|
@@ -85,7 +86,7 @@ pip install -r requirements.txt
|
|
| 85 |
# run inference (see below)
|
| 86 |
apptainer exec peptiverse.sif python inference.py
|
| 87 |
```
|
| 88 |
-
## Repository Structure
|
| 89 |
This repo contains important large files for [PeptiVerse](https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse), an interactive app for peptide property prediction. [Paper link.](https://www.biorxiv.org/content/10.64898/2025.12.31.697180v1)
|
| 90 |
|
| 91 |
```
|
|
@@ -105,8 +106,9 @@ PeptiVerse/
|
|
| 105 |
βββ best_models.txt # Model selection manifest
|
| 106 |
βββ requirements.txt # Python dependencies
|
| 107 |
```
|
|
|
|
| 108 |
|
| 109 |
-
## Training Data Collection
|
| 110 |
|
| 111 |
<table>
|
| 112 |
<caption><strong>Data distribution.</strong> Classification tasks report counts for class 0/1; regression tasks report total sample size (N).</caption>
|
|
@@ -145,15 +147,15 @@ PeptiVerse/
|
|
| 145 |
<td>Solubility</td>
|
| 146 |
<td>9668</td>
|
| 147 |
<td>8785</td>
|
| 148 |
-
<td>
|
| 149 |
-
<td>
|
| 150 |
</tr>
|
| 151 |
<tr>
|
| 152 |
<td>Permeability (Penetrance)</td>
|
| 153 |
<td>1162</td>
|
| 154 |
<td>1162</td>
|
| 155 |
-
<td>
|
| 156 |
-
<td>
|
| 157 |
</tr>
|
| 158 |
<tr>
|
| 159 |
<td>Toxicity</td>
|
|
@@ -189,39 +191,39 @@ PeptiVerse/
|
|
| 189 |
</table>
|
| 190 |
|
| 191 |
|
| 192 |
-
## Best Model List
|
| 193 |
|
| 194 |
### Full model set (cuML-enabled)
|
| 195 |
-
| Property
|
| 196 |
-
|---
|
| 197 |
-
| Hemolysis
|
| 198 |
-
| Non-Fouling
|
| 199 |
-
| Solubility
|
| 200 |
-
| Permeability (Penetrance)
|
| 201 |
-
| Toxicity
|
| 202 |
-
| Binding Affinity
|
| 203 |
-
| Permeability (PAMPA)
|
| 204 |
-
| Permeability (Caco-2)
|
| 205 |
-
| Half-life
|
|
|
|
| 206 |
>Note: *unpooled* indicates models operating on token-level embeddings with cross-attention, rather than mean-pooled representations.
|
| 207 |
|
| 208 |
### Minimal deployable model set (no cuML)
|
| 209 |
-
| Property
|
| 210 |
-
|---
|
| 211 |
-
| Hemolysis
|
| 212 |
-
| Non-Fouling
|
| 213 |
-
| Solubility
|
| 214 |
-
| Permeability (Penetrance)
|
| 215 |
-
| Toxicity
|
| 216 |
-
| Binding Affinity
|
| 217 |
-
| Permeability (PAMPA)
|
| 218 |
-
| Permeability (Caco-2)
|
| 219 |
-
| Half-life
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
## Usage
|
| 225 |
|
| 226 |
### Local Application Hosting
|
| 227 |
- Host the [PeptiVerse UI](https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse) locally with your own resources.
|
|
@@ -231,6 +233,9 @@ PeptiVerse/
|
|
| 231 |
git clone https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse
|
| 232 |
python app.py
|
| 233 |
```
|
|
|
|
|
|
|
|
|
|
| 234 |
### Dataset integration
|
| 235 |
- All properties are provided with raw_data/split_ready_csvs/[huggingface_datasets](https://huggingface.co/docs/datasets/en/index).
|
| 236 |
- Selective download the data you need with `huggingface-cli`
|
|
@@ -266,41 +271,138 @@ print("Downloaded to:", local_dir)
|
|
| 266 |
- Pooled (fixed-length vector per sequence)
|
| 267 |
- Generated by mean-pooling token embeddings excluding special tokens (CLS/EOS) and padding.
|
| 268 |
- Each item:
|
| 269 |
-
sequence: `str`
|
| 270 |
-
label: `int` (classification) or `float` (regression)
|
| 271 |
-
embedding: `float32[H]` (H=1280 for ESM-2 650M)
|
| 272 |
- Unpooled (variable-length token matrix)
|
| 273 |
- Generated by keeping all valid token embeddings (excluding special tokens + padding) as a per-sequence matrix.
|
| 274 |
- Each item:
|
| 275 |
-
sequence: `str`
|
| 276 |
-
label: `int` (classification) or `float` (regression)
|
| 277 |
-
embedding: `float16[L, H]` (nested lists)
|
| 278 |
-
attention_mask: `int8[L]`
|
| 279 |
-
length: `int` (=L)
|
| 280 |
- B) SMILES-based ([PeptideCLM](https://github.com/AaronFeller/PeptideCLM) embeddings)
|
| 281 |
- Pooled (fixed-length vector per sequence)
|
| 282 |
- Generated by mean-pooling token embeddings excluding special tokens (CLS/EOS) and padding.
|
| 283 |
- Each item:
|
| 284 |
-
sequence: `str` (SMILES)
|
| 285 |
-
label: `int` (classification) or `float` (regression)
|
| 286 |
-
embedding: `float32[H]`
|
| 287 |
- Unpooled (variable-length token matrix)
|
| 288 |
- Generated by keeping all valid token embeddings (excluding special tokens + padding) as a per-sequence matrix.
|
| 289 |
- Each item:
|
| 290 |
-
sequence: `str` (SMILES)
|
| 291 |
-
label: `int` (classification) or `float` (regression)
|
| 292 |
-
embedding: `float16[L, H]` (nested lists)
|
| 293 |
-
attention_mask: `int8[L]`
|
| 294 |
-
length: `int` (=L)
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
### Quick
|
| 298 |
```python
|
| 299 |
from inference import PeptiVersePredictor
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
| 304 |
device="cuda", # or "cpu"
|
| 305 |
)
|
| 306 |
|
|
@@ -383,78 +485,150 @@ print(out)
|
|
| 383 |
|
| 384 |
```
|
| 385 |
|
| 386 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
-
|
|
|
|
| 389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
---
|
| 391 |
-
##
|
| 392 |
-
50% of read blood cells being lysed at x ug/ml concetration (HC50). If HC50 < 100uM, considered as hemolytic, otherwise non-hemolytic, resulting in a binary 0/1 dataset. The predicted probability should therefore be interpreted as a risk indicator, not an exact concentration estimate. <br>
|
| 393 |
|
| 394 |
-
|
| 395 |
|
| 396 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
- Score close to 0.0 = non-hemolytic
|
| 398 |
---
|
| 399 |
|
| 400 |
-
###
|
| 401 |
-
Outputs a probability (0β1) that a peptide remains soluble in aqueous conditions.
|
| 402 |
-
|
| 403 |
-
|
|
|
|
| 404 |
|
| 405 |
-
- Score close to 1.0 = highly soluble<br>
|
| 406 |
-
- Score close to 0.0 = poorly soluble<br>
|
| 407 |
---
|
| 408 |
|
| 409 |
-
###
|
| 410 |
-
Higher scores indicate stronger non-fouling behavior, desirable for circulation and surface-exposed applications.
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
- Score close to
|
| 414 |
-
- Score close to 0.0 = fouling<br>
|
| 415 |
-
|
| 416 |
---
|
| 417 |
|
| 418 |
-
###
|
| 419 |
-
Predicts membrane permeability on a log P scale.
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
-
|
| 423 |
-
- For penetrance predictions, it is a classification prediction, so within the [0, 1] range, closer to 1 indicates more permeable.<br>
|
| 424 |
-
|
| 425 |
---
|
| 426 |
|
| 427 |
-
###
|
| 428 |
**Interpretation:** Predicted values reflect relative peptide stability for the unit in hours. Higher scores indicate longer persistence in serum, while lower scores suggest faster degradation.
|
| 429 |
|
| 430 |
---
|
| 431 |
|
| 432 |
-
###
|
| 433 |
**Interpretation:** Outputs a probability (0β1) that a peptide exhibits toxic effects. Higher scores indicate increased toxicity risk.
|
| 434 |
|
| 435 |
---
|
| 436 |
|
| 437 |
-
###
|
| 438 |
|
| 439 |
-
Predicts peptide-protein binding affinity. Requires both peptide and target protein sequence.
|
| 440 |
|
| 441 |
**Interpretation:**<br>
|
| 442 |
- Scores β₯ 9 correspond to tight binders (K β€ 10β»βΉ M, nanomolar to picomolar range)<br>
|
| 443 |
- Scores between 7 and 9 correspond to medium binders (10β»β·β10β»βΉ M, nanomolar to micromolar range)<br>
|
| 444 |
- Scores < 7 correspond to weak binders (K β₯ 10β»βΆ M, micromolar and weaker)<br>
|
| 445 |
- A difference of 1 unit in score corresponds to an approximately tenfold change in binding affinity.<br>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
|
| 448 |
-
## Model Architecture
|
| 449 |
|
| 450 |
-
- **Sequence Embeddings:** [ESM-2 650M model](https://huggingface.co/facebook/esm2_t33_650M_UR50D) / [PeptideCLM model](https://huggingface.co/aaronfeller/PeptideCLM-23M-all). Foundational embeddings are frozen.
|
| 451 |
- **XGBoost Model:** Gradient boosting on pooled embedding features for efficient, high-performance prediction.
|
| 452 |
- **CNN/Transformer Model:** One-dimensional convolutional/self-attention transformer networks operating on unpooled embeddings to capture local sequence patterns.
|
| 453 |
- **Binding Model:** Transformer-based architecture with cross-attention between protein and peptide representations.
|
| 454 |
- **SVR Model:** Support Vector Regression applied to pooled embeddings, providing a kernel-based, nonparametric regression baseline that is robust on smaller or noisy datasets.
|
| 455 |
- **Others:** SVM and Elastic Nets were trained with [RAPIDS cuML](https://github.com/rapidsai/cuml), which requires a CUDA environment and is therefore not supported in the web app. Model checkpoints remain available in the Hugging Face repository.
|
| 456 |
|
| 457 |
-
## Troubleshooting
|
| 458 |
|
| 459 |
### LFS Download Issues
|
| 460 |
|
|
@@ -466,21 +640,22 @@ huggingface-cli download ChatterjeeLab/PeptiVerse \
|
|
| 466 |
--local-dir . \
|
| 467 |
--local-dir-use-symlinks False
|
| 468 |
```
|
| 469 |
-
### Trouble installing cuML
|
| 470 |
-
For error related to cuda library, reinstall the `torch` after installing `cuML`.
|
| 471 |
|
| 472 |
-
## Citation
|
| 473 |
|
| 474 |
If you find this repository helpful for your publications, please consider citing our paper:
|
| 475 |
|
| 476 |
```
|
| 477 |
-
@article {
|
| 478 |
author = {Zhang, Yinuo and Tang, Sophia and Chen, Tong and Mahood, Elizabeth and Vincoff, Sophia and Chatterjee, Pranam},
|
| 479 |
title = {PeptiVerse: A Unified Platform for Therapeutic Peptide Property Prediction},
|
|
|
|
| 480 |
year = {2026},
|
| 481 |
doi = {10.64898/2025.12.31.697180},
|
| 482 |
-
|
|
|
|
|
|
|
| 483 |
journal = {bioRxiv}
|
| 484 |
}
|
| 485 |
```
|
| 486 |
-
To use this repository, you agree to abide by the
|
|
|
|
| 8 |
|
| 9 |
This is the repository for [PeptiVerse: A Unified Platform for Therapeutic Peptide Property Prediction](https://www.biorxiv.org/content/10.64898/2025.12.31.697180), a collection of machine learning predictors for canonical and non-canonical peptide property prediction using sequence and SMILES representations. 𧬠PeptiVerse π enables evaluation of key biophysical and therapeutic properties of peptides for property-optimized generation.
|
| 10 |
|
| 11 |
+
## Table of Contents π
|
| 12 |
|
| 13 |
- [Quick start](#quick-start)
|
| 14 |
- [Installation](#installation)
|
|
|
|
| 20 |
- [Usage](#usage)
|
| 21 |
- [Local Application Hosting](#local-application-hosting)
|
| 22 |
- [Dataset integration](#dataset-integration)
|
| 23 |
+
- [Training](#training)
|
| 24 |
- [Quick inference by property per model](#Quick-inference-by-property-per-model)
|
| 25 |
- [Property Interpretations](#property-interpretations)
|
| 26 |
- [Model Architecture](#model-architecture)
|
| 27 |
- [Troubleshooting](#troubleshooting)
|
| 28 |
- [Citation](#citation)
|
| 29 |
|
| 30 |
+
## Quick Start π
|
| 31 |
- Light-weighted start (basic models, no cuML, read below for details)
|
| 32 |
```bash
|
| 33 |
# Ignore all LFS files, you will see an empty folder first
|
|
|
|
| 70 |
# Run inference
|
| 71 |
python inference.py
|
| 72 |
```
|
| 73 |
+
## Installation π
|
| 74 |
### Minimal Setup
|
| 75 |
- Easy start-up environment (using transformers, xgboost models)
|
| 76 |
```bash
|
|
|
|
| 86 |
# run inference (see below)
|
| 87 |
apptainer exec peptiverse.sif python inference.py
|
| 88 |
```
|
| 89 |
+
## Repository Structure π
|
| 90 |
This repo contains important large files for [PeptiVerse](https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse), an interactive app for peptide property prediction. [Paper link.](https://www.biorxiv.org/content/10.64898/2025.12.31.697180v1)
|
| 91 |
|
| 92 |
```
|
|
|
|
| 106 |
βββ best_models.txt # Model selection manifest
|
| 107 |
βββ requirements.txt # Python dependencies
|
| 108 |
```
|
| 109 |
+
For full data access, please download the corresponding `training_data_cleaned` and `training_classifiers` from zenodo. The current Huggingface repo only hosts best model weights and meta data with splits labels.
|
| 110 |
|
| 111 |
+
## Training Data Collection π
|
| 112 |
|
| 113 |
<table>
|
| 114 |
<caption><strong>Data distribution.</strong> Classification tasks report counts for class 0/1; regression tasks report total sample size (N).</caption>
|
|
|
|
| 147 |
<td>Solubility</td>
|
| 148 |
<td>9668</td>
|
| 149 |
<td>8785</td>
|
| 150 |
+
<td>9668</td>
|
| 151 |
+
<td>8785</td>
|
| 152 |
</tr>
|
| 153 |
<tr>
|
| 154 |
<td>Permeability (Penetrance)</td>
|
| 155 |
<td>1162</td>
|
| 156 |
<td>1162</td>
|
| 157 |
+
<td>1162</td>
|
| 158 |
+
<td>1162</td>
|
| 159 |
</tr>
|
| 160 |
<tr>
|
| 161 |
<td>Toxicity</td>
|
|
|
|
| 191 |
</table>
|
| 192 |
|
| 193 |
|
| 194 |
+
## Best Model List π
|
| 195 |
|
| 196 |
### Full model set (cuML-enabled)
|
| 197 |
+
| Property | Best Model (Sequence) | Best Model (SMILES) | Task Type | Threshold (Sequence) | Threshold (SMILES) |
|
| 198 |
+
|---|---|---|---|---|---|
|
| 199 |
+
| Hemolysis | SVM | CNN (chemberta) | Classifier | 0.2521 | 0.564 |
|
| 200 |
+
| Non-Fouling | Transformer | ENET (peptideclm) | Classifier | 0.57 | 0.6969 |
|
| 201 |
+
| Solubility | CNN | β | Classifier | 0.377 | β |
|
| 202 |
+
| Permeability (Penetrance) | SVM | SVM (chemberta) | Classifier | 0.5493 | 0.573 |
|
| 203 |
+
| Toxicity | β | CNN (chemberta) | Classifier | β | 0.49 |
|
| 204 |
+
| Binding Affinity | unpooled | unpooled | Regression | β | β |
|
| 205 |
+
| Permeability (PAMPA) | οΏ½οΏ½οΏ½ | CNN (chemberta) | Regression | β | β |
|
| 206 |
+
| Permeability (Caco-2) | β | SVR (chemberta) | Regression | β | β |
|
| 207 |
+
| Half-life | Transformer | XGB (peptideclm) | Regression | β | β |
|
| 208 |
+
|
| 209 |
>Note: *unpooled* indicates models operating on token-level embeddings with cross-attention, rather than mean-pooled representations.
|
| 210 |
|
| 211 |
### Minimal deployable model set (no cuML)
|
| 212 |
+
| Property | Best Model (WT) | Best Model (SMILES) | Task Type | Threshold (WT) | Threshold (SMILES) |
|
| 213 |
+
|---|---|---|---|---|---|
|
| 214 |
+
| Hemolysis | XGB | CNN (chemberta) | Classifier | 0.2801 | 0.564 |
|
| 215 |
+
| Non-Fouling | Transformer | XGB (peptideclm) | Classifier | 0.57 | 0.3892 |
|
| 216 |
+
| Solubility | CNN | β | Classifier | 0.377 | β |
|
| 217 |
+
| Permeability (Penetrance) | XGB | XGB (chemberta) | Classifier | 0.4301 | 0.5028 |
|
| 218 |
+
| Toxicity | β | CNN (chemberta) | Classifier | β | 0.49 |
|
| 219 |
+
| Binding Affinity | wt_wt_pooled | chemberta_smiles_pooled | Regression | β | β |
|
| 220 |
+
| Permeability (PAMPA) | β | CNN (chemberta) | Regression | β | β |
|
| 221 |
+
| Permeability (Caco-2) | β | SVR (chemberta) | Regression | β | β |
|
| 222 |
+
| Half-life | Transformer | XGB (peptideclm) | Regression | β | β |
|
| 223 |
+
>Note: Models marked as SVM or ENET are replaced with XGB as these models are not currently supported in the deployment environment without cuML setups.
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
## Usage π
|
|
|
|
| 227 |
|
| 228 |
### Local Application Hosting
|
| 229 |
- Host the [PeptiVerse UI](https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse) locally with your own resources.
|
|
|
|
| 233 |
git clone https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse
|
| 234 |
python app.py
|
| 235 |
```
|
| 236 |
+
### Data pre-processing
|
| 237 |
+
Under the `training_data_cleaned`, we provided the generated embeddings in huggingface dataset format. The following scripts are the steps used to generate the data.
|
| 238 |
+
|
| 239 |
### Dataset integration
|
| 240 |
- All properties are provided with raw_data/split_ready_csvs/[huggingface_datasets](https://huggingface.co/docs/datasets/en/index).
|
| 241 |
- Selective download the data you need with `huggingface-cli`
|
|
|
|
| 271 |
- Pooled (fixed-length vector per sequence)
|
| 272 |
- Generated by mean-pooling token embeddings excluding special tokens (CLS/EOS) and padding.
|
| 273 |
- Each item:
|
| 274 |
+
sequence: `str`
|
| 275 |
+
label: `int` (classification) or `float` (regression)
|
| 276 |
+
embedding: `float32[H]` (H=1280 for ESM-2 650M)
|
| 277 |
- Unpooled (variable-length token matrix)
|
| 278 |
- Generated by keeping all valid token embeddings (excluding special tokens + padding) as a per-sequence matrix.
|
| 279 |
- Each item:
|
| 280 |
+
sequence: `str`
|
| 281 |
+
label: `int` (classification) or `float` (regression)
|
| 282 |
+
embedding: `float16[L, H]` (nested lists)
|
| 283 |
+
attention_mask: `int8[L]`
|
| 284 |
+
length: `int` (=L)
|
| 285 |
- B) SMILES-based ([PeptideCLM](https://github.com/AaronFeller/PeptideCLM) embeddings)
|
| 286 |
- Pooled (fixed-length vector per sequence)
|
| 287 |
- Generated by mean-pooling token embeddings excluding special tokens (CLS/EOS) and padding.
|
| 288 |
- Each item:
|
| 289 |
+
sequence: `str` (SMILES)
|
| 290 |
+
label: `int` (classification) or `float` (regression)
|
| 291 |
+
embedding: `float32[H]`
|
| 292 |
- Unpooled (variable-length token matrix)
|
| 293 |
- Generated by keeping all valid token embeddings (excluding special tokens + padding) as a per-sequence matrix.
|
| 294 |
- Each item:
|
| 295 |
+
sequence: `str` (SMILES)
|
| 296 |
+
label: `int` (classification) or `float` (regression)
|
| 297 |
+
embedding: `float16[L, H]` (nested lists)
|
| 298 |
+
attention_mask: `int8[L]`
|
| 299 |
+
length: `int` (=L)
|
| 300 |
+
- C) SMILES-based ([ChemBERTa](https://huggingface.co/DeepChem/ChemBERTa-77M-MLM) embeddings)
|
| 301 |
+
- Pooled (fixed-length vector per sequence)
|
| 302 |
+
- Generated by mean-pooling token embeddings excluding special tokens (CLS/EOS) and padding.
|
| 303 |
+
- Each item:
|
| 304 |
+
sequence: `str` (SMILES)
|
| 305 |
+
label: `int` (classification) or `float` (regression)
|
| 306 |
+
embedding: `float32[H]`
|
| 307 |
+
- Unpooled (variable-length token matrix)
|
| 308 |
+
- Generated by keeping all valid token embeddings (excluding special tokens + padding) as a per-sequence matrix.
|
| 309 |
+
- Each item:
|
| 310 |
+
sequence: `str` (SMILES)
|
| 311 |
+
label: `int` (classification) or `float` (regression)
|
| 312 |
+
embedding: `float16[L, H]` (nested lists)
|
| 313 |
+
attention_mask: `int8[L]`
|
| 314 |
+
length: `int` (=L)
|
| 315 |
+
### Training
|
| 316 |
+
Under the `training_classifiers` folder, we provide the python scripts used to train different models. The scripts will
|
| 317 |
+
1. Read the pre-processed Huggingface Dataset from `training_data_cleaned` folder;
|
| 318 |
+
2. Perform OPTUNA hyperparameter sweep once being called;
|
| 319 |
+
3. All training was conducted on HPC with SLURM script under `training_classifiers/src` folder;
|
| 320 |
+
4. Customize or isolate certain model training scripts as needed.
|
| 321 |
+
##### Example of training
|
| 322 |
+
###### ML models
|
| 323 |
+
```
|
| 324 |
+
HOME_LOC=/home
|
| 325 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 326 |
+
EMB_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
|
| 327 |
+
|
| 328 |
+
OBJECTIVE='hemolysis' # nf/solubility/hemolysis/permeability_pampa/permeability_caco2
|
| 329 |
+
WT='smiles' # wt/smiles
|
| 330 |
+
DATA_FILE="hemo_${WT}_with_embeddings"
|
| 331 |
+
LOG_LOC=$SCRIPT_LOC/src/logs
|
| 332 |
+
DATE=$(date +%m_%d)
|
| 333 |
+
MODEL_TYPE='svm_gpu' # xgb/enet_gpu/svm_gpu
|
| 334 |
+
SPECIAL_PREFIX="${MODEL_TYPE}-${OBJECTIVE}-${WT}_new"
|
| 335 |
+
|
| 336 |
+
# Create log directory if it doesn't exist
|
| 337 |
+
mkdir -p $LOG_LOC
|
| 338 |
+
|
| 339 |
+
cd $SCRIPT_LOC
|
| 340 |
+
|
| 341 |
+
python -u train_ml.py \
|
| 342 |
+
--dataset_path "${DATA_LOC}/${OBJECTIVE}/${DATA_FILE}" \
|
| 343 |
+
--out_dir "${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}" \
|
| 344 |
+
--model "${MODEL_TYPE}" \
|
| 345 |
+
--n_trials 200 > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 346 |
+
```
|
| 347 |
+
###### DNN models
|
| 348 |
+
```
|
| 349 |
+
HOME_LOC=/home
|
| 350 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 351 |
+
EMB_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
|
| 352 |
+
|
| 353 |
+
OBJECTIVE='nf' # nf/solubility/hemolysis
|
| 354 |
+
WT='smiles' #wt/smiles
|
| 355 |
+
DATA_FILE="nf_${WT}_with_embeddings_unpooled"
|
| 356 |
+
LOG_LOC=$SCRIPT_LOC/src/logs
|
| 357 |
+
DATE=$(date +%m_%d)
|
| 358 |
+
MODEL_TYPE='cnn' #mlp/cnn/transformer
|
| 359 |
+
SPECIAL_PREFIX="${MODEL_TYPE}-${OBJECTIVE}-${WT}"
|
| 360 |
+
|
| 361 |
+
# Create log directory if it doesn't exist
|
| 362 |
+
mkdir -p $LOG_LOC
|
| 363 |
+
|
| 364 |
+
cd $SCRIPT_LOC
|
| 365 |
+
|
| 366 |
+
python -u train_nn.py \
|
| 367 |
+
--dataset_path "${DATA_LOC}/${OBJECTIVE}/${DATA_FILE}" \
|
| 368 |
+
--out_dir "${SCRIPT_LOC}/${OBJECTIVE}/${MODEL_TYPE}_${WT}" \
|
| 369 |
+
--model "${MODEL_TYPE}" \
|
| 370 |
+
--n_trials 200 > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 371 |
+
```
|
| 372 |
+
###### Binding Affinity
|
| 373 |
+
```
|
| 374 |
+
HOME_LOC=/home
|
| 375 |
+
SCRIPT_LOC=$HOME_LOC/PeptiVerse/training_classifiers
|
| 376 |
+
EMB_LOC=$HOME_LOC/PeptiVerse/training_data_cleaned
|
| 377 |
+
|
| 378 |
+
OBJECTIVE='binding_affinity'
|
| 379 |
+
BINDER_MODEL='chemberta' # peptideclm / chemberta
|
| 380 |
+
STATUS='unpooled' # pooled / unpooled
|
| 381 |
+
TYPE='smiles'
|
| 382 |
+
DATA_FILE='pair_wt_${TYPE}_${STATUS}'
|
| 383 |
+
|
| 384 |
+
LOG_LOC=$SCRIPT_LOC/src/logs
|
| 385 |
+
DATE=$(date +%m_%d)
|
| 386 |
+
SPECIAL_PREFIX="${OBJECTIVE}-${BINDER_MODEL}-${STATUS}"
|
| 387 |
+
|
| 388 |
+
python -u binding_training.py \
|
| 389 |
+
--dataset_path "${EMB_LOC}/${OBJECTIVE}/${BINDER_MODEL}/${DATA_FILE}" \
|
| 390 |
+
--mode "${STATUS}" \
|
| 391 |
+
--out_dir "${SCRIPT_LOC}/${OBJECTIVE}/${BINDER_MODEL}_${TYPE}_${STATUS}" \
|
| 392 |
+
--n_trials 200 > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
|
| 393 |
+
```
|
| 394 |
|
| 395 |
+
### Quick inference by property per model
|
| 396 |
```python
|
| 397 |
from inference import PeptiVersePredictor
|
| 398 |
+
from pathlib import Path
|
| 399 |
|
| 400 |
+
root = Path(__file__).resolve().parent # current script folder
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
predictor = PeptiVersePredictor(
|
| 404 |
+
manifest_path=root / "best_models.txt",
|
| 405 |
+
classifier_weight_root=root,
|
| 406 |
device="cuda", # or "cpu"
|
| 407 |
)
|
| 408 |
|
|
|
|
| 485 |
|
| 486 |
```
|
| 487 |
|
| 488 |
+
#### Advanced inference with uncertainty prediction
|
| 489 |
+
The uncertainty prediction is added as a parameter in the inference code. The full classifier folder from [zenodo]() is required to enable this functionality. The model uncertainty is reported via all the scripts listed under the `training_classifiers` folder starting with "**refit**". Detailed description can be found in the methodology part of the manuscript.
|
| 490 |
+
At inference time, PeptiVersePredictor returns an `uncertainty` field with every prediction when `uncertainty=True` is passed. The method and interpretation depend on the model class, determined automatically at inference time.
|
| 491 |
+
```python
|
| 492 |
+
seq = "GIGAVLKVLTTGLPALISWIKRKRQQ"
|
| 493 |
+
smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O"
|
| 494 |
|
| 495 |
+
print(predictor.predict_property("nf", "wt", seq, uncertainty=True))
|
| 496 |
+
print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True))
|
| 497 |
|
| 498 |
+
{'property': 'nf', 'col': 'wt', 'score': 0.00014520535252195523, 'emb_tag': 'wt', 'label': 0, 'threshold': 0.57, 'uncertainty': 0.0017192508727321288, 'uncertainty_type': 'ensemble_predictive_entropy'}
|
| 499 |
+
{'property': 'nf', 'col': 'smiles', 'score': 0.025485480204224586, 'emb_tag': 'peptideclm', 'label': 0, 'threshold': 0.6969, 'uncertainty': 0.11868063130587676, 'uncertainty_type': 'binary_predictive_entropy_single_model'}
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
+
---
|
| 503 |
+
|
| 504 |
+
##### Method by Model Class
|
| 505 |
+
|
| 506 |
+
| Model Class | Task | Uncertainty Method | Output Type | Range |
|
| 507 |
+
|---|---|---|---|---|
|
| 508 |
+
| MLP, CNN, Transformer | Classifier | Deep ensemble predictive entropy (5 seeds) | `float` | [0, ln(2) β 0.693] |
|
| 509 |
+
| MLP, CNN, Transformer | Regression | Adaptive conformal interval; falls back to ensemble std if no MAPIE bundle | `(lo, hi)` or `float` | unbounded |
|
| 510 |
+
| SVM / SVC / XGBoost | Classifier | Binary predictive entropy (sigmoid of decision function) | `float` | [0, ln(2) β 0.693] |
|
| 511 |
+
| SVR / ElasticNet / XGBoost | Regression | Adaptive conformal interval | `(lo, hi)` | unbounded |
|
| 512 |
+
|
| 513 |
+
> **Uncertainty is `None`** when: a DNN classifier has no seed ensemble trained, or a regression model has no `mapie_calibration.joblib` in its model directory.
|
| 514 |
+
|
| 515 |
---
|
| 516 |
+
## Interpretation π
|
|
|
|
| 517 |
|
| 518 |
+
You can also find the same description in the paper or in the PeptiVerse app `Documentation` tab.
|
| 519 |
|
| 520 |
+
---
|
| 521 |
+
### π©Έ Hemolysis Prediction
|
| 522 |
+
50% of read blood cells being lysed at x ug/ml concetration (HC50). If HC50 < 100uM, considered as hemolytic, otherwise non-hemolytic, resulting in a binary 0/1 dataset. The predicted probability should therefore be interpreted as a risk indicator, not an exact concentration estimate.
|
| 523 |
+
**Output interpretation:**
|
| 524 |
+
- Score close to 1.0 = high probability of red blood cell membrane disruption
|
| 525 |
- Score close to 0.0 = non-hemolytic
|
| 526 |
---
|
| 527 |
|
| 528 |
+
### π§ Solubility Prediction
|
| 529 |
+
Outputs a probability (0β1) that a peptide remains soluble in aqueous conditions.
|
| 530 |
+
**Output interpretation:**
|
| 531 |
+
- Score close to 1.0 = highly soluble
|
| 532 |
+
- Score close to 0.0 = poorly soluble
|
| 533 |
|
|
|
|
|
|
|
| 534 |
---
|
| 535 |
|
| 536 |
+
### π― Non-Fouling Prediction
|
| 537 |
+
Higher scores indicate stronger non-fouling behavior, desirable for circulation and surface-exposed applications.
|
| 538 |
+
**Output interpretation:**
|
| 539 |
+
- Score close to 1.0 = non-fouling
|
| 540 |
+
- Score close to 0.0 = fouling
|
|
|
|
|
|
|
| 541 |
---
|
| 542 |
|
| 543 |
+
### πͺ£ Permeability Prediction
|
| 544 |
+
Predicts membrane permeability on a log P scale.
|
| 545 |
+
**Output interpretation:**
|
| 546 |
+
- Higher values = more permeable (>-6.0)
|
| 547 |
+
- For penetrance predictions, it is a classification prediction, so within the [0, 1] range, closer to 1 indicates more permeable.
|
|
|
|
|
|
|
| 548 |
---
|
| 549 |
|
| 550 |
+
### β±οΈ Half-Life Prediction
|
| 551 |
**Interpretation:** Predicted values reflect relative peptide stability for the unit in hours. Higher scores indicate longer persistence in serum, while lower scores suggest faster degradation.
|
| 552 |
|
| 553 |
---
|
| 554 |
|
| 555 |
+
### β οΈ Toxicity Prediction
|
| 556 |
**Interpretation:** Outputs a probability (0β1) that a peptide exhibits toxic effects. Higher scores indicate increased toxicity risk.
|
| 557 |
|
| 558 |
---
|
| 559 |
|
| 560 |
+
### π Binding Affinity Prediction
|
| 561 |
|
| 562 |
+
Predicts peptide-protein binding affinity. Requires both peptide and target protein sequence.
|
| 563 |
|
| 564 |
**Interpretation:**<br>
|
| 565 |
- Scores β₯ 9 correspond to tight binders (K β€ 10β»βΉ M, nanomolar to picomolar range)<br>
|
| 566 |
- Scores between 7 and 9 correspond to medium binders (10β»β·β10β»βΉ M, nanomolar to micromolar range)<br>
|
| 567 |
- Scores < 7 correspond to weak binders (K β₯ 10β»βΆ M, micromolar and weaker)<br>
|
| 568 |
- A difference of 1 unit in score corresponds to an approximately tenfold change in binding affinity.<br>
|
| 569 |
+
|
| 570 |
+
---
|
| 571 |
+
|
| 572 |
+
### Uncertainty Interpretation
|
| 573 |
+
#### Entropy (classifiers)
|
| 574 |
+
|
| 575 |
+
Binary predictive entropy of the output probability $\bar{p}$:
|
| 576 |
+
|
| 577 |
+
$$\mathcal{H} = -\bar{p}\log\bar{p} - (1 - \bar{p})\log(1 - \bar{p})$$
|
| 578 |
+
|
| 579 |
+
- For **DNN classifiers**: $\bar{p}$ is the mean probability across 5 independently seeded models (deep ensemble). High entropy reflects both epistemic uncertainty (seed disagreement) and aleatoric uncertainty (collectively diffuse predictions).
|
| 580 |
+
- For **XGBoost / SVM / ElasticNet classifiers**: $\bar{p}$ is the single model's output probability (or sigmoid of decision function for ElasticNet). Entropy reflects output confidence of a single model only.
|
| 581 |
+
|
| 582 |
+
| Range | Interpretation |
|
| 583 |
+
|---|---|
|
| 584 |
+
| < 0.1 | High confidence |
|
| 585 |
+
| 0.1 β 0.4 | Moderate uncertainty |
|
| 586 |
+
| 0.4 β 0.6 | Low confidence |
|
| 587 |
+
| > 0.6 | Very low confidence β model close to guessing |
|
| 588 |
+
| β 0.693 | Maximum uncertainty β predicted probability β 0.5 |
|
| 589 |
+
|
| 590 |
+
---
|
| 591 |
+
|
| 592 |
+
#### Adaptive Conformal Prediction Interval (regressors)
|
| 593 |
+
|
| 594 |
+
Returned as a tuple `(lo, hi)` with 90% marginal coverage guarantee.
|
| 595 |
+
|
| 596 |
+
We implement the **residual normalised conformity score** following [Lei et al. (2018)](https://doi.org/10.1080/01621459.2017.1307116) and [Cordier et al. (2023) / MAPIE](https://proceedings.mlr.press/v204/cordier23a.html). An auxiliary XGBoost model $\hat{\sigma}(\mathbf{x})$ is trained on held-out embeddings and absolute residuals $|y_i - \hat{y}_i|$. At inference:
|
| 597 |
+
|
| 598 |
+
$$[\hat{y}(\mathbf{x}) - q \cdot \hat{\sigma}(\mathbf{x}),\ \hat{y}(\mathbf{x}) + q \cdot \hat{\sigma}(\mathbf{x})]$$
|
| 599 |
+
|
| 600 |
+
where $q$ is the $\lceil(n+1)(1-\alpha)\rceil / n$ quantile of the normalized scores $s_i = |y_i - \hat{y}_i| / \hat{\sigma}(\mathbf{x}_i)$.
|
| 601 |
+
|
| 602 |
+
- **Interval width varies per input** -- molecules more dissimilar to training data tend to receive wider intervals
|
| 603 |
+
- **Coverage guarantee**: on exchangeable data, $P(y \in [\hat{y} - q\hat{\sigma},\ \hat{y} + q\hat{\sigma}]) \geq 0.90$
|
| 604 |
+
- **The guarantee is marginal**, not conditional, as an unusually narrow interval on an out-of-distribution molecule does not guarantee correctness
|
| 605 |
+
- **Full access**: we already computed MAPIE for all regression models, users are allowed to directly use them for customized model lists.
|
| 606 |
+
|
| 607 |
+
---
|
| 608 |
+
|
| 609 |
+
#### Generating a MAPIE Bundle for a New Model
|
| 610 |
+
|
| 611 |
+
To enable conformal uncertainty for a newly trained regression model:
|
| 612 |
+
|
| 613 |
+
```bash
|
| 614 |
+
# Fit adaptive conformal bundle from val_predictions.csv
|
| 615 |
+
python fit_mapie_adaptive.py --root training_classifiers --prop <property_name>
|
| 616 |
+
```
|
| 617 |
+
|
| 618 |
+
The script reads `sequence`/`smiles` and `y_pred`/`y_true` columns from the CSV, recomputes embeddings, fits the XGBoost $\hat{\sigma}$ model, and saves `mapie_calibration.joblib` into the model directory. The bundle is automatically detected and loaded by `PeptiVersePredictor` on next initialisation.
|
| 619 |
+
|
| 620 |
|
| 621 |
|
| 622 |
+
## Model Architecture π
|
| 623 |
|
| 624 |
+
- **Sequence Embeddings:** [ESM-2 650M model](https://huggingface.co/facebook/esm2_t33_650M_UR50D) / [PeptideCLM model](https://huggingface.co/aaronfeller/PeptideCLM-23M-all) / [ChemBERTa](https://huggingface.co/DeepChem/ChemBERTa-77M-MLM). Foundational embeddings are frozen.
|
| 625 |
- **XGBoost Model:** Gradient boosting on pooled embedding features for efficient, high-performance prediction.
|
| 626 |
- **CNN/Transformer Model:** One-dimensional convolutional/self-attention transformer networks operating on unpooled embeddings to capture local sequence patterns.
|
| 627 |
- **Binding Model:** Transformer-based architecture with cross-attention between protein and peptide representations.
|
| 628 |
- **SVR Model:** Support Vector Regression applied to pooled embeddings, providing a kernel-based, nonparametric regression baseline that is robust on smaller or noisy datasets.
|
| 629 |
- **Others:** SVM and Elastic Nets were trained with [RAPIDS cuML](https://github.com/rapidsai/cuml), which requires a CUDA environment and is therefore not supported in the web app. Model checkpoints remain available in the Hugging Face repository.
|
| 630 |
|
| 631 |
+
## Troubleshooting π
|
| 632 |
|
| 633 |
### LFS Download Issues
|
| 634 |
|
|
|
|
| 640 |
--local-dir . \
|
| 641 |
--local-dir-use-symlinks False
|
| 642 |
```
|
|
|
|
|
|
|
| 643 |
|
| 644 |
+
## Citation π
|
| 645 |
|
| 646 |
If you find this repository helpful for your publications, please consider citing our paper:
|
| 647 |
|
| 648 |
```
|
| 649 |
+
@article {Zhang2025.12.31.697180,
|
| 650 |
author = {Zhang, Yinuo and Tang, Sophia and Chen, Tong and Mahood, Elizabeth and Vincoff, Sophia and Chatterjee, Pranam},
|
| 651 |
title = {PeptiVerse: A Unified Platform for Therapeutic Peptide Property Prediction},
|
| 652 |
+
elocation-id = {2025.12.31.697180},
|
| 653 |
year = {2026},
|
| 654 |
doi = {10.64898/2025.12.31.697180},
|
| 655 |
+
publisher = {Cold Spring Harbor Laboratory},
|
| 656 |
+
URL = {https://www.biorxiv.org/content/early/2026/01/03/2025.12.31.697180},
|
| 657 |
+
eprint = {https://www.biorxiv.org/content/early/2026/01/03/2025.12.31.697180.full.pdf},
|
| 658 |
journal = {bioRxiv}
|
| 659 |
}
|
| 660 |
```
|
| 661 |
+
To use this repository, you agree to abide by the MIT License.
|
download_light.py
CHANGED
|
@@ -15,8 +15,9 @@ from inference import (
|
|
| 15 |
# -----------------------------
|
| 16 |
# Config
|
| 17 |
# -----------------------------
|
|
|
|
| 18 |
MODEL_REPO = "ChatterjeeLab/PeptiVerse"
|
| 19 |
-
DEFAULT_ASSETS_DIR = Path(
|
| 20 |
DEFAULT_MANIFEST = Path("./basic_models.txt")
|
| 21 |
|
| 22 |
BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}
|
|
|
|
| 15 |
# -----------------------------
|
| 16 |
# Config
|
| 17 |
# -----------------------------
|
| 18 |
+
root = Path(__file__).resolve().parent # current script folder
|
| 19 |
MODEL_REPO = "ChatterjeeLab/PeptiVerse"
|
| 20 |
+
DEFAULT_ASSETS_DIR = Path(root) # where downloaded models live
|
| 21 |
DEFAULT_MANIFEST = Path("./basic_models.txt")
|
| 22 |
|
| 23 |
BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"}
|
fit_mapie_adaptive.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bundle format:
|
| 3 |
+
{
|
| 4 |
+
"quantile": q,
|
| 5 |
+
"sigma_model": xgb_booster,
|
| 6 |
+
"emb_tag": "wt"|"peptideclm"|"chemberta",
|
| 7 |
+
"alpha": 0.1,
|
| 8 |
+
"adaptive": True,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
Binding affinity bundles additionally store "target_emb_tag": "wt" since
|
| 12 |
+
both binder and target embeddings are concatenated for the sigma model.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import sys
|
| 17 |
+
import numpy as np
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import joblib
|
| 20 |
+
import xgboost as xgb
|
| 21 |
+
import torch
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 26 |
+
|
| 27 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
WEIGHT_ROOT = Path(__file__).parent
|
| 29 |
+
|
| 30 |
+
# Properties to skip
|
| 31 |
+
|
| 32 |
+
SKIP_PROPS = {"half_life", "halflife"}
|
| 33 |
+
|
| 34 |
+
def should_skip(model_dir: Path) -> bool:
|
| 35 |
+
return any(part in SKIP_PROPS for part in model_dir.parts)
|
| 36 |
+
|
| 37 |
+
# Embedding tag inference
|
| 38 |
+
|
| 39 |
+
def infer_emb_tag(folder_name: str) -> Optional[str]:
|
| 40 |
+
n = folder_name.lower()
|
| 41 |
+
if "chemberta" in n: return "chemberta"
|
| 42 |
+
if "peptideclm" in n: return "peptideclm"
|
| 43 |
+
if "smiles" in n: return "peptideclm"
|
| 44 |
+
if "wt" in n: return "wt"
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
def is_binding_affinity(model_dir: Path) -> bool:
|
| 48 |
+
return "binding_affinity" in model_dir.parts
|
| 49 |
+
|
| 50 |
+
def infer_binding_emb_tags(folder_name: str):
|
| 51 |
+
"""
|
| 52 |
+
Returns (binder_emb_tag, target_emb_tag) for binding affinity folders.
|
| 53 |
+
Folder convention: {target_emb}_{binder_emb}_{pooled|unpooled}
|
| 54 |
+
e.g. wt_wt_unpooled, chemberta_smiles_unpooled, peptideclm_smiles_unpooled
|
| 55 |
+
"""
|
| 56 |
+
n = folder_name.lower()
|
| 57 |
+
# target is always ESM2 (wt)
|
| 58 |
+
target_emb = "wt"
|
| 59 |
+
# binder emb from folder name
|
| 60 |
+
if "chemberta" in n: binder_emb = "chemberta"
|
| 61 |
+
elif "peptideclm" in n: binder_emb = "peptideclm"
|
| 62 |
+
else: binder_emb = "wt"
|
| 63 |
+
return binder_emb, target_emb
|
| 64 |
+
|
| 65 |
+
SEQ_CANDIDATES = ["sequence", "smiles", "seq", "peptide", "molecule"]
|
| 66 |
+
PRED_CANDIDATES = ["y_prob", "y_pred", "pred_prob", "pred_score", "score", "pred", "prediction"]
|
| 67 |
+
TRUE_CANDIDATES = ["y_true", "label", "true_label", "affinity", "y", "target"]
|
| 68 |
+
|
| 69 |
+
def resolve_col(df, candidates, label):
|
| 70 |
+
cl = {c.lower(): c for c in df.columns}
|
| 71 |
+
for c in candidates:
|
| 72 |
+
if c.lower() in cl:
|
| 73 |
+
return cl[c.lower()]
|
| 74 |
+
raise ValueError(f"Cannot find {label} column. Available: {list(df.columns)}")
|
| 75 |
+
|
| 76 |
+
_embedders = {}
|
| 77 |
+
|
| 78 |
+
def get_embedder(emb_tag: str):
|
| 79 |
+
if emb_tag in _embedders:
|
| 80 |
+
return _embedders[emb_tag]
|
| 81 |
+
if emb_tag == "wt":
|
| 82 |
+
from inference_new import WTEmbedder
|
| 83 |
+
emb = WTEmbedder(DEVICE)
|
| 84 |
+
elif emb_tag == "peptideclm":
|
| 85 |
+
from inference_new import SMILESEmbedder
|
| 86 |
+
emb = SMILESEmbedder(
|
| 87 |
+
DEVICE,
|
| 88 |
+
vocab_path=str(WEIGHT_ROOT / "tokenizer/new_vocab.txt"),
|
| 89 |
+
splits_path=str(WEIGHT_ROOT / "tokenizer/new_splits.txt"),
|
| 90 |
+
)
|
| 91 |
+
elif emb_tag == "chemberta":
|
| 92 |
+
from inference_new import ChemBERTaEmbedder
|
| 93 |
+
emb = ChemBERTaEmbedder(DEVICE)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"Unknown emb_tag: {emb_tag}")
|
| 96 |
+
_embedders[emb_tag] = emb
|
| 97 |
+
return emb
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def embed_sequences(sequences: list, emb_tag: str) -> np.ndarray:
|
| 101 |
+
embedder = get_embedder(emb_tag)
|
| 102 |
+
vecs = []
|
| 103 |
+
for seq in sequences:
|
| 104 |
+
v = embedder.pooled(seq.strip())
|
| 105 |
+
vecs.append(v.cpu().float().numpy())
|
| 106 |
+
return np.vstack(vecs).astype(np.float32)
|
| 107 |
+
|
| 108 |
+
# Sigma model, simple XGB
|
| 109 |
+
|
| 110 |
+
def fit_sigma_model(X: np.ndarray, residuals: np.ndarray) -> xgb.Booster:
|
| 111 |
+
dtrain = xgb.DMatrix(X, label=residuals)
|
| 112 |
+
params = {
|
| 113 |
+
"objective": "reg:squarederror",
|
| 114 |
+
"max_depth": 4,
|
| 115 |
+
"eta": 0.05,
|
| 116 |
+
"subsample": 0.8,
|
| 117 |
+
"colsample_bytree": 0.3,
|
| 118 |
+
"min_child_weight": 5,
|
| 119 |
+
"tree_method": "hist",
|
| 120 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
| 121 |
+
"seed": 1986,
|
| 122 |
+
}
|
| 123 |
+
return xgb.train(params, dtrain, num_boost_round=200, verbose_eval=False)
|
| 124 |
+
|
| 125 |
+
# Standard model dir fitting
|
| 126 |
+
|
| 127 |
+
def fit_standard(model_dir: Path, alpha: float, dry_run: bool) -> str:
|
| 128 |
+
val_path = model_dir / "val_predictions.csv"
|
| 129 |
+
if not val_path.exists():
|
| 130 |
+
val_path = model_dir / "oof_predictions.csv"
|
| 131 |
+
if not val_path.exists():
|
| 132 |
+
return "skip (no val/oof CSV)"
|
| 133 |
+
|
| 134 |
+
emb_tag = infer_emb_tag(model_dir.name)
|
| 135 |
+
if emb_tag is None:
|
| 136 |
+
return "skip (cannot infer emb_tag)"
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
df = pd.read_csv(val_path)
|
| 140 |
+
seq_col = resolve_col(df, SEQ_CANDIDATES, "sequence")
|
| 141 |
+
pred_col = resolve_col(df, PRED_CANDIDATES, "pred")
|
| 142 |
+
true_col = resolve_col(df, TRUE_CANDIDATES, "true")
|
| 143 |
+
except Exception as e:
|
| 144 |
+
return f"error: {e}"
|
| 145 |
+
|
| 146 |
+
sequences = df[seq_col].astype(str).tolist()
|
| 147 |
+
y_pred = df[pred_col].values.astype(np.float64)
|
| 148 |
+
y_true = df[true_col].values.astype(np.float64)
|
| 149 |
+
|
| 150 |
+
mask = np.isfinite(y_pred) & np.isfinite(y_true)
|
| 151 |
+
sequences = [s for s, m in zip(sequences, mask) if m]
|
| 152 |
+
y_pred, y_true = y_pred[mask], y_true[mask]
|
| 153 |
+
n = len(y_pred)
|
| 154 |
+
|
| 155 |
+
if n < 30:
|
| 156 |
+
return f"skip (only {n} samples)"
|
| 157 |
+
|
| 158 |
+
if dry_run:
|
| 159 |
+
return f"would fit (n={n}, emb={emb_tag})"
|
| 160 |
+
|
| 161 |
+
try:
|
| 162 |
+
X = embed_sequences(sequences, emb_tag)
|
| 163 |
+
except Exception as e:
|
| 164 |
+
return f"error embedding: {e}"
|
| 165 |
+
|
| 166 |
+
residuals = np.abs(y_true - y_pred).astype(np.float32)
|
| 167 |
+
sigma_model = fit_sigma_model(X, residuals)
|
| 168 |
+
sigma_cal = np.clip(sigma_model.predict(xgb.DMatrix(X)).astype(np.float64), 1e-6, None)
|
| 169 |
+
norm_scores = (residuals / sigma_cal)
|
| 170 |
+
level = min(1.0, np.ceil((n + 1) * (1 - alpha)) / n)
|
| 171 |
+
q = float(np.quantile(norm_scores, level))
|
| 172 |
+
lo, hi = y_pred - q * sigma_cal, y_pred + q * sigma_cal
|
| 173 |
+
coverage = float(np.mean((y_true >= lo) & (y_true <= hi)))
|
| 174 |
+
avg_width = float(np.mean(hi - lo))
|
| 175 |
+
|
| 176 |
+
bundle = {"quantile": q, "sigma_model": sigma_model,
|
| 177 |
+
"emb_tag": emb_tag, "alpha": alpha, "adaptive": True}
|
| 178 |
+
joblib.dump(bundle, model_dir / "mapie_calibration.joblib")
|
| 179 |
+
return f"ok (n={n}, emb={emb_tag}, q={q:.4f}, cov={coverage:.3f}, avg_width={avg_width:.3f})"
|
| 180 |
+
|
| 181 |
+
# Binding affinity fitting
|
| 182 |
+
|
| 183 |
+
def fit_binding_affinity(model_dir: Path, alpha: float, dry_run: bool) -> str:
|
| 184 |
+
val_path = model_dir / "val_predictions.csv"
|
| 185 |
+
if not val_path.exists():
|
| 186 |
+
return "skip (no val_predictions.csv)"
|
| 187 |
+
|
| 188 |
+
binder_emb, target_emb = infer_binding_emb_tags(model_dir.name)
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
df = pd.read_csv(val_path)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
return f"error reading CSV: {e}"
|
| 194 |
+
|
| 195 |
+
# Binding affinity CSV has both sequence (binder) and target_sequence
|
| 196 |
+
cl = {c.lower(): c for c in df.columns}
|
| 197 |
+
if "sequence" not in cl or "target_sequence" not in cl:
|
| 198 |
+
return f"skip (missing sequence/target_sequence columns, have: {list(df.columns)})"
|
| 199 |
+
|
| 200 |
+
binder_seqs = df[cl["sequence"]].astype(str).tolist()
|
| 201 |
+
target_seqs = df[cl["target_sequence"]].astype(str).tolist()
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
pred_col = resolve_col(df, PRED_CANDIDATES, "pred")
|
| 205 |
+
true_col = resolve_col(df, TRUE_CANDIDATES, "true")
|
| 206 |
+
except Exception as e:
|
| 207 |
+
return f"error: {e}"
|
| 208 |
+
|
| 209 |
+
y_pred = df[pred_col].values.astype(np.float64)
|
| 210 |
+
y_true = df[true_col].values.astype(np.float64)
|
| 211 |
+
|
| 212 |
+
mask = np.isfinite(y_pred) & np.isfinite(y_true)
|
| 213 |
+
binder_seqs = [s for s, m in zip(binder_seqs, mask) if m]
|
| 214 |
+
target_seqs = [s for s, m in zip(target_seqs, mask) if m]
|
| 215 |
+
y_pred, y_true = y_pred[mask], y_true[mask]
|
| 216 |
+
n = len(y_pred)
|
| 217 |
+
|
| 218 |
+
if n < 30:
|
| 219 |
+
return f"skip (only {n} samples)"
|
| 220 |
+
|
| 221 |
+
if dry_run:
|
| 222 |
+
return f"would fit (n={n}, binder_emb={binder_emb}, target_emb={target_emb})"
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
X_binder = embed_sequences(binder_seqs, binder_emb) # (n, H_b)
|
| 226 |
+
X_target = embed_sequences(target_seqs, target_emb) # (n, H_t)
|
| 227 |
+
X = np.concatenate([X_target, X_binder], axis=1) # (n, H_t+H_b)
|
| 228 |
+
except Exception as e:
|
| 229 |
+
return f"error embedding: {e}"
|
| 230 |
+
|
| 231 |
+
# Absolute residuals on held-out validation predictions.
|
| 232 |
+
# Equivalent to ResidualNormalisedScore.get_signed_conformity_scores()
|
| 233 |
+
# in MAPIE (Cordier et al. 2023), which computes |y - y_pred| / sigma_hat.
|
| 234 |
+
# We compute residuals first, then fit sigma_hat below, matching MAPIE's
|
| 235 |
+
# two-stage procedure (fit residual estimator β normalize β take quantile).
|
| 236 |
+
residuals = np.abs(y_true - y_pred).astype(np.float32)
|
| 237 |
+
|
| 238 |
+
# Fit sigma_hat: auxiliary XGBoost regressor trained on (embeddings, |residuals|).
|
| 239 |
+
# Corresponds to ResidualNormalisedScore's residual_estimator fitted on
|
| 240 |
+
# (X_res, |y_res - y_hat_res|) per the MAPIE tutorial. MAPIE fits sigma_hat
|
| 241 |
+
# on log-residuals and exponentiates predictions to ensure positivity; we
|
| 242 |
+
# instead clip sigma to 1e-6 for the same effect.
|
| 243 |
+
sigma_model = fit_sigma_model(X, residuals)
|
| 244 |
+
sigma_cal = np.clip(sigma_model.predict(xgb.DMatrix(X)).astype(np.float64), 1e-6, None)
|
| 245 |
+
|
| 246 |
+
# Normalized conformity scores: s_i = |y_i - y_hat_i| / sigma_hat(x_i).
|
| 247 |
+
# This is the ResidualNormalisedScore formula from MAPIE. Larger scores
|
| 248 |
+
# encode worse agreement between prediction and observation (Vovk et al. 2005).
|
| 249 |
+
norm_scores = residuals / sigma_cal
|
| 250 |
+
|
| 251 |
+
# Finite-sample corrected conformal quantile at level ceil((n+1)(1-alpha))/n.
|
| 252 |
+
# Guarantees marginal coverage >= 1-alpha under exchangeability
|
| 253 |
+
# (Lei et al. 2018, Theorem 1).
|
| 254 |
+
level = min(1.0, np.ceil((n + 1) * (1 - alpha)) / n)
|
| 255 |
+
q = float(np.quantile(norm_scores, level))
|
| 256 |
+
lo, hi = y_pred - q * sigma_cal, y_pred + q * sigma_cal
|
| 257 |
+
coverage = float(np.mean((y_true >= lo) & (y_true <= hi)))
|
| 258 |
+
avg_width = float(np.mean(hi - lo))
|
| 259 |
+
|
| 260 |
+
bundle = {
|
| 261 |
+
"quantile": q,
|
| 262 |
+
"sigma_model": sigma_model,
|
| 263 |
+
"emb_tag": binder_emb,
|
| 264 |
+
"target_emb_tag": target_emb,
|
| 265 |
+
"alpha": alpha,
|
| 266 |
+
"adaptive": True,
|
| 267 |
+
}
|
| 268 |
+
joblib.dump(bundle, model_dir / "mapie_calibration.joblib")
|
| 269 |
+
return (f"ok (n={n}, binder={binder_emb}, target={target_emb}, "
|
| 270 |
+
f"q={q:.4f}, cov={coverage:.3f}, avg_width={avg_width:.3f})")
|
| 271 |
+
|
| 272 |
+
MODEL_PATTERNS = [
|
| 273 |
+
"xgb_*", "enet_*", "svm_*", "svr_*", "mlp_*", "cnn_*", "transformer_*",
|
| 274 |
+
"wt_wt_*", "wt_smiles_*", "peptideclm_smiles_*", "chemberta_smiles_*",
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
def main():
|
| 278 |
+
parser = argparse.ArgumentParser()
|
| 279 |
+
parser.add_argument("--root", type=Path, required=True)
|
| 280 |
+
parser.add_argument("--alpha", type=float, default=0.1)
|
| 281 |
+
parser.add_argument("--prop", type=str, default=None,
|
| 282 |
+
help="Only process a specific property subfolder")
|
| 283 |
+
parser.add_argument("--dry-run", action="store_true")
|
| 284 |
+
parser.add_argument("--overwrite", action="store_true")
|
| 285 |
+
args = parser.parse_args()
|
| 286 |
+
|
| 287 |
+
search_root = args.root / args.prop if args.prop else args.root
|
| 288 |
+
|
| 289 |
+
model_dirs = []
|
| 290 |
+
for pat in MODEL_PATTERNS:
|
| 291 |
+
model_dirs.extend(sorted(search_root.rglob(pat)))
|
| 292 |
+
model_dirs = [d for d in model_dirs if d.is_dir()]
|
| 293 |
+
|
| 294 |
+
print(f"Found {len(model_dirs)} model dirs under {search_root}")
|
| 295 |
+
if args.dry_run:
|
| 296 |
+
print("DRY RUN\n")
|
| 297 |
+
|
| 298 |
+
counts = {"ok": 0, "skip": 0, "error": 0}
|
| 299 |
+
|
| 300 |
+
for model_dir in model_dirs:
|
| 301 |
+
rel = model_dir.relative_to(args.root)
|
| 302 |
+
|
| 303 |
+
if should_skip(model_dir):
|
| 304 |
+
print(f" SKIP {rel} (halflife β no sequence in OOF CSV)")
|
| 305 |
+
counts["skip"] += 1
|
| 306 |
+
continue
|
| 307 |
+
|
| 308 |
+
out = model_dir / "mapie_calibration.joblib"
|
| 309 |
+
if out.exists() and not args.overwrite:
|
| 310 |
+
try:
|
| 311 |
+
b = joblib.load(out)
|
| 312 |
+
if b.get("adaptive"):
|
| 313 |
+
print(f" OK {rel} (already adaptive)")
|
| 314 |
+
counts["ok"] += 1
|
| 315 |
+
continue
|
| 316 |
+
except Exception:
|
| 317 |
+
pass
|
| 318 |
+
|
| 319 |
+
print(f" FITTING {rel} ...", end=" ", flush=True)
|
| 320 |
+
if is_binding_affinity(model_dir):
|
| 321 |
+
status = fit_binding_affinity(model_dir, args.alpha, args.dry_run)
|
| 322 |
+
else:
|
| 323 |
+
status = fit_standard(model_dir, args.alpha, args.dry_run)
|
| 324 |
+
|
| 325 |
+
tag = "ok" if status.startswith("ok") else ("skip" if status.startswith("skip") else "error")
|
| 326 |
+
counts[tag] += 1
|
| 327 |
+
print(status)
|
| 328 |
+
|
| 329 |
+
print(f"\nDone. {counts}")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
if __name__ == "__main__":
|
| 333 |
+
main()
|
inference.py
CHANGED
|
@@ -361,6 +361,11 @@ def _mapie_uncertainty(mapie_bundle: dict, score: float,
|
|
| 361 |
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
|
| 362 |
q = float(mapie_bundle["quantile"])
|
| 363 |
if embedding is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
sigma_model = mapie_bundle["sigma_model"]
|
| 365 |
sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
|
| 366 |
sigma = max(sigma, 1e-6)
|
|
|
|
| 361 |
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
|
| 362 |
q = float(mapie_bundle["quantile"])
|
| 363 |
if embedding is not None:
|
| 364 |
+
# Adaptive interval: y_hat Β± q * sigma_hat(x).
|
| 365 |
+
# Equivalent to MAPIE's get_estimation_distribution():
|
| 366 |
+
# y_pred + conformity_scores * r_pred
|
| 367 |
+
# where conformity_scores=q and r_pred=sigma_hat(x).
|
| 368 |
+
# (ResidualNormalisedScore, Cordier et al. 2023)
|
| 369 |
sigma_model = mapie_bundle["sigma_model"]
|
| 370 |
sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
|
| 371 |
sigma = max(sigma, 1e-6)
|