AliSaadatV commited on
Commit
a84d7b6
·
verified ·
1 Parent(s): 411e478

Add example localization script

Browse files
Files changed (1) hide show
  1. example_localization.py +86 -0
example_localization.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example: Protein subcellular localization prediction using different aggregation methods.
3
+
4
+ Demonstrates how to use ProteinSequenceClassifier with all 6 aggregation strategies.
5
+
6
+ Usage:
7
+ python example_localization.py
8
+ """
9
+
10
+ import torch
11
+ from protein_aggregator import ProteinSequenceClassifier
12
+
13
+
14
+ def main():
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Example: subcellular localization with 10 classes
18
+ LOCALIZATION_CLASSES = [
19
+ "Cytoplasm", "Nucleus", "Cell membrane", "Mitochondrion",
20
+ "Endoplasmic reticulum", "Golgi apparatus", "Lysosome",
21
+ "Peroxisome", "Extracellular", "Plastid",
22
+ ]
23
+
24
+ # Test sequences (short examples)
25
+ sequences = [
26
+ "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSG",
27
+ "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY",
28
+ "MATLEKLMKAFESLKSFQHHMKAGPFLKENSSYRQNIDNFSDNFIDNF",
29
+ ]
30
+
31
+ print("=" * 70)
32
+ print("Protein Subcellular Localization — Aggregation Method Comparison")
33
+ print("=" * 70)
34
+
35
+ for agg_name in ["mean", "max", "cls", "glot", "glot_residue", "covariance"]:
36
+ # Build model
37
+ agg_kwargs = {}
38
+ if agg_name == "glot":
39
+ agg_kwargs = {"p": 128, "K": 2, "tau": 0.6, "n_heads": 4}
40
+ elif agg_name == "glot_residue":
41
+ agg_kwargs = {"p": 128, "K": 2, "seq_neighbor_k": 5, "n_heads": 4}
42
+ elif agg_name == "covariance":
43
+ agg_kwargs = {"d_proj": 64}
44
+
45
+ model = ProteinSequenceClassifier(
46
+ esm2_model_name="facebook/esm2_t12_35M_UR50D", # changeable!
47
+ aggregation=agg_name,
48
+ num_classes=len(LOCALIZATION_CLASSES),
49
+ aggregator_kwargs=agg_kwargs,
50
+ classifier_hidden=256,
51
+ dropout=0.1,
52
+ ).to(device)
53
+
54
+ # Get predictions (untrained — just demonstrating the pipeline)
55
+ model.eval()
56
+ with torch.no_grad():
57
+ embeddings = model.encode(sequences, device=device)
58
+ inputs = model.tokenizer(
59
+ sequences, padding=True, truncation=True,
60
+ max_length=1024, return_tensors="pt",
61
+ ).to(device)
62
+ outputs = model(
63
+ input_ids=inputs["input_ids"],
64
+ attention_mask=inputs["attention_mask"],
65
+ )
66
+
67
+ probs = torch.softmax(outputs["logits"], dim=-1)
68
+
69
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
70
+ print(f"\n--- {agg_name.upper()} (trainable: {trainable:,}, emb_dim: {model.aggregator.out_dim}) ---")
71
+ for i, seq in enumerate(sequences):
72
+ pred_class = probs[i].argmax().item()
73
+ confidence = probs[i].max().item()
74
+ print(f" Seq {i+1} ({seq[:20]}...): {LOCALIZATION_CLASSES[pred_class]} ({confidence:.1%})")
75
+
76
+ del model
77
+ torch.cuda.empty_cache()
78
+
79
+ print("\n" + "=" * 70)
80
+ print("NOTE: Predictions above are from untrained models (random weights).")
81
+ print("Train on a real localization dataset to get meaningful predictions.")
82
+ print("=" * 70)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()