Zhaohan-Meng commited on
Commit
9397da2
·
1 Parent(s): 3e5df8f

Add ExplainBind Demo

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ bin/foldseek filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .ipynb_checkpoints/
3
+ __pycache__/
4
+ *.pyc
5
+ bin/
6
+ *.log
README.md CHANGED
@@ -8,7 +8,8 @@ sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: An interaction-aware demo UI
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: **ExplainBind** is an interaction-aware framework for **protein–ligand binding (PLB)** prediction.
12
+ It supervises token-level cross-attention using **non-covalent interaction maps** (e.g. hydrogen bonds, salt bridges, hydrophobic contacts, van der Waals, π–π, and cation–π interactions) derived from curated **PDB** protein–ligand complexes in **InteractBind**. By aligning model attention with these physically grounded signals, ExplainBind transforms PLB prediction from a black-box reasoning into an **chemistry-grounded** process suitable for large-scale screening.
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── Patch gradio_client so boolean schemas don’t crash json_schema_to_python_type ───
2
+ import gradio_client.utils as _gc_utils
3
+ _orig_get_type = _gc_utils.get_type
4
+ _orig_json2py = _gc_utils._json_schema_to_python_type
5
+ def _patched_get_type(schema):
6
+ if isinstance(schema, bool):
7
+ schema = {}
8
+ return _orig_get_type(schema)
9
+ def _patched_json_schema_to_python_type(schema, defs=None):
10
+ if isinstance(schema, bool):
11
+ schema = {}
12
+ return _orig_json2py(schema, defs)
13
+ _gc_utils.get_type = _patched_get_type
14
+ _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
15
+
16
+ # ─── Imports ───────────────────────────────────────────────────────────────────
17
+ import os
18
+ import io
19
+ import base64
20
+ import argparse
21
+ from typing import Optional, List, Tuple
22
+
23
+ import numpy as np
24
+ import torch
25
+ from torch.utils.data import DataLoader
26
+
27
+ import selfies
28
+ from rdkit import Chem
29
+
30
+ import matplotlib
31
+ matplotlib.use("Agg")
32
+ import matplotlib.pyplot as plt
33
+ from matplotlib import cm
34
+
35
+ from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel, AutoTokenizer
36
+ from Bio.PDB import PDBParser, MMCIFParser
37
+ from Bio.Data import IUPACData
38
+
39
+ import gradio as gr
40
+
41
+ # Project utils (ensure these exist in your repository)
42
+ from utils.metric_learning_models_att_maps import Pre_encoded, ExplainBind
43
+ from utils.foldseek_util import get_struc_seq
44
+
45
+ # ───────────────────── Paths & Logos ─────────────────────
46
+ ROOT = os.path.dirname(os.path.abspath(__file__))
47
+ ASSET_DIR = os.path.join(ROOT, "assets")
48
+
49
+ LOSCAZLO_LOGO = os.path.join(ASSET_DIR, "loscalzo.png")
50
+
51
+ def _load_logo_b64(path):
52
+ if not os.path.exists(path):
53
+ return ""
54
+ with open(path, "rb") as f:
55
+ return base64.b64encode(f.read()).decode("utf-8")
56
+
57
+ LOSCAZLO_B64 = _load_logo_b64(LOSCAZLO_LOGO)
58
+
59
+
60
+ # ───────────────────── Configurable constants ─────────────────────
61
+ # UI-visible names (Halogen bonding removed)
62
+ INTERACTION_NAMES = [
63
+ "Hydrogen bonding",
64
+ "Salt Bridging",
65
+ "π–π Stacking",
66
+ "Cation–π",
67
+ "Hydrophobic",
68
+ "Van der Waals",
69
+ "Overall Interaction",
70
+ ]
71
+
72
+ # Map visible indices (0..5 = specific, 6 = combined) to underlying channel indices
73
+ # Underlying channels originally had Halogen at index=5 (0-based). We skip 5 entirely.
74
+ VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cation-Pi, Hydro, VdW
75
+ N_VISIBLE_SPEC = len(VISIBLE2UNDERLYING) # 6
76
+
77
+ # ───── Helper utilities ───────────────────────────────────────────
78
+ three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
79
+ three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
80
+ STANDARD_AA_SET = set("ACDEFGHIKLMNPQRSTVWY") # Uppercase FASTA amino acids
81
+
82
+
83
+ def simple_seq_from_structure(path: str) -> str:
84
+ """Extract the longest chain from a PDB/mmCIF file and return a simple 1-letter sequence."""
85
+ parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
86
+ structure = parser.get_structure("P", path)
87
+ chains = list(structure.get_chains())
88
+ if not chains:
89
+ return ""
90
+ chain = max(chains, key=lambda c: len(list(c.get_residues())))
91
+ return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
92
+
93
+
94
+ def smiles_to_selfies(smiles_text: str) -> Optional[str]:
95
+ """Validate and convert SMILES to SELFIES; return None if invalid."""
96
+ try:
97
+ mol = Chem.MolFromSmiles(smiles_text)
98
+ if mol is None:
99
+ return None
100
+ return selfies.encoder(smiles_text)
101
+ except Exception:
102
+ return None
103
+
104
+
105
+ def detect_protein_type(seq: str) -> str:
106
+ """
107
+ Heuristic for protein input:
108
+ - All uppercase and only the standard 20 amino acids → 'fasta'
109
+ - Otherwise (contains lowercase or non-standard characters) → 'sa'
110
+ """
111
+ s = (seq or "").strip()
112
+ if not s:
113
+ return "fasta"
114
+ up = s.upper()
115
+ only_aa = all(ch in STANDARD_AA_SET for ch in up)
116
+ all_upper = (s == up)
117
+ return "fasta" if (only_aa and all_upper) else "sa"
118
+
119
+
120
+ def detect_ligand_type(text: str) -> str:
121
+ """
122
+ Heuristic for ligand input:
123
+ - Starts with '[' and contains ']' → 'selfies'
124
+ - Otherwise → 'smiles'
125
+ """
126
+ t = (text or "").strip()
127
+ if not t:
128
+ return "smiles"
129
+ return "selfies" if (t.startswith("[") and ("]" in t)) else "smiles"
130
+
131
+
132
+ def parse_config():
133
+ """Parse command-line options."""
134
+ p = argparse.ArgumentParser()
135
+ p.add_argument("--agg_mode", type=str, default="mean_all_tok")
136
+ p.add_argument("--group_size", type=int, default=1)
137
+ p.add_argument("--fusion", default="CAN")
138
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
139
+ p.add_argument("--save_path_prefix", default="save_model_ckp/") # Root folder containing checkpoints
140
+ p.add_argument("--dataset", default="Human")
141
+ return p.parse_args()
142
+
143
+
144
+ args = parse_config()
145
+ DEVICE = args.device
146
+
147
+ # ───── Dynamic model registry ─────────────────────────────────────
148
+ PROT_MODELS = {
149
+ "sa": "westlake-repl/SaProt_650M_AF2",
150
+ "fasta": "facebook/esm2_t33_650M_UR50D",
151
+ }
152
+ DRUG_MODELS = {
153
+ "selfies": "HUBioDataLab/SELFormer",
154
+ # "smiles": "ibm/MoLFormer-XL-both-10pct",
155
+ }
156
+
157
+
158
+ def load_encoders(ptype: str, ltype: str, args):
159
+ """
160
+ Dynamically load encoders and tokenisers based on input types.
161
+ Returns: (prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding_module)
162
+ """
163
+ # Protein encoder
164
+ if ptype == "fasta":
165
+ prot_path = PROT_MODELS["fasta"]
166
+ prot_tokenizer = EsmTokenizer.from_pretrained(prot_path, do_lower_case=False)
167
+ prot_model = EsmForMaskedLM.from_pretrained(prot_path)
168
+ else: # 'sa'
169
+ prot_path = PROT_MODELS["sa"]
170
+ prot_tokenizer = EsmTokenizer.from_pretrained(prot_path)
171
+ prot_model = EsmForMaskedLM.from_pretrained(prot_path)
172
+
173
+ drug_path = DRUG_MODELS["selfies"]
174
+ drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
175
+ drug_model = AutoModel.from_pretrained(drug_path)
176
+ # Ligand encoder
177
+ # if ltype == "smiles":
178
+ # drug_path = DRUG_MODELS["smiles"]
179
+ # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path, trust_remote_code=True)
180
+ # drug_model = AutoModel.from_pretrained(drug_path, deterministic_eval=True, trust_remote_code=True)
181
+ # else: # 'selfies'
182
+ # drug_path = DRUG_MODELS["selfies"]
183
+ # drug_tokenizer = AutoTokenizer.from_pretrained(drug_path)
184
+ # drug_model = AutoModel.from_pretrained(drug_path)
185
+
186
+ # Wrap encoders with Pre_encoded module
187
+ encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
188
+ return prot_tokenizer, prot_model, drug_tokenizer, drug_model, encoding
189
+
190
+
191
+ def make_collate_fn(prot_tokenizer, drug_tokenizer):
192
+ """Create a batch collation function using the given tokenisers."""
193
+ def _collate_fn(batch):
194
+ query1, query2, scores = zip(*batch)
195
+ query_encodings1 = prot_tokenizer(
196
+ list(query1), max_length=512, padding="max_length", truncation=True,
197
+ add_special_tokens=True, return_tensors="pt",
198
+ )
199
+ query_encodings2 = drug_tokenizer(
200
+ list(query2), max_length=512, padding="max_length", truncation=True,
201
+ add_special_tokens=True, return_tensors="pt",
202
+ )
203
+ scores = torch.tensor(list(scores))
204
+ attention_mask1 = query_encodings1["attention_mask"].bool()
205
+ attention_mask2 = query_encodings2["attention_mask"].bool()
206
+ return (query_encodings1["input_ids"], attention_mask1,
207
+ query_encodings2["input_ids"], attention_mask2, scores)
208
+ return _collate_fn
209
+
210
+
211
+ def get_case_feature(model, loader):
212
+ """Generate features for one protein–ligand pair using the provided model."""
213
+ model.eval()
214
+ with torch.no_grad():
215
+ for p_ids, p_mask, d_ids, d_mask, _ in loader:
216
+ p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
217
+ d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
218
+ p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
219
+ return [(p_emb.cpu(), d_emb.cpu(),
220
+ p_ids.cpu(), d_ids.cpu(),
221
+ p_mask.cpu(), d_mask.cpu(), None)]
222
+
223
+ # ─────────────── SELFIES grouping by ORIGINAL string ─────────────
224
+ def _group_rows_by_selfies_string(n_rows: int, selfies_str: str):
225
+ """
226
+ Partition the attention matrix's n_rows along ligand axis into groups per SELFIES token '[ ... ]'.
227
+ Each group is a contiguous row span; we assign rows ≈ equally using linspace.
228
+ Returns:
229
+ groups: List[(start_row, end_row)] inclusive
230
+ labels: List['[X]','[=O]', ...]
231
+ """
232
+ if n_rows <= 0:
233
+ return [], []
234
+
235
+ try:
236
+ toks = list(selfies.split_selfies((selfies_str or "").strip()))
237
+ except Exception:
238
+ toks = []
239
+
240
+ if not toks:
241
+ # Fallback: treat whole ligand as one token
242
+ return [(0, n_rows - 1)], [selfies_str or "[?]"]
243
+
244
+ g = len(toks)
245
+ edges = np.linspace(0, n_rows, g + 1, dtype=int)
246
+ groups = []
247
+ for i in range(g):
248
+ s, e = edges[i], edges[i + 1] - 1
249
+ if e < s:
250
+ e = s
251
+ groups.append((s, e))
252
+ return groups, toks
253
+
254
+
255
+
256
+ def _connected_components_2d(mask: torch.Tensor) -> List[List[Tuple[int, int]]]:
257
+ """4-connected components over a 2D boolean mask (rows=ligand tokens, cols=protein residues)."""
258
+ h, w = mask.shape
259
+ visited = torch.zeros_like(mask, dtype=torch.bool)
260
+ comps: List[List[Tuple[int,int]]] = []
261
+ for i in range(h):
262
+ for j in range(w):
263
+ if mask[i, j] and not visited[i, j]:
264
+ stack = [(i, j)]
265
+ visited[i, j] = True
266
+ comp = []
267
+ while stack:
268
+ r, c = stack.pop()
269
+ comp.append((r, c))
270
+ for dr, dc in ((1,0), (-1,0), (0,1), (0,-1)):
271
+ rr, cc = r + dr, c + dc
272
+ if 0 <= rr < h and 0 <= cc < w and mask[rr, cc] and not visited[rr, cc]:
273
+ visited[rr, cc] = True
274
+ stack.append((rr, cc))
275
+ comps.append(comp)
276
+ return comps
277
+
278
+ def _format_component_table(
279
+ components,
280
+ p_tokens,
281
+ d_tokens,
282
+ *,
283
+ mode: str = "pair", # "pair" | "residue"
284
+ ):
285
+ """
286
+ Render HTML table for highlighted interaction components.
287
+
288
+ Parameters
289
+ ----------
290
+ components : List[List[Tuple[int,int]]]
291
+ Each component is a list of (ligand_index, protein_index) pairs.
292
+ p_tokens : List[str]
293
+ Protein token strings.
294
+ d_tokens : List[str]
295
+ Ligand token strings.
296
+ mode : str
297
+ "pair" -> show Protein range + Ligand range
298
+ "residue" -> show Protein residue(s) only
299
+ """
300
+
301
+ # ----------------------------
302
+ # Residue-only mode
303
+ # ----------------------------
304
+ if mode == "residue":
305
+ if not components:
306
+ return (
307
+ "<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>"
308
+ "<p>No residues selected.</p>"
309
+ )
310
+
311
+ rows = []
312
+ for comp in components:
313
+ # comp = [(lig_idx, prot_idx), ...]
314
+ prot_indices = [j for (_, j) in comp]
315
+ p_start, p_end = min(prot_indices), max(prot_indices)
316
+
317
+ p_s_idx, p_e_idx = p_start + 1, p_end + 1
318
+ p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
319
+ p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
320
+
321
+ if p_start == p_end:
322
+ label = f"{p_s_idx}:{p_s_tok}"
323
+ else:
324
+ label = f"{p_s_idx}:{p_s_tok} – {p_e_idx}:{p_e_tok}"
325
+
326
+ rows.append(
327
+ f"<tr>"
328
+ f"<td style='border:1px solid #ddd;padding:6px'>"
329
+ f"<strong>{label}</strong>"
330
+ f"</td>"
331
+ f"</tr>"
332
+ )
333
+
334
+ return (
335
+ "<h4 style='margin:12px 0 6px'>Highlighted protein residues</h4>"
336
+ "<table style='border-collapse:collapse;margin:6px 0 16px;width:60%'>"
337
+ "<thead><tr style='background:#f5f5f5'>"
338
+ "<th style='border:1px solid #ddd;padding:6px'>Protein residue(s)</th>"
339
+ "</tr></thead>"
340
+ f"<tbody>{''.join(rows)}</tbody></table>"
341
+ )
342
+
343
+ # ----------------------------
344
+ # Pair mode (default behaviour)
345
+ # ----------------------------
346
+ if not components:
347
+ return (
348
+ "<h4 style='margin:12px 0 6px'>Highlighted interaction segments</h4>"
349
+ "<p>No interaction pairs selected.</p>"
350
+ )
351
+
352
+ rows = []
353
+ for comp in components:
354
+ lig_indices = [i for (i, _) in comp]
355
+ prot_indices = [j for (_, j) in comp]
356
+
357
+ d_start, d_end = min(lig_indices), max(lig_indices)
358
+ p_start, p_end = min(prot_indices), max(prot_indices)
359
+
360
+ d_s_idx, d_e_idx = d_start + 1, d_end + 1
361
+ p_s_idx, p_e_idx = p_start + 1, p_end + 1
362
+
363
+ d_s_tok = d_tokens[d_start] if d_start < len(d_tokens) else "?"
364
+ d_e_tok = d_tokens[d_end] if d_end < len(d_tokens) else "?"
365
+ p_s_tok = p_tokens[p_start] if p_start < len(p_tokens) else "?"
366
+ p_e_tok = p_tokens[p_end] if p_end < len(p_tokens) else "?"
367
+
368
+ rows.append(
369
+ f"<tr>"
370
+ f"<td style='border:1px solid #ddd;padding:6px'>Protein: "
371
+ f"<strong>{p_s_idx}:{p_s_tok}</strong>"
372
+ f"{' – <strong>'+str(p_e_idx)+':'+p_e_tok+'</strong>' if p_end > p_start else ''}"
373
+ f"</td>"
374
+ f"<td style='border:1px solid #ddd;padding:6px'>Ligand: "
375
+ f"<strong>{d_s_idx}:{d_s_tok}</strong>"
376
+ f"{' – <strong>'+str(d_e_idx)+':'+d_e_tok+'</strong>' if d_end > d_start else ''}"
377
+ f"</td>"
378
+ f"</tr>"
379
+ )
380
+
381
+ return (
382
+ "<h4 style='margin:12px 0 6px'></h4>"
383
+ "<table style='border-collapse:collapse;margin:6px 0 16px;width:100%'>"
384
+ "<thead><tr style='background:#f5f5f5'>"
385
+ "<th style='border:1px solid #ddd;padding:6px'>Protein range</th>"
386
+ "<th style='border:1px solid #ddd;padding:6px'>Ligand range</th>"
387
+ "</tr></thead>"
388
+ f"<tbody>{''.join(rows)}</tbody></table>"
389
+ )
390
+
391
+
392
+ def visualize_attention_and_ranges(
393
+ model,
394
+ feats,
395
+ head_idx: int,
396
+ *,
397
+ mode: str = "pair", # "pair" | "residue"
398
+ topk_pairs: int = 1, # Top-K interaction pairs (default=1)
399
+ topk_residues: int = 1, # Top-K residues (1–20, default=1)
400
+ prot_tokenizer=None,
401
+ drug_tokenizer=None,
402
+ ligand_type: str = "selfies",
403
+ raw_selfies: Optional[str] = None,
404
+ ) -> Tuple[str, str]:
405
+ """
406
+ Visualise interaction attention with two complementary Top-K modes.
407
+
408
+ Modes
409
+ -----
410
+ mode="pair":
411
+ - Select Top-K highest-scoring (ligand token, protein residue) pairs
412
+ - Project selected pairs onto protein axis (evaluation-aligned)
413
+ - Default K = 1 (user-controlled)
414
+
415
+ mode="residue":
416
+ - Aggregate attention over ligand dimension
417
+ - Rank residues by aggregated score
418
+ - Select Top-K residues (1–20)
419
+ - Default K = 1 (binding pocket discovery)
420
+
421
+ Notes
422
+ -----
423
+ - Per-head GLOBAL SUM normalisation (matches test()).
424
+ - Specific heads mapped exactly to GT channels.
425
+ - Combined head = sum of 6 specific heads (NOT overall=7).
426
+ """
427
+
428
+ assert mode in {"pair", "residue"}
429
+ assert topk_pairs >= 1
430
+ assert 1 <= topk_residues <= 20
431
+
432
+ model.eval()
433
+ with torch.no_grad():
434
+ # --------------------------------------------------
435
+ # Unpack features
436
+ # --------------------------------------------------
437
+ p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
438
+ p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
439
+ p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
440
+
441
+ # --------------------------------------------------
442
+ # Forward
443
+ # --------------------------------------------------
444
+ _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
445
+ att = att_pd.squeeze(0)
446
+ # expected: [Ld, Lp, 8] or [8, Ld, Lp]
447
+
448
+ # --------------------------------------------------
449
+ # Channel mapping (must match test())
450
+ # --------------------------------------------------
451
+ VISIBLE2UNDERLYING = [1, 2, 3, 4, 6, 0] # HB, Salt, Pi, Cat-Pi, Hydro, VdW
452
+ N_VISIBLE_SPEC = 6
453
+
454
+ def select_channel_map(att_):
455
+ if att_.dim() == 3 and att_.shape[-1] >= 8:
456
+ if head_idx < N_VISIBLE_SPEC:
457
+ return att_[:, :, VISIBLE2UNDERLYING[head_idx]].cpu()
458
+ return att_[:, :, VISIBLE2UNDERLYING].sum(dim=2).cpu()
459
+ if att_.dim() == 3 and att_.shape[0] >= 8:
460
+ if head_idx < N_VISIBLE_SPEC:
461
+ return att_[VISIBLE2UNDERLYING[head_idx]].cpu()
462
+ return att_[VISIBLE2UNDERLYING].sum(dim=0).cpu()
463
+ return att_.squeeze().cpu()
464
+
465
+ att2d_raw = select_channel_map(att) # [Ld, Lp]
466
+
467
+ # --------------------------------------------------
468
+ # Per-head GLOBAL SUM normalisation (critical)
469
+ # --------------------------------------------------
470
+ att2d_raw = att2d_raw / (att2d_raw.sum() + 1e-8)
471
+
472
+ # --------------------------------------------------
473
+ # Token decoding & trimming
474
+ # --------------------------------------------------
475
+ def clean_tokens(ids, tokenizer):
476
+ toks = tokenizer.convert_ids_to_tokens(ids.tolist())
477
+ if hasattr(tokenizer, "all_special_tokens"):
478
+ toks = [t for t in toks if t not in tokenizer.all_special_tokens]
479
+ return toks
480
+
481
+ p_tokens_full = clean_tokens(p_ids[0], prot_tokenizer)
482
+ d_tokens_full = clean_tokens(d_ids[0], drug_tokenizer)
483
+
484
+ n_d = min(len(d_tokens_full), att2d_raw.size(0))
485
+ n_p = min(len(p_tokens_full), att2d_raw.size(1))
486
+
487
+ att2d = att2d_raw[:n_d, :n_p]
488
+ p_tokens = p_tokens_full[:n_p]
489
+ d_tokens = d_tokens_full[:n_d]
490
+
491
+ p_indices = list(range(1, n_p + 1))
492
+ d_indices = list(range(1, n_d + 1))
493
+
494
+ # --------------------------------------------------
495
+ # SELFIES row merging (for interpretability)
496
+ # --------------------------------------------------
497
+ if ligand_type == "selfies" and raw_selfies:
498
+ groups, labels = _group_rows_by_selfies_string(att2d.size(0), raw_selfies)
499
+ if groups:
500
+ merged = []
501
+ for s, e in groups:
502
+ merged.append(att2d[s:e + 1].mean(dim=0, keepdim=True))
503
+ att2d = torch.cat(merged, dim=0)
504
+ d_tokens = labels
505
+ d_indices = list(range(1, len(labels) + 1))
506
+
507
+ # --------------------------------------------------
508
+ # Top-K selection (two modes)
509
+ # --------------------------------------------------
510
+ if mode == "pair":
511
+ # --- Top-K interaction pairs ---
512
+ flat = att2d.reshape(-1)
513
+ k_eff = min(topk_pairs, flat.numel())
514
+ idx = torch.topk(flat, k=k_eff).indices
515
+
516
+ mask_top = torch.zeros_like(flat, dtype=torch.bool)
517
+ mask_top[idx] = True
518
+ mask_top = mask_top.view_as(att2d)
519
+
520
+ else:
521
+ # --- Top-K residues ---
522
+ residue_score = att2d.sum(dim=0) # [protein]
523
+ k_eff = min(topk_residues, residue_score.numel())
524
+ topk_res_idx = torch.topk(residue_score, k=k_eff).indices
525
+
526
+ mask_top = torch.zeros_like(att2d, dtype=torch.bool)
527
+ mask_top[:, topk_res_idx] = True # keep all ligand rows
528
+
529
+ # --------------------------------------------------
530
+ # Connected components (visual coherence)
531
+ # --------------------------------------------------
532
+ components = _connected_components_2d(mask_top)
533
+ ranges_html = _format_component_table(
534
+ components,
535
+ p_tokens,
536
+ d_tokens,
537
+ mode=mode,
538
+ )
539
+
540
+
541
+ # --------------------------------------------------
542
+ # Crop to union of selected rows / columns
543
+ # --------------------------------------------------
544
+ rows_keep = mask_top.any(dim=1)
545
+ cols_keep = mask_top.any(dim=0)
546
+
547
+ if not rows_keep.any():
548
+ rows_keep[:] = True
549
+ if not cols_keep.any():
550
+ cols_keep[:] = True
551
+
552
+ vis = att2d[rows_keep][:, cols_keep]
553
+
554
+ d_tokens_vis = [t for k, t in zip(rows_keep.tolist(), d_tokens) if k]
555
+ p_tokens_vis = [t for k, t in zip(cols_keep.tolist(), p_tokens) if k]
556
+ d_indices_vis = [i for k, i in zip(rows_keep.tolist(), d_indices) if k]
557
+ p_indices_vis = [i for k, i in zip(cols_keep.tolist(), p_indices) if k]
558
+
559
+ # Cap columns for readability
560
+ if vis.size(1) > 150:
561
+ topc = torch.topk(vis.sum(0), k=150).indices
562
+ vis = vis[:, topc]
563
+ p_tokens_vis = [p_tokens_vis[i] for i in topc]
564
+ p_indices_vis = [p_indices_vis[i] for i in topc]
565
+
566
+ # --------------------------------------------------
567
+ # Plot
568
+ # --------------------------------------------------
569
+ x_labels = [f"{i}:{t}" for i, t in zip(p_indices_vis, p_tokens_vis)]
570
+ y_labels = [f"{i}:{t}" for i, t in zip(d_indices_vis, d_tokens_vis)]
571
+
572
+ fig_w = min(22, max(6, len(x_labels) * 0.6))
573
+ fig_h = min(24, max(6, len(y_labels) * 0.8))
574
+
575
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h))
576
+ im = ax.imshow(vis.numpy(), aspect="auto", cmap=cm.viridis)
577
+
578
+ title = INTERACTION_NAMES[head_idx]
579
+ suffix = "Top-K pairs" if mode == "pair" else "Top-K residues"
580
+ ax.set_title(f"Ligand × Protein — {title} ({suffix})", fontsize=10, pad=8)
581
+ ax.set_xlabel("Protein residues")
582
+ ax.set_ylabel("Ligand tokens")
583
+
584
+ ax.set_xticks(range(len(x_labels)))
585
+ ax.set_xticklabels(x_labels, rotation=90, fontsize=8)
586
+ ax.set_yticks(range(len(y_labels)))
587
+ ax.set_yticklabels(y_labels, fontsize=7)
588
+
589
+ ax.xaxis.tick_top()
590
+ ax.xaxis.set_label_position("top")
591
+ ax.tick_params(axis="x", bottom=False)
592
+
593
+ fig.colorbar(im, fraction=0.026, pad=0.01)
594
+ fig.tight_layout()
595
+
596
+ # --------------------------------------------------
597
+ # Export
598
+ # --------------------------------------------------
599
+ buf_png = io.BytesIO()
600
+ buf_pdf = io.BytesIO()
601
+ fig.savefig(buf_png, format="png", dpi=140)
602
+ fig.savefig(buf_pdf, format="pdf")
603
+ plt.close(fig)
604
+
605
+ png_b64 = base64.b64encode(buf_png.getvalue()).decode()
606
+ pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
607
+
608
+ heat_html = f"""
609
+ <div style='position:relative'>
610
+ <a href='data:application/pdf;base64,{pdf_b64}' download='attention_{head_idx+1}.pdf'
611
+ style='position:absolute;top:10px;right:10px;
612
+ background:#111;color:#fff;padding:8px 12px;
613
+ border-radius:10px;font-size:.85rem;text-decoration:none'>
614
+ Download PDF
615
+ </a>
616
+ <img src='data:image/png;base64,{png_b64}' />
617
+ </div>
618
+ """
619
+
620
+ return heat_html, ranges_html
621
+
622
+
623
+
624
+
625
+ # ───── Gradio callbacks ─────────────────────────────────────────
626
+ ROOT = os.path.dirname(os.path.abspath(__file__))
627
+ FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek")
628
+
629
+ def extract_sequence_cb(structure_file, drug_text, protein_text):
630
+ """
631
+ Unified “Extract / Convert” callback.
632
+
633
+ Behaviour:
634
+ • If a .pdb / .cif file is uploaded and successfully parsed:
635
+ → overwrite protein textbox with extracted SA sequence
636
+ • If ligand textbox contains SMILES:
637
+ → convert to SELFIES
638
+ • Otherwise:
639
+ → keep existing protein / ligand unchanged
640
+ """
641
+
642
+ msgs = []
643
+
644
+ # --------------------------------------------------
645
+ # Default: ALWAYS keep current contents
646
+ # --------------------------------------------------
647
+ prot_seq_out = (protein_text or "").strip()
648
+ drug_seq_out = (drug_text or "").strip()
649
+
650
+ # --------------------------------------------------
651
+ # Structure file → SA sequence (overwrite protein ONLY on success)
652
+ # --------------------------------------------------
653
+ if structure_file is not None and os.path.exists(structure_file.name):
654
+ try:
655
+ parsed = get_struc_seq(
656
+ FOLDSEEK_BIN,
657
+ structure_file.name,
658
+ None,
659
+ plddt_mask=False,
660
+ )
661
+ first_chain = next(iter(parsed))
662
+ _, _, struct_seq = parsed[first_chain]
663
+
664
+ if struct_seq:
665
+ prot_seq_out = struct_seq
666
+ msgs.append(
667
+ "<li>✅ Extracted <b>SA sequence</b> from the uploaded structure.</li>"
668
+ )
669
+ else:
670
+ msgs.append(
671
+ "<li>❌ Structure parsed, but no valid protein sequence found.</li>"
672
+ )
673
+
674
+ except Exception as e:
675
+ msgs.append(
676
+ f"<li>❌ Failed to extract SA sequence: <b>{e}</b></li>"
677
+ )
678
+
679
+ # --------------------------------------------------
680
+ # SMILES → SELFIES (do NOT touch protein)
681
+ # --------------------------------------------------
682
+ if drug_seq_out:
683
+ lig_type = detect_ligand_type(drug_seq_out)
684
+ if lig_type == "smiles":
685
+ try:
686
+ conv = smiles_to_selfies(drug_seq_out)
687
+ if conv is None:
688
+ msgs.append(
689
+ "<li>❌ SMILES → SELFIES failed: <b>invalid SMILES</b>.</li>"
690
+ )
691
+ else:
692
+ drug_seq_out = conv
693
+ msgs.append(
694
+ "<li>✅ Converted <b>SMILES</b> to <b>SELFIES</b>.</li>"
695
+ )
696
+ except Exception as e:
697
+ msgs.append(
698
+ f"<li>❌ SMILES → SELFIES error: <b>{e}</b></li>"
699
+ )
700
+
701
+ # --------------------------------------------------
702
+ # Status message box
703
+ # --------------------------------------------------
704
+ if msgs:
705
+ status_html = (
706
+ "<div style='margin:10px 0;padding:10px 12px;"
707
+ "border:1px solid #e5e7eb;border-radius:10px;"
708
+ "background:#f8fafc;color:#0f172a'>"
709
+ "<ul style='margin:0 0 0 18px;padding:0'>"
710
+ f"{''.join(msgs)}"
711
+ "</ul></div>"
712
+ )
713
+ else:
714
+ status_html = ""
715
+
716
+ return prot_seq_out, drug_seq_out, status_html
717
+
718
+
719
+ def _choose_ckpt_by_types(prot_seq: str, ligand_text: str) -> Tuple[str, str, str]:
720
+ """Return (folder_name, protein_type, ligand_type) for checkpoint routing."""
721
+ ptype = detect_protein_type(prot_seq)
722
+ ltype = detect_ligand_type(ligand_text)
723
+ folder = f"{ptype}_{ltype}" # sa_selfies / fasta_selfies / sa_smiles / fasta_smiles
724
+ return folder, ptype, ltype
725
+
726
+
727
+ def inference_cb(prot_seq, drug_seq, head_choice, topk_choice, mode_choice):
728
+ """
729
+ Inference callback supporting two Top-K modes:
730
+ - Top-K interaction pairs
731
+ - Top-K residues
732
+ """
733
+
734
+ # ------------------------------
735
+ # Input validation
736
+ # ------------------------------
737
+ if not prot_seq or not prot_seq.strip():
738
+ return "<p style='color:red'>Please extract or enter a protein sequence first.</p>", ""
739
+
740
+ if not drug_seq or not drug_seq.strip():
741
+ return "<p style='color:red'>Please enter a ligand sequence (SELFIES or SMILES).</p>", ""
742
+
743
+ prot_seq = prot_seq.strip()
744
+ drug_seq_in = drug_seq.strip()
745
+
746
+ # ------------------------------
747
+ # Detect types & checkpoint routing
748
+ # ------------------------------
749
+ folder, ptype, ltype = _choose_ckpt_by_types(prot_seq, drug_seq_in)
750
+
751
+ # Ligand normalisation: always tokenise as SELFIES
752
+ if ltype == "smiles":
753
+ conv = smiles_to_selfies(drug_seq_in)
754
+ if conv is None:
755
+ return (
756
+ "<p style='color:red'>SMILES→SELFIES conversion failed. "
757
+ "The SMILES appears invalid.</p>",
758
+ "",
759
+ )
760
+ drug_seq_for_tokenizer = conv
761
+ else:
762
+ drug_seq_for_tokenizer = drug_seq_in
763
+
764
+ # 🔒 强制统一类型
765
+ ltype = "selfies"
766
+ ligand_type_flag = "selfies"
767
+ raw_selfies = drug_seq_for_tokenizer
768
+ folder = f"{ptype}_selfies"
769
+
770
+
771
+ # # Ligand normalisation: always tokenise as SELFIES
772
+ # if ltype == "smiles":
773
+ # conv = smiles_to_selfies(drug_seq_in)
774
+ # if conv is None:
775
+ # return (
776
+ # "<p style='color:red'>SMILES→SELFIES conversion failed. "
777
+ # "The SMILES appears invalid.</p>",
778
+ # "",
779
+ # )
780
+ # drug_seq_for_tokenizer = conv
781
+ # ligand_type_flag = "selfies"
782
+ # else:
783
+ # drug_seq_for_tokenizer = drug_seq_in
784
+ # ligand_type_flag = "selfies"
785
+
786
+ # raw_selfies = drug_seq_for_tokenizer if ligand_type_flag == "selfies" else None
787
+
788
+ # ------------------------------
789
+ # Load encoders
790
+ # ------------------------------
791
+ prot_tok, prot_m, drug_tok, drug_m, encoding = load_encoders(ptype, ltype, args)
792
+
793
+ loader = DataLoader(
794
+ [(prot_seq, drug_seq_for_tokenizer, 1)],
795
+ batch_size=1,
796
+ collate_fn=make_collate_fn(prot_tok, drug_tok),
797
+ )
798
+
799
+ feats = get_case_feature(encoding, loader)
800
+
801
+ # ------------------------------
802
+ # Load trained checkpoint (if exists)
803
+ # ------------------------------
804
+ ckpt = os.path.join(args.save_path_prefix, folder, "best_model.ckpt")
805
+ model = ExplainBind(1280, 768, args=args).to(DEVICE)
806
+
807
+ if os.path.isfile(ckpt):
808
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
809
+ warn_html = (
810
+ "<div style='margin:8px 0 14px;padding:8px 10px;"
811
+ "border-left:4px solid #10b981;background:#ecfdf5'>"
812
+ f"<b>Loaded model:</b> <code>{folder}/best_model.ckpt</code></div>"
813
+ )
814
+ else:
815
+ warn_html = (
816
+ "<div style='margin:8px 0 14px;padding:8px 10px;"
817
+ "border-left:4px solid #f59e0b;background:#fffbeb'>"
818
+ "<b>Warning:</b> checkpoint not found "
819
+ f"<code>{folder}/best_model.ckpt</code>. "
820
+ "Using randomly initialised weights for visualisation.</div>"
821
+ )
822
+
823
+ # ------------------------------
824
+ # Parse interaction head
825
+ # ------------------------------
826
+ sel = str(head_choice).strip()
827
+ if sel in INTERACTION_NAMES:
828
+ head_idx = INTERACTION_NAMES.index(sel)
829
+ else:
830
+ try:
831
+ n = int(sel.split(".", 1)[0])
832
+ head_idx = max(0, min(len(INTERACTION_NAMES) - 1, n - 1))
833
+ except Exception:
834
+ head_idx = len(INTERACTION_NAMES) - 1 # Combined Interaction
835
+
836
+ # ------------------------------
837
+ # Parse Top-K value
838
+ # ------------------------------
839
+ try:
840
+ topk = int(str(topk_choice).strip())
841
+ except Exception:
842
+ topk = 1
843
+
844
+ topk = max(1, topk)
845
+
846
+ # ------------------------------
847
+ # Parse mode (pair vs residue)
848
+ # ------------------------------
849
+ mode_choice = str(mode_choice).lower()
850
+
851
+ if "residue" in mode_choice:
852
+ mode = "residue"
853
+ topk_pairs = 1
854
+ topk_residues = min(20, topk)
855
+ else:
856
+ mode = "pair"
857
+ topk_pairs = topk
858
+ topk_residues = 1
859
+
860
+ # ------------------------------
861
+ # Visualisation
862
+ # ------------------------------
863
+ heat_html, table_html = visualize_attention_and_ranges(
864
+ model,
865
+ feats,
866
+ head_idx,
867
+ mode=mode,
868
+ topk_pairs=topk_pairs,
869
+ topk_residues=topk_residues,
870
+ prot_tokenizer=prot_tok,
871
+ drug_tokenizer=drug_tok,
872
+ ligand_type=ligand_type_flag,
873
+ raw_selfies=raw_selfies,
874
+ )
875
+
876
+ # ------------------------------
877
+ # Info box
878
+ # ------------------------------
879
+ # info_html = (
880
+ # "<div style='margin:10px 0;padding:10px 12px;"
881
+ # "border:1px solid #e5e7eb;border-radius:10px;"
882
+ # "background:#f8fafc;color:#0f172a'>"
883
+ # "<div style='margin-bottom:6px'>"
884
+ # f"<b>Detected types:</b> "
885
+ # f"Protein=<code>{ptype.upper()}</code>, "
886
+ # f"Ligand=<code>{ltype.upper()}</code></div>"
887
+ # f"<div><b>Top-K mode:</b> "
888
+ # f"<code>{'Interaction pairs' if mode == 'pair' else 'Residues'}</code>, "
889
+ # f"K=<code>{topk}</code></div>"
890
+ # f"{warn_html}</div>"
891
+ # )
892
+
893
+ # table_html = info_html + table_html
894
+
895
+ return table_html, heat_html
896
+
897
+
898
+ def clear_cb():
899
+ return "", "", "", "", None, ""
900
+ # protein, drug, table, heat, file, status
901
+
902
+
903
+ # ───── Gradio interface definition ───────────────────────────────
904
+ css = """
905
+ :root{
906
+ --bg:#f8fafc; --card:#f8fafc; --text:#0f172a;
907
+ --muted:#6b7280; --border:#e5e7eb; --shadow:0 6px 24px rgba(2,6,23,.06);
908
+ --radius:14px; --icon-size:20px;
909
+ }
910
+
911
+ *{box-sizing:border-box}
912
+ html,body{background:#fff!important;color:var(--text)!important}
913
+ .gradio-container{max-width:1120px;margin:0 auto}
914
+
915
+ /* Title and subtitle */
916
+ h1{
917
+ font-family:Inter,ui-sans-serif;letter-spacing:.2px;font-weight:700;
918
+ font-size:32px;margin:22px 0 12px;text-align:center
919
+ }
920
+ .subtle{color:var(--muted);font-size:14px;text-align:center;margin:-6px 0 18px}
921
+
922
+ /* Card style */
923
+ .card{
924
+ background:var(--card); border:1px solid var(--border); border-radius:var(--radius);
925
+ box-shadow:var(--shadow); padding:22px;
926
+ }
927
+
928
+ /* Top links */
929
+ .link-row{display:flex;justify-content:center;gap:14px;margin:0 auto 18px;flex-wrap:wrap}
930
+
931
+ /* Two-column grid: left=input, right=controls */
932
+ .grid-2{display:grid;grid-template-columns:1.4fr .9fr;gap:16px}
933
+ .grid-2 .col{display:flex;flex-direction:column;gap:12px}
934
+
935
+ /* Buttons */
936
+ .gr-button{border-radius:12px !important;font-weight:700 !important;letter-spacing:.2px}
937
+ #extract-btn{background:linear-gradient(90deg,#EFAFB2,#EFAFB2); color:#0f172a}
938
+ #inference-btn{background:linear-gradient(90deg,#B2CBDF,#B2CBDF); color:#0f172a}
939
+ #clear-btn{background:#FFE2B5; color:#0A0A0A; border:1px solid var(--border)}
940
+
941
+ /* Result spacing */
942
+ #result-table{margin-bottom:16px}
943
+
944
+ /* Figure container */
945
+ .figure-wrap{border:1px solid var(--border);border-radius:12px;overflow:hidden;box-shadow:var(--shadow)}
946
+ .figure-wrap img{display:block;width:100%;height:auto}
947
+
948
+ /* Right pane: vertical radio layout and full-width controls (kept for button styling) */
949
+ .right-pane .gr-button{
950
+ width:100% !important;
951
+ height:48px !important;
952
+ border-radius:12px !important;
953
+ font-weight:700 !important;
954
+ letter-spacing:.2px;
955
+ }
956
+ /* ───────── Publication links (Bulma-like) ───────── */
957
+
958
+ .publication-links {
959
+ display: flex;
960
+ justify-content: center;
961
+ gap: 14px;
962
+ flex-wrap: wrap;
963
+ margin: 6px 0 18px;
964
+ }
965
+
966
+ .link-block a {
967
+ display: inline-flex;
968
+ align-items: center;
969
+ gap: 8px;
970
+ padding: 10px 18px;
971
+ font-size: 14px;
972
+ font-weight: 600;
973
+ border-radius: 9999px;
974
+ text-decoration: none;
975
+ transition: all 0.15s ease-in-out;
976
+ }
977
+
978
+ /* colour variants */
979
+ .btn-danger { background:#e2e8f0; color:#0f172a; }
980
+ .btn-dark { background:#e2e8f0; color:#0f172a; }
981
+ .btn-link { background:#e2e8f0; color:#0f172a; }
982
+ .btn-warning { background:#e2e8f0; color:#0f172a; }
983
+
984
+ .link-block a:hover {
985
+ filter: brightness(0.95);
986
+ transform: translateY(-1px);
987
+ }
988
+
989
+ .loscalzo-block img {
990
+ height: 100px;
991
+ width: auto;
992
+ object-fit: contain;
993
+ }
994
+
995
+ .loscalzo-block {
996
+ display: flex;
997
+ align-items: center;
998
+ gap: 10px;
999
+
1000
+ margin: 0 auto;
1001
+ justify-content: center;
1002
+ }
1003
+
1004
+
1005
+ }
1006
+ """
1007
+ with gr.Blocks(
1008
+ theme=gr.themes.Default(),
1009
+ css=css,
1010
+ head="""
1011
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.1/css/all.min.css">
1012
+ """
1013
+ ) as demo:
1014
+ gr.Markdown("<h1>ExplainBind: Token-level Protein–Ligand Interaction Visualiser</h1>")
1015
+ gr.Markdown(
1016
+ '<p class="subtle">'
1017
+ 'Upload structure → Extract SA sequence · Paste SMILES/SELFIES · '
1018
+ 'Choose interaction channel and Top-K mode'
1019
+ '</p>'
1020
+ )
1021
+
1022
+
1023
+ # ===== Loscalzo logo + group name =====
1024
+ gr.HTML(f"""
1025
+ <div class="loscalzo-block">
1026
+ <img src="data:image/png;base64,{LOSCAZLO_B64}"
1027
+ alt="Loscalzo Research Group logo" />
1028
+ <a class="loscalzo-name"
1029
+ href="https://ogephd.hms.harvard.edu/people/joseph-loscalzo"
1030
+ target="_blank" rel="noopener">
1031
+ </a>
1032
+ </div>
1033
+ """)
1034
+ # ───────────────────────────────
1035
+ # Top links
1036
+ # ───────────────────────────────
1037
+ gr.Markdown("""
1038
+ <div class="publication-links">
1039
+
1040
+ <span class="link-block">
1041
+ <a href="https://arxiv.org/abs/2406.01651" target="_blank" class="btn-danger">
1042
+ <i class="fa-solid fa-file-lines"></i> Paper
1043
+ </a>
1044
+ </span>
1045
+
1046
+ <span class="link-block">
1047
+ <a href="https://github.com/ZhaohanM/ExplainBind" target="_blank" class="btn-dark">
1048
+ <i class="fa-brands fa-github"></i> Source Code
1049
+ </a>
1050
+ </span>
1051
+
1052
+ <span class="link-block">
1053
+ <a href="ZhaohanM.github.io/ExplainBind/" target="_blank" class="btn-link">
1054
+ <i class="fa-solid fa-globe"></i> Project Page
1055
+ </a>
1056
+ </span>
1057
+ </div>
1058
+ """)
1059
+
1060
+ # <span class="link-block">
1061
+ # <a href="https://huggingface.co/spaces/Zhaohan-Meng/InteractBind" target="_blank" class="btn-warning">
1062
+ # 🗄 Datasets
1063
+ # </a>
1064
+ # </span>
1065
+ # gr.Markdown("""
1066
+ # <div class="publication-links" style="text-align:center; margin-top: 10px;">
1067
+
1068
+ # <!-- ───── Project Badges ───── -->
1069
+ # <a href="https://zhaohanm.github.io/ExplainBind/" target="_blank">
1070
+ # <img src="https://img.shields.io/badge/Project-Page-4285F4?style=for-the-badge&logo=googlelens&logoColor=4285F4">
1071
+ # </a>
1072
+
1073
+ # <a href="https://doi.org/10.1101/2022.09.16.508229" target="_blank">
1074
+ # <img src="https://img.shields.io/badge/bioRxiv-10.1101%2F2022.09.16.508229-orange?style=for-the-badge">
1075
+ # </a>
1076
+
1077
+ # <a href="https://github.com/ZhaohanM/ExplainBind/blob/main/LICENSE" target="_blank">
1078
+ # <img src="https://img.shields.io/badge/License-MIT-green?style=for-the-badge">
1079
+ # </a>
1080
+
1081
+ # <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fgithub.com%2FZhaohanM%2FExplainBind" target="_blank">
1082
+ # <img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fgithub.com%2FZhaohanM%2FExplainBind&label=Views&countColor=%23f36f43&style=for-the-badge">
1083
+ # </a>
1084
+
1085
+ # </div>
1086
+ # """)
1087
+
1088
+
1089
+ # ───────────────────────────────
1090
+ # Guidelines
1091
+ # ───────────────────────────────
1092
+ with gr.Accordion("Guidelines for Users", open=True, elem_classes=["card"]):
1093
+ gr.HTML("""
1094
+ <ol style="font-size:1rem;line-height:1.6;margin-left:22px;">
1095
+ <li>
1096
+ <strong>Input types:</strong>
1097
+ The model supports <em> structure-aware (SA)</em> or <em>FASTA</em> protein sequences,
1098
+ and <em>SMILES</em> or <em>SELFIES</em> ligand inputs.
1099
+ </li>
1100
+ <li>
1101
+ <strong>Extract sequence:</strong>
1102
+ (1) Converts <em>SMILES</em> to <em>SELFIES</em>;
1103
+ (2) Extracts an <em>SA</em> sequence from uploaded
1104
+ <code>.pdb</code> or <code>.cif</code> files.
1105
+ </li>
1106
+ <li>
1107
+ <strong>Top-K mode:</strong>
1108
+ <ul style="margin-top:6px;">
1109
+ <li>
1110
+ <em>Top-K residues-atom pairs</em>:
1111
+ ranks individual protein-residue and ligand-atom pairs by attention score.
1112
+ </li>
1113
+ <li>
1114
+ <em>Top-K residues</em>:
1115
+ ranks protein residues by attention aggregated over all ligand tokens.
1116
+ </li>
1117
+ </ul>
1118
+ </li>
1119
+ <li>
1120
+ <strong>Inference output:</strong>
1121
+ Generates a token-level attention heat map
1122
+ and a corresponding results table
1123
+ based on the selected Top-K mode.
1124
+ </li>
1125
+ </ol>
1126
+ """)
1127
+
1128
+
1129
+ # ───────────────────────────────
1130
+ # Inputs (left) + Controls (right)
1131
+ # ───────────────────────────────
1132
+ with gr.Row():
1133
+ with gr.Column(elem_classes=["card", "grid-2"]):
1134
+ # ── Left: Inputs ──
1135
+ with gr.Column(elem_id="left"):
1136
+ protein_seq = gr.Textbox(
1137
+ label="Protein structure-aware / FASTA sequence",
1138
+ lines=3,
1139
+ placeholder="Paste SA/FASTA sequence or click Extract…",
1140
+ elem_id="protein-seq",
1141
+ )
1142
+
1143
+ drug_seq = gr.Textbox(
1144
+ label="Ligand (SELFIES / SMILES)",
1145
+ lines=3,
1146
+ placeholder="Paste SELFIES or SMILES",
1147
+ elem_id="drug-seq",
1148
+ )
1149
+
1150
+ structure_file = gr.File(
1151
+ label="Upload protein structure (.pdb / .cif)",
1152
+ file_types=[".pdb", ".cif"],
1153
+ elem_id="structure-file",
1154
+ )
1155
+
1156
+ gr.Examples(
1157
+ examples=[[
1158
+ "MTLSILVAHDLQRVIGFENQLPWHLPNDLKHVKKLSTGHTLVMGRKTFESIGKPLPNRRNVVLTSDTSFNVEGVDVIHSIEDIYQLPGHVFIFGGQTLFEEMIDKVDDMYITVIEGKFRGDTFFPPYTFEDWEVASSVEGKLDEKNTIPHTFLHLIRKK",
1159
+ "[C][O][C][=C][C][Branch1][=C][C][C][=C][N][=C][Branch1][C][N][N][=C][Ring1][#Branch1][N][=C][C][Branch1][Ring1][O][C][=C][Ring1][P][O][C]"
1160
+ ]],
1161
+ inputs=[protein_seq, drug_seq],
1162
+ label="Click to load an example",
1163
+ )
1164
+
1165
+ # ── Right: Controls ──
1166
+ with gr.Column(elem_id="right", elem_classes=["right-pane"]):
1167
+ head_dd = gr.Dropdown(
1168
+ label="Interaction Type/Overall",
1169
+ choices=INTERACTION_NAMES,
1170
+ value="Overall Interaction",
1171
+ interactive=True,
1172
+ )
1173
+
1174
+ mode_dd = gr.Dropdown(
1175
+ label="Top-K selection mode",
1176
+ choices=[
1177
+ "Top-K residues-atom pairs",
1178
+ "Top-K residues",
1179
+ ],
1180
+ value="Top-K residues-atom pairs",
1181
+ interactive=True,
1182
+ )
1183
+
1184
+ top_k_dd = gr.Dropdown(
1185
+ label="Top-K value",
1186
+ choices=[str(i) for i in range(1, 21)],
1187
+ value="1",
1188
+ interactive=True,
1189
+ )
1190
+
1191
+ btn_extract = gr.Button("Extract / Convert sequences", elem_id="extract-btn")
1192
+ btn_infer = gr.Button("Inference", elem_id="inference-btn")
1193
+ clear_btn = gr.Button("Clear", elem_id="clear-btn")
1194
+
1195
+ # ───────────────────────────────
1196
+ # Outputs
1197
+ # ───────────────────────────────
1198
+ with gr.Column(elem_classes=["card"]):
1199
+ status_box = gr.HTML(elem_id="status-box")
1200
+ output_table = gr.HTML(elem_id="result-table")
1201
+ output_heat = gr.HTML(elem_id="result-heat")
1202
+
1203
+ # ───────────────────────────────
1204
+ # Wiring
1205
+ # ───────────────────────────────
1206
+ btn_extract.click(
1207
+ fn=extract_sequence_cb,
1208
+ inputs=[
1209
+ structure_file,
1210
+ drug_seq,
1211
+ protein_seq,
1212
+ ],
1213
+ outputs=[
1214
+ protein_seq,
1215
+ drug_seq,
1216
+ status_box,
1217
+ ],
1218
+ )
1219
+
1220
+ btn_infer.click(
1221
+ fn=inference_cb,
1222
+ inputs=[protein_seq, drug_seq, head_dd, top_k_dd, mode_dd],
1223
+ outputs=[output_table, output_heat],
1224
+ )
1225
+
1226
+ clear_btn.click(
1227
+ fn=clear_cb,
1228
+ inputs=[],
1229
+ outputs=[
1230
+ protein_seq,
1231
+ drug_seq,
1232
+ output_table,
1233
+ output_heat,
1234
+ structure_file,
1235
+ status_box,
1236
+ ],
1237
+ )
1238
+
1239
+ if __name__ == "__main__":
1240
+ demo.launch(
1241
+ server_name="127.0.0.1",
1242
+ server_port=7860,
1243
+ share=True,
1244
+ inbrowser=False,
1245
+ show_error=True,
1246
+ )
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ biopython
4
+ transformers
5
+ selfies
6
+ rdkit-pypi
7
+ gradio==6.5.1
8
+ matplotlib
9
+ scipy<1.9.0
10
+ numpy<1.23.0
11
+ scikit-learn
12
+ pandas
13
+ ipython
save_model_ckp/fasta_selfies/best_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19fbf033e333cc71a47931880eef70fc404101cc0b887ef1e123b1f9f8fe4624
3
+ size 35855458
save_model_ckp/sa_selfies/best_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3402a8f2f1ac484ca42ca4f54e96ad67e5e1c764a02e70a76dda42b6f8a1c2d5
3
+ size 35855458
utils/foldseek_util.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ import sys
7
+
8
+ from Bio.PDB import PDBParser, MMCIFParser
9
+
10
+
11
+ sys.path.append(".")
12
+
13
+
14
+ # Get structural seqs from pdb file
15
+ def get_struc_seq(foldseek,
16
+ path,
17
+ chains: list = None,
18
+ process_id: int = 0,
19
+ plddt_mask: bool = "auto",
20
+ plddt_threshold: float = 70.,
21
+ foldseek_verbose: bool = False) -> dict:
22
+ """
23
+
24
+ Args:
25
+ foldseek: Binary executable file of foldseek
26
+
27
+ path: Path to pdb file
28
+
29
+ chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
30
+
31
+ process_id: Process ID for temporary files. This is used for parallel processing.
32
+
33
+ plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
34
+
35
+ plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
36
+
37
+ foldseek_verbose: If True, foldseek will print verbose messages.
38
+
39
+ Returns:
40
+ seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
41
+ (seq, struc_seq, combined_seq).
42
+ """
43
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
44
+ assert os.path.exists(path), f"PDB file not found: {path}"
45
+
46
+ tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
47
+ if foldseek_verbose:
48
+ cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
49
+ else:
50
+ cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
51
+ os.system(cmd)
52
+
53
+ # Check whether the structure is predicted by AlphaFold2
54
+ if plddt_mask == "auto":
55
+ with open(path, "r") as r:
56
+ plddt_mask = True if "alphafold" in r.read().lower() else False
57
+
58
+ seq_dict = {}
59
+ name = os.path.basename(path)
60
+ with open(tmp_save_path, "r") as r:
61
+ for i, line in enumerate(r):
62
+ desc, seq, struc_seq = line.split("\t")[:3]
63
+
64
+ # Mask low plddt
65
+ if plddt_mask:
66
+ try:
67
+ plddts = extract_plddt(path)
68
+ assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
69
+
70
+ # Mask regions with plddt < threshold
71
+ indices = np.where(plddts < plddt_threshold)[0]
72
+ np_seq = np.array(list(struc_seq))
73
+ np_seq[indices] = "#"
74
+ struc_seq = "".join(np_seq)
75
+
76
+ except Exception as e:
77
+ print(f"Error: {e}")
78
+ print(f"Failed to mask plddt for {name}")
79
+
80
+ name_chain = desc.split(" ")[0]
81
+ chain = name_chain.replace(name, "").split("_")[-1]
82
+
83
+ if chains is None or chain in chains:
84
+ if chain not in seq_dict:
85
+ combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
86
+ seq_dict[chain] = (seq, struc_seq, combined_seq)
87
+
88
+ os.remove(tmp_save_path)
89
+ os.remove(tmp_save_path + ".dbtype")
90
+ return seq_dict
91
+
92
+
93
+ def extract_plddt(pdb_path: str) -> np.ndarray:
94
+ """
95
+ Extract plddt scores from pdb file.
96
+ Args:
97
+ pdb_path: Path to pdb file.
98
+
99
+ Returns:
100
+ plddts: plddt scores.
101
+ """
102
+
103
+ # Initialize parser
104
+ if pdb_path.endswith(".cif"):
105
+ parser = MMCIFParser()
106
+ elif pdb_path.endswith(".pdb"):
107
+ parser = PDBParser()
108
+ else:
109
+ raise ValueError("Invalid file format for plddt extraction. Must be '.cif' or '.pdb'.")
110
+
111
+ structure = parser.get_structure('protein', pdb_path)
112
+ model = structure[0]
113
+ chain = model["A"]
114
+
115
+ # Extract plddt scores
116
+ plddts = []
117
+ for residue in chain:
118
+ residue_plddts = []
119
+ for atom in residue:
120
+ plddt = atom.get_bfactor()
121
+ residue_plddts.append(plddt)
122
+
123
+ plddts.append(np.mean(residue_plddts))
124
+
125
+ plddts = np.array(plddts)
126
+ return plddts
127
+
128
+
129
+ def transform_pdb_dir(foldseek: str, pdb_dir: str, seq_type: str, save_path: str):
130
+ """
131
+ Transform a directory of pdb files into a fasta file.
132
+ Args:
133
+ foldseek: Binary executable file of foldseek.
134
+
135
+ pdb_dir: Directory of pdb files.
136
+
137
+ seq_type: Type of sequence to be extracted. Must be "aa" or "foldseek"
138
+
139
+ save_path: Path to save the fasta file.
140
+ """
141
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
142
+ assert seq_type in ["aa", "foldseek"], f"seq_type must be 'aa' or 'foldseek'!"
143
+
144
+ tmp_save_path = f"get_struc_seq_{time.time()}.tsv"
145
+ cmd = f"{foldseek} structureto3didescriptor --chain-name-mode 1 {pdb_dir} {tmp_save_path}"
146
+ os.system(cmd)
147
+
148
+ with open(tmp_save_path, "r") as r, open(save_path, "w") as w:
149
+ for line in r:
150
+ protein_id, aa_seq, foldseek_seq = line.strip().split("\t")[:3]
151
+
152
+ if seq_type == "aa":
153
+ w.write(f">{protein_id}\n{aa_seq}\n")
154
+ else:
155
+ w.write(f">{protein_id}\n{foldseek_seq.lower()}\n")
156
+
157
+ os.remove(tmp_save_path)
158
+ os.remove(tmp_save_path + ".dbtype")
159
+
160
+
161
+ if __name__ == '__main__':
162
+ foldseek = "/sujin/bin/foldseek"
163
+ # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
164
+ test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
165
+ plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
166
+ res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
167
+ print(res["A"][1].lower())
utils/metric_learning_models_att_maps.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append("../")
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.cuda.amp import autocast
11
+ from torch.nn import Module
12
+ from tqdm import tqdm
13
+ from torch.nn.utils.weight_norm import weight_norm
14
+ from torch.utils.data import Dataset
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+
18
+ class ExplainBind(nn.Module):
19
+ def __init__(self, prot_out_dim, drug_out_dim, args):
20
+ super(ExplainBind, self).__init__()
21
+ self.fusion = args.fusion
22
+ self.drug_reg = nn.Linear(drug_out_dim, 768)
23
+ self.prot_reg = nn.Linear(prot_out_dim, 768)
24
+
25
+ if self.fusion == "CAN":
26
+ self.can_layer = CAN_Layer(hidden_dim=768, num_heads=8, args=args)
27
+ self.mlp_classifier = MlPdecoder_CAN(input_dim=1536)
28
+ elif self.fusion == "Nan":
29
+ self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
30
+
31
+ def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
32
+ # print("drug_embed", drug_embed.shape)
33
+ if self.fusion == "Nan":
34
+ prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
35
+ drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
36
+ joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
37
+ score = self.mlp_classifier_nan(joint_embed)
38
+ else:
39
+ prot_embed = self.prot_reg(prot_embed)
40
+ drug_embed = self.drug_reg(drug_embed)
41
+
42
+ if self.fusion == "CAN":
43
+ joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
44
+
45
+ score = self.mlp_classifier(joint_embed)
46
+
47
+ return score, att
48
+
49
+ class Pre_encoded(nn.Module):
50
+ def __init__(
51
+ self, prot_encoder, drug_encoder, args
52
+ ):
53
+ """Constructor for the model.
54
+
55
+ Args:
56
+ prot_encoder (_type_): Protein sturcture-aware sequence encoder.
57
+ drug_encoder (_type_): Drug SFLFIES encoder.
58
+ args (_type_): _description_
59
+ """
60
+ super(Pre_encoded, self).__init__()
61
+ self.prot_encoder = prot_encoder
62
+ self.drug_encoder = drug_encoder
63
+
64
+ def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
65
+ # Process inputs through encoders
66
+ prot_embed = self.prot_encoder(
67
+ input_ids=prot_input_ids,
68
+ attention_mask=prot_attention_mask,
69
+ output_hidden_states=True, # Request hidden states
70
+ return_dict=True
71
+ ).hidden_states[-1]
72
+ # prot_embed = self.prot_reg(prot_embed)
73
+
74
+ drug_embed = self.drug_encoder(
75
+ input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
76
+ ).last_hidden_state # .last_hidden_state
77
+
78
+ # print("drug_embed", drug_embed.shape)
79
+
80
+ return prot_embed, drug_embed
81
+
82
+
83
+ class CAN_Layer(nn.Module):
84
+ def __init__(self, hidden_dim, num_heads, args):
85
+ super(CAN_Layer, self).__init__()
86
+ self.agg_mode = args.agg_mode
87
+ self.group_size = args.group_size # Control Fusion Scale
88
+ self.hidden_dim = hidden_dim
89
+ self.num_heads = num_heads
90
+ self.head_size = hidden_dim // num_heads
91
+
92
+ self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
93
+ self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
94
+ self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
95
+
96
+ self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
97
+ self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
98
+ self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
99
+
100
+ def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
101
+ N, L1, L2, H = logits.shape
102
+ mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
103
+ mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
104
+ mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
105
+
106
+ logits = torch.where(mask_pair, logits, logits - inf)
107
+ alpha = torch.softmax(logits, dim=2)
108
+ mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
109
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
110
+ return alpha
111
+
112
+ def apply_heads(self, x, n_heads, n_ch):
113
+ s = list(x.size())[:-1] + [n_heads, n_ch]
114
+ return x.view(*s)
115
+
116
+ def group_embeddings(self, x, mask, group_size):
117
+ N, L, D = x.shape
118
+ groups = L // group_size
119
+ x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
120
+ mask_grouped = mask.view(N, groups, group_size).any(dim=2)
121
+ return x_grouped, mask_grouped
122
+
123
+ def forward(self, protein, drug, mask_prot, mask_drug):
124
+ # Group embeddings before applying multi-head attention
125
+ protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
126
+ drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
127
+
128
+ # print("protein_grouped:", protein_grouped.shape)
129
+ # print("mask_prot_grouped:", mask_prot_grouped.shape)
130
+
131
+ # Compute queries, keys, values for both protein and drug after grouping
132
+ query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
133
+ key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
134
+ value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
135
+
136
+ query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
137
+ key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
138
+ value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
139
+
140
+ # Compute attention scores
141
+ logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
142
+ logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
143
+ logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
144
+ logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
145
+ # print("logits_pp:", logits_pp.shape)
146
+
147
+ alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
148
+ alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
149
+ alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
150
+ alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
151
+
152
+ prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
153
+ torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
154
+ drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
155
+ torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
156
+
157
+ # print("prot_embedding:", prot_embedding.shape)
158
+
159
+ # Continue as usual with the aggregation mode
160
+ if self.agg_mode == "cls":
161
+ prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
162
+ drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
163
+ elif self.agg_mode == "mean_all_tok":
164
+ prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
165
+ drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
166
+ elif self.agg_mode == "mean":
167
+ prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
168
+ drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
169
+ else:
170
+ raise NotImplementedError()
171
+
172
+
173
+ query_embed = torch.cat([prot_embed, drug_embed], dim=1)
174
+
175
+ return query_embed, alpha_dp
176
+
177
+ class MlPdecoder_CAN(nn.Module):
178
+ def __init__(self, input_dim):
179
+ super(MlPdecoder_CAN, self).__init__()
180
+ self.fc1 = nn.Linear(input_dim, input_dim)
181
+ self.bn1 = nn.BatchNorm1d(input_dim)
182
+ self.fc2 = nn.Linear(input_dim, input_dim // 2)
183
+ self.bn2 = nn.BatchNorm1d(input_dim // 2)
184
+ self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
185
+ self.bn3 = nn.BatchNorm1d(input_dim // 4)
186
+ self.output = nn.Linear(input_dim // 4, 1)
187
+
188
+ def forward(self, x):
189
+ x = self.bn1(torch.relu(self.fc1(x)))
190
+ x = self.bn2(torch.relu(self.fc2(x)))
191
+ x = self.bn3(torch.relu(self.fc3(x)))
192
+ x = torch.sigmoid(self.output(x))
193
+ return x
194
+
195
+ class MLPdecoder_BAN(nn.Module):
196
+ def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
197
+ super(MLPdecoder_BAN, self).__init__()
198
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
199
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
200
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
201
+ self.bn2 = nn.BatchNorm1d(hidden_dim)
202
+ self.fc3 = nn.Linear(hidden_dim, out_dim)
203
+ self.bn3 = nn.BatchNorm1d(out_dim)
204
+ self.fc4 = nn.Linear(out_dim, binary)
205
+
206
+ def forward(self, x):
207
+ x = self.bn1(F.relu(self.fc1(x)))
208
+ x = self.bn2(F.relu(self.fc2(x)))
209
+ x = self.bn3(F.relu(self.fc3(x)))
210
+ # x = self.fc4(x)
211
+ x = torch.sigmoid(self.fc4(x))
212
+ return x
213
+
214
+ class BANLayer(nn.Module):
215
+ """ Bilinear attention network
216
+ Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
217
+ """
218
+ def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
219
+ super(BANLayer, self).__init__()
220
+
221
+ self.c = 32
222
+ self.k = k
223
+ self.v_dim = v_dim
224
+ self.q_dim = q_dim
225
+ self.h_dim = h_dim
226
+ self.h_out = h_out
227
+
228
+ self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
229
+ self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
230
+ # self.dropout = nn.Dropout(dropout[1])
231
+ if 1 < k:
232
+ self.p_net = nn.AvgPool1d(self.k, stride=self.k)
233
+
234
+ if h_out <= self.c:
235
+ self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
236
+ self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
237
+ else:
238
+ self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
239
+
240
+ self.bn = nn.BatchNorm1d(h_dim)
241
+
242
+ def attention_pooling(self, v, q, att_map):
243
+ fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
244
+ if 1 < self.k:
245
+ fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
246
+ fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
247
+ return fusion_logits
248
+
249
+ def forward(self, v, q, softmax=False):
250
+ v_num = v.size(1)
251
+ q_num = q.size(1)
252
+ # print("v_num", v_num)
253
+ # print("v_num ", v_num)
254
+ if self.h_out <= self.c:
255
+ v_ = self.v_net(v)
256
+ q_ = self.q_net(q)
257
+ # print("v_", v_.shape)
258
+ # print("q_ ", q_.shape)
259
+ att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
260
+ # print("Attention map_1",att_maps.shape)
261
+ else:
262
+ v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
263
+ q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
264
+ d_ = torch.matmul(v_, q_) # b x h_dim x v x q
265
+ att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
266
+ att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
267
+ # print("Attention map_2",att_maps.shape)
268
+ if softmax:
269
+ p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
270
+ att_maps = p.view(-1, self.h_out, v_num, q_num)
271
+ # print("Attention map_softmax", att_maps.shape)
272
+ logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
273
+ for i in range(1, self.h_out):
274
+ logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
275
+ logits += logits_i
276
+ logits = self.bn(logits)
277
+ return logits, att_maps
278
+
279
+
280
+ class FCNet(nn.Module):
281
+ """Simple class for non-linear fully connect network
282
+ Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
283
+ """
284
+
285
+ def __init__(self, dims, act='ReLU', dropout=0):
286
+ super(FCNet, self).__init__()
287
+
288
+ layers = []
289
+ for i in range(len(dims) - 2):
290
+ in_dim = dims[i]
291
+ out_dim = dims[i + 1]
292
+ if 0 < dropout:
293
+ layers.append(nn.Dropout(dropout))
294
+ layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
295
+ if '' != act:
296
+ layers.append(getattr(nn, act)())
297
+ if 0 < dropout:
298
+ layers.append(nn.Dropout(dropout))
299
+ layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
300
+ if '' != act:
301
+ layers.append(getattr(nn, act)())
302
+
303
+ self.main = nn.Sequential(*layers)
304
+
305
+ def forward(self, x):
306
+ return self.main(x)
307
+
308
+
309
+ class BatchFileDataset_Case(Dataset):
310
+ def __init__(self, file_list):
311
+ self.file_list = file_list
312
+
313
+ def __len__(self):
314
+ return len(self.file_list)
315
+
316
+ def __getitem__(self, idx):
317
+ batch_file = self.file_list[idx]
318
+ data = torch.load(batch_file)
319
+ return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']