""" Example: Protein subcellular localization prediction using different aggregation methods. Demonstrates how to use ProteinSequenceClassifier with all 6 aggregation strategies. Usage: python example_localization.py """ import torch from protein_aggregator import ProteinSequenceClassifier def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Example: subcellular localization with 10 classes LOCALIZATION_CLASSES = [ "Cytoplasm", "Nucleus", "Cell membrane", "Mitochondrion", "Endoplasmic reticulum", "Golgi apparatus", "Lysosome", "Peroxisome", "Extracellular", "Plastid", ] # Test sequences (short examples) sequences = [ "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSG", "ACDEFGHIKLMNPQRSTVWYACDEFGHIKLMNPQRSTVWY", "MATLEKLMKAFESLKSFQHHMKAGPFLKENSSYRQNIDNFSDNFIDNF", ] print("=" * 70) print("Protein Subcellular Localization — Aggregation Method Comparison") print("=" * 70) for agg_name in ["mean", "max", "cls", "glot", "glot_residue", "covariance"]: # Build model agg_kwargs = {} if agg_name == "glot": agg_kwargs = {"p": 128, "K": 2, "tau": 0.6, "n_heads": 4} elif agg_name == "glot_residue": agg_kwargs = {"p": 128, "K": 2, "seq_neighbor_k": 5, "n_heads": 4} elif agg_name == "covariance": agg_kwargs = {"d_proj": 64} model = ProteinSequenceClassifier( esm2_model_name="facebook/esm2_t12_35M_UR50D", # changeable! aggregation=agg_name, num_classes=len(LOCALIZATION_CLASSES), aggregator_kwargs=agg_kwargs, classifier_hidden=256, dropout=0.1, ).to(device) # Get predictions (untrained — just demonstrating the pipeline) model.eval() with torch.no_grad(): embeddings = model.encode(sequences, device=device) inputs = model.tokenizer( sequences, padding=True, truncation=True, max_length=1024, return_tensors="pt", ).to(device) outputs = model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], ) probs = torch.softmax(outputs["logits"], dim=-1) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n--- {agg_name.upper()} (trainable: {trainable:,}, emb_dim: {model.aggregator.out_dim}) ---") for i, seq in enumerate(sequences): pred_class = probs[i].argmax().item() confidence = probs[i].max().item() print(f" Seq {i+1} ({seq[:20]}...): {LOCALIZATION_CLASSES[pred_class]} ({confidence:.1%})") del model torch.cuda.empty_cache() print("\n" + "=" * 70) print("NOTE: Predictions above are from untrained models (random weights).") print("Train on a real localization dataset to get meaningful predictions.") print("=" * 70) if __name__ == "__main__": main()