protein-sequence-aggregators / example_localization.py
AliSaadatV's picture
Add example localization script
a84d7b6 verified
"""
Example: Protein subcellular localization prediction using different aggregation methods.
Demonstrates how to use ProteinSequenceClassifier with all 6 aggregation strategies.
Usage:
python example_localization.py
"""
import torch
from protein_aggregator import ProteinSequenceClassifier
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Example: subcellular localization with 10 classes
LOCALIZATION_CLASSES = [
"Cytoplasm", "Nucleus", "Cell membrane", "Mitochondrion",
"Endoplasmic reticulum", "Golgi apparatus", "Lysosome",
"Peroxisome", "Extracellular", "Plastid",
]
# Test sequences (short examples)
sequences = [
"MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSG",
"ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY",
"MATLEKLMKAFESLKSFQHHMKAGPFLKENSSYRQNIDNFSDNFIDNF",
]
print("=" * 70)
print("Protein Subcellular Localization — Aggregation Method Comparison")
print("=" * 70)
for agg_name in ["mean", "max", "cls", "glot", "glot_residue", "covariance"]:
# Build model
agg_kwargs = {}
if agg_name == "glot":
agg_kwargs = {"p": 128, "K": 2, "tau": 0.6, "n_heads": 4}
elif agg_name == "glot_residue":
agg_kwargs = {"p": 128, "K": 2, "seq_neighbor_k": 5, "n_heads": 4}
elif agg_name == "covariance":
agg_kwargs = {"d_proj": 64}
model = ProteinSequenceClassifier(
esm2_model_name="facebook/esm2_t12_35M_UR50D", # changeable!
aggregation=agg_name,
num_classes=len(LOCALIZATION_CLASSES),
aggregator_kwargs=agg_kwargs,
classifier_hidden=256,
dropout=0.1,
).to(device)
# Get predictions (untrained — just demonstrating the pipeline)
model.eval()
with torch.no_grad():
embeddings = model.encode(sequences, device=device)
inputs = model.tokenizer(
sequences, padding=True, truncation=True,
max_length=1024, return_tensors="pt",
).to(device)
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
)
probs = torch.softmax(outputs["logits"], dim=-1)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n--- {agg_name.upper()} (trainable: {trainable:,}, emb_dim: {model.aggregator.out_dim}) ---")
for i, seq in enumerate(sequences):
pred_class = probs[i].argmax().item()
confidence = probs[i].max().item()
print(f" Seq {i+1} ({seq[:20]}...): {LOCALIZATION_CLASSES[pred_class]} ({confidence:.1%})")
del model
torch.cuda.empty_cache()
print("\n" + "=" * 70)
print("NOTE: Predictions above are from untrained models (random weights).")
print("Train on a real localization dataset to get meaningful predictions.")
print("=" * 70)
if __name__ == "__main__":
main()