github-actions commited on
Commit ·
4466c5e
0
Parent(s):
Sync from GitHub main
Browse files- .github/workflows/push_to_hf.yml +38 -0
- .gitignore +7 -0
- Data/Data Editors/csvCleanup.py +14 -0
- Data/Data Editors/csvCombiner.py +58 -0
- Data/Data Editors/fastaCleanup.py +40 -0
- Data/Sequence Fastas/amps.fasta +0 -0
- Data/Sequence Fastas/non_amps.fasta +0 -0
- Data/ampData.csv +0 -0
- MLModels/ampModel.ipynb +303 -0
- PeptideAI/Data/Data Editors/csvCleanup.py +14 -0
- PeptideAI/Data/Data Editors/csvCombiner.py +58 -0
- PeptideAI/Data/Data Editors/fastaCleanup.py +40 -0
- PeptideAI/Data/Sequence Fastas/amps.fasta +0 -0
- PeptideAI/Data/Sequence Fastas/non_amps.fasta +0 -0
- PeptideAI/Data/ampData.csv +0 -0
- PeptideAI/MLModels/ampModel.ipynb +303 -0
- PeptideAI/StreamlitApp/StreamlitApp.py +368 -0
- PeptideAI/StreamlitApp/utils/__init__.py +0 -0
- PeptideAI/StreamlitApp/utils/__pycache__/__init__.cpython-313.pyc +0 -0
- PeptideAI/StreamlitApp/utils/__pycache__/analyze.cpython-313.pyc +0 -0
- PeptideAI/StreamlitApp/utils/__pycache__/optimize.cpython-313.pyc +0 -0
- PeptideAI/StreamlitApp/utils/__pycache__/predict.cpython-313.pyc +0 -0
- PeptideAI/StreamlitApp/utils/__pycache__/visualize.cpython-313.pyc +0 -0
- PeptideAI/StreamlitApp/utils/analyze.py +21 -0
- PeptideAI/StreamlitApp/utils/optimize.py +59 -0
- PeptideAI/StreamlitApp/utils/predict.py +140 -0
- PeptideAI/StreamlitApp/utils/rateLimit.py +30 -0
- PeptideAI/StreamlitApp/utils/visualize.py +31 -0
- README.md +22 -0
- StreamlitApp/StreamlitApp.py +368 -0
- StreamlitApp/utils/__init__.py +0 -0
- StreamlitApp/utils/analyze.py +21 -0
- StreamlitApp/utils/optimize.py +59 -0
- StreamlitApp/utils/predict.py +144 -0
- StreamlitApp/utils/rateLimit.py +30 -0
- StreamlitApp/utils/visualize.py +31 -0
- requirements.txt +9 -0
- space.yaml +6 -0
.github/workflows/push_to_hf.yml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync with Hugging Face Space
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [main]
|
| 6 |
+
workflow_dispatch:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sync-to-hub:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
with:
|
| 15 |
+
fetch-depth: 0
|
| 16 |
+
lfs: true
|
| 17 |
+
|
| 18 |
+
- name: Push to Hugging Face Space
|
| 19 |
+
env:
|
| 20 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 21 |
+
run: |
|
| 22 |
+
git remote remove space 2>/dev/null || true
|
| 23 |
+
git remote add space https://m0ksh:${HF_TOKEN}@huggingface.co/spaces/m0ksh/PeptideAI-App
|
| 24 |
+
|
| 25 |
+
# Create a history-free sync branch so we don't push old binary blobs.
|
| 26 |
+
git config user.name "github-actions"
|
| 27 |
+
git config user.email "actions@github.com"
|
| 28 |
+
|
| 29 |
+
git checkout --orphan hf-sync
|
| 30 |
+
git rm -rf . >/dev/null 2>&1 || true
|
| 31 |
+
git checkout main -- .
|
| 32 |
+
|
| 33 |
+
# Hugging Face Spaces rejects raw binary blobs in git history.
|
| 34 |
+
rm -f PeptideAI/MLModels/*.pt PeptideAI/StreamlitApp/models/*.pt StreamlitApp/models/*.pt 2>/dev/null || true
|
| 35 |
+
|
| 36 |
+
git add -A
|
| 37 |
+
git commit -m "Sync from GitHub main" || true
|
| 38 |
+
git push --force space hf-sync:main
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.vscode/
|
| 2 |
+
Data/**/*.tmp
|
| 3 |
+
Data/**/*.log
|
| 4 |
+
MLModels/**/*.pt
|
| 5 |
+
MLModels/**/*.pth
|
| 6 |
+
StreamlitApp/utils/__pycache__/
|
| 7 |
+
StreamlitApp/models/*.pt
|
Data/Data Editors/csvCleanup.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
# Load data
|
| 4 |
+
df = pd.read_csv("cleaned_amp_data.csv")
|
| 5 |
+
|
| 6 |
+
# Drop index column if it exists
|
| 7 |
+
if 'Unnamed: 0' in df.columns:
|
| 8 |
+
df = df.drop(columns=['Unnamed: 0'])
|
| 9 |
+
|
| 10 |
+
# Drop duplicate sequences
|
| 11 |
+
df = df.drop_duplicates(subset='sequence')
|
| 12 |
+
|
| 13 |
+
# Save cleaned data
|
| 14 |
+
df.to_csv("2cleaned_amp_data.csv", index=False)
|
Data/Data Editors/csvCombiner.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from Bio import SeqIO
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
amp_fasta = "amps.fasta"
|
| 6 |
+
non_amp_fasta = "non_amps.fasta"
|
| 7 |
+
output_csv = "ampData3.csv"
|
| 8 |
+
valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
|
| 9 |
+
|
| 10 |
+
# HELPER: clean and validate sequences
|
| 11 |
+
def clean_seq(seq):
|
| 12 |
+
seq = seq.strip().upper()
|
| 13 |
+
if not seq or any(aa not in valid_aas for aa in seq):
|
| 14 |
+
return None
|
| 15 |
+
return seq
|
| 16 |
+
|
| 17 |
+
# LOAD FASTAS
|
| 18 |
+
def load_fasta(filepath, label):
|
| 19 |
+
"""Load fasta file. Accepts a filename or path. If the path does not exist
|
| 20 |
+
as given, try resolving it relative to this script's directory.
|
| 21 |
+
Returns list of dicts: {"sequence": seq, "label": label}.
|
| 22 |
+
"""
|
| 23 |
+
p = Path(filepath)
|
| 24 |
+
|
| 25 |
+
if not p.exists():
|
| 26 |
+
p = Path(__file__).resolve().parent / filepath
|
| 27 |
+
if not p.exists():
|
| 28 |
+
raise FileNotFoundError(f"Fasta file not found: '{filepath}' (tried '{p}')")
|
| 29 |
+
|
| 30 |
+
records = []
|
| 31 |
+
for record in SeqIO.parse(str(p), "fasta"):
|
| 32 |
+
seq = clean_seq(str(record.seq))
|
| 33 |
+
if seq:
|
| 34 |
+
records.append({"sequence": seq, "label": label})
|
| 35 |
+
return records
|
| 36 |
+
|
| 37 |
+
amps = load_fasta(amp_fasta, 1)
|
| 38 |
+
non_amps = load_fasta(non_amp_fasta, 0)
|
| 39 |
+
|
| 40 |
+
print(f"Loaded {len(amps)} AMPs and {len(non_amps)} non-AMPs before cleaning.")
|
| 41 |
+
|
| 42 |
+
# REMOVE DUPLICATES
|
| 43 |
+
amp_df = pd.DataFrame(amps).drop_duplicates(subset=["sequence"])
|
| 44 |
+
non_amp_df = pd.DataFrame(non_amps).drop_duplicates(subset=["sequence"])
|
| 45 |
+
|
| 46 |
+
# BALANCE CLASSES
|
| 47 |
+
min_count = min(len(amp_df), len(non_amp_df))
|
| 48 |
+
amp_balanced = amp_df.sample(n=min_count, random_state=42)
|
| 49 |
+
non_amp_balanced = non_amp_df.sample(n=min_count, random_state=42)
|
| 50 |
+
|
| 51 |
+
# COMBINE AND SHUFFLE
|
| 52 |
+
final_df = pd.concat([amp_balanced, non_amp_balanced]).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 53 |
+
|
| 54 |
+
# SAVE TO CSV
|
| 55 |
+
final_df.to_csv(output_csv, index=False)
|
| 56 |
+
|
| 57 |
+
print(f"Saved balanced dataset with {len(final_df)} total sequences ({min_count} per class).")
|
| 58 |
+
print(f"Output file: {output_csv}")
|
Data/Data Editors/fastaCleanup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Bio import SeqIO
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
# CONFIG
|
| 5 |
+
input_fasta = "amps.fasta"
|
| 6 |
+
output_fasta = "amps_clean.fasta"
|
| 7 |
+
output_csv = "amps_clean.csv"
|
| 8 |
+
|
| 9 |
+
min_len = 5
|
| 10 |
+
max_len = 100
|
| 11 |
+
valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
|
| 12 |
+
|
| 13 |
+
# CLEAN FUNCTION
|
| 14 |
+
def clean_seq(seq):
|
| 15 |
+
seq = seq.strip().upper()
|
| 16 |
+
if not (min_len <= len(seq) <= max_len):
|
| 17 |
+
return None
|
| 18 |
+
if any(aa not in valid_aas for aa in seq):
|
| 19 |
+
return None
|
| 20 |
+
return seq
|
| 21 |
+
|
| 22 |
+
# READ AND CLEAN
|
| 23 |
+
clean_records = []
|
| 24 |
+
for record in SeqIO.parse(input_fasta, "fasta"):
|
| 25 |
+
seq = clean_seq(str(record.seq))
|
| 26 |
+
if seq:
|
| 27 |
+
clean_records.append(seq)
|
| 28 |
+
|
| 29 |
+
# DEDUPLICATE
|
| 30 |
+
clean_records = list(set(clean_records))
|
| 31 |
+
|
| 32 |
+
# SAVE CLEAN FASTA
|
| 33 |
+
with open(output_fasta, "w") as f:
|
| 34 |
+
for i, seq in enumerate(clean_records, start=1):
|
| 35 |
+
f.write(f">AMP_{i}\n{seq}\n")
|
| 36 |
+
|
| 37 |
+
# SAVE CSV
|
| 38 |
+
pd.DataFrame({"sequence": clean_records}).to_csv(output_csv, index=False)
|
| 39 |
+
|
| 40 |
+
print(f"✅ Cleaned {len(clean_records)} sequences saved to '{output_fasta}' and '{output_csv}'.")
|
Data/Sequence Fastas/amps.fasta
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Data/Sequence Fastas/non_amps.fasta
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Data/ampData.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MLModels/ampModel.ipynb
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "PwhltETnCLY1"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"#**AMP Classification using ProtBERT Embeddings + Fast MLP**\n",
|
| 10 |
+
"This notebook extracts ProtBERT embeddings for peptide sequences and trains a simple Multi-Layer Perceptron (MLP) to classify antimicrobial peptides (AMPs) vs non-AMPs."
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"metadata": {
|
| 17 |
+
"collapsed": true,
|
| 18 |
+
"id": "qv_84qo0CLY6"
|
| 19 |
+
},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"!pip install torch transformers scikit-learn numpy pandas tqdm"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": null,
|
| 28 |
+
"metadata": {
|
| 29 |
+
"collapsed": true,
|
| 30 |
+
"id": "4wld_6KBCLY7"
|
| 31 |
+
},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"import numpy as np\n",
|
| 36 |
+
"import pandas as pd\n",
|
| 37 |
+
"from tqdm import tqdm\n",
|
| 38 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
| 39 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 40 |
+
"from sklearn.preprocessing import LabelEncoder\n",
|
| 41 |
+
"from torch import nn, optim\n",
|
| 42 |
+
"from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score\n",
|
| 43 |
+
"import sys\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 46 |
+
"print('Device:', device)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"metadata": {
|
| 52 |
+
"id": "7n3m1GLLCLY8"
|
| 53 |
+
},
|
| 54 |
+
"source": [
|
| 55 |
+
"##Load Dataset"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {
|
| 62 |
+
"collapsed": true,
|
| 63 |
+
"id": "wAg_vM3JCLY8"
|
| 64 |
+
},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"IN_COLAB = 'google.colab' in sys.modules\n",
|
| 68 |
+
"if IN_COLAB:\n",
|
| 69 |
+
" from google.colab import drive\n",
|
| 70 |
+
" drive.mount('/content/drive')\n",
|
| 71 |
+
" file_path = '/content/drive/MyDrive/ampData.csv'\n",
|
| 72 |
+
"else:\n",
|
| 73 |
+
" file_path = 'ampData.csv'\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"df = pd.read_csv(file_path)\n",
|
| 76 |
+
"df['sequence'] = df['sequence'].astype(str).str.upper().str.strip()\n",
|
| 77 |
+
"df = df.dropna(subset=['sequence','label']).reset_index(drop=True)\n",
|
| 78 |
+
"df.head()"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"metadata": {
|
| 84 |
+
"id": "8HxUUO6SCLY8"
|
| 85 |
+
},
|
| 86 |
+
"source": [
|
| 87 |
+
"## Extract ProtBERT Embeddings"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {
|
| 94 |
+
"collapsed": true,
|
| 95 |
+
"id": "CltjDxknCLY9"
|
| 96 |
+
},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')\n",
|
| 100 |
+
"model = AutoModel.from_pretrained('Rostlab/prot_bert').to(device)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def get_embedding(sequence):\n",
|
| 103 |
+
" seq = ' '.join(list(sequence))\n",
|
| 104 |
+
" tokens = tokenizer(seq, return_tensors='pt', truncation=True, padding=True).to(device)\n",
|
| 105 |
+
" with torch.no_grad():\n",
|
| 106 |
+
" outputs = model(**tokens)\n",
|
| 107 |
+
" emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()\n",
|
| 108 |
+
" return emb\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"embeddings = []\n",
|
| 111 |
+
"for seq in tqdm(df['sequence'], desc='Extracting Embeddings'):\n",
|
| 112 |
+
" embeddings.append(get_embedding(seq))\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"X = np.array(embeddings)\n",
|
| 115 |
+
"y = df['label'].values\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"np.save('X_embeddings.npy', X)\n",
|
| 118 |
+
"np.save('y_labels.npy', y)"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "markdown",
|
| 123 |
+
"metadata": {
|
| 124 |
+
"id": "TZpCHIpTCLY9"
|
| 125 |
+
},
|
| 126 |
+
"source": [
|
| 127 |
+
"## Train-Test Split"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"metadata": {
|
| 134 |
+
"id": "HUhsld4YCLY9"
|
| 135 |
+
},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n",
|
| 141 |
+
"X_test = torch.tensor(X_test, dtype=torch.float32).to(device)\n",
|
| 142 |
+
"y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)\n",
|
| 143 |
+
"y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "markdown",
|
| 148 |
+
"metadata": {
|
| 149 |
+
"id": "aeeNh2s9CLY-"
|
| 150 |
+
},
|
| 151 |
+
"source": [
|
| 152 |
+
"## Define MLP Classifier"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": null,
|
| 158 |
+
"metadata": {
|
| 159 |
+
"collapsed": true,
|
| 160 |
+
"id": "V04ShQ1VCLY-"
|
| 161 |
+
},
|
| 162 |
+
"outputs": [],
|
| 163 |
+
"source": [
|
| 164 |
+
"class MLPClassifier(nn.Module):\n",
|
| 165 |
+
" def __init__(self, input_dim):\n",
|
| 166 |
+
" super().__init__()\n",
|
| 167 |
+
" self.layers = nn.Sequential(\n",
|
| 168 |
+
" nn.Linear(input_dim, 512),\n",
|
| 169 |
+
" nn.ReLU(),\n",
|
| 170 |
+
" nn.Dropout(0.3),\n",
|
| 171 |
+
" nn.Linear(512, 128),\n",
|
| 172 |
+
" nn.ReLU(),\n",
|
| 173 |
+
" nn.Linear(128, 1),\n",
|
| 174 |
+
" nn.Sigmoid()\n",
|
| 175 |
+
" )\n",
|
| 176 |
+
" def forward(self, x):\n",
|
| 177 |
+
" return self.layers(x)\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"model_mlp = MLPClassifier(X_train.shape[1]).to(device)\n",
|
| 180 |
+
"criterion = nn.BCELoss()\n",
|
| 181 |
+
"optimizer = optim.Adam(model_mlp.parameters(), lr=1e-4)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"print(model_mlp)"
|
| 184 |
+
]
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"cell_type": "markdown",
|
| 188 |
+
"metadata": {
|
| 189 |
+
"id": "XAsOa6l7CLY-"
|
| 190 |
+
},
|
| 191 |
+
"source": [
|
| 192 |
+
"## Train MLP"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": null,
|
| 198 |
+
"metadata": {
|
| 199 |
+
"collapsed": true,
|
| 200 |
+
"id": "7sXSUh3WCLY-"
|
| 201 |
+
},
|
| 202 |
+
"outputs": [],
|
| 203 |
+
"source": [
|
| 204 |
+
"epochs = 20\n",
|
| 205 |
+
"batch_size = 64\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"for epoch in range(epochs):\n",
|
| 208 |
+
" model_mlp.train()\n",
|
| 209 |
+
" perm = torch.randperm(X_train.size(0))\n",
|
| 210 |
+
" total_loss = 0\n",
|
| 211 |
+
" for i in range(0, X_train.size(0), batch_size):\n",
|
| 212 |
+
" idx = perm[i:i+batch_size]\n",
|
| 213 |
+
" x_batch, y_batch = X_train[idx], y_train[idx]\n",
|
| 214 |
+
" optimizer.zero_grad()\n",
|
| 215 |
+
" outputs = model_mlp(x_batch)\n",
|
| 216 |
+
" loss = criterion(outputs, y_batch)\n",
|
| 217 |
+
" loss.backward()\n",
|
| 218 |
+
" optimizer.step()\n",
|
| 219 |
+
" total_loss += loss.item()\n",
|
| 220 |
+
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}\")"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "markdown",
|
| 225 |
+
"metadata": {
|
| 226 |
+
"id": "A4XbUrqRCLY-"
|
| 227 |
+
},
|
| 228 |
+
"source": [
|
| 229 |
+
"## Evaluate"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "code",
|
| 234 |
+
"execution_count": null,
|
| 235 |
+
"metadata": {
|
| 236 |
+
"collapsed": true,
|
| 237 |
+
"id": "YtieKVFhCLY_"
|
| 238 |
+
},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"model_mlp.eval()\n",
|
| 242 |
+
"with torch.no_grad():\n",
|
| 243 |
+
" preds = model_mlp(X_test).cpu().numpy().flatten()\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"pred_labels = (preds >= 0.5).astype(int)\n",
|
| 246 |
+
"print('ROC-AUC:', roc_auc_score(y_test.cpu(), preds))\n",
|
| 247 |
+
"print('PR-AUC:', average_precision_score(y_test.cpu(), preds))\n",
|
| 248 |
+
"print('\\nClassification Report:\\n', classification_report(y_test.cpu(), pred_labels))\n",
|
| 249 |
+
"print('Confusion Matrix:\\n', confusion_matrix(y_test.cpu(), pred_labels))"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "markdown",
|
| 254 |
+
"metadata": {
|
| 255 |
+
"id": "ADjCmp8PCLY_"
|
| 256 |
+
},
|
| 257 |
+
"source": [
|
| 258 |
+
"## Save Model"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": null,
|
| 264 |
+
"metadata": {
|
| 265 |
+
"collapsed": true,
|
| 266 |
+
"id": "v0j_4vwKCLY_"
|
| 267 |
+
},
|
| 268 |
+
"outputs": [],
|
| 269 |
+
"source": [
|
| 270 |
+
"torch.save(model_mlp.state_dict(), 'fast_mlp_amp.pt')\n",
|
| 271 |
+
"print('Model saved as fast_mlp_amp.pt')"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": null,
|
| 277 |
+
"metadata": {
|
| 278 |
+
"id": "IuJCNyBTXkBH"
|
| 279 |
+
},
|
| 280 |
+
"outputs": [],
|
| 281 |
+
"source": [
|
| 282 |
+
"from google.colab import files\n",
|
| 283 |
+
"files.download('fast_mlp_amp.pt')"
|
| 284 |
+
]
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
+
"metadata": {
|
| 288 |
+
"colab": {
|
| 289 |
+
"provenance": []
|
| 290 |
+
},
|
| 291 |
+
"kernelspec": {
|
| 292 |
+
"display_name": "Python 3",
|
| 293 |
+
"language": "python",
|
| 294 |
+
"name": "python3"
|
| 295 |
+
},
|
| 296 |
+
"language_info": {
|
| 297 |
+
"name": "python",
|
| 298 |
+
"version": "3.x"
|
| 299 |
+
}
|
| 300 |
+
},
|
| 301 |
+
"nbformat": 4,
|
| 302 |
+
"nbformat_minor": 0
|
| 303 |
+
}
|
PeptideAI/Data/Data Editors/csvCleanup.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
# Load data
|
| 4 |
+
df = pd.read_csv("cleaned_amp_data.csv")
|
| 5 |
+
|
| 6 |
+
# Drop index column if it exists
|
| 7 |
+
if 'Unnamed: 0' in df.columns:
|
| 8 |
+
df = df.drop(columns=['Unnamed: 0'])
|
| 9 |
+
|
| 10 |
+
# Drop duplicate sequences
|
| 11 |
+
df = df.drop_duplicates(subset='sequence')
|
| 12 |
+
|
| 13 |
+
# Save cleaned data
|
| 14 |
+
df.to_csv("2cleaned_amp_data.csv", index=False)
|
PeptideAI/Data/Data Editors/csvCombiner.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from Bio import SeqIO
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
amp_fasta = "amps.fasta"
|
| 6 |
+
non_amp_fasta = "non_amps.fasta"
|
| 7 |
+
output_csv = "ampData3.csv"
|
| 8 |
+
valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
|
| 9 |
+
|
| 10 |
+
# HELPER: clean and validate sequences
|
| 11 |
+
def clean_seq(seq):
|
| 12 |
+
seq = seq.strip().upper()
|
| 13 |
+
if not seq or any(aa not in valid_aas for aa in seq):
|
| 14 |
+
return None
|
| 15 |
+
return seq
|
| 16 |
+
|
| 17 |
+
# LOAD FASTAS
|
| 18 |
+
def load_fasta(filepath, label):
|
| 19 |
+
"""Load fasta file. Accepts a filename or path. If the path does not exist
|
| 20 |
+
as given, try resolving it relative to this script's directory.
|
| 21 |
+
Returns list of dicts: {"sequence": seq, "label": label}.
|
| 22 |
+
"""
|
| 23 |
+
p = Path(filepath)
|
| 24 |
+
|
| 25 |
+
if not p.exists():
|
| 26 |
+
p = Path(__file__).resolve().parent / filepath
|
| 27 |
+
if not p.exists():
|
| 28 |
+
raise FileNotFoundError(f"Fasta file not found: '{filepath}' (tried '{p}')")
|
| 29 |
+
|
| 30 |
+
records = []
|
| 31 |
+
for record in SeqIO.parse(str(p), "fasta"):
|
| 32 |
+
seq = clean_seq(str(record.seq))
|
| 33 |
+
if seq:
|
| 34 |
+
records.append({"sequence": seq, "label": label})
|
| 35 |
+
return records
|
| 36 |
+
|
| 37 |
+
amps = load_fasta(amp_fasta, 1)
|
| 38 |
+
non_amps = load_fasta(non_amp_fasta, 0)
|
| 39 |
+
|
| 40 |
+
print(f"Loaded {len(amps)} AMPs and {len(non_amps)} non-AMPs before cleaning.")
|
| 41 |
+
|
| 42 |
+
# REMOVE DUPLICATES
|
| 43 |
+
amp_df = pd.DataFrame(amps).drop_duplicates(subset=["sequence"])
|
| 44 |
+
non_amp_df = pd.DataFrame(non_amps).drop_duplicates(subset=["sequence"])
|
| 45 |
+
|
| 46 |
+
# BALANCE CLASSES
|
| 47 |
+
min_count = min(len(amp_df), len(non_amp_df))
|
| 48 |
+
amp_balanced = amp_df.sample(n=min_count, random_state=42)
|
| 49 |
+
non_amp_balanced = non_amp_df.sample(n=min_count, random_state=42)
|
| 50 |
+
|
| 51 |
+
# COMBINE AND SHUFFLE
|
| 52 |
+
final_df = pd.concat([amp_balanced, non_amp_balanced]).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 53 |
+
|
| 54 |
+
# SAVE TO CSV
|
| 55 |
+
final_df.to_csv(output_csv, index=False)
|
| 56 |
+
|
| 57 |
+
print(f"Saved balanced dataset with {len(final_df)} total sequences ({min_count} per class).")
|
| 58 |
+
print(f"Output file: {output_csv}")
|
PeptideAI/Data/Data Editors/fastaCleanup.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from Bio import SeqIO
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
# CONFIG
|
| 5 |
+
input_fasta = "amps.fasta"
|
| 6 |
+
output_fasta = "amps_clean.fasta"
|
| 7 |
+
output_csv = "amps_clean.csv"
|
| 8 |
+
|
| 9 |
+
min_len = 5
|
| 10 |
+
max_len = 100
|
| 11 |
+
valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
|
| 12 |
+
|
| 13 |
+
# CLEAN FUNCTION
|
| 14 |
+
def clean_seq(seq):
|
| 15 |
+
seq = seq.strip().upper()
|
| 16 |
+
if not (min_len <= len(seq) <= max_len):
|
| 17 |
+
return None
|
| 18 |
+
if any(aa not in valid_aas for aa in seq):
|
| 19 |
+
return None
|
| 20 |
+
return seq
|
| 21 |
+
|
| 22 |
+
# READ AND CLEAN
|
| 23 |
+
clean_records = []
|
| 24 |
+
for record in SeqIO.parse(input_fasta, "fasta"):
|
| 25 |
+
seq = clean_seq(str(record.seq))
|
| 26 |
+
if seq:
|
| 27 |
+
clean_records.append(seq)
|
| 28 |
+
|
| 29 |
+
# DEDUPLICATE
|
| 30 |
+
clean_records = list(set(clean_records))
|
| 31 |
+
|
| 32 |
+
# SAVE CLEAN FASTA
|
| 33 |
+
with open(output_fasta, "w") as f:
|
| 34 |
+
for i, seq in enumerate(clean_records, start=1):
|
| 35 |
+
f.write(f">AMP_{i}\n{seq}\n")
|
| 36 |
+
|
| 37 |
+
# SAVE CSV
|
| 38 |
+
pd.DataFrame({"sequence": clean_records}).to_csv(output_csv, index=False)
|
| 39 |
+
|
| 40 |
+
print(f"✅ Cleaned {len(clean_records)} sequences saved to '{output_fasta}' and '{output_csv}'.")
|
PeptideAI/Data/Sequence Fastas/amps.fasta
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PeptideAI/Data/Sequence Fastas/non_amps.fasta
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PeptideAI/Data/ampData.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PeptideAI/MLModels/ampModel.ipynb
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"id": "PwhltETnCLY1"
|
| 7 |
+
},
|
| 8 |
+
"source": [
|
| 9 |
+
"#**AMP Classification using ProtBERT Embeddings + Fast MLP**\n",
|
| 10 |
+
"This notebook extracts ProtBERT embeddings for peptide sequences and trains a simple Multi-Layer Perceptron (MLP) to classify antimicrobial peptides (AMPs) vs non-AMPs."
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"metadata": {
|
| 17 |
+
"collapsed": true,
|
| 18 |
+
"id": "qv_84qo0CLY6"
|
| 19 |
+
},
|
| 20 |
+
"outputs": [],
|
| 21 |
+
"source": [
|
| 22 |
+
"!pip install torch transformers scikit-learn numpy pandas tqdm"
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "code",
|
| 27 |
+
"execution_count": null,
|
| 28 |
+
"metadata": {
|
| 29 |
+
"collapsed": true,
|
| 30 |
+
"id": "4wld_6KBCLY7"
|
| 31 |
+
},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"import torch\n",
|
| 35 |
+
"import numpy as np\n",
|
| 36 |
+
"import pandas as pd\n",
|
| 37 |
+
"from tqdm import tqdm\n",
|
| 38 |
+
"from transformers import AutoTokenizer, AutoModel\n",
|
| 39 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 40 |
+
"from sklearn.preprocessing import LabelEncoder\n",
|
| 41 |
+
"from torch import nn, optim\n",
|
| 42 |
+
"from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, average_precision_score\n",
|
| 43 |
+
"import sys\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 46 |
+
"print('Device:', device)"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"metadata": {
|
| 52 |
+
"id": "7n3m1GLLCLY8"
|
| 53 |
+
},
|
| 54 |
+
"source": [
|
| 55 |
+
"##Load Dataset"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"metadata": {
|
| 62 |
+
"collapsed": true,
|
| 63 |
+
"id": "wAg_vM3JCLY8"
|
| 64 |
+
},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"IN_COLAB = 'google.colab' in sys.modules\n",
|
| 68 |
+
"if IN_COLAB:\n",
|
| 69 |
+
" from google.colab import drive\n",
|
| 70 |
+
" drive.mount('/content/drive')\n",
|
| 71 |
+
" file_path = '/content/drive/MyDrive/ampData.csv'\n",
|
| 72 |
+
"else:\n",
|
| 73 |
+
" file_path = 'ampData.csv'\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"df = pd.read_csv(file_path)\n",
|
| 76 |
+
"df['sequence'] = df['sequence'].astype(str).str.upper().str.strip()\n",
|
| 77 |
+
"df = df.dropna(subset=['sequence','label']).reset_index(drop=True)\n",
|
| 78 |
+
"df.head()"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"metadata": {
|
| 84 |
+
"id": "8HxUUO6SCLY8"
|
| 85 |
+
},
|
| 86 |
+
"source": [
|
| 87 |
+
"## Extract ProtBERT Embeddings"
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"cell_type": "code",
|
| 92 |
+
"execution_count": null,
|
| 93 |
+
"metadata": {
|
| 94 |
+
"collapsed": true,
|
| 95 |
+
"id": "CltjDxknCLY9"
|
| 96 |
+
},
|
| 97 |
+
"outputs": [],
|
| 98 |
+
"source": [
|
| 99 |
+
"tokenizer = AutoTokenizer.from_pretrained('Rostlab/prot_bert')\n",
|
| 100 |
+
"model = AutoModel.from_pretrained('Rostlab/prot_bert').to(device)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def get_embedding(sequence):\n",
|
| 103 |
+
" seq = ' '.join(list(sequence))\n",
|
| 104 |
+
" tokens = tokenizer(seq, return_tensors='pt', truncation=True, padding=True).to(device)\n",
|
| 105 |
+
" with torch.no_grad():\n",
|
| 106 |
+
" outputs = model(**tokens)\n",
|
| 107 |
+
" emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()\n",
|
| 108 |
+
" return emb\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"embeddings = []\n",
|
| 111 |
+
"for seq in tqdm(df['sequence'], desc='Extracting Embeddings'):\n",
|
| 112 |
+
" embeddings.append(get_embedding(seq))\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"X = np.array(embeddings)\n",
|
| 115 |
+
"y = df['label'].values\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"np.save('X_embeddings.npy', X)\n",
|
| 118 |
+
"np.save('y_labels.npy', y)"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"cell_type": "markdown",
|
| 123 |
+
"metadata": {
|
| 124 |
+
"id": "TZpCHIpTCLY9"
|
| 125 |
+
},
|
| 126 |
+
"source": [
|
| 127 |
+
"## Train-Test Split"
|
| 128 |
+
]
|
| 129 |
+
},
|
| 130 |
+
{
|
| 131 |
+
"cell_type": "code",
|
| 132 |
+
"execution_count": null,
|
| 133 |
+
"metadata": {
|
| 134 |
+
"id": "HUhsld4YCLY9"
|
| 135 |
+
},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"X_train = torch.tensor(X_train, dtype=torch.float32).to(device)\n",
|
| 141 |
+
"X_test = torch.tensor(X_test, dtype=torch.float32).to(device)\n",
|
| 142 |
+
"y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)\n",
|
| 143 |
+
"y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "markdown",
|
| 148 |
+
"metadata": {
|
| 149 |
+
"id": "aeeNh2s9CLY-"
|
| 150 |
+
},
|
| 151 |
+
"source": [
|
| 152 |
+
"## Define MLP Classifier"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": null,
|
| 158 |
+
"metadata": {
|
| 159 |
+
"collapsed": true,
|
| 160 |
+
"id": "V04ShQ1VCLY-"
|
| 161 |
+
},
|
| 162 |
+
"outputs": [],
|
| 163 |
+
"source": [
|
| 164 |
+
"class MLPClassifier(nn.Module):\n",
|
| 165 |
+
" def __init__(self, input_dim):\n",
|
| 166 |
+
" super().__init__()\n",
|
| 167 |
+
" self.layers = nn.Sequential(\n",
|
| 168 |
+
" nn.Linear(input_dim, 512),\n",
|
| 169 |
+
" nn.ReLU(),\n",
|
| 170 |
+
" nn.Dropout(0.3),\n",
|
| 171 |
+
" nn.Linear(512, 128),\n",
|
| 172 |
+
" nn.ReLU(),\n",
|
| 173 |
+
" nn.Linear(128, 1),\n",
|
| 174 |
+
" nn.Sigmoid()\n",
|
| 175 |
+
" )\n",
|
| 176 |
+
" def forward(self, x):\n",
|
| 177 |
+
" return self.layers(x)\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"model_mlp = MLPClassifier(X_train.shape[1]).to(device)\n",
|
| 180 |
+
"criterion = nn.BCELoss()\n",
|
| 181 |
+
"optimizer = optim.Adam(model_mlp.parameters(), lr=1e-4)\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"print(model_mlp)"
|
| 184 |
+
]
|
| 185 |
+
},
|
| 186 |
+
{
|
| 187 |
+
"cell_type": "markdown",
|
| 188 |
+
"metadata": {
|
| 189 |
+
"id": "XAsOa6l7CLY-"
|
| 190 |
+
},
|
| 191 |
+
"source": [
|
| 192 |
+
"## Train MLP"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "code",
|
| 197 |
+
"execution_count": null,
|
| 198 |
+
"metadata": {
|
| 199 |
+
"collapsed": true,
|
| 200 |
+
"id": "7sXSUh3WCLY-"
|
| 201 |
+
},
|
| 202 |
+
"outputs": [],
|
| 203 |
+
"source": [
|
| 204 |
+
"epochs = 20\n",
|
| 205 |
+
"batch_size = 64\n",
|
| 206 |
+
"\n",
|
| 207 |
+
"for epoch in range(epochs):\n",
|
| 208 |
+
" model_mlp.train()\n",
|
| 209 |
+
" perm = torch.randperm(X_train.size(0))\n",
|
| 210 |
+
" total_loss = 0\n",
|
| 211 |
+
" for i in range(0, X_train.size(0), batch_size):\n",
|
| 212 |
+
" idx = perm[i:i+batch_size]\n",
|
| 213 |
+
" x_batch, y_batch = X_train[idx], y_train[idx]\n",
|
| 214 |
+
" optimizer.zero_grad()\n",
|
| 215 |
+
" outputs = model_mlp(x_batch)\n",
|
| 216 |
+
" loss = criterion(outputs, y_batch)\n",
|
| 217 |
+
" loss.backward()\n",
|
| 218 |
+
" optimizer.step()\n",
|
| 219 |
+
" total_loss += loss.item()\n",
|
| 220 |
+
" print(f\"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}\")"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "markdown",
|
| 225 |
+
"metadata": {
|
| 226 |
+
"id": "A4XbUrqRCLY-"
|
| 227 |
+
},
|
| 228 |
+
"source": [
|
| 229 |
+
"## Evaluate"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "code",
|
| 234 |
+
"execution_count": null,
|
| 235 |
+
"metadata": {
|
| 236 |
+
"collapsed": true,
|
| 237 |
+
"id": "YtieKVFhCLY_"
|
| 238 |
+
},
|
| 239 |
+
"outputs": [],
|
| 240 |
+
"source": [
|
| 241 |
+
"model_mlp.eval()\n",
|
| 242 |
+
"with torch.no_grad():\n",
|
| 243 |
+
" preds = model_mlp(X_test).cpu().numpy().flatten()\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"pred_labels = (preds >= 0.5).astype(int)\n",
|
| 246 |
+
"print('ROC-AUC:', roc_auc_score(y_test.cpu(), preds))\n",
|
| 247 |
+
"print('PR-AUC:', average_precision_score(y_test.cpu(), preds))\n",
|
| 248 |
+
"print('\\nClassification Report:\\n', classification_report(y_test.cpu(), pred_labels))\n",
|
| 249 |
+
"print('Confusion Matrix:\\n', confusion_matrix(y_test.cpu(), pred_labels))"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "markdown",
|
| 254 |
+
"metadata": {
|
| 255 |
+
"id": "ADjCmp8PCLY_"
|
| 256 |
+
},
|
| 257 |
+
"source": [
|
| 258 |
+
"## Save Model"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": null,
|
| 264 |
+
"metadata": {
|
| 265 |
+
"collapsed": true,
|
| 266 |
+
"id": "v0j_4vwKCLY_"
|
| 267 |
+
},
|
| 268 |
+
"outputs": [],
|
| 269 |
+
"source": [
|
| 270 |
+
"torch.save(model_mlp.state_dict(), 'fast_mlp_amp.pt')\n",
|
| 271 |
+
"print('Model saved as fast_mlp_amp.pt')"
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"cell_type": "code",
|
| 276 |
+
"execution_count": null,
|
| 277 |
+
"metadata": {
|
| 278 |
+
"id": "IuJCNyBTXkBH"
|
| 279 |
+
},
|
| 280 |
+
"outputs": [],
|
| 281 |
+
"source": [
|
| 282 |
+
"from google.colab import files\n",
|
| 283 |
+
"files.download('fast_mlp_amp.pt')"
|
| 284 |
+
]
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
+
"metadata": {
|
| 288 |
+
"colab": {
|
| 289 |
+
"provenance": []
|
| 290 |
+
},
|
| 291 |
+
"kernelspec": {
|
| 292 |
+
"display_name": "Python 3",
|
| 293 |
+
"language": "python",
|
| 294 |
+
"name": "python3"
|
| 295 |
+
},
|
| 296 |
+
"language_info": {
|
| 297 |
+
"name": "python",
|
| 298 |
+
"version": "3.x"
|
| 299 |
+
}
|
| 300 |
+
},
|
| 301 |
+
"nbformat": 4,
|
| 302 |
+
"nbformat_minor": 0
|
| 303 |
+
}
|
PeptideAI/StreamlitApp/StreamlitApp.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.patches as mpatches
|
| 6 |
+
import torch
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from sklearn.manifold import TSNE
|
| 9 |
+
|
| 10 |
+
# modular imports
|
| 11 |
+
from utils.predict import load_model, predict_amp, encode_sequence
|
| 12 |
+
from utils.analyze import aa_composition, compute_properties
|
| 13 |
+
from utils.optimize import optimize_sequence
|
| 14 |
+
|
| 15 |
+
# APP CONFIG
|
| 16 |
+
st.set_page_config(page_title="AMP Predictor", layout="wide")
|
| 17 |
+
|
| 18 |
+
# App title
|
| 19 |
+
st.title("PeptideAI: Antimicrobial Peptide Predictor and Optimizer")
|
| 20 |
+
st.write("Use the sidebar to navigate between prediction, analysis, optimization, and visualization tools.")
|
| 21 |
+
st.markdown("---")
|
| 22 |
+
|
| 23 |
+
# SESSION STATE KEYS (one-time init)
|
| 24 |
+
if "predictions" not in st.session_state:
|
| 25 |
+
st.session_state.predictions = [] # list of dicts
|
| 26 |
+
if "predict_ran" not in st.session_state:
|
| 27 |
+
st.session_state.predict_ran = False
|
| 28 |
+
if "analyze_input" not in st.session_state:
|
| 29 |
+
st.session_state.analyze_input = "" # last analyze input
|
| 30 |
+
if "analyze_output" not in st.session_state:
|
| 31 |
+
st.session_state.analyze_output = None # (label, conf_display, comp, props, analysis)
|
| 32 |
+
if "optimize_input" not in st.session_state:
|
| 33 |
+
st.session_state.optimize_input = "" # last optimize input
|
| 34 |
+
if "optimize_output" not in st.session_state:
|
| 35 |
+
st.session_state.optimize_output = None # (orig_seq, orig_conf, improved_seq, improved_conf, history)
|
| 36 |
+
if "visualize_sequences" not in st.session_state:
|
| 37 |
+
st.session_state.visualize_sequences = None
|
| 38 |
+
if "visualize_df" not in st.session_state:
|
| 39 |
+
st.session_state.visualize_df = None
|
| 40 |
+
|
| 41 |
+
# SIDEBAR: navigation + global clear
|
| 42 |
+
st.sidebar.header("Navigation")
|
| 43 |
+
page = st.sidebar.radio("Go to", ["Predict", "Analyze", "Optimize", "Visualize", "About"])
|
| 44 |
+
|
| 45 |
+
if st.sidebar.button("Clear All Fields"):
|
| 46 |
+
|
| 47 |
+
# clear only our known keys
|
| 48 |
+
keys = ["predictions", "predict_ran",
|
| 49 |
+
"analyze_input", "analyze_output",
|
| 50 |
+
"optimize_input", "optimize_output",
|
| 51 |
+
"visualize_sequences", "visualize_df"]
|
| 52 |
+
for k in keys:
|
| 53 |
+
if k in st.session_state:
|
| 54 |
+
del st.session_state[k]
|
| 55 |
+
st.sidebar.success("Cleared app state.")
|
| 56 |
+
st.experimental_rerun()
|
| 57 |
+
|
| 58 |
+
# Load model once
|
| 59 |
+
model = load_model()
|
| 60 |
+
|
| 61 |
+
# PREDICT PAGE
|
| 62 |
+
if page == "Predict":
|
| 63 |
+
st.header("AMP Prediction")
|
| 64 |
+
|
| 65 |
+
seq_input = st.text_area("Enter peptide sequences (one per line):",
|
| 66 |
+
value="", height=150)
|
| 67 |
+
uploaded_file = st.file_uploader("Or upload a FASTA/text file", type=["txt", "fasta"])
|
| 68 |
+
|
| 69 |
+
run = st.button("Run Prediction")
|
| 70 |
+
|
| 71 |
+
if run:
|
| 72 |
+
|
| 73 |
+
# Gather sequences
|
| 74 |
+
sequences = []
|
| 75 |
+
if seq_input:
|
| 76 |
+
sequences += [s.strip() for s in seq_input.splitlines() if s.strip()]
|
| 77 |
+
if uploaded_file:
|
| 78 |
+
text = uploaded_file.read().decode("utf-8")
|
| 79 |
+
sequences += [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
|
| 80 |
+
|
| 81 |
+
if not sequences:
|
| 82 |
+
st.warning("Please input or upload sequences first.")
|
| 83 |
+
else:
|
| 84 |
+
with st.spinner("Predicting..."):
|
| 85 |
+
results = []
|
| 86 |
+
for seq in sequences:
|
| 87 |
+
label, conf = predict_amp(seq, model)
|
| 88 |
+
conf_display = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
|
| 89 |
+
results.append({
|
| 90 |
+
"Sequence": seq,
|
| 91 |
+
"Prediction": label,
|
| 92 |
+
"Confidence": conf,
|
| 93 |
+
"Description": f"{label} with {conf_display}% confidence"
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
# Persist new predictions and mark that we ran
|
| 97 |
+
st.session_state.predictions = results
|
| 98 |
+
st.session_state.predict_ran = True
|
| 99 |
+
st.success("Prediction complete.")
|
| 100 |
+
|
| 101 |
+
# If user hasn't just run predictions, show the last saved results (if any)
|
| 102 |
+
if st.session_state.predictions and not (run and st.session_state.predict_ran is False):
|
| 103 |
+
st.subheader("Predictions (last run)")
|
| 104 |
+
st.dataframe(pd.DataFrame(st.session_state.predictions), use_container_width=True)
|
| 105 |
+
csv = pd.DataFrame(st.session_state.predictions).to_csv(index=False)
|
| 106 |
+
st.download_button("Download predictions as CSV", csv, "predictions.csv", "text/csv")
|
| 107 |
+
|
| 108 |
+
# ANALYZE PAGE
|
| 109 |
+
elif page == "Analyze":
|
| 110 |
+
st.header("Sequence Analysis")
|
| 111 |
+
|
| 112 |
+
# show the last saved analyze output if user navigated back
|
| 113 |
+
last_seq = st.session_state.analyze_input
|
| 114 |
+
seq = st.text_input("Enter a peptide sequence to analyze:",
|
| 115 |
+
value=last_seq)
|
| 116 |
+
|
| 117 |
+
# only run analysis when input changed from last saved input
|
| 118 |
+
if seq and seq != st.session_state.get("analyze_input", ""):
|
| 119 |
+
with st.spinner("Running analysis..."):
|
| 120 |
+
label, conf = predict_amp(seq, model)
|
| 121 |
+
conf_pct = round(conf * 100, 1)
|
| 122 |
+
conf_display = conf_pct if label == "AMP" else 100 - conf_pct
|
| 123 |
+
|
| 124 |
+
comp = aa_composition(seq)
|
| 125 |
+
props = compute_properties(seq)
|
| 126 |
+
|
| 127 |
+
# normalize property key names if necessary
|
| 128 |
+
net_charge = props.get("Net Charge (approx.)",
|
| 129 |
+
props.get("Net charge", props.get("NetCharge", 0)))
|
| 130 |
+
|
| 131 |
+
# build analysis summary (same rules as before)
|
| 132 |
+
length = props.get("Length", len(seq))
|
| 133 |
+
hydro = props.get("Hydrophobic Fraction", props.get("Hydrophobic", 0))
|
| 134 |
+
charge = net_charge
|
| 135 |
+
mw = props.get("Molecular Weight (Da)", props.get("MolecularWeight", 0))
|
| 136 |
+
|
| 137 |
+
analysis = []
|
| 138 |
+
if (conf_pct if label == "AMP" else (100 - conf_pct)) >= 80:
|
| 139 |
+
analysis.append(f"Highly likely to be {label}.")
|
| 140 |
+
elif (conf_pct if label == "AMP" else (100 - conf_pct)) >= 60:
|
| 141 |
+
analysis.append(f"Moderately likely to be {label}.")
|
| 142 |
+
else:
|
| 143 |
+
analysis.append(f"Low likelihood to be {label}.")
|
| 144 |
+
|
| 145 |
+
if hydro < 0.4:
|
| 146 |
+
analysis.append("Low hydrophobicity may reduce membrane interaction.")
|
| 147 |
+
elif hydro > 0.6:
|
| 148 |
+
analysis.append("High hydrophobicity may reduce solubility.")
|
| 149 |
+
|
| 150 |
+
if charge <= 0:
|
| 151 |
+
analysis.append("Low or negative charge may limit antimicrobial activity.")
|
| 152 |
+
|
| 153 |
+
if length < 10:
|
| 154 |
+
analysis.append("Short sequence may reduce efficacy.")
|
| 155 |
+
elif length > 50:
|
| 156 |
+
analysis.append("Long sequence may affect stability.")
|
| 157 |
+
|
| 158 |
+
if comp.get("K", 0) + comp.get("R", 0) + comp.get("H", 0) >= 3:
|
| 159 |
+
analysis.append("High basic residue content enhances membrane binding.")
|
| 160 |
+
if comp.get("C", 0) + comp.get("W", 0) >= 2:
|
| 161 |
+
analysis.append("Multiple cysteine/tryptophan residues may improve activity.")
|
| 162 |
+
|
| 163 |
+
# Save to session state
|
| 164 |
+
st.session_state.analyze_input = seq
|
| 165 |
+
st.session_state.analyze_output = (label, conf, conf_display, comp, props, analysis)
|
| 166 |
+
|
| 167 |
+
# If we have stored output, display it
|
| 168 |
+
if st.session_state.analyze_output:
|
| 169 |
+
label, conf, conf_display, comp, props, analysis = st.session_state.analyze_output
|
| 170 |
+
|
| 171 |
+
st.subheader("AMP Prediction")
|
| 172 |
+
display_conf = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
|
| 173 |
+
st.write(f"Prediction: **{label}** with **{display_conf}%** confidence")
|
| 174 |
+
|
| 175 |
+
st.subheader("Amino Acid Composition")
|
| 176 |
+
comp_df = pd.DataFrame(list(comp.items()), columns=["Amino Acid", "Frequency"]).set_index("Amino Acid")
|
| 177 |
+
st.bar_chart(comp_df)
|
| 178 |
+
|
| 179 |
+
st.subheader("Physicochemical Properties and Favorability")
|
| 180 |
+
|
| 181 |
+
# pull properties safely
|
| 182 |
+
length = props.get("Length", len(st.session_state.analyze_input))
|
| 183 |
+
hydro = props.get("Hydrophobic Fraction", 0)
|
| 184 |
+
charge = props.get("Net Charge (approx.)", props.get("Net charge", 0))
|
| 185 |
+
mw = props.get("Molecular Weight (Da)", 0)
|
| 186 |
+
|
| 187 |
+
favorability = {
|
| 188 |
+
"Length": "Good" if 10 <= length <= 50 else "Too short" if length < 10 else "Too long",
|
| 189 |
+
"Hydrophobic Fraction": "Good" if 0.4 <= hydro <= 0.6 else "Low" if hydro < 0.4 else "High",
|
| 190 |
+
"Net Charge": "Favorable" if charge > 0 else "Neutral" if charge == 0 else "Unfavorable",
|
| 191 |
+
"Molecular Weight": "Acceptable" if 500 <= mw <= 5000 else "Extreme"
|
| 192 |
+
}
|
| 193 |
+
st.table(pd.DataFrame([
|
| 194 |
+
{"Property": "Length", "Value": length, "Favorability": favorability["Length"]},
|
| 195 |
+
{"Property": "Hydrophobic Fraction", "Value": hydro, "Favorability": favorability["Hydrophobic Fraction"]},
|
| 196 |
+
{"Property": "Net Charge", "Value": charge, "Favorability": favorability["Net Charge"]},
|
| 197 |
+
{"Property": "Molecular Weight", "Value": mw, "Favorability": favorability["Molecular Weight"]}
|
| 198 |
+
]))
|
| 199 |
+
|
| 200 |
+
st.subheader("Property Radar Chart")
|
| 201 |
+
categories = ["Length", "Hydrophobic Fraction", "Net Charge", "Molecular Weight"]
|
| 202 |
+
values = [min(length / 50, 1), min(hydro, 1), 1 if charge > 0 else 0, min(mw / 5000, 1)]
|
| 203 |
+
values += values[:1]
|
| 204 |
+
ideal_min = [10/50, 0.4, 1/6, 500/5000] + [10/50]
|
| 205 |
+
ideal_max = [50/50, 0.6, 6/6, 5000/5000] + [50/50]
|
| 206 |
+
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
|
| 207 |
+
angles += angles[:1]
|
| 208 |
+
|
| 209 |
+
# Adjusted figsize for better vertical space
|
| 210 |
+
fig, ax = plt.subplots(figsize=(2.8, 3.2), subplot_kw=dict(polar=True))
|
| 211 |
+
fig.patch.set_facecolor("white")
|
| 212 |
+
ax.fill_between(angles, ideal_min, ideal_max, color='#457a00', alpha=0.15, label="Ideal AMP range")
|
| 213 |
+
ax.plot(angles, values, 'o-', color='#457a00', linewidth=2, label="Sequence")
|
| 214 |
+
ax.fill(angles, values, color='#457a00', alpha=0.25)
|
| 215 |
+
ax.set_thetagrids(np.degrees(angles[:-1]), categories, fontsize=8)
|
| 216 |
+
ax.set_ylim(0, 1)
|
| 217 |
+
ax.tick_params(axis='y', labelsize=7)
|
| 218 |
+
ax.legend(loc='lower center', bbox_to_anchor=(0.85, 1.15), ncol=2, fontsize=7)
|
| 219 |
+
st.pyplot(fig, use_container_width=False)
|
| 220 |
+
|
| 221 |
+
# Analysis Summary
|
| 222 |
+
st.subheader("Analysis Summary")
|
| 223 |
+
for line in analysis:
|
| 224 |
+
st.write(f"- {line}")
|
| 225 |
+
|
| 226 |
+
# OPTIMIZE PAGE
|
| 227 |
+
elif page == "Optimize":
|
| 228 |
+
st.header("AMP Sequence Optimizer")
|
| 229 |
+
|
| 230 |
+
# Single entry point: text input retained across navigation
|
| 231 |
+
seq = st.text_input("Enter a peptide sequence to optimize:",
|
| 232 |
+
value=st.session_state.get("optimize_input", ""))
|
| 233 |
+
|
| 234 |
+
# Run optimization when user changes input and clicks button
|
| 235 |
+
if seq and st.button("Run Optimization"):
|
| 236 |
+
st.session_state.optimize_input = seq
|
| 237 |
+
with st.spinner("Optimizing sequence..."):
|
| 238 |
+
improved_seq, improved_conf, history = optimize_sequence(seq, model)
|
| 239 |
+
orig_label, orig_conf = predict_amp(seq, model)
|
| 240 |
+
st.session_state.optimize_output = (seq, orig_conf, improved_seq, improved_conf, history)
|
| 241 |
+
st.success("Optimization finished.")
|
| 242 |
+
|
| 243 |
+
# If there is saved output show it
|
| 244 |
+
if st.session_state.optimize_output:
|
| 245 |
+
orig_seq, orig_conf, improved_seq, improved_conf, history = st.session_state.optimize_output
|
| 246 |
+
st.subheader("Results")
|
| 247 |
+
st.write(f"**Original Sequence:** {orig_seq} — Confidence: {round(orig_conf*100,1)}%")
|
| 248 |
+
st.write(f"**Optimized Sequence:** {improved_seq} — Confidence: {round(improved_conf*100,1)}%")
|
| 249 |
+
|
| 250 |
+
if len(history) > 1:
|
| 251 |
+
df_steps = pd.DataFrame([{
|
| 252 |
+
"Step": i,
|
| 253 |
+
"Change": change,
|
| 254 |
+
"Old Type": old_type,
|
| 255 |
+
"New Type": new_type,
|
| 256 |
+
"Reason for Improvement": reason,
|
| 257 |
+
"New Confidence (%)": round(conf * 100, 2)
|
| 258 |
+
} for i, (seq_after, conf, change, old_type, new_type, reason) in enumerate(history[1:], start=1)])
|
| 259 |
+
st.subheader("Mutation Steps")
|
| 260 |
+
st.dataframe(df_steps, use_container_width=True)
|
| 261 |
+
|
| 262 |
+
# Confidence improvement plot
|
| 263 |
+
step_nums = df_steps["Step"].tolist()
|
| 264 |
+
conf_values = df_steps["New Confidence (%)"].tolist()
|
| 265 |
+
df_graph = pd.DataFrame({"Step": step_nums, "Confidence (%)": conf_values})
|
| 266 |
+
fig = px.line(df_graph, x="Step", y="Confidence (%)", markers=True, color_discrete_sequence=["#457a00"])
|
| 267 |
+
fig.update_layout(yaxis=dict(range=[0, 100]), title="Confidence Improvement Over Steps")
|
| 268 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 269 |
+
|
| 270 |
+
# VISUALIZE PAGE
|
| 271 |
+
elif page == "Visualize":
|
| 272 |
+
st.header("Sequence Embedding Visualization")
|
| 273 |
+
st.write("Upload peptide sequences (FASTA or plain list) to visualize embeddings with t-SNE.")
|
| 274 |
+
|
| 275 |
+
uploaded_file = st.file_uploader("Upload FASTA or text file", type=["txt", "fasta"])
|
| 276 |
+
|
| 277 |
+
# If file uploaded, set session sequences (replacing previous)
|
| 278 |
+
if uploaded_file:
|
| 279 |
+
text = uploaded_file.read().decode("utf-8")
|
| 280 |
+
sequences = [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
|
| 281 |
+
st.session_state.visualize_sequences = sequences
|
| 282 |
+
|
| 283 |
+
# Clear any previous df so we recompute
|
| 284 |
+
st.session_state.visualize_df = None
|
| 285 |
+
|
| 286 |
+
# If we have sequences stored, compute embeddings and t-SNE if no df present
|
| 287 |
+
if st.session_state.visualize_sequences and st.session_state.visualize_df is None:
|
| 288 |
+
sequences = st.session_state.visualize_sequences
|
| 289 |
+
if len(sequences) < 2:
|
| 290 |
+
st.warning("Need at least 2 sequences for t-SNE visualization.")
|
| 291 |
+
else:
|
| 292 |
+
with st.spinner("Generating embeddings and running t-SNE..."):
|
| 293 |
+
embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
|
| 294 |
+
|
| 295 |
+
# Use model internals for embeddings; keep same approach as your module
|
| 296 |
+
embedding_extractor = torch.nn.Sequential(*list(model.layers)[:-1])
|
| 297 |
+
|
| 298 |
+
for s in sequences:
|
| 299 |
+
x = torch.tensor(encode_sequence(s), dtype=torch.float32).unsqueeze(0)
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
emb = embedding_extractor(x).squeeze().numpy()
|
| 302 |
+
embeddings_list.append(emb)
|
| 303 |
+
label, conf = predict_amp(s, model)
|
| 304 |
+
labels.append(label)
|
| 305 |
+
confs.append(conf)
|
| 306 |
+
props = compute_properties(s)
|
| 307 |
+
lengths.append(props.get("Length", len(s)))
|
| 308 |
+
hydros.append(props.get("Hydrophobic Fraction", 0))
|
| 309 |
+
charges.append(props.get("Net Charge (approx.)", props.get("Net charge", 0)))
|
| 310 |
+
|
| 311 |
+
embeddings_array = np.stack(embeddings_list)
|
| 312 |
+
perplexity = min(30, max(2, len(sequences) - 1))
|
| 313 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
| 314 |
+
reduced = tsne.fit_transform(embeddings_array)
|
| 315 |
+
|
| 316 |
+
df = pd.DataFrame(reduced, columns=["x", "y"])
|
| 317 |
+
df["Sequence"] = sequences
|
| 318 |
+
df["Label"] = labels
|
| 319 |
+
df["Confidence"] = confs
|
| 320 |
+
df["Length"] = lengths
|
| 321 |
+
df["Hydrophobic Fraction"] = hydros
|
| 322 |
+
df["Net Charge"] = charges
|
| 323 |
+
|
| 324 |
+
st.session_state.visualize_df = df
|
| 325 |
+
|
| 326 |
+
# If we have a t-SNE dataframe, show plot and sidebar filters
|
| 327 |
+
if st.session_state.visualize_df is not None:
|
| 328 |
+
df = st.session_state.visualize_df
|
| 329 |
+
st.subheader("t-SNE plot")
|
| 330 |
+
|
| 331 |
+
st.sidebar.subheader("Filter Sequences")
|
| 332 |
+
min_len, max_len = int(df["Length"].min()), int(df["Length"].max())
|
| 333 |
+
if min_len == max_len:
|
| 334 |
+
st.sidebar.write(f"All sequences have length {min_len}")
|
| 335 |
+
length_range = (min_len, max_len)
|
| 336 |
+
else:
|
| 337 |
+
length_range = st.sidebar.slider("Sequence length", min_len, max_len, (min_len, max_len))
|
| 338 |
+
|
| 339 |
+
label_options = st.sidebar.multiselect("Label", ["AMP", "Non-AMP"], default=["AMP", "Non-AMP"])
|
| 340 |
+
filtered_df = df[(df["Length"].between(length_range[0], length_range[1])) & (df["Label"].isin(label_options))]
|
| 341 |
+
color_by = st.sidebar.selectbox("Color points by", ["Label", "Confidence", "Hydrophobic Fraction", "Net Charge", "Length"])
|
| 342 |
+
|
| 343 |
+
color_map = {"AMP": "#2ca02c", "Non-AMP": "#d62728"}
|
| 344 |
+
fig = px.scatter(
|
| 345 |
+
filtered_df,
|
| 346 |
+
x="x", y="y",
|
| 347 |
+
color=color_by if color_by != "Label" else "Label",
|
| 348 |
+
color_discrete_map=color_map if color_by == "Label" else None,
|
| 349 |
+
hover_data={"Sequence": True, "Label": True, "Confidence": True, "Length": True, "Hydrophobic Fraction": True, "Net Charge": True},
|
| 350 |
+
title="t-SNE Visualization of Model Embeddings"
|
| 351 |
+
)
|
| 352 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 353 |
+
|
| 354 |
+
st.subheader("t-SNE Analysis")
|
| 355 |
+
st.markdown("""
|
| 356 |
+
• Each point represents a peptide sequence.
|
| 357 |
+
• Sequences close together have similar internal representations in the model.
|
| 358 |
+
• AMP and Non-AMP clusters indicate strong model separation.
|
| 359 |
+
• Coloring by properties reveals biochemical trends.
|
| 360 |
+
""")
|
| 361 |
+
|
| 362 |
+
# ABOUT PAGE
|
| 363 |
+
elif page == "About":
|
| 364 |
+
st.header("About the Project")
|
| 365 |
+
st.markdown("""
|
| 366 |
+
**Problem:** Antimicrobial resistance is a global health threat. Traditional peptide screening is slow and costly.
|
| 367 |
+
**Solution:** This tool predicts antimicrobial activity directly from sequence using deep learning, speeding up AMP discovery.
|
| 368 |
+
""")
|
PeptideAI/StreamlitApp/utils/__init__.py
ADDED
|
File without changes
|
PeptideAI/StreamlitApp/utils/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
PeptideAI/StreamlitApp/utils/__pycache__/analyze.cpython-313.pyc
ADDED
|
Binary file (2.54 kB). View file
|
|
|
PeptideAI/StreamlitApp/utils/__pycache__/optimize.cpython-313.pyc
ADDED
|
Binary file (2.49 kB). View file
|
|
|
PeptideAI/StreamlitApp/utils/__pycache__/predict.cpython-313.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
PeptideAI/StreamlitApp/utils/__pycache__/visualize.cpython-313.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
PeptideAI/StreamlitApp/utils/analyze.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
def aa_composition(sequence):
|
| 4 |
+
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
|
| 5 |
+
counts = Counter(sequence)
|
| 6 |
+
total = len(sequence)
|
| 7 |
+
return {aa: counts.get(aa, 0) / total for aa in amino_acids}
|
| 8 |
+
|
| 9 |
+
# Compute sequence properties
|
| 10 |
+
def compute_properties(sequence):
|
| 11 |
+
|
| 12 |
+
# Property calculations
|
| 13 |
+
aa_weights = {'A': 89.1, 'R': 174.2, 'N': 132.1, 'D': 133.1, 'C': 121.2,
|
| 14 |
+
'E': 147.1, 'Q': 146.2, 'G': 75.1, 'H': 155.2, 'I': 131.2,
|
| 15 |
+
'L': 131.2, 'K': 146.2, 'M': 149.2, 'F': 165.2, 'P': 115.1,
|
| 16 |
+
'S': 105.1, 'T': 119.1, 'W': 204.2, 'Y': 181.2, 'V': 117.1}
|
| 17 |
+
mw = sum(aa_weights.get(aa, 0) for aa in sequence)
|
| 18 |
+
hydrophobic = sum(1 for aa in sequence if aa in "AILMFWYV") / len(sequence)
|
| 19 |
+
charge = sum(1 for aa in sequence if aa in "KRH") - sum(1 for aa in sequence if aa in "DE")
|
| 20 |
+
return {"Length": len(sequence), "Molecular Weight (Da)": round(mw, 2),
|
| 21 |
+
"Hydrophobic Fraction": round(hydrophobic, 3), "Net Charge (approx.)": charge}
|
PeptideAI/StreamlitApp/utils/optimize.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from utils.predict import predict_amp
|
| 3 |
+
|
| 4 |
+
HYDROPHOBIC = set("AILMFWVPG")
|
| 5 |
+
HYDROPHILIC = set("STNQYCH")
|
| 6 |
+
POSITIVE = set("KRH")
|
| 7 |
+
NEGATIVE = set("DE")
|
| 8 |
+
|
| 9 |
+
# Function to mutate a residue based on simple heuristics
|
| 10 |
+
def mutate_residue(residue):
|
| 11 |
+
if residue in POSITIVE:
|
| 12 |
+
return residue, "Retained strong positive residue"
|
| 13 |
+
elif residue in NEGATIVE:
|
| 14 |
+
return random.choice(list(POSITIVE)), "Increased positive charge"
|
| 15 |
+
elif residue in HYDROPHILIC:
|
| 16 |
+
return random.choice(list(HYDROPHOBIC)), "Improved hydrophobicity balance"
|
| 17 |
+
elif residue in HYDROPHOBIC:
|
| 18 |
+
return random.choice(list(POSITIVE | HYDROPHILIC)), "Enhanced amphipathicity"
|
| 19 |
+
else:
|
| 20 |
+
return random.choice(list(HYDROPHOBIC)), "Adjusted physicochemical profile"
|
| 21 |
+
|
| 22 |
+
# Sequence optimization function
|
| 23 |
+
def optimize_sequence(seq, model, max_rounds=20, confidence_threshold=0.001):
|
| 24 |
+
"""
|
| 25 |
+
Iteratively optimize sequence to increase AMP probability.
|
| 26 |
+
Tries mutating all positions per round and accepts the best change.
|
| 27 |
+
"""
|
| 28 |
+
current_seq = seq
|
| 29 |
+
label, conf = predict_amp(current_seq, model)
|
| 30 |
+
best_conf = conf
|
| 31 |
+
history = [(current_seq, conf, "-", "-", "-", "Original sequence")]
|
| 32 |
+
|
| 33 |
+
# Optimization loop
|
| 34 |
+
for _ in range(max_rounds):
|
| 35 |
+
best_mutation = None
|
| 36 |
+
best_mutation_conf = best_conf
|
| 37 |
+
|
| 38 |
+
for pos, old_res in enumerate(current_seq):
|
| 39 |
+
new_res, reason = mutate_residue(old_res)
|
| 40 |
+
if new_res == old_res:
|
| 41 |
+
continue
|
| 42 |
+
new_seq = current_seq[:pos] + new_res + current_seq[pos+1:]
|
| 43 |
+
_, new_conf = predict_amp(new_seq, model)
|
| 44 |
+
|
| 45 |
+
if new_conf > best_mutation_conf:
|
| 46 |
+
best_mutation_conf = new_conf
|
| 47 |
+
best_mutation = (new_seq, pos, old_res, new_res, reason)
|
| 48 |
+
|
| 49 |
+
if best_mutation and best_mutation_conf - best_conf >= confidence_threshold:
|
| 50 |
+
current_seq, pos, old_res, new_res, reason = best_mutation
|
| 51 |
+
best_conf = best_mutation_conf
|
| 52 |
+
change = f"Pos {pos+1}: {old_res} → {new_res}"
|
| 53 |
+
history.append((current_seq, best_conf, change, old_res, new_res, reason))
|
| 54 |
+
else:
|
| 55 |
+
|
| 56 |
+
# No further improvement, stop
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
return current_seq, best_conf, history
|
PeptideAI/StreamlitApp/utils/predict.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import requests
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from torch import nn
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import shutil
|
| 10 |
+
|
| 11 |
+
# Model Definition
|
| 12 |
+
class FastMLP(nn.Module):
|
| 13 |
+
def __init__(self, input_dim=1024):
|
| 14 |
+
super(FastMLP, self).__init__()
|
| 15 |
+
self.layers = nn.Sequential(
|
| 16 |
+
nn.Linear(input_dim, 512),
|
| 17 |
+
nn.ReLU(),
|
| 18 |
+
nn.Dropout(0.3),
|
| 19 |
+
nn.Linear(512, 128),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Linear(128, 1) # Single output for binary classification
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.layers(x)
|
| 26 |
+
|
| 27 |
+
# Utility: download file from URL to local path (streaming)
|
| 28 |
+
def _download_file(url: str, dest_path: str):
|
| 29 |
+
dest = pathlib.Path(dest_path)
|
| 30 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
with requests.get(url, stream=True) as r:
|
| 32 |
+
r.raise_for_status()
|
| 33 |
+
with open(dest, 'wb') as f:
|
| 34 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 35 |
+
if chunk:
|
| 36 |
+
f.write(chunk)
|
| 37 |
+
|
| 38 |
+
def _get_env(key: str) -> Optional[str]:
|
| 39 |
+
v = os.environ.get(key)
|
| 40 |
+
return v if v else None
|
| 41 |
+
|
| 42 |
+
# Model Loader
|
| 43 |
+
@st.cache_resource
|
| 44 |
+
def load_model():
|
| 45 |
+
# Always resolve relative to the StreamlitApp folder, not the process CWD.
|
| 46 |
+
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
|
| 47 |
+
model_path = streamlitapp_dir / "models" / "ampMLModel.pt"
|
| 48 |
+
|
| 49 |
+
# If the model file doesn't exist, try to download it from a configured URL
|
| 50 |
+
if not model_path.exists():
|
| 51 |
+
|
| 52 |
+
model_url = _get_env("MODEL_URL")
|
| 53 |
+
|
| 54 |
+
if model_url:
|
| 55 |
+
try:
|
| 56 |
+
_download_file(model_url, str(model_path))
|
| 57 |
+
except Exception as e:
|
| 58 |
+
st.error(f"Failed to download model from MODEL_URL: {e}")
|
| 59 |
+
raise
|
| 60 |
+
else:
|
| 61 |
+
model_repo_id = _get_env("MODEL_REPO_ID")
|
| 62 |
+
model_filename = _get_env("MODEL_FILENAME") or "ampMLModel.pt"
|
| 63 |
+
|
| 64 |
+
if not model_repo_id:
|
| 65 |
+
raise FileNotFoundError(
|
| 66 |
+
"Model file './models/ampMLModel.pt' not found.\n"
|
| 67 |
+
"Set one of:\n"
|
| 68 |
+
"- MODEL_URL (direct download URL), or\n"
|
| 69 |
+
"- MODEL_REPO_ID (Hugging Face model repo id) and optional MODEL_FILENAME.\n"
|
| 70 |
+
"\n"
|
| 71 |
+
"Debug (env vars detected): "
|
| 72 |
+
f"MODEL_URL={'set' if _get_env('MODEL_URL') else 'missing'}, "
|
| 73 |
+
f"MODEL_REPO_ID={'set' if _get_env('MODEL_REPO_ID') else 'missing'}, "
|
| 74 |
+
f"MODEL_FILENAME={'set' if _get_env('MODEL_FILENAME') else 'missing'}\n"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
from huggingface_hub import hf_hub_download
|
| 79 |
+
except Exception as e:
|
| 80 |
+
raise RuntimeError(
|
| 81 |
+
"Missing dependency 'huggingface_hub'. Add it to requirements.txt.\n"
|
| 82 |
+
f"Import error: {e}"
|
| 83 |
+
) from e
|
| 84 |
+
|
| 85 |
+
token = _get_env("HF_TOKEN") or _get_env("HUGGINGFACE_TOKEN")
|
| 86 |
+
downloaded_path = hf_hub_download(
|
| 87 |
+
repo_id=model_repo_id,
|
| 88 |
+
filename=model_filename,
|
| 89 |
+
token=token,
|
| 90 |
+
)
|
| 91 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
shutil.copyfile(downloaded_path, model_path)
|
| 93 |
+
|
| 94 |
+
if not model_path.exists():
|
| 95 |
+
raise FileNotFoundError(
|
| 96 |
+
f"Model download did not produce file at: {model_path}\n"
|
| 97 |
+
"Check MODEL_URL or MODEL_REPO_ID/MODEL_FILENAME configuration."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Build model and load weights
|
| 101 |
+
model = FastMLP(input_dim=1024)
|
| 102 |
+
model.load_state_dict(torch.load(str(model_path), map_location="cpu"))
|
| 103 |
+
model.eval()
|
| 104 |
+
return model
|
| 105 |
+
|
| 106 |
+
# Sequence Encoder
|
| 107 |
+
def encode_sequence(seq, max_len=51):
|
| 108 |
+
"""
|
| 109 |
+
Converts amino acid sequence to flattened one-hot vector
|
| 110 |
+
padded/truncated to match model input_dim (1024)
|
| 111 |
+
"""
|
| 112 |
+
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
| 113 |
+
aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}
|
| 114 |
+
|
| 115 |
+
one_hot = np.zeros((max_len, len(amino_acids))) # max_len x 20
|
| 116 |
+
for i, aa in enumerate(seq[:max_len]):
|
| 117 |
+
if aa in aa_to_idx:
|
| 118 |
+
one_hot[i, aa_to_idx[aa]] = 1
|
| 119 |
+
|
| 120 |
+
flat = one_hot.flatten() # length = max_len*20 = 1020
|
| 121 |
+
|
| 122 |
+
if len(flat) < 1024:
|
| 123 |
+
flat = np.pad(flat, (0, 1024 - len(flat)))
|
| 124 |
+
|
| 125 |
+
return flat
|
| 126 |
+
|
| 127 |
+
# Prediction Function
|
| 128 |
+
def predict_amp(sequence, model):
|
| 129 |
+
"""
|
| 130 |
+
Takes an amino acid sequence string and the loaded model,
|
| 131 |
+
returns ("AMP"/"Non-AMP") and probability
|
| 132 |
+
"""
|
| 133 |
+
x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
logits = model(x)
|
| 137 |
+
prob = torch.sigmoid(logits).item()
|
| 138 |
+
|
| 139 |
+
label = "AMP" if prob >= 0.5 else "Non-AMP"
|
| 140 |
+
return label, round(prob, 3)
|
PeptideAI/StreamlitApp/utils/rateLimit.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from collections import deque
|
| 3 |
+
|
| 4 |
+
class RateLimiter:
|
| 5 |
+
|
| 6 |
+
#Sliding-window rate limiter per instance
|
| 7 |
+
def __init__(self, max_calls: int, period_seconds: float):
|
| 8 |
+
self.max_calls = max_calls
|
| 9 |
+
self.period = period_seconds
|
| 10 |
+
self.calls = deque()
|
| 11 |
+
|
| 12 |
+
def allow(self) -> bool:
|
| 13 |
+
now = time.time()
|
| 14 |
+
|
| 15 |
+
# Drop entries older than window
|
| 16 |
+
while self.calls and self.calls[0] <= now - self.period:
|
| 17 |
+
self.calls.popleft()
|
| 18 |
+
if len(self.calls) < self.max_calls:
|
| 19 |
+
self.calls.append(now)
|
| 20 |
+
return True
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
def time_until_next(self) -> float:
|
| 24 |
+
|
| 25 |
+
# Seconds until next slot is available (0 if already available)
|
| 26 |
+
now = time.time()
|
| 27 |
+
if len(self.calls) < self.max_calls:
|
| 28 |
+
return 0.0
|
| 29 |
+
oldest = self.calls[0]
|
| 30 |
+
return max(0.0, (oldest + self.period) - now)
|
PeptideAI/StreamlitApp/utils/visualize.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from sklearn.manifold import TSNE
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils.predict import encode_sequence
|
| 8 |
+
|
| 9 |
+
# t-SNE Visualization
|
| 10 |
+
def tsne_visualization(sequences, model):
|
| 11 |
+
st.info("Generating embeddings... this may take a moment.")
|
| 12 |
+
embeddings = []
|
| 13 |
+
for seq in sequences:
|
| 14 |
+
x = torch.tensor(encode_sequence(seq), dtype=torch.float32).unsqueeze(0)
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
emb = model.layers[0](x) # Grab first layer embedding
|
| 17 |
+
embeddings.append(emb.numpy().flatten())
|
| 18 |
+
|
| 19 |
+
embeddings = np.vstack(embeddings)
|
| 20 |
+
|
| 21 |
+
perplexity = min(30, len(sequences) - 1)
|
| 22 |
+
if perplexity < 2:
|
| 23 |
+
st.warning("Need at least 2 sequences for visualization.")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
| 27 |
+
reduced = tsne.fit_transform(embeddings)
|
| 28 |
+
df = pd.DataFrame(reduced, columns=["x", "y"])
|
| 29 |
+
|
| 30 |
+
st.success("t-SNE visualization complete.")
|
| 31 |
+
st.scatter_chart(df)
|
README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PeptideAI
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: "1.41.1"
|
| 8 |
+
python_version: "3.13"
|
| 9 |
+
app_file: PeptideAI/StreamlitApp/StreamlitApp.py
|
| 10 |
+
pinned: false
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# PeptideAI
|
| 14 |
+
Antimicrobial Peptide (AMP) Prediction App
|
| 15 |
+
A machine learning web app that predicts antimicrobial activity from peptide sequences.
|
| 16 |
+
Built with Python, PyTorch, and Streamlit, it uses ProtBERT embeddings to represent biological sequences and a custom neural network classifier for prediction.
|
| 17 |
+
Includes features for:
|
| 18 |
+
|
| 19 |
+
- AMP probability prediction
|
| 20 |
+
- Amino acid composition analysis
|
| 21 |
+
- Physicochemical property computation
|
| 22 |
+
- t-SNE visualization of embeddings
|
StreamlitApp/StreamlitApp.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.patches as mpatches
|
| 6 |
+
import torch
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
from sklearn.manifold import TSNE
|
| 9 |
+
|
| 10 |
+
# modular imports
|
| 11 |
+
from utils.predict import load_model, predict_amp, encode_sequence
|
| 12 |
+
from utils.analyze import aa_composition, compute_properties
|
| 13 |
+
from utils.optimize import optimize_sequence
|
| 14 |
+
|
| 15 |
+
# APP CONFIG
|
| 16 |
+
st.set_page_config(page_title="AMP Predictor", layout="wide")
|
| 17 |
+
|
| 18 |
+
# App title
|
| 19 |
+
st.title("PeptideAI: Antimicrobial Peptide Predictor and Optimizer")
|
| 20 |
+
st.write("Use the sidebar to navigate between prediction, analysis, optimization, and visualization tools.")
|
| 21 |
+
st.markdown("---")
|
| 22 |
+
|
| 23 |
+
# SESSION STATE KEYS (one-time init)
|
| 24 |
+
if "predictions" not in st.session_state:
|
| 25 |
+
st.session_state.predictions = [] # list of dicts
|
| 26 |
+
if "predict_ran" not in st.session_state:
|
| 27 |
+
st.session_state.predict_ran = False
|
| 28 |
+
if "analyze_input" not in st.session_state:
|
| 29 |
+
st.session_state.analyze_input = "" # last analyze input
|
| 30 |
+
if "analyze_output" not in st.session_state:
|
| 31 |
+
st.session_state.analyze_output = None # (label, conf_display, comp, props, analysis)
|
| 32 |
+
if "optimize_input" not in st.session_state:
|
| 33 |
+
st.session_state.optimize_input = "" # last optimize input
|
| 34 |
+
if "optimize_output" not in st.session_state:
|
| 35 |
+
st.session_state.optimize_output = None # (orig_seq, orig_conf, improved_seq, improved_conf, history)
|
| 36 |
+
if "visualize_sequences" not in st.session_state:
|
| 37 |
+
st.session_state.visualize_sequences = None
|
| 38 |
+
if "visualize_df" not in st.session_state:
|
| 39 |
+
st.session_state.visualize_df = None
|
| 40 |
+
|
| 41 |
+
# SIDEBAR: navigation + global clear
|
| 42 |
+
st.sidebar.header("Navigation")
|
| 43 |
+
page = st.sidebar.radio("Go to", ["Predict", "Analyze", "Optimize", "Visualize", "About"])
|
| 44 |
+
|
| 45 |
+
if st.sidebar.button("Clear All Fields"):
|
| 46 |
+
|
| 47 |
+
# clear only our known keys
|
| 48 |
+
keys = ["predictions", "predict_ran",
|
| 49 |
+
"analyze_input", "analyze_output",
|
| 50 |
+
"optimize_input", "optimize_output",
|
| 51 |
+
"visualize_sequences", "visualize_df"]
|
| 52 |
+
for k in keys:
|
| 53 |
+
if k in st.session_state:
|
| 54 |
+
del st.session_state[k]
|
| 55 |
+
st.sidebar.success("Cleared app state.")
|
| 56 |
+
st.experimental_rerun()
|
| 57 |
+
|
| 58 |
+
# Load model once
|
| 59 |
+
model = load_model()
|
| 60 |
+
|
| 61 |
+
# PREDICT PAGE
|
| 62 |
+
if page == "Predict":
|
| 63 |
+
st.header("AMP Prediction")
|
| 64 |
+
|
| 65 |
+
seq_input = st.text_area("Enter peptide sequences (one per line):",
|
| 66 |
+
value="", height=150)
|
| 67 |
+
uploaded_file = st.file_uploader("Or upload a FASTA/text file", type=["txt", "fasta"])
|
| 68 |
+
|
| 69 |
+
run = st.button("Run Prediction")
|
| 70 |
+
|
| 71 |
+
if run:
|
| 72 |
+
|
| 73 |
+
# Gather sequences
|
| 74 |
+
sequences = []
|
| 75 |
+
if seq_input:
|
| 76 |
+
sequences += [s.strip() for s in seq_input.splitlines() if s.strip()]
|
| 77 |
+
if uploaded_file:
|
| 78 |
+
text = uploaded_file.read().decode("utf-8")
|
| 79 |
+
sequences += [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
|
| 80 |
+
|
| 81 |
+
if not sequences:
|
| 82 |
+
st.warning("Please input or upload sequences first.")
|
| 83 |
+
else:
|
| 84 |
+
with st.spinner("Predicting..."):
|
| 85 |
+
results = []
|
| 86 |
+
for seq in sequences:
|
| 87 |
+
label, conf = predict_amp(seq, model)
|
| 88 |
+
conf_display = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
|
| 89 |
+
results.append({
|
| 90 |
+
"Sequence": seq,
|
| 91 |
+
"Prediction": label,
|
| 92 |
+
"Confidence": conf,
|
| 93 |
+
"Description": f"{label} with {conf_display}% confidence"
|
| 94 |
+
})
|
| 95 |
+
|
| 96 |
+
# Persist new predictions and mark that we ran
|
| 97 |
+
st.session_state.predictions = results
|
| 98 |
+
st.session_state.predict_ran = True
|
| 99 |
+
st.success("Prediction complete.")
|
| 100 |
+
|
| 101 |
+
# If user hasn't just run predictions, show the last saved results (if any)
|
| 102 |
+
if st.session_state.predictions and not (run and st.session_state.predict_ran is False):
|
| 103 |
+
st.subheader("Predictions (last run)")
|
| 104 |
+
st.dataframe(pd.DataFrame(st.session_state.predictions), use_container_width=True)
|
| 105 |
+
csv = pd.DataFrame(st.session_state.predictions).to_csv(index=False)
|
| 106 |
+
st.download_button("Download predictions as CSV", csv, "predictions.csv", "text/csv")
|
| 107 |
+
|
| 108 |
+
# ANALYZE PAGE
|
| 109 |
+
elif page == "Analyze":
|
| 110 |
+
st.header("Sequence Analysis")
|
| 111 |
+
|
| 112 |
+
# show the last saved analyze output if user navigated back
|
| 113 |
+
last_seq = st.session_state.analyze_input
|
| 114 |
+
seq = st.text_input("Enter a peptide sequence to analyze:",
|
| 115 |
+
value=last_seq)
|
| 116 |
+
|
| 117 |
+
# only run analysis when input changed from last saved input
|
| 118 |
+
if seq and seq != st.session_state.get("analyze_input", ""):
|
| 119 |
+
with st.spinner("Running analysis..."):
|
| 120 |
+
label, conf = predict_amp(seq, model)
|
| 121 |
+
conf_pct = round(conf * 100, 1)
|
| 122 |
+
conf_display = conf_pct if label == "AMP" else 100 - conf_pct
|
| 123 |
+
|
| 124 |
+
comp = aa_composition(seq)
|
| 125 |
+
props = compute_properties(seq)
|
| 126 |
+
|
| 127 |
+
# normalize property key names if necessary
|
| 128 |
+
net_charge = props.get("Net Charge (approx.)",
|
| 129 |
+
props.get("Net charge", props.get("NetCharge", 0)))
|
| 130 |
+
|
| 131 |
+
# build analysis summary (same rules as before)
|
| 132 |
+
length = props.get("Length", len(seq))
|
| 133 |
+
hydro = props.get("Hydrophobic Fraction", props.get("Hydrophobic", 0))
|
| 134 |
+
charge = net_charge
|
| 135 |
+
mw = props.get("Molecular Weight (Da)", props.get("MolecularWeight", 0))
|
| 136 |
+
|
| 137 |
+
analysis = []
|
| 138 |
+
if (conf_pct if label == "AMP" else (100 - conf_pct)) >= 80:
|
| 139 |
+
analysis.append(f"Highly likely to be {label}.")
|
| 140 |
+
elif (conf_pct if label == "AMP" else (100 - conf_pct)) >= 60:
|
| 141 |
+
analysis.append(f"Moderately likely to be {label}.")
|
| 142 |
+
else:
|
| 143 |
+
analysis.append(f"Low likelihood to be {label}.")
|
| 144 |
+
|
| 145 |
+
if hydro < 0.4:
|
| 146 |
+
analysis.append("Low hydrophobicity may reduce membrane interaction.")
|
| 147 |
+
elif hydro > 0.6:
|
| 148 |
+
analysis.append("High hydrophobicity may reduce solubility.")
|
| 149 |
+
|
| 150 |
+
if charge <= 0:
|
| 151 |
+
analysis.append("Low or negative charge may limit antimicrobial activity.")
|
| 152 |
+
|
| 153 |
+
if length < 10:
|
| 154 |
+
analysis.append("Short sequence may reduce efficacy.")
|
| 155 |
+
elif length > 50:
|
| 156 |
+
analysis.append("Long sequence may affect stability.")
|
| 157 |
+
|
| 158 |
+
if comp.get("K", 0) + comp.get("R", 0) + comp.get("H", 0) >= 3:
|
| 159 |
+
analysis.append("High basic residue content enhances membrane binding.")
|
| 160 |
+
if comp.get("C", 0) + comp.get("W", 0) >= 2:
|
| 161 |
+
analysis.append("Multiple cysteine/tryptophan residues may improve activity.")
|
| 162 |
+
|
| 163 |
+
# Save to session state
|
| 164 |
+
st.session_state.analyze_input = seq
|
| 165 |
+
st.session_state.analyze_output = (label, conf, conf_display, comp, props, analysis)
|
| 166 |
+
|
| 167 |
+
# If we have stored output, display it
|
| 168 |
+
if st.session_state.analyze_output:
|
| 169 |
+
label, conf, conf_display, comp, props, analysis = st.session_state.analyze_output
|
| 170 |
+
|
| 171 |
+
st.subheader("AMP Prediction")
|
| 172 |
+
display_conf = round(conf * 100, 1) if label == "AMP" else round((1 - conf) * 100, 1)
|
| 173 |
+
st.write(f"Prediction: **{label}** with **{display_conf}%** confidence")
|
| 174 |
+
|
| 175 |
+
st.subheader("Amino Acid Composition")
|
| 176 |
+
comp_df = pd.DataFrame(list(comp.items()), columns=["Amino Acid", "Frequency"]).set_index("Amino Acid")
|
| 177 |
+
st.bar_chart(comp_df)
|
| 178 |
+
|
| 179 |
+
st.subheader("Physicochemical Properties and Favorability")
|
| 180 |
+
|
| 181 |
+
# pull properties safely
|
| 182 |
+
length = props.get("Length", len(st.session_state.analyze_input))
|
| 183 |
+
hydro = props.get("Hydrophobic Fraction", 0)
|
| 184 |
+
charge = props.get("Net Charge (approx.)", props.get("Net charge", 0))
|
| 185 |
+
mw = props.get("Molecular Weight (Da)", 0)
|
| 186 |
+
|
| 187 |
+
favorability = {
|
| 188 |
+
"Length": "Good" if 10 <= length <= 50 else "Too short" if length < 10 else "Too long",
|
| 189 |
+
"Hydrophobic Fraction": "Good" if 0.4 <= hydro <= 0.6 else "Low" if hydro < 0.4 else "High",
|
| 190 |
+
"Net Charge": "Favorable" if charge > 0 else "Neutral" if charge == 0 else "Unfavorable",
|
| 191 |
+
"Molecular Weight": "Acceptable" if 500 <= mw <= 5000 else "Extreme"
|
| 192 |
+
}
|
| 193 |
+
st.table(pd.DataFrame([
|
| 194 |
+
{"Property": "Length", "Value": length, "Favorability": favorability["Length"]},
|
| 195 |
+
{"Property": "Hydrophobic Fraction", "Value": hydro, "Favorability": favorability["Hydrophobic Fraction"]},
|
| 196 |
+
{"Property": "Net Charge", "Value": charge, "Favorability": favorability["Net Charge"]},
|
| 197 |
+
{"Property": "Molecular Weight", "Value": mw, "Favorability": favorability["Molecular Weight"]}
|
| 198 |
+
]))
|
| 199 |
+
|
| 200 |
+
st.subheader("Property Radar Chart")
|
| 201 |
+
categories = ["Length", "Hydrophobic Fraction", "Net Charge", "Molecular Weight"]
|
| 202 |
+
values = [min(length / 50, 1), min(hydro, 1), 1 if charge > 0 else 0, min(mw / 5000, 1)]
|
| 203 |
+
values += values[:1]
|
| 204 |
+
ideal_min = [10/50, 0.4, 1/6, 500/5000] + [10/50]
|
| 205 |
+
ideal_max = [50/50, 0.6, 6/6, 5000/5000] + [50/50]
|
| 206 |
+
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
|
| 207 |
+
angles += angles[:1]
|
| 208 |
+
|
| 209 |
+
# Adjusted figsize for better vertical space
|
| 210 |
+
fig, ax = plt.subplots(figsize=(2.8, 3.2), subplot_kw=dict(polar=True))
|
| 211 |
+
fig.patch.set_facecolor("white")
|
| 212 |
+
ax.fill_between(angles, ideal_min, ideal_max, color='#457a00', alpha=0.15, label="Ideal AMP range")
|
| 213 |
+
ax.plot(angles, values, 'o-', color='#457a00', linewidth=2, label="Sequence")
|
| 214 |
+
ax.fill(angles, values, color='#457a00', alpha=0.25)
|
| 215 |
+
ax.set_thetagrids(np.degrees(angles[:-1]), categories, fontsize=8)
|
| 216 |
+
ax.set_ylim(0, 1)
|
| 217 |
+
ax.tick_params(axis='y', labelsize=7)
|
| 218 |
+
ax.legend(loc='lower center', bbox_to_anchor=(0.85, 1.15), ncol=2, fontsize=7)
|
| 219 |
+
st.pyplot(fig, use_container_width=False)
|
| 220 |
+
|
| 221 |
+
# Analysis Summary
|
| 222 |
+
st.subheader("Analysis Summary")
|
| 223 |
+
for line in analysis:
|
| 224 |
+
st.write(f"- {line}")
|
| 225 |
+
|
| 226 |
+
# OPTIMIZE PAGE
|
| 227 |
+
elif page == "Optimize":
|
| 228 |
+
st.header("AMP Sequence Optimizer")
|
| 229 |
+
|
| 230 |
+
# Single entry point: text input retained across navigation
|
| 231 |
+
seq = st.text_input("Enter a peptide sequence to optimize:",
|
| 232 |
+
value=st.session_state.get("optimize_input", ""))
|
| 233 |
+
|
| 234 |
+
# Run optimization when user changes input and clicks button
|
| 235 |
+
if seq and st.button("Run Optimization"):
|
| 236 |
+
st.session_state.optimize_input = seq
|
| 237 |
+
with st.spinner("Optimizing sequence..."):
|
| 238 |
+
improved_seq, improved_conf, history = optimize_sequence(seq, model)
|
| 239 |
+
orig_label, orig_conf = predict_amp(seq, model)
|
| 240 |
+
st.session_state.optimize_output = (seq, orig_conf, improved_seq, improved_conf, history)
|
| 241 |
+
st.success("Optimization finished.")
|
| 242 |
+
|
| 243 |
+
# If there is saved output show it
|
| 244 |
+
if st.session_state.optimize_output:
|
| 245 |
+
orig_seq, orig_conf, improved_seq, improved_conf, history = st.session_state.optimize_output
|
| 246 |
+
st.subheader("Results")
|
| 247 |
+
st.write(f"**Original Sequence:** {orig_seq} — Confidence: {round(orig_conf*100,1)}%")
|
| 248 |
+
st.write(f"**Optimized Sequence:** {improved_seq} — Confidence: {round(improved_conf*100,1)}%")
|
| 249 |
+
|
| 250 |
+
if len(history) > 1:
|
| 251 |
+
df_steps = pd.DataFrame([{
|
| 252 |
+
"Step": i,
|
| 253 |
+
"Change": change,
|
| 254 |
+
"Old Type": old_type,
|
| 255 |
+
"New Type": new_type,
|
| 256 |
+
"Reason for Improvement": reason,
|
| 257 |
+
"New Confidence (%)": round(conf * 100, 2)
|
| 258 |
+
} for i, (seq_after, conf, change, old_type, new_type, reason) in enumerate(history[1:], start=1)])
|
| 259 |
+
st.subheader("Mutation Steps")
|
| 260 |
+
st.dataframe(df_steps, use_container_width=True)
|
| 261 |
+
|
| 262 |
+
# Confidence improvement plot
|
| 263 |
+
step_nums = df_steps["Step"].tolist()
|
| 264 |
+
conf_values = df_steps["New Confidence (%)"].tolist()
|
| 265 |
+
df_graph = pd.DataFrame({"Step": step_nums, "Confidence (%)": conf_values})
|
| 266 |
+
fig = px.line(df_graph, x="Step", y="Confidence (%)", markers=True, color_discrete_sequence=["#457a00"])
|
| 267 |
+
fig.update_layout(yaxis=dict(range=[0, 100]), title="Confidence Improvement Over Steps")
|
| 268 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 269 |
+
|
| 270 |
+
# VISUALIZE PAGE
|
| 271 |
+
elif page == "Visualize":
|
| 272 |
+
st.header("Sequence Embedding Visualization")
|
| 273 |
+
st.write("Upload peptide sequences (FASTA or plain list) to visualize embeddings with t-SNE.")
|
| 274 |
+
|
| 275 |
+
uploaded_file = st.file_uploader("Upload FASTA or text file", type=["txt", "fasta"])
|
| 276 |
+
|
| 277 |
+
# If file uploaded, set session sequences (replacing previous)
|
| 278 |
+
if uploaded_file:
|
| 279 |
+
text = uploaded_file.read().decode("utf-8")
|
| 280 |
+
sequences = [l.strip() for l in text.splitlines() if not l.startswith(">") and l.strip()]
|
| 281 |
+
st.session_state.visualize_sequences = sequences
|
| 282 |
+
|
| 283 |
+
# Clear any previous df so we recompute
|
| 284 |
+
st.session_state.visualize_df = None
|
| 285 |
+
|
| 286 |
+
# If we have sequences stored, compute embeddings and t-SNE if no df present
|
| 287 |
+
if st.session_state.visualize_sequences and st.session_state.visualize_df is None:
|
| 288 |
+
sequences = st.session_state.visualize_sequences
|
| 289 |
+
if len(sequences) < 2:
|
| 290 |
+
st.warning("Need at least 2 sequences for t-SNE visualization.")
|
| 291 |
+
else:
|
| 292 |
+
with st.spinner("Generating embeddings and running t-SNE..."):
|
| 293 |
+
embeddings_list, labels, confs, lengths, hydros, charges = [], [], [], [], [], []
|
| 294 |
+
|
| 295 |
+
# Use model internals for embeddings; keep same approach as your module
|
| 296 |
+
embedding_extractor = torch.nn.Sequential(*list(model.layers)[:-1])
|
| 297 |
+
|
| 298 |
+
for s in sequences:
|
| 299 |
+
x = torch.tensor(encode_sequence(s), dtype=torch.float32).unsqueeze(0)
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
emb = embedding_extractor(x).squeeze().numpy()
|
| 302 |
+
embeddings_list.append(emb)
|
| 303 |
+
label, conf = predict_amp(s, model)
|
| 304 |
+
labels.append(label)
|
| 305 |
+
confs.append(conf)
|
| 306 |
+
props = compute_properties(s)
|
| 307 |
+
lengths.append(props.get("Length", len(s)))
|
| 308 |
+
hydros.append(props.get("Hydrophobic Fraction", 0))
|
| 309 |
+
charges.append(props.get("Net Charge (approx.)", props.get("Net charge", 0)))
|
| 310 |
+
|
| 311 |
+
embeddings_array = np.stack(embeddings_list)
|
| 312 |
+
perplexity = min(30, max(2, len(sequences) - 1))
|
| 313 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
| 314 |
+
reduced = tsne.fit_transform(embeddings_array)
|
| 315 |
+
|
| 316 |
+
df = pd.DataFrame(reduced, columns=["x", "y"])
|
| 317 |
+
df["Sequence"] = sequences
|
| 318 |
+
df["Label"] = labels
|
| 319 |
+
df["Confidence"] = confs
|
| 320 |
+
df["Length"] = lengths
|
| 321 |
+
df["Hydrophobic Fraction"] = hydros
|
| 322 |
+
df["Net Charge"] = charges
|
| 323 |
+
|
| 324 |
+
st.session_state.visualize_df = df
|
| 325 |
+
|
| 326 |
+
# If we have a t-SNE dataframe, show plot and sidebar filters
|
| 327 |
+
if st.session_state.visualize_df is not None:
|
| 328 |
+
df = st.session_state.visualize_df
|
| 329 |
+
st.subheader("t-SNE plot")
|
| 330 |
+
|
| 331 |
+
st.sidebar.subheader("Filter Sequences")
|
| 332 |
+
min_len, max_len = int(df["Length"].min()), int(df["Length"].max())
|
| 333 |
+
if min_len == max_len:
|
| 334 |
+
st.sidebar.write(f"All sequences have length {min_len}")
|
| 335 |
+
length_range = (min_len, max_len)
|
| 336 |
+
else:
|
| 337 |
+
length_range = st.sidebar.slider("Sequence length", min_len, max_len, (min_len, max_len))
|
| 338 |
+
|
| 339 |
+
label_options = st.sidebar.multiselect("Label", ["AMP", "Non-AMP"], default=["AMP", "Non-AMP"])
|
| 340 |
+
filtered_df = df[(df["Length"].between(length_range[0], length_range[1])) & (df["Label"].isin(label_options))]
|
| 341 |
+
color_by = st.sidebar.selectbox("Color points by", ["Label", "Confidence", "Hydrophobic Fraction", "Net Charge", "Length"])
|
| 342 |
+
|
| 343 |
+
color_map = {"AMP": "#2ca02c", "Non-AMP": "#d62728"}
|
| 344 |
+
fig = px.scatter(
|
| 345 |
+
filtered_df,
|
| 346 |
+
x="x", y="y",
|
| 347 |
+
color=color_by if color_by != "Label" else "Label",
|
| 348 |
+
color_discrete_map=color_map if color_by == "Label" else None,
|
| 349 |
+
hover_data={"Sequence": True, "Label": True, "Confidence": True, "Length": True, "Hydrophobic Fraction": True, "Net Charge": True},
|
| 350 |
+
title="t-SNE Visualization of Model Embeddings"
|
| 351 |
+
)
|
| 352 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 353 |
+
|
| 354 |
+
st.subheader("t-SNE Analysis")
|
| 355 |
+
st.markdown("""
|
| 356 |
+
• Each point represents a peptide sequence.
|
| 357 |
+
• Sequences close together have similar internal representations in the model.
|
| 358 |
+
• AMP and Non-AMP clusters indicate strong model separation.
|
| 359 |
+
• Coloring by properties reveals biochemical trends.
|
| 360 |
+
""")
|
| 361 |
+
|
| 362 |
+
# ABOUT PAGE
|
| 363 |
+
elif page == "About":
|
| 364 |
+
st.header("About the Project")
|
| 365 |
+
st.markdown("""
|
| 366 |
+
**Problem:** Antimicrobial resistance is a global health threat. Traditional peptide screening is slow and costly.
|
| 367 |
+
**Solution:** This tool predicts antimicrobial activity directly from sequence using deep learning, speeding up AMP discovery.
|
| 368 |
+
""")
|
StreamlitApp/utils/__init__.py
ADDED
|
File without changes
|
StreamlitApp/utils/analyze.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
def aa_composition(sequence):
|
| 4 |
+
amino_acids = list("ACDEFGHIKLMNPQRSTVWY")
|
| 5 |
+
counts = Counter(sequence)
|
| 6 |
+
total = len(sequence)
|
| 7 |
+
return {aa: counts.get(aa, 0) / total for aa in amino_acids}
|
| 8 |
+
|
| 9 |
+
# Compute sequence properties
|
| 10 |
+
def compute_properties(sequence):
|
| 11 |
+
|
| 12 |
+
# Property calculations
|
| 13 |
+
aa_weights = {'A': 89.1, 'R': 174.2, 'N': 132.1, 'D': 133.1, 'C': 121.2,
|
| 14 |
+
'E': 147.1, 'Q': 146.2, 'G': 75.1, 'H': 155.2, 'I': 131.2,
|
| 15 |
+
'L': 131.2, 'K': 146.2, 'M': 149.2, 'F': 165.2, 'P': 115.1,
|
| 16 |
+
'S': 105.1, 'T': 119.1, 'W': 204.2, 'Y': 181.2, 'V': 117.1}
|
| 17 |
+
mw = sum(aa_weights.get(aa, 0) for aa in sequence)
|
| 18 |
+
hydrophobic = sum(1 for aa in sequence if aa in "AILMFWYV") / len(sequence)
|
| 19 |
+
charge = sum(1 for aa in sequence if aa in "KRH") - sum(1 for aa in sequence if aa in "DE")
|
| 20 |
+
return {"Length": len(sequence), "Molecular Weight (Da)": round(mw, 2),
|
| 21 |
+
"Hydrophobic Fraction": round(hydrophobic, 3), "Net Charge (approx.)": charge}
|
StreamlitApp/utils/optimize.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from utils.predict import predict_amp
|
| 3 |
+
|
| 4 |
+
HYDROPHOBIC = set("AILMFWVPG")
|
| 5 |
+
HYDROPHILIC = set("STNQYCH")
|
| 6 |
+
POSITIVE = set("KRH")
|
| 7 |
+
NEGATIVE = set("DE")
|
| 8 |
+
|
| 9 |
+
# Function to mutate a residue based on simple heuristics
|
| 10 |
+
def mutate_residue(residue):
|
| 11 |
+
if residue in POSITIVE:
|
| 12 |
+
return residue, "Retained strong positive residue"
|
| 13 |
+
elif residue in NEGATIVE:
|
| 14 |
+
return random.choice(list(POSITIVE)), "Increased positive charge"
|
| 15 |
+
elif residue in HYDROPHILIC:
|
| 16 |
+
return random.choice(list(HYDROPHOBIC)), "Improved hydrophobicity balance"
|
| 17 |
+
elif residue in HYDROPHOBIC:
|
| 18 |
+
return random.choice(list(POSITIVE | HYDROPHILIC)), "Enhanced amphipathicity"
|
| 19 |
+
else:
|
| 20 |
+
return random.choice(list(HYDROPHOBIC)), "Adjusted physicochemical profile"
|
| 21 |
+
|
| 22 |
+
# Sequence optimization function
|
| 23 |
+
def optimize_sequence(seq, model, max_rounds=20, confidence_threshold=0.001):
|
| 24 |
+
"""
|
| 25 |
+
Iteratively optimize sequence to increase AMP probability.
|
| 26 |
+
Tries mutating all positions per round and accepts the best change.
|
| 27 |
+
"""
|
| 28 |
+
current_seq = seq
|
| 29 |
+
label, conf = predict_amp(current_seq, model)
|
| 30 |
+
best_conf = conf
|
| 31 |
+
history = [(current_seq, conf, "-", "-", "-", "Original sequence")]
|
| 32 |
+
|
| 33 |
+
# Optimization loop
|
| 34 |
+
for _ in range(max_rounds):
|
| 35 |
+
best_mutation = None
|
| 36 |
+
best_mutation_conf = best_conf
|
| 37 |
+
|
| 38 |
+
for pos, old_res in enumerate(current_seq):
|
| 39 |
+
new_res, reason = mutate_residue(old_res)
|
| 40 |
+
if new_res == old_res:
|
| 41 |
+
continue
|
| 42 |
+
new_seq = current_seq[:pos] + new_res + current_seq[pos+1:]
|
| 43 |
+
_, new_conf = predict_amp(new_seq, model)
|
| 44 |
+
|
| 45 |
+
if new_conf > best_mutation_conf:
|
| 46 |
+
best_mutation_conf = new_conf
|
| 47 |
+
best_mutation = (new_seq, pos, old_res, new_res, reason)
|
| 48 |
+
|
| 49 |
+
if best_mutation and best_mutation_conf - best_conf >= confidence_threshold:
|
| 50 |
+
current_seq, pos, old_res, new_res, reason = best_mutation
|
| 51 |
+
best_conf = best_mutation_conf
|
| 52 |
+
change = f"Pos {pos+1}: {old_res} → {new_res}"
|
| 53 |
+
history.append((current_seq, best_conf, change, old_res, new_res, reason))
|
| 54 |
+
else:
|
| 55 |
+
|
| 56 |
+
# No further improvement, stop
|
| 57 |
+
break
|
| 58 |
+
|
| 59 |
+
return current_seq, best_conf, history
|
StreamlitApp/utils/predict.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import requests
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from torch import nn
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import shutil
|
| 10 |
+
|
| 11 |
+
# Model Definition
|
| 12 |
+
class FastMLP(nn.Module):
|
| 13 |
+
def __init__(self, input_dim=1024):
|
| 14 |
+
super(FastMLP, self).__init__()
|
| 15 |
+
self.layers = nn.Sequential(
|
| 16 |
+
nn.Linear(input_dim, 512),
|
| 17 |
+
nn.ReLU(),
|
| 18 |
+
nn.Dropout(0.3),
|
| 19 |
+
nn.Linear(512, 128),
|
| 20 |
+
nn.ReLU(),
|
| 21 |
+
nn.Linear(128, 1) # Single output for binary classification
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.layers(x)
|
| 26 |
+
|
| 27 |
+
# Utility: download file from URL to local path (streaming)
|
| 28 |
+
def _download_file(url: str, dest_path: str):
|
| 29 |
+
dest = pathlib.Path(dest_path)
|
| 30 |
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
| 31 |
+
with requests.get(url, stream=True) as r:
|
| 32 |
+
r.raise_for_status()
|
| 33 |
+
with open(dest, 'wb') as f:
|
| 34 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 35 |
+
if chunk:
|
| 36 |
+
f.write(chunk)
|
| 37 |
+
|
| 38 |
+
def _get_env(key: str) -> Optional[str]:
|
| 39 |
+
v = os.environ.get(key)
|
| 40 |
+
return v if v else None
|
| 41 |
+
|
| 42 |
+
# Model Loader
|
| 43 |
+
@st.cache_resource
|
| 44 |
+
def load_model():
|
| 45 |
+
# Always resolve relative to the StreamlitApp folder, not the process CWD.
|
| 46 |
+
streamlitapp_dir = pathlib.Path(__file__).resolve().parent.parent
|
| 47 |
+
model_path = streamlitapp_dir / "models" / "ampMLModel.pt"
|
| 48 |
+
|
| 49 |
+
# If the model file doesn't exist, try to download it from a configured URL
|
| 50 |
+
if not model_path.exists():
|
| 51 |
+
|
| 52 |
+
model_url = _get_env("MODEL_URL")
|
| 53 |
+
|
| 54 |
+
if model_url:
|
| 55 |
+
try:
|
| 56 |
+
_download_file(model_url, str(model_path))
|
| 57 |
+
except Exception as e:
|
| 58 |
+
st.error(f"Failed to download model from MODEL_URL: {e}")
|
| 59 |
+
raise
|
| 60 |
+
else:
|
| 61 |
+
# Fall back to Hugging Face Hub model repo download.
|
| 62 |
+
# Configure these in HF Space secrets/vars, or locally in env:
|
| 63 |
+
# - MODEL_REPO_ID (e.g. "m0ksh/peptideai-models")
|
| 64 |
+
# - MODEL_FILENAME (default: "ampMLModel.pt")
|
| 65 |
+
model_repo_id = _get_env("MODEL_REPO_ID")
|
| 66 |
+
model_filename = _get_env("MODEL_FILENAME") or "ampMLModel.pt"
|
| 67 |
+
|
| 68 |
+
if not model_repo_id:
|
| 69 |
+
raise FileNotFoundError(
|
| 70 |
+
"Model file './models/ampMLModel.pt' not found.\n"
|
| 71 |
+
"Set one of:\n"
|
| 72 |
+
"- MODEL_URL (direct download URL), or\n"
|
| 73 |
+
"- MODEL_REPO_ID (Hugging Face model repo id) and optional MODEL_FILENAME.\n"
|
| 74 |
+
"\n"
|
| 75 |
+
"Debug (env vars detected): "
|
| 76 |
+
f"MODEL_URL={'set' if _get_env('MODEL_URL') else 'missing'}, "
|
| 77 |
+
f"MODEL_REPO_ID={'set' if _get_env('MODEL_REPO_ID') else 'missing'}, "
|
| 78 |
+
f"MODEL_FILENAME={'set' if _get_env('MODEL_FILENAME') else 'missing'}\n"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
from huggingface_hub import hf_hub_download
|
| 83 |
+
except Exception as e:
|
| 84 |
+
raise RuntimeError(
|
| 85 |
+
"Missing dependency 'huggingface_hub'. Add it to requirements.txt.\n"
|
| 86 |
+
f"Import error: {e}"
|
| 87 |
+
) from e
|
| 88 |
+
|
| 89 |
+
token = _get_env("HF_TOKEN") or _get_env("HUGGINGFACE_TOKEN")
|
| 90 |
+
downloaded_path = hf_hub_download(
|
| 91 |
+
repo_id=model_repo_id,
|
| 92 |
+
filename=model_filename,
|
| 93 |
+
token=token,
|
| 94 |
+
)
|
| 95 |
+
model_path.parent.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
shutil.copyfile(downloaded_path, model_path)
|
| 97 |
+
|
| 98 |
+
if not model_path.exists():
|
| 99 |
+
raise FileNotFoundError(
|
| 100 |
+
f"Model download did not produce file at: {model_path}\n"
|
| 101 |
+
"Check MODEL_URL or MODEL_REPO_ID/MODEL_FILENAME configuration."
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Build model and load weights
|
| 105 |
+
model = FastMLP(input_dim=1024)
|
| 106 |
+
model.load_state_dict(torch.load(str(model_path), map_location="cpu"))
|
| 107 |
+
model.eval()
|
| 108 |
+
return model
|
| 109 |
+
|
| 110 |
+
# Sequence Encoder
|
| 111 |
+
def encode_sequence(seq, max_len=51):
|
| 112 |
+
"""
|
| 113 |
+
Converts amino acid sequence to flattened one-hot vector
|
| 114 |
+
padded/truncated to match model input_dim (1024)
|
| 115 |
+
"""
|
| 116 |
+
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
|
| 117 |
+
aa_to_idx = {aa: i for i, aa in enumerate(amino_acids)}
|
| 118 |
+
|
| 119 |
+
one_hot = np.zeros((max_len, len(amino_acids))) # max_len x 20
|
| 120 |
+
for i, aa in enumerate(seq[:max_len]):
|
| 121 |
+
if aa in aa_to_idx:
|
| 122 |
+
one_hot[i, aa_to_idx[aa]] = 1
|
| 123 |
+
|
| 124 |
+
flat = one_hot.flatten() # length = max_len*20 = 1020
|
| 125 |
+
|
| 126 |
+
if len(flat) < 1024:
|
| 127 |
+
flat = np.pad(flat, (0, 1024 - len(flat)))
|
| 128 |
+
|
| 129 |
+
return flat
|
| 130 |
+
|
| 131 |
+
# Prediction Function
|
| 132 |
+
def predict_amp(sequence, model):
|
| 133 |
+
"""
|
| 134 |
+
Takes an amino acid sequence string and the loaded model,
|
| 135 |
+
returns ("AMP"/"Non-AMP") and probability
|
| 136 |
+
"""
|
| 137 |
+
x = torch.tensor(encode_sequence(sequence), dtype=torch.float32).unsqueeze(0)
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
logits = model(x)
|
| 141 |
+
prob = torch.sigmoid(logits).item()
|
| 142 |
+
|
| 143 |
+
label = "AMP" if prob >= 0.5 else "Non-AMP"
|
| 144 |
+
return label, round(prob, 3)
|
StreamlitApp/utils/rateLimit.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
from collections import deque
|
| 3 |
+
|
| 4 |
+
class RateLimiter:
|
| 5 |
+
|
| 6 |
+
#Sliding-window rate limiter per instance
|
| 7 |
+
def __init__(self, max_calls: int, period_seconds: float):
|
| 8 |
+
self.max_calls = max_calls
|
| 9 |
+
self.period = period_seconds
|
| 10 |
+
self.calls = deque()
|
| 11 |
+
|
| 12 |
+
def allow(self) -> bool:
|
| 13 |
+
now = time.time()
|
| 14 |
+
|
| 15 |
+
# Drop entries older than window
|
| 16 |
+
while self.calls and self.calls[0] <= now - self.period:
|
| 17 |
+
self.calls.popleft()
|
| 18 |
+
if len(self.calls) < self.max_calls:
|
| 19 |
+
self.calls.append(now)
|
| 20 |
+
return True
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
def time_until_next(self) -> float:
|
| 24 |
+
|
| 25 |
+
# Seconds until next slot is available (0 if already available)
|
| 26 |
+
now = time.time()
|
| 27 |
+
if len(self.calls) < self.max_calls:
|
| 28 |
+
return 0.0
|
| 29 |
+
oldest = self.calls[0]
|
| 30 |
+
return max(0.0, (oldest + self.period) - now)
|
StreamlitApp/utils/visualize.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
from sklearn.manifold import TSNE
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from utils.predict import encode_sequence
|
| 8 |
+
|
| 9 |
+
# t-SNE Visualization
|
| 10 |
+
def tsne_visualization(sequences, model):
|
| 11 |
+
st.info("Generating embeddings... this may take a moment.")
|
| 12 |
+
embeddings = []
|
| 13 |
+
for seq in sequences:
|
| 14 |
+
x = torch.tensor(encode_sequence(seq), dtype=torch.float32).unsqueeze(0)
|
| 15 |
+
with torch.no_grad():
|
| 16 |
+
emb = model.layers[0](x) # Grab first layer embedding
|
| 17 |
+
embeddings.append(emb.numpy().flatten())
|
| 18 |
+
|
| 19 |
+
embeddings = np.vstack(embeddings)
|
| 20 |
+
|
| 21 |
+
perplexity = min(30, len(sequences) - 1)
|
| 22 |
+
if perplexity < 2:
|
| 23 |
+
st.warning("Need at least 2 sequences for visualization.")
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
| 27 |
+
reduced = tsne.fit_transform(embeddings)
|
| 28 |
+
df = pd.DataFrame(reduced, columns=["x", "y"])
|
| 29 |
+
|
| 30 |
+
st.success("t-SNE visualization complete.")
|
| 31 |
+
st.scatter_chart(df)
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
torch
|
| 5 |
+
scikit-learn
|
| 6 |
+
matplotlib
|
| 7 |
+
plotly
|
| 8 |
+
requests
|
| 9 |
+
huggingface_hub
|
space.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sdk: streamlit
|
| 2 |
+
name: peptideai
|
| 3 |
+
title: PeptideAI
|
| 4 |
+
emoji: 🔬
|
| 5 |
+
app_file: PeptideAI/StreamlitApp/StreamlitApp.py
|
| 6 |
+
python_version: "3.13"
|