File size: 8,333 Bytes
411e478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""
ESM2 backbone + pluggable aggregation head + classification head.

The ESM2 backbone is always frozen.  Only the aggregation module and the
classifier head are trained.

ESM2 model variants (all from facebook):
    esm2_t6_8M_UR50D      ->  d=320,   8M params
    esm2_t12_35M_UR50D    ->  d=480,  35M params   (default)
    esm2_t30_150M_UR50D   ->  d=640, 150M params
    esm2_t33_650M_UR50D   ->  d=1280, 650M params
    esm2_t36_3B_UR50D     ->  d=2560,  3B params
"""

from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from transformers import AutoTokenizer, EsmModel

from .aggregators import (
    CLSPooling,
    CovariancePooling,
    GLOTPooling,
    GLOTResidueGraphPooling,
    MaxPooling,
    MeanPooling,
)

# Map of aggregation method names to classes
AGGREGATOR_REGISTRY = {
    "mean": MeanPooling,
    "max": MaxPooling,
    "cls": CLSPooling,
    "glot": GLOTPooling,
    "glot_residue": GLOTResidueGraphPooling,
    "covariance": CovariancePooling,
}

# ESM2 hidden dimensions by model name
ESM2_HIDDEN_DIMS = {
    "facebook/esm2_t6_8M_UR50D": 320,
    "facebook/esm2_t12_35M_UR50D": 480,
    "facebook/esm2_t30_150M_UR50D": 640,
    "facebook/esm2_t33_650M_UR50D": 1280,
    "facebook/esm2_t36_3B_UR50D": 2560,
}


class ProteinSequenceClassifier(nn.Module):
    """End-to-end model: frozen ESM2 -> aggregation -> classification.

    Args:
        esm2_model_name:  HuggingFace model ID for ESM2.
        aggregation:      Name of aggregation method (see AGGREGATOR_REGISTRY).
        num_classes:      Number of output classes.
        aggregator_kwargs: Extra arguments passed to the aggregator constructor.
        classifier_hidden: If >0, adds a hidden layer in the classifier head.
        dropout:          Dropout rate before the classifier.
        strip_special_tokens: If True (default for mean/max/glot/glot_residue/covariance),
                              strips the <cls> and <eos> tokens from the ESM2 output
                              before aggregation.  CLS pooling operates on the raw output.
    """

    def __init__(
        self,
        esm2_model_name: str = "facebook/esm2_t12_35M_UR50D",
        aggregation: str = "mean",
        num_classes: int = 10,
        aggregator_kwargs: Optional[Dict] = None,
        classifier_hidden: int = 0,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.esm2_model_name = esm2_model_name
        self.aggregation_name = aggregation

        # ---- ESM2 backbone (frozen) ----
        self.esm2 = EsmModel.from_pretrained(esm2_model_name)
        for param in self.esm2.parameters():
            param.requires_grad = False
        self.esm2.eval()

        # ---- Determine hidden size ----
        self.d_esm2 = ESM2_HIDDEN_DIMS.get(
            esm2_model_name, self.esm2.config.hidden_size
        )

        # ---- Aggregation head ----
        if aggregation not in AGGREGATOR_REGISTRY:
            raise ValueError(
                f"Unknown aggregation '{aggregation}'. "
                f"Choose from: {list(AGGREGATOR_REGISTRY.keys())}"
            )

        agg_cls = AGGREGATOR_REGISTRY[aggregation]
        agg_kwargs = aggregator_kwargs or {}
        self.aggregator = agg_cls(d_in=self.d_esm2, **agg_kwargs)

        # Whether to strip <cls>/<eos> before aggregation
        self.strip_special = aggregation != "cls"

        # ---- Classification head ----
        agg_dim = self.aggregator.out_dim
        if classifier_hidden > 0:
            self.classifier = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(agg_dim, classifier_hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(classifier_hidden, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(agg_dim, num_classes),
            )

    @property
    def tokenizer(self):
        """Lazy-load tokenizer."""
        if not hasattr(self, "_tokenizer"):
            self._tokenizer = AutoTokenizer.from_pretrained(self.esm2_model_name)
        return self._tokenizer

    def get_residue_embeddings(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> tuple:
        """Extract per-residue embeddings from frozen ESM2.

        Returns:
            token_embeddings: [B, L, d]  (optionally with special tokens stripped)
            mask:             [B, L]
        """
        with torch.no_grad():
            outputs = self.esm2(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )

        hidden_states = outputs.last_hidden_state  # [B, L_full, d]

        if self.strip_special:
            # Strip <cls> (pos 0) and <eos> (last valid position)
            # For ESM2: input is [<cls>, AA1, AA2, ..., AAN, <eos>, <pad>, ...]
            token_embeddings = hidden_states[:, 1:, :]  # remove <cls>
            mask = attention_mask[:, 1:].clone()  # adjust mask

            # Now remove the <eos> token for each sequence
            # The <eos> is the last 1 in the mask (before padding)
            B, L = mask.shape
            # Find the position of the last 1 in each row
            lengths = mask.sum(dim=1).long()  # number of valid tokens after removing <cls>
            for i in range(B):
                if lengths[i] > 0:
                    mask[i, lengths[i] - 1] = 0  # zero out <eos> position
        else:
            token_embeddings = hidden_states
            mask = attention_mask

        return token_embeddings, mask

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        pdb_paths: Optional[List[Optional[str]]] = None,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """
        Args:
            input_ids:      [B, L] tokenized protein sequences.
            attention_mask: [B, L] attention mask.
            labels:         [B] class labels (optional, for loss computation).
            pdb_paths:      List of PDB file paths (only for glot_residue aggregation).

        Returns:
            Dict with keys: 'logits', optionally 'loss', 'embeddings'.
        """
        # Extract residue embeddings from frozen ESM2
        token_embeddings, mask = self.get_residue_embeddings(input_ids, attention_mask)

        # Aggregate to sequence-level
        extra_kwargs = {}
        if pdb_paths is not None:
            extra_kwargs["pdb_paths"] = pdb_paths

        sequence_embedding = self.aggregator(
            token_embeddings, mask, **extra_kwargs
        )  # [B, agg_dim]

        # Classify
        logits = self.classifier(sequence_embedding)  # [B, num_classes]

        result = {"logits": logits, "embeddings": sequence_embedding}

        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            result["loss"] = loss_fn(logits, labels)

        return result

    def encode(
        self,
        sequences: Union[str, List[str]],
        pdb_paths: Optional[List[Optional[str]]] = None,
        max_length: int = 1024,
        device: Optional[torch.device] = None,
    ) -> torch.Tensor:
        """Convenience method: tokenize + forward to get sequence embeddings.

        Args:
            sequences: Single protein sequence or list of sequences.
            pdb_paths: Optional PDB paths for glot_residue aggregation.
            max_length: Maximum sequence length (ESM2 supports up to 1026).
            device: Device to run on.

        Returns:
            Sequence-level embeddings [B, agg_dim].
        """
        if isinstance(sequences, str):
            sequences = [sequences]

        if device is None:
            device = next(self.parameters()).device

        inputs = self.tokenizer(
            sequences,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        ).to(device)

        self.eval()
        with torch.no_grad():
            outputs = self.forward(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pdb_paths=pdb_paths,
            )

        return outputs["embeddings"]