| """ |
| Example: Protein subcellular localization prediction using different aggregation methods. |
| |
| Demonstrates how to use ProteinSequenceClassifier with all 6 aggregation strategies. |
| |
| Usage: |
| python example_localization.py |
| """ |
|
|
| import torch |
| from protein_aggregator import ProteinSequenceClassifier |
|
|
|
|
| def main(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| LOCALIZATION_CLASSES = [ |
| "Cytoplasm", "Nucleus", "Cell membrane", "Mitochondrion", |
| "Endoplasmic reticulum", "Golgi apparatus", "Lysosome", |
| "Peroxisome", "Extracellular", "Plastid", |
| ] |
|
|
| |
| sequences = [ |
| "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSG", |
| "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY", |
| "MATLEKLMKAFESLKSFQHHMKAGPFLKENSSYRQNIDNFSDNFIDNF", |
| ] |
|
|
| print("=" * 70) |
| print("Protein Subcellular Localization — Aggregation Method Comparison") |
| print("=" * 70) |
|
|
| for agg_name in ["mean", "max", "cls", "glot", "glot_residue", "covariance"]: |
| |
| agg_kwargs = {} |
| if agg_name == "glot": |
| agg_kwargs = {"p": 128, "K": 2, "tau": 0.6, "n_heads": 4} |
| elif agg_name == "glot_residue": |
| agg_kwargs = {"p": 128, "K": 2, "seq_neighbor_k": 5, "n_heads": 4} |
| elif agg_name == "covariance": |
| agg_kwargs = {"d_proj": 64} |
|
|
| model = ProteinSequenceClassifier( |
| esm2_model_name="facebook/esm2_t12_35M_UR50D", |
| aggregation=agg_name, |
| num_classes=len(LOCALIZATION_CLASSES), |
| aggregator_kwargs=agg_kwargs, |
| classifier_hidden=256, |
| dropout=0.1, |
| ).to(device) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| embeddings = model.encode(sequences, device=device) |
| inputs = model.tokenizer( |
| sequences, padding=True, truncation=True, |
| max_length=1024, return_tensors="pt", |
| ).to(device) |
| outputs = model( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| ) |
|
|
| probs = torch.softmax(outputs["logits"], dim=-1) |
|
|
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\n--- {agg_name.upper()} (trainable: {trainable:,}, emb_dim: {model.aggregator.out_dim}) ---") |
| for i, seq in enumerate(sequences): |
| pred_class = probs[i].argmax().item() |
| confidence = probs[i].max().item() |
| print(f" Seq {i+1} ({seq[:20]}...): {LOCALIZATION_CLASSES[pred_class]} ({confidence:.1%})") |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| print("\n" + "=" * 70) |
| print("NOTE: Predictions above are from untrained models (random weights).") |
| print("Train on a real localization dataset to get meaningful predictions.") |
| print("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|