File size: 3,083 Bytes
a84d7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Example: Protein subcellular localization prediction using different aggregation methods.

Demonstrates how to use ProteinSequenceClassifier with all 6 aggregation strategies.

Usage:
    python example_localization.py
"""

import torch
from protein_aggregator import ProteinSequenceClassifier


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Example: subcellular localization with 10 classes
    LOCALIZATION_CLASSES = [
        "Cytoplasm", "Nucleus", "Cell membrane", "Mitochondrion",
        "Endoplasmic reticulum", "Golgi apparatus", "Lysosome",
        "Peroxisome", "Extracellular", "Plastid",
    ]

    # Test sequences (short examples)
    sequences = [
        "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSG",
        "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY",
        "MATLEKLMKAFESLKSFQHHMKAGPFLKENSSYRQNIDNFSDNFIDNF",
    ]

    print("=" * 70)
    print("Protein Subcellular Localization — Aggregation Method Comparison")
    print("=" * 70)

    for agg_name in ["mean", "max", "cls", "glot", "glot_residue", "covariance"]:
        # Build model
        agg_kwargs = {}
        if agg_name == "glot":
            agg_kwargs = {"p": 128, "K": 2, "tau": 0.6, "n_heads": 4}
        elif agg_name == "glot_residue":
            agg_kwargs = {"p": 128, "K": 2, "seq_neighbor_k": 5, "n_heads": 4}
        elif agg_name == "covariance":
            agg_kwargs = {"d_proj": 64}

        model = ProteinSequenceClassifier(
            esm2_model_name="facebook/esm2_t12_35M_UR50D",  # changeable!
            aggregation=agg_name,
            num_classes=len(LOCALIZATION_CLASSES),
            aggregator_kwargs=agg_kwargs,
            classifier_hidden=256,
            dropout=0.1,
        ).to(device)

        # Get predictions (untrained — just demonstrating the pipeline)
        model.eval()
        with torch.no_grad():
            embeddings = model.encode(sequences, device=device)
            inputs = model.tokenizer(
                sequences, padding=True, truncation=True,
                max_length=1024, return_tensors="pt",
            ).to(device)
            outputs = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
            )

        probs = torch.softmax(outputs["logits"], dim=-1)

        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"\n--- {agg_name.upper()} (trainable: {trainable:,}, emb_dim: {model.aggregator.out_dim}) ---")
        for i, seq in enumerate(sequences):
            pred_class = probs[i].argmax().item()
            confidence = probs[i].max().item()
            print(f"  Seq {i+1} ({seq[:20]}...): {LOCALIZATION_CLASSES[pred_class]} ({confidence:.1%})")

        del model
        torch.cuda.empty_cache()

    print("\n" + "=" * 70)
    print("NOTE: Predictions above are from untrained models (random weights).")
    print("Train on a real localization dataset to get meaningful predictions.")
    print("=" * 70)


if __name__ == "__main__":
    main()