Irwiny123 commited on
Commit
7a1bbaf
·
1 Parent(s): a4262ae

添加PhysDock初始代码

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +204 -0
  2. .idea/.gitignore +8 -0
  3. .idea/PhysDock.iml +12 -0
  4. .idea/inspectionProfiles/Project_Default.xml +24 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +7 -0
  7. .idea/modules.xml +8 -0
  8. .idea/vcs.xml +6 -0
  9. License +21 -0
  10. PhysDock/__init__.py +3 -0
  11. PhysDock/configs.py +195 -0
  12. PhysDock/configs_old.py +245 -0
  13. PhysDock/data/__init__.py +109 -0
  14. PhysDock/data/alignment_runner.py +937 -0
  15. PhysDock/data/alignment_runner_v2.py +327 -0
  16. PhysDock/data/constants/PDBData.py +348 -0
  17. PhysDock/data/constants/__init__.py +0 -0
  18. PhysDock/data/constants/periodic_table.py +27 -0
  19. PhysDock/data/constants/residue_constants.py +562 -0
  20. PhysDock/data/constants/restype_constants.py +107 -0
  21. PhysDock/data/feature_loader.py +1283 -0
  22. PhysDock/data/feature_loader_plinder.py +1258 -0
  23. PhysDock/data/generate_system.py +148 -0
  24. PhysDock/data/relaxation.py +259 -0
  25. PhysDock/data/tools/PDBData.py +348 -0
  26. PhysDock/data/tools/__init__.py +0 -0
  27. PhysDock/data/tools/alignment_runner.py +588 -0
  28. PhysDock/data/tools/convert_unifold_template_to_stfold.py +127 -0
  29. PhysDock/data/tools/dataset_manager.py +570 -0
  30. PhysDock/data/tools/feature_processing_multimer.py +257 -0
  31. PhysDock/data/tools/get_metrics.py +294 -0
  32. PhysDock/data/tools/hhblits.py +175 -0
  33. PhysDock/data/tools/hhsearch.py +126 -0
  34. PhysDock/data/tools/hmmalign.py +66 -0
  35. PhysDock/data/tools/hmmbuild.py +165 -0
  36. PhysDock/data/tools/hmmsearch.py +137 -0
  37. PhysDock/data/tools/jackhmmer.py +262 -0
  38. PhysDock/data/tools/kalign.py +114 -0
  39. PhysDock/data/tools/mmcif_parsing.py +519 -0
  40. PhysDock/data/tools/msa_identifiers.py +90 -0
  41. PhysDock/data/tools/msa_pairing.py +496 -0
  42. PhysDock/data/tools/nhmmer.py +257 -0
  43. PhysDock/data/tools/parse_msas.py +328 -0
  44. PhysDock/data/tools/parsers.py +727 -0
  45. PhysDock/data/tools/rdkit.py +220 -0
  46. PhysDock/data/tools/residue_constants.py +604 -0
  47. PhysDock/data/tools/templates.py +1357 -0
  48. PhysDock/data/tools/utils.py +48 -0
  49. PhysDock/models/__init__.py +0 -0
  50. PhysDock/models/layers/__init__.py +0 -0
.gitignore ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Marimo
198
+ marimo/_static/
199
+ marimo/_lsp/
200
+ __marimo__/
201
+
202
+ # Streamlit
203
+ .streamlit/secrets.toml
204
+ params/params.pt
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/PhysDock.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="PhysDock" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="11">
8
+ <item index="0" class="java.lang.String" itemvalue="tqdm" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="deepspeed" />
11
+ <item index="3" class="java.lang.String" itemvalue="PyYAML" />
12
+ <item index="4" class="java.lang.String" itemvalue="pytorch_lightning" />
13
+ <item index="5" class="java.lang.String" itemvalue="ml-collections" />
14
+ <item index="6" class="java.lang.String" itemvalue="torch" />
15
+ <item index="7" class="java.lang.String" itemvalue="typing-extensions" />
16
+ <item index="8" class="java.lang.String" itemvalue="numpy" />
17
+ <item index="9" class="java.lang.String" itemvalue="requests" />
18
+ <item index="10" class="java.lang.String" itemvalue="dm-tree" />
19
+ </list>
20
+ </value>
21
+ </option>
22
+ </inspection_tool>
23
+ </profile>
24
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="PhysDock" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="PhysDock" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/PhysDock.iml" filepath="$PROJECT_DIR$/.idea/PhysDock.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
License ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ShanghaiTech University
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PhysDock/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from PhysDock.models.model import PhysDock
2
+ from PhysDock.models.loss import PhysDockLoss
3
+ from PhysDock.configs import PhysDockConfig
PhysDock/configs.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections as mlc
2
+
3
+
4
+ def PhysDockConfig(
5
+ inference_mode=True,
6
+ model_name="medium",
7
+ num_augmentation_sample=48,
8
+
9
+ crop_size=256,
10
+ atom_crop_size=256 * 8,
11
+
12
+ alpha_confifdence=1e-4,
13
+ alpha_diffusion=4,
14
+ alpha_bond=0,
15
+ alpha_distogram=3e-2,
16
+ alpha_pae=0,
17
+ inf=1e9,
18
+ eps=1e-8,
19
+
20
+
21
+ # Inference Config
22
+ infer_pocket_type="atom", # "ca"
23
+ infer_pocket_cutoff=6, # 8 10 12
24
+ infer_pocket_dist_type="ligand", # "ligand_centre"
25
+ infer_use_pocket=True,
26
+ infer_use_key_res=True,
27
+
28
+ # Training Config
29
+ train_pocket_type_atom_ratio=0.5,
30
+ train_pocket_cutoff_ligand_min=6,
31
+ train_pocket_cutoff_ligand_max=12,
32
+ train_pocket_cutoff_ligand_centre_min=10,
33
+ train_pocket_cutoff_ligand_centre_max=16,
34
+ train_pocket_dist_type_ligand_ratio=0.5,
35
+ train_use_pocket_ratio=0.5,
36
+ train_use_key_res_ratio=0.5,
37
+
38
+ train_shuffle_sym_id=True,
39
+ train_spatial_crop_ligand_ratio=0.2,
40
+ train_spatial_crop_interface_ratio=0.4,
41
+ train_spatial_crop_interface_threshold=15.,
42
+ train_charility_augmentation_ratio=0.1,
43
+ train_use_template_ratio=0.75,
44
+ train_template_mask_max_ratio=0.4,
45
+
46
+ # Other Configs
47
+ max_msa_clusters=128,
48
+ key_res_random_mask_ratio=0.5,
49
+ token_bond_threshold=2.4,
50
+ sigma_data=16.,
51
+ ):
52
+ ref_dim = 167
53
+ target_dim = 65
54
+ msa_dim = 34
55
+
56
+ inf = inf
57
+ eps = eps
58
+
59
+ c_m = 256 # 256
60
+ c_s = 512 # 1024
61
+ c_z = 128 # 64 | 128
62
+ c_a = 128 # 128
63
+ c_ap = 16 # 16 | 32
64
+
65
+ if model_name == "toy":
66
+ no_blocks_atom = 2
67
+ no_blocks_evoformer = 2
68
+ no_blocks_pairformer = 2
69
+ no_blocks_dit = 2
70
+ no_blocks_heads = 2
71
+ elif model_name == "tiny":
72
+ no_blocks_atom = 2
73
+ no_blocks_evoformer = 2
74
+ no_blocks_pairformer = 8
75
+ no_blocks_dit = 4
76
+ no_blocks_heads = 2
77
+ elif model_name == "small":
78
+ no_blocks_atom = 2
79
+ no_blocks_evoformer = 3
80
+ no_blocks_pairformer = 16
81
+ no_blocks_dit = 8
82
+ no_blocks_heads = 2
83
+ elif model_name == "medium":
84
+ no_blocks_atom = 3
85
+ no_blocks_evoformer = 4
86
+ no_blocks_pairformer = 24
87
+ no_blocks_dit = 12
88
+ no_blocks_heads = 3
89
+ elif model_name == "full":
90
+ no_blocks_atom = 3
91
+ no_blocks_evoformer = 4
92
+ no_blocks_pairformer = 48
93
+ no_blocks_dit = 24
94
+ no_blocks_heads = 4
95
+ else:
96
+ raise ValueError("Unknown model name")
97
+
98
+ config = {
99
+ "inference_mode": inference_mode,
100
+ "sigma_data": sigma_data,
101
+ "data": {
102
+ "crop_size": crop_size,
103
+ "atom_crop_size": atom_crop_size,
104
+ "max_msa_seqs": 16384,
105
+ "max_uniprot_msa_seqs": 8192,
106
+ "interface_threshold": 15,
107
+ "token_bond_threshold": token_bond_threshold,
108
+ "covalent_bond_threshold": 1.8,
109
+ "max_msa_clusters": max_msa_clusters,
110
+ "resample_msa_in_recycling": True,
111
+ },
112
+ "model": {
113
+ "c_z": c_z,
114
+ "num_augmentation_sample": num_augmentation_sample,
115
+ "diffusion_conditioning": {
116
+ "ref_dim": ref_dim,
117
+ "target_dim": target_dim,
118
+ "msa_dim": msa_dim,
119
+ "c_a": c_a,
120
+ "c_ap": c_ap,
121
+ "c_s": c_s,
122
+ "c_m": c_m,
123
+ "c_z": c_z,
124
+ "inf": inf,
125
+ "eps": eps,
126
+ "no_blocks_atom": no_blocks_atom,
127
+ "no_blocks_evoformer": no_blocks_evoformer,
128
+ "no_blocks_pairformer": no_blocks_pairformer
129
+ },
130
+ "dit": {
131
+ "c_a": c_a,
132
+ "c_ap": c_ap,
133
+ "c_s": c_s,
134
+ "c_z": c_z,
135
+ "inf": inf,
136
+ "eps": eps,
137
+ "no_blocks_atom": no_blocks_atom,
138
+ "no_blocks_dit": no_blocks_dit,
139
+ "sigma_data": sigma_data
140
+ },
141
+ "confidence_module": {
142
+ "c_a": c_a,
143
+ "c_ap": c_ap,
144
+ "c_s": c_s,
145
+ "c_z": c_z,
146
+ "inf": inf,
147
+ "eps": eps,
148
+ "no_blocks_heads": no_blocks_heads,
149
+ "no_blocks_atom": no_blocks_atom,
150
+ }
151
+ },
152
+ "loss": {
153
+ "weighted_mse_loss": {
154
+ "weight": alpha_diffusion,
155
+ "sigma_data": sigma_data,
156
+ "alpha_dna": 5.0,
157
+ "alpha_rna": 5.0,
158
+ "alpha_ligand": 10.0,
159
+ },
160
+ "smooth_lddt_loss": {
161
+ "weight": alpha_diffusion,
162
+ "max_clamp_distance": 15.,
163
+ },
164
+
165
+ "bond_loss": {
166
+ "weight": alpha_diffusion * alpha_bond,
167
+ "sigma_data": sigma_data,
168
+ },
169
+ "key_res_loss": {
170
+ "weight": alpha_diffusion * alpha_bond,
171
+ "sigma_data": sigma_data,
172
+ },
173
+ "distogram_loss": {
174
+ "weight": alpha_distogram,
175
+ "min_bin": 3.25,
176
+ "max_bin": 50.75,
177
+ "no_bins": 39,
178
+ "eps": 1e-9,
179
+ },
180
+ "plddt_loss": {
181
+ "weight": alpha_confifdence,
182
+ "no_bins": 50,
183
+ },
184
+ "pae_loss": {
185
+ "weight": alpha_confifdence * alpha_pae,
186
+ },
187
+ "pde_loss": {
188
+ "weight": alpha_confifdence,
189
+ "min_bin": 0,
190
+ "max_bin": 32,
191
+ "no_bins": 64,
192
+ },
193
+ }
194
+ }
195
+ return mlc.ConfigDict(config)
PhysDock/configs_old.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections as mlc
2
+
3
+
4
+ def model_config(
5
+ model_name="full",
6
+ max_recycling_iters=1, # 0
7
+ max_msa_clusters=128, # 32
8
+ crop_size=256, #
9
+ num_augmentation_sample=48, # 128
10
+ alpha_confifdence=1e-4,
11
+ alpha_diffusion=4,
12
+ alpha_bond=0,
13
+ alpha_distogram=3e-2,
14
+ alpha_pae=0,
15
+ use_template=True, # False
16
+ use_mini_rollout=True, # False
17
+ use_flash_attn=False, # False
18
+ custom_rel_token=-1, # 42
19
+ ref_dim=1 + 2 + 2 + 128 + 256, # 167
20
+ mini_rollout_steps=20,
21
+ atom_attention_type="full",
22
+ templ_dim=108,
23
+ interaction_aware=True,
24
+ ):
25
+ sigma_data = 16
26
+ # ref_dim = 1 + 2 + 2 + 128 + 256
27
+ msa_dim = 34
28
+ templ_dim = templ_dim
29
+
30
+ inf = 1e9
31
+ eps = 1e-8
32
+
33
+ pair_dropout = 0.25
34
+ msa_dropout = 0.15
35
+
36
+ c_m = 256 # 256
37
+ c_s = 768 # 1024
38
+ c_z = 128 # 64 | 128
39
+ c_tz = 64
40
+ c_a = 128 # 128
41
+ c_ap = 16 # 16 | 32
42
+
43
+ no_blocks_templ = 2
44
+ no_blocks_evo = 48
45
+ no_blocks_atom = 3
46
+ no_blocks_dit = 24
47
+ no_blocks_heads = 4
48
+ if model_name == "small_toy":
49
+ no_blocks_templ = 1
50
+ no_blocks_evo = 1
51
+ no_blocks_atom = 1
52
+ no_blocks_dit = 1
53
+ no_blocks_heads = 1
54
+ elif model_name == "toy":
55
+ no_blocks_templ = 2
56
+ no_blocks_evo = 2
57
+ no_blocks_atom = 2
58
+ no_blocks_dit = 2
59
+ no_blocks_heads = 2
60
+
61
+ elif model_name == "small":
62
+ no_blocks_templ = 2
63
+ no_blocks_evo = 4
64
+ no_blocks_atom = 2
65
+ no_blocks_dit = 2
66
+ no_blocks_heads = 2
67
+ elif model_name == "docking":
68
+ no_blocks_templ = 2
69
+ no_blocks_evo = 8
70
+ no_blocks_atom = 2
71
+ no_blocks_dit = 4
72
+ no_blocks_heads = 2
73
+ elif model_name == "medium":
74
+ no_blocks_templ = 2
75
+ no_blocks_evo = 16
76
+ no_blocks_atom = 3
77
+ no_blocks_dit = 8
78
+ no_blocks_heads = 2
79
+ elif model_name == "large":
80
+ no_blocks_templ = 2
81
+ no_blocks_evo = 24
82
+ no_blocks_atom = 3
83
+ no_blocks_dit = 12
84
+ no_blocks_heads = 4
85
+ elif model_name == "full":
86
+ no_blocks_templ = 2
87
+ no_blocks_evo = 48
88
+ no_blocks_atom = 3
89
+ no_blocks_dit = 24
90
+ no_blocks_heads = 4
91
+
92
+ return mlc.ConfigDict({
93
+ "use_template": use_template,
94
+ "use_mini_rollout": use_mini_rollout,
95
+ "mini_rollout_steps": mini_rollout_steps,
96
+
97
+ "data": {
98
+ "crop_size": crop_size,
99
+ "atom_crop_factor": 10,
100
+ "max_msa_seqs": 16384,
101
+ "max_uniprot_msa_seqs": 8192,
102
+ "interface_threshold": 15,
103
+ "token_bond_threshold": 2.4,
104
+ "covalent_bond_threshold": 1.8,
105
+ "max_msa_clusters": max_msa_clusters,
106
+ "resample_msa_in_recycling": True,
107
+ "max_recycling_iters": max_recycling_iters, # TODO 3
108
+ "sample_msa": {
109
+ "max_msa_clusters": 128,
110
+ "resample_msa_in_recycling": True,
111
+ },
112
+ "make_crop_ids": {
113
+ "crop_size": 384
114
+ }
115
+ },
116
+ "model": {
117
+ "input_feature_embedder": {
118
+ "msa_dim": msa_dim,
119
+ "ref_dim": ref_dim,
120
+ "c_s": c_s,
121
+ "c_m": c_m,
122
+ "c_z": c_z,
123
+ "c_ap": c_ap,
124
+ "c_a": c_a,
125
+ "no_heads": 4,
126
+ "c_hidden": 16,
127
+ "inf": inf,
128
+ "eps": eps,
129
+ "no_blocks": 3,
130
+ "interaction_aware": interaction_aware,
131
+ "custom_rel_token": custom_rel_token,
132
+ },
133
+ "template_pair_embedder": {
134
+ "templ_dim": templ_dim,
135
+ "c_z": c_z,
136
+ "c_tz": c_tz,
137
+ "c_hidden_tz": 16,
138
+ "no_heads_tz": 4,
139
+ "inf": inf,
140
+ "eps": eps,
141
+ "no_blocks": no_blocks_templ,
142
+ },
143
+ "recycling_embedder": {
144
+ "c_m": c_m,
145
+ "c_z": c_z,
146
+ },
147
+ "evoformer_stack": {
148
+ "c_m": c_m,
149
+ "c_z": c_z,
150
+ "c_hidden_m": 32,
151
+ "no_heads_m": 8,
152
+ "c_hidden_z": 32,
153
+ "no_heads_z": 4,
154
+ "c_hidden_opm": 32,
155
+ "inf": inf,
156
+ "eps": eps,
157
+ "no_blocks": no_blocks_evo,
158
+ "single_mode": False,
159
+ },
160
+ "diffusion_module": {
161
+ "ref_dim": ref_dim,
162
+ "c_m": c_m,
163
+ "c_s": c_s,
164
+ "c_z": c_z,
165
+ "c_a": c_a,
166
+ "c_ap": c_ap,
167
+ "no_heads_atom": 4,
168
+ "c_hidden_atom": 16,
169
+ "no_heads": c_ap,
170
+ "c_hidden": 32,
171
+ "inf": inf,
172
+ "eps": eps,
173
+ "no_blocks": no_blocks_dit,
174
+ "no_blocks_atom": no_blocks_atom,
175
+ "num_augmentation_sample": num_augmentation_sample,
176
+ "custom_rel_token": custom_rel_token,
177
+ "use_flash_attn": use_flash_attn,
178
+ "atom_attention_type": atom_attention_type
179
+ },
180
+ "confidence_module": {
181
+ "c_a": c_a,
182
+ "c_ap": c_ap,
183
+ "c_s": c_s,
184
+ "c_m": c_m,
185
+ "c_z": c_z,
186
+ "no_heads_a": 4,
187
+ "c_hidden_a": 16,
188
+ "c_hidden_m": 32,
189
+ "no_heads_m": 8,
190
+ "c_hidden_z": 32,
191
+ "no_heads_z": 4,
192
+ "c_hidden_opm": 32,
193
+ "inf": inf,
194
+ "eps": eps,
195
+ "no_blocks": no_blocks_heads,
196
+ "no_blocks_atom": no_blocks_atom,
197
+ "c_pae": 64,
198
+ "c_pde": 64,
199
+ "c_plddt": 50,
200
+ "c_distogram": 39,
201
+ },
202
+ "loss": {
203
+ "weighted_mse_loss": {
204
+ "weight": alpha_diffusion,
205
+ "sigma_data": sigma_data,
206
+ "alpha_dna": 5.0,
207
+ "alpha_rna": 5.0,
208
+ "alpha_ligand": 10.0,
209
+ },
210
+ "smooth_lddt_loss": {
211
+ "weight": alpha_diffusion,
212
+ "max_clamp_distance": 15.,
213
+ },
214
+ "clamp_distance_loss": {
215
+ "weight": alpha_diffusion * 0.2,
216
+ "max_clamp_distance": 10,
217
+ },
218
+
219
+ "bond_loss": {
220
+ "weight": alpha_diffusion * alpha_bond,
221
+ "sigma_data": sigma_data,
222
+ },
223
+ "distogram_loss": {
224
+ "weight": alpha_distogram,
225
+ "min_bin": 3.25,
226
+ "max_bin": 50.75,
227
+ "no_bins": 39,
228
+ "eps": 1e-9,
229
+ },
230
+ "plddt_loss": {
231
+ "weight": alpha_confifdence,
232
+ "no_bins": 50,
233
+ },
234
+ "pae_loss": {
235
+ "weight": alpha_confifdence * alpha_pae,
236
+ },
237
+ "pde_loss": {
238
+ "weight": alpha_confifdence,
239
+ "min_bin": 0,
240
+ "max_bin": 32,
241
+ "no_bins": 64,
242
+ },
243
+ }
244
+ }
245
+ })
PhysDock/data/__init__.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union, Any
2
+ import torch
3
+ import numpy as np
4
+ from scipy.sparse.coo import coo_matrix
5
+
6
+ # TODO: Keep Only ref_mask eq 1 for all atomwise feature
7
+
8
+
9
+ '''
10
+ Notation:
11
+ batch,
12
+ token dimension: i, j, k
13
+ flat atom dimension: l, m | WARNING: We should flatten due to local atom attention mask
14
+ sequence dimension: s(msa) t(time)
15
+ head dimension: h
16
+
17
+ ####################
18
+ z_ij: pair repr
19
+ {z_ij}: all pair repr
20
+ x: atom position
21
+ {x_l}: flat atom list, full atomic structure
22
+ exist a mapping: flat atom index -> token index and within token atom index: l -> i, a
23
+
24
+
25
+ ####################
26
+ a: atom representation
27
+ have the same shape as s
28
+ exist a transform: flat atom representation -> atom represenation
29
+ s: token representation
30
+ z: pair representation
31
+
32
+ '''
33
+
34
+ FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
35
+ TensorDict = Dict[str, Union[torch.Tensor, Any]]
36
+
37
+ PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
38
+
39
+ NUM_CONFORMER = "num conformer placeholder"
40
+ NUM_TOKEN = "num tokens placeholder"
41
+ NUM_ATOM = "num atoms placeholder"
42
+
43
+ NUM_SEQ = "num MSAs placeholder"
44
+ NUM_TEMPL = "num templates placeholder"
45
+
46
+ NUM_RECYCLING = "num recycling placeholder"
47
+ NUM_SAMPLE = "num sample placeholder"
48
+
49
+ SHAPE_SCHIME = {
50
+ ################################################################
51
+ # Conformerwise Feature
52
+
53
+ # Tokenwise Feature
54
+ "residue_index": [NUM_TOKEN],
55
+ "restype": [NUM_TOKEN],
56
+ "token_index": [NUM_TOKEN],
57
+ "s_mask": [NUM_TOKEN],
58
+ "is_protein": [NUM_TOKEN],
59
+ "is_rna": [NUM_TOKEN],
60
+ "is_dna": [NUM_TOKEN],
61
+ "is_ligand": [NUM_TOKEN],
62
+ "token_id_to_centre_atom_id": [NUM_TOKEN],
63
+ "token_id_to_pseudo_beta_atom_id": [NUM_TOKEN],
64
+ "token_id_to_chunk_sizes": [NUM_TOKEN],
65
+ "token_id_to_conformer_id": [NUM_TOKEN],
66
+ "asym_id": [NUM_TOKEN],
67
+ "entity_id": [NUM_TOKEN],
68
+ "sym_id": [NUM_TOKEN],
69
+ "token_bonds": [NUM_TOKEN, NUM_TOKEN],
70
+ "target_feat": [NUM_TOKEN],
71
+ "token_exists": [NUM_TOKEN],
72
+ "spatial_crop_target_res_mask": [NUM_TOKEN],
73
+
74
+ # Atomwise features
75
+ "ref_space_uid": [NUM_ATOM],
76
+ "atom_index": [NUM_ATOM],
77
+ "ref_feat": [NUM_ATOM, 389],
78
+ "ref_pos": [NUM_ATOM, 3],
79
+ "a_mask": [NUM_ATOM],
80
+ "atom_id_to_token_id": [NUM_ATOM],
81
+ "x_gt": [NUM_ATOM, 3],
82
+ "x_exists": [NUM_ATOM],
83
+ "rec_mask": [NUM_ATOM, NUM_ATOM],
84
+
85
+ "msa": [NUM_SEQ, NUM_TOKEN],
86
+ "deletion_matrix": [NUM_SEQ, NUM_TOKEN],
87
+ "msa_feat": [NUM_SEQ, NUM_TOKEN, None],
88
+ "crop_idx": [None],
89
+ "crop_idx_atom": [None],
90
+
91
+ #
92
+ "x_centre": [None],
93
+ # # Template features
94
+ # "template_restype": [NUM_TEMPLATES, NUM_TOKEN],
95
+ # "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_TOKEN],
96
+ # "template_backbone_frame_mask": [NUM_TEMPLATES, NUM_TOKEN],
97
+ # "template_distogram": [NUM_TEMPLATES, NUM_TOKEN, NUM_TOKEN, 39],
98
+ # "template_unit_vector": [NUM_TEMPLATES, NUM_TOKEN, NUM_TOKEN, 3],
99
+
100
+ ###########################################################
101
+ }
102
+
103
+ SUPERVISED_FEATURES = [
104
+
105
+ ]
106
+
107
+ UNSUPERVISED_FEATURES = [
108
+
109
+ ]
PhysDock/data/alignment_runner.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ import shutil
4
+ from functools import partial
5
+ import tqdm
6
+ from typing import Optional, Mapping, Any, Union
7
+
8
+ from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
9
+ from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
10
+ from PhysDock.data.tools.parsers import parse_fasta
11
+
12
+ TemplateSearcher = Union[hhsearch.HHSearch]
13
+
14
+
15
+ class AlignmentRunner:
16
+ def __init__(
17
+ self,
18
+ # Homo Search Tools
19
+ jackhmmer_binary_path: Optional[str] = None,
20
+ hhblits_binary_path: Optional[str] = None,
21
+ nhmmer_binary_path: Optional[str] = None,
22
+ hmmbuild_binary_path: Optional[str] = None,
23
+ hmmalign_binary_path: Optional[str] = None,
24
+ kalign_binary_path: Optional[str] = None,
25
+
26
+ # Templ Search Tools
27
+ hhsearch_binary_path: Optional[str] = None,
28
+ template_searcher: Optional[TemplateSearcher] = None,
29
+ template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
30
+
31
+ # Databases
32
+ uniref90_database_path: Optional[str] = None,
33
+ uniprot_database_path: Optional[str] = None,
34
+ uniclust30_database_path: Optional[str] = None,
35
+ uniref30_database_path: Optional[str] = None,
36
+ bfd_database_path: Optional[str] = None,
37
+ reduced_bfd_database_path: Optional[str] = None,
38
+ mgnify_database_path: Optional[str] = None,
39
+ rfam_database_path: Optional[str] = None,
40
+ rnacentral_database_path: Optional[str] = None,
41
+ nt_database_path: Optional[str] = None,
42
+ #
43
+ no_cpus: int = 8,
44
+ # Limitations
45
+ uniref90_seq_limit: int = 100000,
46
+ uniprot_seq_limit: int = 500000,
47
+ reduced_bfd_seq_limit: int = 50000,
48
+ mgnify_seq_limit: int = 50000,
49
+ uniref90_max_hits: int = 10000,
50
+ uniprot_max_hits: int = 50000,
51
+ reduced_bfd_max_hits: int = 5000,
52
+ mgnify_max_hits: int = 5000,
53
+ rfam_max_hits: int = 10000,
54
+ rnacentral_max_hits: int = 10000,
55
+ nt_max_hits: int = 10000,
56
+ ):
57
+ self.uniref90_jackhmmer_runner = None
58
+ self.uniprot_jackhmmer_runner = None
59
+ self.reduced_bfd_jackhmmer_runner = None
60
+ self.mgnify_jackhmmer_runner = None
61
+ self.bfd_uniref30_hhblits_runner = None
62
+ self.bfd_uniclust30_hhblits_runner = None
63
+ self.rfam_nhmmer_runner = None
64
+ self.rnacentral_nhmmer_runner = None
65
+ self.nt_nhmmer_runner = None
66
+ self.rna_realign_runner = None
67
+ self.template_searcher = template_searcher
68
+ self.template_featurizer = template_featurizer
69
+
70
+ def _all_exists(*objs, hhblits_mode=False):
71
+ if not hhblits_mode:
72
+ for obj in objs:
73
+ if obj is None or not os.path.exists(obj):
74
+ return False
75
+ else:
76
+ for obj in objs:
77
+ if obj is None or not os.path.exists(os.path.split(obj)[0]):
78
+ return False
79
+ return True
80
+
81
+ def _run_msa_tool(
82
+ fasta_path: str,
83
+ msa_out_path: str,
84
+ msa_runner,
85
+ msa_format: str,
86
+ max_sto_sequences: Optional[int] = None,
87
+ ) -> Mapping[str, Any]:
88
+ """Runs an MSA tool, checking if output already exists first."""
89
+ if (msa_format == "sto" and max_sto_sequences is not None):
90
+ result = msa_runner.query(fasta_path, max_sto_sequences)[0]
91
+ else:
92
+ result = msa_runner.query(fasta_path)[0]
93
+
94
+ assert msa_out_path.split('.')[-1] == msa_format
95
+ with open(msa_out_path, "w") as f:
96
+ f.write(result[msa_format])
97
+
98
+ return result
99
+
100
+ def _run_rna_realign_tool(
101
+ fasta_path: str,
102
+ msa_in_path: str,
103
+ msa_out_path: str,
104
+ use_precompute=True,
105
+ ):
106
+ runner = hmmalign.Hmmalign(
107
+ hmmbuild_binary_path=hmmbuild_binary_path,
108
+ hmmalign_binary_path=hmmalign_binary_path,
109
+ )
110
+ if os.path.exists(msa_in_path) and os.path.getsize(msa_in_path) == 0:
111
+ # print("MSA sto file is 0")
112
+ with open(msa_out_path, "w") as f:
113
+ pass
114
+ return
115
+ if use_precompute:
116
+ if os.path.exists(msa_in_path) and os.path.exists(msa_out_path):
117
+ if os.path.getsize(msa_in_path) > 0 and os.path.getsize(msa_out_path) == 0:
118
+ logging.warning(f"The msa realign file size is zero but the origin file size is over 0! "
119
+ f"fasta: {fasta_path} msa_in_file: {msa_in_path}")
120
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
121
+ else:
122
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
123
+ else:
124
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
125
+ # with open(msa_out_path, "w") as f:
126
+ # f.write(msa_out)
127
+
128
+ assert uniclust30_database_path is None or uniref30_database_path is None, "Only one used"
129
+
130
+ # Jackhmmer
131
+ if _all_exists(jackhmmer_binary_path, uniref90_database_path):
132
+ self.uniref90_jackhmmer_runner = partial(
133
+ _run_msa_tool,
134
+ msa_runner=jackhmmer.Jackhmmer(
135
+ binary_path=jackhmmer_binary_path,
136
+ database_path=uniref90_database_path,
137
+ seq_limit=uniref90_seq_limit,
138
+ n_cpu=no_cpus,
139
+ ),
140
+ msa_format="sto",
141
+ max_sto_sequences=uniref90_max_hits
142
+ )
143
+
144
+ if _all_exists(jackhmmer_binary_path, uniprot_database_path):
145
+ self.uniprot_jackhmmer_runner = partial(
146
+ _run_msa_tool,
147
+ msa_runner=jackhmmer.Jackhmmer(
148
+ binary_path=jackhmmer_binary_path,
149
+ database_path=uniprot_database_path,
150
+ seq_limit=uniprot_seq_limit,
151
+ n_cpu=no_cpus,
152
+ ),
153
+ msa_format="sto",
154
+ max_sto_sequences=uniprot_max_hits
155
+ )
156
+ if _all_exists(jackhmmer_binary_path, reduced_bfd_database_path):
157
+ self.reduced_bfd_jackhmmer_runner = partial(
158
+ _run_msa_tool,
159
+ msa_runner=jackhmmer.Jackhmmer(
160
+ binary_path=jackhmmer_binary_path,
161
+ database_path=reduced_bfd_database_path,
162
+ seq_limit=reduced_bfd_seq_limit,
163
+ n_cpu=no_cpus,
164
+ ),
165
+ msa_format="sto",
166
+ max_sto_sequences=reduced_bfd_max_hits
167
+ )
168
+
169
+ if _all_exists(jackhmmer_binary_path, mgnify_database_path):
170
+ self.mgnify_jackhmmer_runner = partial(
171
+ _run_msa_tool,
172
+ msa_runner=jackhmmer.Jackhmmer(
173
+ binary_path=jackhmmer_binary_path,
174
+ database_path=mgnify_database_path,
175
+ seq_limit=mgnify_seq_limit,
176
+ n_cpu=no_cpus,
177
+ ),
178
+ msa_format="sto",
179
+ max_sto_sequences=mgnify_max_hits
180
+ )
181
+
182
+ # HHblits
183
+ if _all_exists(hhblits_binary_path, bfd_database_path, uniref30_database_path, hhblits_mode=True):
184
+ self.bfd_uniref30_hhblits_runner = partial(
185
+ _run_msa_tool,
186
+ msa_runner=hhblits.HHBlits(
187
+ binary_path=hhblits_binary_path,
188
+ databases=[bfd_database_path, uniref30_database_path],
189
+ n_cpu=no_cpus,
190
+ ),
191
+ msa_format="a3m",
192
+ )
193
+ elif _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
194
+ self.bfd_uniclust30_hhblits_runner = partial(
195
+ _run_msa_tool,
196
+ msa_runner=hhblits.HHBlits(
197
+ binary_path=hhblits_binary_path,
198
+ databases=[bfd_database_path, uniclust30_database_path],
199
+ n_cpu=no_cpus,
200
+ ),
201
+ msa_format="a3m",
202
+ )
203
+
204
+ # Nhmmer
205
+ if _all_exists(nhmmer_binary_path, rfam_database_path):
206
+ self.rfam_nhmmer_runner = partial(
207
+ _run_msa_tool,
208
+ msa_runner=nhmmer.Nhmmer(
209
+ binary_path=nhmmer_binary_path,
210
+ database_path=rfam_database_path,
211
+ n_cpu=no_cpus
212
+ ),
213
+ msa_format="sto",
214
+ max_sto_sequences=rfam_max_hits
215
+ )
216
+ if _all_exists(nhmmer_binary_path, rnacentral_database_path):
217
+ self.rnacentral_nhmmer_runner = partial(
218
+ _run_msa_tool,
219
+ msa_runner=nhmmer.Nhmmer(
220
+ binary_path=nhmmer_binary_path,
221
+ database_path=rnacentral_database_path,
222
+ n_cpu=no_cpus
223
+ ),
224
+ msa_format="sto",
225
+ max_sto_sequences=rnacentral_max_hits
226
+ )
227
+ if _all_exists(nhmmer_binary_path, nt_database_path):
228
+ self.nt_nhmmer_runner = partial(
229
+ _run_msa_tool,
230
+ msa_runner=nhmmer.Nhmmer(
231
+ binary_path=nhmmer_binary_path,
232
+ database_path=nt_database_path,
233
+ n_cpu=no_cpus
234
+ ),
235
+ msa_format="sto",
236
+ max_sto_sequences=nt_max_hits
237
+ )
238
+
239
+ # def _run_rna_hmm(
240
+ # fasta_path: str,
241
+ # hmm_out_path: str,
242
+ # ):
243
+ # runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
244
+ # hmm = runner.build_rna_profile_from_fasta(fasta_path)
245
+ # with open(hmm_out_path, "w") as f:
246
+ # f.write(hmm)
247
+
248
+ if _all_exists(hmmbuild_binary_path, hmmalign_binary_path):
249
+ self.rna_realign_runner = _run_rna_realign_tool
250
+
251
+ def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
252
+ os.makedirs(output_msas_dir, exist_ok=True)
253
+ templates_out_path = os.path.join(output_msas_dir, "templates")
254
+ uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
255
+ uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
256
+ reduced_bfd_out_path = os.path.join(output_msas_dir, "reduced_bfd_hits.sto")
257
+ mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
258
+ bfd_uniref30_out_path = os.path.join(output_msas_dir, f"bfd_uniref30_hits.a3m")
259
+ bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
260
+ rfam_out_path = os.path.join(output_msas_dir, f"rfam_hits.sto")
261
+ rfam_out_realigned_path = os.path.join(output_msas_dir, f"rfam_hits_realigned.sto")
262
+ rnacentral_out_path = os.path.join(output_msas_dir, f"rnacentral_hits.sto")
263
+ rnacentral_out_realigned_path = os.path.join(output_msas_dir, f"rnacentral_hits_realigned.sto")
264
+ nt_out_path = os.path.join(output_msas_dir, f"nt_hits.sto")
265
+ nt_out_realigned_path = os.path.join(output_msas_dir, f"nt_hits_realigned.sto")
266
+
267
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
268
+ prefix = "protein"
269
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
270
+ output_feature = os.path.dirname(output_msas_dir)
271
+ output_feature = os.path.dirname(output_feature)
272
+ pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
273
+ pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
274
+ pkl_save_path_temp = os.path.join(output_feature, "template_features", f"{md5}.pkl.gz")
275
+
276
+ if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_temp):
277
+ if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(templates_out_path):
278
+ # print(input_fasta_path, uniref90_out_path)
279
+ if not os.path.exists(uniref90_out_path):
280
+ print(uniref90_out_path)
281
+ self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
282
+
283
+ print("begin templates")
284
+ if templates_out_path is not None:
285
+ try:
286
+ os.makedirs(templates_out_path, exist_ok=True)
287
+ seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
288
+ input_sequence = seq[0]
289
+ # msa_for_templates = jackhmmer_uniref90_result["sto"]
290
+ msa_for_templates = parsers.truncate_stockholm_msa(
291
+ uniref90_out_path, max_sequences=10000
292
+ )
293
+ msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
294
+ msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
295
+ msa_for_templates
296
+ )
297
+ if self.template_searcher.input_format == "sto":
298
+ pdb_templates_result = self.template_searcher.query(msa_for_templates)
299
+ elif self.template_searcher.input_format == "a3m":
300
+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
301
+ pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
302
+ else:
303
+ raise ValueError(
304
+ "Unrecognized template input format: "
305
+ f"{self.template_searcher.input_format}"
306
+ )
307
+
308
+ pdb_hits_out_path = os.path.join(
309
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
310
+ )
311
+ with open(os.path.join(
312
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
313
+ ), "w") as f:
314
+ f.write(pdb_templates_result)
315
+
316
+ pdb_template_hits = self.template_searcher.get_template_hits(
317
+ output_string=pdb_templates_result, input_sequence=input_sequence
318
+ )
319
+ templates_result = self.template_featurizer.get_templates(
320
+ query_sequence=input_sequence, hits=pdb_template_hits
321
+ )
322
+ except Exception as e:
323
+ logging.exception("An error in template searching")
324
+
325
+ dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
326
+ if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
327
+ if not os.path.exists(uniprot_out_path) or not use_precompute:
328
+ self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
329
+ if self.reduced_bfd_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
330
+ if not os.path.exists(reduced_bfd_out_path) or not use_precompute:
331
+ self.reduced_bfd_jackhmmer_runner(input_fasta_path, reduced_bfd_out_path)
332
+ if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
333
+ if not os.path.exists(mgnify_out_path) or not use_precompute:
334
+ self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
335
+ if self.bfd_uniref30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
336
+ if not os.path.exists(bfd_uniref30_out_path) or not use_precompute:
337
+ self.bfd_uniref30_hhblits_runner(input_fasta_path, bfd_uniref30_out_path)
338
+ if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
339
+ if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
340
+ self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
341
+ # if self.rfam_nhmmer_runner is not None:
342
+ # if not os.path.exists(rfam_out_path) or not use_precompute:
343
+ # self.rfam_nhmmer_runner(input_fasta_path, rfam_out_path)
344
+ # # print(self.rna_realign_runner is not None, os.path.exists(rfam_out_path))
345
+ # if self.rna_realign_runner is not None and os.path.exists(rfam_out_path):
346
+ # self.rna_realign_runner(input_fasta_path, rfam_out_path, rfam_out_realigned_path)
347
+ # if self.rnacentral_nhmmer_runner is not None:
348
+ # if not os.path.exists(rnacentral_out_path) or not use_precompute:
349
+ # self.rnacentral_nhmmer_runner(input_fasta_path, rnacentral_out_path)
350
+ # if self.rna_realign_runner is not None and os.path.exists(rnacentral_out_path):
351
+ # self.rna_realign_runner(input_fasta_path, rnacentral_out_path, rnacentral_out_realigned_path)
352
+ # if self.nt_nhmmer_runner is not None:
353
+ # if not os.path.exists(nt_out_path) or not use_precompute:
354
+ # self.nt_nhmmer_runner(input_fasta_path, nt_out_path)
355
+ # if self.rna_realign_runner is not None and os.path.exists(nt_out_path):
356
+ # # print("realign",nt_out_path,nt_out_realigned_path)
357
+ # self.rna_realign_runner(input_fasta_path, nt_out_path, nt_out_realigned_path)
358
+
359
+
360
+ class DataProcessor:
361
+ def __init__(
362
+ self,
363
+ alphafold3_database_path,
364
+ jackhmmer_binary_path: Optional[str] = None,
365
+ hhblits_binary_path: Optional[str] = None,
366
+ nhmmer_binary_path: Optional[str] = None,
367
+ kalign_binary_path: Optional[str] = None,
368
+ hmmbuild_binary_path: Optional[str] = None,
369
+ hmmalign_binary_path: Optional[str] = None,
370
+ hhsearch_binary_path: Optional[str] = None,
371
+ template_searcher: Optional[TemplateSearcher] = None,
372
+ template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
373
+ n_cpus: int = 8,
374
+ n_workers: int = 1,
375
+ ):
376
+ '''
377
+ Database Versions:
378
+ Training:
379
+ uniref90: v2022_05
380
+ bfd:
381
+ reduces_bfd:
382
+ uniclust30: v2018_08
383
+ uniprot: v2020_05
384
+ mgnify: v2022_05
385
+ rfam: v14.9
386
+ rnacentral: v21.0
387
+ nt: v2023_02_23
388
+ Inference:
389
+ uniref90: v2022_05
390
+ bfd:
391
+ reduces_bfd:
392
+ uniclust30: v2018_08
393
+ uniprot: v2021_04 *
394
+ mgnify: v2022_05
395
+ rfam: v14.9
396
+ rnacentral: v21.0
397
+ nt: v2023_02_23
398
+ Inference Ligand:
399
+ uniref90: v2020_01 *
400
+ bfd:
401
+ reduces_bfd:
402
+ uniclust30: v2018_08
403
+ uniprot: v2020_05
404
+ mgnify: v2018_12 *
405
+ rfam: v14.9
406
+ rnacentral: v21.0
407
+ nt: v2023_02_23
408
+
409
+ Args:
410
+ alphafold3_database_path: Database dir that contains all alphafold3 databases
411
+ jackhmmer_binary_path:
412
+ hhblits_binary_path:
413
+ nhmmer_binary_path:
414
+ kalign_binary_path:
415
+ hmmaligh_binary_path:
416
+ n_cpus:
417
+ n_workers:
418
+ '''
419
+ self.jackhmmer_binary_path = jackhmmer_binary_path
420
+ self.hhblits_binary_path = hhblits_binary_path
421
+ self.nhmmer_binary_path = nhmmer_binary_path
422
+ self.hmmbuild_binary_path = hmmbuild_binary_path
423
+ self.hmmalign_binary_path = hmmalign_binary_path
424
+ self.hhsearch_binary_path = hhsearch_binary_path
425
+
426
+ self.template_searcher = template_searcher
427
+ self.template_featurizer = template_featurizer
428
+
429
+ self.n_cpus = n_cpus
430
+ self.n_workers = n_workers
431
+
432
+ self.uniref90_database_path = os.path.join(
433
+ alphafold3_database_path, "uniref90", "uniref90.fasta"
434
+ )
435
+ ################### TODO: DEBUG
436
+ self.uniprot_database_path = os.path.join(
437
+ alphafold3_database_path, "uniprot", "uniprot.fasta"
438
+ )
439
+ self.bfd_database_path = os.path.join(
440
+ alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
441
+ )
442
+ self.uniclust30_database_path = os.path.join(
443
+ alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
444
+ )
445
+ ################### TODO: check alphafold2 multimer uniref30 version
446
+ self.uniref_30_database_path = os.path.join(
447
+ alphafold3_database_path, "uniref30", "v2020_06"
448
+ )
449
+ # self.reduced_bfd_database_path = os.path.join(
450
+ # alphafold3_database_path,"reduced_bfd"
451
+ # )
452
+
453
+ self.mgnify_database_path = os.path.join(
454
+ alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
455
+ )
456
+ self.rfam_database_path = os.path.join(
457
+ alphafold3_database_path, "rfam", "v14.9", "Rfam_af3_clustered_rep_seq.fasta"
458
+ )
459
+ self.rnacentral_database_path = os.path.join(
460
+ alphafold3_database_path, "rnacentral", "v21.0", "rnacentral_db_rep_seq.fasta"
461
+ )
462
+
463
+ self.nt_database_path = os.path.join(
464
+ # alphafold3_database_path, "nt", "v2023_02_23", "nt_af3_clustered_rep_seq.fasta" # DEBUG
465
+ alphafold3_database_path, "nt", "v2023_02_23", "nt.fasta"
466
+ )
467
+
468
+ self.runner_args_map = {
469
+ "uniref90": {
470
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
471
+ "uniref90_database_path": self.uniref90_database_path,
472
+ },
473
+ "bfd_uniclust30": {
474
+ "hhblits_binary_path": self.hhblits_binary_path,
475
+ "bfd_database_path": self.bfd_database_path,
476
+ "uniclust30_database_path": self.uniclust30_database_path
477
+ },
478
+ "bfd_uniref30": {
479
+ "hhblits_binary_path": self.hhblits_binary_path,
480
+ "bfd_database_path": self.bfd_database_path,
481
+ "uniref_30_database_path": self.uniref_30_database_path
482
+ },
483
+
484
+ "mgnify": {
485
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
486
+ "mgnify_database_path": self.mgnify_database_path,
487
+ },
488
+ "uniprot": {
489
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
490
+ "uniprot_database_path": self.uniprot_database_path,
491
+ },
492
+ ###################### RNA ########################
493
+ "rfam": {
494
+ "nhmmer_binary_path": self.nhmmer_binary_path,
495
+ "rfam_database_path": self.rfam_database_path,
496
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
497
+ "hmmalign_binary_path": self.hmmalign_binary_path,
498
+ },
499
+ "rnacentral": {
500
+ "nhmmer_binary_path": self.nhmmer_binary_path,
501
+ "rnacentral_database_path": self.rnacentral_database_path,
502
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
503
+ "hmmalign_binary_path": self.hmmalign_binary_path,
504
+ },
505
+ "nt": {
506
+ "nhmmer_binary_path": self.nhmmer_binary_path,
507
+ "nt_database_path": self.nt_database_path,
508
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
509
+ "hmmalign_binary_path": self.hmmalign_binary_path,
510
+ },
511
+
512
+ ###################################################
513
+ "alphafold2": {
514
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
515
+ "hhblits_binary_path": self.hhblits_binary_path,
516
+ "uniref90_database_path": self.uniref90_database_path,
517
+ "bfd_database_path": self.bfd_database_path,
518
+ "uniclust30_database_path": self.uniclust30_database_path,
519
+ "mgnify_database_path": self.mgnify_database_path,
520
+ },
521
+ "alphafold2_multimer": {
522
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
523
+ "hhblits_binary_path": self.hhblits_binary_path,
524
+ "uniref90_database_path": self.uniref90_database_path,
525
+ "bfd_database_path": self.bfd_database_path,
526
+ "uniref_30_database_path": self.uniref_30_database_path,
527
+ "mgnify_database_path": self.mgnify_database_path,
528
+ "uniprot_database_path": self.uniprot_database_path,
529
+ },
530
+ "alphafold3": {
531
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
532
+ "hhblits_binary_path": self.hhblits_binary_path,
533
+ "template_searcher": self.template_searcher,
534
+ "template_featurizer": self.template_featurizer,
535
+ "uniref90_database_path": self.uniref90_database_path,
536
+ "bfd_database_path": self.bfd_database_path,
537
+ "uniclust30_database_path": self.uniclust30_database_path,
538
+ "mgnify_database_path": self.mgnify_database_path,
539
+ "uniprot_database_path": self.uniprot_database_path,
540
+ },
541
+
542
+ "rna": {
543
+ "nhmmer_binary_path": self.nhmmer_binary_path,
544
+ "rfam_database_path": self.rfam_database_path,
545
+ "rnacentral_database_path": self.rnacentral_database_path,
546
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
547
+ "hmmalign_binary_path": self.hmmalign_binary_path,
548
+ },
549
+ }
550
+
551
+ def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
552
+ os.makedirs(output_dir, exist_ok=True)
553
+ if isinstance(input_fasta_path, list):
554
+ input_fasta_paths = input_fasta_path
555
+ elif os.path.isdir(input_fasta_path):
556
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
557
+ elif os.path.isfile(input_fasta_path):
558
+ input_fasta_paths = [input_fasta_path]
559
+ else:
560
+ input_fasta_paths = []
561
+ Exception("Can't parse input fasta path!")
562
+ seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
563
+ # sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
564
+ # TODO: debug
565
+ if convert_md5:
566
+ output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
567
+ seqs]
568
+ else:
569
+ output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
570
+ io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
571
+ return io_tuples
572
+
573
+ def _process_iotuple(self, io_tuple, msas_type, use_precompute=True):
574
+ i, o = io_tuple
575
+ alignment_runner = AlignmentRunner(
576
+ **self.runner_args_map[msas_type],
577
+ no_cpus=self.n_cpus
578
+ )
579
+ try:
580
+ alignment_runner.run(i, o, use_precompute=use_precompute)
581
+ except:
582
+ logging.warning(f"{i}:{o} task failed!")
583
+
584
+ def process(self, input_fasta_path, output_dir, msas_type="rfam", convert_md5=True, use_precompute=True):
585
+ prefix = "rna" if msas_type in ["rfam", "rnacentral", "nt", "rna"] else "protein"
586
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
587
+ run_pool_tasks(partial(self._process_iotuple, msas_type=msas_type, use_precompute=use_precompute), io_tuples,
588
+ num_workers=self.n_workers,
589
+ return_dict=False)
590
+
591
+ def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
592
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
593
+ io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
594
+
595
+ for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
596
+ o, o_md5 = io0[1], io1[1]
597
+ os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
598
+
599
+
600
+ def run_homo_search(
601
+ out_dir,
602
+ save_dir,
603
+ feature_dir,
604
+ pdb_70_dir,
605
+ template_mmcif_dir,
606
+ max_template_date="2021-09-30",
607
+ obsolete_pdbs_path=None,
608
+ use_precompute=True,
609
+ ):
610
+ # save_dir = os.path.join(out_dir,"cache")
611
+ data_processor = DataProcessor(
612
+ alphafold3_database_path=feature_dir,
613
+ # nhmmer_binary_path="/usr/bin/nhmmer",
614
+ jackhmmer_binary_path="/usr/bin/jackhmmer",
615
+ hhblits_binary_path="/usr/bin/hhblits",
616
+ hhsearch_binary_path="/usr/bin/hhsearch",
617
+ template_searcher=hhsearch.HHSearch(
618
+ binary_path="/usr/bin/hhsearch",
619
+ databases=[pdb_70_dir]
620
+ ),
621
+ template_featurizer=templates.HhsearchHitFeaturizer(
622
+ mmcif_dir=template_mmcif_dir,
623
+ max_template_date=max_template_date,
624
+ max_hits=20,
625
+ kalign_binary_path="/usr/bin/kalign",
626
+ release_dates_path=None,
627
+ obsolete_pdbs_path=obsolete_pdbs_path,
628
+ ),
629
+ n_cpus=32,
630
+ n_workers=12
631
+ )
632
+
633
+ output_dir = os.path.join(out_dir, "features/msas")
634
+ # output_dir = "/2022133002/data/stfold-data-v5/features/msas"
635
+ # output_dir = "/2022133002/data/benchmark/stfold/dta/features/msas"
636
+ # output_dir = "/2022133002/data/benchmark/features/msas"
637
+
638
+ os.makedirs(output_dir, exist_ok=True)
639
+ files = os.listdir(save_dir)
640
+
641
+ files = [os.path.join(save_dir, file) for file in files]
642
+ # files = chunk_lists(files,num_workers=4)[3]
643
+
644
+ try:
645
+ data_processor.process(
646
+ input_fasta_path=files,
647
+ output_dir=output_dir,
648
+ msas_type="alphafold3",
649
+ convert_md5=True,
650
+ use_precompute=use_precompute
651
+ )
652
+ print(f"save msa to {output_dir}")
653
+ except Exception as e:
654
+ print(e)
655
+ pass
656
+
657
+ # msa_dir = "/2022133002/data/stfold-data-v5/features/msa_features"
658
+ msa_dir = os.path.join(out_dir, "features/msa_features")
659
+ os.makedirs(msa_dir, exist_ok=True)
660
+ from PhysDock.data.tools.dataset_manager import DatasetManager
661
+ from PhysDock.data.tools.convert_unifold_template_to_stfold import \
662
+ convert_unifold_template_feature_to_stfold_unifold_feature
663
+
664
+ try:
665
+ out = DatasetManager.convert_msas_out_to_msa_features(
666
+ input_fasta_path=save_dir,
667
+ output_dir=output_dir,
668
+ msa_feature_dir=msa_dir,
669
+ convert_md5=True,
670
+ num_workers=2
671
+ )
672
+ print(f"save msa feature to {msa_dir}")
673
+ except:
674
+ pass
675
+
676
+ try:
677
+ msa_dir_uni = os.path.join(out_dir, "features/uniprot_msa_features")
678
+ # msa_dir_uni = "/2022133002/data/stfold-data-v5/features/uniprot_msa_features"
679
+ os.makedirs(msa_dir_uni, exist_ok=True)
680
+ out = DatasetManager.convert_msas_out_to_uniprot_msa_features(
681
+ input_fasta_path=save_dir,
682
+ output_dir=output_dir,
683
+ uniprot_msa_feature_dir=msa_dir_uni,
684
+ convert_md5=True,
685
+ num_workers=2
686
+ )
687
+ print(f"save uni msa feature to {msa_dir_uni}")
688
+ except Exception as e:
689
+ print(e)
690
+ pass
691
+
692
+ templ_dir_uni = os.path.join(out_dir, "features/template_features")
693
+ # templ_dir_uni = "/2022133002/data/stfold-data-v5/features/template_features"
694
+ os.makedirs(templ_dir_uni, exist_ok=True)
695
+ try:
696
+ files = os.listdir(save_dir)
697
+ files = [os.path.join(out_dir, file) for file in files[::-1]]
698
+ run_pool_tasks(convert_unifold_template_feature_to_stfold_unifold_feature, files, num_workers=16)
699
+ except:
700
+ pass
701
+
702
+
703
+ class STDockAlignmentRunner():
704
+ def __init__(
705
+ self,
706
+ # Homo Search Tools Path
707
+ jackhmmer_binary_path: Optional[str] = None,
708
+ hhblits_binary_path: Optional[str] = None,
709
+ kalign_binary_path: Optional[str] = None,
710
+ hhsearch_binary_path: Optional[str] = None,
711
+
712
+ # Databases
713
+ uniref90_database_path: Optional[str] = None,
714
+ uniprot_database_path: Optional[str] = None,
715
+ uniclust30_database_path: Optional[str] = None,
716
+ bfd_database_path: Optional[str] = None,
717
+ mgnify_database_path: Optional[str] = None,
718
+ pdb_70_database_path: Optional[str] = None,
719
+ mmcif_files_path: Optional[str] = None,
720
+ obsolete_pdbs_path: Optional[str] = None,
721
+
722
+ # Settings
723
+ max_template_date: str = "2021-09-30",
724
+ max_template_hits: int = 20,
725
+ #
726
+ no_cpus: int = 8,
727
+ # Limitations
728
+ uniref90_seq_limit: int = 100000,
729
+ uniprot_seq_limit: int = 500000,
730
+ mgnify_seq_limit: int = 50000,
731
+ uniref90_max_hits: int = 10000,
732
+ uniprot_max_hits: int = 50000,
733
+ mgnify_max_hits: int = 5000,
734
+ ):
735
+ super().__init__()
736
+ #
737
+ self.jackhmmer_binary_path = jackhmmer_binary_path
738
+ self.hhblits_binary_path = hhblits_binary_path
739
+
740
+ self.uniref90_database_path = uniref90_database_path
741
+ self.mgnify_database_path = mgnify_database_path
742
+ self.uniclust30_database_path = uniclust30_database_path
743
+ self.bfd_database_path = bfd_database_path
744
+ self.uniprot_database_path = uniprot_database_path
745
+
746
+ self.template_searcher = hhsearch.HHSearch(
747
+ binary_path=hhsearch_binary_path,
748
+ databases=[pdb_70_database_path]
749
+ )
750
+ self.template_featurizer = templates.HhsearchHitFeaturizer(
751
+ mmcif_dir=mmcif_files_path,
752
+ max_template_date=max_template_date,
753
+ max_hits=max_template_hits,
754
+ kalign_binary_path=kalign_binary_path,
755
+ obsolete_pdbs_path=obsolete_pdbs_path,
756
+ )
757
+
758
+ def _all_exists(*objs, hhblits_mode=False):
759
+ if not hhblits_mode:
760
+ for obj in objs:
761
+ if obj is None or not os.path.exists(obj):
762
+ return False
763
+ else:
764
+ for obj in objs:
765
+ if obj is None or not os.path.exists(os.path.split(obj)[0]):
766
+ return False
767
+ return True
768
+
769
+ def _run_msa_tool(
770
+ fasta_path: str,
771
+ msa_out_path: str,
772
+ msa_runner,
773
+ msa_format: str,
774
+ max_sto_sequences: Optional[int] = None,
775
+ ) -> Mapping[str, Any]:
776
+ """Runs an MSA tool, checking if output already exists first."""
777
+ if (msa_format == "sto" and max_sto_sequences is not None):
778
+ result = msa_runner.query(fasta_path, max_sto_sequences)[0]
779
+ else:
780
+ result = msa_runner.query(fasta_path)[0]
781
+
782
+ assert msa_out_path.split('.')[-1] == msa_format
783
+ with open(msa_out_path, "w") as f:
784
+ f.write(result[msa_format])
785
+
786
+ return result
787
+
788
+ # Jackhmmer
789
+ if _all_exists(jackhmmer_binary_path, uniref90_database_path):
790
+ self.uniref90_jackhmmer_runner = partial(
791
+ _run_msa_tool,
792
+ msa_runner=jackhmmer.Jackhmmer(
793
+ binary_path=jackhmmer_binary_path,
794
+ database_path=uniref90_database_path,
795
+ seq_limit=uniref90_seq_limit,
796
+ n_cpu=no_cpus,
797
+ ),
798
+ msa_format="sto",
799
+ max_sto_sequences=uniref90_max_hits
800
+ )
801
+
802
+ if _all_exists(jackhmmer_binary_path, uniprot_database_path):
803
+ self.uniprot_jackhmmer_runner = partial(
804
+ _run_msa_tool,
805
+ msa_runner=jackhmmer.Jackhmmer(
806
+ binary_path=jackhmmer_binary_path,
807
+ database_path=uniprot_database_path,
808
+ seq_limit=uniprot_seq_limit,
809
+ n_cpu=no_cpus,
810
+ ),
811
+ msa_format="sto",
812
+ max_sto_sequences=uniprot_max_hits
813
+ )
814
+
815
+ if _all_exists(jackhmmer_binary_path, mgnify_database_path):
816
+ self.mgnify_jackhmmer_runner = partial(
817
+ _run_msa_tool,
818
+ msa_runner=jackhmmer.Jackhmmer(
819
+ binary_path=jackhmmer_binary_path,
820
+ database_path=mgnify_database_path,
821
+ seq_limit=mgnify_seq_limit,
822
+ n_cpu=no_cpus,
823
+ ),
824
+ msa_format="sto",
825
+ max_sto_sequences=mgnify_max_hits
826
+ )
827
+
828
+ # HHblits
829
+ if _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
830
+ self.bfd_uniclust30_hhblits_runner = partial(
831
+ _run_msa_tool,
832
+ msa_runner=hhblits.HHBlits(
833
+ binary_path=hhblits_binary_path,
834
+ databases=[bfd_database_path, uniclust30_database_path],
835
+ n_cpu=no_cpus,
836
+ ),
837
+ msa_format="a3m",
838
+ )
839
+
840
+ def run_protein_msas(
841
+ self,
842
+ input_fasta_path,
843
+ output_msas_dir,
844
+ use_precompute=True,
845
+ copy_to_dataset=False,
846
+ dataset_path=None,
847
+ ):
848
+ os.makedirs(output_msas_dir, exist_ok=True)
849
+ os.makedirs(os.path.join(output_msas_dir, "features"), exist_ok=True)
850
+ os.makedirs(os.path.join(output_msas_dir, "features", "msa_features"), exist_ok=True)
851
+ os.makedirs(os.path.join(output_msas_dir, "features", "uniprot_msa_features"), exist_ok=True)
852
+ os.makedirs(os.path.join(output_msas_dir, "features", "template_features"), exist_ok=True)
853
+
854
+ templates_out_path = os.path.join(output_msas_dir, "templates")
855
+ uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
856
+ uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
857
+ mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
858
+ bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
859
+
860
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
861
+ prefix = "protein"
862
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
863
+
864
+ msa_md5_save_path = os.path.join(output_msas_dir, "features", "msa_features", f"{md5}.pkl.gz")
865
+ uniprot_msa_md5_save_path = os.path.join(output_msas_dir, "features", "uniprot_msa_features", f"{md5}.pkl.gz")
866
+ template_md5_save_path = os.path.join(output_msas_dir, "features", "template_features", f"{md5}.pkl.gz")
867
+
868
+ if dataset_path is not None:
869
+ dataset_msa_md5_save_path = os.path.join(
870
+ dataset_path, "features", "msa_features", f"{md5}.pkl.gz")
871
+ dataset_uniprot_msa_md5_save_path = os.path.join(
872
+ dataset_path, "features", "uniprot_msa_features", f"{md5}.pkl.gz")
873
+ dataset_template_md5_save_path = os.path.join(
874
+ dataset_path, "features", "template_features", f"{md5}.pkl.gz")
875
+
876
+ if os.path.exists(dataset_msa_md5_save_path):
877
+ shutil.copyfile(dataset_msa_md5_save_path, msa_md5_save_path)
878
+ if os.path.exists(dataset_uniprot_msa_md5_save_path):
879
+ shutil.copyfile(dataset_uniprot_msa_md5_save_path, uniprot_msa_md5_save_path)
880
+ if os.path.exists(dataset_template_md5_save_path):
881
+ shutil.copyfile(dataset_template_md5_save_path, template_md5_save_path)
882
+
883
+ if self.uniref90_jackhmmer_runner is not None and not os.path.exists(template_md5_save_path):
884
+ if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(templates_out_path):
885
+ if not os.path.exists(uniref90_out_path):
886
+ self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
887
+ if templates_out_path is not None:
888
+ try:
889
+ os.makedirs(templates_out_path, exist_ok=True)
890
+ seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
891
+ input_sequence = seq[0]
892
+ # msa_for_templates = jackhmmer_uniref90_result["sto"]
893
+ msa_for_templates = parsers.truncate_stockholm_msa(
894
+ uniref90_out_path, max_sequences=10000
895
+ )
896
+ msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
897
+ msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
898
+ msa_for_templates
899
+ )
900
+ if self.template_searcher.input_format == "sto":
901
+ pdb_templates_result = self.template_searcher.query(msa_for_templates)
902
+ elif self.template_searcher.input_format == "a3m":
903
+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
904
+ pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
905
+ else:
906
+ raise ValueError(
907
+ "Unrecognized template input format: "
908
+ f"{self.template_searcher.input_format}"
909
+ )
910
+
911
+ pdb_hits_out_path = os.path.join(
912
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
913
+ )
914
+ with open(os.path.join(
915
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
916
+ ), "w") as f:
917
+ f.write(pdb_templates_result)
918
+
919
+ pdb_template_hits = self.template_searcher.get_template_hits(
920
+ output_string=pdb_templates_result, input_sequence=input_sequence
921
+ )
922
+ templates_result = self.template_featurizer.get_templates(
923
+ query_sequence=input_sequence, hits=pdb_template_hits
924
+ )
925
+ dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
926
+ except Exception as e:
927
+ logging.exception("An error in template searching")
928
+
929
+ if self.uniprot_jackhmmer_runner is not None and not os.path.exists(uniprot_msa_md5_save_path):
930
+ if not os.path.exists(uniprot_out_path) or not use_precompute:
931
+ self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
932
+ if self.mgnify_jackhmmer_runner is not None and not os.path.exists(msa_md5_save_path):
933
+ if not os.path.exists(mgnify_out_path) or not use_precompute:
934
+ self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
935
+ if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(msa_md5_save_path):
936
+ if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
937
+ self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
PhysDock/data/alignment_runner_v2.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ import shutil
4
+ from functools import partial
5
+ import tqdm
6
+ from typing import Optional, Mapping, Any, Union
7
+
8
+ from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
9
+ from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
10
+ from PhysDock.data.tools.parsers import parse_fasta
11
+ from PhysDock.data.tools.dataset_manager import DatasetManager
12
+
13
+ TemplateSearcher = Union[hhsearch.HHSearch]
14
+
15
+
16
+ class AlignmentRunner:
17
+ def __init__(
18
+ self,
19
+ # Databases
20
+ uniref90_database_path: Optional[str] = None,
21
+ uniprot_database_path: Optional[str] = None,
22
+ uniclust30_database_path: Optional[str] = None,
23
+ bfd_database_path: Optional[str] = None,
24
+ mgnify_database_path: Optional[str] = None,
25
+
26
+ # Homo Search Tools
27
+ jackhmmer_binary_path: str = "/usr/bin/jackhmmer",
28
+ hhblits_binary_path: str = "/usr/bin/hhblits",
29
+
30
+ # Params
31
+ no_cpus: int = 8,
32
+
33
+ # Thresholds
34
+ uniref90_seq_limit: int = 100000,
35
+ uniprot_seq_limit: int = 500000,
36
+ mgnify_seq_limit: int = 50000,
37
+ uniref90_max_hits: int = 10000,
38
+ uniprot_max_hits: int = 50000,
39
+ mgnify_max_hits: int = 5000,
40
+ ):
41
+ self.uniref90_jackhmmer_runner = None
42
+ self.uniprot_jackhmmer_runner = None
43
+ self.mgnify_jackhmmer_runner = None
44
+ self.bfd_uniref30_hhblits_runner = None
45
+ self.bfd_uniclust30_hhblits_runner = None
46
+
47
+ def _all_exists(*objs, hhblits_mode=False):
48
+ if not hhblits_mode:
49
+ for obj in objs:
50
+ if obj is None or not os.path.exists(obj):
51
+ return False
52
+ else:
53
+ for obj in objs:
54
+ if obj is None or not os.path.exists(os.path.split(obj)[0]):
55
+ return False
56
+ return True
57
+
58
+ def _run_msa_tool(
59
+ fasta_path: str,
60
+ msa_out_path: str,
61
+ msa_runner,
62
+ msa_format: str,
63
+ max_sto_sequences: Optional[int] = None,
64
+ ) -> Mapping[str, Any]:
65
+ """Runs an MSA tool, checking if output already exists first."""
66
+ if (msa_format == "sto" and max_sto_sequences is not None):
67
+ result = msa_runner.query(fasta_path, max_sto_sequences)[0]
68
+ else:
69
+ result = msa_runner.query(fasta_path)[0]
70
+
71
+ assert msa_out_path.split('.')[-1] == msa_format
72
+ with open(msa_out_path, "w") as f:
73
+ f.write(result[msa_format])
74
+
75
+ return result
76
+
77
+ # Jackhmmer
78
+ if _all_exists(jackhmmer_binary_path, uniref90_database_path):
79
+ self.uniref90_jackhmmer_runner = partial(
80
+ _run_msa_tool,
81
+ msa_runner=jackhmmer.Jackhmmer(
82
+ binary_path=jackhmmer_binary_path,
83
+ database_path=uniref90_database_path,
84
+ seq_limit=uniref90_seq_limit,
85
+ n_cpu=no_cpus,
86
+ ),
87
+ msa_format="sto",
88
+ max_sto_sequences=uniref90_max_hits
89
+ )
90
+
91
+ if _all_exists(jackhmmer_binary_path, uniprot_database_path):
92
+ self.uniprot_jackhmmer_runner = partial(
93
+ _run_msa_tool,
94
+ msa_runner=jackhmmer.Jackhmmer(
95
+ binary_path=jackhmmer_binary_path,
96
+ database_path=uniprot_database_path,
97
+ seq_limit=uniprot_seq_limit,
98
+ n_cpu=no_cpus,
99
+ ),
100
+ msa_format="sto",
101
+ max_sto_sequences=uniprot_max_hits
102
+ )
103
+
104
+ if _all_exists(jackhmmer_binary_path, mgnify_database_path):
105
+ self.mgnify_jackhmmer_runner = partial(
106
+ _run_msa_tool,
107
+ msa_runner=jackhmmer.Jackhmmer(
108
+ binary_path=jackhmmer_binary_path,
109
+ database_path=mgnify_database_path,
110
+ seq_limit=mgnify_seq_limit,
111
+ n_cpu=no_cpus,
112
+ ),
113
+ msa_format="sto",
114
+ max_sto_sequences=mgnify_max_hits
115
+ )
116
+
117
+ # HHblits
118
+ if _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
119
+ self.bfd_uniclust30_hhblits_runner = partial(
120
+ _run_msa_tool,
121
+ msa_runner=hhblits.HHBlits(
122
+ binary_path=hhblits_binary_path,
123
+ databases=[bfd_database_path, uniclust30_database_path],
124
+ n_cpu=no_cpus,
125
+ ),
126
+ msa_format="a3m",
127
+ )
128
+
129
+ def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
130
+ os.makedirs(output_msas_dir, exist_ok=True)
131
+ uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
132
+ uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
133
+ mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
134
+ bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
135
+
136
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
137
+ prefix = "protein"
138
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
139
+ output_feature = os.path.dirname(output_msas_dir)
140
+ output_feature = os.path.dirname(output_feature)
141
+
142
+ pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
143
+ pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
144
+
145
+ if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
146
+ if not os.path.exists(uniref90_out_path) or not use_precompute:
147
+ self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
148
+
149
+ if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
150
+ if not os.path.exists(uniprot_out_path) or not use_precompute:
151
+ self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
152
+ if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
153
+ if not os.path.exists(mgnify_out_path) or not use_precompute:
154
+ self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
155
+ if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
156
+ if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
157
+ self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
158
+
159
+
160
+ class DataProcessor:
161
+ def __init__(
162
+ self,
163
+ bfd_database_path,
164
+ uniclust30_database_path,
165
+ uniref90_database_path,
166
+ mgnify_database_path,
167
+ uniprot_database_path,
168
+ jackhmmer_binary_path: Optional[str] = None,
169
+ hhblits_binary_path: Optional[str] = None,
170
+
171
+ n_cpus: int = 8,
172
+ n_workers: int = 1,
173
+ ):
174
+ '''
175
+ '''
176
+ self.jackhmmer_binary_path = jackhmmer_binary_path
177
+ self.hhblits_binary_path = hhblits_binary_path
178
+
179
+ self.n_cpus = n_cpus
180
+ self.n_workers = n_workers
181
+
182
+ self.uniref90_database_path = uniref90_database_path
183
+ self.uniprot_database_path = uniprot_database_path
184
+ self.bfd_database_path = bfd_database_path
185
+ self.uniclust30_database_path = uniclust30_database_path
186
+ self.mgnify_database_path = mgnify_database_path
187
+
188
+ # self.uniref90_database_path = os.path.join(
189
+ # alphafold3_database_path, "uniref90", "uniref90.fasta"
190
+ # )
191
+ # self.uniprot_database_path = os.path.join(
192
+ # alphafold3_database_path, "uniprot", "uniprot.fasta"
193
+ # )
194
+ # self.bfd_database_path = os.path.join(
195
+ # alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
196
+ # )
197
+ # self.uniclust30_database_path = os.path.join(
198
+ # alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
199
+ # )
200
+ #
201
+ # self.mgnify_database_path = os.path.join(
202
+ # alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
203
+ # )
204
+
205
+ def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
206
+ os.makedirs(output_dir, exist_ok=True)
207
+ if isinstance(input_fasta_path, list):
208
+ input_fasta_paths = input_fasta_path
209
+ elif os.path.isdir(input_fasta_path):
210
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
211
+ elif os.path.isfile(input_fasta_path):
212
+ input_fasta_paths = [input_fasta_path]
213
+ else:
214
+ input_fasta_paths = []
215
+ Exception("Can't parse input fasta path!")
216
+ seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
217
+ # sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
218
+ # TODO: debug
219
+ if convert_md5:
220
+ output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
221
+ seqs]
222
+ else:
223
+ output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
224
+ io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
225
+ return io_tuples
226
+
227
+ def _process_iotuple(self, io_tuple, use_precompute=True):
228
+ i, o = io_tuple
229
+ kwargs = {
230
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
231
+ "hhblits_binary_path": self.hhblits_binary_path,
232
+ "uniref90_database_path": self.uniref90_database_path,
233
+ "bfd_database_path": self.bfd_database_path,
234
+ "uniclust30_database_path": self.uniclust30_database_path,
235
+ "mgnify_database_path": self.mgnify_database_path,
236
+ "uniprot_database_path": self.uniprot_database_path,
237
+ }
238
+ alignment_runner = AlignmentRunner(
239
+ **kwargs,
240
+ no_cpus=self.n_cpus
241
+ )
242
+ try:
243
+ alignment_runner.run(i, o, use_precompute=use_precompute)
244
+ except:
245
+ logging.warning(f"{i}:{o} task failed!")
246
+
247
+ def process(self, input_fasta_path, output_dir, convert_md5=True, use_precompute=True):
248
+ prefix = "protein"
249
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
250
+ run_pool_tasks(partial(self._process_iotuple, use_precompute=use_precompute), io_tuples,
251
+ num_workers=self.n_workers,
252
+ return_dict=False)
253
+
254
+ def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
255
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
256
+ io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
257
+
258
+ for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
259
+ o, o_md5 = io0[1], io1[1]
260
+ os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
261
+
262
+
263
+ def run_homo_search(
264
+ bfd_database_path,
265
+ uniclust30_database_path,
266
+ uniref90_database_path,
267
+ mgnify_database_path,
268
+ uniprot_database_path,
269
+ jackhmmer_binary_path,
270
+ hhblits_binary_path,
271
+
272
+ input_fasta_path,
273
+ out_dir,
274
+
275
+ n_cpus=16,
276
+ n_workers=1,
277
+ ):
278
+ # save_dir = os.path.join(out_dir,"cache")
279
+ data_processor = DataProcessor(
280
+ bfd_database_path,
281
+ uniclust30_database_path,
282
+ uniref90_database_path,
283
+ mgnify_database_path,
284
+ uniprot_database_path,
285
+ jackhmmer_binary_path=jackhmmer_binary_path,
286
+ hhblits_binary_path=hhblits_binary_path,
287
+ n_cpus=n_cpus,
288
+ n_workers=n_workers
289
+ )
290
+
291
+ output_dir = os.path.join(out_dir, "msas")
292
+ os.makedirs(output_dir, exist_ok=True)
293
+ if os.path.isfile(input_fasta_path):
294
+ files = [input_fasta_path]
295
+ else:
296
+ files = os.listdir(input_fasta_path)
297
+ files = [os.path.join(input_fasta_path, file) for file in files[::-1]]
298
+
299
+ data_processor.process(
300
+ input_fasta_path=files,
301
+ output_dir=output_dir,
302
+ convert_md5=True
303
+ )
304
+ print(f"save msa to {output_dir}")
305
+
306
+ msa_dir = os.path.join(out_dir, "msa_features")
307
+ os.makedirs(msa_dir, exist_ok=True)
308
+
309
+ out = DatasetManager.convert_msas_out_to_msa_features(
310
+ input_fasta_path=input_fasta_path,
311
+ output_dir=output_dir,
312
+ msa_feature_dir=msa_dir,
313
+ convert_md5=True,
314
+ num_workers=2
315
+ )
316
+ print(f"save msa feature to {msa_dir}")
317
+
318
+ msa_dir_uni = os.path.join(out_dir, "uniprot_msa_features")
319
+ os.makedirs(msa_dir_uni, exist_ok=True)
320
+ out = DatasetManager.convert_msas_out_to_uniprot_msa_features(
321
+ input_fasta_path=input_fasta_path,
322
+ output_dir=output_dir,
323
+ uniprot_msa_feature_dir=msa_dir_uni,
324
+ convert_md5=True,
325
+ num_workers=2
326
+ )
327
+ print(f"save uni msa feature to {msa_dir_uni}")
PhysDock/data/constants/PDBData.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2000 Andrew Dalke. All rights reserved.
2
+ #
3
+ # This file is part of the Biopython distribution and governed by your
4
+ # choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
5
+ # Please see the LICENSE file that should have been included as part of this
6
+ # package.
7
+ """Information about the IUPAC alphabets."""
8
+
9
+ protein_letters = "ACDEFGHIKLMNPQRSTVWY"
10
+ extended_protein_letters = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
11
+
12
+ # B = "Asx"; aspartic acid or asparagine (D or N)
13
+ # X = "Xxx"; unknown or 'other' amino acid
14
+ # Z = "Glx"; glutamic acid or glutamine (E or Q)
15
+ # http://www.chem.qmul.ac.uk/iupac/AminoAcid/A2021.html#AA212
16
+ #
17
+ # J = "Xle"; leucine or isoleucine (L or I, used in NMR)
18
+ # Mentioned in http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
19
+ # Also the International Nucleotide Sequence Database Collaboration (INSDC)
20
+ # (i.e. GenBank, EMBL, DDBJ) adopted this in 2006
21
+ # http://www.ddbj.nig.ac.jp/insdc/icm2006-e.html
22
+ #
23
+ # Xle (J); Leucine or Isoleucine
24
+ # The residue abbreviations, Xle (the three-letter abbreviation) and J
25
+ # (the one-letter abbreviation) are reserved for the case that cannot
26
+ # experimentally distinguish leucine from isoleucine.
27
+ #
28
+ # U = "Sec"; selenocysteine
29
+ # http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
30
+ #
31
+ # O = "Pyl"; pyrrolysine
32
+ # http://www.chem.qmul.ac.uk/iubmb/newsletter/2009.html#item35
33
+
34
+ protein_letters_1to3 = {
35
+ "A": "Ala",
36
+ "C": "Cys",
37
+ "D": "Asp",
38
+ "E": "Glu",
39
+ "F": "Phe",
40
+ "G": "Gly",
41
+ "H": "His",
42
+ "I": "Ile",
43
+ "K": "Lys",
44
+ "L": "Leu",
45
+ "M": "Met",
46
+ "N": "Asn",
47
+ "P": "Pro",
48
+ "Q": "Gln",
49
+ "R": "Arg",
50
+ "S": "Ser",
51
+ "T": "Thr",
52
+ "V": "Val",
53
+ "W": "Trp",
54
+ "Y": "Tyr",
55
+ }
56
+ protein_letters_1to3 = {k.upper(): v.upper() for k, v in protein_letters_1to3.items()}
57
+ protein_letters_3to1 = {v: k for k, v in protein_letters_1to3.items()}
58
+
59
+ protein_letters_3to1_extended = {
60
+ "A5N": "N", "A8E": "V", "A9D": "S", "AA3": "A", "AA4": "A", "AAR": "R",
61
+ "ABA": "A", "ACL": "R", "AEA": "C", "AEI": "D", "AFA": "N", "AGM": "R",
62
+ "AGQ": "Y", "AGT": "C", "AHB": "N", "AHL": "R", "AHO": "A", "AHP": "A",
63
+ "AIB": "A", "AKL": "D", "AKZ": "D", "ALA": "A", "ALC": "A", "ALM": "A",
64
+ "ALN": "A", "ALO": "T", "ALS": "A", "ALT": "A", "ALV": "A", "ALY": "K",
65
+ "AME": "M", "AN6": "L", "AN8": "A", "API": "K", "APK": "K", "AR2": "R",
66
+ "AR4": "E", "AR7": "R", "ARG": "R", "ARM": "R", "ARO": "R", "AS7": "N",
67
+ "ASA": "D", "ASB": "D", "ASI": "D", "ASK": "D", "ASL": "D", "ASN": "N",
68
+ "ASP": "D", "ASQ": "D", "AYA": "A", "AZH": "A", "AZK": "K", "AZS": "S",
69
+ "AZY": "Y", "AVJ": "H", "A30": "Y", "A3U": "F", "ECC": "Q", "ECX": "C",
70
+ "EFC": "C", "EHP": "F", "ELY": "K", "EME": "E", "EPM": "M", "EPQ": "Q",
71
+ "ESB": "Y", "ESC": "M", "EXY": "L", "EXA": "K", "E0Y": "P", "E9V": "H",
72
+ "E9M": "W", "EJA": "C", "EUP": "T", "EZY": "G", "E9C": "Y", "EW6": "S",
73
+ "EXL": "W", "I2M": "I", "I4G": "G", "I58": "K", "IAM": "A", "IAR": "R",
74
+ "ICY": "C", "IEL": "K", "IGL": "G", "IIL": "I", "ILE": "I", "ILG": "E",
75
+ "ILM": "I", "ILX": "I", "ILY": "K", "IML": "I", "IOR": "R", "IPG": "G",
76
+ "IT1": "K", "IYR": "Y", "IZO": "M", "IC0": "G", "M0H": "C", "M2L": "K",
77
+ "M2S": "M", "M30": "G", "M3L": "K", "M3R": "K", "MA ": "A", "MAA": "A",
78
+ "MAI": "R", "MBQ": "Y", "MC1": "S", "MCL": "K", "MCS": "C", "MD3": "C",
79
+ "MD5": "C", "MD6": "G", "MDF": "Y", "ME0": "M", "MEA": "F", "MEG": "E",
80
+ "MEN": "N", "MEQ": "Q", "MET": "M", "MEU": "G", "MFN": "E", "MGG": "R",
81
+ "MGN": "Q", "MGY": "G", "MH1": "H", "MH6": "S", "MHL": "L", "MHO": "M",
82
+ "MHS": "H", "MHU": "F", "MIR": "S", "MIS": "S", "MK8": "L", "ML3": "K",
83
+ "MLE": "L", "MLL": "L", "MLY": "K", "MLZ": "K", "MME": "M", "MMO": "R",
84
+ "MNL": "L", "MNV": "V", "MP8": "P", "MPQ": "G", "MSA": "G", "MSE": "M",
85
+ "MSL": "M", "MSO": "M", "MT2": "M", "MTY": "Y", "MVA": "V", "MYK": "K",
86
+ "MYN": "R", "QCS": "C", "QIL": "I", "QMM": "Q", "QPA": "C", "QPH": "F",
87
+ "Q3P": "K", "QVA": "C", "QX7": "A", "Q2E": "W", "Q75": "M", "Q78": "F",
88
+ "QM8": "L", "QMB": "A", "QNQ": "C", "QNT": "C", "QNW": "C", "QO2": "C",
89
+ "QO5": "C", "QO8": "C", "QQ8": "Q", "U2X": "Y", "U3X": "F", "UF0": "S",
90
+ "UGY": "G", "UM1": "A", "UM2": "A", "UMA": "A", "UQK": "A", "UX8": "W",
91
+ "UXQ": "F", "YCM": "C", "YOF": "Y", "YPR": "P", "YPZ": "Y", "YTH": "T",
92
+ "Y1V": "L", "Y57": "K", "YHA": "K", "200": "F", "23F": "F", "23P": "A",
93
+ "26B": "T", "28X": "T", "2AG": "A", "2CO": "C", "2FM": "M", "2GX": "F",
94
+ "2HF": "H", "2JG": "S", "2KK": "K", "2KP": "K", "2LT": "Y", "2LU": "L",
95
+ "2ML": "L", "2MR": "R", "2MT": "P", "2OR": "R", "2P0": "P", "2QZ": "T",
96
+ "2R3": "Y", "2RA": "A", "2RX": "S", "2SO": "H", "2TY": "Y", "2VA": "V",
97
+ "2XA": "C", "2ZC": "S", "6CL": "K", "6CW": "W", "6GL": "A", "6HN": "K",
98
+ "60F": "C", "66D": "I", "6CV": "A", "6M6": "C", "6V1": "C", "6WK": "C",
99
+ "6Y9": "P", "6DN": "K", "DA2": "R", "DAB": "A", "DAH": "F", "DBS": "S",
100
+ "DBU": "T", "DBY": "Y", "DBZ": "A", "DC2": "C", "DDE": "H", "DDZ": "A",
101
+ "DI7": "Y", "DHA": "S", "DHN": "V", "DIR": "R", "DLS": "K", "DM0": "K",
102
+ "DMH": "N", "DMK": "D", "DNL": "K", "DNP": "A", "DNS": "K", "DNW": "A",
103
+ "DOH": "D", "DON": "L", "DP1": "R", "DPL": "P", "DPP": "A", "DPQ": "Y",
104
+ "DYS": "C", "D2T": "D", "DYA": "D", "DJD": "F", "DYJ": "P", "DV9": "E",
105
+ "H14": "F", "H1D": "M", "H5M": "P", "HAC": "A", "HAR": "R", "HBN": "H",
106
+ "HCM": "C", "HGY": "G", "HHI": "H", "HIA": "H", "HIC": "H", "HIP": "H",
107
+ "HIQ": "H", "HIS": "H", "HL2": "L", "HLU": "L", "HMR": "R", "HNC": "C",
108
+ "HOX": "F", "HPC": "F", "HPE": "F", "HPH": "F", "HPQ": "F", "HQA": "A",
109
+ "HR7": "R", "HRG": "R", "HRP": "W", "HS8": "H", "HS9": "H", "HSE": "S",
110
+ "HSK": "H", "HSL": "S", "HSO": "H", "HT7": "W", "HTI": "C", "HTR": "W",
111
+ "HV5": "A", "HVA": "V", "HY3": "P", "HYI": "M", "HYP": "P", "HZP": "P",
112
+ "HIX": "A", "HSV": "H", "HLY": "K", "HOO": "H", "H7V": "A", "L5P": "K",
113
+ "LRK": "K", "L3O": "L", "LA2": "K", "LAA": "D", "LAL": "A", "LBY": "K",
114
+ "LCK": "K", "LCX": "K", "LDH": "K", "LE1": "V", "LED": "L", "LEF": "L",
115
+ "LEH": "L", "LEM": "L", "LEN": "L", "LET": "K", "LEU": "L", "LEX": "L",
116
+ "LGY": "K", "LLO": "K", "LLP": "K", "LLY": "K", "LLZ": "K", "LME": "E",
117
+ "LMF": "K", "LMQ": "Q", "LNE": "L", "LNM": "L", "LP6": "K", "LPD": "P",
118
+ "LPG": "G", "LPS": "S", "LSO": "K", "LTR": "W", "LVG": "G", "LVN": "V",
119
+ "LWY": "P", "LYF": "K", "LYK": "K", "LYM": "K", "LYN": "K", "LYO": "K",
120
+ "LYP": "K", "LYR": "K", "LYS": "K", "LYU": "K", "LYX": "K", "LYZ": "K",
121
+ "LAY": "L", "LWI": "F", "LBZ": "K", "P1L": "C", "P2Q": "Y", "P2Y": "P",
122
+ "P3Q": "Y", "PAQ": "Y", "PAS": "D", "PAT": "W", "PBB": "C", "PBF": "F",
123
+ "PCA": "Q", "PCC": "P", "PCS": "F", "PE1": "K", "PEC": "C", "PF5": "F",
124
+ "PFF": "F", "PG1": "S", "PGY": "G", "PHA": "F", "PHD": "D", "PHE": "F",
125
+ "PHI": "F", "PHL": "F", "PHM": "F", "PKR": "P", "PLJ": "P", "PM3": "F",
126
+ "POM": "P", "PPN": "F", "PR3": "C", "PR4": "P", "PR7": "P", "PR9": "P",
127
+ "PRJ": "P", "PRK": "K", "PRO": "P", "PRS": "P", "PRV": "G", "PSA": "F",
128
+ "PSH": "H", "PTH": "Y", "PTM": "Y", "PTR": "Y", "PVH": "H", "PXU": "P",
129
+ "PYA": "A", "PYH": "K", "PYX": "C", "PH6": "P", "P9S": "C", "P5U": "S",
130
+ "POK": "R", "T0I": "Y", "T11": "F", "TAV": "D", "TBG": "V", "TBM": "T",
131
+ "TCQ": "Y", "TCR": "W", "TEF": "F", "TFQ": "F", "TH5": "T", "TH6": "T",
132
+ "THC": "T", "THR": "T", "THZ": "R", "TIH": "A", "TIS": "S", "TLY": "K",
133
+ "TMB": "T", "TMD": "T", "TNB": "C", "TNR": "S", "TNY": "T", "TOQ": "W",
134
+ "TOX": "W", "TPJ": "P", "TPK": "P", "TPL": "W", "TPO": "T", "TPQ": "Y",
135
+ "TQI": "W", "TQQ": "W", "TQZ": "C", "TRF": "W", "TRG": "K", "TRN": "W",
136
+ "TRO": "W", "TRP": "W", "TRQ": "W", "TRW": "W", "TRX": "W", "TRY": "W",
137
+ "TS9": "I", "TSY": "C", "TTQ": "W", "TTS": "Y", "TXY": "Y", "TY1": "Y",
138
+ "TY2": "Y", "TY3": "Y", "TY5": "Y", "TY8": "Y", "TY9": "Y", "TYB": "Y",
139
+ "TYC": "Y", "TYE": "Y", "TYI": "Y", "TYJ": "Y", "TYN": "Y", "TYO": "Y",
140
+ "TYQ": "Y", "TYR": "Y", "TYS": "Y", "TYT": "Y", "TYW": "Y", "TYY": "Y",
141
+ "T8L": "T", "T9E": "T", "TNQ": "W", "TSQ": "F", "TGH": "W", "X2W": "E",
142
+ "XCN": "C", "XPR": "P", "XSN": "N", "XW1": "A", "XX1": "K", "XYC": "A",
143
+ "XA6": "F", "11Q": "P", "11W": "E", "12L": "P", "12X": "P", "12Y": "P",
144
+ "143": "C", "1AC": "A", "1L1": "A", "1OP": "Y", "1PA": "F", "1PI": "A",
145
+ "1TQ": "W", "1TY": "Y", "1X6": "S", "56A": "H", "5AB": "A", "5CS": "C",
146
+ "5CW": "W", "5HP": "E", "5OH": "A", "5PG": "G", "51T": "Y", "54C": "W",
147
+ "5CR": "F", "5CT": "K", "5FQ": "A", "5GM": "I", "5JP": "S", "5T3": "K",
148
+ "5MW": "K", "5OW": "K", "5R5": "S", "5VV": "N", "5XU": "A", "55I": "F",
149
+ "999": "D", "9DN": "N", "9NE": "E", "9NF": "F", "9NR": "R", "9NV": "V",
150
+ "9E7": "K", "9KP": "K", "9WV": "A", "9TR": "K", "9TU": "K", "9TX": "K",
151
+ "9U0": "K", "9IJ": "F", "B1F": "F", "B27": "T", "B2A": "A", "B2F": "F",
152
+ "B2I": "I", "B2V": "V", "B3A": "A", "B3D": "D", "B3E": "E", "B3K": "K",
153
+ "B3U": "H", "B3X": "N", "B3Y": "Y", "BB6": "C", "BB7": "C", "BB8": "F",
154
+ "BB9": "C", "BBC": "C", "BCS": "C", "BCX": "C", "BFD": "D", "BG1": "S",
155
+ "BH2": "D", "BHD": "D", "BIF": "F", "BIU": "I", "BL2": "L", "BLE": "L",
156
+ "BLY": "K", "BMT": "T", "BNN": "F", "BOR": "R", "BP5": "A", "BPE": "C",
157
+ "BSE": "S", "BTA": "L", "BTC": "C", "BTK": "K", "BTR": "W", "BUC": "C",
158
+ "BUG": "V", "BYR": "Y", "BWV": "R", "BWB": "S", "BXT": "S", "F2F": "F",
159
+ "F2Y": "Y", "FAK": "K", "FB5": "A", "FB6": "A", "FC0": "F", "FCL": "F",
160
+ "FDL": "K", "FFM": "C", "FGL": "G", "FGP": "S", "FH7": "K", "FHL": "K",
161
+ "FHO": "K", "FIO": "R", "FLA": "A", "FLE": "L", "FLT": "Y", "FME": "M",
162
+ "FOE": "C", "FP9": "P", "FPK": "P", "FT6": "W", "FTR": "W", "FTY": "Y",
163
+ "FVA": "V", "FZN": "K", "FY3": "Y", "F7W": "W", "FY2": "Y", "FQA": "K",
164
+ "F7Q": "Y", "FF9": "K", "FL6": "D", "JJJ": "C", "JJK": "C", "JJL": "C",
165
+ "JLP": "K", "J3D": "C", "J9Y": "R", "J8W": "S", "JKH": "P", "N10": "S",
166
+ "N7P": "P", "NA8": "A", "NAL": "A", "NAM": "A", "NBQ": "Y", "NC1": "S",
167
+ "NCB": "A", "NEM": "H", "NEP": "H", "NFA": "F", "NIY": "Y", "NLB": "L",
168
+ "NLE": "L", "NLN": "L", "NLO": "L", "NLP": "L", "NLQ": "Q", "NLY": "G",
169
+ "NMC": "G", "NMM": "R", "NNH": "R", "NOT": "L", "NPH": "C", "NPI": "A",
170
+ "NTR": "Y", "NTY": "Y", "NVA": "V", "NWD": "A", "NYB": "C", "NYS": "C",
171
+ "NZH": "H", "N80": "P", "NZC": "T", "NLW": "L", "N0A": "F", "N9P": "A",
172
+ "N65": "K", "R1A": "C", "R4K": "W", "RE0": "W", "RE3": "W", "RGL": "R",
173
+ "RGP": "E", "RT0": "P", "RVX": "S", "RZ4": "S", "RPI": "R", "RVJ": "A",
174
+ "VAD": "V", "VAF": "V", "VAH": "V", "VAI": "V", "VAL": "V", "VB1": "K",
175
+ "VH0": "P", "VR0": "R", "V44": "C", "V61": "F", "VPV": "K", "V5N": "H",
176
+ "V7T": "K", "Z01": "A", "Z3E": "T", "Z70": "H", "ZBZ": "C", "ZCL": "F",
177
+ "ZU0": "T", "ZYJ": "P", "ZYK": "P", "ZZD": "C", "ZZJ": "A", "ZIQ": "W",
178
+ "ZPO": "P", "ZDJ": "Y", "ZT1": "K", "30V": "C", "31Q": "C", "33S": "F",
179
+ "33W": "A", "34E": "V", "3AH": "H", "3BY": "P", "3CF": "F", "3CT": "Y",
180
+ "3GA": "A", "3GL": "E", "3MD": "D", "3MY": "Y", "3NF": "Y", "3O3": "E",
181
+ "3PX": "P", "3QN": "K", "3TT": "P", "3XH": "G", "3YM": "Y", "3WS": "A",
182
+ "3WX": "P", "3X9": "C", "3ZH": "H", "7JA": "I", "73C": "S", "73N": "R",
183
+ "73O": "Y", "73P": "K", "74P": "K", "7N8": "F", "7O5": "A", "7XC": "F",
184
+ "7ID": "D", "7OZ": "A", "C1S": "C", "C1T": "C", "C1X": "K", "C22": "A",
185
+ "C3Y": "C", "C4R": "C", "C5C": "C", "C6C": "C", "CAF": "C", "CAS": "C",
186
+ "CAY": "C", "CCS": "C", "CEA": "C", "CGA": "E", "CGU": "E", "CGV": "C",
187
+ "CHP": "G", "CIR": "R", "CLE": "L", "CLG": "K", "CLH": "K", "CME": "C",
188
+ "CMH": "C", "CML": "C", "CMT": "C", "CR5": "G", "CS0": "C", "CS1": "C",
189
+ "CS3": "C", "CS4": "C", "CSA": "C", "CSB": "C", "CSD": "C", "CSE": "C",
190
+ "CSJ": "C", "CSO": "C", "CSP": "C", "CSR": "C", "CSS": "C", "CSU": "C",
191
+ "CSW": "C", "CSX": "C", "CSZ": "C", "CTE": "W", "CTH": "T", "CWD": "A",
192
+ "CWR": "S", "CXM": "M", "CY0": "C", "CY1": "C", "CY3": "C", "CY4": "C",
193
+ "CYA": "C", "CYD": "C", "CYF": "C", "CYG": "C", "CYJ": "K", "CYM": "C",
194
+ "CYQ": "C", "CYR": "C", "CYS": "C", "CYW": "C", "CZ2": "C", "CZZ": "C",
195
+ "CG6": "C", "C1J": "R", "C4G": "R", "C67": "R", "C6D": "R", "CE7": "N",
196
+ "CZS": "A", "G01": "E", "G8M": "E", "GAU": "E", "GEE": "G", "GFT": "S",
197
+ "GHC": "E", "GHG": "Q", "GHW": "E", "GL3": "G", "GLH": "Q", "GLJ": "E",
198
+ "GLK": "E", "GLN": "Q", "GLQ": "E", "GLU": "E", "GLY": "G", "GLZ": "G",
199
+ "GMA": "E", "GME": "E", "GNC": "Q", "GPL": "K", "GSC": "G", "GSU": "E",
200
+ "GT9": "C", "GVL": "S", "G3M": "R", "G5G": "L", "G1X": "Y", "G8X": "P",
201
+ "K1R": "C", "KBE": "K", "KCX": "K", "KFP": "K", "KGC": "K", "KNB": "A",
202
+ "KOR": "M", "KPI": "K", "KPY": "K", "KST": "K", "KYN": "W", "KYQ": "K",
203
+ "KCR": "K", "KPF": "K", "K5L": "S", "KEO": "K", "KHB": "K", "KKD": "D",
204
+ "K5H": "C", "K7K": "S", "OAR": "R", "OAS": "S", "OBS": "K", "OCS": "C",
205
+ "OCY": "C", "OHI": "H", "OHS": "D", "OLD": "H", "OLT": "T", "OLZ": "S",
206
+ "OMH": "S", "OMT": "M", "OMX": "Y", "OMY": "Y", "ONH": "A", "ORN": "A",
207
+ "ORQ": "R", "OSE": "S", "OTH": "T", "OXX": "D", "OYL": "H", "O7A": "T",
208
+ "O7D": "W", "O7G": "V", "O2E": "S", "O6H": "W", "OZW": "F", "S12": "S",
209
+ "S1H": "S", "S2C": "C", "S2P": "A", "SAC": "S", "SAH": "C", "SAR": "G",
210
+ "SBG": "S", "SBL": "S", "SCH": "C", "SCS": "C", "SCY": "C", "SD4": "N",
211
+ "SDB": "S", "SDP": "S", "SEB": "S", "SEE": "S", "SEG": "A", "SEL": "S",
212
+ "SEM": "S", "SEN": "S", "SEP": "S", "SER": "S", "SET": "S", "SGB": "S",
213
+ "SHC": "C", "SHP": "G", "SHR": "K", "SIB": "C", "SLL": "K", "SLZ": "K",
214
+ "SMC": "C", "SME": "M", "SMF": "F", "SNC": "C", "SNN": "N", "SOY": "S",
215
+ "SRZ": "S", "STY": "Y", "SUN": "S", "SVA": "S", "SVV": "S", "SVW": "S",
216
+ "SVX": "S", "SVY": "S", "SVZ": "S", "SXE": "S", "SKH": "K", "SNM": "S",
217
+ "SNK": "H", "SWW": "S", "WFP": "F", "WLU": "L", "WPA": "F", "WRP": "W",
218
+ "WVL": "V", "02K": "A", "02L": "N", "02O": "A", "02Y": "A", "033": "V",
219
+ "037": "P", "03Y": "C", "04U": "P", "04V": "P", "05N": "P", "07O": "C",
220
+ "0A0": "D", "0A1": "Y", "0A2": "K", "0A8": "C", "0A9": "F", "0AA": "V",
221
+ "0AB": "V", "0AC": "G", "0AF": "W", "0AG": "L", "0AH": "S", "0AK": "D",
222
+ "0AR": "R", "0BN": "F", "0CS": "A", "0E5": "T", "0EA": "Y", "0FL": "A",
223
+ "0LF": "P", "0NC": "A", "0PR": "Y", "0QL": "C", "0TD": "D", "0UO": "W",
224
+ "0WZ": "Y", "0X9": "R", "0Y8": "P", "4AF": "F", "4AR": "R", "4AW": "W",
225
+ "4BF": "F", "4CF": "F", "4CY": "M", "4DP": "W", "4FB": "P", "4FW": "W",
226
+ "4HL": "Y", "4HT": "W", "4IN": "W", "4MM": "M", "4PH": "F", "4U7": "A",
227
+ "41H": "F", "41Q": "N", "42Y": "S", "432": "S", "45F": "P", "4AK": "K",
228
+ "4D4": "R", "4GJ": "C", "4KY": "P", "4L0": "P", "4LZ": "Y", "4N7": "P",
229
+ "4N8": "P", "4N9": "P", "4OG": "W", "4OU": "F", "4OV": "S", "4OZ": "S",
230
+ "4PQ": "W", "4SJ": "F", "4WQ": "A", "4HH": "S", "4HJ": "S", "4J4": "C",
231
+ "4J5": "R", "4II": "F", "4VI": "R", "823": "N", "8SP": "S", "8AY": "A",
232
+ }
233
+
234
+ # Nucleic Acids
235
+ nucleic_letters_3to1 = {
236
+ "A ": "A", "C ": "C", "G ": "G", "U ": "U",
237
+ "DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
238
+ }
239
+
240
+ rna_letters_3to1 = {
241
+ "A ": "A", "C ": "C", "G ": "G", "U ": "U",
242
+ }
243
+
244
+ dna_letters_3to1 = {
245
+ "DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
246
+ }
247
+
248
+ # fmt: off
249
+ nucleic_letters_3to1_extended = {
250
+ "A ": "A", "A23": "A", "A2L": "A", "A2M": "A", "A34": "A", "A35": "A",
251
+ "A38": "A", "A39": "A", "A3A": "A", "A3P": "A", "A40": "A", "A43": "A",
252
+ "A44": "A", "A47": "A", "A5L": "A", "A5M": "C", "A5O": "A", "A6A": "A",
253
+ "A6C": "C", "A6G": "G", "A6U": "U", "A7E": "A", "A9Z": "A", "ABR": "A",
254
+ "ABS": "A", "AD2": "A", "ADI": "A", "ADP": "A", "AET": "A", "AF2": "A",
255
+ "AFG": "G", "AMD": "A", "AMO": "A", "AP7": "A", "AS ": "A", "ATD": "T",
256
+ "ATL": "T", "ATM": "T", "AVC": "A", "AI5": "C", "E ": "A", "E1X": "A",
257
+ "EDA": "A", "EFG": "G", "EHG": "G", "EIT": "T", "EXC": "C", "E3C": "C",
258
+ "E6G": "G", "E7G": "G", "EQ4": "G", "EAN": "T", "I5C": "C", "IC ": "C",
259
+ "IG ": "G", "IGU": "G", "IMC": "C", "IMP": "G", "IU ": "U", "I4U": "U",
260
+ "IOO": "G", "M1G": "G", "M2G": "G", "M4C": "C", "M5M": "C", "MA6": "A",
261
+ "MA7": "A", "MAD": "A", "MCY": "C", "ME6": "C", "MEP": "U", "MG1": "G",
262
+ "MGQ": "A", "MGT": "G", "MGV": "G", "MIA": "A", "MMT": "T", "MNU": "U",
263
+ "MRG": "G", "MTR": "T", "MTU": "A", "MFO": "G", "M7A": "A", "MHG": "G",
264
+ "MMX": "C", "QUO": "G", "QCK": "T", "QSQ": "A", "U ": "U", "U25": "U",
265
+ "U2L": "U", "U2P": "U", "U31": "U", "U34": "U", "U36": "U", "U37": "U",
266
+ "U8U": "U", "UAR": "U", "UBB": "U", "UBD": "U", "UD5": "U", "UPV": "U",
267
+ "UR3": "U", "URD": "U", "US3": "T", "US5": "U", "UZR": "U", "UMO": "U",
268
+ "U23": "U", "U48": "C", "U7B": "C", "Y ": "A", "YCO": "C", "YG ": "G",
269
+ "YYG": "G", "23G": "G", "26A": "A", "2AR": "A", "2AT": "T", "2AU": "U",
270
+ "2BT": "T", "2BU": "A", "2DA": "A", "2DT": "T", "2EG": "G", "2GT": "T",
271
+ "2JV": "G", "2MA": "A", "2MG": "G", "2MU": "U", "2NT": "T", "2OM": "U",
272
+ "2OT": "T", "2PR": "G", "2SG": "G", "2ST": "T", "63G": "G", "63H": "G",
273
+ "64T": "T", "68Z": "G", "6CT": "T", "6HA": "A", "6HB": "A", "6HC": "C",
274
+ "6HG": "G", "6HT": "T", "6IA": "A", "6MA": "A", "6MC": "A", "6MP": "A",
275
+ "6MT": "A", "6MZ": "A", "6OG": "G", "6PO": "G", "6FK": "G", "6NW": "A",
276
+ "6OO": "C", "D00": "C", "D3T": "T", "D4M": "T", "DA ": "A", "DC ": "C",
277
+ "DCG": "G", "DCT": "C", "DDG": "G", "DFC": "C", "DFG": "G", "DG ": "G",
278
+ "DG8": "G", "DGI": "G", "DGP": "G", "DHU": "U", "DNR": "C", "DOC": "C",
279
+ "DPB": "T", "DRT": "T", "DT ": "T", "DZM": "A", "D4B": "C", "H2U": "U",
280
+ "HN0": "G", "HN1": "G", "LC ": "C", "LCA": "A", "LCG": "G", "LG ": "G",
281
+ "LGP": "G", "LHU": "U", "LSH": "T", "LST": "T", "LDG": "G", "L3X": "A",
282
+ "LHH": "C", "LV2": "C", "L1J": "G", "P ": "G", "P2T": "T", "P5P": "A",
283
+ "PG7": "G", "PGN": "G", "PGP": "G", "PMT": "C", "PPU": "A", "PPW": "G",
284
+ "PR5": "A", "PRN": "A", "PST": "T", "PSU": "U", "PU ": "A", "PVX": "C",
285
+ "PYO": "U", "PZG": "G", "P4U": "U", "P7G": "G", "T ": "T", "T2S": "T",
286
+ "T31": "U", "T32": "T", "T36": "T", "T37": "T", "T38": "T", "T39": "T",
287
+ "T3P": "T", "T41": "T", "T48": "T", "T49": "T", "T4S": "T", "T5S": "T",
288
+ "T64": "T", "T6A": "A", "TA3": "T", "TAF": "T", "TBN": "A", "TC1": "C",
289
+ "TCP": "T", "TCY": "A", "TDY": "T", "TED": "T", "TFE": "T", "TFF": "T",
290
+ "TFO": "A", "TFT": "T", "TGP": "G", "TCJ": "C", "TLC": "T", "TP1": "T",
291
+ "TPC": "C", "TPG": "G", "TSP": "T", "TTD": "T", "TTM": "T", "TXD": "A",
292
+ "TXP": "A", "TC ": "C", "TG ": "G", "T0N": "G", "T0Q": "G", "X ": "G",
293
+ "XAD": "A", "XAL": "A", "XCL": "C", "XCR": "C", "XCT": "C", "XCY": "C",
294
+ "XGL": "G", "XGR": "G", "XGU": "G", "XPB": "G", "XTF": "T", "XTH": "T",
295
+ "XTL": "T", "XTR": "T", "XTS": "G", "XUA": "A", "XUG": "G", "102": "G",
296
+ "10C": "C", "125": "U", "126": "U", "127": "U", "12A": "A", "16B": "C",
297
+ "18M": "G", "1AP": "A", "1CC": "C", "1FC": "C", "1MA": "A", "1MG": "G",
298
+ "1RN": "U", "1SC": "C", "5AA": "A", "5AT": "T", "5BU": "U", "5CG": "G",
299
+ "5CM": "C", "5FA": "A", "5FC": "C", "5FU": "U", "5HC": "C", "5HM": "C",
300
+ "5HT": "T", "5IC": "C", "5IT": "T", "5MC": "C", "5MU": "U", "5NC": "C",
301
+ "5PC": "C", "5PY": "T", "9QV": "U", "94O": "T", "9SI": "A", "9SY": "A",
302
+ "B7C": "C", "BGM": "G", "BOE": "T", "B8H": "U", "B8K": "G", "B8Q": "C",
303
+ "B8T": "C", "B8W": "G", "B9B": "G", "B9H": "C", "BGH": "G", "F3H": "T",
304
+ "F3N": "A", "F4H": "T", "FA2": "A", "FDG": "G", "FHU": "U", "FMG": "G",
305
+ "FNU": "U", "FOX": "G", "F2T": "U", "F74": "G", "F4Q": "G", "F7H": "C",
306
+ "F7K": "G", "JDT": "T", "JMH": "C", "J0X": "C", "N5M": "C", "N6G": "G",
307
+ "N79": "A", "NCU": "C", "NMS": "T", "NMT": "T", "NTT": "T", "N7X": "C",
308
+ "R ": "A", "RBD": "A", "RDG": "G", "RIA": "A", "RMP": "A", "RPC": "C",
309
+ "RSP": "C", "RSQ": "C", "RT ": "T", "RUS": "U", "RFJ": "G", "V3L": "A",
310
+ "VC7": "G", "Z ": "C", "ZAD": "A", "ZBC": "C", "ZBU": "U", "ZCY": "C",
311
+ "ZGU": "G", "31H": "A", "31M": "A", "3AU": "U", "3DA": "A", "3ME": "U",
312
+ "3MU": "U", "3TD": "U", "70U": "U", "7AT": "A", "7DA": "A", "7GU": "G",
313
+ "7MG": "G", "7BG": "G", "73W": "C", "75B": "U", "7OK": "C", "7S3": "G",
314
+ "7SN": "G", "C ": "C", "C25": "C", "C2L": "C", "C2S": "C", "C31": "C",
315
+ "C32": "C", "C34": "C", "C36": "C", "C37": "C", "C38": "C", "C42": "C",
316
+ "C43": "C", "C45": "C", "C46": "C", "C49": "C", "C4S": "C", "C5L": "C",
317
+ "C6G": "G", "CAR": "C", "CB2": "C", "CBR": "C", "CBV": "C", "CCC": "C",
318
+ "CDW": "C", "CFL": "C", "CFZ": "C", "CG1": "G", "CH ": "C", "CMR": "C",
319
+ "CNU": "U", "CP1": "C", "CSF": "C", "CSL": "C", "CTG": "T", "CX2": "C",
320
+ "C7S": "C", "C7R": "C", "G ": "G", "G1G": "G", "G25": "G", "G2L": "G",
321
+ "G2S": "G", "G31": "G", "G32": "G", "G33": "G", "G36": "G", "G38": "G",
322
+ "G42": "G", "G46": "G", "G47": "G", "G48": "G", "G49": "G", "G7M": "G",
323
+ "GAO": "G", "GCK": "C", "GDO": "G", "GDP": "G", "GDR": "G", "GF2": "G",
324
+ "GFL": "G", "GH3": "G", "GMS": "G", "GN7": "G", "GNG": "G", "GOM": "G",
325
+ "GRB": "G", "GS ": "G", "GSR": "G", "GSS": "G", "GTP": "G", "GX1": "G",
326
+ "KAG": "G", "KAK": "G", "O2G": "G", "OGX": "G", "OMC": "C", "OMG": "G",
327
+ "OMU": "U", "ONE": "U", "O2Z": "A", "OKN": "C", "OKQ": "C", "S2M": "T",
328
+ "S4A": "A", "S4C": "C", "S4G": "G", "S4U": "U", "S6G": "G", "SC ": "C",
329
+ "SDE": "A", "SDG": "G", "SDH": "G", "SMP": "A", "SMT": "T", "SPT": "T",
330
+ "SRA": "A", "SSU": "U", "SUR": "U", "00A": "A", "0AD": "G", "0AM": "A",
331
+ "0AP": "C", "0AV": "A", "0R8": "C", "0SP": "A", "0UH": "G", "47C": "C",
332
+ "4OC": "C", "4PC": "C", "4PD": "C", "4PE": "C", "4SC": "C", "4SU": "U",
333
+ "45A": "A", "4U3": "C", "8AG": "G", "8AN": "A", "8BA": "A", "8FG": "G",
334
+ "8MG": "G", "8OG": "G", "8PY": "G", "8AA": "G", "85Y": "U", "8OS": "G",
335
+ "UNK": "X", # DEBUG
336
+ }
337
+
338
+ standard_protein_letters_3to1 = protein_letters_3to1
339
+ standard_protein_letters_1to3 = protein_letters_1to3
340
+ nonstandard_protein_letters_3to1 = {k: v for k, v in protein_letters_3to1_extended.items() if
341
+ k not in standard_protein_letters_3to1}
342
+
343
+ standard_nucleic_letters_3to1 = nucleic_letters_3to1
344
+ standard_nucleic_letters_1to3 = {v: k for k, v in standard_nucleic_letters_3to1.items()}
345
+ nonstandard_nucleic_letters_3to1 = {k: v for k, v in nucleic_letters_3to1_extended.items() if
346
+ k not in standard_nucleic_letters_3to1}
347
+
348
+ letters_3to1_extended = {**protein_letters_3to1_extended, **nucleic_letters_3to1_extended}
PhysDock/data/constants/__init__.py ADDED
File without changes
PhysDock/data/constants/periodic_table.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ periodic_table = [
2
+ "h", "he",
3
+ "li", "be", "b", "c", "n", "o", "f", "ne",
4
+ "na", "mg", "al", "si", "p", "s", "cl", "ar",
5
+ "k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
6
+ "rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
7
+ "cs", "ba",
8
+ "la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
9
+ "hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
10
+ "fr", "ra",
11
+ "ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
12
+ "rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
13
+ ]
14
+
15
+ PeriodicTable = [
16
+ "H", "He",
17
+ "Li", "Be", "B", "C", "N", "O", "F", "Ne",
18
+ "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
19
+ "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr",
20
+ "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe",
21
+ "Cs", "Ba",
22
+ "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu",
23
+ "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Rn",
24
+ "Fr", "Ra",
25
+ "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr",
26
+ "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc", "Lv", "Ts", "Og"
27
+ ]
PhysDock/data/constants/residue_constants.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ amino_acid_1to3 = {
4
+ "A": "ALA",
5
+ "R": "ARG",
6
+ "N": "ASN",
7
+ "D": "ASP",
8
+ "C": "CYS",
9
+ "Q": "GLN",
10
+ "E": "GLU",
11
+ "G": "GLY",
12
+ "H": "HIS",
13
+ "I": "ILE",
14
+ "L": "LEU",
15
+ "K": "LYS",
16
+ "M": "MET",
17
+ "F": "PHE",
18
+ "P": "PRO",
19
+ "S": "SER",
20
+ "T": "THR",
21
+ "W": "TRP",
22
+ "Y": "TYR",
23
+ "V": "VAL",
24
+ "X": "UNK",
25
+ }
26
+
27
+ amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
28
+
29
+ # Ligand Atom is representaed as "UNK" in token
30
+ # standard_residue is also ccd
31
+ standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
32
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
33
+ standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
34
+ standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
35
+ standard_nucleics = standard_rna + standard_dna
36
+ standard_ccds_without_gap = standard_protein + standard_nucleics
37
+ GAP = ["GAP"] # used in msa one-hot
38
+ standard_ccds = standard_protein + standard_nucleics + GAP
39
+
40
+ standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
41
+
42
+ standard_purines = ["A ", "G ", "DA ", "DG "]
43
+ standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
44
+
45
+ is_standard = lambda x: x in standard_ccds
46
+ is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
47
+ is_protein = lambda x: x in standard_protein and not is_unk(x)
48
+ is_rna = lambda x: x in standard_rna and not is_unk(x)
49
+ is_dna = lambda x: x in standard_dna and not is_unk(x)
50
+ is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
51
+ is_purines = lambda x: x in standard_purines
52
+ is_pyrimidines = lambda x: x in standard_pyrimidines
53
+
54
+ standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
55
+ 5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
56
+ 8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
57
+ 22, 23, 20, 20, None,
58
+ 21, 22, 19, 20, None,
59
+ None,
60
+ ])}
61
+
62
+ standard_ccd_to_token_centre_atom_name = {
63
+ **{residue: "CA" for residue in standard_protein},
64
+ **{residue: "C1'" for residue in standard_nucleics},
65
+ }
66
+
67
+ standard_ccd_to_frame_atom_name_0 = {
68
+ **{residue: "N" for residue in standard_protein},
69
+ **{residue: "C1'" for residue in standard_nucleics},
70
+ }
71
+
72
+ standard_ccd_to_frame_atom_name_1 = {
73
+ **{residue: "CA" for residue in standard_protein},
74
+ **{residue: "C3'" for residue in standard_nucleics},
75
+ }
76
+
77
+ standard_ccd_to_frame_atom_name_2 = {
78
+ **{residue: "C" for residue in standard_protein},
79
+ **{residue: "C4'" for residue in standard_nucleics},
80
+ }
81
+
82
+ standard_ccd_to_token_pseudo_beta_atom_name = {
83
+ **{residue: "CB" for residue in standard_protein},
84
+ **{residue: "C4" for residue in standard_purines},
85
+ **{residue: "C2" for residue in standard_pyrimidines},
86
+ }
87
+ standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
88
+
89
+ ########################################################
90
+ # periodic table that used to encode elements #
91
+ ########################################################
92
+ periodic_table = [
93
+ "h", "he",
94
+ "li", "be", "b", "c", "n", "o", "f", "ne",
95
+ "na", "mg", "al", "si", "p", "s", "cl", "ar",
96
+ "k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
97
+ "rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
98
+ "cs", "ba",
99
+ "la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
100
+ "hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
101
+ "fr", "ra",
102
+ "ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
103
+ "rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
104
+ ]
105
+
106
+ get_element_id = {ele: ele_id for ele_id, ele in enumerate(periodic_table)}
107
+
108
+ ##########################################################
109
+
110
+ standard_ccd_to_reference_features_table = {
111
+ # letters_3: [ref_pos,ref_charge, ref_mask, ref_elements, ref_atom_name_chars]
112
+ "ALA": [
113
+ [-0.966, 0.493, 1.500, 0., 1, "N", "N"],
114
+ [0.257, 0.418, 0.692, 0., 1, "C", "CA"],
115
+ [-0.094, 0.017, -0.716, 0., 1, "C", "C"],
116
+ [-1.056, -0.682, -0.923, 0., 1, "O", "O"],
117
+ [1.204, -0.620, 1.296, 0., 1, "C", "CB"],
118
+ [0.661, 0.439, -1.742, 0., 0, "O", "OXT"],
119
+ ],
120
+ "ARG": [
121
+ [-0.469, 1.110, -0.993, 0., 1, "N", "N"],
122
+ [0.004, 2.294, -1.708, 0., 1, "C", "CA"],
123
+ [-0.907, 2.521, -2.901, 0., 1, "C", "C"],
124
+ [-1.827, 1.789, -3.242, 0., 1, "O", "O"],
125
+ [1.475, 2.150, -2.127, 0., 1, "C", "CB"],
126
+ [1.745, 1.017, -3.130, 0., 1, "C", "CG"],
127
+ [3.210, 0.954, -3.557, 0., 1, "C", "CD"],
128
+ [4.071, 0.726, -2.421, 0., 1, "N", "NE"],
129
+ [5.469, 0.624, -2.528, 0., 1, "C", "CZ"],
130
+ [6.259, 0.404, -1.405, 0., 1, "N", "NH1"],
131
+ [6.078, 0.744, -3.773, 0., 1, "N", "NH2"],
132
+ [-0.588, 3.659, -3.574, 0., 0, "O", "OXT"],
133
+ ],
134
+ "ASN": [
135
+ [-0.293, 1.686, 0.094, 0., 1, "N", "N"],
136
+ [-0.448, 0.292, -0.340, 0., 1, "C", "CA"],
137
+ [-1.846, -0.179, -0.031, 0., 1, "C", "C"],
138
+ [-2.510, 0.402, 0.794, 0., 1, "O", "O"],
139
+ [0.562, -0.588, 0.401, 0., 1, "C", "CB"],
140
+ [1.960, -0.197, -0.002, 0., 1, "C", "CG"],
141
+ [2.132, 0.697, -0.804, 0., 1, "O", "OD1"],
142
+ [3.019, -0.841, 0.527, 0., 1, "N", "ND2"],
143
+ [-2.353, -1.243, -0.673, 0., 0, "O", "OXT"],
144
+ ],
145
+ "ASP": [
146
+ [-0.317, 1.688, 0.066, 0., 1, "N", "N"],
147
+ [-0.470, 0.286, -0.344, 0., 1, "C", "CA"],
148
+ [-1.868, -0.180, -0.029, 0., 1, "C", "C"],
149
+ [-2.534, 0.415, 0.786, 0., 1, "O", "O"],
150
+ [0.539, -0.580, 0.413, 0., 1, "C", "CB"],
151
+ [1.938, -0.195, 0.004, 0., 1, "C", "CG"],
152
+ [2.109, 0.681, -0.810, 0., 1, "O", "OD1"],
153
+ [2.992, -0.826, 0.543, 0., 1, "O", "OD2"],
154
+ [-2.374, -1.256, -0.652, 0., 0, "O", "OXT"],
155
+ ],
156
+ "CYS": [
157
+ [1.585, 0.483, -0.081, 0., 1, "N", "N"],
158
+ [0.141, 0.450, 0.186, 0., 1, "C", "CA"],
159
+ [-0.095, 0.006, 1.606, 0., 1, "C", "C"],
160
+ [0.685, -0.742, 2.143, 0., 1, "O", "O"],
161
+ [-0.533, -0.530, -0.774, 0., 1, "C", "CB"],
162
+ [-0.247, 0.004, -2.484, 0., 1, "S", "SG"],
163
+ [-1.174, 0.443, 2.275, 0., 0, "O", "OXT"],
164
+ ],
165
+ "GLN": [
166
+ [1.858, -0.148, 1.125, 0., 1, "N", "N"],
167
+ [0.517, 0.451, 1.112, 0., 1, "C", "CA"],
168
+ [-0.236, 0.022, 2.344, 0., 1, "C", "C"],
169
+ [-0.005, -1.049, 2.851, 0., 1, "O", "O"],
170
+ [-0.236, -0.013, -0.135, 0., 1, "C", "CB"],
171
+ [0.529, 0.421, -1.385, 0., 1, "C", "CG"],
172
+ [-0.213, -0.036, -2.614, 0., 1, "C", "CD"],
173
+ [-1.252, -0.650, -2.500, 0., 1, "O", "OE1"],
174
+ [0.277, 0.236, -3.839, 0., 1, "N", "NE2"],
175
+ [-1.165, 0.831, 2.878, 0., 0, "O", "OXT"],
176
+ ],
177
+ "GLU": [
178
+ [1.199, 1.867, -0.117, 0., 1, "N", "N"],
179
+ [1.138, 0.515, 0.453, 0., 1, "C", "CA"],
180
+ [2.364, -0.260, 0.041, 0., 1, "C", "C"],
181
+ [3.010, 0.096, -0.916, 0., 1, "O", "O"],
182
+ [-0.113, -0.200, -0.062, 0., 1, "C", "CB"],
183
+ [-1.360, 0.517, 0.461, 0., 1, "C", "CG"],
184
+ [-2.593, -0.187, -0.046, 0., 1, "C", "CD"],
185
+ [-2.485, -1.161, -0.753, 0., 1, "O", "OE1"],
186
+ [-3.811, 0.269, 0.287, 0., 1, "O", "OE2"],
187
+ [2.737, -1.345, 0.737, 0., 0, "O", "OXT"],
188
+ ],
189
+ "GLY": [
190
+ [1.931, 0.090, -0.034, 0., 1, "N", "N"],
191
+ [0.761, -0.799, -0.008, 0., 1, "C", "CA"],
192
+ [-0.498, 0.029, -0.005, 0., 1, "C", "C"],
193
+ [-0.429, 1.235, -0.023, 0., 1, "O", "O"],
194
+ [-1.697, -0.574, 0.018, 0., 0, "O", "OXT"],
195
+ ],
196
+ "HIS": [
197
+ [-0.040, -1.210, 0.053, 0., 1, "N", "N"],
198
+ [1.172, -1.709, 0.652, 0., 1, "C", "CA"],
199
+ [1.083, -3.207, 0.905, 0., 1, "C", "C"],
200
+ [0.040, -3.770, 1.222, 0., 1, "O", "O"],
201
+ [1.484, -0.975, 1.962, 0., 1, "C", "CB"],
202
+ [2.940, -1.060, 2.353, 0., 1, "C", "CG"],
203
+ [3.380, -2.075, 3.129, 0., 1, "N", "ND1"],
204
+ [3.960, -0.251, 2.046, 0., 1, "C", "CD2"],
205
+ [4.693, -1.908, 3.317, 0., 1, "C", "CE1"],
206
+ [5.058, -0.801, 2.662, 0., 1, "N", "NE2"],
207
+ [2.247, -3.882, 0.744, 0., 0, "O", "OXT"],
208
+ ],
209
+ "ILE": [
210
+ [-1.944, 0.335, -0.343, 0., 1, "N", "N"],
211
+ [-0.487, 0.519, -0.369, 0., 1, "C", "CA"],
212
+ [0.066, -0.032, -1.657, 0., 1, "C", "C"],
213
+ [-0.484, -0.958, -2.203, 0., 1, "O", "O"],
214
+ [0.140, -0.219, 0.814, 0., 1, "C", "CB"],
215
+ [-0.421, 0.341, 2.122, 0., 1, "C", "CG1"],
216
+ [1.658, -0.027, 0.788, 0., 1, "C", "CG2"],
217
+ [0.206, -0.397, 3.305, 0., 1, "C", "CD1"],
218
+ [1.171, 0.504, -2.197, 0., 0, "O", "OXT"],
219
+ ],
220
+ "LEU": [
221
+ [-1.661, 0.627, -0.406, 0., 1, "N", "N"],
222
+ [-0.205, 0.441, -0.467, 0., 1, "C", "CA"],
223
+ [0.180, -0.055, -1.836, 0., 1, "C", "C"],
224
+ [-0.591, -0.731, -2.474, 0., 1, "O", "O"],
225
+ [0.221, -0.583, 0.585, 0., 1, "C", "CB"],
226
+ [-0.170, -0.079, 1.976, 0., 1, "C", "CG"],
227
+ [0.256, -1.104, 3.029, 0., 1, "C", "CD1"],
228
+ [0.526, 1.254, 2.250, 0., 1, "C", "CD2"],
229
+ [1.382, 0.254, -2.348, 0., 0, "O", "OXT"],
230
+ ],
231
+ "LYS": [
232
+ [1.422, 1.796, 0.198, 0., 1, "N", "N"],
233
+ [1.394, 0.355, 0.484, 0., 1, "C", "CA"],
234
+ [2.657, -0.284, -0.032, 0., 1, "C", "C"],
235
+ [3.316, 0.275, -0.876, 0., 1, "O", "O"],
236
+ [0.184, -0.278, -0.206, 0., 1, "C", "CB"],
237
+ [-1.102, 0.282, 0.407, 0., 1, "C", "CG"],
238
+ [-2.313, -0.351, -0.283, 0., 1, "C", "CD"],
239
+ [-3.598, 0.208, 0.329, 0., 1, "C", "CE"],
240
+ [-4.761, -0.400, -0.332, 0., 1, "N", "NZ"],
241
+ [3.050, -1.476, 0.446, 0., 0, "O", "OXT"],
242
+ ],
243
+ "MET": [
244
+ [-1.816, 0.142, -1.166, 0., 1, "N", "N"],
245
+ [-0.392, 0.499, -1.214, 0., 1, "C", "CA"],
246
+ [0.206, 0.002, -2.504, 0., 1, "C", "C"],
247
+ [-0.236, -0.989, -3.033, 0., 1, "O", "O"],
248
+ [0.334, -0.145, -0.032, 0., 1, "C", "CB"],
249
+ [-0.273, 0.359, 1.277, 0., 1, "C", "CG"],
250
+ [0.589, -0.405, 2.678, 0., 1, "S", "SD"],
251
+ [-0.314, 0.353, 4.056, 0., 1, "C", "CE"],
252
+ [1.232, 0.661, -3.066, 0., 0, "O", "OXT"],
253
+ ],
254
+ "PHE": [
255
+ [1.317, 0.962, 1.014, 0., 1, "N", "N"],
256
+ [-0.020, 0.426, 1.300, 0., 1, "C", "CA"],
257
+ [-0.109, 0.047, 2.756, 0., 1, "C", "C"],
258
+ [0.879, -0.317, 3.346, 0., 1, "O", "O"],
259
+ [-0.270, -0.809, 0.434, 0., 1, "C", "CB"],
260
+ [-0.181, -0.430, -1.020, 0., 1, "C", "CG"],
261
+ [1.031, -0.498, -1.680, 0., 1, "C", "CD1"],
262
+ [-1.314, -0.018, -1.698, 0., 1, "C", "CD2"],
263
+ [1.112, -0.150, -3.015, 0., 1, "C", "CE1"],
264
+ [-1.231, 0.333, -3.032, 0., 1, "C", "CE2"],
265
+ [-0.018, 0.265, -3.691, 0., 1, "C", "CZ"],
266
+ [-1.286, 0.113, 3.396, 0., 0, "O", "OXT"],
267
+ ],
268
+ "PRO": [
269
+ [-0.816, 1.108, 0.254, 0., 1, "N", "N"],
270
+ [0.001, -0.107, 0.509, 0., 1, "C", "CA"],
271
+ [1.408, 0.091, 0.005, 0., 1, "C", "C"],
272
+ [1.650, 0.980, -0.777, 0., 1, "O", "O"],
273
+ [-0.703, -1.227, -0.286, 0., 1, "C", "CB"],
274
+ [-2.163, -0.753, -0.439, 0., 1, "C", "CG"],
275
+ [-2.218, 0.614, 0.276, 0., 1, "C", "CD"],
276
+ [2.391, -0.721, 0.424, 0., 0, "O", "OXT"],
277
+ ],
278
+ "SER": [
279
+ [1.525, 0.493, -0.608, 0., 1, "N", "N"],
280
+ [0.100, 0.469, -0.252, 0., 1, "C", "CA"],
281
+ [-0.053, 0.004, 1.173, 0., 1, "C", "C"],
282
+ [0.751, -0.760, 1.649, 0., 1, "O", "O"],
283
+ [-0.642, -0.489, -1.184, 0., 1, "C", "CB"],
284
+ [-0.496, -0.049, -2.535, 0., 1, "O", "OG"],
285
+ [-1.084, 0.440, 1.913, 0., 0, "O", "OXT"],
286
+ ],
287
+ "THR": [
288
+ [1.543, -0.702, 0.430, 0., 1, "N", "N"],
289
+ [0.122, -0.706, 0.056, 0., 1, "C", "CA"],
290
+ [-0.038, -0.090, -1.309, 0., 1, "C", "C"],
291
+ [0.732, 0.761, -1.683, 0., 1, "O", "O"],
292
+ [-0.675, 0.104, 1.079, 0., 1, "C", "CB"],
293
+ [-0.193, 1.448, 1.103, 0., 1, "O", "OG1"],
294
+ [-0.511, -0.521, 2.466, 0., 1, "C", "CG2"],
295
+ [-1.039, -0.488, -2.110, 0., 0, "O", "OXT"],
296
+ ],
297
+ "TRP": [
298
+ [1.278, 1.121, 2.059, 0., 1, "N", "N"],
299
+ [-0.008, 0.417, 1.970, 0., 1, "C", "CA"],
300
+ [-0.490, 0.076, 3.357, 0., 1, "C", "C"],
301
+ [0.308, -0.130, 4.240, 0., 1, "O", "O"],
302
+ [0.168, -0.868, 1.161, 0., 1, "C", "CB"],
303
+ [0.650, -0.526, -0.225, 0., 1, "C", "CG"],
304
+ [1.928, -0.418, -0.622, 0., 1, "C", "CD1"],
305
+ [-0.186, -0.256, -1.396, 0., 1, "C", "CD2"],
306
+ [1.978, -0.095, -1.951, 0., 1, "N", "NE1"],
307
+ [0.701, 0.014, -2.454, 0., 1, "C", "CE2"],
308
+ [-1.564, -0.210, -1.615, 0., 1, "C", "CE3"],
309
+ [0.190, 0.314, -3.712, 0., 1, "C", "CZ2"],
310
+ [-2.044, 0.086, -2.859, 0., 1, "C", "CZ3"],
311
+ [-1.173, 0.348, -3.907, 0., 1, "C", "CH2"],
312
+ [-1.806, 0.001, 3.610, 0., 0, "O", "OXT"],
313
+ ],
314
+ "TYR": [
315
+ [1.320, 0.952, 1.428, 0., 1, "N", "N"],
316
+ [-0.018, 0.429, 1.734, 0., 1, "C", "CA"],
317
+ [-0.103, 0.094, 3.201, 0., 1, "C", "C"],
318
+ [0.886, -0.254, 3.799, 0., 1, "O", "O"],
319
+ [-0.274, -0.831, 0.907, 0., 1, "C", "CB"],
320
+ [-0.189, -0.496, -0.559, 0., 1, "C", "CG"],
321
+ [1.022, -0.589, -1.219, 0., 1, "C", "CD1"],
322
+ [-1.324, -0.102, -1.244, 0., 1, "C", "CD2"],
323
+ [1.103, -0.282, -2.563, 0., 1, "C", "CE1"],
324
+ [-1.247, 0.210, -2.587, 0., 1, "C", "CE2"],
325
+ [-0.032, 0.118, -3.252, 0., 1, "C", "CZ"],
326
+ [0.044, 0.420, -4.574, 0., 1, "O", "OH"],
327
+ [-1.279, 0.184, 3.842, 0., 0, "O", "OXT"],
328
+ ],
329
+ "VAL": [
330
+ [1.564, -0.642, 0.454, 0., 1, "N", "N"],
331
+ [0.145, -0.698, 0.079, 0., 1, "C", "CA"],
332
+ [-0.037, -0.093, -1.288, 0., 1, "C", "C"],
333
+ [0.703, 0.784, -1.664, 0., 1, "O", "O"],
334
+ [-0.682, 0.086, 1.098, 0., 1, "C", "CB"],
335
+ [-0.497, -0.528, 2.487, 0., 1, "C", "CG1"],
336
+ [-0.218, 1.543, 1.119, 0., 1, "C", "CG2"],
337
+ [-1.022, -0.529, -2.089, 0., 0, "O", "OXT"],
338
+ ],
339
+ "A ": [
340
+ [2.135, -1.141, -5.313, 0., 0, "O", "OP3"],
341
+ [1.024, -0.137, -4.723, 0., 1, "P", "P"],
342
+ [1.633, 1.190, -4.488, 0., 1, "O", "OP1"],
343
+ [-0.183, 0.005, -5.778, 0., 1, "O", "OP2"],
344
+ [0.456, -0.720, -3.334, 0., 1, "O", "O5'"],
345
+ [-0.520, 0.209, -2.863, 0., 1, "C", "C5'"],
346
+ [-1.101, -0.287, -1.538, 0., 1, "C", "C4'"],
347
+ [-0.064, -0.383, -0.538, 0., 1, "O", "O4'"],
348
+ [-2.105, 0.739, -0.969, 0., 1, "C", "C3'"],
349
+ [-3.445, 0.360, -1.287, 0., 1, "O", "O3'"],
350
+ [-1.874, 0.684, 0.558, 0., 1, "C", "C2'"],
351
+ [-3.065, 0.271, 1.231, 0., 1, "O", "O2'"],
352
+ [-0.755, -0.367, 0.729, 0., 1, "C", "C1'"],
353
+ [0.158, 0.029, 1.803, 0., 1, "N", "N9"],
354
+ [1.265, 0.813, 1.672, 0., 1, "C", "C8"],
355
+ [1.843, 0.963, 2.828, 0., 1, "N", "N7"],
356
+ [1.143, 0.292, 3.773, 0., 1, "C", "C5"],
357
+ [1.290, 0.091, 5.156, 0., 1, "C", "C6"],
358
+ [2.344, 0.664, 5.846, 0., 1, "N", "N6"],
359
+ [0.391, -0.656, 5.787, 0., 1, "N", "N1"],
360
+ [-0.617, -1.206, 5.136, 0., 1, "C", "C2"],
361
+ [-0.792, -1.051, 3.841, 0., 1, "N", "N3"],
362
+ [0.056, -0.320, 3.126, 0., 1, "C", "C4"],
363
+ ],
364
+ "G ": [
365
+ [-1.945, -1.360, 5.599, 0., 0, "O", "OP3"],
366
+ [-0.911, -0.277, 5.008, 0., 1, "P", "P"],
367
+ [-1.598, 1.022, 4.844, 0., 1, "O", "OP1"],
368
+ [0.325, -0.105, 6.025, 0., 1, "O", "OP2"],
369
+ [-0.365, -0.780, 3.580, 0., 1, "O", "O5'"],
370
+ [0.542, 0.217, 3.109, 0., 1, "C", "C5'"],
371
+ [1.100, -0.200, 1.748, 0., 1, "C", "C4'"],
372
+ [0.033, -0.318, 0.782, 0., 1, "O", "O4'"],
373
+ [2.025, 0.898, 1.182, 0., 1, "C", "C3'"],
374
+ [3.395, 0.582, 1.439, 0., 1, "O", "O3'"],
375
+ [1.741, 0.884, -0.338, 0., 1, "C", "C2'"],
376
+ [2.927, 0.560, -1.066, 0., 1, "O", "O2'"],
377
+ [0.675, -0.220, -0.507, 0., 1, "C", "C1'"],
378
+ [-0.297, 0.162, -1.534, 0., 1, "N", "N9"],
379
+ [-1.440, 0.880, -1.334, 0., 1, "C", "C8"],
380
+ [-2.066, 1.037, -2.464, 0., 1, "N", "N7"],
381
+ [-1.364, 0.431, -3.453, 0., 1, "C", "C5"],
382
+ [-1.556, 0.279, -4.846, 0., 1, "C", "C6"],
383
+ [-2.534, 0.755, -5.397, 0., 1, "O", "O6"],
384
+ [-0.626, -0.401, -5.551, 0., 1, "N", "N1"],
385
+ [0.459, -0.934, -4.923, 0., 1, "C", "C2"],
386
+ [1.384, -1.626, -5.664, 0., 1, "N", "N2"],
387
+ [0.649, -0.800, -3.630, 0., 1, "N", "N3"],
388
+ [-0.226, -0.134, -2.868, 0., 1, "C", "C4"],
389
+ ],
390
+ "C ": [
391
+ [2.147, -1.021, -4.678, 0., 0, "O", "OP3"],
392
+ [1.049, -0.039, -4.028, 0., 1, "P", "P"],
393
+ [1.692, 1.237, -3.646, 0., 1, "O", "OP1"],
394
+ [-0.116, 0.246, -5.102, 0., 1, "O", "OP2"],
395
+ [0.415, -0.733, -2.721, 0., 1, "O", "O5'"],
396
+ [-0.546, 0.181, -2.193, 0., 1, "C", "C5'"],
397
+ [-1.189, -0.419, -0.942, 0., 1, "C", "C4'"],
398
+ [-0.190, -0.648, 0.076, 0., 1, "O", "O4'"],
399
+ [-2.178, 0.583, -0.307, 0., 1, "C", "C3'"],
400
+ [-3.518, 0.283, -0.703, 0., 1, "O", "O3'"],
401
+ [-2.001, 0.373, 1.215, 0., 1, "C", "C2'"],
402
+ [-3.228, -0.059, 1.806, 0., 1, "O", "O2'"],
403
+ [-0.924, -0.729, 1.317, 0., 1, "C", "C1'"],
404
+ [-0.036, -0.470, 2.453, 0., 1, "N", "N1"],
405
+ [0.652, 0.683, 2.514, 0., 1, "C", "C2"],
406
+ [0.529, 1.504, 1.620, 0., 1, "O", "O2"],
407
+ [1.467, 0.945, 3.535, 0., 1, "N", "N3"],
408
+ [1.620, 0.070, 4.520, 0., 1, "C", "C4"],
409
+ [2.464, 0.350, 5.569, 0., 1, "N", "N4"],
410
+ [0.916, -1.151, 4.483, 0., 1, "C", "C5"],
411
+ [0.087, -1.399, 3.442, 0., 1, "C", "C6"],
412
+ ],
413
+ "U ": [
414
+ [-2.122, 1.033, -4.690, 0., 0, "O", "OP3"],
415
+ [-1.030, 0.047, -4.037, 0., 1, "P", "P"],
416
+ [-1.679, -1.228, -3.660, 0., 1, "O", "OP1"],
417
+ [0.138, -0.241, -5.107, 0., 1, "O", "OP2"],
418
+ [-0.399, 0.736, -2.726, 0., 1, "O", "O5'"],
419
+ [0.557, -0.182, -2.196, 0., 1, "C", "C5'"],
420
+ [1.197, 0.415, -0.942, 0., 1, "C", "C4'"],
421
+ [0.194, 0.645, 0.074, 0., 1, "O", "O4'"],
422
+ [2.181, -0.588, -0.301, 0., 1, "C", "C3'"],
423
+ [3.524, -0.288, -0.686, 0., 1, "O", "O3'"],
424
+ [1.995, -0.383, 1.218, 0., 1, "C", "C2'"],
425
+ [3.219, 0.046, 1.819, 0., 1, "O", "O2'"],
426
+ [0.922, 0.723, 1.319, 0., 1, "C", "C1'"],
427
+ [0.028, 0.464, 2.451, 0., 1, "N", "N1"],
428
+ [-0.690, -0.671, 2.486, 0., 1, "C", "C2"],
429
+ [-0.587, -1.474, 1.580, 0., 1, "O", "O2"],
430
+ [-1.515, -0.936, 3.517, 0., 1, "N", "N3"],
431
+ [-1.641, -0.055, 4.530, 0., 1, "C", "C4"],
432
+ [-2.391, -0.292, 5.460, 0., 1, "O", "O4"],
433
+ [-0.894, 1.146, 4.502, 0., 1, "C", "C5"],
434
+ [-0.070, 1.384, 3.459, 0., 1, "C", "C6"],
435
+ ],
436
+ "DA ": [
437
+ [1.845, -1.282, -5.339, 0., 0, "O", "OP3"],
438
+ [0.934, -0.156, -4.636, 0., 1, "P", "P"],
439
+ [1.781, 0.996, -4.255, 0., 1, "O", "OP1"],
440
+ [-0.204, 0.331, -5.665, 0., 1, "O", "OP2"],
441
+ [0.241, -0.771, -3.320, 0., 1, "O", "O5'"],
442
+ [-0.549, 0.270, -2.744, 0., 1, "C", "C5'"],
443
+ [-1.239, -0.251, -1.482, 0., 1, "C", "C4'"],
444
+ [-0.267, -0.564, -0.458, 0., 1, "O", "O4'"],
445
+ [-2.105, 0.859, -0.835, 0., 1, "C", "C3'"],
446
+ [-3.409, 0.895, -1.418, 0., 1, "O", "O3'"],
447
+ [-2.173, 0.398, 0.640, 0., 1, "C", "C2'"],
448
+ [-0.965, -0.545, 0.797, 0., 1, "C", "C1'"],
449
+ [-0.078, -0.047, 1.852, 0., 1, "N", "N9"],
450
+ [0.962, 0.817, 1.689, 0., 1, "C", "C8"],
451
+ [1.535, 1.044, 2.835, 0., 1, "N", "N7"],
452
+ [0.897, 0.346, 3.805, 0., 1, "C", "C5"],
453
+ [1.069, 0.196, 5.191, 0., 1, "C", "C6"],
454
+ [2.079, 0.869, 5.856, 0., 1, "N", "N6"],
455
+ [0.236, -0.603, 5.850, 0., 1, "N", "N1"],
456
+ [-0.729, -1.249, 5.224, 0., 1, "C", "C2"],
457
+ [-0.925, -1.144, 3.927, 0., 1, "N", "N3"],
458
+ [-0.142, -0.368, 3.184, 0., 1, "C", "C4"],
459
+ ],
460
+ "DG ": [
461
+ [-1.603, -1.547, 5.624, 0., 0, "O", "OP3"],
462
+ [-0.818, -0.321, 4.935, 0., 1, "P", "P"],
463
+ [-1.774, 0.766, 4.630, 0., 1, "O", "OP1"],
464
+ [0.312, 0.224, 5.941, 0., 1, "O", "OP2"],
465
+ [-0.126, -0.826, 3.572, 0., 1, "O", "O5'"],
466
+ [0.550, 0.300, 3.011, 0., 1, "C", "C5'"],
467
+ [1.233, -0.113, 1.706, 0., 1, "C", "C4'"],
468
+ [0.253, -0.471, 0.705, 0., 1, "O", "O4'"],
469
+ [1.976, 1.091, 1.073, 0., 1, "C", "C3'"],
470
+ [3.294, 1.218, 1.612, 0., 1, "O", "O3'"],
471
+ [2.026, 0.692, -0.421, 0., 1, "C", "C2'"],
472
+ [0.897, -0.345, -0.573, 0., 1, "C", "C1'"],
473
+ [-0.068, 0.111, -1.575, 0., 1, "N", "N9"],
474
+ [-1.172, 0.877, -1.341, 0., 1, "C", "C8"],
475
+ [-1.804, 1.094, -2.458, 0., 1, "N", "N7"],
476
+ [-1.145, 0.482, -3.472, 0., 1, "C", "C5"],
477
+ [-1.361, 0.377, -4.866, 0., 1, "C", "C6"],
478
+ [-2.321, 0.914, -5.391, 0., 1, "O", "O6"],
479
+ [-0.473, -0.327, -5.601, 0., 1, "N", "N1"],
480
+ [0.593, -0.928, -5.003, 0., 1, "C", "C2"],
481
+ [1.474, -1.643, -5.774, 0., 1, "N", "N2"],
482
+ [0.804, -0.839, -3.709, 0., 1, "N", "N3"],
483
+ [-0.027, -0.152, -2.917, 0., 1, "C", "C4"],
484
+ ],
485
+ "DC ": [
486
+ [1.941, -1.055, -4.672, 0., 0, "O", "OP3"],
487
+ [0.987, -0.017, -3.894, 0., 1, "P", "P"],
488
+ [1.802, 1.099, -3.365, 0., 1, "O", "OP1"],
489
+ [-0.119, 0.560, -4.910, 0., 1, "O", "OP2"],
490
+ [0.255, -0.772, -2.674, 0., 1, "O", "O5'"],
491
+ [-0.571, 0.196, -2.027, 0., 1, "C", "C5'"],
492
+ [-1.300, -0.459, -0.852, 0., 1, "C", "C4'"],
493
+ [-0.363, -0.863, 0.171, 0., 1, "O", "O4'"],
494
+ [-2.206, 0.569, -0.129, 0., 1, "C", "C3'"],
495
+ [-3.488, 0.649, -0.756, 0., 1, "O", "O3'"],
496
+ [-2.322, -0.040, 1.288, 0., 1, "C", "C2'"],
497
+ [-1.106, -0.981, 1.395, 0., 1, "C", "C1'"],
498
+ [-0.267, -0.584, 2.528, 0., 1, "N", "N1"],
499
+ [0.270, 0.648, 2.563, 0., 1, "C", "C2"],
500
+ [0.052, 1.424, 1.647, 0., 1, "O", "O2"],
501
+ [1.037, 1.035, 3.581, 0., 1, "N", "N3"],
502
+ [1.291, 0.212, 4.589, 0., 1, "C", "C4"],
503
+ [2.085, 0.622, 5.635, 0., 1, "N", "N4"],
504
+ [0.746, -1.088, 4.580, 0., 1, "C", "C5"],
505
+ [-0.035, -1.465, 3.541, 0., 1, "C", "C6"],
506
+ ],
507
+ "DT ": [
508
+ [-3.912, -2.311, 1.636, 0., 0, "O", "OP3"],
509
+ [-3.968, -1.665, 3.118, 0., 1, "P", "P"],
510
+ [-4.406, -2.599, 4.208, 0., 1, "O", "OP1"],
511
+ [-4.901, -0.360, 2.920, 0., 1, "O", "OP2"],
512
+ [-2.493, -1.028, 3.315, 0., 1, "O", "O5'"],
513
+ [-2.005, -0.136, 2.327, 0., 1, "C", "C5'"],
514
+ [-0.611, 0.328, 2.728, 0., 1, "C", "C4'"],
515
+ [0.247, -0.829, 2.764, 0., 1, "O", "O4'"],
516
+ [0.008, 1.286, 1.720, 0., 1, "C", "C3'"],
517
+ [0.965, 2.121, 2.368, 0., 1, "O", "O3'"],
518
+ [0.710, 0.360, 0.754, 0., 1, "C", "C2'"],
519
+ [1.157, -0.778, 1.657, 0., 1, "C", "C1'"],
520
+ [1.164, -2.047, 0.989, 0., 1, "N", "N1"],
521
+ [2.333, -2.544, 0.374, 0., 1, "C", "C2"],
522
+ [3.410, -1.945, 0.363, 0., 1, "O", "O2"],
523
+ [2.194, -3.793, -0.240, 0., 1, "N", "N3"],
524
+ [1.047, -4.570, -0.300, 0., 1, "C", "C4"],
525
+ [0.995, -5.663, -0.857, 0., 1, "O", "O4"],
526
+ [-0.143, -3.980, 0.369, 0., 1, "C", "C5"],
527
+ [-1.420, -4.757, 0.347, 0., 1, "C", "C7"],
528
+ [-0.013, -2.784, 0.958, 0., 1, "C", "C6"],
529
+ ],
530
+ }
531
+
532
+ standard_ccd_to_ref_atom_name_chars = {
533
+ ccd: [atom_ref_feats[-1] for atom_ref_feats in standard_ccd_to_reference_features_table[ccd]]
534
+ for ccd in standard_ccds if not is_unk(ccd)
535
+ }
536
+
537
+ eye_64 = np.eye(64)
538
+ eye_128 = np.eye(128)
539
+ eye_9 = np.eye(9)
540
+ eye_7 = np.eye(7)
541
+ eye_3 = np.eye(3)
542
+ eye_32 = np.eye(32)
543
+ eye_5 = np.eye(5)
544
+ eye8 = np.eye(8)
545
+ eye5 = np.eye(5)
546
+
547
+
548
+ def _get_ref_feat_from_ccd_data(ccd, ref_feat_table):
549
+ ref_feat = np.stack([
550
+ np.concatenate(
551
+ [np.array(atom_ref_feats[:5]), eye_128[get_element_id[atom_ref_feats[5].lower()]],
552
+ *[eye_64[ord(c) - 32] for c in f"{atom_ref_feats[-1]:<4}"]], axis=-1)
553
+ for atom_ref_feats in ref_feat_table[ccd]
554
+ ], axis=0)
555
+
556
+ return ref_feat
557
+
558
+
559
+ standard_ccd_to_ref_feat = {
560
+ ccd: _get_ref_feat_from_ccd_data(ccd, standard_ccd_to_reference_features_table) for ccd in standard_ccds if
561
+ not is_unk(ccd)
562
+ }
PhysDock/data/constants/restype_constants.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from .PDBData import protein_letters_3to1_extended, nucleic_letters_3to1_extended
3
+
4
+ restype_1_to_3 = {
5
+ "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP", "C": "CYS",
6
+ "Q": "GLN", "E": "GLU", "G": "GLY", "H": "HIS", "I": "ILE",
7
+ "L": "LEU", "K": "LYS", "M": "MET", "F": "PHE", "P": "PRO",
8
+ "S": "SER", "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL",
9
+ "X": "UNK",
10
+ "0": "A ", "1": "G ", "2": "C ", "3": "U ", "4": "N ",
11
+ "5": "DA ", "6": "DG ", "7": "DC ", "8": "DT ", "9": "DN ",
12
+ }
13
+ na_c_to_type = {
14
+ "A": "A ", "G": "G ", "C": "C ", "U": "U ", "N": "N ", "T": "T ", "X": "N "
15
+ }
16
+
17
+ restype_3_to_1 = {v: k for k, v in restype_1_to_3.items()}
18
+ restype_3_to_1["T "]="8"
19
+
20
+
21
+ restypes3 = [
22
+ "ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
23
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK",
24
+ "A ", "G ", "C ", "U ", "N ",
25
+ "DA ", "DG ", "DC ", "DT ", "DN ",
26
+ ]
27
+ restypes1 = [restype_3_to_1[ccd] for ccd in restypes3]
28
+
29
+ restype_3_to_1_extended = {}
30
+
31
+ for c3, c in protein_letters_3to1_extended.items():
32
+ restype_3_to_1_extended[f"{c3:<3}"] = c
33
+ # TODO: How to distinguish RNA and DNA
34
+ for c3, c in nucleic_letters_3to1_extended.items():
35
+ restype_3_to_1_extended[c3] = restype_3_to_1[na_c_to_type[c]]
36
+ restype_3_to_1_extended.update(restype_3_to_1)
37
+
38
+ ############
39
+ standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
40
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
41
+ standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
42
+ standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
43
+ standard_nucleics = standard_rna + standard_dna
44
+ standard_ccds_without_gap = standard_protein + standard_nucleics
45
+ GAP = ["GAP"] # used in msa one-hot
46
+ standard_ccds = standard_protein + standard_nucleics + GAP
47
+
48
+ standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
49
+
50
+ standard_purines = ["A ", "G ", "DA ", "DG "]
51
+ standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
52
+
53
+ is_standard = lambda x: x in standard_ccds
54
+ is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
55
+ is_protein = lambda x: x in standard_protein and not is_unk(x)
56
+ is_rna = lambda x: x in standard_rna and not is_unk(x)
57
+ is_dna = lambda x: x in standard_dna and not is_unk(x)
58
+ is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
59
+ is_purines = lambda x: x in standard_purines
60
+ is_pyrimidines = lambda x: x in standard_pyrimidines
61
+
62
+
63
+ standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
64
+ 5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
65
+ 8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
66
+ 22, 23, 20, 20, None,
67
+ 21, 22, 19, 20, None,
68
+ None,
69
+ ])}
70
+
71
+ standard_ccd_to_token_centre_atom_name = {
72
+ **{residue: "CA" for residue in standard_protein},
73
+ **{residue: "C1'" for residue in standard_nucleics},
74
+ }
75
+
76
+ standard_ccd_to_frame_atom_name_0 = {
77
+ **{residue: "N" for residue in standard_protein},
78
+ **{residue: "C1'" for residue in standard_nucleics},
79
+ }
80
+
81
+ standard_ccd_to_frame_atom_name_1 = {
82
+ **{residue: "CA" for residue in standard_protein},
83
+ **{residue: "C3'" for residue in standard_nucleics},
84
+ }
85
+
86
+ standard_ccd_to_frame_atom_name_2 = {
87
+ **{residue: "C" for residue in standard_protein},
88
+ **{residue: "C4'" for residue in standard_nucleics},
89
+ }
90
+
91
+ standard_ccd_to_token_pseudo_beta_atom_name = {
92
+ **{residue: "CB" for residue in standard_protein},
93
+ **{residue: "C4" for residue in standard_purines},
94
+ **{residue: "C2" for residue in standard_pyrimidines},
95
+ }
96
+ standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
97
+
98
+
99
+ eye_64 = np.eye(64)
100
+ eye_128 = np.eye(128)
101
+ eye_9 = np.eye(9)
102
+ eye_7 = np.eye(7)
103
+ eye_3 = np.eye(3)
104
+ eye_32 = np.eye(32)
105
+ eye_5 = np.eye(5)
106
+ eye8 = np.eye(8)
107
+ eye5 = np.eye(5)
PhysDock/data/feature_loader.py ADDED
@@ -0,0 +1,1283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import random
4
+ from functools import reduce
5
+ from operator import add
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Optional
10
+
11
+ from PhysDock.data.constants.PDBData import protein_letters_3to1_extended
12
+ from PhysDock.data.constants import restype_constants as rc
13
+ from PhysDock.utils.io_utils import convert_md5_string, load_json, load_pkl, dump_txt, find_files
14
+ from PhysDock.data.tools.feature_processing_multimer import pair_and_merge
15
+ from PhysDock.utils.tensor_utils import centre_random_augmentation_np_apply, dgram_from_positions, \
16
+ centre_random_augmentation_np_batch
17
+ from PhysDock.data.constants.periodic_table import PeriodicTable
18
+ from PhysDock.data.tools.rdkit import get_features_from_smi
19
+
20
+ PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
21
+
22
+
23
+ class FeatureLoader:
24
+ def __init__(
25
+ self,
26
+ # Dataset Config
27
+ dataset_path=None,
28
+ msa_features_dir=None,
29
+ ccd_id_meta_data=None,
30
+
31
+ crop_size=256,
32
+ atom_crop_size=256 * 8,
33
+
34
+ # Infer config
35
+ inference_mode=False,
36
+ infer_pocket_type="atom", # "ca"
37
+ infer_pocket_cutoff=6, # 8 10 12
38
+ infer_pocket_dist_type="ligand", # "ligand_centre"
39
+ infer_use_pocket=True,
40
+ infer_use_key_res=True,
41
+
42
+ # Train Config
43
+ train_pocket_type_atom_ratio=0.5,
44
+ train_pocket_cutoff_ligand_min=6,
45
+ train_pocket_cutoff_ligand_max=12,
46
+ train_pocket_cutoff_ligand_centre_min=10,
47
+ train_pocket_cutoff_ligand_centre_max=16,
48
+ train_pocket_dist_type_ligand_ratio=0.5,
49
+ train_use_pocket_ratio=0.5,
50
+ train_use_key_res_ratio=0.5,
51
+
52
+ train_shuffle_sym_id=True,
53
+ train_spatial_crop_ligand_ratio=0.2,
54
+ train_spatial_crop_interface_ratio=0.4,
55
+ train_spatial_crop_interface_threshold=15.,
56
+ train_charility_augmentation_ratio=0.1,
57
+ train_use_template_ratio=0.75,
58
+ train_template_mask_max_ratio=0.4,
59
+
60
+ # Other Configs
61
+ max_msa_clusters=128,
62
+ key_res_random_mask_ratio=0.5,
63
+
64
+ # Abalation
65
+ use_x_gt_ligand_as_ref_pos=False,
66
+
67
+ # Recycle
68
+ num_recycles=None
69
+ ):
70
+ # Init Dataset
71
+ if dataset_path is not None:
72
+ self.msa_features_path = os.path.join(dataset_path, "msa_features")
73
+ self.uniprot_msa_features_path = os.path.join(dataset_path, "uniprot_msa_features")
74
+ if os.path.exists(os.path.join(dataset_path, "train_val")):
75
+ self.used_sample_ids = find_files(os.path.join(dataset_path, "train_val"))
76
+ if os.path.exists(os.path.join(dataset_path, "train_val_weights.json")):
77
+ weights = load_json(os.path.join(dataset_path, "train_val_weights.json"))
78
+ self.weights = np.array([weights[sample_id] for sample_id in self.used_sample_ids])
79
+ self.probabilities = torch.from_numpy(self.weights / self.weights.sum())
80
+ if msa_features_dir is not None:
81
+ self.msa_features_path = os.path.join(msa_features_dir, "msa_features")
82
+ self.uniprot_msa_features_path = os.path.join(msa_features_dir, "uniprot_msa_features")
83
+ # self.ccd_id_meta_data = load_pkl(
84
+ # "/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
85
+ # self.ccd_id_ref_mol = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_dict.pkl.gz")
86
+ #
87
+ # # self.ccd_id_meta_data = load_pkl(
88
+ # # "/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
89
+ # ddd_ccd_meta = load_pkl(
90
+ # "/2022133002/projects/stdock/stdock_v9.5/scripts/phi_ccd_meta_data_confs_chars.pkl.gz")
91
+ # self.ccd_id_meta_data.update(ddd_ccd_meta)
92
+ # self.ccd_id_ref_mol = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_dict.pkl.gz")
93
+ # ccd_id_ref_mol_phi = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/phi_ccd_dict.pkl.gz")
94
+ # self.ccd_id_ref_mol.update(ccd_id_ref_mol_phi)
95
+ if ccd_id_meta_data is None:
96
+ assert os.path.exists(os.path.join(dataset_path, "ccd_id_meta_data.pkl.gz")) or ccd_id_meta_data is not None
97
+ self.ccd_id_meta_data = load_pkl(os.path.join(dataset_path, "ccd_id_meta_data.pkl.gz"))
98
+ else:
99
+ self.ccd_id_meta_data = load_pkl(ccd_id_meta_data)
100
+
101
+
102
+ # Inference Config
103
+ self.inference_mode = inference_mode
104
+ self.infer_use_pocket = infer_use_pocket
105
+ self.infer_use_key_res = infer_use_key_res
106
+ self.infer_pocket_type = infer_pocket_type
107
+ self.infer_pocket_cutoff = infer_pocket_cutoff
108
+ self.infer_pocket_dist_type = infer_pocket_dist_type
109
+
110
+ # Training Config
111
+ self.train_pocket_type_atom_ratio = train_pocket_type_atom_ratio
112
+ self.train_pocket_cutoff_ligand_min = train_pocket_cutoff_ligand_min
113
+ self.train_pocket_cutoff_ligand_max = train_pocket_cutoff_ligand_max
114
+ self.train_pocket_cutoff_ligand_centre_min = train_pocket_cutoff_ligand_centre_min
115
+ self.train_pocket_cutoff_ligand_centre_max = train_pocket_cutoff_ligand_centre_max
116
+ self.train_pocket_dist_type_ligand_ratio = train_pocket_dist_type_ligand_ratio
117
+ self.train_use_pocket_ratio = train_use_pocket_ratio
118
+ self.train_use_key_res_ratio = train_use_key_res_ratio
119
+ self.train_shuffle_sym_id = train_shuffle_sym_id
120
+ self.train_spatial_crop_ligand_ratio = train_spatial_crop_ligand_ratio
121
+ self.train_spatial_crop_interface_ratio = train_spatial_crop_interface_ratio
122
+ self.train_spatial_crop_interface_threshold = train_spatial_crop_interface_threshold
123
+ self.train_charility_augmentation_ratio = train_charility_augmentation_ratio
124
+ self.train_use_template_ratio = train_use_template_ratio
125
+ self.train_template_mask_max_ratio = train_template_mask_max_ratio
126
+
127
+ # Other Configs
128
+ self.token_bond_threshold = 2.4
129
+ self.key_res_random_mask_ratio = key_res_random_mask_ratio
130
+ self.crop_size = crop_size
131
+ self.atom_crop_size = atom_crop_size
132
+ self.max_msa_clusters = max_msa_clusters
133
+
134
+ self.use_x_gt_ligand_as_ref_pos = use_x_gt_ligand_as_ref_pos
135
+
136
+ self.num_recycles = num_recycles
137
+
138
+ def _update_CONF_META_DATA(self, CONF_META_DATA, ccds):
139
+ for ccd in ccds:
140
+ if ccd in CONF_META_DATA:
141
+ continue
142
+ ccd_features = self.ccd_id_meta_data[ccd]
143
+ ref_pos = ccd_features["ref_pos"]
144
+ ref_pos = ref_pos - np.mean(ref_pos, axis=0, keepdims=True)
145
+ CONF_META_DATA[ccd] = {
146
+ "ref_feat": np.concatenate([
147
+ ref_pos,
148
+ ccd_features["ref_charge"][..., None],
149
+ rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
150
+ ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
151
+ rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
152
+ rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
153
+ rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
154
+ rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
155
+ ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
156
+ ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
157
+ ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
158
+ ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
159
+ ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
160
+ ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
161
+ ], axis=-1),
162
+ "rel_tok_feat": np.concatenate([
163
+ rc.eye_32[ccd_features["d_token"]].astype(np.float32),
164
+ rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
165
+ ccd_features["token_bonds"].astype(np.float32)[..., None],
166
+ ccd_features["bond_as_double"].astype(np.float32)[..., None],
167
+ ccd_features["bond_in_ring"].astype(np.float32)[..., None],
168
+ ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
169
+ ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
170
+ ], axis=-1),
171
+ "ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
172
+ "ref_element": ccd_features["ref_element"],
173
+ "token_bonds": ccd_features["token_bonds"],
174
+ }
175
+
176
+ return CONF_META_DATA
177
+
178
+ def _update_chain_feature(self, chain_feature, CONF_META_DATA, use_pocket, use_key_res, ):
179
+ ccds_ori = chain_feature["ccds"]
180
+ chain_class = chain_feature["chain_class"]
181
+ if chain_class == "protein":
182
+ sequence = "".join([protein_letters_3to1_extended.get(ccd, "X") for ccd in ccds_ori])
183
+ md5 = convert_md5_string(f"protein:{sequence}")
184
+ # with open("add_msa.fasta", "a") as f:
185
+ # f.write(f">{md5}\n{sequence}\n")
186
+ try:
187
+ # import shutil
188
+ # shutil.copy(
189
+ # os.path.join(self.msa_features_path, f"{md5}.pkl.gz"),
190
+ # os.path.join("/home/zhangkexin/research/PhysDock/examples/demo/features/msa_features",
191
+ # f"{md5}.pkl.gz")
192
+ # )
193
+ # shutil.copy(
194
+ # os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"),
195
+ # os.path.join("/home/zhangkexin/research/PhysDock/examples/demo/features/uniprot_msa_features",
196
+ # f"{md5}.pkl.gz")
197
+ # )
198
+ chain_feature.update(
199
+ load_pkl(os.path.join(self.msa_features_path, f"{md5}.pkl.gz"))
200
+ )
201
+ except:
202
+ print(f"Can't find msa feature!!! md5: {md5}")
203
+ with open("add_msa.fasta", "a") as f:
204
+ f.write(f">{md5}\n{sequence}\n")
205
+
206
+ chain_feature.update(
207
+ load_pkl(os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"))
208
+ )
209
+ else:
210
+ chain_feature["msa"] = np.array([[rc.standard_ccds.index(ccd)
211
+ if ccd in rc.standard_ccds else 20 for ccd in ccds_ori]] * 2,
212
+ dtype=np.int8)
213
+ chain_feature["deletion_matrix"] = np.zeros_like(chain_feature["msa"])
214
+
215
+ # Merge Key Res Feat & Augmentation
216
+ if "salt bridges" in chain_feature and use_key_res:
217
+ key_res_feat = np.stack([
218
+ chain_feature["salt bridges"],
219
+ chain_feature["pi-cation interactions"],
220
+ chain_feature["hydrophobic interactions"],
221
+ chain_feature["pi-stacking"],
222
+ chain_feature["hydrogen bonds"],
223
+ chain_feature["metal complexes"],
224
+ np.zeros_like(chain_feature["salt bridges"]),
225
+ ], axis=-1).astype(np.float32)
226
+ else:
227
+ key_res_feat = np.zeros([len(ccds_ori), 7], dtype=np.float32)
228
+ is_key_res = np.any(key_res_feat.astype(np.bool_), axis=-1).astype(np.float32)
229
+ # Augmentation
230
+ # if not self.inference_mode:
231
+ key_res_feat = (key_res_feat *
232
+ (np.random.random([len(ccds_ori), 7]) > self.key_res_random_mask_ratio))
233
+ if "pocket_res_feat" in chain_feature and use_pocket:
234
+ pocket_res_feat = chain_feature["pocket_res_feat"]
235
+ else:
236
+ pocket_res_feat = np.zeros([len(ccds_ori)], dtype=np.float32)
237
+ x_gt = []
238
+ atom_id_to_conformer_atom_id = []
239
+
240
+ # Conformer
241
+ conformer_id_to_chunk_sizes = []
242
+ residue_index = []
243
+ restype = []
244
+ ccds = []
245
+
246
+ conformer_exists = []
247
+
248
+ for c_id, ccd in enumerate(chain_feature["ccds"]):
249
+ no_atom_this_conf = len(CONF_META_DATA[ccd]["ref_feat"])
250
+ conformer_atom_ids_this_conf = np.arange(no_atom_this_conf)
251
+
252
+ x_gt_this_conf = chain_feature["all_atom_positions"][c_id]
253
+ x_exists_this_conf = chain_feature["all_atom_mask"][c_id].astype(np.bool_)
254
+
255
+ # TODO DEBUG
256
+ # conformer_exist = np.sum(x_exists_this_conf).item() > len(x_exists_this_conf) - 2
257
+ conformer_exist = np.any(x_exists_this_conf).item()
258
+
259
+ if rc.is_standard(ccd):
260
+ conformer_exist = conformer_exist and x_exists_this_conf[1]
261
+ if ccd != "GLY":
262
+ conformer_exist = conformer_exist and x_exists_this_conf[4]
263
+
264
+ conformer_exists.append(conformer_exist)
265
+ if conformer_exist:
266
+ # Atomwise
267
+ x_gt.append(x_gt_this_conf[x_exists_this_conf])
268
+ atom_id_to_conformer_atom_id.append(conformer_atom_ids_this_conf[x_exists_this_conf])
269
+ # Tokenwise
270
+ residue_index.append(c_id)
271
+ conformer_id_to_chunk_sizes.append(np.sum(x_exists_this_conf).item())
272
+ restype.append(rc.standard_ccds.index(ccd) if ccd in rc.standard_ccds else 20)
273
+ ccds.append(ccd)
274
+ # print(ccds)
275
+ # print("x_gt", x_gt)
276
+ x_gt = np.concatenate(x_gt, axis=0)
277
+ atom_id_to_conformer_atom_id = np.concatenate(atom_id_to_conformer_atom_id, axis=0, dtype=np.int32)
278
+ residue_index = np.array(residue_index, dtype=np.int64)
279
+ conformer_id_to_chunk_sizes = np.array(conformer_id_to_chunk_sizes, dtype=np.int64)
280
+ restype = np.array(restype, dtype=np.int64)
281
+
282
+ conformer_exists = np.array(conformer_exists, dtype=np.bool_)
283
+
284
+ chain_feature_update = {
285
+ "x_gt": x_gt,
286
+ "atom_id_to_conformer_atom_id": atom_id_to_conformer_atom_id,
287
+ "residue_index": residue_index,
288
+ "conformer_id_to_chunk_sizes": conformer_id_to_chunk_sizes,
289
+ "restype": restype,
290
+ "ccds": ccds,
291
+ "msa": chain_feature["msa"][:, conformer_exists],
292
+ "deletion_matrix": chain_feature["deletion_matrix"][:, conformer_exists],
293
+ "chain_class": chain_class,
294
+ "key_res_feat": key_res_feat[conformer_exists],
295
+ "is_key_res": is_key_res[conformer_exists],
296
+ "pocket_res_feat": pocket_res_feat[conformer_exists],
297
+ }
298
+
299
+ chain_feature_update["is_protein"] = np.array([chain_class == "protein"] * len(ccds)).astype(np.float32)
300
+ chain_feature_update["is_ligand"] = np.array([chain_class != "protein"] * len(ccds)).astype(np.float32)
301
+ # Assert Short Poly Chain like peptide
302
+ chain_feature_update["is_short_poly"] = np.array(
303
+ [chain_class != "protein" and len(ccds) >= 2 and rc.is_standard(ccd) for ccd in ccds]
304
+ ).astype(np.float32)
305
+
306
+ if "msa_all_seq" in chain_feature:
307
+ chain_feature_update["msa_all_seq"] = chain_feature["msa_all_seq"][:, conformer_exists]
308
+ chain_feature_update["deletion_matrix_all_seq"] = \
309
+ chain_feature["deletion_matrix_all_seq"][:, conformer_exists]
310
+ chain_feature_update["msa_species_identifiers_all_seq"] = chain_feature["msa_species_identifiers_all_seq"]
311
+ del chain_feature
312
+ return chain_feature_update
313
+
314
+ def _update_smi(self, smi, all_chain_labels, CONF_META_DATA):
315
+ ccd = "XXX"
316
+ chain_id = "99"
317
+ label_feature, conf_feature, ref_mol = get_features_from_smi(smi)
318
+ all_chain_labels[chain_id] = {
319
+ "all_atom_positions": label_feature["x_gt"][None],
320
+ "all_atom_mask": label_feature["x_exists"][None],
321
+ "ccds": [ccd]
322
+ }
323
+ ref_atom_name_chars = []
324
+ for id, ele in enumerate(conf_feature["ref_element"]):
325
+ atom_name = f"{PeriodicTable[ele] + str(id):<4}"
326
+ ref_atom_name_chars.append(atom_name)
327
+ CONF_META_DATA[ccd] = {
328
+ "ref_feat": np.concatenate([
329
+ conf_feature["ref_pos"],
330
+ conf_feature["ref_charge"][..., None],
331
+ rc.eye_128[conf_feature["ref_element"]].astype(np.float32),
332
+ conf_feature["ref_is_aromatic"].astype(np.float32)[..., None],
333
+ rc.eye_9[conf_feature["ref_degree"]].astype(np.float32),
334
+ rc.eye_7[conf_feature["ref_hybridization"]].astype(np.float32),
335
+ rc.eye_9[conf_feature["ref_implicit_valence"]].astype(np.float32),
336
+ rc.eye_3[conf_feature["ref_chirality"]].astype(np.float32),
337
+ conf_feature["ref_in_ring_of_3"].astype(np.float32)[..., None],
338
+ conf_feature["ref_in_ring_of_4"].astype(np.float32)[..., None],
339
+ conf_feature["ref_in_ring_of_5"].astype(np.float32)[..., None],
340
+ conf_feature["ref_in_ring_of_6"].astype(np.float32)[..., None],
341
+ conf_feature["ref_in_ring_of_7"].astype(np.float32)[..., None],
342
+ conf_feature["ref_in_ring_of_8"].astype(np.float32)[..., None],
343
+ ], axis=-1),
344
+ "rel_tok_feat": np.concatenate([
345
+ rc.eye_32[conf_feature["d_token"]].astype(np.float32),
346
+ rc.eye_5[conf_feature["bond_type"]].astype(np.float32),
347
+ conf_feature["token_bonds"].astype(np.float32)[..., None],
348
+ conf_feature["bond_as_double"].astype(np.float32)[..., None],
349
+ conf_feature["bond_in_ring"].astype(np.float32)[..., None],
350
+ conf_feature["bond_is_conjugated"].astype(np.float32)[..., None],
351
+ conf_feature["bond_is_aromatic"].astype(np.float32)[..., None],
352
+ ], axis=-1),
353
+ "ref_atom_name_chars": ref_atom_name_chars,
354
+ "ref_element": conf_feature["ref_element"],
355
+ "token_bonds": conf_feature["token_bonds"],
356
+ }
357
+
358
+ return all_chain_labels, CONF_META_DATA, ref_mol
359
+
360
+ def _add_assembly_feature(self, all_chain_features, SEQ3, ASYM_ID):
361
+ entities = {}
362
+ for chain_id, seq3 in SEQ3.items():
363
+ if seq3 not in entities:
364
+ entities[seq3] = [chain_id]
365
+ else:
366
+ entities[seq3].append(chain_id)
367
+
368
+ asym_id = 0
369
+ for entity_id, seq3 in enumerate(list(entities.keys())):
370
+ chain_ids = copy.deepcopy(entities[seq3])
371
+ if not self.inference_mode and self.train_shuffle_sym_id:
372
+ # sym_id augmentation
373
+ random.shuffle(chain_ids)
374
+ for sym_id, chain_id in enumerate(chain_ids):
375
+ num_conformers = len(all_chain_features[chain_id]["ccds"])
376
+ all_chain_features[chain_id]["asym_id"] = \
377
+ np.full([num_conformers], fill_value=asym_id, dtype=np.int32)
378
+ all_chain_features[chain_id]["sym_id"] = \
379
+ np.full([num_conformers], fill_value=sym_id, dtype=np.int32)
380
+ all_chain_features[chain_id]["entity_id"] = \
381
+ np.full([num_conformers], fill_value=entity_id, dtype=np.int32)
382
+
383
+ all_chain_features[chain_id]["sequence_3"] = seq3
384
+ ASYM_ID[asym_id] = chain_id
385
+
386
+ asym_id += 1
387
+ return all_chain_features, ASYM_ID
388
+
389
+ def _crop_all_chain_features(self, all_chain_features, infer_meta_data, crop_centre=None):
390
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
391
+ ordered_chain_ids = list(all_chain_features.keys())
392
+
393
+ x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
394
+
395
+ token_id_to_centre_atom_id = []
396
+ token_id_to_conformer_id = []
397
+ token_id_to_ccd_chunk_sizes = []
398
+ token_id_to_ccd = []
399
+ asym_id_ca = []
400
+ token_id = 0
401
+ atom_id = 0
402
+ conf_id = 0
403
+ x_gt_ligand = []
404
+ for chain_id in ordered_chain_ids:
405
+ if chain_id.isdigit() and len(all_chain_features[chain_id]["ccds"]) == 1:
406
+ x_gt_ligand.append(all_chain_features[chain_id]["x_gt"])
407
+ atom_offset = 0
408
+ for ccd, chunk_size_this_ccd, asym_id in zip(
409
+ all_chain_features[chain_id]["ccds"],
410
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
411
+ all_chain_features[chain_id]["asym_id"],
412
+ ):
413
+ inner_atom_idx = all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][
414
+ atom_offset:atom_offset + chunk_size_this_ccd]
415
+ atom_names = [CONF_META_DATA[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
416
+ if rc.is_standard(ccd):
417
+
418
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
419
+ if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
420
+ token_id_to_centre_atom_id.append(atom_id)
421
+ token_id_to_conformer_id.append(conf_id)
422
+ token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
423
+ token_id_to_ccd.append(ccd)
424
+ asym_id_ca.append(asym_id)
425
+ atom_id += 1
426
+ token_id += 1
427
+
428
+ else:
429
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
430
+ token_id_to_centre_atom_id.append(atom_id)
431
+ token_id_to_conformer_id.append(conf_id)
432
+ token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
433
+ token_id_to_ccd.append(ccd)
434
+ asym_id_ca.append(asym_id)
435
+ atom_id += 1
436
+ token_id += 1
437
+ atom_offset += chunk_size_this_ccd
438
+ conf_id += 1
439
+
440
+ x_gt_ca = x_gt[token_id_to_centre_atom_id]
441
+ asym_id_ca = np.array(asym_id_ca)
442
+
443
+ crop_scheme_seed = random.random()
444
+
445
+ # Crop Ligand Centre
446
+ if self.inference_mode and len(x_gt_ligand) == 1:
447
+ x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
448
+ x_gt_sel = np.mean(x_gt_ligand, axis=0)[None]
449
+
450
+ # Spatial Crop Ligand
451
+ elif crop_scheme_seed < (self.train_spatial_crop_ligand_ratio if not self.inference_mode else 1.0) and len(
452
+ x_gt_ligand) > 0:
453
+ x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
454
+ x_gt_sel = random.choice(x_gt_ligand)[None]
455
+ # Spatial Crop Interface
456
+ elif crop_scheme_seed < self.train_spatial_crop_ligand_ratio + self.train_spatial_crop_interface_ratio and len(
457
+ set(asym_id_ca.tolist())) > 1:
458
+ chain_same = asym_id_ca[None] * asym_id_ca[:, None]
459
+ dist = np.linalg.norm(x_gt_ca[:, None] - x_gt_ca[None], axis=-1)
460
+
461
+ dist = dist + chain_same * 100
462
+ # interface_threshold
463
+ mask = np.any(dist < self.train_spatial_crop_interface_threshold, axis=-1)
464
+ if sum(mask) > 0:
465
+ x_gt_sel = random.choice(x_gt_ca[mask])[None]
466
+ else:
467
+ x_gt_sel = random.choice(x_gt_ca)[None]
468
+ # Spatial Crop
469
+ else:
470
+ x_gt_sel = random.choice(x_gt_ca)[None]
471
+ dist = np.linalg.norm(x_gt_ca - x_gt_sel, axis=-1)
472
+ token_idxs = np.argsort(dist)
473
+
474
+ select_ccds_idx = []
475
+ to_sum_atom = 0
476
+ to_sum_token = 0
477
+ for token_idx in token_idxs:
478
+ ccd_idx = token_id_to_conformer_id[token_idx]
479
+ ccd_chunk_size = token_id_to_ccd_chunk_sizes[token_idx]
480
+ ccd_this_token = token_id_to_ccd[token_idx]
481
+ if ccd_idx in select_ccds_idx:
482
+ continue
483
+ if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
484
+ break
485
+ to_add_token = 1 if rc.is_standard(ccd_this_token) else ccd_chunk_size
486
+ if to_sum_token + to_add_token > self.crop_size:
487
+ break
488
+ select_ccds_idx.append(ccd_idx)
489
+ to_sum_atom += ccd_chunk_size
490
+ to_sum_token += to_add_token
491
+
492
+ ccd_all_id = 0
493
+ crop_chains = []
494
+ for chain_id in ordered_chain_ids:
495
+ conformer_used_mask = []
496
+ atom_used_mask = []
497
+ ccds = []
498
+ for ccd, chunk_size_this_ccd in zip(
499
+ all_chain_features[chain_id]["ccds"],
500
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
501
+ ):
502
+ if ccd_all_id in select_ccds_idx:
503
+ ccds.append(ccd)
504
+ if chain_id not in crop_chains:
505
+ crop_chains.append(chain_id)
506
+ conformer_used_mask.append(ccd_all_id in select_ccds_idx)
507
+ atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
508
+ ccd_all_id += 1
509
+ conf_mask = np.array(conformer_used_mask).astype(np.bool_)
510
+ atom_mask = np.array(atom_used_mask).astype(np.bool_)
511
+ # Update All Chain Features
512
+ all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
513
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
514
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
515
+ all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
516
+ all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
517
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
518
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
519
+ # BUG Fix
520
+ all_chain_features[chain_id]["key_res_feat"] = all_chain_features[chain_id]["key_res_feat"][conf_mask]
521
+ all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id]["pocket_res_feat"][conf_mask]
522
+ all_chain_features[chain_id]["is_key_res"] = all_chain_features[chain_id]["is_key_res"][conf_mask]
523
+ all_chain_features[chain_id]["is_protein"] = all_chain_features[chain_id]["is_protein"][conf_mask]
524
+ all_chain_features[chain_id]["is_short_poly"] = all_chain_features[chain_id]["is_short_poly"][conf_mask]
525
+ all_chain_features[chain_id]["is_ligand"] = all_chain_features[chain_id]["is_ligand"][conf_mask]
526
+ all_chain_features[chain_id]["asym_id"] = all_chain_features[chain_id]["asym_id"][conf_mask]
527
+ all_chain_features[chain_id]["sym_id"] = all_chain_features[chain_id]["sym_id"][conf_mask]
528
+ all_chain_features[chain_id]["entity_id"] = all_chain_features[chain_id]["entity_id"][conf_mask]
529
+
530
+ all_chain_features[chain_id]["ccds"] = ccds
531
+ if "msa" in all_chain_features[chain_id]:
532
+ all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
533
+ all_chain_features[chain_id]["deletion_matrix"] = \
534
+ all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
535
+ if "msa_all_seq" in all_chain_features[chain_id]:
536
+ all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
537
+ all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
538
+ all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
539
+ # Remove Unused Chains
540
+ for chain_id in list(all_chain_features.keys()):
541
+ if chain_id not in crop_chains:
542
+ all_chain_features.pop(chain_id, None)
543
+ return all_chain_features, infer_meta_data
544
+
545
+ def _make_ccd_features(self, raw_feats, infer_meta_data):
546
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
547
+ ccds = raw_feats["ccds"]
548
+ atom_id_to_conformer_atom_id = raw_feats["atom_id_to_conformer_atom_id"]
549
+ conformer_id_to_chunk_sizes = raw_feats["conformer_id_to_chunk_sizes"]
550
+
551
+ # Atomwise
552
+ atom_id_to_conformer_id = []
553
+ atom_id_to_token_id = []
554
+ ref_feat = []
555
+
556
+ # Tokenwise
557
+ s_mask = []
558
+ token_id_to_conformer_id = []
559
+ token_id_to_chunk_sizes = []
560
+ token_id_to_centre_atom_id = []
561
+ token_id_to_pseudo_beta_atom_id = []
562
+
563
+ token_id = 0
564
+ atom_id = 0
565
+ for conf_id, (ccd, ccd_atoms) in enumerate(zip(ccds, conformer_id_to_chunk_sizes)):
566
+ conf_meta_data = CONF_META_DATA[ccd]
567
+ # UNK Conformer
568
+ if rc.is_unk(ccd):
569
+ s_mask.append(0)
570
+ token_id_to_chunk_sizes.append(0)
571
+ token_id_to_conformer_id.append(conf_id)
572
+ token_id_to_centre_atom_id.append(-1)
573
+ token_id_to_pseudo_beta_atom_id.append(-1)
574
+ token_id += 1
575
+ # Standard Residue
576
+ elif rc.is_standard(ccd):
577
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
578
+ atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
579
+ ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
580
+ token_id_to_conformer_id.append(conf_id)
581
+ token_id_to_chunk_sizes.append(ccd_atoms.item())
582
+ s_mask.append(1)
583
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
584
+ # Update Atomwise Features
585
+ atom_id_to_conformer_id.append(conf_id)
586
+ atom_id_to_token_id.append(token_id)
587
+ # Update special atom ids
588
+ if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
589
+ token_id_to_centre_atom_id.append(atom_id)
590
+ if atom_name == rc.standard_ccd_to_token_pseudo_beta_atom_name[ccd]:
591
+ token_id_to_pseudo_beta_atom_id.append(atom_id)
592
+ atom_id += 1
593
+ token_id += 1
594
+ # Nonestandard Residue & Ligand
595
+ else:
596
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
597
+ atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
598
+ ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
599
+ # ref_pos_new.append(conf_meta_data["ref_pos_new"][:, inner_atom_idx])
600
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
601
+ # Update Atomwise Features
602
+ atom_id_to_conformer_id.append(conf_id)
603
+ atom_id_to_token_id.append(token_id)
604
+ # Update Tokenwise Features
605
+ token_id_to_chunk_sizes.append(1)
606
+ token_id_to_conformer_id.append(conf_id)
607
+ s_mask.append(1)
608
+ token_id_to_centre_atom_id.append(atom_id)
609
+ token_id_to_pseudo_beta_atom_id.append(atom_id)
610
+ atom_id += 1
611
+ token_id += 1
612
+
613
+ if len(ref_feat) > 1:
614
+ ref_feat = np.concatenate(ref_feat, axis=0).astype(np.float32)
615
+ else:
616
+ ref_feat = ref_feat[0].astype(np.float32)
617
+
618
+ features = {
619
+ # Atomwise
620
+ "atom_id_to_conformer_id": np.array(atom_id_to_conformer_id, dtype=np.int64),
621
+ "atom_id_to_token_id": np.array(atom_id_to_token_id, dtype=np.int64),
622
+ "ref_feat": ref_feat,
623
+ # Tokewise
624
+ "token_id_to_conformer_id": np.array(token_id_to_conformer_id, dtype=np.int64),
625
+ "s_mask": np.array(s_mask, dtype=np.int64),
626
+ "token_id_to_centre_atom_id": np.array(token_id_to_centre_atom_id, dtype=np.int64),
627
+ "token_id_to_pseudo_beta_atom_id": np.array(token_id_to_pseudo_beta_atom_id, dtype=np.int64),
628
+ "token_id_to_chunk_sizes": np.array(token_id_to_chunk_sizes, dtype=np.int64),
629
+ }
630
+ features["ref_pos"] = features["ref_feat"][..., :3]
631
+ return features
632
+
633
+ def pair_and_merge(self, all_chain_features, infer_meta_data):
634
+ CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"] # Dict
635
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
636
+ ASYM_ID = infer_meta_data["ASYM_ID"]
637
+ homo_feats = {}
638
+
639
+ all_chain_ids = list(all_chain_features.keys())
640
+ if len(all_chain_ids) == 1 and CHAIN_CLASS[all_chain_ids[0]] == "ligand":
641
+ ordered_chain_ids = all_chain_ids
642
+ raw_feats = all_chain_features[all_chain_ids[0]]
643
+ raw_feats["msa"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
644
+ raw_feats["deletion_matrix"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
645
+ keys = list(raw_feats.keys())
646
+
647
+ for feature_name in keys:
648
+ if feature_name not in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index",
649
+ "conformer_id_to_chunk_sizes", "restype", "is_protein", "is_short_poly",
650
+ "is_ligand",
651
+ "asym_id", "sym_id", "entity_id", "msa", "deletion_matrix", "ccds",
652
+ "pocket_res_feat", "key_res_feat", "is_key_res"]:
653
+ raw_feats.pop(feature_name)
654
+
655
+ # Update Profile and Deletion Mean
656
+ msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
657
+ raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
658
+ del msa_one_hot
659
+ raw_feats["deletion_mean"] = (torch.atan(
660
+ torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
661
+ ) * (2.0 / torch.pi)).numpy()
662
+ else:
663
+
664
+ for chain_id in list(all_chain_features.keys()):
665
+ homo_feats[chain_id] = {
666
+ "asym_id": copy.deepcopy(all_chain_features[chain_id]["asym_id"]),
667
+ "sym_id": copy.deepcopy(all_chain_features[chain_id]["sym_id"]),
668
+ "entity_id": copy.deepcopy(all_chain_features[chain_id]["entity_id"]),
669
+ }
670
+ for chain_id in list(all_chain_features.keys()):
671
+ homo_feats[chain_id]["chain_class"] = all_chain_features[chain_id].pop("chain_class")
672
+ homo_feats[chain_id]["sequence_3"] = all_chain_features[chain_id].pop("sequence_3")
673
+ homo_feats[chain_id]["msa"] = all_chain_features[chain_id].pop("msa")
674
+ homo_feats[chain_id]["deletion_matrix"] = all_chain_features[chain_id].pop("deletion_matrix")
675
+ if "msa_all_seq" in all_chain_features[chain_id]:
676
+ homo_feats[chain_id]["msa_all_seq"] = all_chain_features[chain_id].pop("msa_all_seq")
677
+ homo_feats[chain_id]["deletion_matrix_all_seq"] = all_chain_features[chain_id].pop(
678
+ "deletion_matrix_all_seq")
679
+ homo_feats[chain_id]["msa_species_identifiers_all_seq"] = all_chain_features[chain_id].pop(
680
+ "msa_species_identifiers_all_seq")
681
+
682
+ # Initial raw feats with merged homo feats
683
+ raw_feats = pair_and_merge(homo_feats, is_homomer_or_monomer=False)
684
+
685
+ # Update Profile and Deletion Mean
686
+ msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
687
+ raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
688
+ del msa_one_hot
689
+ raw_feats["deletion_mean"] = (torch.atan(
690
+ torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
691
+ ) * (2.0 / torch.pi)).numpy()
692
+
693
+ # Merge no homo feats according to asym_id
694
+ ordered_asym_ids = []
695
+ for i in raw_feats["asym_id"]:
696
+ if i not in ordered_asym_ids:
697
+ ordered_asym_ids.append(i)
698
+ ordered_chain_ids = [ASYM_ID[i] for i in ordered_asym_ids]
699
+ for feature_name in ["chain_class", "sequence_3", "assembly_num_chains", "entity_mask", "seq_length",
700
+ "num_alignments"]:
701
+ raw_feats.pop(feature_name, None)
702
+
703
+ for feature_name in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index", "conformer_id_to_chunk_sizes",
704
+ "restype", "is_protein", "is_short_poly", "is_ligand", "pocket_res_feat",
705
+ "key_res_feat", "is_key_res"]:
706
+ raw_feats[feature_name] = np.concatenate([
707
+ all_chain_features[chain_id].pop(feature_name) for chain_id in ordered_chain_ids
708
+ ], axis=0)
709
+
710
+ # Conformerwise Chain Class
711
+ CHAIN_CLASS_NEW = []
712
+ for chain_id in ordered_chain_ids:
713
+ CHAIN_CLASS_NEW += [CHAIN_CLASS[chain_id]] * len(all_chain_features[chain_id]["ccds"])
714
+ infer_meta_data["CHAIN_CLASS"] = CHAIN_CLASS_NEW
715
+
716
+ raw_feats["ccds"] = reduce(add, [all_chain_features[chain_id].pop("ccds") for chain_id in ordered_chain_ids])
717
+
718
+ # Create Atomwise and Tokenwise Features
719
+ raw_feats.update(self._make_ccd_features(raw_feats, infer_meta_data))
720
+ if self.use_x_gt_ligand_as_ref_pos:
721
+ is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_conformer_id"]].astype(np.bool_)
722
+ raw_feats["ref_pos"][is_ligand_atom] = raw_feats["x_gt"][is_ligand_atom] - np.mean(
723
+ raw_feats["x_gt"][is_ligand_atom], axis=0, keepdims=True)
724
+
725
+ asym_id_conformerwise = copy.deepcopy(raw_feats["asym_id"])
726
+ residue_index_conformerwise = copy.deepcopy(raw_feats["residue_index"])
727
+
728
+ # Conformerwise to Tokenwise
729
+ token_id_to_conformer_id = raw_feats["token_id_to_conformer_id"]
730
+ for key in ["is_protein", "is_short_poly", "is_ligand", "residue_index", "restype", "asym_id", "entity_id",
731
+ "sym_id", "deletion_mean", "profile", "pocket_res_feat", "key_res_feat", "is_key_res"]:
732
+ raw_feats[key] = raw_feats[key][token_id_to_conformer_id]
733
+ for key in ["msa", "deletion_matrix"]:
734
+ if key in raw_feats:
735
+ raw_feats[key] = raw_feats[key][:, token_id_to_conformer_id]
736
+ ###################################################
737
+ # Centre Random Augmentation of ref pos #
738
+ ###################################################
739
+ # atom_id_to_token_id
740
+ # atom_id_to_conformer_id
741
+ raw_feats["ref_pos"] = centre_random_augmentation_np_apply(
742
+ raw_feats["ref_pos"], raw_feats["atom_id_to_conformer_id"]).astype(np.float32)
743
+ raw_feats["ref_feat"][:, :3] = raw_feats["ref_pos"]
744
+
745
+ ###################################################
746
+ # Create token pair features #
747
+ ###################################################
748
+ no_token = len(raw_feats["token_id_to_conformer_id"])
749
+ token_bonds = np.zeros([no_token, no_token], dtype=np.float32)
750
+ rel_tok_feat = np.zeros([no_token, no_token, 42], dtype=np.float32)
751
+ offset = 0
752
+ atom_offset = 0
753
+ for ccd, len_atoms in zip(
754
+ raw_feats["ccds"],
755
+ raw_feats["conformer_id_to_chunk_sizes"]
756
+ ):
757
+ if rc.is_standard(ccd) or rc.is_unk(ccd):
758
+ offset += 1
759
+ else:
760
+ len_atoms = len_atoms.item()
761
+ inner_atom_idx = raw_feats["atom_id_to_conformer_atom_id"][atom_offset:atom_offset + len_atoms]
762
+ token_bonds[offset:offset + len_atoms, offset:offset + len_atoms] = \
763
+ CONF_META_DATA[ccd]["token_bonds"][inner_atom_idx][:, inner_atom_idx]
764
+ rel_tok_feat[offset:offset + len_atoms, offset:offset + len_atoms] = \
765
+ CONF_META_DATA[ccd]["rel_tok_feat"][inner_atom_idx][:, inner_atom_idx]
766
+ offset += len_atoms
767
+ atom_offset += len_atoms
768
+ raw_feats["token_bonds"] = token_bonds.astype(np.float32)
769
+ raw_feats["token_bonds_feature"] = token_bonds.astype(np.float32)
770
+ raw_feats["rel_tok_feat"] = rel_tok_feat.astype(np.float32)
771
+ ###################################################
772
+ # Charility Augmentation #
773
+ ###################################################
774
+ if not self.inference_mode:
775
+ # TODO Charility probs
776
+ charility_seed = random.random()
777
+ if charility_seed < self.train_charility_augmentation_ratio:
778
+ ref_chirality = raw_feats["ref_feat"][:, 158:161]
779
+ ref_chirality_replace = np.zeros_like(ref_chirality)
780
+ ref_chirality_replace[:, 2] = 1
781
+
782
+ is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_token_id"]]
783
+ remove_charility = (np.random.randint(0, 2, [len(is_ligand_atom)]) * is_ligand_atom).astype(
784
+ np.bool_)
785
+ ref_chirality = np.where(remove_charility[:, None], ref_chirality_replace, ref_chirality)
786
+ raw_feats["ref_feat"][:, 158:161] = ref_chirality
787
+
788
+ # MASKS
789
+ raw_feats["x_exists"] = np.ones_like(raw_feats["x_gt"][..., 0]).astype(np.float32)
790
+ raw_feats["a_mask"] = raw_feats["x_exists"]
791
+ raw_feats["s_mask"] = np.ones_like(raw_feats["asym_id"]).astype(np.float32)
792
+ raw_feats["ref_space_uid"] = raw_feats["atom_id_to_conformer_id"]
793
+
794
+ # Write Infer Meta Data
795
+ infer_meta_data["ccds"] = raw_feats.pop("ccds")
796
+ infer_meta_data["atom_id_to_conformer_atom_id"] = raw_feats.pop("atom_id_to_conformer_atom_id")
797
+ infer_meta_data["residue_index"] = residue_index_conformerwise
798
+ infer_meta_data["asym_id"] = asym_id_conformerwise
799
+ infer_meta_data["conformer_id_to_chunk_sizes"] = raw_feats.pop("conformer_id_to_chunk_sizes")
800
+
801
+ return raw_feats, infer_meta_data
802
+
803
+ def make_feats(self, tensors):
804
+ # Target Feat
805
+ tensors["target_feat"] = torch.cat([
806
+ F.one_hot(tensors["restype"].long(), 32).float(),
807
+ tensors["profile"].float(),
808
+ tensors["deletion_mean"][..., None].float()
809
+ ], dim=-1)
810
+
811
+ if self.num_recycles is None:
812
+ # MSA Feat
813
+ inds = [0] + torch.randperm(len(tensors["msa"]))[:self.max_msa_clusters - 1].tolist()
814
+
815
+ tensors["msa"] = tensors["msa"][inds]
816
+ tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
817
+
818
+ has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
819
+ pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
820
+ deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
821
+ tensors["msa_feat"] = torch.cat([
822
+ F.one_hot(tensors["msa"].long(), 32).float(),
823
+ has_deletion[..., None].float(),
824
+ deletion_value[..., None].float(),
825
+ ], dim=-1)
826
+ else:
827
+ batch_msa_feat = []
828
+ for i in range(self.num_recycles):
829
+ inds = [0] + torch.randperm(len(tensors["msa"]))[:self.max_msa_clusters - 1].tolist()
830
+
831
+ tensors["msa"] = tensors["msa"][inds]
832
+ tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
833
+
834
+ has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
835
+ pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
836
+ deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
837
+ msa_feat = torch.cat([
838
+ F.one_hot(tensors["msa"].long(), 32).float(),
839
+ has_deletion[..., None].float(),
840
+ deletion_value[..., None].float(),
841
+ ], dim=-1)
842
+ batch_msa_feat.append(msa_feat)
843
+ tensors["msa_feat"] = batch_msa_feat[0]
844
+ tensors["batch_msa_feat"] = torch.stack(batch_msa_feat, dim=0)
845
+
846
+ tensors.pop("msa", None)
847
+ tensors.pop("deletion_mean", None)
848
+ tensors.pop("profile", None)
849
+ tensors.pop("deletion_matrix", None)
850
+
851
+ return tensors
852
+
853
+ def _make_token_bonds(self, tensors):
854
+ # Get Polymer-Ligand & Ligand-Ligand Within Conformer Token Bond
855
+ # Atomwise asym_id
856
+ asym_id = tensors["asym_id"][tensors["atom_id_to_token_id"]]
857
+ is_ligand = tensors["is_ligand"][tensors["atom_id_to_token_id"]]
858
+
859
+ x_gt = tensors["x_gt"]
860
+ a_mask = tensors["a_mask"]
861
+
862
+ # Get
863
+ atom_id_to_token_id = tensors["atom_id_to_token_id"]
864
+
865
+ num_token = len(tensors["asym_id"])
866
+ between_conformer_token_bonds = torch.zeros([num_token, num_token])
867
+
868
+ # create chainwise feature
869
+ asym_id_chain = []
870
+ asym_id_atom_offset = []
871
+ asym_id_is_ligand = []
872
+ for atom_offset, (a_id, i_id) in enumerate(zip(asym_id.tolist(), is_ligand.tolist())):
873
+ if len(asym_id_chain) == 0 or asym_id_chain[-1] != a_id:
874
+ asym_id_chain.append(a_id)
875
+ asym_id_atom_offset.append(atom_offset)
876
+ asym_id_is_ligand.append(i_id)
877
+
878
+ len_asym_id_chain = len(asym_id_chain)
879
+ if len_asym_id_chain >= 2:
880
+ for i in range(0, len_asym_id_chain - 1):
881
+ asym_id_i = asym_id_chain[i]
882
+ mask_i = asym_id == asym_id_i
883
+ x_gt_i = x_gt[mask_i]
884
+ a_mask_i = a_mask[mask_i]
885
+ for j in range(i + 1, len_asym_id_chain):
886
+ if not bool(asym_id_is_ligand[i]) and not bool(asym_id_is_ligand[j]):
887
+ continue
888
+ asym_id_j = asym_id_chain[j]
889
+ mask_j = asym_id == asym_id_j
890
+ x_gt_j = x_gt[mask_j]
891
+ a_mask_j = a_mask[mask_j]
892
+ dis_ij = torch.norm(x_gt_i[:, None, :] - x_gt_j[None, :, :], dim=-1)
893
+ dis_ij = dis_ij + (1 - a_mask_i[:, None] * a_mask_j[None]) * 1000
894
+ if torch.min(dis_ij) < self.token_bond_threshold:
895
+ ij = torch.argmin(dis_ij).item()
896
+ l_j = len(x_gt_j)
897
+ atom_i = int(ij // l_j) # raw
898
+ atom_j = int(ij % l_j) # col
899
+ global_atom_i = atom_i + asym_id_atom_offset[i]
900
+ global_atom_j = atom_j + asym_id_atom_offset[j]
901
+ token_i = atom_id_to_token_id[global_atom_i]
902
+ token_j = atom_id_to_token_id[global_atom_j]
903
+
904
+ between_conformer_token_bonds[token_i, token_j] = 1
905
+ between_conformer_token_bonds[token_j, token_i] = 1
906
+ token_bond_seed = random.random()
907
+ tensors["token_bonds"] = tensors["token_bonds"] + between_conformer_token_bonds
908
+ # Docking Indicate Token Bond
909
+ # if token_bond_seed >= 1:
910
+ # tensors["token_bonds_feature"] = tensors["token_bonds"]
911
+ return tensors
912
+
913
+ def _pad_to_size(self, tensors):
914
+
915
+ to_pad_atom = self.atom_crop_size - len(tensors["x_gt"])
916
+ to_pad_token = self.crop_size - len(tensors["residue_index"])
917
+ if to_pad_token > 0:
918
+ for k in ["restype", "residue_index", "is_protein", "is_short_poly", "is_ligand", "is_key_res",
919
+ "asym_id", "entity_id", "sym_id", "token_id_to_conformer_id", "s_mask",
920
+ "token_id_to_centre_atom_id", "token_id_to_pseudo_beta_atom_id", "token_id_to_chunk_sizes",
921
+ "pocket_res_feat"]:
922
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token])
923
+ for k in ["target_feat", "msa_feat", "key_res_feat", "batch_msa_feat"]:
924
+ if k in tensors:
925
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token])
926
+ for k in ["token_bonds", "token_bonds_feature"]:
927
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token, 0, to_pad_token])
928
+ for k in ["rel_tok_feat"]:
929
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token, 0, to_pad_token])
930
+ if to_pad_atom > 0:
931
+ for k in ["a_mask", "x_exists", "atom_id_to_conformer_id", "atom_id_to_token_id", "ref_space_uid"]:
932
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom])
933
+ for k in ["x_gt", "ref_feat", "ref_pos"]: # , "ref_pos_new"
934
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_atom])
935
+ # for k in ["z_mask"]: # , "ref_pos_new"
936
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
937
+ # for k in ["conformer_mask_atom"]:
938
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
939
+ # for k in ["rel_token_feat_atom"]:
940
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0,0,0, to_pad_atom, 0, to_pad_atom])
941
+ # rel_token_feat_atom
942
+ return tensors
943
+
944
+ def get_template_feat(self, tensors):
945
+ x_gt = tensors["x_gt"][tensors["token_id_to_pseudo_beta_atom_id"]]
946
+ z_mask = tensors["z_mask"]
947
+ asym_id = tensors["asym_id"]
948
+ is_protein = tensors["is_protein"]
949
+ chain_same = (asym_id[None] == asym_id[:, None]).float()
950
+ protein2d = is_protein[None] * is_protein[:, None]
951
+ dgram = dgram_from_positions(x_gt, no_bins=39)
952
+ dgram = dgram * protein2d[..., None] * z_mask[..., None]
953
+
954
+ if not self.inference_mode:
955
+ if random.random() > self.train_use_template_ratio:
956
+ tensors["t_mask"] = torch.tensor(1, dtype=torch.float32)
957
+ bert_mask = torch.rand([len(x_gt)]) > random.random() * (1 - self.train_template_mask_max_ratio)
958
+ template_pseudo_beta_mask = (bert_mask[None] * bert_mask[:, None]) * z_mask * protein2d
959
+ else:
960
+ tensors["t_mask"] = torch.tensor(0, dtype=torch.float32)
961
+ template_pseudo_beta_mask = z_mask * protein2d
962
+ else:
963
+ tensors["t_mask"] = torch.tensor(1, dtype=torch.float32)
964
+ template_pseudo_beta_mask = z_mask * protein2d
965
+ dgram = dgram * template_pseudo_beta_mask[..., None]
966
+ templ_feat = torch.cat([dgram, template_pseudo_beta_mask[..., None]], dim=-1)
967
+ tensors["templ_feat"] = templ_feat.float()
968
+ return tensors
969
+
970
+ def transform(self, raw_feats):
971
+ # np to tensor
972
+ tensors = dict()
973
+ for key in raw_feats.keys():
974
+ tensors[key] = torch.from_numpy(raw_feats[key])
975
+ # Make Target & MSA Feat
976
+ tensors = self.make_feats(tensors)
977
+
978
+ # Make Token Bond Feat
979
+ tensors = self._make_token_bonds(tensors)
980
+
981
+ # Padding
982
+ if not self.inference_mode:
983
+ tensors = self._pad_to_size(tensors)
984
+
985
+ # Mask
986
+ tensors["z_mask"] = tensors["s_mask"][None] * tensors["s_mask"][:, None]
987
+ tensors["ap_mask"] = tensors["a_mask"][None] * tensors["a_mask"][:, None]
988
+ tensors["is_dna"] = torch.zeros_like(tensors["is_protein"])
989
+ tensors["is_rna"] = torch.zeros_like(tensors["is_protein"])
990
+
991
+ # Template
992
+ tensors = self.get_template_feat(tensors)
993
+
994
+ # Correct Type
995
+ is_short_poly = tensors.pop("is_short_poly")
996
+ tensors["is_protein"] = tensors["is_protein"] + is_short_poly
997
+ tensors["is_ligand"] = tensors["is_ligand"] - is_short_poly
998
+ return tensors
999
+
1000
+ # residue_index 0-100
1001
+ # CCDS
1002
+ # CCD<RES_ID>
1003
+ #
1004
+ def load(
1005
+ self,
1006
+ system_pkl_path, # Receptor chains: all_atom_positions pocket_res_feat Ligand_chains
1007
+ template_receptor_chain_ids=None, # ["A"]
1008
+ template_ligand_chain_ids=None, # ["1"]
1009
+ remove_receptor=False,
1010
+ remove_ligand=False, # True, CCD_META_DATA ref_mol
1011
+ smi=None, # "CCCCC"
1012
+ ):
1013
+ ##########################################################
1014
+ # Initialization of Configs #
1015
+ ##########################################################
1016
+ if self.inference_mode:
1017
+ pocket_type = self.infer_pocket_type
1018
+ pocket_cutoff = self.infer_pocket_cutoff
1019
+ pocket_dist_type = self.infer_pocket_dist_type
1020
+ use_pocket = self.infer_use_pocket
1021
+ use_key_res = self.infer_use_key_res
1022
+ else:
1023
+ pocket_type = random.choices(
1024
+ ["atom", "ca"],
1025
+ [self.train_pocket_type_atom_ratio, 1 - self.train_pocket_type_atom_ratio])
1026
+
1027
+ pocket_dist_type = random.choices(
1028
+ ["ligand", "ligand_cetre"],
1029
+ [self.train_pocket_dist_type_ligand_ratio, 1 - self.train_pocket_dist_type_ligand_ratio])
1030
+
1031
+ if pocket_dist_type == "ligand":
1032
+ pocket_cutoff = self.train_pocket_cutoff_ligand_min + random.random() * (
1033
+ self.train_pocket_cutoff_ligand_max - self.train_pocket_cutoff_ligand_min)
1034
+ else:
1035
+ pocket_cutoff = self.train_pocket_cutoff_ligand_centre_min + random.random() * (
1036
+ self.train_pocket_cutoff_ligand_centre_max - self.train_pocket_cutoff_ligand_centre_min)
1037
+
1038
+ use_pocket = random.random() < self.train_use_pocket_ratio
1039
+ use_key_res = random.random() < self.train_use_key_res_ratio
1040
+
1041
+ ##########################################################
1042
+ # Initialization of features #
1043
+ ##########################################################
1044
+ system_id = os.path.split(system_pkl_path)[1][:-7]
1045
+ all_chain_labels = {}
1046
+ all_chain_features = {}
1047
+
1048
+ CONF_META_DATA = {}
1049
+ ref_mol = None
1050
+ CHAIN_CLASS = {}
1051
+ SEQ3 = {}
1052
+ ASYM_ID = {}
1053
+
1054
+ ##########################################################
1055
+ # Load All Chain Labels #
1056
+ ##########################################################
1057
+ data = load_pkl(system_pkl_path)
1058
+ # print(data)
1059
+ if template_receptor_chain_ids is None:
1060
+ template_receptor_chain_ids = [chain_id for chain_id in data.keys() if not chain_id.isdigit()]
1061
+
1062
+ if template_ligand_chain_ids is None:
1063
+ template_ligand_chain_ids = [chain_id for chain_id in data.keys() if chain_id.isdigit()]
1064
+ # TODO: Save Ligand Centre for cropped screening
1065
+ # Calculate Pocket Residue According to Template receptor and ligand
1066
+ if not remove_receptor and len(template_ligand_chain_ids) > 0:
1067
+ for receptor_chain_id in template_receptor_chain_ids:
1068
+ ccds_receptor = data[receptor_chain_id]["ccds"]
1069
+ x_gt_receptor = data[receptor_chain_id]["all_atom_positions"]
1070
+ x_exists_receptor = data[receptor_chain_id]["all_atom_mask"]
1071
+ x_gt_this_receptor = []
1072
+ atom_id_to_ccd_id = []
1073
+ for ccd_id, (ccd, x_gt_ccd, x_exists_ccd) in enumerate(
1074
+ zip(ccds_receptor, x_gt_receptor, x_exists_receptor)):
1075
+
1076
+ if rc.is_standard(ccd):
1077
+ x_exists_ccd_bool = x_exists_ccd.astype(np.bool_)
1078
+ if x_exists_ccd_bool[1]: # CA exsits
1079
+ if pocket_type == "atom":
1080
+ num_atoms = sum(x_exists_ccd_bool)
1081
+ x_gt_this_receptor.append(x_gt_ccd[x_exists_ccd_bool])
1082
+ atom_id_to_ccd_id += num_atoms * [ccd_id]
1083
+ else:
1084
+ x_gt_this_receptor.append(x_gt_ccd[1][None])
1085
+ atom_id_to_ccd_id.append(ccd_id)
1086
+ x_gt_this_receptor = np.concatenate(x_gt_this_receptor, axis=0)
1087
+ atom_id_to_ccd_id = np.array(atom_id_to_ccd_id)
1088
+ used_ccd_ids = []
1089
+ for ligand_chain_id in template_ligand_chain_ids:
1090
+ x_gt_ligand = data[ligand_chain_id]["all_atom_positions"]
1091
+ x_exists_ligand = data[ligand_chain_id]["all_atom_mask"]
1092
+ x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)[
1093
+ np.concatenate(x_exists_ligand, axis=0).astype(np.bool_)]
1094
+ if pocket_dist_type == "ligand":
1095
+ used_ccd_bool = np.any(
1096
+ np.linalg.norm(x_gt_this_receptor[:, None] - x_gt_ligand[None], axis=-1) < pocket_cutoff,
1097
+ axis=-1)
1098
+ elif pocket_dist_type == "ligand_centre":
1099
+ x_mean = np.min(x_gt_ligand, axis=0, keepdims=True)
1100
+ used_ccd_bool = np.any(
1101
+ np.linalg.norm(x_gt_this_receptor[:, None] - x_mean[None], axis=-1) < pocket_cutoff,
1102
+ axis=-1)
1103
+ else:
1104
+ raise NotImplementedError()
1105
+ used_ccd_ids.append(atom_id_to_ccd_id[used_ccd_bool])
1106
+ used_ccd_ids = list(sorted(list(set(np.concatenate(used_ccd_ids, axis=0).tolist()))))
1107
+ pocket_res_feat = np.zeros([len(ccds_receptor)], dtype=np.float32)
1108
+ pocket_res_feat[used_ccd_ids] = 1.
1109
+ all_chain_labels[receptor_chain_id] = data[receptor_chain_id]
1110
+ all_chain_labels[receptor_chain_id]["pocket_res_feat"] = pocket_res_feat
1111
+
1112
+ if remove_ligand:
1113
+ if remove_receptor:
1114
+ assert smi is not None and self.inference_mode
1115
+ if smi is not None:
1116
+ all_chain_labels, CONF_META_DATA, ref_mol = self._update_smi(smi, all_chain_labels, CONF_META_DATA)
1117
+ else:
1118
+ assert smi is None
1119
+ for ligand_chain_id in template_ligand_chain_ids:
1120
+ all_chain_labels[ligand_chain_id] = data[ligand_chain_id]
1121
+ # For Benchmarking
1122
+ if len(template_ligand_chain_ids) == 1:
1123
+ ccds = all_chain_labels[template_ligand_chain_ids[0]]["ccds"]
1124
+ if len(ccds) == 1:
1125
+ # ref_mol = self.ccd_id_ref_mol[ccds[0]]
1126
+ ref_mol = self.ccd_id_meta_data[ccds[0]]["ref_mol"]
1127
+ ##########################################################
1128
+ # Init All Chain Features #
1129
+ ##########################################################
1130
+ for chain_id, chain_feature in all_chain_labels.items():
1131
+ ccds = chain_feature["ccds"]
1132
+ CONF_META_DATA = self._update_CONF_META_DATA(CONF_META_DATA, ccds)
1133
+ SEQ3[chain_id] = "-".join(ccds)
1134
+ chain_class = "protein" if not chain_id.isdigit() else "ligand"
1135
+ chain_feature["chain_class"] = chain_class
1136
+ # print(chain_id)
1137
+ all_chain_features[chain_id] = self._update_chain_feature(
1138
+ chain_feature,
1139
+ CONF_META_DATA,
1140
+ use_pocket,
1141
+ use_key_res,
1142
+ )
1143
+ CHAIN_CLASS[chain_id] = chain_class
1144
+ ##########################################################
1145
+ # Add Assembly Feature #
1146
+ ##########################################################
1147
+ all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3, ASYM_ID)
1148
+
1149
+ infer_meta_data = {
1150
+ "CONF_META_DATA": CONF_META_DATA,
1151
+ "SEQ3": SEQ3,
1152
+ "ASYM_ID": ASYM_ID,
1153
+ "CHAIN_CLASS": CHAIN_CLASS,
1154
+ "ref_mol": ref_mol,
1155
+ "system_id": system_id,
1156
+ }
1157
+ ##########################################################
1158
+ # Cropping #
1159
+ ##########################################################
1160
+ # if not self.inference_mode:
1161
+ if self.crop_size is not None:
1162
+ all_chain_features, infer_meta_data = self._crop_all_chain_features(
1163
+ all_chain_features, infer_meta_data, crop_centre=None) # TODO: Add Cropping Centre
1164
+ ##########################################################
1165
+ # Pair & Merge #
1166
+ ##########################################################
1167
+ raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
1168
+
1169
+ ##########################################################
1170
+ # Transform #
1171
+ ##########################################################
1172
+ tensors = self.transform(raw_feats)
1173
+ return tensors, infer_meta_data
1174
+
1175
+ def write_pdb(self, x_pred, fname, infer_meta_data, receptor_only=False, ligand_only=False):
1176
+ ccds = infer_meta_data["ccds"]
1177
+ atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
1178
+ ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
1179
+ CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
1180
+ conf_meta_data = infer_meta_data["CONF_META_DATA"]
1181
+ residue_index = infer_meta_data["residue_index"].tolist()
1182
+ asym_id = infer_meta_data["asym_id"].tolist()
1183
+
1184
+ atom_lines = []
1185
+ atom_offset = 0
1186
+ for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
1187
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
1188
+ atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
1189
+ atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
1190
+ chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
1191
+ record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
1192
+
1193
+ for ccd_atom_idx, atom_name in enumerate(atom_names):
1194
+ x = x_pred[atom_offset]
1195
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
1196
+ res_name_3 = ccd
1197
+ alt_loc = ""
1198
+ insertion_code = ""
1199
+ occupancy = 1.00
1200
+ element = atom_elements[ccd_atom_idx]
1201
+ # b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
1202
+ b_factor = 70.
1203
+ charge = 0
1204
+ pos = x.tolist()
1205
+ atom_line = (
1206
+ f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
1207
+ f"{res_name_3.split()[0][-3:]:>3} {chain_tag:>1}"
1208
+ f"{res_id + 1:>4}{insertion_code:>1} "
1209
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
1210
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
1211
+ f"{element:>2}{charge:>2}"
1212
+ )
1213
+ if receptor_only and not ligand_only:
1214
+ if record_type == "ATOM":
1215
+ atom_lines.append(atom_line)
1216
+ elif not receptor_only and ligand_only:
1217
+ if record_type == "HETATM":
1218
+ atom_lines.append(atom_line)
1219
+ elif not receptor_only and not ligand_only:
1220
+ atom_lines.append(atom_line)
1221
+ else:
1222
+ raise NotImplementedError()
1223
+ atom_offset += 1
1224
+ if atom_offset == len(atom_id_to_conformer_atom_id):
1225
+ break
1226
+ out = "\n".join(atom_lines)
1227
+ out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
1228
+ dump_txt(out, fname)
1229
+
1230
+ def write_pdb_block(self, x_pred, infer_meta_data, receptor_only=False, ligand_only=False):
1231
+ ccds = infer_meta_data["ccds"]
1232
+ atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
1233
+ ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
1234
+ CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
1235
+ conf_meta_data = infer_meta_data["CONF_META_DATA"]
1236
+ residue_index = infer_meta_data["residue_index"].tolist()
1237
+ asym_id = infer_meta_data["asym_id"].tolist()
1238
+
1239
+ atom_lines = []
1240
+ atom_offset = 0
1241
+ for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
1242
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
1243
+ atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
1244
+ atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
1245
+ chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
1246
+ record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
1247
+
1248
+ for ccd_atom_idx, atom_name in enumerate(atom_names):
1249
+ x = x_pred[atom_offset]
1250
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
1251
+ res_name_3 = ccd
1252
+ alt_loc = ""
1253
+ insertion_code = ""
1254
+ occupancy = 1.00
1255
+ element = atom_elements[ccd_atom_idx]
1256
+ # b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
1257
+ b_factor = 70.
1258
+ charge = 0
1259
+ pos = x.tolist()
1260
+ atom_line = (
1261
+ f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
1262
+ f"{res_name_3.split()[0][-3:]:>3} {chain_tag:>1}"
1263
+ f"{res_id + 1:>4}{insertion_code:>1} "
1264
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
1265
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
1266
+ f"{element:>2}{charge:>2}"
1267
+ )
1268
+ if receptor_only and not ligand_only:
1269
+ if record_type == "ATOM":
1270
+ atom_lines.append(atom_line)
1271
+ elif not receptor_only and ligand_only:
1272
+ if record_type == "HETATM":
1273
+ atom_lines.append(atom_line)
1274
+ elif not receptor_only and not ligand_only:
1275
+ atom_lines.append(atom_line)
1276
+ else:
1277
+ raise NotImplementedError()
1278
+ atom_offset += 1
1279
+ if atom_offset == len(atom_id_to_conformer_atom_id):
1280
+ break
1281
+ out = "\n".join(atom_lines)
1282
+ out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
1283
+ return out
PhysDock/data/feature_loader_plinder.py ADDED
@@ -0,0 +1,1258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ############################
2
+ # 0.85 Receptor+Ligand
3
+ # 0.5 APO Template
4
+ # 0.5 HOLO Template
5
+ # 0.05 Protein (All APO or PRED)
6
+ # 0.1 Ligand
7
+ ############################
8
+ import copy
9
+ import os
10
+ import random
11
+ from functools import reduce
12
+ from operator import add
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ # Key Res
17
+ # Dynamic Cutoff
18
+
19
+ import numpy as np
20
+
21
+ from stdock.data.constants.PDBData import protein_letters_3to1_extended
22
+ from stdock.data.constants import restype_constants as rc
23
+ from stdock.utils.io_utils import convert_md5_string, load_json, load_pkl, dump_txt
24
+ from stdock.data.tools.feature_processing_multimer import pair_and_merge
25
+ from stdock.utils.tensor_utils import centre_random_augmentation_np_apply, dgram_from_positions, \
26
+ centre_random_augmentation_np_batch
27
+ from stdock.data.constants.periodic_table import PeriodicTable
28
+
29
+ PDB_CHAIN_IDS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
30
+
31
+
32
+ class FeatureLoader:
33
+ def __init__(
34
+ self,
35
+ # config,
36
+ token_crop_size=256,
37
+ atom_crop_size=256 * 8,
38
+ inference_mode=False,
39
+ ):
40
+ self.inference_mode = inference_mode
41
+ self.msa_features_path = "/2022133002/data/stfold-data-v5/features/msa_features/"
42
+ self.uniprot_msa_features_path = "/2022133002/data/stfold-data-v5/features/uniprot_msa_features/"
43
+
44
+ self.token_crop_size = token_crop_size
45
+ self.atom_crop_size = atom_crop_size
46
+ self.token_bond_threshold = 2.4
47
+
48
+ self.ccd_id_meta_data = load_pkl(
49
+ "/2022133002/projects/stdock/stdock_v9.5/scripts/ccd_meta_data_confs_chars.pkl.gz")
50
+
51
+ self.samples_path = "/2022133002/data/plinder/2024-06/v2/plinder_samples_raw_data_v2"
52
+ to_remove = set(load_json("/2022133002/projects/stdock/stdock_v9.6/scripts/to_remove.json"))
53
+ weights = load_json(
54
+ "/2022133002/projects/stdock/stdock_v9.5/scripts/cluster_scripts/train_samples_new_weight_seq.json")
55
+ self.splits = load_json("/2022133002/data/plinder/2024-06/v2/splits/splits.json")
56
+ self.used_sample_ids = [sample_id for sample_id in weights
57
+ if sample_id in self.splits and self.splits[sample_id] != "test"
58
+ and sample_id not in to_remove]
59
+ print("train samples", len(self.used_sample_ids))
60
+ # self.weights = np.array(list(weights.values()))
61
+ self.weights = np.array([weights[sample_id] for sample_id in self.used_sample_ids])
62
+
63
+ self.probabilities = torch.from_numpy(self.weights / self.weights.sum())
64
+
65
+ self.used_test_sample_ids = [sample_id for sample_id in weights
66
+ if sample_id in self.splits and self.splits[sample_id] == "test"
67
+ and sample_id not in to_remove]
68
+ print("test samples", len(self.used_test_sample_ids))
69
+ # self.weights = np.array(list(weights.values()))
70
+ self.test_weights = np.array([weights[sample_id] for sample_id in self.used_test_sample_ids])
71
+
72
+ self.test_probabilities = torch.from_numpy(self.test_weights / self.test_weights.sum())
73
+
74
+ def _update_CONF_META_DATA(self, CONF_META_DATA, ccds):
75
+ for ccd in ccds:
76
+ if ccd in CONF_META_DATA:
77
+ continue
78
+ ccd_features = self.ccd_id_meta_data[ccd]
79
+ ref_pos = ccd_features["ref_pos"]
80
+
81
+ ref_pos = ref_pos - np.mean(ref_pos, axis=0, keepdims=True)
82
+
83
+ CONF_META_DATA[ccd] = {
84
+ "ref_feat": np.concatenate([
85
+ ref_pos,
86
+ ccd_features["ref_charge"][..., None],
87
+ rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
88
+ ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
89
+ rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
90
+ rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
91
+ rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
92
+ rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
93
+ ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
94
+ ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
95
+ ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
96
+ ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
97
+ ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
98
+ ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
99
+ ], axis=-1),
100
+ "rel_tok_feat": np.concatenate([
101
+ rc.eye_32[ccd_features["d_token"]].astype(np.float32),
102
+ rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
103
+ ccd_features["token_bonds"].astype(np.float32)[..., None],
104
+ ccd_features["bond_as_double"].astype(np.float32)[..., None],
105
+ ccd_features["bond_in_ring"].astype(np.float32)[..., None],
106
+ ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
107
+ ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
108
+ ], axis=-1),
109
+ "ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
110
+ "ref_element": ccd_features["ref_element"],
111
+ "token_bonds": ccd_features["token_bonds"],
112
+
113
+ }
114
+ if not rc.is_standard(ccd):
115
+ conformers = ccd_features["conformers"]
116
+ if conformers is None:
117
+ conformers = np.repeat(ccd_features["ref_pos"][None], 32, axis=0)
118
+ else:
119
+ conformers = np.stack([random.choice(ccd_features["conformers"]) for i in range(32)], axis=0)
120
+ CONF_META_DATA[ccd]["batch_ref_pos"] = centre_random_augmentation_np_batch(conformers)
121
+
122
+ return CONF_META_DATA
123
+
124
+ def _update_CONF_META_DATA_ligand(self, CONF_META_DATA, sequence_3, ccd_features):
125
+ ccds = sequence_3.split("-")
126
+ # ccd_features = self.ccd_id_meta_data[ccd]
127
+ for ccd in ccds:
128
+ CONF_META_DATA[ccd] = {
129
+ "ref_feat": np.concatenate([
130
+ ccd_features["ref_pos"],
131
+ ccd_features["ref_charge"][..., None],
132
+ rc.eye_128[ccd_features["ref_element"]].astype(np.float32),
133
+ ccd_features["ref_is_aromatic"].astype(np.float32)[..., None],
134
+ rc.eye_9[ccd_features["ref_degree"]].astype(np.float32),
135
+ rc.eye_7[ccd_features["ref_hybridization"]].astype(np.float32),
136
+ rc.eye_9[ccd_features["ref_implicit_valence"]].astype(np.float32),
137
+ rc.eye_3[ccd_features["ref_chirality"]].astype(np.float32),
138
+ ccd_features["ref_in_ring_of_3"].astype(np.float32)[..., None],
139
+ ccd_features["ref_in_ring_of_4"].astype(np.float32)[..., None],
140
+ ccd_features["ref_in_ring_of_5"].astype(np.float32)[..., None],
141
+ ccd_features["ref_in_ring_of_6"].astype(np.float32)[..., None],
142
+ ccd_features["ref_in_ring_of_7"].astype(np.float32)[..., None],
143
+ ccd_features["ref_in_ring_of_8"].astype(np.float32)[..., None],
144
+ ], axis=-1),
145
+ "rel_tok_feat": np.concatenate([
146
+ rc.eye_32[ccd_features["d_token"]].astype(np.float32),
147
+ rc.eye_5[ccd_features["bond_type"]].astype(np.float32),
148
+ ccd_features["token_bonds"].astype(np.float32)[..., None],
149
+ ccd_features["bond_as_double"].astype(np.float32)[..., None],
150
+ ccd_features["bond_in_ring"].astype(np.float32)[..., None],
151
+ ccd_features["bond_is_conjugated"].astype(np.float32)[..., None],
152
+ ccd_features["bond_is_aromatic"].astype(np.float32)[..., None],
153
+ ], axis=-1),
154
+ "ref_atom_name_chars": ccd_features["ref_atom_name_chars"],
155
+ "ref_element": ccd_features["ref_element"],
156
+ "token_bonds": ccd_features["token_bonds"],
157
+
158
+ }
159
+ if not rc.is_standard(ccd):
160
+ conformers = ccd_features["conformers"]
161
+ if conformers is None:
162
+ conformers = np.repeat(ccd_features["ref_pos"][None], 32, axis=0)
163
+ else:
164
+ conformers = np.stack([random.choice(ccd_features["conformers"]) for i in range(32)], axis=0)
165
+ CONF_META_DATA[ccd]["batch_ref_pos"] = centre_random_augmentation_np_batch(conformers)
166
+
167
+ return CONF_META_DATA
168
+
169
+ def _update_chain_feature(self, chain_feature, CONF_META_DATA):
170
+
171
+ ccds_ori = chain_feature["ccds"]
172
+ chain_class = chain_feature["chain_class"]
173
+ if chain_class == "protein":
174
+
175
+ sequence = "".join([protein_letters_3to1_extended.get(ccd, "X") for ccd in ccds_ori])
176
+ md5 = convert_md5_string(f"protein:{sequence}")
177
+
178
+ chain_feature.update(
179
+ load_pkl(os.path.join(self.msa_features_path, f"{md5}.pkl.gz"))
180
+ )
181
+ chain_feature.update(
182
+ load_pkl(os.path.join(self.uniprot_msa_features_path, f"{md5}.pkl.gz"))
183
+ )
184
+ else:
185
+ chain_feature["msa"] = np.array([[rc.standard_ccds.index(ccd)
186
+ if ccd in rc.standard_ccds else 20 for ccd in ccds_ori]] * 2,
187
+ dtype=np.int8)
188
+ chain_feature["deletion_matrix"] = np.zeros_like(chain_feature["msa"])
189
+
190
+ # Merge Key Res Feat & Augmentation
191
+ if "salt bridges" in chain_feature:
192
+ key_res_feat = np.stack([
193
+ chain_feature["salt bridges"],
194
+ chain_feature["pi-cation interactions"],
195
+ chain_feature["hydrophobic interactions"],
196
+ chain_feature["pi-stacking"],
197
+ chain_feature["hydrogen bonds"],
198
+ chain_feature["metal complexes"],
199
+ np.zeros_like(chain_feature["salt bridges"]),
200
+ ], axis=-1).astype(np.float32)
201
+ else:
202
+ key_res_feat = np.zeros(
203
+ [len(ccds_ori), 7], dtype=np.float32
204
+ )
205
+ is_key_res = np.any(key_res_feat.astype(np.bool_), axis=-1).astype(np.float32)
206
+ # Augmentation
207
+ if not self.inference_mode:
208
+ key_res_feat = key_res_feat * (np.random.random([len(ccds_ori), 7]) < 0.5)
209
+ # else:
210
+ # # TODO: No key res in inference mode
211
+ # key_res_feat = key_res_feat * 0
212
+ # Atom
213
+ x_gt = []
214
+ atom_id_to_conformer_atom_id = []
215
+
216
+ # Conformer
217
+ conformer_id_to_chunk_sizes = []
218
+ residue_index = []
219
+ restype = []
220
+ ccds = []
221
+
222
+ conformer_exists = []
223
+
224
+ for c_id, ccd in enumerate(chain_feature["ccds"]):
225
+ no_atom_this_conf = len(CONF_META_DATA[ccd]["ref_feat"])
226
+ conformer_atom_ids_this_conf = np.arange(no_atom_this_conf)
227
+ x_gt_this_conf = chain_feature["all_atom_positions"][c_id]
228
+ x_exists_this_conf = chain_feature["all_atom_mask"][c_id].astype(np.bool_)
229
+
230
+ # TODO DEBUG
231
+ #
232
+ conformer_exist = np.any(x_exists_this_conf).item()
233
+ if rc.is_standard(ccd):
234
+ conformer_exist = np.sum(x_exists_this_conf).item() > len(x_exists_this_conf) - 2
235
+ # conformer_exist = conformer_exist and x_exists_this_conf[1]
236
+ # if ccd != "GLY":
237
+ # conformer_exist = conformer_exist and x_exists_this_conf[4]
238
+
239
+ conformer_exists.append(conformer_exist)
240
+ if conformer_exist:
241
+ # Atomwise
242
+ x_gt.append(x_gt_this_conf[x_exists_this_conf])
243
+ atom_id_to_conformer_atom_id.append(conformer_atom_ids_this_conf[x_exists_this_conf])
244
+ # Tokenwise
245
+ residue_index.append(c_id)
246
+ conformer_id_to_chunk_sizes.append(np.sum(x_exists_this_conf).item())
247
+ restype.append(rc.standard_ccds.index(ccd) if ccd in rc.standard_ccds else 20)
248
+ ccds.append(ccd)
249
+ x_gt = np.concatenate(x_gt, axis=0)
250
+ atom_id_to_conformer_atom_id = np.concatenate(atom_id_to_conformer_atom_id, axis=0, dtype=np.int32)
251
+ residue_index = np.array(residue_index, dtype=np.int64)
252
+ conformer_id_to_chunk_sizes = np.array(conformer_id_to_chunk_sizes, dtype=np.int64)
253
+ restype = np.array(restype, dtype=np.int64)
254
+
255
+ conformer_exists = np.array(conformer_exists, dtype=np.bool_)
256
+
257
+ chain_feature_update = {
258
+ "x_gt": x_gt,
259
+ "atom_id_to_conformer_atom_id": atom_id_to_conformer_atom_id,
260
+ "residue_index": residue_index,
261
+ "conformer_id_to_chunk_sizes": conformer_id_to_chunk_sizes,
262
+ "restype": restype,
263
+ "ccds": ccds,
264
+ "msa": chain_feature["msa"][:, conformer_exists],
265
+ "deletion_matrix": chain_feature["deletion_matrix"][:, conformer_exists],
266
+ "chain_class": chain_class,
267
+ "key_res_feat": key_res_feat[conformer_exists],
268
+ "is_key_res": is_key_res[conformer_exists],
269
+ }
270
+
271
+ chain_feature_update["is_protein"] = np.array([chain_class == "protein"] * len(ccds)).astype(np.float32)
272
+ chain_feature_update["is_ligand"] = np.array([chain_class != "protein"] * len(ccds)).astype(np.float32)
273
+ # Assert Short Poly Chain like peptide
274
+ chain_feature_update["is_short_poly"] = np.array(
275
+ [chain_class != "protein" and len(ccds) >= 2 and rc.is_standard(ccd) for ccd in ccds]
276
+ ).astype(np.float32)
277
+
278
+ if "msa_all_seq" in chain_feature:
279
+ chain_feature_update["msa_all_seq"] = chain_feature["msa_all_seq"][:, conformer_exists]
280
+ chain_feature_update["deletion_matrix_all_seq"] = \
281
+ chain_feature["deletion_matrix_all_seq"][:, conformer_exists]
282
+ chain_feature_update["msa_species_identifiers_all_seq"] = chain_feature["msa_species_identifiers_all_seq"]
283
+ del chain_feature
284
+ return chain_feature_update
285
+
286
+ def _add_assembly_feature(self, all_chain_features, SEQ3):
287
+ entities = {}
288
+ for chain_id, seq3 in SEQ3.items():
289
+ if seq3 not in entities:
290
+ entities[seq3] = [chain_id]
291
+ else:
292
+ entities[seq3].append(chain_id)
293
+
294
+ asym_id = 0
295
+ ASYM_ID = {}
296
+ for entity_id, seq3 in enumerate(list(entities.keys())):
297
+ chain_ids = copy.deepcopy(entities[seq3])
298
+ if not self.inference_mode:
299
+ # sym_id augmentation
300
+ random.shuffle(chain_ids)
301
+ for sym_id, chain_id in enumerate(chain_ids):
302
+ num_conformers = len(all_chain_features[chain_id]["ccds"])
303
+ all_chain_features[chain_id]["asym_id"] = \
304
+ np.full([num_conformers], fill_value=asym_id, dtype=np.int32)
305
+ all_chain_features[chain_id]["sym_id"] = \
306
+ np.full([num_conformers], fill_value=sym_id, dtype=np.int32)
307
+ all_chain_features[chain_id]["entity_id"] = \
308
+ np.full([num_conformers], fill_value=entity_id, dtype=np.int32)
309
+
310
+ all_chain_features[chain_id]["sequence_3"] = seq3
311
+ ASYM_ID[asym_id] = chain_id
312
+
313
+ asym_id += 1
314
+ return all_chain_features, ASYM_ID
315
+
316
+ def load_all_chain_features(self, all_chain_labels):
317
+ all_chain_features = {}
318
+ CONF_META_DATA = {}
319
+ SEQ3 = {}
320
+ CHAIN_CLASS = {}
321
+
322
+ for chain_id, chain_feature in all_chain_labels.items():
323
+ ccds = chain_feature["ccds"]
324
+ CONF_META_DATA = self._update_CONF_META_DATA(CONF_META_DATA, ccds)
325
+ SEQ3[chain_id] = "-".join(ccds)
326
+ chain_class = "protein" if not chain_id.isdigit() else "ligand"
327
+ chain_feature["chain_class"] = chain_class
328
+
329
+ all_chain_features[chain_id] = self._update_chain_feature(
330
+ chain_feature,
331
+ CONF_META_DATA
332
+ )
333
+ CHAIN_CLASS[chain_id] = chain_class
334
+
335
+ all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3)
336
+
337
+ infer_meta_data = {
338
+ "CONF_META_DATA": CONF_META_DATA,
339
+ "SEQ3": SEQ3,
340
+ "ASYM_ID": ASYM_ID,
341
+ "CHAIN_CLASS": CHAIN_CLASS
342
+ }
343
+
344
+ return all_chain_features, infer_meta_data
345
+
346
+ def _spatial_crop_v2(self, all_chain_features, infer_meta_data):
347
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
348
+ ordered_chain_ids = list(all_chain_features.keys())
349
+
350
+ x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
351
+ # asym_id = np.concatenate([all_chain_features[chain_id]["asym_id"] for chain_id in ordered_chain_ids], axis=0)
352
+
353
+ token_id_to_centre_atom_id = []
354
+ token_id_to_conformer_id = []
355
+ token_id_to_ccd_chunk_sizes = []
356
+ token_id_to_ccd = []
357
+ asym_id_ca = []
358
+ token_id = 0
359
+ atom_id = 0
360
+ conf_id = 0
361
+ x_gt_ligand = []
362
+ for chain_id in ordered_chain_ids:
363
+ if chain_id.isdigit() and len(all_chain_features[chain_id]["ccds"]) == 1:
364
+ x_gt_ligand.append(all_chain_features[chain_id]["x_gt"])
365
+ atom_offset = 0
366
+ for ccd, chunk_size_this_ccd, asym_id in zip(
367
+ all_chain_features[chain_id]["ccds"],
368
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
369
+ all_chain_features[chain_id]["asym_id"],
370
+ ):
371
+ inner_atom_idx = all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][
372
+ atom_offset:atom_offset + chunk_size_this_ccd]
373
+ atom_names = [CONF_META_DATA[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
374
+ if rc.is_standard(ccd):
375
+
376
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
377
+ if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
378
+ token_id_to_centre_atom_id.append(atom_id)
379
+ token_id_to_conformer_id.append(conf_id)
380
+ token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
381
+ token_id_to_ccd.append(ccd)
382
+ asym_id_ca.append(asym_id)
383
+ atom_id += 1
384
+ token_id += 1
385
+
386
+ else:
387
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
388
+ token_id_to_centre_atom_id.append(atom_id)
389
+ token_id_to_conformer_id.append(conf_id)
390
+ token_id_to_ccd_chunk_sizes.append(chunk_size_this_ccd)
391
+ token_id_to_ccd.append(ccd)
392
+ asym_id_ca.append(asym_id)
393
+ atom_id += 1
394
+ token_id += 1
395
+ atom_offset += chunk_size_this_ccd
396
+ conf_id += 1
397
+
398
+ x_gt_ca = x_gt[token_id_to_centre_atom_id]
399
+ asym_id_ca = np.array(asym_id_ca)
400
+
401
+ crop_scheme_seed = random.random()
402
+ # Spatial Crop Ligand
403
+
404
+ if crop_scheme_seed < (0.6 if not self.inference_mode else 1.0) and len(x_gt_ligand) > 0:
405
+ x_gt_ligand = np.concatenate(x_gt_ligand, axis=0)
406
+ x_gt_sel = random.choice(x_gt_ligand)[None]
407
+ # Spatial Crop Interface
408
+ elif crop_scheme_seed < 0.8 and len(set(asym_id_ca.tolist())) > 1:
409
+ chain_same = asym_id_ca[None] * asym_id_ca[:, None]
410
+ dist = np.linalg.norm(x_gt_ca[:, None] - x_gt_ca[None], axis=-1)
411
+
412
+ dist = dist + chain_same * 100
413
+ # interface_threshold
414
+ mask = np.any(dist < 15, axis=-1)
415
+ if sum(mask) > 0:
416
+ x_gt_sel = random.choice(x_gt_ca[mask])[None]
417
+ else:
418
+ x_gt_sel = random.choice(x_gt_ca)[None]
419
+ # Spatial Crop
420
+ else:
421
+ x_gt_sel = random.choice(x_gt_ca)[None]
422
+ dist = np.linalg.norm(x_gt_ca - x_gt_sel, axis=-1)
423
+ token_idxs = np.argsort(dist)
424
+
425
+ select_ccds_idx = []
426
+ to_sum_atom = 0
427
+ to_sum_token = 0
428
+ for token_idx in token_idxs:
429
+ ccd_idx = token_id_to_conformer_id[token_idx]
430
+ ccd_chunk_size = token_id_to_ccd_chunk_sizes[token_idx]
431
+ ccd_this_token = token_id_to_ccd[token_idx]
432
+ if ccd_idx in select_ccds_idx:
433
+ continue
434
+ if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
435
+ break
436
+ to_add_token = 1 if rc.is_standard(ccd_this_token) else ccd_chunk_size
437
+ if to_sum_token + to_add_token > self.token_crop_size:
438
+ break
439
+ select_ccds_idx.append(ccd_idx)
440
+ to_sum_atom += ccd_chunk_size
441
+ to_sum_token += to_add_token
442
+
443
+ ccd_all_id = 0
444
+ crop_chains = []
445
+ for chain_id in ordered_chain_ids:
446
+ conformer_used_mask = []
447
+ atom_used_mask = []
448
+ ccds = []
449
+ for ccd, chunk_size_this_ccd in zip(
450
+ all_chain_features[chain_id]["ccds"],
451
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
452
+ ):
453
+ if ccd_all_id in select_ccds_idx:
454
+ ccds.append(ccd)
455
+ if chain_id not in crop_chains:
456
+ crop_chains.append(chain_id)
457
+ conformer_used_mask.append(ccd_all_id in select_ccds_idx)
458
+ atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
459
+ ccd_all_id += 1
460
+ conf_mask = np.array(conformer_used_mask).astype(np.bool_)
461
+ atom_mask = np.array(atom_used_mask).astype(np.bool_)
462
+ # Update All Chain Features
463
+ all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
464
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
465
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
466
+ all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
467
+ all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
468
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
469
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
470
+ # BUG Fix
471
+ all_chain_features[chain_id]["key_res_feat"] = all_chain_features[chain_id]["key_res_feat"][conf_mask]
472
+ all_chain_features[chain_id]["is_key_res"] = all_chain_features[chain_id]["is_key_res"][conf_mask]
473
+ all_chain_features[chain_id]["is_protein"] = all_chain_features[chain_id]["is_protein"][conf_mask]
474
+ all_chain_features[chain_id]["is_short_poly"] = all_chain_features[chain_id]["is_short_poly"][conf_mask]
475
+ all_chain_features[chain_id]["is_ligand"] = all_chain_features[chain_id]["is_ligand"][conf_mask]
476
+ all_chain_features[chain_id]["asym_id"] = all_chain_features[chain_id]["asym_id"][conf_mask]
477
+ all_chain_features[chain_id]["sym_id"] = all_chain_features[chain_id]["sym_id"][conf_mask]
478
+ all_chain_features[chain_id]["entity_id"] = all_chain_features[chain_id]["entity_id"][conf_mask]
479
+
480
+ all_chain_features[chain_id]["ccds"] = ccds
481
+ if "msa" in all_chain_features[chain_id]:
482
+ all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
483
+ all_chain_features[chain_id]["deletion_matrix"] = \
484
+ all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
485
+ if "msa_all_seq" in all_chain_features[chain_id]:
486
+ all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
487
+ all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
488
+ all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
489
+ # Remove Unused Chains
490
+ for chain_id in list(all_chain_features.keys()):
491
+ if chain_id not in crop_chains:
492
+ all_chain_features.pop(chain_id, None)
493
+
494
+ return all_chain_features
495
+
496
+ def _spatial_crop(self, all_chain_features):
497
+
498
+ ordered_chain_ids = list(all_chain_features.keys())
499
+ atom_id_to_ccd_id = []
500
+ atom_id_to_ccd_chunk_sizes = []
501
+ atom_id_to_ccd = []
502
+
503
+ ccd_all_id = 0
504
+ for chain_id in ordered_chain_ids:
505
+ for ccd, chunk_size_this_ccd in zip(
506
+ all_chain_features[chain_id]["ccds"],
507
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
508
+ ):
509
+ atom_id_to_ccd_id += [ccd_all_id] * chunk_size_this_ccd
510
+ atom_id_to_ccd_chunk_sizes += [chunk_size_this_ccd] * chunk_size_this_ccd
511
+ atom_id_to_ccd += [ccd] * chunk_size_this_ccd
512
+ ccd_all_id += 1
513
+
514
+ to_sum_atom = 0
515
+ to_sum_token = 0
516
+ x_gt = np.concatenate([all_chain_features[chain_id]["x_gt"] for chain_id in ordered_chain_ids], axis=0)
517
+
518
+ spatial_crop_ratio = 0.3 if not self.inference_mode else 0
519
+ if random.random() < spatial_crop_ratio or len(ordered_chain_ids) == 1:
520
+ x_gt_sel = random.choice(x_gt)[None]
521
+ else:
522
+ asym_id = np.array(reduce(add, [
523
+ [asym_id] * len(all_chain_features[chain_id]["x_gt"])
524
+ for asym_id, chain_id in enumerate(ordered_chain_ids)
525
+ ]))
526
+ chain_same = asym_id[None] * asym_id[:, None]
527
+ dist = np.linalg.norm(x_gt[:, None] - x_gt[None], axis=-1)
528
+
529
+ dist = dist + chain_same * 100
530
+ mask = np.any(dist < 4, axis=-1)
531
+ if sum(mask) > 0:
532
+ x_gt_ = x_gt[mask]
533
+ x_gt_sel = random.choice(x_gt)[None]
534
+ dist = np.linalg.norm(x_gt - x_gt_sel, axis=-1)
535
+ atom_idxs = np.argsort(dist)
536
+ select_ccds_idx = []
537
+ for atom_idx in atom_idxs:
538
+ ccd_idx = atom_id_to_ccd_id[atom_idx]
539
+ ccd_chunk_size = atom_id_to_ccd_chunk_sizes[atom_idx]
540
+ ccd_this_atom = atom_id_to_ccd[atom_idx]
541
+ if ccd_idx in select_ccds_idx:
542
+ continue
543
+ if to_sum_atom + ccd_chunk_size > self.atom_crop_size:
544
+ break
545
+ to_add_token = 1 if rc.is_standard(ccd_this_atom) else ccd_chunk_size
546
+ if to_sum_token + to_add_token > self.token_crop_size:
547
+ break
548
+ select_ccds_idx.append(ccd_idx)
549
+ to_sum_atom += ccd_chunk_size
550
+ to_sum_token += to_add_token
551
+ ccd_all_id = 0
552
+ crop_chains = []
553
+ for chain_id in ordered_chain_ids:
554
+ conformer_used_mask = []
555
+ atom_used_mask = []
556
+ ccds = []
557
+ for ccd, chunk_size_this_ccd in zip(
558
+ all_chain_features[chain_id]["ccds"],
559
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"],
560
+ ):
561
+ if ccd_all_id in select_ccds_idx:
562
+ ccds.append(ccd)
563
+ if chain_id not in crop_chains:
564
+ crop_chains.append(chain_id)
565
+ conformer_used_mask.append(ccd_all_id in select_ccds_idx)
566
+ atom_used_mask += [ccd_all_id in select_ccds_idx] * chunk_size_this_ccd
567
+ ccd_all_id += 1
568
+ conf_mask = np.array(conformer_used_mask).astype(np.bool_)
569
+ atom_mask = np.array(atom_used_mask).astype(np.bool_)
570
+ # Update All Chain Features
571
+ all_chain_features[chain_id]["x_gt"] = all_chain_features[chain_id]["x_gt"][atom_mask]
572
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"] = \
573
+ all_chain_features[chain_id]["atom_id_to_conformer_atom_id"][atom_mask]
574
+ all_chain_features[chain_id]["restype"] = all_chain_features[chain_id]["restype"][conf_mask]
575
+ all_chain_features[chain_id]["residue_index"] = all_chain_features[chain_id]["residue_index"][conf_mask]
576
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = \
577
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"][conf_mask]
578
+ all_chain_features[chain_id]["ccds"] = ccds
579
+ if "msa" in all_chain_features[chain_id]:
580
+ all_chain_features[chain_id]["msa"] = all_chain_features[chain_id]["msa"][:, conf_mask]
581
+ all_chain_features[chain_id]["deletion_matrix"] = \
582
+ all_chain_features[chain_id]["deletion_matrix"][:, conf_mask]
583
+ if "msa_all_seq" in all_chain_features[chain_id]:
584
+ all_chain_features[chain_id]["msa_all_seq"] = all_chain_features[chain_id]["msa_all_seq"][:, conf_mask]
585
+ all_chain_features[chain_id]["deletion_matrix_all_seq"] = \
586
+ all_chain_features[chain_id]["deletion_matrix_all_seq"][:, conf_mask]
587
+ # Remove Unused Chains
588
+ for chain_id in list(all_chain_features.keys()):
589
+ if chain_id not in crop_chains:
590
+ all_chain_features.pop(chain_id, None)
591
+ return all_chain_features
592
+
593
+ def crop_all_chain_features(self, all_chain_features, infer_meta_data):
594
+ # all_chain_features = self._spatial_crop(all_chain_features)
595
+ all_chain_features = self._spatial_crop_v2(all_chain_features, infer_meta_data)
596
+ return all_chain_features, infer_meta_data
597
+
598
+ def _make_pocket_features(self, all_chain_features):
599
+ # minimium distance 6-12
600
+ all_chain_ids = list(all_chain_features.keys())
601
+
602
+ for chain_id in all_chain_ids:
603
+ all_chain_features[chain_id]["pocket_res_feat"] = np.zeros(
604
+ [len(all_chain_features[chain_id]["ccds"])], dtype=np.bool_)
605
+
606
+ ligand_chain_ids = [i for i in all_chain_ids if i.isdigit()]
607
+ receptor_chain_ids = [i for i in all_chain_ids if not i.isdigit()]
608
+
609
+ use_pocket = random.random() < 0.5
610
+ # TODO: Inference mode assign
611
+ if len(ligand_chain_ids) == 0 or len(receptor_chain_ids) == 0 or not use_pocket:
612
+ for chain_id in all_chain_ids:
613
+ all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id][
614
+ "pocket_res_feat"].astype(np.float32)
615
+ return all_chain_features
616
+
617
+ # Aug Part
618
+ for ligand_chain_id in ligand_chain_ids:
619
+ x_gt_ligand = all_chain_features[ligand_chain_id]["x_gt"]
620
+
621
+ # x_gt_mean = np.mean(x_gt_ligand, axis=0) + np.random.randn(3)
622
+
623
+ for receptor_chain_id in receptor_chain_ids:
624
+ x_gt_receptor = all_chain_features[receptor_chain_id]["x_gt"]
625
+
626
+ # dist = np.linalg.norm(x_gt_receptor - x_gt_mean[None], axis=-1)
627
+ # is_pocket_atom = (dist < (random.random() * 6 + 8)).astype(np.bool_)
628
+ is_pocket_atom = np.any(
629
+ np.linalg.norm(x_gt_receptor[:, None] - x_gt_ligand[None], axis=-1) < (random.random() * 6 + 6),
630
+ axis=-1
631
+ )
632
+
633
+ is_pocket_ccd = []
634
+ offset = 0
635
+ for chunk_size in all_chain_features[receptor_chain_id]["conformer_id_to_chunk_sizes"]:
636
+ is_pocket_ccd.append(np.any(is_pocket_atom[offset:offset + chunk_size]).item())
637
+ offset += chunk_size
638
+ is_pocket_ccd = np.array(is_pocket_ccd, dtype=np.bool_)
639
+
640
+ is_pocket_ccd = np.array([np.any(i).item() for i in is_pocket_ccd], dtype=np.bool_)
641
+ all_chain_features[receptor_chain_id]["pocket_res_feat"] = all_chain_features[receptor_chain_id][
642
+ "pocket_res_feat"] | is_pocket_ccd
643
+
644
+ for chain_id in all_chain_ids:
645
+ all_chain_features[chain_id]["pocket_res_feat"] = all_chain_features[chain_id]["pocket_res_feat"].astype(
646
+ np.float32)
647
+
648
+ return all_chain_features
649
+
650
+ def _make_ccd_features(self, raw_feats, infer_meta_data):
651
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
652
+ ccds = raw_feats["ccds"]
653
+ atom_id_to_conformer_atom_id = raw_feats["atom_id_to_conformer_atom_id"]
654
+ conformer_id_to_chunk_sizes = raw_feats["conformer_id_to_chunk_sizes"]
655
+
656
+ # Atomwise
657
+ atom_id_to_conformer_id = []
658
+ atom_id_to_token_id = []
659
+ ref_feat = []
660
+
661
+ # Tokenwise
662
+ s_mask = []
663
+ token_id_to_conformer_id = []
664
+ token_id_to_chunk_sizes = []
665
+ token_id_to_centre_atom_id = []
666
+ token_id_to_pseudo_beta_atom_id = []
667
+
668
+ token_id = 0
669
+ atom_id = 0
670
+ for conf_id, (ccd, ccd_atoms) in enumerate(zip(ccds, conformer_id_to_chunk_sizes)):
671
+ conf_meta_data = CONF_META_DATA[ccd]
672
+ # UNK Conformer
673
+ if rc.is_unk(ccd):
674
+ s_mask.append(0)
675
+ token_id_to_chunk_sizes.append(0)
676
+ token_id_to_conformer_id.append(conf_id)
677
+ token_id_to_centre_atom_id.append(-1)
678
+ token_id_to_pseudo_beta_atom_id.append(-1)
679
+ token_id += 1
680
+ # Standard Residue
681
+ elif rc.is_standard(ccd):
682
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
683
+ atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
684
+ ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
685
+ token_id_to_conformer_id.append(conf_id)
686
+ token_id_to_chunk_sizes.append(ccd_atoms.item())
687
+ s_mask.append(1)
688
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
689
+ # Update Atomwise Features
690
+ atom_id_to_conformer_id.append(conf_id)
691
+ atom_id_to_token_id.append(token_id)
692
+ # Update special atom ids
693
+ if atom_name == rc.standard_ccd_to_token_centre_atom_name[ccd]:
694
+ token_id_to_centre_atom_id.append(atom_id)
695
+ if atom_name == rc.standard_ccd_to_token_pseudo_beta_atom_name[ccd]:
696
+ token_id_to_pseudo_beta_atom_id.append(atom_id)
697
+ atom_id += 1
698
+ token_id += 1
699
+ # Nonestandard Residue & Ligand
700
+ else:
701
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_id:atom_id + ccd_atoms.item()]
702
+ atom_names = [conf_meta_data["ref_atom_name_chars"][i] for i in inner_atom_idx]
703
+ ref_feat.append(conf_meta_data["ref_feat"][inner_atom_idx])
704
+ # ref_pos_new.append(conf_meta_data["ref_pos_new"][:, inner_atom_idx])
705
+ for atom_id_this_ccd, atom_name in enumerate(atom_names):
706
+ # Update Atomwise Features
707
+ atom_id_to_conformer_id.append(conf_id)
708
+ atom_id_to_token_id.append(token_id)
709
+ # Update Tokenwise Features
710
+ token_id_to_chunk_sizes.append(1)
711
+ token_id_to_conformer_id.append(conf_id)
712
+ s_mask.append(1)
713
+ token_id_to_centre_atom_id.append(atom_id)
714
+ token_id_to_pseudo_beta_atom_id.append(atom_id)
715
+ atom_id += 1
716
+ token_id += 1
717
+
718
+ if len(ref_feat) > 1:
719
+ ref_feat = np.concatenate(ref_feat, axis=0).astype(np.float32)
720
+ else:
721
+ ref_feat = ref_feat[0].astype(np.float32)
722
+
723
+ features = {
724
+ # Atomwise
725
+ "atom_id_to_conformer_id": np.array(atom_id_to_conformer_id, dtype=np.int64),
726
+ "atom_id_to_token_id": np.array(atom_id_to_token_id, dtype=np.int64),
727
+ "ref_feat": ref_feat,
728
+ # Tokewise
729
+ "token_id_to_conformer_id": np.array(token_id_to_conformer_id, dtype=np.int64),
730
+ "s_mask": np.array(s_mask, dtype=np.int64),
731
+ "token_id_to_centre_atom_id": np.array(token_id_to_centre_atom_id, dtype=np.int64),
732
+ "token_id_to_pseudo_beta_atom_id": np.array(token_id_to_pseudo_beta_atom_id, dtype=np.int64),
733
+ "token_id_to_chunk_sizes": np.array(token_id_to_chunk_sizes, dtype=np.int64),
734
+ }
735
+ features["ref_pos"] = features["ref_feat"][..., :3]
736
+ return features
737
+
738
+ def pair_and_merge(self, all_chain_features, infer_meta_data):
739
+ CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"] # Dict
740
+ CONF_META_DATA = infer_meta_data["CONF_META_DATA"]
741
+ ASYM_ID = infer_meta_data["ASYM_ID"]
742
+ homo_feats = {}
743
+
744
+ # Create Aug Pocket Feature
745
+ all_chain_features = self._make_pocket_features(all_chain_features)
746
+
747
+ all_chain_ids = list(all_chain_features.keys())
748
+ if len(all_chain_ids) == 1 and CHAIN_CLASS[all_chain_ids[0]] == "ligand":
749
+ ordered_chain_ids = all_chain_ids
750
+ raw_feats = all_chain_features[all_chain_ids[0]]
751
+ raw_feats["msa"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
752
+ raw_feats["deletion_matrix"] = np.repeat(raw_feats["msa"][:1], 256, axis=0)
753
+ keys = list(raw_feats.keys())
754
+
755
+ for feature_name in keys:
756
+ if feature_name not in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index",
757
+ "conformer_id_to_chunk_sizes", "restype", "is_protein", "is_short_poly",
758
+ "is_ligand",
759
+ "asym_id", "sym_id", "entity_id", "msa", "deletion_matrix", "ccds",
760
+ "pocket_res_feat", "key_res_feat", "is_key_res"]:
761
+ raw_feats.pop(feature_name)
762
+
763
+ # Update Profile and Deletion Mean
764
+ msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
765
+ raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
766
+ del msa_one_hot
767
+ raw_feats["deletion_mean"] = (torch.atan(
768
+ torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
769
+ ) * (2.0 / torch.pi)).numpy()
770
+ else:
771
+
772
+ for chain_id in list(all_chain_features.keys()):
773
+ homo_feats[chain_id] = {
774
+ "asym_id": copy.deepcopy(all_chain_features[chain_id]["asym_id"]),
775
+ "sym_id": copy.deepcopy(all_chain_features[chain_id]["sym_id"]),
776
+ "entity_id": copy.deepcopy(all_chain_features[chain_id]["entity_id"]),
777
+ }
778
+ for chain_id in list(all_chain_features.keys()):
779
+ homo_feats[chain_id]["chain_class"] = all_chain_features[chain_id].pop("chain_class")
780
+ homo_feats[chain_id]["sequence_3"] = all_chain_features[chain_id].pop("sequence_3")
781
+ homo_feats[chain_id]["msa"] = all_chain_features[chain_id].pop("msa")
782
+ homo_feats[chain_id]["deletion_matrix"] = all_chain_features[chain_id].pop("deletion_matrix")
783
+ if "msa_all_seq" in all_chain_features[chain_id]:
784
+ homo_feats[chain_id]["msa_all_seq"] = all_chain_features[chain_id].pop("msa_all_seq")
785
+ homo_feats[chain_id]["deletion_matrix_all_seq"] = all_chain_features[chain_id].pop(
786
+ "deletion_matrix_all_seq")
787
+ homo_feats[chain_id]["msa_species_identifiers_all_seq"] = all_chain_features[chain_id].pop(
788
+ "msa_species_identifiers_all_seq")
789
+
790
+ # Initial raw feats with merged homo feats
791
+ raw_feats = pair_and_merge(homo_feats, is_homomer_or_monomer=False)
792
+
793
+ # Update Profile and Deletion Mean
794
+ msa_one_hot = F.one_hot(torch.from_numpy(raw_feats["msa"]).long(), 32).type(torch.float32)
795
+ raw_feats["profile"] = torch.mean(msa_one_hot, dim=-3).numpy()
796
+ del msa_one_hot
797
+ raw_feats["deletion_mean"] = (torch.atan(
798
+ torch.sum(torch.from_numpy(raw_feats["deletion_matrix"]), dim=0) / 3.0
799
+ ) * (2.0 / torch.pi)).numpy()
800
+
801
+ # Merge no homo feats according to asym_id
802
+ ordered_asym_ids = []
803
+ for i in raw_feats["asym_id"]:
804
+ if i not in ordered_asym_ids:
805
+ ordered_asym_ids.append(i)
806
+ ordered_chain_ids = [ASYM_ID[i] for i in ordered_asym_ids]
807
+ for feature_name in ["chain_class", "sequence_3", "assembly_num_chains", "entity_mask", "seq_length",
808
+ "num_alignments"]:
809
+ raw_feats.pop(feature_name, None)
810
+ for feature_name in ["x_gt", "atom_id_to_conformer_atom_id", "residue_index", "conformer_id_to_chunk_sizes",
811
+ "restype", "is_protein", "is_short_poly", "is_ligand", "pocket_res_feat",
812
+ "key_res_feat", "is_key_res"]:
813
+ raw_feats[feature_name] = np.concatenate([
814
+ all_chain_features[chain_id].pop(feature_name) for chain_id in ordered_chain_ids
815
+ ], axis=0)
816
+
817
+ # Conformerwise Chain Class
818
+ CHAIN_CLASS_NEW = []
819
+ for chain_id in ordered_chain_ids:
820
+ CHAIN_CLASS_NEW += [CHAIN_CLASS[chain_id]] * len(all_chain_features[chain_id]["ccds"])
821
+ infer_meta_data["CHAIN_CLASS"] = CHAIN_CLASS_NEW
822
+
823
+ raw_feats["ccds"] = reduce(add, [all_chain_features[chain_id].pop("ccds") for chain_id in ordered_chain_ids])
824
+
825
+ # Create Atomwise and Tokenwise Features
826
+ raw_feats.update(self._make_ccd_features(raw_feats, infer_meta_data))
827
+
828
+ asym_id_conformerwise = copy.deepcopy(raw_feats["asym_id"])
829
+ residue_index_conformerwise = copy.deepcopy(raw_feats["residue_index"])
830
+
831
+ # Conformerwise to Tokenwise
832
+ token_id_to_conformer_id = raw_feats["token_id_to_conformer_id"]
833
+ for key in ["is_protein", "is_short_poly", "is_ligand", "residue_index", "restype", "asym_id", "entity_id",
834
+ "sym_id", "deletion_mean", "profile", "pocket_res_feat", "key_res_feat", "is_key_res"]:
835
+ raw_feats[key] = raw_feats[key][token_id_to_conformer_id]
836
+ for key in ["msa", "deletion_matrix"]:
837
+ if key in raw_feats:
838
+ raw_feats[key] = raw_feats[key][:, token_id_to_conformer_id]
839
+ ###################################################
840
+ # Centre Random Augmentation of ref pos #
841
+ ###################################################
842
+ raw_feats["ref_pos"] = centre_random_augmentation_np_apply(
843
+ raw_feats["ref_pos"], raw_feats["atom_id_to_token_id"]).astype(np.float32)
844
+ raw_feats["ref_feat"][:, :3] = raw_feats["ref_pos"]
845
+
846
+ ###################################################
847
+ # Create token pair features #
848
+ ###################################################
849
+ no_token = len(raw_feats["token_id_to_conformer_id"])
850
+ token_bonds = np.zeros([no_token, no_token], dtype=np.float32)
851
+ rel_tok_feat = np.zeros([no_token, no_token, 42], dtype=np.float32)
852
+ batch_ref_pos = np.zeros([32, no_token, 3], dtype=np.float32)
853
+ offset = 0
854
+ atom_offset = 0
855
+ for ccd, len_atoms in zip(
856
+ raw_feats["ccds"],
857
+ raw_feats["conformer_id_to_chunk_sizes"]
858
+ ):
859
+ if rc.is_standard(ccd) or rc.is_unk(ccd):
860
+ offset += 1
861
+ else:
862
+ len_atoms = len_atoms.item()
863
+ inner_atom_idx = raw_feats["atom_id_to_conformer_atom_id"][atom_offset:atom_offset + len_atoms]
864
+ batch_ref_pos[:, offset:offset + len_atoms] = CONF_META_DATA[ccd]["batch_ref_pos"][:, inner_atom_idx]
865
+ token_bonds[offset:offset + len_atoms, offset:offset + len_atoms] = \
866
+ CONF_META_DATA[ccd]["token_bonds"][inner_atom_idx][:, inner_atom_idx]
867
+ rel_tok_feat[offset:offset + len_atoms, offset:offset + len_atoms] = \
868
+ CONF_META_DATA[ccd]["rel_tok_feat"][inner_atom_idx][:, inner_atom_idx]
869
+ offset += len_atoms
870
+ atom_offset += len_atoms
871
+ raw_feats["token_bonds"] = token_bonds.astype(np.float32)
872
+ raw_feats["token_bonds_feature"] = token_bonds.astype(np.float32)
873
+ raw_feats["rel_tok_feat"] = rel_tok_feat.astype(np.float32)
874
+ raw_feats["batch_ref_pos"] = batch_ref_pos.astype(np.float32)
875
+ ###################################################
876
+ # Charility Augmentation #
877
+ ###################################################
878
+ if not self.inference_mode:
879
+ # TODO Charility probs
880
+ charility_seed = random.random()
881
+ if charility_seed < 0.1:
882
+ ref_chirality = raw_feats["ref_feat"][:, 158:161]
883
+ ref_chirality_replace = np.zeros_like(ref_chirality)
884
+ ref_chirality_replace[:, 2] = 1
885
+
886
+ is_ligand_atom = raw_feats["is_ligand"][raw_feats["atom_id_to_token_id"]]
887
+ remove_charility = (np.random.randint(0, 2, [len(is_ligand_atom)]) * is_ligand_atom).astype(
888
+ np.bool_)
889
+ ref_chirality = np.where(remove_charility[:, None], ref_chirality_replace, ref_chirality)
890
+ raw_feats["ref_feat"][:, 158:161] = ref_chirality
891
+
892
+ # MASKS
893
+ raw_feats["x_exists"] = np.ones_like(raw_feats["x_gt"][..., 0]).astype(np.float32)
894
+ raw_feats["a_mask"] = raw_feats["x_exists"]
895
+ raw_feats["s_mask"] = np.ones_like(raw_feats["asym_id"]).astype(np.float32)
896
+ raw_feats["ref_space_uid"] = raw_feats["atom_id_to_conformer_id"]
897
+
898
+ # Write Infer Meta Data
899
+ infer_meta_data["ccds"] = raw_feats.pop("ccds")
900
+ infer_meta_data["atom_id_to_conformer_atom_id"] = raw_feats.pop("atom_id_to_conformer_atom_id")
901
+ infer_meta_data["residue_index"] = residue_index_conformerwise
902
+ infer_meta_data["asym_id"] = asym_id_conformerwise
903
+ infer_meta_data["conformer_id_to_chunk_sizes"] = raw_feats.pop("conformer_id_to_chunk_sizes")
904
+
905
+ return raw_feats, infer_meta_data
906
+
907
+ def make_feats(self, tensors):
908
+ # Target Feat
909
+ tensors["target_feat"] = torch.cat([
910
+ F.one_hot(tensors["restype"].long(), 32).float(),
911
+ tensors["profile"].float(),
912
+ tensors["deletion_mean"][..., None].float()
913
+ ], dim=-1)
914
+
915
+ # MSA Feat
916
+ inds = [0] + torch.randperm(len(tensors["msa"]))[:127].tolist()
917
+
918
+ tensors["msa"] = tensors["msa"][inds]
919
+ tensors["deletion_matrix"] = tensors["deletion_matrix"][inds]
920
+
921
+ has_deletion = torch.clamp(tensors["deletion_matrix"].float(), min=0., max=1.)
922
+ pi = torch.acos(torch.zeros(1, device=tensors["deletion_matrix"].device)) * 2
923
+ deletion_value = (torch.atan(tensors["deletion_matrix"] / 3.) * (2. / pi))
924
+ tensors["msa_feat"] = torch.cat([
925
+ F.one_hot(tensors["msa"].long(), 32).float(),
926
+ has_deletion[..., None].float(),
927
+ deletion_value[..., None].float(),
928
+ ], dim=-1)
929
+ tensors.pop("msa", None)
930
+ tensors.pop("deletion_mean", None)
931
+ tensors.pop("profile", None)
932
+ tensors.pop("deletion_matrix", None)
933
+
934
+ return tensors
935
+
936
+ def _make_token_bonds(self, tensors):
937
+ # Get Polymer-Ligand & Ligand-Ligand Within Conformer Token Bond
938
+
939
+ # Atomwise asym_id
940
+ asym_id = tensors["asym_id"][tensors["atom_id_to_token_id"]]
941
+ is_ligand = tensors["is_ligand"][tensors["atom_id_to_token_id"]]
942
+
943
+ x_gt = tensors["x_gt"]
944
+ a_mask = tensors["a_mask"]
945
+
946
+ # Get
947
+ atom_id_to_token_id = tensors["atom_id_to_token_id"]
948
+
949
+ num_token = len(tensors["asym_id"])
950
+ between_conformer_token_bonds = torch.zeros([num_token, num_token])
951
+
952
+ # create chainwise feature
953
+ asym_id_chain = []
954
+ asym_id_atom_offset = []
955
+ asym_id_is_ligand = []
956
+ for atom_offset, (a_id, i_id) in enumerate(zip(asym_id.tolist(), is_ligand.tolist())):
957
+ if len(asym_id_chain) == 0 or asym_id_chain[-1] != a_id:
958
+ asym_id_chain.append(a_id)
959
+ asym_id_atom_offset.append(atom_offset)
960
+ asym_id_is_ligand.append(i_id)
961
+
962
+ len_asym_id_chain = len(asym_id_chain)
963
+ if len_asym_id_chain >= 2:
964
+ for i in range(0, len_asym_id_chain - 1):
965
+ asym_id_i = asym_id_chain[i]
966
+ mask_i = asym_id == asym_id_i
967
+ x_gt_i = x_gt[mask_i]
968
+ a_mask_i = a_mask[mask_i]
969
+ for j in range(i + 1, len_asym_id_chain):
970
+ if not bool(asym_id_is_ligand[i]) and not bool(asym_id_is_ligand[j]):
971
+ continue
972
+ asym_id_j = asym_id_chain[j]
973
+ mask_j = asym_id == asym_id_j
974
+ x_gt_j = x_gt[mask_j]
975
+ a_mask_j = a_mask[mask_j]
976
+ dis_ij = torch.norm(x_gt_i[:, None, :] - x_gt_j[None, :, :], dim=-1)
977
+ dis_ij = dis_ij + (1 - a_mask_i[:, None] * a_mask_j[None]) * 1000
978
+ if torch.min(dis_ij) < self.token_bond_threshold:
979
+ ij = torch.argmin(dis_ij).item()
980
+ l_j = len(x_gt_j)
981
+ atom_i = int(ij // l_j) # raw
982
+ atom_j = int(ij % l_j) # col
983
+ global_atom_i = atom_i + asym_id_atom_offset[i]
984
+ global_atom_j = atom_j + asym_id_atom_offset[j]
985
+ token_i = atom_id_to_token_id[global_atom_i]
986
+ token_j = atom_id_to_token_id[global_atom_j]
987
+
988
+ between_conformer_token_bonds[token_i, token_j] = 1
989
+ between_conformer_token_bonds[token_j, token_i] = 1
990
+ token_bond_seed = random.random()
991
+ tensors["token_bonds"] = tensors["token_bonds"] + between_conformer_token_bonds
992
+ # Docking Indicate Token Bond
993
+ if token_bond_seed >= 0:
994
+ tensors["token_bonds_feature"] = tensors["token_bonds"]
995
+ return tensors
996
+
997
+ def _pad_to_size(self, tensors):
998
+
999
+ to_pad_atom = self.atom_crop_size - len(tensors["x_gt"])
1000
+ to_pad_token = self.token_crop_size - len(tensors["residue_index"])
1001
+ if to_pad_token > 0:
1002
+ for k in ["restype", "residue_index", "is_protein", "is_short_poly", "is_ligand", "is_key_res",
1003
+ "asym_id", "entity_id", "sym_id", "token_id_to_conformer_id", "s_mask",
1004
+ "token_id_to_centre_atom_id", "token_id_to_pseudo_beta_atom_id", "token_id_to_chunk_sizes",
1005
+ "pocket_res_feat"]:
1006
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token])
1007
+ for k in ["target_feat", "msa_feat", "batch_ref_pos", "key_res_feat"]:
1008
+ if k in tensors:
1009
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token])
1010
+ for k in ["token_bonds", "token_bonds_feature"]:
1011
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_token, 0, to_pad_token])
1012
+ for k in ["rel_tok_feat"]:
1013
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_token, 0, to_pad_token])
1014
+ if to_pad_atom > 0:
1015
+ for k in ["a_mask", "x_exists", "atom_id_to_conformer_id", "atom_id_to_token_id", "ref_space_uid"]:
1016
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom])
1017
+ for k in ["x_gt", "ref_feat", "ref_pos"]: # , "ref_pos_new"
1018
+ tensors[k] = torch.nn.functional.pad(tensors[k], [0, 0, 0, to_pad_atom])
1019
+ # for k in ["z_mask"]: # , "ref_pos_new"
1020
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
1021
+ # for k in ["conformer_mask_atom"]:
1022
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0, to_pad_atom, 0, to_pad_atom])
1023
+ # for k in ["rel_token_feat_atom"]:
1024
+ # tensors[k] = torch.nn.functional.pad(tensors[k], [0,0,0, to_pad_atom, 0, to_pad_atom])
1025
+ # rel_token_feat_atom
1026
+ return tensors
1027
+
1028
+ def get_template_feat(self, tensors):
1029
+ x_gt = tensors["x_gt"][tensors["token_id_to_pseudo_beta_atom_id"]]
1030
+ z_mask = tensors["z_mask"]
1031
+ asym_id = tensors["asym_id"]
1032
+ is_protein = tensors["is_protein"]
1033
+ chain_same = (asym_id[None] == asym_id[:, None]).float()
1034
+ protein2d = is_protein[None] * is_protein[:, None]
1035
+ dgram = dgram_from_positions(x_gt)
1036
+ dgram = dgram * protein2d[..., None] * z_mask[..., None]
1037
+
1038
+ # if not self.inference_mode:
1039
+ # bert_mask = torch.rand([len(x_gt)]) > random.random() * 0.4
1040
+ # asym_ids = list(set(asym_id.tolist()))
1041
+ # used_asym_ids = []
1042
+ # for a in asym_ids:
1043
+ # if random.random() > 0.6:
1044
+ # used_asym_ids.append(a)
1045
+ # if len(used_asym_ids) > 0:
1046
+ # used_asym_ids = torch.tensor(used_asym_ids)
1047
+ # chain_bert_mask = torch.any(asym_id[:, None] == used_asym_ids[None], dim=-1)
1048
+ # bert_mask = chain_bert_mask * bert_mask
1049
+ # else:
1050
+ # bert_mask = bert_mask * 0
1051
+ # template_pseudo_beta_mask = (bert_mask[None] * bert_mask[:, None]) * z_mask * protein2d
1052
+ # else:
1053
+ # template_pseudo_beta_mask = z_mask * protein2d
1054
+ template_pseudo_beta_mask = z_mask * protein2d
1055
+ # template_pseudo_beta_mask = protein2d * z_mask
1056
+ dgram = dgram * template_pseudo_beta_mask[..., None]
1057
+ templ_feat = torch.cat([dgram, template_pseudo_beta_mask[..., None]], dim=-1)
1058
+ tensors["templ_feat"] = templ_feat.float()[None]
1059
+ t_mask_seed = random.random()
1060
+ # Template Augmentation
1061
+ if self.inference_mode or t_mask_seed < 0.1:
1062
+ tensors["t_mask"] = torch.ones([len(tensors["templ_feat"])], dtype=torch.float32)
1063
+ else:
1064
+ tensors["t_mask"] = torch.zeros([len(tensors["templ_feat"])], dtype=torch.float32)
1065
+
1066
+ # TODO: No Template
1067
+ # if not self.inference_mode:
1068
+ # if random.random() < 0.5:
1069
+ # tensors["templ_feat"] *= 0
1070
+ return tensors
1071
+
1072
+ def transform(self, raw_feats):
1073
+ # np to tensor
1074
+ tensors = dict()
1075
+ for key in raw_feats.keys():
1076
+ tensors[key] = torch.from_numpy(raw_feats[key])
1077
+ # Make Target & MSA Feat
1078
+ tensors = self.make_feats(tensors)
1079
+
1080
+ # Make Token Bond Feat
1081
+ tensors = self._make_token_bonds(tensors)
1082
+
1083
+ # Padding
1084
+ if not self.inference_mode:
1085
+ tensors = self._pad_to_size(tensors)
1086
+
1087
+ # # Make Pocket Res Feat
1088
+ #
1089
+ # tensors["pocket_res_feat"] = torch.zeros([l], dtype=torch.float32)
1090
+
1091
+ # Make Key Res Feat
1092
+ # l = len(tensors["asym_id"])
1093
+ # tensors["key_res_feat"] = torch.zeros([l, 7], dtype=torch.float32)
1094
+ # tensors["key_res_feat"][:, 0] = 1.
1095
+
1096
+ # Mask
1097
+ tensors["z_mask"] = tensors["s_mask"][None] * tensors["s_mask"][:, None]
1098
+
1099
+ # Template
1100
+ tensors = self.get_template_feat(tensors)
1101
+
1102
+ # Correct Type
1103
+ is_short_poly = tensors.pop("is_short_poly")
1104
+ tensors["is_protein"] = tensors["is_protein"] + is_short_poly
1105
+ tensors["is_ligand"] = tensors["is_ligand"] - is_short_poly
1106
+ tensors["is_dna"] = torch.zeros_like(tensors["is_protein"])
1107
+ tensors["is_rna"] = torch.zeros_like(tensors["is_protein"])
1108
+ return tensors
1109
+
1110
+ def load(self, sample_id):
1111
+ all_chain_labels = load_pkl(os.path.join(self.samples_path, f"{sample_id}.pkl.gz"))
1112
+
1113
+ all_chain_features, infer_meta_data = self.load_all_chain_features(all_chain_labels)
1114
+ infer_meta_data["system_id"] = sample_id
1115
+ # if not self.inference_mode:
1116
+ all_chain_features, infer_meta_data = self.crop_all_chain_features(all_chain_features, infer_meta_data)
1117
+ raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
1118
+
1119
+ tensors = self.transform(raw_feats)
1120
+ return tensors, infer_meta_data
1121
+
1122
+ def random_load(self):
1123
+
1124
+ sample_id = self.used_sample_ids[torch.multinomial(self.probabilities, 1).item()]
1125
+ print(sample_id)
1126
+ return self.load(sample_id)
1127
+
1128
+ def random_load_test(self):
1129
+
1130
+ sample_id = self.used_test_sample_ids[torch.multinomial(self.test_probabilities, 1).item()]
1131
+ print(sample_id)
1132
+ return self.load(sample_id)
1133
+
1134
+ def weighted_random_load(self):
1135
+ weight_seed = random.random()
1136
+ if weight_seed < 0.95:
1137
+ sample_id = self.used_sample_ids[torch.multinomial(self.probabilities, 1).item()]
1138
+ else:
1139
+ return self.random_load_mol_chunks()
1140
+ return self.load(sample_id)
1141
+
1142
+ def load_ligand(self, sample_id, chain_features):
1143
+ all_chain_features = {}
1144
+ CHAIN_META_DATA = {
1145
+ "spatial_crop_chain_ids": None,
1146
+ "chain_class": {},
1147
+ "chain_sequence_3s": {},
1148
+ "fake_ccds": [],
1149
+ }
1150
+ CONF_META_DATA = {}
1151
+ num_prev_fake_ccds = len(CHAIN_META_DATA["fake_ccds"])
1152
+ fake_ccd = f"{num_prev_fake_ccds:#>3}"
1153
+
1154
+ chain_features["msa"] = chain_features["restype"][None]
1155
+ chain_features["deletion_matrix"] = np.ones_like(chain_features["msa"])
1156
+ chain_features["ccds"] = [fake_ccd]
1157
+ chain_features["chain_class"] = "ligand"
1158
+ chain_features["all_atom_positions"] = chain_features["all_atom_positions"][None]
1159
+ chain_features["all_atom_mask"] = chain_features["all_atom_mask"][None]
1160
+ # Update Chain and Conf
1161
+ CHAIN_META_DATA["fake_ccds"].append(fake_ccd)
1162
+ sequence_3 = fake_ccd
1163
+ chain_id = f"SDFM_{sample_id}"
1164
+ CHAIN_META_DATA["chain_sequence_3s"][chain_id] = sequence_3
1165
+ CHAIN_META_DATA["chain_class"][chain_id] = "ligand"
1166
+ CONF_META_DATA = self._update_CONF_META_DATA_ligand(
1167
+ CONF_META_DATA, sequence_3, chain_features)
1168
+ all_chain_features[chain_id] = chain_features
1169
+ all_chain_features[chain_id] = self._update_chain_feature(
1170
+ chain_features,
1171
+ CONF_META_DATA
1172
+ )
1173
+ SEQ3 = {}
1174
+ CHAIN_CLASS = {}
1175
+ SEQ3[chain_id] = "-".join([fake_ccd])
1176
+ all_chain_features, ASYM_ID = self._add_assembly_feature(all_chain_features, SEQ3)
1177
+ all_chain_features[chain_id]["conformer_id_to_chunk_sizes"] = np.array(
1178
+ [len(chain_features["ref_atom_name_chars"])], dtype=np.int64)
1179
+
1180
+ all_chain_features[chain_id]["x_gt"] = chain_features["all_atom_positions"][0]
1181
+ all_chain_features[chain_id]["x_exists"] = chain_features["all_atom_mask"][0]
1182
+ CHAIN_CLASS[chain_id] = "ligand"
1183
+ infer_meta_data = {
1184
+ "CONF_META_DATA": CONF_META_DATA,
1185
+ "SEQ3": SEQ3,
1186
+ "ASYM_ID": ASYM_ID,
1187
+ "CHAIN_CLASS": CHAIN_CLASS
1188
+ }
1189
+
1190
+ # if not self.inference_mode:
1191
+ # all_chain_features, infer_meta_data = self.crop_all_chain_features(all_chain_features, infer_meta_data)
1192
+
1193
+ raw_feats, infer_meta_data = self.pair_and_merge(all_chain_features, infer_meta_data)
1194
+
1195
+ tensors = self.transform(raw_feats)
1196
+ infer_meta_data["system_id"] = sample_id
1197
+ return tensors, infer_meta_data
1198
+
1199
+ def random_load_mol_chunks(self, ligand_db_name="1"):
1200
+ if ligand_db_name == "0":
1201
+ ligand_db = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/try_new.pkl.gz")
1202
+ elif ligand_db_name == "1":
1203
+ id = random.randint(1, 374)
1204
+ ligand_db = load_pkl(f"/2022133002/data/ligand_samples/samples_{id}.pkl.gz")
1205
+ elif ligand_db_name == "2":
1206
+ ligand_db = load_pkl("/2022133002/projects/stdock/stdock_v9.5/scripts/try_400k_2_new.pkl.gz")
1207
+ else:
1208
+ raise ValueError("MOL DB Name is Wrong!")
1209
+ sample_id = random.choice(list(ligand_db.keys()))
1210
+ sample_feature = ligand_db[sample_id]
1211
+ tensors, infer_meta_data = self.load_ligand(sample_id, sample_feature)
1212
+ return tensors, infer_meta_data
1213
+
1214
+ def write_pdb(self, x_pred, fname, infer_meta_data):
1215
+ ccds = infer_meta_data["ccds"]
1216
+ atom_id_to_conformer_atom_id = infer_meta_data["atom_id_to_conformer_atom_id"]
1217
+ ccd_chunk_sizes = infer_meta_data["conformer_id_to_chunk_sizes"].tolist()
1218
+ CHAIN_CLASS = infer_meta_data["CHAIN_CLASS"]
1219
+ conf_meta_data = infer_meta_data["CONF_META_DATA"]
1220
+ residue_index = infer_meta_data["residue_index"].tolist()
1221
+ asym_id = infer_meta_data["asym_id"].tolist()
1222
+
1223
+ atom_lines = []
1224
+ atom_offset = 0
1225
+ for ccd_id, (ccd, chunk_size, res_id) in enumerate(zip(ccds, ccd_chunk_sizes, residue_index)):
1226
+ inner_atom_idx = atom_id_to_conformer_atom_id[atom_offset:atom_offset + chunk_size]
1227
+ atom_names = [conf_meta_data[ccd]["ref_atom_name_chars"][i] for i in inner_atom_idx]
1228
+ atom_elements = [PeriodicTable[conf_meta_data[ccd]["ref_element"][i]] for i in inner_atom_idx]
1229
+ chain_tag = PDB_CHAIN_IDS[int(asym_id[ccd_id])]
1230
+ record_type = "HETATM" if CHAIN_CLASS[ccd_id] == "ligand" else "ATOM"
1231
+
1232
+ for ccd_atom_idx, atom_name in enumerate(atom_names):
1233
+ x = x_pred[atom_offset]
1234
+ name = atom_name if len(atom_name) == 4 else f" {atom_name}"
1235
+ res_name_3 = ccd
1236
+ alt_loc = ""
1237
+ insertion_code = ""
1238
+ occupancy = 1.00
1239
+ element = atom_elements[ccd_atom_idx]
1240
+ # b_factor = torch.argmax(plddt[atom_offset],dim=-1).item()*2 +1
1241
+ b_factor = 70.
1242
+ charge = 0
1243
+ pos = x.tolist()
1244
+ atom_line = (
1245
+ f"{record_type:<6}{atom_offset + 1:>5} {name:<4}{alt_loc:>1}"
1246
+ f"{res_name_3.split()[0]:>3} {chain_tag:>1}"
1247
+ f"{res_id + 1:>4}{insertion_code:>1} "
1248
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
1249
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
1250
+ f"{element:>2}{charge:>2}"
1251
+ )
1252
+ atom_lines.append(atom_line)
1253
+ atom_offset += 1
1254
+ if atom_offset == len(atom_id_to_conformer_atom_id):
1255
+ break
1256
+ out = "\n".join(atom_lines)
1257
+ out = f"MODEL 1\n{out}\nTER\nENDMDL\nEND"
1258
+ dump_txt(out, fname)
PhysDock/data/generate_system.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ import warnings
5
+
6
+ import numpy as np
7
+ from Bio.PDB import PDBParser
8
+ from rdkit import Chem
9
+
10
+ from PhysDock.utils.io_utils import dump_pkl, load_pkl, convert_md5_string
11
+ from PhysDock.data.constants.PDBData import protein_letters_3to1_extended
12
+
13
+ warnings.filterwarnings("ignore")
14
+
15
+
16
+ def generate_system(
17
+ receptor_pdb_path,
18
+ ligand_sdf_path,
19
+ ligand_ccd_id,
20
+ systems_dir,
21
+ ccd_id_meta_data=None,
22
+
23
+ # bfd_database_path,
24
+ # uniclust30_database_path,
25
+ # uniref90_database_path,
26
+ # mgnify_database_path,
27
+ # uniprot_database_path,
28
+ # jackhmmer_binary_path,
29
+ # hhblits_binary_path,
30
+ #
31
+ # input_dir,
32
+ # out_dir,
33
+ #
34
+ # n_cpus=16,
35
+ # n_workers=1,
36
+
37
+ ):
38
+ """
39
+ Parse PDB and SDF files to generate protein-ligand complex features.
40
+
41
+ Args:
42
+ input_pdb (str): Path to the input PDB file.
43
+ input_sdf (str): Path to the input ligand SDF file.
44
+ systems_dir (str): Directory to save system feature pickle files.
45
+ feature_dir (str): Directory to save feature files (e.g., input FASTA).
46
+ ligand_id (str): CCD ID of the ligand.
47
+ """
48
+ # Create output directories
49
+ if ccd_id_meta_data is None:
50
+ print("Loading CCD meta data ...")
51
+ ccd_id_meta_data = load_pkl(os.path.join(os.path.split(__file__)[0], "../../params/ccd_id_meta_data.pkl.gz"))
52
+ os.makedirs(systems_dir, exist_ok=True)
53
+
54
+
55
+ # Initialize parser and data containers
56
+ pdb_parser = PDBParser()
57
+ structure = pdb_parser.get_structure("", receptor_pdb_path)
58
+ model = structure[0]
59
+
60
+ all_chain_features = {}
61
+ used_chain_ids = []
62
+
63
+ # Extract protein chains from PDB
64
+ for chain in model:
65
+ chain_id = chain.id
66
+ used_chain_ids.append(chain_id)
67
+ all_chain_features[chain_id] = {
68
+ "all_atom_positions": [],
69
+ "all_atom_mask": [],
70
+ "ccds": []
71
+ }
72
+
73
+ offset = None
74
+ for residue in chain:
75
+ if offset is None:
76
+ offset = int(residue.id[1])
77
+
78
+ resname = residue.get_resname().strip().ljust(3)
79
+ res_idx = int(residue.id[1]) - offset
80
+ num_atoms = len(ccd_id_meta_data[resname]["ref_atom_name_chars"])
81
+
82
+ # Fill missing residues
83
+ while len(all_chain_features[chain_id]["ccds"]) < res_idx:
84
+ all_chain_features[chain_id]["ccds"].append("UNK")
85
+ all_chain_features[chain_id]["all_atom_positions"].append(np.zeros([1, 3], dtype=np.float32))
86
+ all_chain_features[chain_id]["all_atom_mask"].append(np.zeros([1], dtype=np.int8))
87
+
88
+ # Initialize residue data
89
+ all_chain_features[chain_id]["ccds"].append(resname)
90
+ all_chain_features[chain_id]["all_atom_positions"].append(np.zeros([num_atoms, 3], dtype=np.float32))
91
+ all_chain_features[chain_id]["all_atom_mask"].append(np.zeros([num_atoms], dtype=np.int8))
92
+
93
+ ref_atom_names = ccd_id_meta_data[resname]["ref_atom_name_chars"]
94
+ for atom in residue:
95
+ if atom.name in ref_atom_names:
96
+ atom_idx = ref_atom_names.index(atom.name)
97
+ all_chain_features[chain_id]["all_atom_positions"][res_idx][atom_idx] = atom.coord
98
+ all_chain_features[chain_id]["all_atom_mask"][res_idx][atom_idx] = 1
99
+
100
+ # Add interaction features # TODO PLIP
101
+ interaction_keys = ['salt bridges', 'pi-cation interactions', 'hydrophobic interactions',
102
+ 'pi-stacking', 'hydrogen bonds', 'metal complexes']
103
+ for key in interaction_keys:
104
+ all_chain_features[chain_id][key] = np.zeros(len(all_chain_features[chain_id]["ccds"]), dtype=np.int8)
105
+
106
+ # Extract ligand from SDF
107
+ supplier = Chem.SDMolSupplier(ligand_sdf_path, removeHs=True, sanitize=False)
108
+ mol = supplier[0]
109
+ mol = Chem.RemoveAllHs(mol)
110
+ conf = mol.GetConformer()
111
+ ligand_chain_id = "1"
112
+ used_chain_ids.append(ligand_chain_id)
113
+
114
+ ligand_atom_count = mol.GetNumAtoms()
115
+ ligand_positions = np.zeros([ligand_atom_count, 3], dtype=np.float32)
116
+ ligand_masks = np.ones([ligand_atom_count], dtype=np.int8)
117
+
118
+ for atom in mol.GetAtoms():
119
+ idx = atom.GetIdx()
120
+ pos = conf.GetAtomPosition(idx)
121
+ ligand_positions[idx] = [pos.x, pos.y, pos.z]
122
+
123
+ all_chain_features[ligand_chain_id] = {
124
+ "all_atom_positions": [ligand_positions],
125
+ "all_atom_mask": [ligand_masks],
126
+ "ccds": [ligand_ccd_id.upper()]
127
+ }
128
+
129
+ for key in interaction_keys:
130
+ all_chain_features[ligand_chain_id][key] = np.zeros(1, dtype=np.int8)
131
+
132
+ # Generate system pickle file
133
+ save_name = os.path.basename(receptor_pdb_path).replace('.pdb', '')
134
+ for cid in used_chain_ids:
135
+ save_name += f"_{cid}"
136
+
137
+ dump_pkl(all_chain_features, os.path.join(systems_dir, f"{save_name}.pkl.gz"))
138
+
139
+ # Generate FASTA files to run homo search
140
+ for cid, features in all_chain_features.items():
141
+ if cid == ligand_chain_id:
142
+ continue
143
+ sequence = ''.join(protein_letters_3to1_extended.get(ccd, "X") for ccd in features["ccds"])
144
+ md5_hash = convert_md5_string(f"protein:{sequence}")
145
+ os.makedirs(os.path.join(systems_dir, "fastas"), exist_ok=True)
146
+ with open(os.path.join(systems_dir, "fastas", f"{md5_hash}.fasta"), "w") as f:
147
+ f.write(f">{md5_hash}\n{sequence}\n")
148
+ print("Make system successfully!")
PhysDock/data/relaxation.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import sys
3
+ import torch
4
+ import os
5
+ import tqdm
6
+ import pandas as pd
7
+ import argparse
8
+
9
+ sys.path.append("../")
10
+ from PhysDock.utils.io_utils import run_pool_tasks, load_txt, dump_txt
11
+ from pathlib import Path
12
+ import openmm.app as mm_app
13
+ import openmm.unit as mm_unit
14
+ import openmm as mm
15
+ import os.path
16
+ import sys
17
+ import mdtraj
18
+ from openmm.app import PDBFile, Modeller
19
+ import pdbfixer
20
+ from openmmforcefields.generators import SystemGenerator
21
+ from openff.toolkit import Molecule
22
+ from openff.toolkit.utils.exceptions import UndefinedStereochemistryError, RadicalsNotSupportedError
23
+ from openmm import CustomExternalForce
24
+ from posebusters import PoseBusters
25
+ from posebusters.posebusters import _dataframe_from_output
26
+ from posebusters.cli import _select_mode, _format_results
27
+
28
+
29
+ def get_bust_results( # noqa: PLR0913
30
+ mol_pred,
31
+ mol_true,
32
+ mol_cond,
33
+ top_n: int | None = None,
34
+ ):
35
+ mol_pred = [Path(mol_pred)]
36
+ mol_true = Path(mol_true)
37
+ mol_cond = Path(mol_cond) # Each bust running has different receptor
38
+
39
+ # run on single input
40
+ d = {k for k, v in dict(mol_pred=mol_pred, mol_true=mol_true, mol_cond=mol_cond).items() if v}
41
+ mode = _select_mode(None, d)
42
+ posebusters = PoseBusters(mode, top_n=top_n)
43
+ cols = ["mol_pred", "mol_true", "mol_cond"]
44
+ posebusters.file_paths = pd.DataFrame([[mol_pred, mol_true, mol_cond] for mol_pred in mol_pred], columns=cols)
45
+ posebusters_results = posebusters._run()
46
+ results = None
47
+ for i, results_dict in enumerate(posebusters_results):
48
+ results = _dataframe_from_output(results_dict, posebusters.config, full_report=True)
49
+ break
50
+ return results
51
+
52
+
53
+ def fix_pdb(pdbname, outdir, file_name):
54
+ """add"""
55
+ fixer = pdbfixer.PDBFixer(pdbname)
56
+ fixer.findMissingResidues()
57
+ fixer.findNonstandardResidues()
58
+ fixer.replaceNonstandardResidues()
59
+ fixer.findMissingAtoms()
60
+ fixer.addMissingAtoms()
61
+ fixer.addMissingHydrogens(7.0)
62
+ # 根据文件名判断是否写入指定目录
63
+ # if "relaxed_complex" in file_name:
64
+ # target_path = f'{outdir}/{file_name}_hydrogen_added.pdb'
65
+ # else:
66
+ # target_path = f'{file_name}_hydrogen_added.pdb'
67
+ # mm_app.PDBFile.writeFile(fixer.topology, fixer.positions, open(target_path, 'w'))
68
+ return fixer.topology, fixer.positions
69
+
70
+
71
+ def set_system(topology):
72
+ """
73
+ Set the system using the topology from the pdb file
74
+ """
75
+ # Put it in a force field to skip adding all particles manually
76
+ forcefield = mm_app.ForceField('amber14-all.xml', 'amber14/tip3pfb.xml')
77
+
78
+ system = forcefield.createSystem(topology,
79
+ removeCMMotion=False,
80
+ nonbondedMethod=mm_app.NoCutoff,
81
+ rigidWater=True # Use implicit solvent
82
+ )
83
+ return system
84
+
85
+
86
+ def minimize_energy(
87
+ topology,
88
+ system,
89
+ positions,
90
+ outdir,
91
+ out_title
92
+ ):
93
+ '''Function that minimizes energy, given topology, OpenMM system, and positions '''
94
+ # Use a Brownian Integrator
95
+ integrator = mm.BrownianIntegrator(
96
+ 100 * mm.unit.kelvin,
97
+ 100. / mm.unit.picoseconds,
98
+ 2.0 * mm.unit.femtoseconds
99
+ )
100
+ # platform = Platform.getPlatformByName('CUDA')
101
+ # properties = {'DeviceIndex': '0', 'Precision': 'mixed'}
102
+ simulation = mm.app.Simulation(topology, system, integrator)
103
+
104
+ # Initialize the DCDReporter
105
+ reportInterval = 100 # Adjust this value as needed
106
+ reporter = mdtraj.reporters.DCDReporter('positions.dcd', reportInterval)
107
+
108
+ # Add the reporter to the simulation
109
+ simulation.reporters.append(reporter)
110
+
111
+ simulation.context.setPositions(positions)
112
+
113
+ simulation.minimizeEnergy(1, 100)
114
+ # Save positions
115
+ minpositions = simulation.context.getState(getPositions=True).getPositions()
116
+
117
+ # 根据out_title决定是否写入指定目录
118
+ if "relaxed_complex" in out_title:
119
+ target_path = outdir + f'/{out_title}.pdb'
120
+ else:
121
+ target_path = f'{out_title}.pdb'
122
+ mm_app.PDBFile.writeFile(topology, minpositions, open(target_path, 'w'))
123
+
124
+ # Get and return the minimized energy
125
+ minimized_energy = simulation.context.getState(getEnergy=True).getPotentialEnergy()
126
+
127
+ reporter.close()
128
+
129
+ return topology, minpositions, minimized_energy
130
+
131
+
132
+ def add_restraints(
133
+ system,
134
+ topology,
135
+ positions,
136
+ restraint_type
137
+ ):
138
+ '''Function to add restraints to specified group of atoms
139
+
140
+ Code adapted from https://gist.github.com/peastman/ad8cda653242d731d75e18c836b2a3a5
141
+
142
+ '''
143
+ restraint = CustomExternalForce('k*periodicdistance(x, y, z, x0, y0, z0)^2')
144
+ system.addForce(restraint)
145
+ restraint.addGlobalParameter('k', 100000000.0 * mm_unit.kilojoules_per_mole / mm_unit.nanometer ** 2)
146
+ restraint.addPerParticleParameter('x0')
147
+ restraint.addPerParticleParameter('y0')
148
+ restraint.addPerParticleParameter('z0')
149
+
150
+ for atom in topology.atoms():
151
+ if restraint_type == 'protein':
152
+ if 'x' not in atom.name:
153
+ restraint.addParticle(atom.index, positions[atom.index])
154
+ elif restraint_type == 'CA+ligand':
155
+ if ('x' in atom.name) or (atom.name == "CA"):
156
+ restraint.addParticle(atom.index, positions[atom.index])
157
+
158
+ return system
159
+
160
+
161
+ def run(
162
+ # i
163
+ input_pdb,
164
+ outdir,
165
+ mol_in,
166
+ file_name,
167
+ restraint_type="ca+ligand",
168
+ relax_protein_first=False,
169
+ steps=100,
170
+ ):
171
+ try:
172
+ ligand_mol = Molecule.from_file(mol_in)
173
+ # Check for undefined stereochemistry, allow undefined stereochemistry to be loaded
174
+ except UndefinedStereochemistryError:
175
+ print('Undefined Stereochemistry Error found! Trying with undefined stereo flag True')
176
+ ligand_mol = Molecule.from_file(mol_in, allow_undefined_stereo=True)
177
+ # Check for radicals -- break out of script if radical is encountered
178
+ except RadicalsNotSupportedError:
179
+ print('OpenFF does not currently support radicals -- use unrelaxed structure')
180
+ sys.exit()
181
+ # Assigning partial charges first because the default method (am1bcc) does not work
182
+ ligand_mol.assign_partial_charges(partial_charge_method='gasteiger')
183
+
184
+ ## Read protein PDB and add hydrogens
185
+ protein_topology, protein_positions = fix_pdb(input_pdb, outdir, file_name)
186
+ # print('Added all atoms...')
187
+
188
+ # Minimize energy for the protein
189
+ system = set_system(protein_topology)
190
+ # print('Creating system...')
191
+ # Relax
192
+ if relax_protein_first:
193
+ print('Relaxing ONLY protein structure...')
194
+ protein_topology, protein_positions = minimize_energy(
195
+ protein_topology,
196
+ system,
197
+ protein_positions,
198
+ outdir,
199
+ f'{file_name}_relaxed_protein'
200
+ )
201
+
202
+ # print('Preparing complex')
203
+ ## Add protein first
204
+ modeller = Modeller(protein_topology, protein_positions)
205
+ # print('System has %d atoms' % modeller.topology.getNumAtoms())
206
+
207
+ ## Then add ligand
208
+ # print('Adding ligand...')
209
+ lig_top = ligand_mol.to_topology()
210
+ modeller.add(lig_top.to_openmm(), lig_top.get_positions().to_openmm())
211
+ # print('System has %d atoms' % modeller.topology.getNumAtoms())
212
+
213
+ # print('Preparing system')
214
+ # Initialize a SystemGenerator using the GAFF for the ligand and implicit water.
215
+ # forcefield_kwargs = {'constraints': mm_app.HBonds, 'rigidWater': True, 'removeCMMotion': False, 'hydrogenMass': 4*mm_unit.amu }
216
+ system_generator = SystemGenerator(
217
+ forcefields=['amber14-all.xml', 'implicit/gbn2.xml'],
218
+ small_molecule_forcefield='gaff-2.11',
219
+ molecules=[ligand_mol],
220
+ # forcefield_kwargs=forcefield_kwargs
221
+ )
222
+
223
+ ## Create system
224
+ system = system_generator.create_system(modeller.topology, molecules=ligand_mol)
225
+
226
+ # if restraint_type == 'protein':
227
+ # print('Adding restraints on entire protein')
228
+ # elif restraint_type == 'CA+ligand':
229
+ # print('Adding restraints on protein CAs and ligand atoms')
230
+
231
+ system = add_restraints(system, modeller.topology, modeller.positions, restraint_type=restraint_type)
232
+
233
+ ## Minimize energy for the complex and print the minimized energy
234
+ _, _, minimized_energy = minimize_energy(
235
+ modeller.topology,
236
+ system,
237
+ modeller.positions,
238
+ outdir,
239
+ f'{file_name}_relaxed_complex'
240
+ )
241
+
242
+
243
+ def relax(receptor_pdb, ligand_mol_sdf):
244
+ output_dir = os.path.split(receptor_pdb)[0]
245
+ file_name = os.path.split(receptor_pdb)[1].split(".")[0]
246
+ system_file_name = "system" + file_name.split("receptor")[1]
247
+ try:
248
+ run(
249
+ input_pdb=receptor_pdb,
250
+ outdir=output_dir,
251
+ mol_in=ligand_mol_sdf,
252
+ file_name=system_file_name
253
+ )
254
+ lines = load_txt(
255
+ os.path.join(output_dir, f"{system_file_name}_relaxed_complex.pdb")).split("\n")
256
+ receptor = "\n".join([i for i in lines if "HETATM" not in i])
257
+ dump_txt(receptor, os.path.join(output_dir, f"{file_name}_relaxed_complex.pdb"))
258
+ except Exception as e:
259
+ print(dir, "can't relax,", e)
PhysDock/data/tools/PDBData.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2000 Andrew Dalke. All rights reserved.
2
+ #
3
+ # This file is part of the Biopython distribution and governed by your
4
+ # choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
5
+ # Please see the LICENSE file that should have been included as part of this
6
+ # package.
7
+ """Information about the IUPAC alphabets."""
8
+
9
+ protein_letters = "ACDEFGHIKLMNPQRSTVWY"
10
+ extended_protein_letters = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
11
+
12
+ # B = "Asx"; aspartic acid or asparagine (D or N)
13
+ # X = "Xxx"; unknown or 'other' amino acid
14
+ # Z = "Glx"; glutamic acid or glutamine (E or Q)
15
+ # http://www.chem.qmul.ac.uk/iupac/AminoAcid/A2021.html#AA212
16
+ #
17
+ # J = "Xle"; leucine or isoleucine (L or I, used in NMR)
18
+ # Mentioned in http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
19
+ # Also the International Nucleotide Sequence Database Collaboration (INSDC)
20
+ # (i.e. GenBank, EMBL, DDBJ) adopted this in 2006
21
+ # http://www.ddbj.nig.ac.jp/insdc/icm2006-e.html
22
+ #
23
+ # Xle (J); Leucine or Isoleucine
24
+ # The residue abbreviations, Xle (the three-letter abbreviation) and J
25
+ # (the one-letter abbreviation) are reserved for the case that cannot
26
+ # experimentally distinguish leucine from isoleucine.
27
+ #
28
+ # U = "Sec"; selenocysteine
29
+ # http://www.chem.qmul.ac.uk/iubmb/newsletter/1999/item3.html
30
+ #
31
+ # O = "Pyl"; pyrrolysine
32
+ # http://www.chem.qmul.ac.uk/iubmb/newsletter/2009.html#item35
33
+
34
+ protein_letters_1to3 = {
35
+ "A": "Ala",
36
+ "C": "Cys",
37
+ "D": "Asp",
38
+ "E": "Glu",
39
+ "F": "Phe",
40
+ "G": "Gly",
41
+ "H": "His",
42
+ "I": "Ile",
43
+ "K": "Lys",
44
+ "L": "Leu",
45
+ "M": "Met",
46
+ "N": "Asn",
47
+ "P": "Pro",
48
+ "Q": "Gln",
49
+ "R": "Arg",
50
+ "S": "Ser",
51
+ "T": "Thr",
52
+ "V": "Val",
53
+ "W": "Trp",
54
+ "Y": "Tyr",
55
+ }
56
+ protein_letters_1to3 = {k.upper(): v.upper() for k, v in protein_letters_1to3.items()}
57
+ protein_letters_3to1 = {v: k for k, v in protein_letters_1to3.items()}
58
+
59
+ protein_letters_3to1_extended = {
60
+ "A5N": "N", "A8E": "V", "A9D": "S", "AA3": "A", "AA4": "A", "AAR": "R",
61
+ "ABA": "A", "ACL": "R", "AEA": "C", "AEI": "D", "AFA": "N", "AGM": "R",
62
+ "AGQ": "Y", "AGT": "C", "AHB": "N", "AHL": "R", "AHO": "A", "AHP": "A",
63
+ "AIB": "A", "AKL": "D", "AKZ": "D", "ALA": "A", "ALC": "A", "ALM": "A",
64
+ "ALN": "A", "ALO": "T", "ALS": "A", "ALT": "A", "ALV": "A", "ALY": "K",
65
+ "AME": "M", "AN6": "L", "AN8": "A", "API": "K", "APK": "K", "AR2": "R",
66
+ "AR4": "E", "AR7": "R", "ARG": "R", "ARM": "R", "ARO": "R", "AS7": "N",
67
+ "ASA": "D", "ASB": "D", "ASI": "D", "ASK": "D", "ASL": "D", "ASN": "N",
68
+ "ASP": "D", "ASQ": "D", "AYA": "A", "AZH": "A", "AZK": "K", "AZS": "S",
69
+ "AZY": "Y", "AVJ": "H", "A30": "Y", "A3U": "F", "ECC": "Q", "ECX": "C",
70
+ "EFC": "C", "EHP": "F", "ELY": "K", "EME": "E", "EPM": "M", "EPQ": "Q",
71
+ "ESB": "Y", "ESC": "M", "EXY": "L", "EXA": "K", "E0Y": "P", "E9V": "H",
72
+ "E9M": "W", "EJA": "C", "EUP": "T", "EZY": "G", "E9C": "Y", "EW6": "S",
73
+ "EXL": "W", "I2M": "I", "I4G": "G", "I58": "K", "IAM": "A", "IAR": "R",
74
+ "ICY": "C", "IEL": "K", "IGL": "G", "IIL": "I", "ILE": "I", "ILG": "E",
75
+ "ILM": "I", "ILX": "I", "ILY": "K", "IML": "I", "IOR": "R", "IPG": "G",
76
+ "IT1": "K", "IYR": "Y", "IZO": "M", "IC0": "G", "M0H": "C", "M2L": "K",
77
+ "M2S": "M", "M30": "G", "M3L": "K", "M3R": "K", "MA ": "A", "MAA": "A",
78
+ "MAI": "R", "MBQ": "Y", "MC1": "S", "MCL": "K", "MCS": "C", "MD3": "C",
79
+ "MD5": "C", "MD6": "G", "MDF": "Y", "ME0": "M", "MEA": "F", "MEG": "E",
80
+ "MEN": "N", "MEQ": "Q", "MET": "M", "MEU": "G", "MFN": "E", "MGG": "R",
81
+ "MGN": "Q", "MGY": "G", "MH1": "H", "MH6": "S", "MHL": "L", "MHO": "M",
82
+ "MHS": "H", "MHU": "F", "MIR": "S", "MIS": "S", "MK8": "L", "ML3": "K",
83
+ "MLE": "L", "MLL": "L", "MLY": "K", "MLZ": "K", "MME": "M", "MMO": "R",
84
+ "MNL": "L", "MNV": "V", "MP8": "P", "MPQ": "G", "MSA": "G", "MSE": "M",
85
+ "MSL": "M", "MSO": "M", "MT2": "M", "MTY": "Y", "MVA": "V", "MYK": "K",
86
+ "MYN": "R", "QCS": "C", "QIL": "I", "QMM": "Q", "QPA": "C", "QPH": "F",
87
+ "Q3P": "K", "QVA": "C", "QX7": "A", "Q2E": "W", "Q75": "M", "Q78": "F",
88
+ "QM8": "L", "QMB": "A", "QNQ": "C", "QNT": "C", "QNW": "C", "QO2": "C",
89
+ "QO5": "C", "QO8": "C", "QQ8": "Q", "U2X": "Y", "U3X": "F", "UF0": "S",
90
+ "UGY": "G", "UM1": "A", "UM2": "A", "UMA": "A", "UQK": "A", "UX8": "W",
91
+ "UXQ": "F", "YCM": "C", "YOF": "Y", "YPR": "P", "YPZ": "Y", "YTH": "T",
92
+ "Y1V": "L", "Y57": "K", "YHA": "K", "200": "F", "23F": "F", "23P": "A",
93
+ "26B": "T", "28X": "T", "2AG": "A", "2CO": "C", "2FM": "M", "2GX": "F",
94
+ "2HF": "H", "2JG": "S", "2KK": "K", "2KP": "K", "2LT": "Y", "2LU": "L",
95
+ "2ML": "L", "2MR": "R", "2MT": "P", "2OR": "R", "2P0": "P", "2QZ": "T",
96
+ "2R3": "Y", "2RA": "A", "2RX": "S", "2SO": "H", "2TY": "Y", "2VA": "V",
97
+ "2XA": "C", "2ZC": "S", "6CL": "K", "6CW": "W", "6GL": "A", "6HN": "K",
98
+ "60F": "C", "66D": "I", "6CV": "A", "6M6": "C", "6V1": "C", "6WK": "C",
99
+ "6Y9": "P", "6DN": "K", "DA2": "R", "DAB": "A", "DAH": "F", "DBS": "S",
100
+ "DBU": "T", "DBY": "Y", "DBZ": "A", "DC2": "C", "DDE": "H", "DDZ": "A",
101
+ "DI7": "Y", "DHA": "S", "DHN": "V", "DIR": "R", "DLS": "K", "DM0": "K",
102
+ "DMH": "N", "DMK": "D", "DNL": "K", "DNP": "A", "DNS": "K", "DNW": "A",
103
+ "DOH": "D", "DON": "L", "DP1": "R", "DPL": "P", "DPP": "A", "DPQ": "Y",
104
+ "DYS": "C", "D2T": "D", "DYA": "D", "DJD": "F", "DYJ": "P", "DV9": "E",
105
+ "H14": "F", "H1D": "M", "H5M": "P", "HAC": "A", "HAR": "R", "HBN": "H",
106
+ "HCM": "C", "HGY": "G", "HHI": "H", "HIA": "H", "HIC": "H", "HIP": "H",
107
+ "HIQ": "H", "HIS": "H", "HL2": "L", "HLU": "L", "HMR": "R", "HNC": "C",
108
+ "HOX": "F", "HPC": "F", "HPE": "F", "HPH": "F", "HPQ": "F", "HQA": "A",
109
+ "HR7": "R", "HRG": "R", "HRP": "W", "HS8": "H", "HS9": "H", "HSE": "S",
110
+ "HSK": "H", "HSL": "S", "HSO": "H", "HT7": "W", "HTI": "C", "HTR": "W",
111
+ "HV5": "A", "HVA": "V", "HY3": "P", "HYI": "M", "HYP": "P", "HZP": "P",
112
+ "HIX": "A", "HSV": "H", "HLY": "K", "HOO": "H", "H7V": "A", "L5P": "K",
113
+ "LRK": "K", "L3O": "L", "LA2": "K", "LAA": "D", "LAL": "A", "LBY": "K",
114
+ "LCK": "K", "LCX": "K", "LDH": "K", "LE1": "V", "LED": "L", "LEF": "L",
115
+ "LEH": "L", "LEM": "L", "LEN": "L", "LET": "K", "LEU": "L", "LEX": "L",
116
+ "LGY": "K", "LLO": "K", "LLP": "K", "LLY": "K", "LLZ": "K", "LME": "E",
117
+ "LMF": "K", "LMQ": "Q", "LNE": "L", "LNM": "L", "LP6": "K", "LPD": "P",
118
+ "LPG": "G", "LPS": "S", "LSO": "K", "LTR": "W", "LVG": "G", "LVN": "V",
119
+ "LWY": "P", "LYF": "K", "LYK": "K", "LYM": "K", "LYN": "K", "LYO": "K",
120
+ "LYP": "K", "LYR": "K", "LYS": "K", "LYU": "K", "LYX": "K", "LYZ": "K",
121
+ "LAY": "L", "LWI": "F", "LBZ": "K", "P1L": "C", "P2Q": "Y", "P2Y": "P",
122
+ "P3Q": "Y", "PAQ": "Y", "PAS": "D", "PAT": "W", "PBB": "C", "PBF": "F",
123
+ "PCA": "Q", "PCC": "P", "PCS": "F", "PE1": "K", "PEC": "C", "PF5": "F",
124
+ "PFF": "F", "PG1": "S", "PGY": "G", "PHA": "F", "PHD": "D", "PHE": "F",
125
+ "PHI": "F", "PHL": "F", "PHM": "F", "PKR": "P", "PLJ": "P", "PM3": "F",
126
+ "POM": "P", "PPN": "F", "PR3": "C", "PR4": "P", "PR7": "P", "PR9": "P",
127
+ "PRJ": "P", "PRK": "K", "PRO": "P", "PRS": "P", "PRV": "G", "PSA": "F",
128
+ "PSH": "H", "PTH": "Y", "PTM": "Y", "PTR": "Y", "PVH": "H", "PXU": "P",
129
+ "PYA": "A", "PYH": "K", "PYX": "C", "PH6": "P", "P9S": "C", "P5U": "S",
130
+ "POK": "R", "T0I": "Y", "T11": "F", "TAV": "D", "TBG": "V", "TBM": "T",
131
+ "TCQ": "Y", "TCR": "W", "TEF": "F", "TFQ": "F", "TH5": "T", "TH6": "T",
132
+ "THC": "T", "THR": "T", "THZ": "R", "TIH": "A", "TIS": "S", "TLY": "K",
133
+ "TMB": "T", "TMD": "T", "TNB": "C", "TNR": "S", "TNY": "T", "TOQ": "W",
134
+ "TOX": "W", "TPJ": "P", "TPK": "P", "TPL": "W", "TPO": "T", "TPQ": "Y",
135
+ "TQI": "W", "TQQ": "W", "TQZ": "C", "TRF": "W", "TRG": "K", "TRN": "W",
136
+ "TRO": "W", "TRP": "W", "TRQ": "W", "TRW": "W", "TRX": "W", "TRY": "W",
137
+ "TS9": "I", "TSY": "C", "TTQ": "W", "TTS": "Y", "TXY": "Y", "TY1": "Y",
138
+ "TY2": "Y", "TY3": "Y", "TY5": "Y", "TY8": "Y", "TY9": "Y", "TYB": "Y",
139
+ "TYC": "Y", "TYE": "Y", "TYI": "Y", "TYJ": "Y", "TYN": "Y", "TYO": "Y",
140
+ "TYQ": "Y", "TYR": "Y", "TYS": "Y", "TYT": "Y", "TYW": "Y", "TYY": "Y",
141
+ "T8L": "T", "T9E": "T", "TNQ": "W", "TSQ": "F", "TGH": "W", "X2W": "E",
142
+ "XCN": "C", "XPR": "P", "XSN": "N", "XW1": "A", "XX1": "K", "XYC": "A",
143
+ "XA6": "F", "11Q": "P", "11W": "E", "12L": "P", "12X": "P", "12Y": "P",
144
+ "143": "C", "1AC": "A", "1L1": "A", "1OP": "Y", "1PA": "F", "1PI": "A",
145
+ "1TQ": "W", "1TY": "Y", "1X6": "S", "56A": "H", "5AB": "A", "5CS": "C",
146
+ "5CW": "W", "5HP": "E", "5OH": "A", "5PG": "G", "51T": "Y", "54C": "W",
147
+ "5CR": "F", "5CT": "K", "5FQ": "A", "5GM": "I", "5JP": "S", "5T3": "K",
148
+ "5MW": "K", "5OW": "K", "5R5": "S", "5VV": "N", "5XU": "A", "55I": "F",
149
+ "999": "D", "9DN": "N", "9NE": "E", "9NF": "F", "9NR": "R", "9NV": "V",
150
+ "9E7": "K", "9KP": "K", "9WV": "A", "9TR": "K", "9TU": "K", "9TX": "K",
151
+ "9U0": "K", "9IJ": "F", "B1F": "F", "B27": "T", "B2A": "A", "B2F": "F",
152
+ "B2I": "I", "B2V": "V", "B3A": "A", "B3D": "D", "B3E": "E", "B3K": "K",
153
+ "B3U": "H", "B3X": "N", "B3Y": "Y", "BB6": "C", "BB7": "C", "BB8": "F",
154
+ "BB9": "C", "BBC": "C", "BCS": "C", "BCX": "C", "BFD": "D", "BG1": "S",
155
+ "BH2": "D", "BHD": "D", "BIF": "F", "BIU": "I", "BL2": "L", "BLE": "L",
156
+ "BLY": "K", "BMT": "T", "BNN": "F", "BOR": "R", "BP5": "A", "BPE": "C",
157
+ "BSE": "S", "BTA": "L", "BTC": "C", "BTK": "K", "BTR": "W", "BUC": "C",
158
+ "BUG": "V", "BYR": "Y", "BWV": "R", "BWB": "S", "BXT": "S", "F2F": "F",
159
+ "F2Y": "Y", "FAK": "K", "FB5": "A", "FB6": "A", "FC0": "F", "FCL": "F",
160
+ "FDL": "K", "FFM": "C", "FGL": "G", "FGP": "S", "FH7": "K", "FHL": "K",
161
+ "FHO": "K", "FIO": "R", "FLA": "A", "FLE": "L", "FLT": "Y", "FME": "M",
162
+ "FOE": "C", "FP9": "P", "FPK": "P", "FT6": "W", "FTR": "W", "FTY": "Y",
163
+ "FVA": "V", "FZN": "K", "FY3": "Y", "F7W": "W", "FY2": "Y", "FQA": "K",
164
+ "F7Q": "Y", "FF9": "K", "FL6": "D", "JJJ": "C", "JJK": "C", "JJL": "C",
165
+ "JLP": "K", "J3D": "C", "J9Y": "R", "J8W": "S", "JKH": "P", "N10": "S",
166
+ "N7P": "P", "NA8": "A", "NAL": "A", "NAM": "A", "NBQ": "Y", "NC1": "S",
167
+ "NCB": "A", "NEM": "H", "NEP": "H", "NFA": "F", "NIY": "Y", "NLB": "L",
168
+ "NLE": "L", "NLN": "L", "NLO": "L", "NLP": "L", "NLQ": "Q", "NLY": "G",
169
+ "NMC": "G", "NMM": "R", "NNH": "R", "NOT": "L", "NPH": "C", "NPI": "A",
170
+ "NTR": "Y", "NTY": "Y", "NVA": "V", "NWD": "A", "NYB": "C", "NYS": "C",
171
+ "NZH": "H", "N80": "P", "NZC": "T", "NLW": "L", "N0A": "F", "N9P": "A",
172
+ "N65": "K", "R1A": "C", "R4K": "W", "RE0": "W", "RE3": "W", "RGL": "R",
173
+ "RGP": "E", "RT0": "P", "RVX": "S", "RZ4": "S", "RPI": "R", "RVJ": "A",
174
+ "VAD": "V", "VAF": "V", "VAH": "V", "VAI": "V", "VAL": "V", "VB1": "K",
175
+ "VH0": "P", "VR0": "R", "V44": "C", "V61": "F", "VPV": "K", "V5N": "H",
176
+ "V7T": "K", "Z01": "A", "Z3E": "T", "Z70": "H", "ZBZ": "C", "ZCL": "F",
177
+ "ZU0": "T", "ZYJ": "P", "ZYK": "P", "ZZD": "C", "ZZJ": "A", "ZIQ": "W",
178
+ "ZPO": "P", "ZDJ": "Y", "ZT1": "K", "30V": "C", "31Q": "C", "33S": "F",
179
+ "33W": "A", "34E": "V", "3AH": "H", "3BY": "P", "3CF": "F", "3CT": "Y",
180
+ "3GA": "A", "3GL": "E", "3MD": "D", "3MY": "Y", "3NF": "Y", "3O3": "E",
181
+ "3PX": "P", "3QN": "K", "3TT": "P", "3XH": "G", "3YM": "Y", "3WS": "A",
182
+ "3WX": "P", "3X9": "C", "3ZH": "H", "7JA": "I", "73C": "S", "73N": "R",
183
+ "73O": "Y", "73P": "K", "74P": "K", "7N8": "F", "7O5": "A", "7XC": "F",
184
+ "7ID": "D", "7OZ": "A", "C1S": "C", "C1T": "C", "C1X": "K", "C22": "A",
185
+ "C3Y": "C", "C4R": "C", "C5C": "C", "C6C": "C", "CAF": "C", "CAS": "C",
186
+ "CAY": "C", "CCS": "C", "CEA": "C", "CGA": "E", "CGU": "E", "CGV": "C",
187
+ "CHP": "G", "CIR": "R", "CLE": "L", "CLG": "K", "CLH": "K", "CME": "C",
188
+ "CMH": "C", "CML": "C", "CMT": "C", "CR5": "G", "CS0": "C", "CS1": "C",
189
+ "CS3": "C", "CS4": "C", "CSA": "C", "CSB": "C", "CSD": "C", "CSE": "C",
190
+ "CSJ": "C", "CSO": "C", "CSP": "C", "CSR": "C", "CSS": "C", "CSU": "C",
191
+ "CSW": "C", "CSX": "C", "CSZ": "C", "CTE": "W", "CTH": "T", "CWD": "A",
192
+ "CWR": "S", "CXM": "M", "CY0": "C", "CY1": "C", "CY3": "C", "CY4": "C",
193
+ "CYA": "C", "CYD": "C", "CYF": "C", "CYG": "C", "CYJ": "K", "CYM": "C",
194
+ "CYQ": "C", "CYR": "C", "CYS": "C", "CYW": "C", "CZ2": "C", "CZZ": "C",
195
+ "CG6": "C", "C1J": "R", "C4G": "R", "C67": "R", "C6D": "R", "CE7": "N",
196
+ "CZS": "A", "G01": "E", "G8M": "E", "GAU": "E", "GEE": "G", "GFT": "S",
197
+ "GHC": "E", "GHG": "Q", "GHW": "E", "GL3": "G", "GLH": "Q", "GLJ": "E",
198
+ "GLK": "E", "GLN": "Q", "GLQ": "E", "GLU": "E", "GLY": "G", "GLZ": "G",
199
+ "GMA": "E", "GME": "E", "GNC": "Q", "GPL": "K", "GSC": "G", "GSU": "E",
200
+ "GT9": "C", "GVL": "S", "G3M": "R", "G5G": "L", "G1X": "Y", "G8X": "P",
201
+ "K1R": "C", "KBE": "K", "KCX": "K", "KFP": "K", "KGC": "K", "KNB": "A",
202
+ "KOR": "M", "KPI": "K", "KPY": "K", "KST": "K", "KYN": "W", "KYQ": "K",
203
+ "KCR": "K", "KPF": "K", "K5L": "S", "KEO": "K", "KHB": "K", "KKD": "D",
204
+ "K5H": "C", "K7K": "S", "OAR": "R", "OAS": "S", "OBS": "K", "OCS": "C",
205
+ "OCY": "C", "OHI": "H", "OHS": "D", "OLD": "H", "OLT": "T", "OLZ": "S",
206
+ "OMH": "S", "OMT": "M", "OMX": "Y", "OMY": "Y", "ONH": "A", "ORN": "A",
207
+ "ORQ": "R", "OSE": "S", "OTH": "T", "OXX": "D", "OYL": "H", "O7A": "T",
208
+ "O7D": "W", "O7G": "V", "O2E": "S", "O6H": "W", "OZW": "F", "S12": "S",
209
+ "S1H": "S", "S2C": "C", "S2P": "A", "SAC": "S", "SAH": "C", "SAR": "G",
210
+ "SBG": "S", "SBL": "S", "SCH": "C", "SCS": "C", "SCY": "C", "SD4": "N",
211
+ "SDB": "S", "SDP": "S", "SEB": "S", "SEE": "S", "SEG": "A", "SEL": "S",
212
+ "SEM": "S", "SEN": "S", "SEP": "S", "SER": "S", "SET": "S", "SGB": "S",
213
+ "SHC": "C", "SHP": "G", "SHR": "K", "SIB": "C", "SLL": "K", "SLZ": "K",
214
+ "SMC": "C", "SME": "M", "SMF": "F", "SNC": "C", "SNN": "N", "SOY": "S",
215
+ "SRZ": "S", "STY": "Y", "SUN": "S", "SVA": "S", "SVV": "S", "SVW": "S",
216
+ "SVX": "S", "SVY": "S", "SVZ": "S", "SXE": "S", "SKH": "K", "SNM": "S",
217
+ "SNK": "H", "SWW": "S", "WFP": "F", "WLU": "L", "WPA": "F", "WRP": "W",
218
+ "WVL": "V", "02K": "A", "02L": "N", "02O": "A", "02Y": "A", "033": "V",
219
+ "037": "P", "03Y": "C", "04U": "P", "04V": "P", "05N": "P", "07O": "C",
220
+ "0A0": "D", "0A1": "Y", "0A2": "K", "0A8": "C", "0A9": "F", "0AA": "V",
221
+ "0AB": "V", "0AC": "G", "0AF": "W", "0AG": "L", "0AH": "S", "0AK": "D",
222
+ "0AR": "R", "0BN": "F", "0CS": "A", "0E5": "T", "0EA": "Y", "0FL": "A",
223
+ "0LF": "P", "0NC": "A", "0PR": "Y", "0QL": "C", "0TD": "D", "0UO": "W",
224
+ "0WZ": "Y", "0X9": "R", "0Y8": "P", "4AF": "F", "4AR": "R", "4AW": "W",
225
+ "4BF": "F", "4CF": "F", "4CY": "M", "4DP": "W", "4FB": "P", "4FW": "W",
226
+ "4HL": "Y", "4HT": "W", "4IN": "W", "4MM": "M", "4PH": "F", "4U7": "A",
227
+ "41H": "F", "41Q": "N", "42Y": "S", "432": "S", "45F": "P", "4AK": "K",
228
+ "4D4": "R", "4GJ": "C", "4KY": "P", "4L0": "P", "4LZ": "Y", "4N7": "P",
229
+ "4N8": "P", "4N9": "P", "4OG": "W", "4OU": "F", "4OV": "S", "4OZ": "S",
230
+ "4PQ": "W", "4SJ": "F", "4WQ": "A", "4HH": "S", "4HJ": "S", "4J4": "C",
231
+ "4J5": "R", "4II": "F", "4VI": "R", "823": "N", "8SP": "S", "8AY": "A",
232
+ }
233
+
234
+ # Nucleic Acids
235
+ nucleic_letters_3to1 = {
236
+ "A ": "A", "C ": "C", "G ": "G", "U ": "U",
237
+ "DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
238
+ }
239
+
240
+ rna_letters_3to1 = {
241
+ "A ": "A", "C ": "C", "G ": "G", "U ": "U",
242
+ }
243
+
244
+ dna_letters_3to1 = {
245
+ "DA ": "A", "DC ": "C", "DG ": "G", "DT ": "T",
246
+ }
247
+
248
+ # fmt: off
249
+ nucleic_letters_3to1_extended = {
250
+ "A ": "A", "A23": "A", "A2L": "A", "A2M": "A", "A34": "A", "A35": "A",
251
+ "A38": "A", "A39": "A", "A3A": "A", "A3P": "A", "A40": "A", "A43": "A",
252
+ "A44": "A", "A47": "A", "A5L": "A", "A5M": "C", "A5O": "A", "A6A": "A",
253
+ "A6C": "C", "A6G": "G", "A6U": "U", "A7E": "A", "A9Z": "A", "ABR": "A",
254
+ "ABS": "A", "AD2": "A", "ADI": "A", "ADP": "A", "AET": "A", "AF2": "A",
255
+ "AFG": "G", "AMD": "A", "AMO": "A", "AP7": "A", "AS ": "A", "ATD": "T",
256
+ "ATL": "T", "ATM": "T", "AVC": "A", "AI5": "C", "E ": "A", "E1X": "A",
257
+ "EDA": "A", "EFG": "G", "EHG": "G", "EIT": "T", "EXC": "C", "E3C": "C",
258
+ "E6G": "G", "E7G": "G", "EQ4": "G", "EAN": "T", "I5C": "C", "IC ": "C",
259
+ "IG ": "G", "IGU": "G", "IMC": "C", "IMP": "G", "IU ": "U", "I4U": "U",
260
+ "IOO": "G", "M1G": "G", "M2G": "G", "M4C": "C", "M5M": "C", "MA6": "A",
261
+ "MA7": "A", "MAD": "A", "MCY": "C", "ME6": "C", "MEP": "U", "MG1": "G",
262
+ "MGQ": "A", "MGT": "G", "MGV": "G", "MIA": "A", "MMT": "T", "MNU": "U",
263
+ "MRG": "G", "MTR": "T", "MTU": "A", "MFO": "G", "M7A": "A", "MHG": "G",
264
+ "MMX": "C", "QUO": "G", "QCK": "T", "QSQ": "A", "U ": "U", "U25": "U",
265
+ "U2L": "U", "U2P": "U", "U31": "U", "U34": "U", "U36": "U", "U37": "U",
266
+ "U8U": "U", "UAR": "U", "UBB": "U", "UBD": "U", "UD5": "U", "UPV": "U",
267
+ "UR3": "U", "URD": "U", "US3": "T", "US5": "U", "UZR": "U", "UMO": "U",
268
+ "U23": "U", "U48": "C", "U7B": "C", "Y ": "A", "YCO": "C", "YG ": "G",
269
+ "YYG": "G", "23G": "G", "26A": "A", "2AR": "A", "2AT": "T", "2AU": "U",
270
+ "2BT": "T", "2BU": "A", "2DA": "A", "2DT": "T", "2EG": "G", "2GT": "T",
271
+ "2JV": "G", "2MA": "A", "2MG": "G", "2MU": "U", "2NT": "T", "2OM": "U",
272
+ "2OT": "T", "2PR": "G", "2SG": "G", "2ST": "T", "63G": "G", "63H": "G",
273
+ "64T": "T", "68Z": "G", "6CT": "T", "6HA": "A", "6HB": "A", "6HC": "C",
274
+ "6HG": "G", "6HT": "T", "6IA": "A", "6MA": "A", "6MC": "A", "6MP": "A",
275
+ "6MT": "A", "6MZ": "A", "6OG": "G", "6PO": "G", "6FK": "G", "6NW": "A",
276
+ "6OO": "C", "D00": "C", "D3T": "T", "D4M": "T", "DA ": "A", "DC ": "C",
277
+ "DCG": "G", "DCT": "C", "DDG": "G", "DFC": "C", "DFG": "G", "DG ": "G",
278
+ "DG8": "G", "DGI": "G", "DGP": "G", "DHU": "U", "DNR": "C", "DOC": "C",
279
+ "DPB": "T", "DRT": "T", "DT ": "T", "DZM": "A", "D4B": "C", "H2U": "U",
280
+ "HN0": "G", "HN1": "G", "LC ": "C", "LCA": "A", "LCG": "G", "LG ": "G",
281
+ "LGP": "G", "LHU": "U", "LSH": "T", "LST": "T", "LDG": "G", "L3X": "A",
282
+ "LHH": "C", "LV2": "C", "L1J": "G", "P ": "G", "P2T": "T", "P5P": "A",
283
+ "PG7": "G", "PGN": "G", "PGP": "G", "PMT": "C", "PPU": "A", "PPW": "G",
284
+ "PR5": "A", "PRN": "A", "PST": "T", "PSU": "U", "PU ": "A", "PVX": "C",
285
+ "PYO": "U", "PZG": "G", "P4U": "U", "P7G": "G", "T ": "T", "T2S": "T",
286
+ "T31": "U", "T32": "T", "T36": "T", "T37": "T", "T38": "T", "T39": "T",
287
+ "T3P": "T", "T41": "T", "T48": "T", "T49": "T", "T4S": "T", "T5S": "T",
288
+ "T64": "T", "T6A": "A", "TA3": "T", "TAF": "T", "TBN": "A", "TC1": "C",
289
+ "TCP": "T", "TCY": "A", "TDY": "T", "TED": "T", "TFE": "T", "TFF": "T",
290
+ "TFO": "A", "TFT": "T", "TGP": "G", "TCJ": "C", "TLC": "T", "TP1": "T",
291
+ "TPC": "C", "TPG": "G", "TSP": "T", "TTD": "T", "TTM": "T", "TXD": "A",
292
+ "TXP": "A", "TC ": "C", "TG ": "G", "T0N": "G", "T0Q": "G", "X ": "G",
293
+ "XAD": "A", "XAL": "A", "XCL": "C", "XCR": "C", "XCT": "C", "XCY": "C",
294
+ "XGL": "G", "XGR": "G", "XGU": "G", "XPB": "G", "XTF": "T", "XTH": "T",
295
+ "XTL": "T", "XTR": "T", "XTS": "G", "XUA": "A", "XUG": "G", "102": "G",
296
+ "10C": "C", "125": "U", "126": "U", "127": "U", "12A": "A", "16B": "C",
297
+ "18M": "G", "1AP": "A", "1CC": "C", "1FC": "C", "1MA": "A", "1MG": "G",
298
+ "1RN": "U", "1SC": "C", "5AA": "A", "5AT": "T", "5BU": "U", "5CG": "G",
299
+ "5CM": "C", "5FA": "A", "5FC": "C", "5FU": "U", "5HC": "C", "5HM": "C",
300
+ "5HT": "T", "5IC": "C", "5IT": "T", "5MC": "C", "5MU": "U", "5NC": "C",
301
+ "5PC": "C", "5PY": "T", "9QV": "U", "94O": "T", "9SI": "A", "9SY": "A",
302
+ "B7C": "C", "BGM": "G", "BOE": "T", "B8H": "U", "B8K": "G", "B8Q": "C",
303
+ "B8T": "C", "B8W": "G", "B9B": "G", "B9H": "C", "BGH": "G", "F3H": "T",
304
+ "F3N": "A", "F4H": "T", "FA2": "A", "FDG": "G", "FHU": "U", "FMG": "G",
305
+ "FNU": "U", "FOX": "G", "F2T": "U", "F74": "G", "F4Q": "G", "F7H": "C",
306
+ "F7K": "G", "JDT": "T", "JMH": "C", "J0X": "C", "N5M": "C", "N6G": "G",
307
+ "N79": "A", "NCU": "C", "NMS": "T", "NMT": "T", "NTT": "T", "N7X": "C",
308
+ "R ": "A", "RBD": "A", "RDG": "G", "RIA": "A", "RMP": "A", "RPC": "C",
309
+ "RSP": "C", "RSQ": "C", "RT ": "T", "RUS": "U", "RFJ": "G", "V3L": "A",
310
+ "VC7": "G", "Z ": "C", "ZAD": "A", "ZBC": "C", "ZBU": "U", "ZCY": "C",
311
+ "ZGU": "G", "31H": "A", "31M": "A", "3AU": "U", "3DA": "A", "3ME": "U",
312
+ "3MU": "U", "3TD": "U", "70U": "U", "7AT": "A", "7DA": "A", "7GU": "G",
313
+ "7MG": "G", "7BG": "G", "73W": "C", "75B": "U", "7OK": "C", "7S3": "G",
314
+ "7SN": "G", "C ": "C", "C25": "C", "C2L": "C", "C2S": "C", "C31": "C",
315
+ "C32": "C", "C34": "C", "C36": "C", "C37": "C", "C38": "C", "C42": "C",
316
+ "C43": "C", "C45": "C", "C46": "C", "C49": "C", "C4S": "C", "C5L": "C",
317
+ "C6G": "G", "CAR": "C", "CB2": "C", "CBR": "C", "CBV": "C", "CCC": "C",
318
+ "CDW": "C", "CFL": "C", "CFZ": "C", "CG1": "G", "CH ": "C", "CMR": "C",
319
+ "CNU": "U", "CP1": "C", "CSF": "C", "CSL": "C", "CTG": "T", "CX2": "C",
320
+ "C7S": "C", "C7R": "C", "G ": "G", "G1G": "G", "G25": "G", "G2L": "G",
321
+ "G2S": "G", "G31": "G", "G32": "G", "G33": "G", "G36": "G", "G38": "G",
322
+ "G42": "G", "G46": "G", "G47": "G", "G48": "G", "G49": "G", "G7M": "G",
323
+ "GAO": "G", "GCK": "C", "GDO": "G", "GDP": "G", "GDR": "G", "GF2": "G",
324
+ "GFL": "G", "GH3": "G", "GMS": "G", "GN7": "G", "GNG": "G", "GOM": "G",
325
+ "GRB": "G", "GS ": "G", "GSR": "G", "GSS": "G", "GTP": "G", "GX1": "G",
326
+ "KAG": "G", "KAK": "G", "O2G": "G", "OGX": "G", "OMC": "C", "OMG": "G",
327
+ "OMU": "U", "ONE": "U", "O2Z": "A", "OKN": "C", "OKQ": "C", "S2M": "T",
328
+ "S4A": "A", "S4C": "C", "S4G": "G", "S4U": "U", "S6G": "G", "SC ": "C",
329
+ "SDE": "A", "SDG": "G", "SDH": "G", "SMP": "A", "SMT": "T", "SPT": "T",
330
+ "SRA": "A", "SSU": "U", "SUR": "U", "00A": "A", "0AD": "G", "0AM": "A",
331
+ "0AP": "C", "0AV": "A", "0R8": "C", "0SP": "A", "0UH": "G", "47C": "C",
332
+ "4OC": "C", "4PC": "C", "4PD": "C", "4PE": "C", "4SC": "C", "4SU": "U",
333
+ "45A": "A", "4U3": "C", "8AG": "G", "8AN": "A", "8BA": "A", "8FG": "G",
334
+ "8MG": "G", "8OG": "G", "8PY": "G", "8AA": "G", "85Y": "U", "8OS": "G",
335
+ "UNK": "X", # DEBUG
336
+ }
337
+
338
+ standard_protein_letters_3to1 = protein_letters_3to1
339
+ standard_protein_letters_1to3 = protein_letters_1to3
340
+ nonstandard_protein_letters_3to1 = {k: v for k, v in protein_letters_3to1_extended.items() if
341
+ k not in standard_protein_letters_3to1}
342
+
343
+ standard_nucleic_letters_3to1 = nucleic_letters_3to1
344
+ standard_nucleic_letters_1to3 = {v: k for k, v in standard_nucleic_letters_3to1.items()}
345
+ nonstandard_nucleic_letters_3to1 = {k: v for k, v in nucleic_letters_3to1_extended.items() if
346
+ k not in standard_nucleic_letters_3to1}
347
+
348
+ letters_3to1_extended = {**protein_letters_3to1_extended, **nucleic_letters_3to1_extended}
PhysDock/data/tools/__init__.py ADDED
File without changes
PhysDock/data/tools/alignment_runner.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os.path
3
+ from functools import partial
4
+ import tqdm
5
+ from typing import Optional, Mapping, Any, Union
6
+
7
+ from PhysDock.data.tools import jackhmmer, nhmmer, hhblits, kalign, hmmalign, parsers, hmmbuild, hhsearch, templates
8
+ from PhysDock.utils.io_utils import load_pkl, load_txt, load_json, run_pool_tasks, convert_md5_string, dump_pkl
9
+ from PhysDock.data.tools.parsers import parse_fasta
10
+
11
+ TemplateSearcher = Union[hhsearch.HHSearch]
12
+
13
+
14
+ class AlignmentRunner:
15
+ def __init__(
16
+ self,
17
+ # Homo Search Tools
18
+ jackhmmer_binary_path: Optional[str] = None,
19
+ hhblits_binary_path: Optional[str] = None,
20
+ nhmmer_binary_path: Optional[str] = None,
21
+ hmmbuild_binary_path: Optional[str] = None,
22
+ hmmalign_binary_path: Optional[str] = None,
23
+ kalign_binary_path: Optional[str] = None,
24
+
25
+ # Templ Search Tools
26
+ hhsearch_binary_path: Optional[str] = None,
27
+ template_searcher: Optional[TemplateSearcher] = None,
28
+ template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
29
+
30
+ # Databases
31
+ uniref90_database_path: Optional[str] = None,
32
+ uniprot_database_path: Optional[str] = None,
33
+ uniclust30_database_path: Optional[str] = None,
34
+ uniref30_database_path: Optional[str] = None,
35
+ bfd_database_path: Optional[str] = None,
36
+ reduced_bfd_database_path: Optional[str] = None,
37
+ mgnify_database_path: Optional[str] = None,
38
+ rfam_database_path: Optional[str] = None,
39
+ rnacentral_database_path: Optional[str] = None,
40
+ nt_database_path: Optional[str] = None,
41
+ #
42
+ no_cpus: int = 8,
43
+ # Limitations
44
+ uniref90_seq_limit: int = 100000,
45
+ uniprot_seq_limit: int = 500000,
46
+ reduced_bfd_seq_limit: int = 50000,
47
+ mgnify_seq_limit: int = 50000,
48
+ uniref90_max_hits: int = 10000,
49
+ uniprot_max_hits: int = 50000,
50
+ reduced_bfd_max_hits: int = 5000,
51
+ mgnify_max_hits: int = 5000,
52
+ rfam_max_hits: int = 10000,
53
+ rnacentral_max_hits: int = 10000,
54
+ nt_max_hits: int = 10000,
55
+ ):
56
+ self.uniref90_jackhmmer_runner = None
57
+ self.uniprot_jackhmmer_runner = None
58
+ self.reduced_bfd_jackhmmer_runner = None
59
+ self.mgnify_jackhmmer_runner = None
60
+ self.bfd_uniref30_hhblits_runner = None
61
+ self.bfd_uniclust30_hhblits_runner = None
62
+ self.rfam_nhmmer_runner = None
63
+ self.rnacentral_nhmmer_runner = None
64
+ self.nt_nhmmer_runner = None
65
+ self.rna_realign_runner = None
66
+ self.template_searcher = template_searcher
67
+ self.template_featurizer = template_featurizer
68
+
69
+ def _all_exists(*objs, hhblits_mode=False):
70
+ if not hhblits_mode:
71
+ for obj in objs:
72
+ if obj is None or not os.path.exists(obj):
73
+ return False
74
+ else:
75
+ for obj in objs:
76
+ if obj is None or not os.path.exists(os.path.split(obj)[0]):
77
+ return False
78
+ return True
79
+
80
+ def _run_msa_tool(
81
+ fasta_path: str,
82
+ msa_out_path: str,
83
+ msa_runner,
84
+ msa_format: str,
85
+ max_sto_sequences: Optional[int] = None,
86
+ ) -> Mapping[str, Any]:
87
+ """Runs an MSA tool, checking if output already exists first."""
88
+ if (msa_format == "sto" and max_sto_sequences is not None):
89
+ result = msa_runner.query(fasta_path, max_sto_sequences)[0]
90
+ else:
91
+ result = msa_runner.query(fasta_path)[0]
92
+
93
+ assert msa_out_path.split('.')[-1] == msa_format
94
+ with open(msa_out_path, "w") as f:
95
+ f.write(result[msa_format])
96
+
97
+ return result
98
+
99
+ def _run_rna_realign_tool(
100
+ fasta_path: str,
101
+ msa_in_path: str,
102
+ msa_out_path: str,
103
+ use_precompute=True,
104
+ ):
105
+ runner = hmmalign.Hmmalign(
106
+ hmmbuild_binary_path=hmmbuild_binary_path,
107
+ hmmalign_binary_path=hmmalign_binary_path,
108
+ )
109
+ if os.path.exists(msa_in_path) and os.path.getsize(msa_in_path) == 0:
110
+ # print("MSA sto file is 0")
111
+ with open(msa_out_path, "w") as f:
112
+ pass
113
+ return
114
+ if use_precompute:
115
+ if os.path.exists(msa_in_path) and os.path.exists(msa_out_path):
116
+ if os.path.getsize(msa_in_path) > 0 and os.path.getsize(msa_out_path) == 0:
117
+ logging.warning(f"The msa realign file size is zero but the origin file size is over 0! "
118
+ f"fasta: {fasta_path} msa_in_file: {msa_in_path}")
119
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
120
+ else:
121
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
122
+ else:
123
+ runner.realign_sto_with_fasta(fasta_path, msa_in_path, msa_out_path)
124
+ # with open(msa_out_path, "w") as f:
125
+ # f.write(msa_out)
126
+
127
+ assert uniclust30_database_path is None or uniref30_database_path is None, "Only one used"
128
+
129
+ # Jackhmmer
130
+ if _all_exists(jackhmmer_binary_path, uniref90_database_path):
131
+ self.uniref90_jackhmmer_runner = partial(
132
+ _run_msa_tool,
133
+ msa_runner=jackhmmer.Jackhmmer(
134
+ binary_path=jackhmmer_binary_path,
135
+ database_path=uniref90_database_path,
136
+ seq_limit=uniref90_seq_limit,
137
+ n_cpu=no_cpus,
138
+ ),
139
+ msa_format="sto",
140
+ max_sto_sequences=uniref90_max_hits
141
+ )
142
+
143
+ if _all_exists(jackhmmer_binary_path, uniprot_database_path):
144
+ self.uniprot_jackhmmer_runner = partial(
145
+ _run_msa_tool,
146
+ msa_runner=jackhmmer.Jackhmmer(
147
+ binary_path=jackhmmer_binary_path,
148
+ database_path=uniprot_database_path,
149
+ seq_limit=uniprot_seq_limit,
150
+ n_cpu=no_cpus,
151
+ ),
152
+ msa_format="sto",
153
+ max_sto_sequences=uniprot_max_hits
154
+ )
155
+ if _all_exists(jackhmmer_binary_path, reduced_bfd_database_path):
156
+ self.reduced_bfd_jackhmmer_runner = partial(
157
+ _run_msa_tool,
158
+ msa_runner=jackhmmer.Jackhmmer(
159
+ binary_path=jackhmmer_binary_path,
160
+ database_path=reduced_bfd_database_path,
161
+ seq_limit=reduced_bfd_seq_limit,
162
+ n_cpu=no_cpus,
163
+ ),
164
+ msa_format="sto",
165
+ max_sto_sequences=reduced_bfd_max_hits
166
+ )
167
+
168
+ if _all_exists(jackhmmer_binary_path, mgnify_database_path):
169
+ self.mgnify_jackhmmer_runner = partial(
170
+ _run_msa_tool,
171
+ msa_runner=jackhmmer.Jackhmmer(
172
+ binary_path=jackhmmer_binary_path,
173
+ database_path=mgnify_database_path,
174
+ seq_limit=mgnify_seq_limit,
175
+ n_cpu=no_cpus,
176
+ ),
177
+ msa_format="sto",
178
+ max_sto_sequences=mgnify_max_hits
179
+ )
180
+
181
+ # HHblits
182
+ if _all_exists(hhblits_binary_path, bfd_database_path, uniref30_database_path, hhblits_mode=True):
183
+ self.bfd_uniref30_hhblits_runner = partial(
184
+ _run_msa_tool,
185
+ msa_runner=hhblits.HHBlits(
186
+ binary_path=hhblits_binary_path,
187
+ databases=[bfd_database_path, uniref30_database_path],
188
+ n_cpu=no_cpus,
189
+ ),
190
+ msa_format="a3m",
191
+ )
192
+ elif _all_exists(hhblits_binary_path, bfd_database_path, uniclust30_database_path, hhblits_mode=True):
193
+ self.bfd_uniclust30_hhblits_runner = partial(
194
+ _run_msa_tool,
195
+ msa_runner=hhblits.HHBlits(
196
+ binary_path=hhblits_binary_path,
197
+ databases=[bfd_database_path, uniclust30_database_path],
198
+ n_cpu=no_cpus,
199
+ ),
200
+ msa_format="a3m",
201
+ )
202
+
203
+ # Nhmmer
204
+ if _all_exists(nhmmer_binary_path, rfam_database_path):
205
+ self.rfam_nhmmer_runner = partial(
206
+ _run_msa_tool,
207
+ msa_runner=nhmmer.Nhmmer(
208
+ binary_path=nhmmer_binary_path,
209
+ database_path=rfam_database_path,
210
+ n_cpu=no_cpus
211
+ ),
212
+ msa_format="sto",
213
+ max_sto_sequences=rfam_max_hits
214
+ )
215
+ if _all_exists(nhmmer_binary_path, rnacentral_database_path):
216
+ self.rnacentral_nhmmer_runner = partial(
217
+ _run_msa_tool,
218
+ msa_runner=nhmmer.Nhmmer(
219
+ binary_path=nhmmer_binary_path,
220
+ database_path=rnacentral_database_path,
221
+ n_cpu=no_cpus
222
+ ),
223
+ msa_format="sto",
224
+ max_sto_sequences=rnacentral_max_hits
225
+ )
226
+ if _all_exists(nhmmer_binary_path, nt_database_path):
227
+ self.nt_nhmmer_runner = partial(
228
+ _run_msa_tool,
229
+ msa_runner=nhmmer.Nhmmer(
230
+ binary_path=nhmmer_binary_path,
231
+ database_path=nt_database_path,
232
+ n_cpu=no_cpus
233
+ ),
234
+ msa_format="sto",
235
+ max_sto_sequences=nt_max_hits
236
+ )
237
+
238
+ # def _run_rna_hmm(
239
+ # fasta_path: str,
240
+ # hmm_out_path: str,
241
+ # ):
242
+ # runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
243
+ # hmm = runner.build_rna_profile_from_fasta(fasta_path)
244
+ # with open(hmm_out_path, "w") as f:
245
+ # f.write(hmm)
246
+
247
+ if _all_exists(hmmbuild_binary_path, hmmalign_binary_path):
248
+ self.rna_realign_runner = _run_rna_realign_tool
249
+
250
+ def run(self, input_fasta_path, output_msas_dir, use_precompute=True):
251
+ os.makedirs(output_msas_dir, exist_ok=True)
252
+ templates_out_path = os.path.join(output_msas_dir, "templates")
253
+ uniref90_out_path = os.path.join(output_msas_dir, "uniref90_hits.sto")
254
+ uniprot_out_path = os.path.join(output_msas_dir, "uniprot_hits.sto")
255
+ reduced_bfd_out_path = os.path.join(output_msas_dir, "reduced_bfd_hits.sto")
256
+ mgnify_out_path = os.path.join(output_msas_dir, "mgnify_hits.sto")
257
+ bfd_uniref30_out_path = os.path.join(output_msas_dir, f"bfd_uniref30_hits.a3m")
258
+ bfd_uniclust30_out_path = os.path.join(output_msas_dir, f"bfd_uniclust30_hits.a3m")
259
+
260
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
261
+ prefix = "protein"
262
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
263
+ output_feature = os.path.dirname(output_msas_dir)
264
+ output_feature = os.path.dirname(output_feature)
265
+ pkl_save_path_msa = os.path.join(output_feature, "msa_features", f"{md5}.pkl.gz")
266
+ pkl_save_path_msa_uni = os.path.join(output_feature, "uniprot_msa_features", f"{md5}.pkl.gz")
267
+ pkl_save_path_temp = os.path.join(output_feature, "template_features", f"{md5}.pkl.gz")
268
+
269
+ if self.uniref90_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_temp):
270
+ if not os.path.exists(uniref90_out_path) or not use_precompute or not os.path.exists(pkl_save_path_temp):
271
+ if not os.path.exists(uniref90_out_path):
272
+ print(uniref90_out_path)
273
+ self.uniref90_jackhmmer_runner(input_fasta_path, uniref90_out_path)
274
+
275
+ print("begin templates")
276
+ if templates_out_path is not None \
277
+ and self.template_searcher is not None and self.template_featurizer is not None:
278
+ try:
279
+ os.makedirs(templates_out_path, exist_ok=True)
280
+ seq, dec = parsers.parse_fasta(load_txt(input_fasta_path))
281
+ input_sequence = seq[0]
282
+ msa_for_templates = parsers.truncate_stockholm_msa(
283
+ uniref90_out_path, max_sequences=10000
284
+ )
285
+ msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
286
+ msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
287
+ msa_for_templates
288
+ )
289
+ if self.template_searcher.input_format == "sto":
290
+ pdb_templates_result = self.template_searcher.query(msa_for_templates)
291
+ elif self.template_searcher.input_format == "a3m":
292
+ uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(msa_for_templates)
293
+ pdb_templates_result = self.template_searcher.query(uniref90_msa_as_a3m)
294
+ else:
295
+ raise ValueError(
296
+ "Unrecognized template input format: "
297
+ f"{self.template_searcher.input_format}"
298
+ )
299
+
300
+ pdb_hits_out_path = os.path.join(
301
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}.pkl.gz"
302
+ )
303
+ with open(os.path.join(
304
+ templates_out_path, f"pdb_hits.{self.template_searcher.output_format}"
305
+ ), "w") as f:
306
+ f.write(pdb_templates_result)
307
+
308
+ pdb_template_hits = self.template_searcher.get_template_hits(
309
+ output_string=pdb_templates_result, input_sequence=input_sequence
310
+ )
311
+ templates_result = self.template_featurizer.get_templates(
312
+ query_sequence=input_sequence, hits=pdb_template_hits
313
+ )
314
+ except Exception as e:
315
+ logging.exception("An error in template searching")
316
+
317
+ dump_pkl(templates_result.features, pdb_hits_out_path, compress=True)
318
+ if self.uniprot_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa_uni):
319
+ if not os.path.exists(uniprot_out_path) or not use_precompute:
320
+ self.uniprot_jackhmmer_runner(input_fasta_path, uniprot_out_path)
321
+ if self.reduced_bfd_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
322
+ if not os.path.exists(reduced_bfd_out_path) or not use_precompute:
323
+ self.reduced_bfd_jackhmmer_runner(input_fasta_path, reduced_bfd_out_path)
324
+ if self.mgnify_jackhmmer_runner is not None and not os.path.exists(pkl_save_path_msa):
325
+ if not os.path.exists(mgnify_out_path) or not use_precompute:
326
+ self.mgnify_jackhmmer_runner(input_fasta_path, mgnify_out_path)
327
+ if self.bfd_uniref30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
328
+ if not os.path.exists(bfd_uniref30_out_path) or not use_precompute:
329
+ self.bfd_uniref30_hhblits_runner(input_fasta_path, bfd_uniref30_out_path)
330
+ if self.bfd_uniclust30_hhblits_runner is not None and not os.path.exists(pkl_save_path_msa):
331
+ if not os.path.exists(bfd_uniclust30_out_path) or not use_precompute:
332
+ self.bfd_uniclust30_hhblits_runner(input_fasta_path, bfd_uniclust30_out_path)
333
+ # if self.rfam_nhmmer_runner is not None:
334
+ # if not os.path.exists(rfam_out_path) or not use_precompute:
335
+ # self.rfam_nhmmer_runner(input_fasta_path, rfam_out_path)
336
+ # # print(self.rna_realign_runner is not None, os.path.exists(rfam_out_path))
337
+ # if self.rna_realign_runner is not None and os.path.exists(rfam_out_path):
338
+ # self.rna_realign_runner(input_fasta_path, rfam_out_path, rfam_out_realigned_path)
339
+ # if self.rnacentral_nhmmer_runner is not None:
340
+ # if not os.path.exists(rnacentral_out_path) or not use_precompute:
341
+ # self.rnacentral_nhmmer_runner(input_fasta_path, rnacentral_out_path)
342
+ # if self.rna_realign_runner is not None and os.path.exists(rnacentral_out_path):
343
+ # self.rna_realign_runner(input_fasta_path, rnacentral_out_path, rnacentral_out_realigned_path)
344
+ # if self.nt_nhmmer_runner is not None:
345
+ # if not os.path.exists(nt_out_path) or not use_precompute:
346
+ # self.nt_nhmmer_runner(input_fasta_path, nt_out_path)
347
+ # if self.rna_realign_runner is not None and os.path.exists(nt_out_path):
348
+ # # print("realign",nt_out_path,nt_out_realigned_path)
349
+ # self.rna_realign_runner(input_fasta_path, nt_out_path, nt_out_realigned_path)
350
+
351
+
352
+ class DataProcessor:
353
+ def __init__(
354
+ self,
355
+ alphafold3_database_path,
356
+ jackhmmer_binary_path: Optional[str] = None,
357
+ hhblits_binary_path: Optional[str] = None,
358
+ nhmmer_binary_path: Optional[str] = None,
359
+ kalign_binary_path: Optional[str] = None,
360
+ hmmbuild_binary_path: Optional[str] = None,
361
+ hmmalign_binary_path: Optional[str] = None,
362
+ hhsearch_binary_path: Optional[str] = None,
363
+ template_searcher: Optional[TemplateSearcher] = None,
364
+ template_featurizer: Optional[templates.TemplateHitFeaturizer] = None,
365
+ n_cpus: int = 8,
366
+ n_workers: int = 1,
367
+ ):
368
+ '''
369
+ Database Versions:
370
+ Training:
371
+ uniref90: v2022_05
372
+ bfd:
373
+ reduces_bfd:
374
+ uniclust30: v2018_08
375
+ uniprot: v2020_05
376
+ mgnify: v2022_05
377
+ rfam: v14.9
378
+ rnacentral: v21.0
379
+ nt: v2023_02_23
380
+ Inference:
381
+ uniref90: v2022_05
382
+ bfd:
383
+ reduces_bfd:
384
+ uniclust30: v2018_08
385
+ uniprot: v2021_04 *
386
+ mgnify: v2022_05
387
+ rfam: v14.9
388
+ rnacentral: v21.0
389
+ nt: v2023_02_23
390
+ Inference Ligand:
391
+ uniref90: v2020_01 *
392
+ bfd:
393
+ reduces_bfd:
394
+ uniclust30: v2018_08
395
+ uniprot: v2020_05
396
+ mgnify: v2018_12 *
397
+ rfam: v14.9
398
+ rnacentral: v21.0
399
+ nt: v2023_02_23
400
+
401
+ Args:
402
+ alphafold3_database_path: Database dir that contains all alphafold3 databases
403
+ jackhmmer_binary_path:
404
+ hhblits_binary_path:
405
+ nhmmer_binary_path:
406
+ kalign_binary_path:
407
+ hmmaligh_binary_path:
408
+ n_cpus:
409
+ n_workers:
410
+ '''
411
+ self.jackhmmer_binary_path = jackhmmer_binary_path
412
+ self.hhblits_binary_path = hhblits_binary_path
413
+ self.nhmmer_binary_path = nhmmer_binary_path
414
+ self.hmmbuild_binary_path = hmmbuild_binary_path
415
+ self.hmmalign_binary_path = hmmalign_binary_path
416
+ self.hhsearch_binary_path = hhsearch_binary_path
417
+
418
+ self.template_searcher = template_searcher
419
+ self.template_featurizer = template_featurizer
420
+
421
+ self.n_cpus = n_cpus
422
+ self.n_workers = n_workers
423
+
424
+ self.uniref90_database_path = os.path.join(
425
+ alphafold3_database_path, "uniref90", "uniref90.fasta"
426
+ )
427
+ ################### TODO: DEBUG
428
+ self.uniprot_database_path = os.path.join(
429
+ alphafold3_database_path, "uniprot", "uniprot.fasta"
430
+ )
431
+ self.bfd_database_path = os.path.join(
432
+ alphafold3_database_path, "bfd", "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
433
+ )
434
+ self.uniclust30_database_path = os.path.join(
435
+ alphafold3_database_path, "uniclust30", "uniclust30_2018_08", "uniclust30_2018_08"
436
+ )
437
+ ################### TODO: check alphafold2 multimer uniref30 version
438
+ self.uniref_30_database_path = os.path.join(
439
+ alphafold3_database_path, "uniref30", "v2020_06"
440
+ )
441
+ # self.reduced_bfd_database_path = os.path.join(
442
+ # alphafold3_database_path,"reduced_bfd"
443
+ # )
444
+
445
+ self.mgnify_database_path = os.path.join(
446
+ alphafold3_database_path, "mgnify", "mgnify", "mgy_clusters.fa"
447
+ )
448
+ self.rfam_database_path = os.path.join(
449
+ alphafold3_database_path, "rfam", "v14.9", "Rfam_af3_clustered_rep_seq.fasta"
450
+ )
451
+ self.rnacentral_database_path = os.path.join(
452
+ alphafold3_database_path, "rnacentral", "v21.0", "rnacentral_db_rep_seq.fasta"
453
+ )
454
+
455
+ self.nt_database_path = os.path.join(
456
+ # alphafold3_database_path, "nt", "v2023_02_23", "nt_af3_clustered_rep_seq.fasta" # DEBUG
457
+ alphafold3_database_path, "nt", "v2023_02_23", "nt.fasta"
458
+ )
459
+
460
+ self.runner_args_map = {
461
+ "uniref90": {
462
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
463
+ "uniref90_database_path": self.uniref90_database_path,
464
+ },
465
+ "bfd_uniclust30": {
466
+ "hhblits_binary_path": self.hhblits_binary_path,
467
+ "bfd_database_path": self.bfd_database_path,
468
+ "uniclust30_database_path": self.uniclust30_database_path
469
+ },
470
+ "bfd_uniref30": {
471
+ "hhblits_binary_path": self.hhblits_binary_path,
472
+ "bfd_database_path": self.bfd_database_path,
473
+ "uniref_30_database_path": self.uniref_30_database_path
474
+ },
475
+
476
+ "mgnify": {
477
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
478
+ "mgnify_database_path": self.mgnify_database_path,
479
+ },
480
+ "uniprot": {
481
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
482
+ "uniprot_database_path": self.uniprot_database_path,
483
+ },
484
+ ###################### RNA ########################
485
+ "rfam": {
486
+ "nhmmer_binary_path": self.nhmmer_binary_path,
487
+ "rfam_database_path": self.rfam_database_path,
488
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
489
+ "hmmalign_binary_path": self.hmmalign_binary_path,
490
+ },
491
+ "rnacentral": {
492
+ "nhmmer_binary_path": self.nhmmer_binary_path,
493
+ "rnacentral_database_path": self.rnacentral_database_path,
494
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
495
+ "hmmalign_binary_path": self.hmmalign_binary_path,
496
+ },
497
+ "nt": {
498
+ "nhmmer_binary_path": self.nhmmer_binary_path,
499
+ "nt_database_path": self.nt_database_path,
500
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
501
+ "hmmalign_binary_path": self.hmmalign_binary_path,
502
+ },
503
+
504
+ ###################################################
505
+ "alphafold2": {
506
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
507
+ "hhblits_binary_path": self.hhblits_binary_path,
508
+ "uniref90_database_path": self.uniref90_database_path,
509
+ "bfd_database_path": self.bfd_database_path,
510
+ "uniclust30_database_path": self.uniclust30_database_path,
511
+ "mgnify_database_path": self.mgnify_database_path,
512
+ },
513
+ "alphafold2_multimer": {
514
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
515
+ "hhblits_binary_path": self.hhblits_binary_path,
516
+ "uniref90_database_path": self.uniref90_database_path,
517
+ "bfd_database_path": self.bfd_database_path,
518
+ "uniref_30_database_path": self.uniref_30_database_path,
519
+ "mgnify_database_path": self.mgnify_database_path,
520
+ "uniprot_database_path": self.uniprot_database_path,
521
+ },
522
+ "alphafold3": {
523
+ "jackhmmer_binary_path": self.jackhmmer_binary_path,
524
+ "hhblits_binary_path": self.hhblits_binary_path,
525
+ "template_searcher": self.template_searcher,
526
+ "template_featurizer": self.template_featurizer,
527
+ "uniref90_database_path": self.uniref90_database_path,
528
+ "bfd_database_path": self.bfd_database_path,
529
+ "uniclust30_database_path": self.uniclust30_database_path,
530
+ "mgnify_database_path": self.mgnify_database_path,
531
+ "uniprot_database_path": self.uniprot_database_path,
532
+ },
533
+
534
+ "rna": {
535
+ "nhmmer_binary_path": self.nhmmer_binary_path,
536
+ "rfam_database_path": self.rfam_database_path,
537
+ "rnacentral_database_path": self.rnacentral_database_path,
538
+ "hmmbuild_binary_path": self.hmmbuild_binary_path,
539
+ "hmmalign_binary_path": self.hmmalign_binary_path,
540
+ },
541
+ }
542
+
543
+ def _parse_io_tuples(self, input_fasta_path, output_dir, convert_md5=True, prefix="protein"):
544
+ os.makedirs(output_dir, exist_ok=True)
545
+ if isinstance(input_fasta_path, list):
546
+ input_fasta_paths = input_fasta_path
547
+ elif os.path.isdir(input_fasta_path):
548
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
549
+ elif os.path.isfile(input_fasta_path):
550
+ input_fasta_paths = [input_fasta_path]
551
+ else:
552
+ input_fasta_paths = []
553
+ Exception("Can't parse input fasta path!")
554
+ seqs = [parse_fasta(load_txt(i))[0][0] for i in input_fasta_paths]
555
+ # sequences = [parsers.parse_fasta(load_txt(path))[0][0] for path in input_fasta_paths]
556
+ # TODO: debug
557
+ if convert_md5:
558
+ output_msas_dirs = [os.path.join(output_dir, convert_md5_string(f"{prefix}:{i}")) for i in
559
+ seqs]
560
+ else:
561
+ output_msas_dirs = [os.path.join(output_dir, os.path.split(i)[1].split(".")[0]) for i in input_fasta_paths]
562
+ io_tuples = [(i, o) for i, o in zip(input_fasta_paths, output_msas_dirs)]
563
+ return io_tuples
564
+
565
+ def _process_iotuple(self, io_tuple, msas_type):
566
+ i, o = io_tuple
567
+ alignment_runner = AlignmentRunner(
568
+ **self.runner_args_map[msas_type],
569
+ no_cpus=self.n_cpus
570
+ )
571
+ try:
572
+ alignment_runner.run(i, o)
573
+ except:
574
+ logging.warning(f"{i}:{o} task failed!")
575
+
576
+ def process(self, input_fasta_path, output_dir, msas_type="rfam", convert_md5=True):
577
+ prefix = "rna" if msas_type in ["rfam", "rnacentral", "nt", "rna"] else "protein"
578
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=convert_md5, prefix=prefix)
579
+ run_pool_tasks(partial(self._process_iotuple, msas_type=msas_type), io_tuples, num_workers=self.n_workers,
580
+ return_dict=False)
581
+
582
+ def convert_output_to_md5(self, input_fasta_path, output_dir, md5_output_dir, prefix="protein"):
583
+ io_tuples = self._parse_io_tuples(input_fasta_path, output_dir, convert_md5=False, prefix=prefix)
584
+ io_tuples_md5 = self._parse_io_tuples(input_fasta_path, md5_output_dir, convert_md5=True, prefix=prefix)
585
+
586
+ for io0, io1 in tqdm.tqdm(zip(io_tuples, io_tuples_md5)):
587
+ o, o_md5 = io0[1], io1[1]
588
+ os.system(f"cp -r {os.path.abspath(o)} {os.path.abspath(o_md5)}")
PhysDock/data/tools/convert_unifold_template_to_stfold.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ sys.path.append("../")
8
+
9
+ from PhysDock.utils.io_utils import load_pkl, dump_pkl
10
+ from PhysDock.data.tools.residue_constants import \
11
+ hhblits_id_to_standard_residue_id_np, af3_if_to_residue_id
12
+
13
+ def dgram_from_positions(
14
+ pos: torch.Tensor,
15
+ min_bin: float = 3.25,
16
+ max_bin: float = 50.75,
17
+ no_bins: float = 39,
18
+ inf: float = 1e8,
19
+ ):
20
+ dgram = torch.sum(
21
+ (pos[..., None, :] - pos[..., None, :, :]) ** 2, dim=-1, keepdim=True
22
+ )
23
+ lower = torch.linspace(min_bin, max_bin, no_bins, device=pos.device) ** 2
24
+ upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
25
+ dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
26
+ return dgram
27
+
28
+
29
+ def convert_unifold_template_feature_to_stfold_unifold_feature(unifold_template_feature):
30
+ try:
31
+ print(unifold_template_feature)
32
+ md5_string = os.path.basename(unifold_template_feature)[:-6]
33
+ out_path = os.path.dirname(unifold_template_feature)
34
+ out_path = os.path.dirname(out_path)
35
+ unifold_template_feature = os.path.join(out_path, "msas", md5_string)
36
+ out_path_final = os.path.join(out_path, "msas", "template_features")
37
+ final = os.path.join(out_path_final, f"{md5_string}.pkl.gz")
38
+ if os.path.exists(final):
39
+ return dict()
40
+
41
+ unifold_template_feature = os.path.join(unifold_template_feature, "templates", "pdb_hits.hhr.pkl.gz")
42
+ if isinstance(unifold_template_feature, str):
43
+ data = load_pkl(unifold_template_feature)
44
+ else:
45
+ data = unifold_template_feature
46
+ template_restype = af3_if_to_residue_id[
47
+ hhblits_id_to_standard_residue_id_np[np.argmax(data["template_aatype"], axis=-1)]]
48
+ assert np.all(template_restype != -1)
49
+ assert len(template_restype) >= 1
50
+ assert len(template_restype[0, :]) >= 4
51
+ # shape = template_restype.shape
52
+ # template_restype=template_restype.view([-1])[].view(shape)
53
+
54
+ bb_x_gt = torch.from_numpy(data["template_all_atom_positions"][..., :3, :])
55
+ bb_x_mask = torch.from_numpy(data["template_all_atom_masks"][..., :3])
56
+
57
+ bb_x_gt_beta1 = data["template_all_atom_positions"][..., 3, :]
58
+ bb_x_gt_beta_mask1 = data["template_all_atom_masks"][..., 3]
59
+ bb_x_gt_beta2 = data["template_all_atom_positions"][..., 1, :]
60
+ bb_x_gt_beta_mask2 = data["template_all_atom_masks"][..., 1]
61
+
62
+ is_gly = template_restype == 7
63
+ template_pseudo_beta = np.where(is_gly[..., None], bb_x_gt_beta2, bb_x_gt_beta1)
64
+ template_pseudo_beta_mask = np.where(is_gly, bb_x_gt_beta_mask2, bb_x_gt_beta_mask1)
65
+ template_backbone_frame_mask = bb_x_mask[..., 0] * bb_x_mask[..., 1] * bb_x_mask[..., 2]
66
+ out = {
67
+ "template_restype": template_restype.astype(np.int8),
68
+ "template_backbone_frame_mask": template_backbone_frame_mask.numpy().astype(np.int8),
69
+ "template_backbone_frame": bb_x_gt.numpy().astype(np.float32),
70
+ "template_pseudo_beta": template_pseudo_beta.astype(np.float32),
71
+ "template_pseudo_beta_mask": template_pseudo_beta_mask.astype(np.int8),
72
+ # "template_backbone_mask": template_mask.numpy().astype(np.int8),
73
+ }
74
+ # for k,v in out.items():
75
+ # print(k,v.shape)
76
+ dump_pkl(out, os.path.join(out_path_final, f"{md5_string}.pkl.gz"), compress=True)
77
+ except:
78
+ pass
79
+ out = dict()
80
+ print(f"dump templ feats to {md5_string}.pkl.gz")
81
+ return out
82
+
83
+
84
+ # HHBLITS_ID_TO_AA = {
85
+ # 0: "ALA",
86
+ # 1: "CYS", # Also U.
87
+ # 2: "ASP", # Also B.
88
+ # 3: "GLU", # Also Z.
89
+ # 4: "PHE",
90
+ # 5: "GLY",
91
+ # 6: "HIS",
92
+ # 7: "ILE",
93
+ # 8: "LYS",
94
+ # 9: "LEU",
95
+ # 10: "MET",
96
+ # 11: "ASN",
97
+ # 12: "PRO",
98
+ # 13: "GLN",
99
+ # 14: "ARG",
100
+ # 15: "SER",
101
+ # 16: "THR",
102
+ # 17: "VAL",
103
+ # 18: "TRP",
104
+ # 19: "TYR",
105
+ # 20: "UNK", # Includes J and O.
106
+ # 21: "GAP",
107
+ # }
108
+ #
109
+ # # Usage: Convert hhblits msa to af3 aatype
110
+ # # msa = hhblits_id_to_standard_residue_id_np[hhblits_msa.astype(np.int64)]
111
+ # hhblits_id_to_standard_residue_id_np = np.array(
112
+ # [standard_ccds.index(ccd) for id, ccd in HHBLITS_ID_TO_AA.items()]
113
+ # )
114
+ #
115
+ # of_restypes = [
116
+ # "A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
117
+ # "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X", "-"
118
+ # ]
119
+ #
120
+ # af3_restypes = [amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "-" if ccd == "GAP" else "None" for ccd in
121
+ # standard_ccds
122
+ # ]
123
+ #
124
+ # af3_if_to_residue_id = np.array(
125
+ # [af3_restypes.index(restype) if restype in of_restypes else -1 for restype in af3_restypes])
126
+
127
+
PhysDock/data/tools/dataset_manager.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Optional
3
+ from functools import partial
4
+ import numpy as np
5
+
6
+ from PhysDock.utils.io_utils import run_pool_tasks, load_json, load_txt, dump_pkl, \
7
+ convert_md5_string
8
+ from PhysDock.data.tools.parsers import parse_fasta
9
+ from PhysDock.data.tools.parse_msas import parse_protein_alignment_dir, parse_uniprot_alignment_dir, \
10
+ parse_rna_alignment_dir
11
+ from PhysDock.data.alignment_runner import DataProcessor
12
+ from PhysDock.data.tools.residue_constants import standard_ccds, amino_acid_3to1
13
+ from PhysDock.data.tools.PDBData import protein_letters_3to1_extended, nucleic_letters_3to1_extended
14
+
15
+
16
+ def get_protein_md5(sequence_3):
17
+ ccds = sequence_3.split("-")
18
+
19
+ start = 0
20
+ end = 0
21
+ for i, ccd in enumerate(ccds):
22
+ if ccd not in ["UNK", "N ", "DN ", "GAP"]:
23
+ start = i
24
+ break
25
+ for i, ccd in enumerate(ccds[::-1]):
26
+ if ccd not in ["UNK", "N ", "DN ", "GAP"]:
27
+ end = i
28
+ break
29
+ # print(start,end)
30
+ ccds_strip_unk = ccds[start:-end] if end > 0 else ccds[start:]
31
+ sequence_0 = "".join(
32
+ [protein_letters_3to1_extended[ccd] if ccd in protein_letters_3to1_extended else "X" for ccd in ccds])
33
+ sequence_1 = "".join([amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "X" for ccd in ccds])
34
+ sequence_2 = "".join(
35
+ [protein_letters_3to1_extended[ccd] if ccd in protein_letters_3to1_extended else "X" for ccd in ccds_strip_unk])
36
+ sequence_3 = "".join([amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "X" for ccd in ccds_strip_unk])
37
+
38
+ sequences = []
39
+ for sequence in [sequence_0, sequence_1, sequence_2, sequence_3]:
40
+ if sequence not in sequences:
41
+ sequences.append(sequence)
42
+ return sequences, [convert_md5_string(f"protein:{i}") for i in sequences]
43
+
44
+
45
+ def get_rna_md5(sequence_3):
46
+ ccds = sequence_3.split("-")
47
+ chs = [nucleic_letters_3to1_extended[ccd] if ccd not in ["UNK", "GAP", "N ", "DN "] else "N" for ccd in ccds]
48
+ sequence = "".join(chs)
49
+ md5 = convert_md5_string(f"rna:{sequence}")
50
+ return sequence, md5
51
+
52
+
53
+ class DatasetManager:
54
+ def __init__(
55
+ self,
56
+ dataset_path
57
+ ):
58
+ self.dataset_path = dataset_path
59
+ # Meta Data
60
+ self.chain_id_to_meta_info = load_json(os.path.join(dataset_path, "chain_id_to_meta_info.json"))
61
+ self.pdb_id_to_meta_info = load_json(os.path.join(dataset_path, "pdb_id_to_meta_info.json"))
62
+ self.ccd_id_to_meta_info = load_json(os.path.join(dataset_path, "ccd_id_to_meta_info.json"))
63
+
64
+ # filtering chains
65
+ # self.train_polymer_chain_ids = load_json()
66
+ # self.validation_polymer_chain_ids = load_json()
67
+
68
+ # def check_protein_msa_features_completeness(input_fasta_path):
69
+ # pass
70
+ #
71
+ # def check_protein_uniprot_msa_features_completeness(self, chain_ids, num_workers):
72
+ # def _run(chain_id):
73
+ # if chain_id not in self.chain_id_to_meta_info:
74
+ # return {
75
+ # "chain_id": f"Not find this chain {chain_id}"
76
+ # }
77
+ # sequence_3 = self.chain_id_to_meta_info[chain_id]["sequence_3"]
78
+ #
79
+ # out = run_pool_tasks(_run, chain_ids, return_dict=True, num_workers=num_workers)
80
+ # return out
81
+
82
+ @staticmethod
83
+ def homo_search(
84
+ input_fasta_path,
85
+ output_dir,
86
+ msas_type,
87
+ convert_md5,
88
+ alphafold3_database_path,
89
+ jackhmmer_binary_path: Optional[str] = None,
90
+ hhblits_binary_path: Optional[str] = None,
91
+ nhmmer_binary_path: Optional[str] = None,
92
+ kalign_binary_path: Optional[str] = None,
93
+ hmmbuild_binary_path: Optional[str] = None,
94
+ hmmalign_binary_path: Optional[str] = None,
95
+ n_cpus: int = 8,
96
+ n_workers: int = 1,
97
+ ):
98
+ data_processor = DataProcessor(
99
+ alphafold3_database_path=alphafold3_database_path,
100
+ jackhmmer_binary_path=jackhmmer_binary_path,
101
+ hhblits_binary_path=hhblits_binary_path,
102
+ nhmmer_binary_path=nhmmer_binary_path,
103
+ kalign_binary_path=kalign_binary_path,
104
+ hmmbuild_binary_path=hmmbuild_binary_path,
105
+ hmmalign_binary_path=hmmalign_binary_path,
106
+ n_cpus=n_cpus,
107
+ n_workers=n_workers,
108
+ )
109
+ data_processor.process(
110
+ input_fasta_path=input_fasta_path,
111
+ output_dir=output_dir,
112
+ msas_type=msas_type,
113
+ # msas_type="bfd_uniclust30", # alphafold3 # rna
114
+ # msas_type="uniprot", # alphafold3 # rna
115
+ convert_md5=convert_md5
116
+ )
117
+
118
+ @staticmethod
119
+ def get_unsearched_input_fasta_path(input_fasta_path, output_dir, msas_type, convert_md5, num_workers=128):
120
+ if isinstance(input_fasta_path, list):
121
+ input_fasta_paths = input_fasta_path
122
+ elif os.path.isdir(input_fasta_path):
123
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
124
+ elif os.path.isfile(input_fasta_path):
125
+ input_fasta_paths = [input_fasta_path]
126
+ else:
127
+ input_fasta_paths = []
128
+ Exception("Can't parse input fasta path!")
129
+
130
+ prefix = {
131
+ "uniref90": "protein",
132
+ "bfd_uniclust30": "protein",
133
+ "bfd_uniref30": "protein",
134
+ "uniprot": "protein",
135
+ "mgnify": "protein",
136
+ "rfam": "rna",
137
+ "rnacentral": "rna",
138
+ "nt": "rna",
139
+ }[msas_type]
140
+ global _get_unsearched_input_fasta_path
141
+
142
+ def _get_unsearched_input_fasta_path(input_fasta_path, convert_md5, prefix, output_dir, msas_type):
143
+ # TODO
144
+ # seqs, decs = parse_fasta(i)
145
+ if convert_md5:
146
+ # dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
147
+ dec = convert_md5_string(f"{prefix}:{seqs[0]}")
148
+ else:
149
+ dec = os.path.split(input_fasta_path)[1].split(".")[0]
150
+
151
+ if os.path.exists(os.path.join(output_dir, dec, f"{msas_type}_hits.sto")) or \
152
+ os.path.exists(os.path.join(output_dir, dec, f"{msas_type}_hits.a3m")):
153
+ return dict()
154
+ else:
155
+ return {input_fasta_path: False}
156
+
157
+ out = run_pool_tasks(partial(
158
+ _get_unsearched_input_fasta_path,
159
+ convert_md5=convert_md5,
160
+ prefix=prefix,
161
+ output_dir=output_dir,
162
+ msas_type=msas_type,
163
+ ), input_fasta_paths, num_workers=num_workers, return_dict=True)
164
+ return list(out.keys())
165
+
166
+ @staticmethod
167
+ def convert_msas_out_to_msa_features(
168
+ input_fasta_path,
169
+ output_dir,
170
+ msa_feature_dir,
171
+ convert_md5=True,
172
+ num_workers=128
173
+ ):
174
+ if isinstance(input_fasta_path, list):
175
+ input_fasta_paths = input_fasta_path
176
+ elif os.path.isdir(input_fasta_path):
177
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
178
+ elif os.path.isfile(input_fasta_path):
179
+ input_fasta_paths = [input_fasta_path]
180
+ else:
181
+ input_fasta_paths = []
182
+ Exception("Can't parse input fasta path!")
183
+
184
+ global _convert_msas_out_to_msa_features
185
+
186
+ def _convert_msas_out_to_msa_features(
187
+ input_fasta_path,
188
+ output_dir,
189
+ msa_feature_dir,
190
+ convert_md5=True,
191
+ ):
192
+ prefix = "protein"
193
+ max_seq = 16384
194
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
195
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
196
+ if convert_md5:
197
+ # TODO: debug
198
+ # dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
199
+ dec = md5
200
+ else:
201
+ dec = os.path.split(input_fasta_path)[1].split(".")[0]
202
+
203
+ # DEBUG: whc homo search hits
204
+ if os.path.exists(os.path.join(output_dir, dec, "msas")):
205
+ dec = dec + "/msas"
206
+
207
+ pkl_save_path = os.path.join(msa_feature_dir, f"{md5}.pkl.gz")
208
+
209
+ if os.path.exists(pkl_save_path):
210
+ return dict()
211
+ if os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
212
+ os.path.exists(os.path.join(output_dir, dec, "bfd_uniclust30_hits.a3m")) and \
213
+ os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
214
+ msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
215
+
216
+ sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
217
+ md5_string = convert_md5_string(f"protein:{sequence}")
218
+ if md5 == md5_string:
219
+ feature = {
220
+ "msa": msa_feature["msa"][:max_seq].astype(np.int8),
221
+ "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
222
+ "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
223
+ }
224
+ dump_pkl(feature, pkl_save_path)
225
+ return dict()
226
+ else:
227
+ return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
228
+
229
+ # DEBUG: whc
230
+ elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
231
+ os.path.exists(os.path.join(output_dir, dec, "bfd_uniref_hits.a3m")) and \
232
+ os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
233
+ msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
234
+
235
+ sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
236
+ md5_string = convert_md5_string(f"protein:{sequence}")
237
+ if md5 == md5_string:
238
+ feature = {
239
+ "msa": msa_feature["msa"][:max_seq].astype(np.int8),
240
+ "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
241
+ "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
242
+ }
243
+ dump_pkl(feature, pkl_save_path)
244
+ return dict()
245
+ else:
246
+ return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
247
+
248
+ elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
249
+ os.path.exists(os.path.join(output_dir, dec, "bfd_uniclust30_hits.a3m")):
250
+ msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
251
+ if len(msa_feature["msa"]) < max_seq:
252
+ return {
253
+ input_fasta_path: f"MSA is not enough!"
254
+ }
255
+ sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
256
+ md5_string = convert_md5_string(f"protein:{sequence}")
257
+ if md5 == md5_string:
258
+ feature = {
259
+ "msa": msa_feature["msa"][:max_seq].astype(np.int8),
260
+ "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
261
+ "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
262
+ }
263
+ dump_pkl(feature, pkl_save_path)
264
+ return dict()
265
+ else:
266
+ return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
267
+ elif os.path.exists(os.path.join(output_dir, dec, "uniref90_hits.sto")) and \
268
+ os.path.exists(os.path.join(output_dir, dec, "mgnify_hits.sto")):
269
+ msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
270
+
271
+ sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
272
+ md5_string = convert_md5_string(f"protein:{sequence}")
273
+ if md5 == md5_string:
274
+ feature = {
275
+ "msa": msa_feature["msa"][:max_seq].astype(np.int8),
276
+ "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
277
+ "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
278
+ }
279
+ dump_pkl(feature, pkl_save_path)
280
+ return dict()
281
+ else:
282
+ return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
283
+
284
+ else:
285
+ # msa_feature = parse_protein_alignment_dir(os.path.join(output_dir, dec))
286
+ #
287
+ # sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa"][0]])
288
+ # md5_string = convert_md5_string(f"protein:{sequence}")
289
+ # if md5 == md5_string:
290
+ # feature = {
291
+ # "msa": msa_feature["msa"][:max_seq].astype(np.int8),
292
+ # "deletion_matrix": msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
293
+ # "msa_species_identifiers": msa_feature["msa_species_identifiers"][:max_seq]
294
+ # }
295
+ # dump_pkl(feature, pkl_save_path)
296
+ # return dict()
297
+ # else:
298
+ # return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
299
+
300
+ return {
301
+ input_fasta_path: f"MSA is not enough!"
302
+ }
303
+
304
+ out = run_pool_tasks(partial(
305
+ _convert_msas_out_to_msa_features,
306
+ output_dir=output_dir,
307
+ msa_feature_dir=msa_feature_dir,
308
+ convert_md5=convert_md5
309
+ ), input_fasta_paths, num_workers=num_workers, return_dict=True)
310
+ return out
311
+
312
+ @staticmethod
313
+ def convert_msas_out_to_uniprot_msa_features(
314
+ input_fasta_path,
315
+ output_dir,
316
+ uniprot_msa_feature_dir,
317
+ convert_md5=True,
318
+ num_workers=128
319
+ ):
320
+ if isinstance(input_fasta_path, list):
321
+ input_fasta_paths = input_fasta_path
322
+ elif os.path.isdir(input_fasta_path):
323
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
324
+ elif os.path.isfile(input_fasta_path):
325
+ input_fasta_paths = [input_fasta_path]
326
+ else:
327
+ input_fasta_paths = []
328
+ Exception("Can't parse input fasta path!")
329
+
330
+ global _convert_msas_out_to_uniprot_msa_features
331
+
332
+ def _convert_msas_out_to_uniprot_msa_features(
333
+ input_fasta_path,
334
+ output_dir,
335
+ uniprot_msa_feature_dir,
336
+ convert_md5=True,
337
+ ):
338
+ prefix = "protein"
339
+ max_seq = 50000
340
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
341
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
342
+ if convert_md5:
343
+ # TODO: debug
344
+ # dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
345
+ dec = md5
346
+ else:
347
+ dec = os.path.split(input_fasta_path)[1].split(".")[0]
348
+
349
+ pkl_save_path = os.path.join(uniprot_msa_feature_dir, f"{md5}.pkl.gz")
350
+
351
+ if os.path.exists(pkl_save_path):
352
+ return dict()
353
+ if os.path.exists(os.path.join(output_dir, dec, "uniprot_hits.sto")):
354
+ msa_feature = parse_uniprot_alignment_dir(os.path.join(output_dir, dec))
355
+
356
+ sequence = "".join([amino_acid_3to1[standard_ccds[i]] for i in msa_feature["msa_all_seq"][0]])
357
+ md5_string = convert_md5_string(f"protein:{sequence}")
358
+ if md5 == md5_string:
359
+ feature = {
360
+ "msa_all_seq": msa_feature["msa_all_seq"][:max_seq].astype(np.int8),
361
+ "deletion_matrix_all_seq": msa_feature["deletion_matrix_all_seq"][:max_seq].astype(np.int8),
362
+ "msa_species_identifiers_all_seq": msa_feature["msa_species_identifiers_all_seq"][:max_seq]
363
+ }
364
+ dump_pkl(feature, pkl_save_path)
365
+ return dict()
366
+ else:
367
+ return {input_fasta_path: f"seqs not equal, asset [{sequence}], but found [{seqs[0]}]"}
368
+
369
+ else:
370
+ return {
371
+ input_fasta_path: f"MSA is not enough!"
372
+ }
373
+
374
+ out = run_pool_tasks(partial(
375
+ _convert_msas_out_to_uniprot_msa_features,
376
+ output_dir=output_dir,
377
+ uniprot_msa_feature_dir=uniprot_msa_feature_dir,
378
+ convert_md5=convert_md5
379
+ ), input_fasta_paths, num_workers=num_workers, return_dict=True)
380
+ return out
381
+
382
+ @staticmethod
383
+ def convert_msas_out_to_rna_msa_features(
384
+ input_fasta_path,
385
+ output_dir,
386
+ rna_msa_feature_dir,
387
+ convert_md5=True,
388
+ num_workers=128
389
+ ):
390
+ import os
391
+ os.makedirs(rna_msa_feature_dir, exist_ok=True)
392
+ if isinstance(input_fasta_path, list):
393
+ input_fasta_paths = input_fasta_path
394
+ elif os.path.isdir(input_fasta_path):
395
+ input_fasta_paths = [os.path.join(input_fasta_path, i) for i in os.listdir(input_fasta_path)]
396
+ elif os.path.isfile(input_fasta_path):
397
+ input_fasta_paths = [input_fasta_path]
398
+ else:
399
+ input_fasta_paths = []
400
+ Exception("Can't parse input fasta path!")
401
+
402
+ global _convert_msas_out_to_rna_msa_features
403
+
404
+ def _convert_msas_out_to_rna_msa_features(
405
+ input_fasta_path,
406
+ output_dir,
407
+ rna_msa_feature_dir,
408
+ convert_md5=True,
409
+ ):
410
+ prefix = "rna"
411
+ max_seq = 16384
412
+ seqs, decs = parse_fasta(load_txt(input_fasta_path))
413
+ md5 = convert_md5_string(f"{prefix}:{seqs[0]}")
414
+ if convert_md5:
415
+ # TODO: debug
416
+ # dec = convert_md5_string(f"{prefix}:{input_fasta_path}")
417
+ dec = md5
418
+ else:
419
+ dec = os.path.split(input_fasta_path)[1].split(".")[0]
420
+
421
+ # DEBUG: whc homo search hits
422
+ if os.path.exists(os.path.join(output_dir, dec, "msas")):
423
+ dec = dec + "/msas"
424
+
425
+ pkl_save_path = os.path.join(rna_msa_feature_dir, f"{md5}.pkl.gz")
426
+
427
+ if os.path.exists(pkl_save_path):
428
+ return dict()
429
+ rna_msa_feature = parse_rna_alignment_dir(
430
+ os.path.join(output_dir, dec),
431
+ input_fasta_path
432
+ )
433
+
434
+ feature = {
435
+ "msa": rna_msa_feature["msa"][:max_seq].astype(np.int8),
436
+ "deletion_matrix": rna_msa_feature["deletion_matrix"][:max_seq].astype(np.int8),
437
+ "msa_species_identifiers": None
438
+ }
439
+ dump_pkl(feature, pkl_save_path)
440
+
441
+ return dict()
442
+
443
+ out = run_pool_tasks(partial(
444
+ _convert_msas_out_to_rna_msa_features,
445
+ output_dir=output_dir,
446
+ rna_msa_feature_dir=rna_msa_feature_dir,
447
+ convert_md5=convert_md5
448
+ ), input_fasta_paths, num_workers=num_workers, return_dict=True)
449
+ return out
450
+
451
+ @staticmethod
452
+ def find_chain_ids_without_msa_features(
453
+ polymer_filtering_out_json,
454
+ chain_id_to_meta_info_path,
455
+ dataset_dir,
456
+ uniprot=False,
457
+ num_workers=256,
458
+ ):
459
+ if not isinstance(polymer_filtering_out_json, list):
460
+ polymer_filtering_out_json = [polymer_filtering_out_json]
461
+ polymer_filtering_out = dict()
462
+ for i in polymer_filtering_out_json:
463
+ polymer_filtering_out.update(load_json(i))
464
+ chain_id_to_meta_info = load_json(chain_id_to_meta_info_path)
465
+ global find_chain_ids_without_msa_features
466
+
467
+ def find_chain_ids_without_msa_features(chain_id):
468
+ sequence_3 = chain_id_to_meta_info[chain_id]["sequence_3"]
469
+ seqs, md5s = get_protein_md5(sequence_3)
470
+ if uniprot:
471
+ dirs = ["uniprot_msa_features", "uniprot_msa_features_zkx", "uniprot_msa_features_unifold"]
472
+ else:
473
+ dirs = ["msa_features", "msa_features_zkx", "msa_features_whc", "msa_features_unifold"]
474
+ md5_dir = [[md5, dir] for dir in dirs for md5 in md5s]
475
+ for md5, dir in md5_dir:
476
+ if os.path.exists(os.path.join(dataset_dir, "features/", dir, f"{md5}.pkl.gz")):
477
+ return dict()
478
+ return {chain_id: {"state": False, "seqs": seqs}}
479
+
480
+ chain_ids = [k for k, v in polymer_filtering_out.items() if
481
+ v["state"] and chain_id_to_meta_info[k]["chain_class"] == "protein"]
482
+
483
+ out = run_pool_tasks(
484
+ find_chain_ids_without_msa_features, chain_ids, num_workers=num_workers, return_dict=True)
485
+ return out
486
+
487
+ def find_chain_ids_without_rna_msa_features(
488
+ self,
489
+ polymer_filtering_out_json,
490
+ ):
491
+ pass
492
+
493
+ @staticmethod
494
+ def check_msa_md5(msa_feature_dir):
495
+ pass
496
+
497
+ @staticmethod
498
+ def check_uniprot_msa_md5(uniprot_msa_feature_dir):
499
+ pass
500
+
501
+ @staticmethod
502
+ def check_rna_msa_md5(rna_msa_feature_dir):
503
+ pass
504
+
505
+ def get_training_pdbs(self):
506
+ pass
507
+
508
+
509
+
510
+
511
+
512
+
513
+
514
+
515
+ class DataPipeline():
516
+ def __init__(self):
517
+ super().__init__()
518
+ self.data_manager = DatasetManager()
519
+
520
+ # PDB:
521
+ # polymer_chain_id:
522
+ # weight_chain: contiguous_crop:1/3 spatial_crop: 2/3
523
+ # Interface
524
+ # weight_interface:
525
+ # [chain_id1, chain_id2] 0.2 contiguous_crop
526
+ # [chain_id1, chain_id2] 0.4 spatial_crop_interface
527
+ # [ < 20 chains] 0.4 spatial crop
528
+ #
529
+ #
530
+ #
531
+ #
532
+ # polymer chain contiguous crop sample weight w_chain*1/3 [chain_id]
533
+ # polymer chain spatial crop sample weight w_chain*2/3 [chain_id]
534
+ #
535
+ # interface contiguous crop sample weight w_interface * 0.2 [chain_id, chain_id]
536
+ # interface spatial crop sample weight w_interface * 0.4 >[chain_id, chain_id]
537
+ # interface spatial crop interface sample weight w_interface * 0.4 [chain_id, chain_id]
538
+ #
539
+ #
540
+ # pdb:
541
+ # chain:
542
+ # [chain_id]: 0.14
543
+ # [chain_id]: 0.23
544
+
545
+
546
+ # @staticmethod
547
+ # def get_pdb_info(pdb_id):
548
+ # all_chain_ids = pdb_id_to_meta_info[pdb_id]["chain_ids"]
549
+ #
550
+ # chain_ids_info = {
551
+ # "protein": [],
552
+ # "rna": [],
553
+ # "dna": [],
554
+ # "ligand": []
555
+ # }
556
+ # for chain_id_ in all_chain_ids:
557
+ # chain_id = f"{pdb_id}_{chain_id_}"
558
+ # if chain_id in chain_id_to_meta_info:
559
+ # chain_class = chain_id_to_meta_info[chain_id]["chain_class"].split("_")[0]
560
+ # if chain_id in chain_ids and os.path.exists(os.path.join(stfold_data_path, f"{chain_id}.pkl.gz")):
561
+ # if chain_class == "protein":
562
+ # if check_protein_msa_features(chain_id, chain_id_to_meta_info)[chain_id]["state"]:
563
+ # chain_ids_info[chain_class].append(chain_id)
564
+ # elif chain_class == "rna":
565
+ # if check_rna_msa_features(chain_id, chain_id_to_meta_info)[chain_id]["state"]:
566
+ # chain_ids_info[chain_class].append(chain_id)
567
+ #
568
+ # elif chain_class == "ligand" and os.path.exists(os.path.join(stfold_data_path, f"{chain_id}.pkl.gz")):
569
+ # chain_ids_info[chain_class].append(chain_id)
570
+ # return {pdb_id: chain_ids_info}
PhysDock/data/tools/feature_processing_multimer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ # Copyright 2022 AlQuraishi Laboratory
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Feature processing logic for multimer data pipeline."""
17
+ from typing import Iterable, MutableMapping, List, Mapping, Dict, Any, Union
18
+ from scipy.sparse import coo_matrix
19
+ import numpy as np
20
+
21
+ from . import msa_pairing
22
+
23
+ FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
24
+ # TODO: Move this into the config
25
+ REQUIRED_FEATURES = frozenset({
26
+ 'aatype', 'all_atom_mask', 'all_atom_positions', 'all_chains_entity_ids',
27
+ 'all_crops_all_chains_mask', 'all_crops_all_chains_positions',
28
+ 'all_crops_all_chains_residue_ids', 'assembly_num_chains', 'asym_id',
29
+ 'bert_mask', 'cluster_bias_mask', 'deletion_matrix', 'deletion_mean',
30
+ 'entity_id', 'entity_mask', 'mem_peak', 'msa', 'msa_mask', 'num_alignments',
31
+ 'num_templates', 'queue_size', 'residue_index', 'resolution',
32
+ 'seq_length', 'seq_mask', 'sym_id', 'template_aatype',
33
+ 'template_all_atom_mask', 'template_all_atom_positions'
34
+ })
35
+
36
+ MAX_TEMPLATES = 4
37
+ MSA_CROP_SIZE = 16384
38
+
39
+
40
+ def _is_homomer_or_monomer(chains: Iterable[Mapping[str, np.ndarray]]) -> bool:
41
+ """Checks if a list of chains represents a homomer/monomer example."""
42
+ # Note that an entity_id of 0 indicates padding.
43
+ # num_unique_chains = len(np.unique(np.concatenate(
44
+ # [np.unique(chain['entity_id'][chain['entity_id'] > 0]) for
45
+ # chain in chains])))
46
+
47
+ # return num_unique_chains == 1
48
+ num_chains = len(chains)
49
+ return num_chains == 1
50
+
51
+
52
+ def pair_and_merge(
53
+ all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]],
54
+ is_homomer_or_monomer,
55
+ ) -> FeatureDict:
56
+ """Runs processing on features to augment, pair and merge.
57
+
58
+ Args:
59
+ all_chain_features: A MutableMap of dictionaries of features for each chain.
60
+
61
+ Returns:
62
+ A dictionary of features.
63
+ """
64
+
65
+ process_unmerged_features(all_chain_features)
66
+
67
+ np_chains_list = list(all_chain_features.values())
68
+ np_chains_list_prot = [chain for chain in np_chains_list if
69
+ chain['chain_class'] in ['protein']]
70
+ np_chains_list_dna = [chain for chain in np_chains_list if
71
+ chain['chain_class'] in ['dna']]
72
+ np_chains_list_rna = [chain for chain in np_chains_list if
73
+ chain['chain_class'] in ['rna', ]]
74
+ # TODO: ligand?
75
+ np_chains_list_ligand = [chain for chain in np_chains_list if chain['chain_class'] in ['ligand']]
76
+ # np_chains_list_ligand = []
77
+
78
+ # pair_msa_sequences_prot = False
79
+ # if np_chains_list_prot:
80
+ # pair_msa_sequences_prot = not _is_homomer_or_monomer(np_chains_list_prot)
81
+ pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)
82
+ pair_msa_sequences_prot = not is_homomer_or_monomer
83
+ if pair_msa_sequences_prot and np_chains_list_prot:
84
+ # uniprot : all_seq pairs
85
+ np_chains_list_prot = msa_pairing.create_paired_features(
86
+ chains=np_chains_list_prot
87
+ )
88
+ # deduplicate msa
89
+ # np_chains_list_prot = msa_pairing.deduplicate_unpaired_sequences(np_chains_list_prot)
90
+ else:
91
+ if np_chains_list_prot:
92
+ for prot in np_chains_list_prot:
93
+ prot["num_alignments"] = np.ones([], dtype=np.int32)
94
+
95
+ for chain in np_chains_list_prot:
96
+ chain.pop("msa_species_identifiers", None)
97
+ chain.pop("msa_species_identifiers_all_seq", None)
98
+
99
+ np_chains_list_prot.extend(np_chains_list_rna)
100
+ np_chains_list_prot.extend(np_chains_list_dna)
101
+ np_chains_list_prot.extend(np_chains_list_ligand)
102
+
103
+ np_chains_list = np_chains_list_prot
104
+
105
+ np_chains_list = crop_chains(
106
+ np_chains_list,
107
+ msa_crop_size=MSA_CROP_SIZE,
108
+ pair_msa_sequences=pair_msa_sequences_prot,
109
+ max_templates=MAX_TEMPLATES
110
+ )
111
+
112
+ np_example = msa_pairing.merge_chain_features(
113
+ np_chains_list=np_chains_list, pair_msa_sequences=pair_msa_sequences,
114
+ max_templates=MAX_TEMPLATES
115
+ )
116
+
117
+ # np_example = print_final(np_example)
118
+ return np_example
119
+
120
+
121
+ def crop_chains(
122
+ chains_list: List[Mapping[str, np.ndarray]],
123
+ msa_crop_size: int,
124
+ pair_msa_sequences: bool,
125
+ max_templates: int
126
+ ) -> List[Mapping[str, np.ndarray]]:
127
+ """Crops the MSAs for a set of chains.
128
+
129
+ Args:
130
+ chains_list: A list of chains to be cropped.
131
+ msa_crop_size: The total number of sequences to crop from the MSA.
132
+ pair_msa_sequences: Whether we are operating in sequence-pairing mode.
133
+ max_templates: The maximum templates to use per chain.
134
+
135
+ Returns:
136
+ The chains cropped.
137
+ """
138
+
139
+ # Apply the cropping.
140
+ cropped_chains = []
141
+ for chain in chains_list:
142
+ if chain['chain_class'] in ['protein']:
143
+ # print(chain['chain_class'])
144
+ cropped_chain = _crop_single_chain(
145
+ chain,
146
+ msa_crop_size=msa_crop_size,
147
+ pair_msa_sequences=pair_msa_sequences,
148
+ max_templates=max_templates)
149
+ else:
150
+
151
+ msa_size = chain['msa'].shape[0]
152
+ msa_size_array = np.arange(msa_size)
153
+ target_size = MSA_CROP_SIZE
154
+ if msa_size < target_size:
155
+ sample_msa_id = np.random.choice(msa_size_array, target_size - msa_size, replace=True)
156
+ sample_msa = chain['msa'][sample_msa_id, :]
157
+ chain['msa'] = np.concatenate([chain['msa'], sample_msa], axis=0)
158
+ sample_msa_del = chain['deletion_matrix'][sample_msa_id, :]
159
+ chain['deletion_matrix'] = np.concatenate([chain['deletion_matrix'], sample_msa_del], axis=0)
160
+
161
+ else:
162
+ chain['msa'] = chain['msa'][:target_size, :]
163
+ msa_size = chain['msa'].shape[0]
164
+ msa_size_array = np.arange(msa_size)
165
+ chain['deletion_matrix'] = chain['deletion_matrix'][:target_size, :]
166
+
167
+ cropped_chain = chain
168
+ cropped_chains.append(cropped_chain)
169
+
170
+ return cropped_chains
171
+
172
+
173
+ def _crop_single_chain(chain: Mapping[str, np.ndarray],
174
+ msa_crop_size: int,
175
+ pair_msa_sequences: bool,
176
+ max_templates: int) -> Mapping[str, np.ndarray]:
177
+ """Crops msa sequences to `msa_crop_size`."""
178
+ msa_size = len(chain['msa'])
179
+
180
+ if pair_msa_sequences:
181
+ # print(chain.keys())
182
+ msa_size_all_seq = chain['msa_all_seq'].shape[0]
183
+ msa_crop_size_all_seq = np.minimum(msa_size_all_seq, msa_crop_size // 2)
184
+
185
+
186
+
187
+ else:
188
+
189
+ msa_crop_size_all_seq = 0
190
+
191
+ include_templates = 'template_aatype' in chain and max_templates
192
+ if include_templates:
193
+ num_templates = chain['template_aatype'].shape[0]
194
+ templates_crop_size = np.minimum(num_templates, max_templates)
195
+
196
+ target_size = MSA_CROP_SIZE - msa_crop_size_all_seq
197
+ if msa_size < target_size:
198
+ sample_msa_id = np.random.choice(np.arange(msa_size), target_size - msa_size, replace=True)
199
+ for k in chain:
200
+ k_split = k.split('_all_seq')[0]
201
+ if k_split in msa_pairing.TEMPLATE_FEATURES:
202
+ chain[k] = chain[k][:templates_crop_size, :]
203
+ elif k_split in msa_pairing.MSA_FEATURES:
204
+ if '_all_seq' in k:
205
+ chain[k] = chain[k][:msa_crop_size_all_seq, :]
206
+ else:
207
+
208
+ if msa_size < target_size:
209
+
210
+ sample_msa = chain[k][sample_msa_id, :]
211
+ chain[k] = np.concatenate([chain[k], sample_msa], axis=0)
212
+
213
+
214
+ else:
215
+ chain[k] = chain[k][:target_size, :]
216
+
217
+ chain['num_alignments'] = np.asarray(len(chain['msa']), dtype=np.int32)
218
+ if include_templates:
219
+ chain['num_templates'] = np.asarray(templates_crop_size, dtype=np.int32)
220
+ if pair_msa_sequences:
221
+ chain['num_alignments_all_seq'] = np.asarray(
222
+ len(chain['msa_all_seq']), dtype=np.int32)
223
+
224
+ return chain
225
+
226
+
227
+ def print_final(
228
+ np_example: Mapping[str, np.ndarray]
229
+ ) -> Mapping[str, np.ndarray]:
230
+ return np_example
231
+
232
+
233
+ def _filter_features(
234
+ np_example: Mapping[str, np.ndarray]
235
+ ) -> Mapping[str, np.ndarray]:
236
+ """Filters features of example to only those requested."""
237
+ return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}
238
+
239
+
240
+ def process_unmerged_features(
241
+ all_chain_features: MutableMapping[str, Mapping[str, np.ndarray]]
242
+ ):
243
+ """Postprocessing stage for per-chain features before merging."""
244
+ num_chains = len(all_chain_features)
245
+ for chain_features in all_chain_features.values():
246
+ # if chain_features['chain_class'] in ['protein']:
247
+ chain_features['deletion_mean'] = np.mean(
248
+ chain_features['deletion_matrix'], axis=0
249
+ )
250
+
251
+ # Add assembly_num_chains.
252
+ chain_features['assembly_num_chains'] = np.asarray(num_chains)
253
+
254
+ # Add entity_mask.
255
+ for chain_features in all_chain_features.values():
256
+ chain_features['entity_mask'] = (
257
+ chain_features['entity_id'] != 0).astype(np.int32)
PhysDock/data/tools/get_metrics.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ sys.path.append("../")
5
+ import os
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Optional
10
+
11
+ import scipy
12
+ import random
13
+ import logging
14
+ import collections
15
+ from functools import partial
16
+ from typing import Union, Tuple, Dict
17
+ import itertools
18
+ from PhysDock.utils.tensor_utils import tensor_tree_map
19
+ from PhysDock.utils.io_utils import load_json, load_txt,dump_json,dump_txt,dump_pkl, load_pkl
20
+
21
+
22
+
23
+ def _calculate_bin_centers(breaks: np.ndarray):
24
+ """Gets the bin centers from the bin edges.
25
+
26
+ Args:
27
+ breaks: [num_bins - 1] the error bin edges.
28
+
29
+ Returns:
30
+ bin_centers: [num_bins] the error bin centers.
31
+ """
32
+ step = (breaks[1] - breaks[0])
33
+
34
+ # Add half-step to get the center
35
+ bin_centers = breaks + step / 2
36
+ # Add a catch-all bin at the end.
37
+ bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
38
+ axis=0)
39
+ return bin_centers
40
+
41
+ def _calculate_expected_aligned_error(
42
+ alignment_confidence_breaks: np.ndarray,
43
+ aligned_distance_error_probs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
44
+ """Calculates expected aligned distance errors for every pair of residues.
45
+
46
+ Args:
47
+ alignment_confidence_breaks: [num_bins - 1] the error bin edges.
48
+ aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
49
+ probs for each error bin, for each pair of residues.
50
+
51
+ Returns:
52
+ predicted_aligned_error: [num_res, num_res] the expected aligned distance
53
+ error for each pair of residues.
54
+ max_predicted_aligned_error: The maximum predicted error possible.
55
+ """
56
+ bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
57
+
58
+ # Tuple of expected aligned distance error and max possible error.
59
+ return (np.sum(aligned_distance_error_probs * bin_centers, axis=-1),
60
+ np.asarray(bin_centers[-1]))
61
+
62
+
63
+ def compute_plddt(logits: np.ndarray) -> np.ndarray:
64
+ """Computes per-residue pLDDT from logits.
65
+
66
+ Args:
67
+ logits: [num_res, num_bins] output from the PredictedLDDTHead.
68
+
69
+ Returns:
70
+ plddt: [num_res] per-residue pLDDT.
71
+ """
72
+ num_bins = logits.shape[-1]
73
+ bin_width = 1.0 / num_bins
74
+ bin_centers = np.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
75
+ probs = scipy.special.softmax(logits, axis=-1)
76
+ predicted_lddt_ca = np.sum(probs * bin_centers[None, :], axis=-1)
77
+ return predicted_lddt_ca * 100
78
+
79
+ def predicted_tm_score(
80
+ logits: np.ndarray,
81
+ breaks: np.ndarray,
82
+ residue_weights: Optional[np.ndarray] = None,
83
+ asym_id: Optional[np.ndarray] = None,
84
+ interface: bool = False) -> np.ndarray:
85
+ """Computes predicted TM alignment or predicted interface TM alignment score.
86
+
87
+ Args:
88
+ logits: [num_res, num_res, num_bins] the logits output from
89
+ PredictedAlignedErrorHead.
90
+ breaks: [num_bins] the error bins.
91
+ residue_weights: [num_res] the per residue weights to use for the
92
+ expectation.
93
+ asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
94
+ ipTM calculation, i.e. when interface=True.
95
+ interface: If True, interface predicted TM score is computed.
96
+
97
+ Returns:
98
+ ptm_score: The predicted TM alignment or the predicted iTM score.
99
+ """
100
+
101
+ # residue_weights has to be in [0, 1], but can be floating-point, i.e. the
102
+ # exp. resolved head's probability.
103
+ if residue_weights is None:
104
+ residue_weights = np.ones(logits.shape[0])
105
+
106
+ bin_centers = _calculate_bin_centers(breaks)
107
+
108
+ num_res = int(np.sum(residue_weights))
109
+ # Clip num_res to avoid negative/undefined d0.
110
+ clipped_num_res = max(num_res, 19)
111
+
112
+ # Compute d_0(num_res) as defined by TM-score, eqn. (5) in Yang & Skolnick
113
+ # "Scoring function for automated assessment of protein structure template
114
+ # quality", 2004: http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf
115
+ d0 = 1.24 * (clipped_num_res - 15) ** (1./3) - 1.8
116
+
117
+ # Convert logits to probs.
118
+ probs = scipy.special.softmax(logits, axis=-1)
119
+
120
+ # TM-Score term for every bin.
121
+ tm_per_bin = 1. / (1 + np.square(bin_centers) / np.square(d0))
122
+ # E_distances tm(distance).
123
+ predicted_tm_term = np.sum(probs * tm_per_bin, axis=-1)
124
+
125
+ pair_mask = np.ones_like(predicted_tm_term, dtype=bool)
126
+ if interface:
127
+ pair_mask *= asym_id[:, None] != asym_id[None, :]
128
+
129
+ predicted_tm_term *= pair_mask
130
+
131
+ pair_residue_weights = pair_mask * (
132
+ residue_weights[None, :] * residue_weights[:, None])
133
+ normed_residue_mask = pair_residue_weights / (1e-8 + np.sum(
134
+ pair_residue_weights, axis=-1, keepdims=True))
135
+ per_alignment = np.sum(predicted_tm_term * normed_residue_mask, axis=-1)
136
+ return np.asarray(per_alignment[(per_alignment * residue_weights).argmax()])
137
+
138
+
139
+ def compute_predicted_aligned_error(
140
+ logits: np.ndarray,
141
+ breaks: np.ndarray) -> Dict[str, np.ndarray]:
142
+ """Computes aligned confidence metrics from logits.
143
+
144
+ Args:
145
+ logits: [num_res, num_res, num_bins] the logits output from
146
+ PredictedAlignedErrorHead.
147
+ breaks: [num_bins - 1] the error bin edges.
148
+
149
+ Returns:
150
+ aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
151
+ aligned error probabilities over bins for each residue pair.
152
+ predicted_aligned_error: [num_res, num_res] the expected aligned distance
153
+ error for each pair of residues.
154
+ max_predicted_aligned_error: The maximum predicted error possible.
155
+ """
156
+ aligned_confidence_probs = scipy.special.softmax(
157
+ logits,
158
+ axis=-1)
159
+ predicted_aligned_error, max_predicted_aligned_error = (
160
+ _calculate_expected_aligned_error(
161
+ alignment_confidence_breaks=breaks,
162
+ aligned_distance_error_probs=aligned_confidence_probs))
163
+ return {
164
+ 'aligned_confidence_probs': aligned_confidence_probs,
165
+ 'predicted_aligned_error': predicted_aligned_error,
166
+ 'max_predicted_aligned_error': max_predicted_aligned_error,
167
+ }
168
+
169
+ def get_has_clash(atom_pos, atom_mask, asym_id, is_polymer_chain):
170
+ """
171
+ A structure is marked as having a clash (has_clash) if for any two
172
+ polymer chains A,B in the prediction clashes(A,B) > 100 or
173
+ clashes(A,B) / min(NA,NB) > 0.5 where NA is the number of atoms in
174
+ chain A.
175
+ Args:
176
+ atom_pos: [N_atom, 3]
177
+ atom_mask: [N_atom]
178
+ asym_id: [N_atom]
179
+ is_polymer_chain: [N_atom]
180
+ """
181
+ flag = np.logical_and(atom_mask == 1, is_polymer_chain == 1)
182
+ atom_pos = atom_pos[flag]
183
+ asym_id = asym_id[flag]
184
+ uniq_asym_ids = np.unique(asym_id)
185
+ n = len(uniq_asym_ids)
186
+ if n == 1:
187
+ return 0
188
+ for aid1 in uniq_asym_ids[:-1]:
189
+ for aid2 in uniq_asym_ids[1:]:
190
+ pos1 = atom_pos[asym_id == aid1]
191
+ pos2 = atom_pos[asym_id == aid2]
192
+ dist = np.sqrt(np.sum((pos1[None] - pos2[:, None]) ** 2, -1))
193
+ n_clash = np.sum(dist < 1.1).astype('float32')
194
+ if n_clash > 100 or n_clash / min(len(pos1), len(pos2)) > 0.5:
195
+ return 1
196
+ return 0
197
+
198
+
199
+
200
+
201
+ def get_metrics(output, batch):
202
+ """
203
+ Args:
204
+ logits_plddt: (B, N_atom, b_plddt)
205
+ logits_pae: (B, N_token, N_token, b_pae)
206
+
207
+ Returns:
208
+ atom_plddts: (B, N_atom)
209
+ mean_plddt: (B,)
210
+ pae: (B, N_token, N_token)
211
+ ptm: (B,)
212
+ iptm: (B,)
213
+ has_clash: (B,)
214
+ ranking_confidence: (B,)
215
+ """
216
+ logit_value = output
217
+
218
+ # B = logit_value['p_pae'].shape[0]
219
+ breaks_pae = torch.linspace(0.,
220
+ 0.5 * 64,
221
+ 64 - 1)
222
+ inputs = {
223
+ 's_mask': batch['s_mask'],
224
+ 'asym_id': batch['asym_id'],
225
+ 'breaks_pae': torch.tile(breaks_pae, [ 1]),
226
+ # 'perm_asym_id': batch['perm_asym_id'],
227
+ 'is_polymer_chain': ((batch['is_protein'] +
228
+ batch['is_dna'] + batch['is_rna']) > 0),
229
+ **logit_value,
230
+ **batch
231
+
232
+ }
233
+
234
+ ret_list = []
235
+ # for i in range(B):
236
+ cur_input = tensor_tree_map(lambda x: x.numpy(), inputs)
237
+ # cur_input = inputs
238
+ ret = get_all_atom_confidence_metrics(cur_input,0)
239
+ ret_list.append(ret)
240
+
241
+ metrics = {}
242
+ for k, v in ret_list[0].items():
243
+ metrics[k] = torch.from_numpy(np.stack([r[k] for r in ret_list]))
244
+ return metrics
245
+
246
+
247
+
248
+ def get_all_atom_confidence_metrics(
249
+ prediction_result,b):
250
+ """get_all_atom_confidence_metrics."""
251
+ metrics = {}
252
+ metrics['atom_plddts'] = compute_plddt(
253
+ prediction_result['p_plddt'])
254
+ metrics['mean_plddt'] = metrics['atom_plddts'].mean()
255
+ metrics['pae'] = compute_predicted_aligned_error(
256
+ logits=prediction_result['p_pae'],
257
+ breaks=prediction_result['breaks_pae'])['predicted_aligned_error']
258
+ metrics['ptm'] = predicted_tm_score(
259
+ logits=prediction_result['p_pae'],
260
+ breaks=prediction_result['breaks_pae'],
261
+ residue_weights=prediction_result['s_mask'],
262
+ asym_id=None)
263
+ metrics['iptm'] = predicted_tm_score(
264
+ logits=prediction_result['p_pae'],
265
+ breaks=prediction_result['breaks_pae'],
266
+ residue_weights=prediction_result['s_mask'],
267
+ asym_id=prediction_result['asym_id'],
268
+ interface=True)
269
+ metrics['has_clash'] = get_has_clash(
270
+ prediction_result['x_pred'][b],
271
+ prediction_result['a_mask'],
272
+ prediction_result['asym_id'][prediction_result["atom_id_to_token_id"]],
273
+ ~prediction_result['is_ligand'][prediction_result["atom_id_to_token_id"]])
274
+ metrics['ranking_confidence'] = (
275
+ 0.8 * metrics['iptm'] + 0.2 * metrics['ptm']
276
+ - 1.0 * metrics['has_clash'])
277
+ return metrics
278
+
279
+
280
+ # output = load_pkl("../output.pkl.gz")
281
+ # feats = load_pkl("../feats.pkl.gz")
282
+ # for k,v in output.items():
283
+ # print(k,v.shape)
284
+ # # output[k] = torch.from_numpy(v)
285
+ # for k,v in feats.items():
286
+ # print(k,v.shape)
287
+ # # feats[k] = torch.from_numpy(v)
288
+ # # dump_pkl(output,"../output.pkl.gz")
289
+ # # dump_pkl(feats,"../feats.pkl.gz")
290
+ #
291
+ # metrics = get_metrics(output,feats)
292
+ #
293
+ # for k,v in metrics.items():
294
+ # print(k,v)
PhysDock/data/tools/hhblits.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run HHblits from Python."""
17
+ import glob
18
+ import logging
19
+ import os
20
+ import subprocess
21
+ from typing import Any, List, Mapping, Optional, Sequence
22
+
23
+ from . import utils
24
+
25
+
26
+ _HHBLITS_DEFAULT_P = 20
27
+ _HHBLITS_DEFAULT_Z = 500
28
+
29
+
30
+ class HHBlits:
31
+ """Python wrapper of the HHblits binary."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ binary_path: str,
37
+ databases: Sequence[str],
38
+ n_cpu: int = 4,
39
+ n_iter: int = 3,
40
+ e_value: float = 0.001,
41
+ maxseq: int = 1_000_000,
42
+ realign_max: int = 100_000,
43
+ maxfilt: int = 100_000,
44
+ min_prefilter_hits: int = 1000,
45
+ all_seqs: bool = False,
46
+ alt: Optional[int] = None,
47
+ p: int = _HHBLITS_DEFAULT_P,
48
+ z: int = _HHBLITS_DEFAULT_Z,
49
+ ):
50
+ """Initializes the Python HHblits wrapper.
51
+
52
+ Args:
53
+ binary_path: The path to the HHblits executable.
54
+ databases: A sequence of HHblits database paths. This should be the
55
+ common prefix for the database files (i.e. up to but not including
56
+ _hhm.ffindex etc.)
57
+ n_cpu: The number of CPUs to give HHblits.
58
+ n_iter: The number of HHblits iterations.
59
+ e_value: The E-value, see HHblits docs for more details.
60
+ maxseq: The maximum number of rows in an input alignment. Note that this
61
+ parameter is only supported in HHBlits version 3.1 and higher.
62
+ realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
63
+ maxfilt: Max number of hits allowed to pass the 2nd prefilter.
64
+ HHblits default: 20000.
65
+ min_prefilter_hits: Min number of hits to pass prefilter.
66
+ HHblits default: 100.
67
+ all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
68
+ HHblits default: False.
69
+ alt: Show up to this many alternative alignments.
70
+ p: Minimum Prob for a hit to be included in the output hhr file.
71
+ HHblits default: 20.
72
+ z: Hard cap on number of hits reported in the hhr file.
73
+ HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
74
+
75
+ Raises:
76
+ RuntimeError: If HHblits binary not found within the path.
77
+ """
78
+ self.binary_path = binary_path
79
+ self.databases = databases
80
+
81
+ for database_path in self.databases:
82
+ if not glob.glob(database_path + "_*"):
83
+ logging.error(
84
+ "Could not find HHBlits database %s", database_path
85
+ )
86
+ raise ValueError(
87
+ f"Could not find HHBlits database {database_path}"
88
+ )
89
+
90
+ self.n_cpu = n_cpu
91
+ self.n_iter = n_iter
92
+ self.e_value = e_value
93
+ self.maxseq = maxseq
94
+ self.realign_max = realign_max
95
+ self.maxfilt = maxfilt
96
+ self.min_prefilter_hits = min_prefilter_hits
97
+ self.all_seqs = all_seqs
98
+ self.alt = alt
99
+ self.p = p
100
+ self.z = z
101
+
102
+ def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
103
+ """Queries the database using HHblits."""
104
+ with utils.tmpdir_manager() as query_tmp_dir:
105
+ a3m_path = os.path.join(query_tmp_dir, "output.a3m")
106
+
107
+ db_cmd = []
108
+ for db_path in self.databases:
109
+ db_cmd.append("-d")
110
+ db_cmd.append(db_path)
111
+ cmd = [
112
+ self.binary_path,
113
+ "-i",
114
+ input_fasta_path,
115
+ "-cpu",
116
+ str(self.n_cpu),
117
+ "-oa3m",
118
+ a3m_path,
119
+ "-o",
120
+ "/dev/null",
121
+ "-n",
122
+ str(self.n_iter),
123
+ "-e",
124
+ str(self.e_value),
125
+ "-maxseq",
126
+ str(self.maxseq),
127
+ "-realign_max",
128
+ str(self.realign_max),
129
+ "-maxfilt",
130
+ str(self.maxfilt),
131
+ "-min_prefilter_hits",
132
+ str(self.min_prefilter_hits),
133
+ ]
134
+ if self.all_seqs:
135
+ cmd += ["-all"]
136
+ if self.alt:
137
+ cmd += ["-alt", str(self.alt)]
138
+ if self.p != _HHBLITS_DEFAULT_P:
139
+ cmd += ["-p", str(self.p)]
140
+ if self.z != _HHBLITS_DEFAULT_Z:
141
+ cmd += ["-Z", str(self.z)]
142
+ cmd += db_cmd
143
+
144
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
145
+ process = subprocess.Popen(
146
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
147
+ )
148
+
149
+ with utils.timing("HHblits query"):
150
+ stdout, stderr = process.communicate()
151
+ retcode = process.wait()
152
+
153
+ if retcode:
154
+ # Logs have a 15k character limit, so log HHblits error line by line.
155
+ logging.error("HHblits failed. HHblits stderr begin:")
156
+ for error_line in stderr.decode("utf-8").splitlines():
157
+ if error_line.strip():
158
+ logging.error(error_line.strip())
159
+ logging.error("HHblits stderr end")
160
+ raise RuntimeError(
161
+ "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
162
+ % (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
163
+ )
164
+
165
+ with open(a3m_path) as f:
166
+ a3m = f.read()
167
+
168
+ raw_output = dict(
169
+ a3m=a3m,
170
+ output=stdout,
171
+ stderr=stderr,
172
+ n_iter=self.n_iter,
173
+ e_value=self.e_value,
174
+ )
175
+ return [raw_output]
PhysDock/data/tools/hhsearch.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run HHsearch from Python."""
17
+ import glob
18
+ import logging
19
+ import os
20
+ import subprocess
21
+ from typing import Sequence, Optional
22
+
23
+ from . import parsers
24
+ from . import utils
25
+
26
+
27
+ class HHSearch:
28
+ """Python wrapper of the HHsearch binary."""
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ binary_path: str,
34
+ databases: Sequence[str],
35
+ n_cpu: int = 2,
36
+ maxseq: int = 1_000_000,
37
+ ):
38
+ """Initializes the Python HHsearch wrapper.
39
+
40
+ Args:
41
+ binary_path: The path to the HHsearch executable.
42
+ databases: A sequence of HHsearch database paths. This should be the
43
+ common prefix for the database files (i.e. up to but not including
44
+ _hhm.ffindex etc.)
45
+ n_cpu: The number of CPUs to use
46
+ maxseq: The maximum number of rows in an input alignment. Note that this
47
+ parameter is only supported in HHBlits version 3.1 and higher.
48
+
49
+ Raises:
50
+ RuntimeError: If HHsearch binary not found within the path.
51
+ """
52
+ self.binary_path = binary_path
53
+ self.databases = databases
54
+ self.n_cpu = n_cpu
55
+ self.maxseq = maxseq
56
+
57
+ for database_path in self.databases:
58
+ if not glob.glob(database_path + "_*"):
59
+ logging.error(
60
+ "Could not find HHsearch database %s", database_path
61
+ )
62
+ raise ValueError(
63
+ f"Could not find HHsearch database {database_path}"
64
+ )
65
+
66
+ @property
67
+ def output_format(self) -> str:
68
+ return 'hhr'
69
+
70
+ @property
71
+ def input_format(self) -> str:
72
+ # return 'sto'
73
+ return 'a3m'
74
+
75
+ def query(self, a3m: str, output_dir: Optional[str] = None) -> str:
76
+ """Queries the database using HHsearch using a given a3m."""
77
+ with utils.tmpdir_manager() as query_tmp_dir:
78
+ input_path = os.path.join(query_tmp_dir, "query.a3m")
79
+ output_dir = query_tmp_dir if output_dir is None else output_dir
80
+ hhr_path = os.path.join(output_dir, "hhsearch_output.hhr")
81
+ with open(input_path, "w") as f:
82
+ f.write(a3m)
83
+
84
+ db_cmd = []
85
+ for db_path in self.databases:
86
+ db_cmd.append("-d")
87
+ db_cmd.append(db_path)
88
+ cmd = [
89
+ self.binary_path,
90
+ "-i",
91
+ input_path,
92
+ "-o",
93
+ hhr_path,
94
+ "-maxseq",
95
+ str(self.maxseq),
96
+ "-cpu",
97
+ str(self.n_cpu),
98
+ ] + db_cmd
99
+
100
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
101
+ process = subprocess.Popen(
102
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
103
+ )
104
+ with utils.timing("HHsearch query"):
105
+ stdout, stderr = process.communicate()
106
+ retcode = process.wait()
107
+
108
+ # if retcode:
109
+ # # Stderr is truncated to prevent proto size errors in Beam.
110
+ # raise RuntimeError(
111
+ # "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
112
+ # % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
113
+ # )
114
+
115
+ with open(hhr_path) as f:
116
+ hhr = f.read()
117
+ return hhr
118
+
119
+ @staticmethod
120
+ def get_template_hits(
121
+ output_string: str,
122
+ input_sequence: str
123
+ ) -> Sequence[parsers.TemplateHit]:
124
+ """Gets parsed template hits from the raw string output by the tool"""
125
+ del input_sequence # Used by hmmsearch but not needed for hhsearch
126
+ return parsers.parse_hhr(output_string)
PhysDock/data/tools/hmmalign.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from typing import Optional, Sequence
4
+ import logging
5
+
6
+ from . import parsers
7
+ from . import hmmbuild
8
+ from . import utils
9
+
10
+
11
+ class Hmmalign(object):
12
+ def __init__(
13
+ self,
14
+ *,
15
+ hmmbuild_binary_path: str,
16
+ hmmalign_binary_path: str,
17
+ ):
18
+ self.binary_path = hmmalign_binary_path
19
+ self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
20
+
21
+ @property
22
+ def output_format(self) -> str:
23
+ return 'sto'
24
+
25
+ @property
26
+ def input_format(self) -> str:
27
+ return 'sto'
28
+
29
+ def realign_sto_with_fasta(self, input_fasta_path, input_sto_path, output_sto_path: Optional = None) -> str:
30
+ delete_out = False if output_sto_path is not None else True
31
+ with utils.tmpdir_manager() as query_tmp_dir:
32
+ hmm_output_path = os.path.join(query_tmp_dir, 'query.hmm')
33
+ output_sto_path = os.path.join(query_tmp_dir,
34
+ "realigned.sto") if output_sto_path is None else output_sto_path
35
+ with open(input_fasta_path, "r") as f:
36
+ hmm = self.hmmbuild_runner.build_rna_profile_from_fasta(f.read())
37
+ with open(hmm_output_path, 'w') as f:
38
+ f.write(hmm)
39
+
40
+ cmd = [
41
+ self.binary_path,
42
+ '--rna', # Don't include the alignment in stdout.
43
+ '--mapali', input_fasta_path,
44
+ "-o", output_sto_path,
45
+ hmm_output_path,
46
+ input_sto_path
47
+ ]
48
+ # print(cmd)
49
+ # print(" ".join(cmd))
50
+ logging.info('Launching sub-process %s', cmd)
51
+ process = subprocess.Popen(
52
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
53
+ with utils.timing(
54
+ f'hmmsearch query'):
55
+ stdout, stderr = process.communicate()
56
+ retcode = process.wait()
57
+
58
+ if retcode:
59
+ raise RuntimeError(
60
+ 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
61
+ stdout.decode('utf-8'), stderr.decode('utf-8')))
62
+ if delete_out:
63
+ with open(output_sto_path) as f:
64
+ out_msa = f.read()
65
+ if delete_out:
66
+ return out_msa
PhysDock/data/tools/hmmbuild.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A Python wrapper for hmmbuild - construct HMM profiles from MSA."""
16
+
17
+ import os
18
+ import re
19
+ import subprocess
20
+ import logging
21
+
22
+ from . import utils
23
+
24
+
25
+ class Hmmbuild(object):
26
+ """Python wrapper of the hmmbuild binary."""
27
+
28
+ def __init__(self,
29
+ *,
30
+ binary_path: str,
31
+ singlemx: bool = False):
32
+ """Initializes the Python hmmbuild wrapper.
33
+
34
+ Args:
35
+ binary_path: The path to the hmmbuild executable.
36
+ singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
37
+ just use a common substitution score matrix.
38
+
39
+ Raises:
40
+ RuntimeError: If hmmbuild binary not found within the path.
41
+ """
42
+ self.binary_path = binary_path
43
+ self.singlemx = singlemx
44
+
45
+ def build_profile_from_sto(self, sto: str, model_construction='fast') -> str:
46
+ """Builds a HHM for the aligned sequences given as an A3M string.
47
+
48
+ Args:
49
+ sto: A string with the aligned sequences in the Stockholm format.
50
+ model_construction: Whether to use reference annotation in the msa to
51
+ determine consensus columns ('hand') or default ('fast').
52
+
53
+ Returns:
54
+ A string with the profile in the HMM format.
55
+
56
+ Raises:
57
+ RuntimeError: If hmmbuild fails.
58
+ """
59
+ return self._build_profile(sto, model_construction=model_construction)
60
+
61
+ def build_profile_from_a3m(self, a3m: str) -> str:
62
+ """Builds a HHM for the aligned sequences given as an A3M string.
63
+
64
+ Args:
65
+ a3m: A string with the aligned sequences in the A3M format.
66
+
67
+ Returns:
68
+ A string with the profile in the HMM format.
69
+
70
+ Raises:
71
+ RuntimeError: If hmmbuild fails.
72
+ """
73
+ lines = []
74
+ for line in a3m.splitlines():
75
+ if not line.startswith('>'):
76
+ line = re.sub('[a-z]+', '', line) # Remove inserted residues.
77
+ lines.append(line + '\n')
78
+ msa = ''.join(lines)
79
+ return self._build_profile(msa, model_construction='fast')
80
+
81
+ def _build_profile(self, msa: str, model_construction: str = 'fast') -> str:
82
+ """Builds a HMM for the aligned sequences given as an MSA string.
83
+
84
+ Args:
85
+ msa: A string with the aligned sequences, in A3M or STO format.
86
+ model_construction: Whether to use reference annotation in the msa to
87
+ determine consensus columns ('hand') or default ('fast').
88
+
89
+ Returns:
90
+ A string with the profile in the HMM format.
91
+
92
+ Raises:
93
+ RuntimeError: If hmmbuild fails.
94
+ ValueError: If unspecified arguments are provided.
95
+ """
96
+ if model_construction not in {'hand', 'fast'}:
97
+ raise ValueError(f'Invalid model_construction {model_construction} - only'
98
+ 'hand and fast supported.')
99
+
100
+ with utils.tmpdir_manager() as query_tmp_dir:
101
+ input_query = os.path.join(query_tmp_dir, 'query.msa')
102
+ output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')
103
+
104
+ with open(input_query, 'w') as f:
105
+ f.write(msa)
106
+
107
+ cmd = [self.binary_path]
108
+ # If adding flags, we have to do so before the output and input:
109
+
110
+ if model_construction == 'hand':
111
+ cmd.append(f'--{model_construction}')
112
+ if self.singlemx:
113
+ cmd.append('--singlemx')
114
+ cmd.extend([
115
+ '--amino',
116
+ output_hmm_path,
117
+ input_query,
118
+ ])
119
+
120
+ logging.info('Launching subprocess %s', cmd)
121
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
122
+ stderr=subprocess.PIPE)
123
+
124
+ with utils.timing('hmmbuild query'):
125
+ stdout, stderr = process.communicate()
126
+ retcode = process.wait()
127
+ logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
128
+ stdout.decode('utf-8'), stderr.decode('utf-8'))
129
+
130
+ if retcode:
131
+ raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
132
+ % (stdout.decode('utf-8'), stderr.decode('utf-8')))
133
+
134
+ with open(output_hmm_path, encoding='utf-8') as f:
135
+ hmm = f.read()
136
+
137
+ return hmm
138
+
139
+ def build_rna_profile_from_fasta(self, fasta: str):
140
+ with utils.tmpdir_manager() as query_tmp_dir:
141
+ input_query = os.path.join(query_tmp_dir, 'query.fasta')
142
+ output_hmm_path = os.path.join(query_tmp_dir, 'query.hmm')
143
+ with open(input_query, 'w') as f:
144
+ f.write(fasta)
145
+ cmd = [self.binary_path]
146
+ cmd.extend([
147
+ '--rna',
148
+ output_hmm_path,
149
+ input_query,
150
+ ])
151
+ logging.info('Launching subprocess %s', cmd)
152
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
153
+ stderr=subprocess.PIPE)
154
+ with utils.timing('hmmbuild query'):
155
+ stdout, stderr = process.communicate()
156
+ retcode = process.wait()
157
+ logging.info('hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
158
+ stdout.decode('utf-8'), stderr.decode('utf-8'))
159
+ if retcode:
160
+ raise RuntimeError('hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n'
161
+ % (stdout.decode('utf-8'), stderr.decode('utf-8')))
162
+
163
+ with open(output_hmm_path, encoding='utf-8') as f:
164
+ hmm = f.read()
165
+ return hmm
PhysDock/data/tools/hmmsearch.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A Python wrapper for hmmsearch - search profile against a sequence db."""
16
+
17
+ import os
18
+ import subprocess
19
+ from typing import Optional, Sequence
20
+ import logging
21
+
22
+ from . import parsers
23
+ from . import hmmbuild
24
+ from . import utils
25
+
26
+
27
+ class Hmmsearch(object):
28
+ """Python wrapper of the hmmsearch binary."""
29
+
30
+ def __init__(self,
31
+ *,
32
+ binary_path: str,
33
+ hmmbuild_binary_path: str,
34
+ database_path: str,
35
+ flags: Optional[Sequence[str]] = None
36
+ ):
37
+ """Initializes the Python hmmsearch wrapper.
38
+
39
+ Args:
40
+ binary_path: The path to the hmmsearch executable.
41
+ hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
42
+ an hmm from an input a3m.
43
+ database_path: The path to the hmmsearch database (FASTA format).
44
+ flags: List of flags to be used by hmmsearch.
45
+
46
+ Raises:
47
+ RuntimeError: If hmmsearch binary not found within the path.
48
+ """
49
+ self.binary_path = binary_path
50
+ self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
51
+ self.database_path = database_path
52
+ if flags is None:
53
+ # Default hmmsearch run settings.
54
+ flags = ['--F1', '0.1',
55
+ '--F2', '0.1',
56
+ '--F3', '0.1',
57
+ '--incE', '100',
58
+ '-E', '100',
59
+ '--domE', '100',
60
+ '--incdomE', '100']
61
+ self.flags = flags
62
+
63
+ if not os.path.exists(self.database_path):
64
+ logging.error('Could not find hmmsearch database %s', database_path)
65
+ raise ValueError(f'Could not find hmmsearch database {database_path}')
66
+
67
+ @property
68
+ def output_format(self) -> str:
69
+ return 'sto'
70
+
71
+ @property
72
+ def input_format(self) -> str:
73
+ return 'sto'
74
+
75
+ def query(self, msa_sto: str, output_dir: Optional[str] = None) -> str:
76
+ """Queries the database using hmmsearch using a given stockholm msa."""
77
+ hmm = self.hmmbuild_runner.build_profile_from_sto(
78
+ msa_sto,
79
+ model_construction='hand'
80
+ )
81
+ return self.query_with_hmm(hmm, output_dir)
82
+
83
+ def query_with_hmm(self,
84
+ hmm: str,
85
+ output_dir: Optional[str] = None
86
+ ) -> str:
87
+ """Queries the database using hmmsearch using a given hmm."""
88
+ with utils.tmpdir_manager() as query_tmp_dir:
89
+ hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
90
+ output_dir = query_tmp_dir if output_dir is None else output_dir
91
+ out_path = os.path.join(output_dir, 'hmm_output.sto')
92
+ with open(hmm_input_path, 'w') as f:
93
+ f.write(hmm)
94
+
95
+ cmd = [
96
+ self.binary_path,
97
+ '--noali', # Don't include the alignment in stdout.
98
+ '--cpu', '8'
99
+ ]
100
+ # If adding flags, we have to do so before the output and input:
101
+ if self.flags:
102
+ cmd.extend(self.flags)
103
+ cmd.extend([
104
+ '-A', out_path,
105
+ hmm_input_path,
106
+ self.database_path,
107
+ ])
108
+
109
+ logging.info('Launching sub-process %s', cmd)
110
+ process = subprocess.Popen(
111
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
112
+ with utils.timing(
113
+ f'hmmsearch ({os.path.basename(self.database_path)}) query'):
114
+ stdout, stderr = process.communicate()
115
+ retcode = process.wait()
116
+
117
+ if retcode:
118
+ raise RuntimeError(
119
+ 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
120
+ stdout.decode('utf-8'), stderr.decode('utf-8')))
121
+
122
+ with open(out_path) as f:
123
+ out_msa = f.read()
124
+
125
+ return out_msa
126
+
127
+ @staticmethod
128
+ def get_template_hits(
129
+ output_string: str,
130
+ input_sequence: str
131
+ ) -> Sequence[parsers.TemplateHit]:
132
+ """Gets parsed template hits from the raw string output by the tool."""
133
+ template_hits = parsers.parse_hmmsearch_sto(
134
+ output_string,
135
+ input_sequence,
136
+ )
137
+ return template_hits
PhysDock/data/tools/jackhmmer.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run Jackhmmer from Python."""
17
+
18
+ from concurrent import futures
19
+ import glob
20
+ import logging
21
+ import os
22
+ import subprocess
23
+ from typing import Any, Callable, Mapping, Optional, Sequence
24
+ from urllib import request
25
+
26
+ from . import parsers
27
+ from . import utils
28
+
29
+
30
+ class Jackhmmer:
31
+ """Python wrapper of the Jackhmmer binary."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ binary_path: str,
37
+ database_path: str,
38
+ n_cpu: int = 8,
39
+ n_iter: int = 1,
40
+ e_value: float = 0.0001,
41
+ z_value: Optional[int] = None,
42
+ get_tblout: bool = False,
43
+ filter_f1: float = 0.0005,
44
+ filter_f2: float = 0.00005,
45
+ filter_f3: float = 0.0000005,
46
+ seq_limit: int = 50000,
47
+ incdom_e: Optional[float] = None,
48
+ dom_e: Optional[float] = None,
49
+ num_streamed_chunks: Optional[int] = None,
50
+ streaming_callback: Optional[Callable[[int], None]] = None,
51
+ ):
52
+ """Initializes the Python Jackhmmer wrapper.
53
+
54
+ Args:
55
+ binary_path: The path to the jackhmmer executable.
56
+ database_path: The path to the jackhmmer database (FASTA format).
57
+ n_cpu: The number of CPUs to give Jackhmmer.
58
+ n_iter: The number of Jackhmmer iterations.
59
+ e_value: The E-value, see Jackhmmer docs for more details.
60
+ z_value: The Z-value, see Jackhmmer docs for more details.
61
+ get_tblout: Whether to save tblout string.
62
+ filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
63
+ filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
64
+ filter_f3: Forward pre-filter, set to >1.0 to turn off.
65
+ incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
66
+ round.
67
+ dom_e: Domain e-value criteria for inclusion in tblout.
68
+ num_streamed_chunks: Number of database chunks to stream over.
69
+ streaming_callback: Callback function run after each chunk iteration with
70
+ the iteration number as argument.
71
+ """
72
+ self.binary_path = binary_path
73
+ self.database_path = database_path
74
+ self.num_streamed_chunks = num_streamed_chunks
75
+
76
+ if (
77
+ not os.path.exists(self.database_path)
78
+ and num_streamed_chunks is None
79
+ ):
80
+ logging.error("Could not find Jackhmmer database %s", database_path)
81
+ raise ValueError(
82
+ f"Could not find Jackhmmer database {database_path}"
83
+ )
84
+
85
+ self.n_cpu = n_cpu
86
+ self.n_iter = n_iter
87
+ self.e_value = e_value
88
+ self.z_value = z_value
89
+ self.filter_f1 = filter_f1
90
+ self.filter_f2 = filter_f2
91
+ self.filter_f3 = filter_f3
92
+ self.seq_limit = seq_limit
93
+ self.incdom_e = incdom_e
94
+ self.dom_e = dom_e
95
+ self.get_tblout = get_tblout
96
+ self.streaming_callback = streaming_callback
97
+
98
+ def _query_chunk(
99
+ self,
100
+ input_fasta_path: str,
101
+ database_path: str,
102
+ max_sequences: Optional[int] = None
103
+ ) -> Mapping[str, Any]:
104
+ """Queries the database chunk using Jackhmmer."""
105
+ with utils.tmpdir_manager() as query_tmp_dir:
106
+ sto_path = os.path.join(query_tmp_dir, "output.sto")
107
+
108
+ # The F1/F2/F3 are the expected proportion to pass each of the filtering
109
+ # stages (which get progressively more expensive), reducing these
110
+ # speeds up the pipeline at the expensive of sensitivity. They are
111
+ # currently set very low to make querying Mgnify run in a reasonable
112
+ # amount of time.
113
+ cmd_flags = [
114
+ # Don't pollute stdout with Jackhmmer output.
115
+ "-o",
116
+ "/dev/null",
117
+ "-A",
118
+ sto_path,
119
+ "--noali",
120
+ "--F1",
121
+ str(self.filter_f1),
122
+ "--F2",
123
+ str(self.filter_f2),
124
+ "--F3",
125
+ str(self.filter_f3),
126
+ # "--seq_limit",
127
+ # str(self.seq_limit),
128
+ "--incE",
129
+ str(self.e_value),
130
+ # Report only sequences with E-values <= x in per-sequence output.
131
+ "-E",
132
+ str(self.e_value),
133
+ "--cpu",
134
+ str(self.n_cpu),
135
+ "-N",
136
+ str(self.n_iter),
137
+ ]
138
+ if self.get_tblout:
139
+ tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
140
+ cmd_flags.extend(["--tblout", tblout_path])
141
+
142
+ if self.z_value:
143
+ cmd_flags.extend(["-Z", str(self.z_value)])
144
+
145
+ if self.dom_e is not None:
146
+ cmd_flags.extend(["--domE", str(self.dom_e)])
147
+
148
+ if self.incdom_e is not None:
149
+ cmd_flags.extend(["--incdomE", str(self.incdom_e)])
150
+
151
+ cmd = (
152
+ [self.binary_path]
153
+ + cmd_flags
154
+ + [input_fasta_path, database_path]
155
+ )
156
+ # print(cmd)
157
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
158
+ process = subprocess.Popen(
159
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
160
+ )
161
+ with utils.timing(
162
+ f"Jackhmmer ({os.path.basename(database_path)}) query"
163
+ ):
164
+ _, stderr = process.communicate()
165
+ retcode = process.wait()
166
+
167
+ if retcode:
168
+ raise RuntimeError(
169
+ "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
170
+ )
171
+
172
+ # Get e-values for each target name
173
+ tbl = ""
174
+ if self.get_tblout:
175
+ with open(tblout_path) as f:
176
+ tbl = f.read()
177
+
178
+ if (max_sequences is None):
179
+ with open(sto_path) as f:
180
+ sto = f.read()
181
+ else:
182
+ sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
183
+
184
+ raw_output = dict(
185
+ sto=sto,
186
+ tbl=tbl,
187
+ stderr=stderr,
188
+ n_iter=self.n_iter,
189
+ e_value=self.e_value,
190
+ )
191
+
192
+ return raw_output
193
+
194
+ def query(self,
195
+ input_fasta_path: str,
196
+ max_sequences: Optional[int] = None
197
+ ) -> Sequence[Sequence[Mapping[str, Any]]]:
198
+ return self.query_multiple([input_fasta_path], max_sequences)
199
+
200
+ def query_multiple(self,
201
+ input_fasta_paths: Sequence[str],
202
+ max_sequences: Optional[int] = None
203
+ ) -> Sequence[Sequence[Mapping[str, Any]]]:
204
+ """Queries the database using Jackhmmer."""
205
+ if self.num_streamed_chunks is None:
206
+ single_chunk_results = []
207
+ for input_fasta_path in input_fasta_paths:
208
+ single_chunk_result = self._query_chunk(
209
+ input_fasta_path, self.database_path, max_sequences,
210
+ )
211
+ single_chunk_results.append(single_chunk_result)
212
+ return single_chunk_results
213
+
214
+ db_basename = os.path.basename(self.database_path)
215
+ db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
216
+ db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
217
+
218
+ # Remove existing files to prevent OOM
219
+ for f in glob.glob(db_local_chunk("[0-9]*")):
220
+ try:
221
+ os.remove(f)
222
+ except OSError:
223
+ print(f"OSError while deleting {f}")
224
+
225
+ # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
226
+ with futures.ThreadPoolExecutor(max_workers=2) as executor:
227
+ chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
228
+ for i in range(1, self.num_streamed_chunks + 1):
229
+ # Copy the chunk locally
230
+ if i == 1:
231
+ future = executor.submit(
232
+ request.urlretrieve,
233
+ db_remote_chunk(i),
234
+ db_local_chunk(i),
235
+ )
236
+ if i < self.num_streamed_chunks:
237
+ next_future = executor.submit(
238
+ request.urlretrieve,
239
+ db_remote_chunk(i + 1),
240
+ db_local_chunk(i + 1),
241
+ )
242
+
243
+ # Run Jackhmmer with the chunk
244
+ future.result()
245
+ for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
246
+ chunked_outputs[fasta_idx].append(
247
+ self._query_chunk(
248
+ input_fasta_path,
249
+ db_local_chunk(i),
250
+ max_sequences
251
+ )
252
+ )
253
+
254
+ # Remove the local copy of the chunk
255
+ os.remove(db_local_chunk(i))
256
+ # Do not set next_future for the last chunk so that this works
257
+ # even for databases with only 1 chunk
258
+ if (i < self.num_streamed_chunks):
259
+ future = next_future
260
+ if self.streaming_callback:
261
+ self.streaming_callback(i)
262
+ return chunked_outputs
PhysDock/data/tools/kalign.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """A Python wrapper for Kalign."""
17
+ import os
18
+ import subprocess
19
+ from typing import Sequence
20
+ import logging
21
+
22
+ from . import utils
23
+
24
+
25
+ def _to_a3m(sequences: Sequence[str]) -> str:
26
+ """Converts sequences to an a3m file."""
27
+ names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
28
+ a3m = []
29
+ for sequence, name in zip(sequences, names):
30
+ a3m.append(u">" + name + u"\n")
31
+ a3m.append(sequence + u"\n")
32
+ return "".join(a3m)
33
+
34
+
35
+ class Kalign:
36
+ """Python wrapper of the Kalign binary."""
37
+
38
+ def __init__(self, *, binary_path: str):
39
+ """Initializes the Python Kalign wrapper.
40
+
41
+ Args:
42
+ binary_path: The path to the Kalign binary.
43
+
44
+ Raises:
45
+ RuntimeError: If Kalign binary not found within the path.
46
+ """
47
+ self.binary_path = binary_path
48
+
49
+ def align(self, sequences: Sequence[str]) -> str:
50
+ """Aligns the sequences and returns the alignment in A3M string.
51
+
52
+ Args:
53
+ sequences: A list of query sequence strings. The sequences have to be at
54
+ least 6 residues long (Kalign requires this). Note that the order in
55
+ which you give the sequences might alter the output slightly as
56
+ different alignment tree might get constructed.
57
+
58
+ Returns:
59
+ A string with the alignment in a3m format.
60
+
61
+ Raises:
62
+ RuntimeError: If Kalign fails.
63
+ ValueError: If any of the sequences is less than 6 residues long.
64
+ """
65
+ logging.info("Aligning %d sequences", len(sequences))
66
+
67
+ for s in sequences:
68
+ if len(s) < 6:
69
+ raise ValueError(
70
+ "Kalign requires all sequences to be at least 6 "
71
+ "residues long. Got %s (%d residues)." % (s, len(s))
72
+ )
73
+
74
+ with utils.tmpdir_manager() as query_tmp_dir:
75
+ input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
76
+ output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
77
+
78
+ with open(input_fasta_path, "w") as f:
79
+ f.write(_to_a3m(sequences))
80
+
81
+ cmd = [
82
+ self.binary_path,
83
+ "-i",
84
+ input_fasta_path,
85
+ "-o",
86
+ output_a3m_path,
87
+ "-format",
88
+ "fasta",
89
+ ]
90
+
91
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
92
+ process = subprocess.Popen(
93
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
94
+ )
95
+
96
+ with utils.timing("Kalign query"):
97
+ stdout, stderr = process.communicate()
98
+ retcode = process.wait()
99
+ logging.info(
100
+ "Kalign stdout:\n%s\n\nstderr:\n%s\n",
101
+ stdout.decode("utf-8"),
102
+ stderr.decode("utf-8"),
103
+ )
104
+
105
+ if retcode:
106
+ raise RuntimeError(
107
+ "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
108
+ % (stdout.decode("utf-8"), stderr.decode("utf-8"))
109
+ )
110
+
111
+ with open(output_a3m_path) as f:
112
+ a3m = f.read()
113
+
114
+ return a3m
PhysDock/data/tools/mmcif_parsing.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Parses the mmCIF file format."""
17
+ import collections
18
+ import dataclasses
19
+ import functools
20
+ import io
21
+ import json
22
+ import logging
23
+ import os
24
+ from typing import Any, Mapping, Optional, Sequence, Tuple
25
+ import numpy as np
26
+ from Bio import PDB
27
+
28
+ from . import PDBData
29
+
30
+ @dataclasses.dataclass
31
+ class residue_constants:
32
+ atom_type_num = 37
33
+ atom_types = ["N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD", "CD1", "CD2", "ND1", "ND2",
34
+ "OD1",
35
+ "OD2", "SD", "CE", "CE1", "CE2", "CE3", "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH",
36
+ "CZ",
37
+ "CZ2", "CZ3", "NZ", "OXT", ]
38
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
39
+
40
+
41
+ """General-purpose errors used throughout the data pipeline"""
42
+
43
+
44
+ class Error(Exception):
45
+ """Base class for exceptions."""
46
+
47
+
48
+ class MultipleChainsError(Error):
49
+ """An error indicating that multiple chains were found for a given ID."""
50
+
51
+
52
+ # Type aliases:
53
+ ChainId = str
54
+ PdbHeader = Mapping[str, Any]
55
+ PdbStructure = PDB.Structure.Structure
56
+ SeqRes = str
57
+ MmCIFDict = Mapping[str, Sequence[str]]
58
+
59
+
60
+ @dataclasses.dataclass(frozen=True)
61
+ class Monomer:
62
+ id: str
63
+ num: int
64
+
65
+
66
+ # Note - mmCIF format provides no guarantees on the type of author-assigned
67
+ # sequence numbers. They need not be integers.
68
+ @dataclasses.dataclass(frozen=True)
69
+ class AtomSite:
70
+ residue_name: str
71
+ author_chain_id: str
72
+ mmcif_chain_id: str
73
+ author_seq_num: str
74
+ mmcif_seq_num: int
75
+ insertion_code: str
76
+ hetatm_atom: str
77
+ model_num: int
78
+
79
+
80
+ # Used to map SEQRES index to a residue in the structure.
81
+ @dataclasses.dataclass(frozen=True)
82
+ class ResiduePosition:
83
+ chain_id: str
84
+ residue_number: int
85
+ insertion_code: str
86
+
87
+
88
+ @dataclasses.dataclass(frozen=True)
89
+ class ResidueAtPosition:
90
+ position: Optional[ResiduePosition]
91
+ name: str
92
+ is_missing: bool
93
+ hetflag: str
94
+
95
+
96
+ @dataclasses.dataclass(frozen=True)
97
+ class MmcifObject:
98
+ """Representation of a parsed mmCIF file.
99
+
100
+ Contains:
101
+ file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
102
+ files being processed.
103
+ header: Biopython header.
104
+ structure: Biopython structure.
105
+ chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
106
+ {'A': 'ABCDEFG'}
107
+ seqres_to_structure: Dict; for each chain_id contains a mapping between
108
+ SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
109
+ 1: ResidueAtPosition,
110
+ ...}}
111
+ raw_string: The raw string used to construct the MmcifObject.
112
+ """
113
+
114
+ file_id: str
115
+ header: PdbHeader
116
+ structure: PdbStructure
117
+ chain_to_seqres: Mapping[ChainId, SeqRes]
118
+ seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
119
+ raw_string: Any
120
+
121
+
122
+ @dataclasses.dataclass(frozen=True)
123
+ class ParsingResult:
124
+ """Returned by the parse function.
125
+
126
+ Contains:
127
+ mmcif_object: A MmcifObject, may be None if no chain could be successfully
128
+ parsed.
129
+ errors: A dict mapping (file_id, chain_id) to any exception generated.
130
+ """
131
+
132
+ mmcif_object: Optional[MmcifObject]
133
+ errors: Mapping[Tuple[str, str], Any]
134
+
135
+
136
+ class ParseError(Exception):
137
+ """An error indicating that an mmCIF file could not be parsed."""
138
+
139
+
140
+ def mmcif_loop_to_list(
141
+ prefix: str, parsed_info: MmCIFDict
142
+ ) -> Sequence[Mapping[str, str]]:
143
+ """Extracts loop associated with a prefix from mmCIF data as a list.
144
+
145
+ Reference for loop_ in mmCIF:
146
+ http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
147
+
148
+ Args:
149
+ prefix: Prefix shared by each of the data items in the loop.
150
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
151
+ _entity_poly_seq.mon_id. Should include the trailing period.
152
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
153
+ parser.
154
+
155
+ Returns:
156
+ Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
157
+ """
158
+ cols = []
159
+ data = []
160
+ for key, value in parsed_info.items():
161
+ if key.startswith(prefix):
162
+ cols.append(key)
163
+ data.append(value)
164
+
165
+ assert all([len(xs) == len(data[0]) for xs in data]), (
166
+ "mmCIF error: Not all loops are the same length: %s" % cols
167
+ )
168
+
169
+ return [dict(zip(cols, xs)) for xs in zip(*data)]
170
+
171
+
172
+ def mmcif_loop_to_dict(
173
+ prefix: str,
174
+ index: str,
175
+ parsed_info: MmCIFDict,
176
+ ) -> Mapping[str, Mapping[str, str]]:
177
+ """Extracts loop associated with a prefix from mmCIF data as a dictionary.
178
+
179
+ Args:
180
+ prefix: Prefix shared by each of the data items in the loop.
181
+ e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
182
+ _entity_poly_seq.mon_id. Should include the trailing period.
183
+ index: Which item of loop data should serve as the key.
184
+ parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
185
+ parser.
186
+
187
+ Returns:
188
+ Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
189
+ indexed by the index column.
190
+ """
191
+ entries = mmcif_loop_to_list(prefix, parsed_info)
192
+ return {entry[index]: entry for entry in entries}
193
+
194
+
195
+ @functools.lru_cache(16, typed=False)
196
+ def parse(
197
+ *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
198
+ ) -> ParsingResult:
199
+ """Entry point, parses an mmcif_string.
200
+
201
+ Args:
202
+ file_id: A string identifier for this file. Should be unique within the
203
+ collection of files being processed.
204
+ mmcif_string: Contents of an mmCIF file.
205
+ catch_all_errors: If True, all exceptions are caught and error messages are
206
+ returned as part of the ParsingResult. If False exceptions will be allowed
207
+ to propagate.
208
+
209
+ Returns:
210
+ A ParsingResult.
211
+ """
212
+ errors = {}
213
+ try:
214
+ parser = PDB.MMCIFParser(QUIET=True)
215
+ handle = io.StringIO(mmcif_string)
216
+ full_structure = parser.get_structure("", handle)
217
+ first_model_structure = _get_first_model(full_structure)
218
+ # Extract the _mmcif_dict from the parser, which contains useful fields not
219
+ # reflected in the Biopython structure.
220
+ parsed_info = parser._mmcif_dict # pylint:disable=protected-access
221
+
222
+ # Ensure all values are lists, even if singletons.
223
+ for key, value in parsed_info.items():
224
+ if not isinstance(value, list):
225
+ parsed_info[key] = [value]
226
+
227
+ header = _get_header(parsed_info)
228
+
229
+ # Determine the protein chains, and their start numbers according to the
230
+ # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
231
+ valid_chains = _get_protein_chains(parsed_info=parsed_info)
232
+ if not valid_chains:
233
+ return ParsingResult(
234
+ None, {(file_id, ""): "No protein chains found in this file."}
235
+ )
236
+ seq_start_num = {
237
+ chain_id: min([monomer.num for monomer in seq])
238
+ for chain_id, seq in valid_chains.items()
239
+ }
240
+
241
+ # Loop over the atoms for which we have coordinates. Populate two mappings:
242
+ # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
243
+ # the authors / Biopython).
244
+ # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
245
+ mmcif_to_author_chain_id = {}
246
+ seq_to_structure_mappings = {}
247
+ for atom in _get_atom_site_list(parsed_info):
248
+ if atom.model_num != "1":
249
+ # We only process the first model at the moment.
250
+ continue
251
+
252
+ mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
253
+
254
+ if atom.mmcif_chain_id in valid_chains:
255
+ hetflag = " "
256
+ if atom.hetatm_atom == "HETATM":
257
+ # Water atoms are assigned a special hetflag of W in Biopython. We
258
+ # need to do the same, so that this hetflag can be used to fetch
259
+ # a residue from the Biopython structure by id.
260
+ if atom.residue_name in ("HOH", "WAT"):
261
+ hetflag = "W"
262
+ else:
263
+ hetflag = "H_" + atom.residue_name
264
+ insertion_code = atom.insertion_code
265
+ if not _is_set(atom.insertion_code):
266
+ insertion_code = " "
267
+ position = ResiduePosition(
268
+ chain_id=atom.author_chain_id,
269
+ residue_number=int(atom.author_seq_num),
270
+ insertion_code=insertion_code,
271
+ )
272
+ seq_idx = (
273
+ int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
274
+ )
275
+ current = seq_to_structure_mappings.get(
276
+ atom.author_chain_id, {}
277
+ )
278
+ current[seq_idx] = ResidueAtPosition(
279
+ position=position,
280
+ name=atom.residue_name,
281
+ is_missing=False,
282
+ hetflag=hetflag,
283
+ )
284
+ seq_to_structure_mappings[atom.author_chain_id] = current
285
+
286
+ # Add missing residue information to seq_to_structure_mappings.
287
+ for chain_id, seq_info in valid_chains.items():
288
+ author_chain = mmcif_to_author_chain_id[chain_id]
289
+ current_mapping = seq_to_structure_mappings[author_chain]
290
+ for idx, monomer in enumerate(seq_info):
291
+ if idx not in current_mapping:
292
+ current_mapping[idx] = ResidueAtPosition(
293
+ position=None,
294
+ name=monomer.id,
295
+ is_missing=True,
296
+ hetflag=" ",
297
+ )
298
+
299
+ author_chain_to_sequence = {}
300
+ for chain_id, seq_info in valid_chains.items():
301
+ author_chain = mmcif_to_author_chain_id[chain_id]
302
+ seq = []
303
+ for monomer in seq_info:
304
+ code = PDBData.protein_letters_3to1_extended.get(monomer.id, "X")
305
+ seq.append(code if len(code) == 1 else "X")
306
+ seq = "".join(seq)
307
+ author_chain_to_sequence[author_chain] = seq
308
+
309
+ mmcif_object = MmcifObject(
310
+ file_id=file_id,
311
+ header=header,
312
+ structure=first_model_structure,
313
+ chain_to_seqres=author_chain_to_sequence,
314
+ seqres_to_structure=seq_to_structure_mappings,
315
+ raw_string=parsed_info,
316
+ )
317
+
318
+ return ParsingResult(mmcif_object=mmcif_object, errors=errors)
319
+ except Exception as e: # pylint:disable=broad-except
320
+ errors[(file_id, "")] = e
321
+ if not catch_all_errors:
322
+ raise
323
+ return ParsingResult(mmcif_object=None, errors=errors)
324
+
325
+
326
+ def _get_first_model(structure: PdbStructure) -> PdbStructure:
327
+ """Returns the first model in a Biopython structure."""
328
+ return next(structure.get_models())
329
+
330
+
331
+ _MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
332
+
333
+
334
+ def get_release_date(parsed_info: MmCIFDict) -> str:
335
+ """Returns the oldest revision date."""
336
+ revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
337
+ return min(revision_dates)
338
+
339
+
340
+ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
341
+ """Returns a basic header containing method, release date and resolution."""
342
+ header = {}
343
+
344
+ experiments = mmcif_loop_to_list("_exptl.", parsed_info)
345
+ header["structure_method"] = ",".join(
346
+ [experiment["_exptl.method"].lower() for experiment in experiments]
347
+ )
348
+
349
+ # Note: The release_date here corresponds to the oldest revision. We prefer to
350
+ # use this for dataset filtering over the deposition_date.
351
+ if "_pdbx_audit_revision_history.revision_date" in parsed_info:
352
+ header["release_date"] = get_release_date(parsed_info)
353
+ else:
354
+ logging.warning(
355
+ "Could not determine release_date: %s", parsed_info["_entry.id"]
356
+ )
357
+
358
+ header["resolution"] = 0.00
359
+ for res_key in (
360
+ "_refine.ls_d_res_high",
361
+ "_em_3d_reconstruction.resolution",
362
+ "_reflns.d_resolution_high",
363
+ ):
364
+ if res_key in parsed_info:
365
+ try:
366
+ raw_resolution = parsed_info[res_key][0]
367
+ header["resolution"] = float(raw_resolution)
368
+ except ValueError:
369
+ logging.debug(
370
+ "Invalid resolution format: %s", parsed_info[res_key]
371
+ )
372
+
373
+ return header
374
+
375
+
376
+ def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
377
+ """Returns list of atom sites; contains data not present in the structure."""
378
+ return [
379
+ AtomSite(*site)
380
+ for site in zip( # pylint:disable=g-complex-comprehension
381
+ parsed_info["_atom_site.label_comp_id"],
382
+ parsed_info["_atom_site.auth_asym_id"],
383
+ parsed_info["_atom_site.label_asym_id"],
384
+ parsed_info["_atom_site.auth_seq_id"],
385
+ parsed_info["_atom_site.label_seq_id"],
386
+ parsed_info["_atom_site.pdbx_PDB_ins_code"],
387
+ parsed_info["_atom_site.group_PDB"],
388
+ parsed_info["_atom_site.pdbx_PDB_model_num"],
389
+ )
390
+ ]
391
+
392
+
393
+ def _get_protein_chains(
394
+ *, parsed_info: Mapping[str, Any]
395
+ ) -> Mapping[ChainId, Sequence[Monomer]]:
396
+ """Extracts polymer information for protein chains only.
397
+
398
+ Args:
399
+ parsed_info: _mmcif_dict produced by the Biopython parser.
400
+
401
+ Returns:
402
+ A dict mapping mmcif chain id to a list of Monomers.
403
+ """
404
+ # Get polymer information for each entity in the structure.
405
+ entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
406
+
407
+ polymers = collections.defaultdict(list)
408
+ for entity_poly_seq in entity_poly_seqs:
409
+ polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
410
+ Monomer(
411
+ id=entity_poly_seq["_entity_poly_seq.mon_id"],
412
+ num=int(entity_poly_seq["_entity_poly_seq.num"]),
413
+ )
414
+ )
415
+
416
+ # Get chemical compositions. Will allow us to identify which of these polymers
417
+ # are proteins.
418
+ chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
419
+
420
+ # Get chains information for each entity. Necessary so that we can return a
421
+ # dict keyed on chain id rather than entity.
422
+ struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
423
+
424
+ entity_to_mmcif_chains = collections.defaultdict(list)
425
+ for struct_asym in struct_asyms:
426
+ chain_id = struct_asym["_struct_asym.id"]
427
+ entity_id = struct_asym["_struct_asym.entity_id"]
428
+ entity_to_mmcif_chains[entity_id].append(chain_id)
429
+
430
+ # Identify and return the valid protein chains.
431
+ valid_chains = {}
432
+ for entity_id, seq_info in polymers.items():
433
+ chain_ids = entity_to_mmcif_chains[entity_id]
434
+
435
+ # Reject polymers without any peptide-like components, such as DNA/RNA.
436
+ if any(
437
+ [
438
+ "peptide" in chem_comps[monomer.id]["_chem_comp.type"]
439
+ for monomer in seq_info
440
+ ]
441
+ ):
442
+ for chain_id in chain_ids:
443
+ valid_chains[chain_id] = seq_info
444
+ return valid_chains
445
+
446
+
447
+ def _is_set(data: str) -> bool:
448
+ """Returns False if data is a special mmCIF character indicating 'unset'."""
449
+ return data not in (".", "?")
450
+
451
+
452
+ def get_atom_coords(
453
+ mmcif_object: MmcifObject,
454
+ chain_id: str,
455
+ _zero_center_positions: bool = False
456
+ ) -> Tuple[np.ndarray, np.ndarray]:
457
+ # Locate the right chain
458
+ chains = list(mmcif_object.structure.get_chains())
459
+ relevant_chains = [c for c in chains if c.id == chain_id]
460
+ if len(relevant_chains) != 1:
461
+ raise MultipleChainsError(
462
+ f"Expected exactly one chain in structure with id {chain_id}."
463
+ )
464
+ chain = relevant_chains[0]
465
+
466
+ # Extract the coordinates
467
+ num_res = len(mmcif_object.chain_to_seqres[chain_id])
468
+ all_atom_positions = np.zeros(
469
+ [num_res, residue_constants.atom_type_num, 3], dtype=np.float32
470
+ )
471
+ all_atom_mask = np.zeros(
472
+ [num_res, residue_constants.atom_type_num], dtype=np.float32
473
+ )
474
+ for res_index in range(num_res):
475
+ pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
476
+ mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
477
+ res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
478
+ if not res_at_position.is_missing:
479
+ res = chain[
480
+ (
481
+ res_at_position.hetflag,
482
+ res_at_position.position.residue_number,
483
+ res_at_position.position.insertion_code,
484
+ )
485
+ ]
486
+ for atom in res.get_atoms():
487
+ atom_name = atom.get_name()
488
+ x, y, z = atom.get_coord()
489
+ if atom_name in residue_constants.atom_order.keys():
490
+ pos[residue_constants.atom_order[atom_name]] = [x, y, z]
491
+ mask[residue_constants.atom_order[atom_name]] = 1.0
492
+ elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
493
+ # Put the coords of the selenium atom in the sulphur column
494
+ pos[residue_constants.atom_order["SD"]] = [x, y, z]
495
+ mask[residue_constants.atom_order["SD"]] = 1.0
496
+
497
+ # Fix naming errors in arginine residues where NH2 is incorrectly
498
+ # assigned to be closer to CD than NH1
499
+ cd = residue_constants.atom_order['CD']
500
+ nh1 = residue_constants.atom_order['NH1']
501
+ nh2 = residue_constants.atom_order['NH2']
502
+ if (
503
+ res.get_resname() == 'ARG' and
504
+ all(mask[atom_index] for atom_index in (cd, nh1, nh2)) and
505
+ (np.linalg.norm(pos[nh1] - pos[cd]) >
506
+ np.linalg.norm(pos[nh2] - pos[cd]))
507
+ ):
508
+ pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy()
509
+ mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy()
510
+
511
+ all_atom_positions[res_index] = pos
512
+ all_atom_mask[res_index] = mask
513
+
514
+ if _zero_center_positions:
515
+ binary_mask = all_atom_mask.astype(bool)
516
+ translation_vec = all_atom_positions[binary_mask].mean(axis=0)
517
+ all_atom_positions[binary_mask] -= translation_vec
518
+
519
+ return all_atom_positions, all_atom_mask
PhysDock/data/tools/msa_identifiers.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utilities for extracting identifiers from MSA sequence descriptions."""
16
+
17
+ import dataclasses
18
+ import re
19
+ from typing import Optional
20
+
21
+
22
+ # Sequences coming from UniProtKB database come in the
23
+ # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
24
+ # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
25
+ _UNIPROT_PATTERN = re.compile(
26
+ r"""
27
+ ^
28
+ # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
29
+ (?:tr|sp)
30
+ \|
31
+ # A primary accession number of the UniProtKB entry.
32
+ (?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
33
+ # Occasionally there is a _0 or _1 isoform suffix, which we ignore.
34
+ (?:_\d)?
35
+ \|
36
+ # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
37
+ # protein ID code.
38
+ (?:[A-Za-z0-9]+)
39
+ _
40
+ # A mnemonic species identification code.
41
+ (?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
42
+ # Small BFD uses a final value after an underscore, which we ignore.
43
+ (?:_\d+)?
44
+ $
45
+ """,
46
+ re.VERBOSE,
47
+ )
48
+
49
+
50
+ @dataclasses.dataclass(frozen=True)
51
+ class Identifiers:
52
+ species_id: str = ""
53
+
54
+
55
+ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
56
+ """Gets accession id and species from an msa sequence identifier.
57
+
58
+ The sequence identifier has the format specified by
59
+ _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
60
+ An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
61
+
62
+ Args:
63
+ msa_sequence_identifier: a sequence identifier.
64
+
65
+ Returns:
66
+ An `Identifiers` instance with a species_id. These
67
+ can be empty in the case where no identifier was found.
68
+ """
69
+ matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
70
+ if matches:
71
+ return Identifiers(species_id=matches.group("SpeciesIdentifier"))
72
+ return Identifiers()
73
+
74
+
75
+ def _extract_sequence_identifier(description: str) -> Optional[str]:
76
+ """Extracts sequence identifier from description. Returns None if no match."""
77
+ split_description = description.split()
78
+ if split_description:
79
+ return split_description[0].partition("/")[0]
80
+ else:
81
+ return None
82
+
83
+
84
+ def get_identifiers(description: str) -> Identifiers:
85
+ """Computes extra MSA features from the description."""
86
+ sequence_identifier = _extract_sequence_identifier(description)
87
+ if sequence_identifier is None:
88
+ return Identifiers()
89
+ else:
90
+ return _parse_sequence_identifier(sequence_identifier)
PhysDock/data/tools/msa_pairing.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Pairing logic for multimer data pipeline."""
16
+
17
+ import collections
18
+ import functools
19
+ import string
20
+ from typing import Any, Dict, Iterable, List, Sequence, Mapping
21
+ import numpy as np
22
+ import pandas as pd
23
+ import scipy.linalg
24
+ from scipy.linalg import block_diag
25
+
26
+ # TODO: This stuff should probably also be in a config
27
+
28
+ MSA_GAP_IDX = 31
29
+ SEQUENCE_GAP_CUTOFF = 0.5
30
+ SEQUENCE_SIMILARITY_CUTOFF = 0.9
31
+
32
+ MAX_MSA_SIZE = 16384
33
+
34
+ MSA_PAD_VALUES = {'msa_all_seq': MSA_GAP_IDX,
35
+ 'msa_mask_all_seq': 1,
36
+ 'deletion_matrix_all_seq': 0,
37
+ 'deletion_matrix_int_all_seq': 0,
38
+ 'msa': MSA_GAP_IDX,
39
+ 'msa_mask': 1,
40
+ 'deletion_matrix': 0,
41
+ 'deletion_matrix_int': 0}
42
+
43
+ MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix')
44
+ SEQ_FEATURES = ('residue_index', 'aatype', 'all_atom_positions',
45
+ 'all_atom_mask', 'seq_mask', 'between_segment_residues',
46
+ 'has_alt_locations', 'has_hetatoms', 'asym_id', 'entity_id',
47
+ 'sym_id', 'entity_mask', 'deletion_mean',
48
+ 'prediction_atom_mask',
49
+ 'literature_positions', 'atom_indices_to_group_indices',
50
+ 'rigid_group_default_frame',
51
+ 'restype', 'token_index', 'token_exists', 's_mask',
52
+ "token_id_to_centre_atom_id",
53
+ "token_id_to_pseudo_beta_atom_id",
54
+ "token_id_to_chunk_sizes",
55
+ "token_id_to_conformer_id",
56
+ "is_protein",
57
+ "is_dna",
58
+ "is_rna",
59
+ "is_ligand",
60
+ "atom_index",
61
+ "atom_id_to_token_id",
62
+ "ref_space_uid",
63
+ "ref_pos",
64
+ "ref_feat",
65
+ "x_gt",
66
+ "x_exists",
67
+ "b_factors",
68
+ "a_mask"
69
+
70
+ )
71
+ TEMPLATE_FEATURES = ('template_aatype', 'template_all_atom_positions',
72
+ 'template_all_atom_mask')
73
+ CHAIN_FEATURES = ('num_alignments', 'seq_length')
74
+
75
+
76
+ def create_paired_features(
77
+ chains: Iterable[Mapping[str, np.ndarray]],
78
+ ) -> List[Mapping[str, np.ndarray]]:
79
+ """Returns the original chains with paired NUM_SEQ features.
80
+
81
+ Args:
82
+ chains: A list of feature dictionaries for each chain.
83
+ = list(all_chain_features.values())
84
+ Returns:
85
+ A list of feature dictionaries with sequence features including only
86
+ rows to be paired.
87
+ """
88
+ chains = list(chains)
89
+ chain_keys = chains[0].keys()
90
+
91
+ if len(chains) < 2:
92
+ return chains
93
+ else:
94
+
95
+ updated_chains = []
96
+ paired_chains_to_paired_row_indices = pair_sequences(chains)
97
+
98
+ paired_rows = reorder_paired_rows(
99
+ paired_chains_to_paired_row_indices)
100
+
101
+ for chain_num, chain in enumerate(chains):
102
+ new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
103
+ for feature_name in chain_keys:
104
+ if feature_name.endswith('_all_seq'):
105
+ feats_padded = pad_features(chain[feature_name], feature_name)
106
+
107
+ new_chain[feature_name] = feats_padded[paired_rows[:, chain_num]]
108
+ new_chain['num_alignments_all_seq'] = np.asarray(
109
+ len(paired_rows[:, chain_num]))
110
+
111
+ updated_chains.append(new_chain)
112
+
113
+ return updated_chains
114
+
115
+
116
+ def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
117
+ """Add a 'padding' row at the end of the features list.
118
+
119
+ The padding row will be selected as a 'paired' row in the case of partial
120
+ alignment - for the chain that doesn't have paired alignment.
121
+
122
+ Args:
123
+ feature: The feature to be padded.
124
+ feature_name: The name of the feature to be padded.
125
+
126
+ Returns:
127
+ The feature with an additional padding row.
128
+ """
129
+ assert feature.dtype != np.dtype(np.string_)
130
+ if feature_name in ('msa_all_seq',
131
+ 'deletion_matrix_all_seq'):
132
+ num_res = feature.shape[1]
133
+ padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
134
+ feature.dtype)
135
+ elif feature_name == 'msa_species_identifiers_all_seq':
136
+ padding = [b'']
137
+ else:
138
+ return feature
139
+ feats_padded = np.concatenate([feature, padding], axis=0)
140
+ return feats_padded
141
+
142
+
143
+ def _make_msa_df(chain_features: Mapping[str, np.ndarray]) -> pd.DataFrame:
144
+ """Makes dataframe with msa features needed for msa pairing."""
145
+ chain_msa = chain_features['msa_all_seq']
146
+ query_seq = chain_msa[0]
147
+
148
+ per_seq_similarity = np.sum(
149
+ query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
150
+ per_seq_gap = np.sum(chain_msa == 31, axis=-1) / float(len(query_seq))
151
+ msa_df = pd.DataFrame({
152
+ 'msa_species_identifiers':
153
+ chain_features['msa_species_identifiers_all_seq'],
154
+ 'msa_row':
155
+ np.arange(len(
156
+ chain_features['msa_species_identifiers_all_seq'])),
157
+ 'msa_similarity': per_seq_similarity,
158
+ 'gap': per_seq_gap
159
+ })
160
+ return msa_df
161
+
162
+
163
+ def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
164
+ """Creates mapping from species to msa dataframe of that species."""
165
+ species_lookup = {}
166
+ for species, species_df in msa_df.groupby('msa_species_identifiers'):
167
+ species_lookup[species] = species_df
168
+ return species_lookup
169
+
170
+
171
+ def _match_rows_by_sequence_similarity(this_species_msa_dfs: List[pd.DataFrame]
172
+ ) -> List[List[int]]:
173
+ """Finds MSA sequence pairings across chains based on sequence similarity.
174
+
175
+ Each chain's MSA sequences are first sorted by their sequence similarity to
176
+ their respective target sequence. The sequences are then paired, starting
177
+ from the sequences most similar to their target sequence.
178
+
179
+ Args:
180
+ this_species_msa_dfs: a list of dataframes containing MSA features for
181
+ sequences for a specific species.
182
+
183
+ Returns:
184
+ A list of lists, each containing M indices corresponding to paired MSA rows,
185
+ where M is the number of chains.
186
+ """
187
+ all_paired_msa_rows = []
188
+
189
+ num_seqs = [len(species_df) for species_df in this_species_msa_dfs
190
+ if species_df is not None]
191
+ take_num_seqs = np.min(num_seqs)
192
+
193
+ sort_by_similarity = (
194
+ lambda x: x.sort_values('msa_similarity', axis=0, ascending=False))
195
+
196
+ for species_df in this_species_msa_dfs:
197
+ if species_df is not None:
198
+ species_df_sorted = sort_by_similarity(species_df)
199
+ msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
200
+ else:
201
+ msa_rows = [-1] * take_num_seqs # take the last 'padding' row
202
+ all_paired_msa_rows.append(msa_rows)
203
+ all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
204
+ return all_paired_msa_rows
205
+
206
+
207
+ def pair_sequences(
208
+ examples: List[Mapping[str, np.ndarray]],
209
+ ) -> Dict[int, np.ndarray]:
210
+ """Returns indices for paired MSA sequences across chains."""
211
+
212
+ num_examples = len(examples)
213
+
214
+ all_chain_species_dict = []
215
+ common_species = set()
216
+ for chain_features in examples:
217
+ msa_df = _make_msa_df(chain_features) # """Makes dataframe with msa features needed for msa pairing."""
218
+ species_dict = _create_species_dict(msa_df)
219
+ all_chain_species_dict.append(species_dict)
220
+ common_species.update(set(species_dict))
221
+
222
+ common_species = sorted(common_species)
223
+ common_species.remove(b'') # Remove target sequence species.
224
+
225
+ all_paired_msa_rows = [np.zeros(len(examples), int)]
226
+
227
+ all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
228
+ all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]
229
+
230
+ for species in common_species:
231
+ if not species:
232
+ continue
233
+ this_species_msa_dfs = []
234
+ species_dfs_present = 0
235
+ for species_dict in all_chain_species_dict:
236
+ if species in species_dict:
237
+ this_species_msa_dfs.append(species_dict[species])
238
+ species_dfs_present += 1
239
+ else:
240
+ this_species_msa_dfs.append(None)
241
+
242
+ # Skip species that are present in only one chain.
243
+ if species_dfs_present <= 1:
244
+ continue
245
+
246
+ if np.any(
247
+ np.array([len(species_df) for species_df in
248
+ this_species_msa_dfs if
249
+ isinstance(species_df, pd.DataFrame)]) > 600):
250
+ continue
251
+
252
+ paired_msa_rows = _match_rows_by_sequence_similarity(this_species_msa_dfs)
253
+ all_paired_msa_rows.extend(paired_msa_rows)
254
+ all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
255
+ all_paired_msa_rows_dict = {
256
+ num_examples: np.array(paired_msa_rows) for
257
+ num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
258
+ }
259
+ return all_paired_msa_rows_dict
260
+
261
+
262
+ def reorder_paired_rows(all_paired_msa_rows_dict: Dict[int, np.ndarray]
263
+ ) -> np.ndarray:
264
+ """Creates a list of indices of paired MSA rows across chains.
265
+
266
+ Args:
267
+ all_paired_msa_rows_dict: a mapping from the number of paired chains to the
268
+ paired indices.
269
+
270
+ Returns:
271
+ a list of lists, each containing indices of paired MSA rows across chains.
272
+ The paired-index lists are ordered by:
273
+ 1) the number of chains in the paired alignment, i.e, all-chain pairings
274
+ will come first.
275
+ 2) e-values
276
+ """
277
+ all_paired_msa_rows = []
278
+
279
+ for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
280
+ paired_rows = all_paired_msa_rows_dict[num_pairings]
281
+ paired_rows_product = abs(np.array([np.prod(rows) for rows in paired_rows]))
282
+ paired_rows_sort_index = np.argsort(paired_rows_product)
283
+ all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])
284
+
285
+ return np.array(all_paired_msa_rows)
286
+
287
+
288
+ # def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
289
+ # """Like scipy.linalg.block_diag but with an optional padding value."""
290
+ # ones_arrs = [np.ones_like(x) for x in arrs]
291
+ # off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
292
+ # diag = scipy.linalg.block_diag(*arrs)
293
+ # diag += (off_diag_mask * pad_value).astype(diag.dtype)
294
+ # return diag
295
+
296
+
297
+ def _correct_post_merged_feats(
298
+ np_example: Mapping[str, np.ndarray],
299
+ np_chains_list: Sequence[Mapping[str, np.ndarray]],
300
+ pair_msa_sequences: bool
301
+ ) -> Mapping[str, np.ndarray]:
302
+ """Adds features that need to be computed/recomputed post merging."""
303
+ np_example['seq_length'] = np.asarray(
304
+ len(np_example['sequence_3'].split("-")),
305
+ dtype=np.int32
306
+ )
307
+ np_example['num_alignments'] = np.asarray(
308
+ np_example['msa'].shape[0],
309
+ dtype=np.int32
310
+ )
311
+
312
+ return np_example
313
+
314
+
315
+ def _pad_templates(chains: Sequence[Mapping[str, np.ndarray]],
316
+ max_templates: int) -> Sequence[Mapping[str, np.ndarray]]:
317
+ """For each chain pad the number of templates to a fixed size.
318
+
319
+ Args:
320
+ chains: A list of protein chains.
321
+ max_templates: Each chain will be padded to have this many templates.
322
+
323
+ Returns:
324
+ The list of chains, updated to have template features padded to
325
+ max_templates.
326
+ """
327
+ for chain in chains:
328
+ for k, v in chain.items():
329
+ if k in TEMPLATE_FEATURES:
330
+ padding = np.zeros_like(v.shape)
331
+ padding[0] = max_templates - v.shape[0]
332
+ padding = [(0, p) for p in padding]
333
+ chain[k] = np.pad(v, padding, mode='constant')
334
+ return chains
335
+
336
+
337
+ def _merge_features_from_multiple_chains(
338
+ chains: Sequence[Mapping[str, np.ndarray]],
339
+ pair_msa_sequences: bool) -> Mapping[str, np.ndarray]:
340
+ """Merge features from multiple chains.
341
+
342
+ Args:
343
+ chains: A list of feature dictionaries that we want to merge.
344
+ pair_msa_sequences: Whether to concatenate MSA features along the
345
+ num_res dimension (if True), or to block diagonalize them (if False).
346
+
347
+ Returns:
348
+ A feature dictionary for the merged example.
349
+ """
350
+ merged_example = {}
351
+
352
+ for feature_name in chains[0]:
353
+ if feature_name == "msa_species_identifiers" or feature_name == "num_alignments_all_seq" or feature_name == "num_alignments":
354
+ continue
355
+ feats = [x[feature_name] for x in chains]
356
+
357
+ feature_name_split = feature_name.split('_all_seq')[0]
358
+ if feature_name_split in MSA_FEATURES:
359
+ merged_example[feature_name] = np.concatenate(feats, axis=1)
360
+ elif feature_name_split == "templ_feat":
361
+ num_templ = feats[0].shape[0]
362
+ total_len = sum([feat.shape[1] for feat in feats])
363
+ out_mat = np.zeros([num_templ, total_len, total_len, 108], dtype=np.float32)
364
+ start = 0
365
+ end = 0
366
+ for feat in feats:
367
+ end += feat.shape[1]
368
+ out_mat[:, start:end, start:end] = feat
369
+ start = end
370
+ merged_example[feature_name] = out_mat
371
+ elif feature_name_split in SEQ_FEATURES:
372
+ merged_example[feature_name] = np.concatenate(feats, axis=0)
373
+ elif feature_name_split in TEMPLATE_FEATURES:
374
+ merged_example[feature_name] = np.concatenate(feats, axis=1)
375
+ elif feature_name_split in CHAIN_FEATURES:
376
+ merged_example[feature_name] = np.sum(x for x in feats).astype(np.int32)
377
+ else:
378
+
379
+ merged_example[feature_name] = feats[0]
380
+
381
+ return merged_example
382
+
383
+
384
+ def _merge_homomers_dense_msa(
385
+ chains: Iterable[Mapping[str, np.ndarray]]) -> Sequence[Mapping[str, np.ndarray]]:
386
+ """Merge all identical chains, making the resulting MSA dense.
387
+
388
+ Args:
389
+ chains: An iterable of features for each chain.
390
+
391
+ Returns:
392
+ A list of feature dictionaries. All features with the same entity_id
393
+ will be merged - MSA features will be concatenated along the num_res
394
+ dimension - making them dense.
395
+ """
396
+ entity_chains = collections.defaultdict(list)
397
+ for chain in chains:
398
+ entity_id = chain['entity_id'][0]
399
+ entity_chains[entity_id].append(chain)
400
+
401
+ grouped_chains = []
402
+ for entity_id in sorted(entity_chains):
403
+ chains = entity_chains[entity_id]
404
+
405
+ grouped_chains.append(chains)
406
+
407
+ chains = [
408
+ _merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
409
+ for chains in grouped_chains]
410
+ return chains
411
+
412
+
413
+ def _concatenate_paired_and_unpaired_features(
414
+ example: Mapping[str, np.ndarray]) -> Mapping[str, np.ndarray]:
415
+ """Merges paired and block-diagonalised features."""
416
+ features = MSA_FEATURES
417
+ for feature_name in features:
418
+ feat_all_seq_name = feature_name + '_all_seq'
419
+ if feature_name in example and feat_all_seq_name in example:
420
+ feat = example[feature_name]
421
+ feat_all_seq = example[feature_name + '_all_seq']
422
+
423
+ merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
424
+ example[feature_name] = merged_feat
425
+ example.pop(feature_name + '_all_seq', None)
426
+ # example['num_alignments_all_seq'] = np.array(example['msa_all_seq'].shape[0],
427
+ # dtype=np.int32)
428
+
429
+ return example
430
+
431
+
432
+ def merge_chain_features(np_chains_list: List[Mapping[str, np.ndarray]],
433
+ pair_msa_sequences: bool,
434
+ max_templates: int) -> Mapping[str, np.ndarray]:
435
+ """Merges features for multiple chains to single FeatureDict.
436
+
437
+ Args:
438
+ np_chains_list: List of FeatureDicts for each chain.
439
+ pair_msa_sequences: Whether to merge paired MSAs.
440
+ max_templates: The maximum number of templates to include.
441
+
442
+ Returns:
443
+ Single FeatureDict for entire complex.
444
+ """
445
+ # print(np_chains_list)
446
+ # for chain in np_chains_list:
447
+ # for k,v in chain.items():
448
+ # try:
449
+ # print(k,v.shape)
450
+ # except:
451
+ # print(k)
452
+ # print("################")
453
+
454
+ np_chains_list = _pad_templates(
455
+ np_chains_list, max_templates=max_templates)
456
+ # group chains
457
+ # print(np_chains_list)
458
+ np_chains_list = _merge_homomers_dense_msa(np_chains_list)
459
+
460
+ np_chains_list = [_concatenate_paired_and_unpaired_features(np_chain) for np_chain in np_chains_list]
461
+
462
+ assert len(np_chains_list) > 0, f"np_chain_list, error, {np_chains_list}"
463
+ np_example = _merge_features_from_multiple_chains(
464
+ np_chains_list, pair_msa_sequences=False)
465
+
466
+ np_example = _correct_post_merged_feats(
467
+ np_example=np_example,
468
+ np_chains_list=np_chains_list,
469
+ pair_msa_sequences=pair_msa_sequences)
470
+
471
+ return np_example
472
+
473
+
474
+ def deduplicate_unpaired_sequences(
475
+ np_chains: List[Mapping[str, np.ndarray]]) -> List[Mapping[str, np.ndarray]]:
476
+ """Removes unpaired sequences which duplicate a paired sequence."""
477
+
478
+ feature_names = np_chains[0].keys()
479
+ msa_features = MSA_FEATURES
480
+
481
+ for chain in np_chains:
482
+ # Convert the msa_all_seq numpy array to a tuple for hashing.
483
+ sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
484
+ keep_rows = []
485
+ # Go through unpaired MSA seqs and remove any rows that correspond to the
486
+ # sequences that are already present in the paired MSA.
487
+ for row_num, seq in enumerate(chain['msa']):
488
+ if tuple(seq) not in sequence_set:
489
+ keep_rows.append(row_num)
490
+ if keep_rows is not None:
491
+ for feature_name in feature_names:
492
+ if feature_name in msa_features:
493
+ chain[feature_name] = chain[feature_name][keep_rows]
494
+ chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
495
+
496
+ return np_chains
PhysDock/data/tools/nhmmer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Library to run Jackhmmer from Python."""
17
+
18
+ from concurrent import futures
19
+ import glob
20
+ import logging
21
+ import os
22
+ import subprocess
23
+ from typing import Any, Callable, Mapping, Optional, Sequence
24
+ from urllib import request
25
+
26
+ from . import parsers
27
+ from . import utils
28
+
29
+
30
+ class Nhmmer:
31
+ """Python wrapper of the Jackhmmer binary."""
32
+
33
+ def __init__(
34
+ self,
35
+ *,
36
+ binary_path: str,
37
+ database_path: str,
38
+ n_cpu: int = 8, # --cpu
39
+ e_value: float = 0.001, # -E --incE | Z-value is not used in AF3
40
+ filter_f3: float = 0.00005, # --F3 0.02 for sequences shorter than 50 nucleotides
41
+ get_tblout: bool = False,
42
+ incdom_e: Optional[float] = None,
43
+ dom_e: Optional[float] = None,
44
+ num_streamed_chunks: Optional[int] = None,
45
+ streaming_callback: Optional[Callable[[int], None]] = None,
46
+ ):
47
+ """Initializes the Python Nhmmer wrapper.
48
+
49
+ Args:
50
+ binary_path: The path to the jackhmmer executable.
51
+ database_path: The path to the jackhmmer database (FASTA format).
52
+ n_cpu: The number of CPUs to give Jackhmmer.
53
+ e_value: The E-value, see Jackhmmer docs for more details.
54
+ get_tblout: Whether to save tblout string.
55
+ filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
56
+ filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
57
+ filter_f3: Forward pre-filter, set to >1.0 to turn off.
58
+ incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
59
+ round.
60
+ dom_e: Domain e-value criteria for inclusion in tblout.
61
+ num_streamed_chunks: Number of database chunks to stream over.
62
+ streaming_callback: Callback function run after each chunk iteration with
63
+ the iteration number as argument.
64
+ """
65
+ self.binary_path = binary_path
66
+ self.database_path = database_path
67
+ self.num_streamed_chunks = num_streamed_chunks
68
+
69
+ if (
70
+ not os.path.exists(self.database_path)
71
+ and num_streamed_chunks is None
72
+ ):
73
+ logging.error("Could not find Jackhmmer database %s", database_path)
74
+ raise ValueError(
75
+ f"Could not find Jackhmmer database {database_path}"
76
+ )
77
+
78
+ self.n_cpu = n_cpu
79
+ self.e_value = e_value
80
+ self.filter_f3 = filter_f3
81
+ self.incdom_e = incdom_e
82
+ self.dom_e = dom_e
83
+ self.get_tblout = get_tblout
84
+ self.streaming_callback = streaming_callback
85
+
86
+ def _query_chunk(
87
+ self,
88
+ input_fasta_path: str,
89
+ database_path: str,
90
+ max_sequences: Optional[int] = None
91
+ ) -> Mapping[str, Any]:
92
+ """Queries the database chunk using Jackhmmer."""
93
+
94
+ with open(input_fasta_path, "r") as f:
95
+ sequences, desc = parsers.parse_fasta(f.read())
96
+ assert len(sequences) == 1, f"Parse Fasta File with only 1 Sequence, but found {len(sequences)}"
97
+ if len(sequences[0]) < 50:
98
+ self.filter_f3 = 0.02
99
+ else:
100
+ self.filter_f3 = 0.00005
101
+ with utils.tmpdir_manager() as query_tmp_dir:
102
+ sto_path = os.path.join(query_tmp_dir, "output.sto")
103
+
104
+ # The F1/F2/F3 are the expected proportion to pass each of the filtering
105
+ # stages (which get progressively more expensive), reducing these
106
+ # speeds up the pipeline at the expensive of sensitivity. They are
107
+ # currently set very low to make querying Mgnify run in a reasonable
108
+ # amount of time.
109
+ cmd_flags = [
110
+ # Don't pollute stdout with Jackhmmer output.
111
+ "-o",
112
+ "/dev/null",
113
+ "-A",
114
+ sto_path,
115
+ "--noali",
116
+ # Report only sequences with E-values <= x in per-sequence output.
117
+ "-E",
118
+ str(self.e_value),
119
+ "--incE",
120
+ str(self.e_value),
121
+ "--rna",
122
+ "--watson",
123
+ "--F3", # Only F3 is used
124
+ str(self.filter_f3),
125
+ "--cpu",
126
+ str(self.n_cpu),
127
+ ]
128
+ if self.get_tblout:
129
+ tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
130
+ cmd_flags.extend(["--tblout", tblout_path])
131
+
132
+ if self.dom_e is not None:
133
+ cmd_flags.extend(["--domE", str(self.dom_e)])
134
+
135
+ if self.incdom_e is not None:
136
+ cmd_flags.extend(["--incdomE", str(self.incdom_e)])
137
+
138
+ cmd = (
139
+ [self.binary_path]
140
+ + cmd_flags
141
+ + [input_fasta_path, database_path]
142
+ )
143
+ # print(cmd)
144
+ logging.info('Launching subprocess "%s"', " ".join(cmd))
145
+ process = subprocess.Popen(
146
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
147
+ )
148
+ with utils.timing(
149
+ f"Nhmmer ({os.path.basename(database_path)}) query"
150
+ ):
151
+ _, stderr = process.communicate()
152
+ retcode = process.wait()
153
+
154
+ if retcode:
155
+ raise RuntimeError(
156
+ "Nhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
157
+ )
158
+
159
+ # Get e-values for each target name
160
+ tbl = ""
161
+ if self.get_tblout:
162
+ with open(tblout_path) as f:
163
+ tbl = f.read()
164
+
165
+ if (max_sequences is None):
166
+ with open(sto_path) as f:
167
+ sto = f.read()
168
+ else:
169
+ sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)
170
+
171
+ raw_output = dict(
172
+ sto=sto,
173
+ tbl=tbl,
174
+ stderr=stderr,
175
+ e_value=self.e_value,
176
+ )
177
+
178
+ return raw_output
179
+
180
+ def query(self,
181
+ input_fasta_path: str,
182
+ max_sequences: Optional[int] = None
183
+ ) -> Sequence[Sequence[Mapping[str, Any]]]:
184
+ return self.query_multiple([input_fasta_path], max_sequences)
185
+
186
+ def query_multiple(self,
187
+ input_fasta_paths: Sequence[str],
188
+ max_sequences: Optional[int] = None
189
+ ) -> Sequence[Sequence[Mapping[str, Any]]]:
190
+ """Queries the database using Nhmmer."""
191
+ if self.num_streamed_chunks is None:
192
+ single_chunk_results = []
193
+ for input_fasta_path in input_fasta_paths:
194
+ single_chunk_result = self._query_chunk(
195
+ input_fasta_path, self.database_path, max_sequences,
196
+ )
197
+ single_chunk_results.append(single_chunk_result)
198
+ return single_chunk_results
199
+
200
+ db_basename = os.path.basename(self.database_path)
201
+ db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
202
+ db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
203
+
204
+ # Remove existing files to prevent OOM
205
+ for f in glob.glob(db_local_chunk("[0-9]*")):
206
+ try:
207
+ os.remove(f)
208
+ except OSError:
209
+ print(f"OSError while deleting {f}")
210
+
211
+ # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
212
+ with futures.ThreadPoolExecutor(max_workers=2) as executor:
213
+ chunked_outputs = [[] for _ in range(len(input_fasta_paths))]
214
+ for i in range(1, self.num_streamed_chunks + 1):
215
+ # Copy the chunk locally
216
+ if i == 1:
217
+ future = executor.submit(
218
+ request.urlretrieve,
219
+ db_remote_chunk(i),
220
+ db_local_chunk(i),
221
+ )
222
+ if i < self.num_streamed_chunks:
223
+ next_future = executor.submit(
224
+ request.urlretrieve,
225
+ db_remote_chunk(i + 1),
226
+ db_local_chunk(i + 1),
227
+ )
228
+
229
+ # Run Jackhmmer with the chunk
230
+ future.result()
231
+ for fasta_idx, input_fasta_path in enumerate(input_fasta_paths):
232
+ chunked_outputs[fasta_idx].append(
233
+ self._query_chunk(
234
+ input_fasta_path,
235
+ db_local_chunk(i),
236
+ max_sequences
237
+ )
238
+ )
239
+
240
+ # Remove the local copy of the chunk
241
+ os.remove(db_local_chunk(i))
242
+ # Do not set next_future for the last chunk so that this works
243
+ # even for databases with only 1 chunk
244
+ if (i < self.num_streamed_chunks):
245
+ future = next_future
246
+ if self.streaming_callback:
247
+ self.streaming_callback(i)
248
+ return chunked_outputs
249
+
250
+
251
+ if __name__ == '__main__':
252
+ # pass
253
+ nhmmer = Nhmmer(binary_path="/usr/bin/nhmmer",
254
+ # database_path="/group1/share01/data/alphafold3/rnacentral/v21.0/rnacentral.fasta")
255
+ database_path="/group1/share01/data/alphafold3/rfam/v14.9/Rfam_af3_clustered_all_seqs.fasta")
256
+ out = nhmmer.query(input_fasta_path="./test.fa")
257
+ print(out[0]["sto"])
PhysDock/data/tools/parse_msas.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import numpy as np
3
+ from typing import Optional, Sequence, Dict, OrderedDict, Any, Union
4
+ from scipy.sparse import coo_matrix
5
+
6
+ from .parsers import parse_fasta, parse_hhr, parse_stockholm, parse_a3m, parse_hmmsearch_a3m, \
7
+ parse_hmmsearch_sto, Msa, parse_stockholm_file, Msa
8
+ from . import msa_identifiers
9
+
10
+ FeatureDict = Dict[str, Union[np.ndarray, coo_matrix, None, Any]]
11
+
12
+
13
+ def load_txt(fname):
14
+ with open(fname, "r") as f:
15
+ data = f.read()
16
+ return data
17
+
18
+
19
+ amino_acids = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
20
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
21
+
22
+ HHBLITS_AA_TO_AA = {
23
+ "A": "A",
24
+ "B": "D",
25
+ "C": "C",
26
+ "D": "D",
27
+ "E": "E",
28
+ "F": "F",
29
+ "G": "G",
30
+ "H": "H",
31
+ "I": "I",
32
+ "J": "X",
33
+ "K": "K",
34
+ "L": "L",
35
+ "M": "M",
36
+ "N": "N",
37
+ "O": "X",
38
+ "P": "P",
39
+ "Q": "Q",
40
+ "R": "R",
41
+ "S": "S",
42
+ "T": "T",
43
+ "U": "C",
44
+ "V": "V",
45
+ "W": "W",
46
+ "X": "X",
47
+ "Y": "Y",
48
+ "Z": "E",
49
+ "-": "-",
50
+ }
51
+ standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
52
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
53
+ amino_acid_1to3 = {
54
+ "A": "ALA",
55
+ "R": "ARG",
56
+ "N": "ASN",
57
+ "D": "ASP",
58
+ "C": "CYS",
59
+ "Q": "GLN",
60
+ "E": "GLU",
61
+ "G": "GLY",
62
+ "H": "HIS",
63
+ "I": "ILE",
64
+ "L": "LEU",
65
+ "K": "LYS",
66
+ "M": "MET",
67
+ "F": "PHE",
68
+ "P": "PRO",
69
+ "S": "SER",
70
+ "T": "THR",
71
+ "W": "TRP",
72
+ "Y": "TYR",
73
+ "V": "VAL",
74
+ "X": "UNK",
75
+ }
76
+
77
+ amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
78
+
79
+ AA_TO_ID = {
80
+ amino_acid_3to1[ccd]: amino_acids.index(ccd) for ccd in standard_protein
81
+ }
82
+ AA_TO_ID["-"] = 31
83
+
84
+ robon_nucleic_acids = ["A", "G", "C", "U", "N", ]
85
+
86
+ RNA_TO_ID = {ch: robon_nucleic_acids.index(ch) + 21 for ch in robon_nucleic_acids}
87
+ RNA_TO_ID["-"] = 31
88
+
89
+
90
+ # DEBUG
91
+ # RNA_TO_ID["."] = 31
92
+
93
+
94
+ def make_msa_features(msas: Sequence[Msa], is_rna=False) -> FeatureDict:
95
+ """Constructs a feature dict of MSA features."""
96
+ if not msas:
97
+ raise ValueError("At least one MSA must be provided.")
98
+
99
+ int_msa = []
100
+ deletion_matrix = []
101
+ species_ids = []
102
+ seen_sequences = set()
103
+
104
+ for msa_index, msa in enumerate(msas):
105
+ if not msa:
106
+ raise ValueError(f"MSA {msa_index} must contain at least one sequence.")
107
+ for sequence_index, (sequence, msa_deletion_matrix) in enumerate(
108
+ zip(msa.sequences, msa.deletion_matrix)):
109
+ if sequence in seen_sequences:
110
+ continue
111
+ seen_sequences.add(sequence)
112
+ if is_rna:
113
+ int_msa.append(
114
+ [RNA_TO_ID.get(res, RNA_TO_ID["N"]) for res in sequence]
115
+ )
116
+ # deletion_matrix.append([
117
+ # msa_deletion_matrix[id] for id, res in enumerate(sequence)
118
+ # ])
119
+ else:
120
+ int_msa.append(
121
+ [AA_TO_ID[HHBLITS_AA_TO_AA[res]] for res in sequence]
122
+ )
123
+ deletion_matrix.append(msa_deletion_matrix)
124
+ identifiers = msa_identifiers.get_identifiers(
125
+ msa.descriptions[sequence_index]
126
+ )
127
+ species_ids.append(identifiers.species_id.encode("utf-8"))
128
+ features = {}
129
+ features["deletion_matrix"] = np.array(deletion_matrix, dtype=np.int8)
130
+
131
+ features["msa"] = np.array(int_msa, dtype=np.int8)
132
+ features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
133
+ return features
134
+
135
+
136
+ def parse_alignment_dir(
137
+ alignment_dir,
138
+ ):
139
+ # MSA Order: uniref90 bfd_uniclust30/bfd_uniref30 mgnify
140
+ uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
141
+ uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
142
+ reduced_bfd_out_path = os.path.join(alignment_dir, "reduced_bfd_hits.sto")
143
+ mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
144
+ bfd_uniref30_out_path = os.path.join(alignment_dir, f"bfd_uniref30_hits.a3m")
145
+ bfd_uniclust30_out_path = os.path.join(alignment_dir, f"bfd_uniclust30_hits.a3m")
146
+ rfam_out_path = os.path.join(alignment_dir, f"rfam_hits2.sto")
147
+ rnacentral_out_path = os.path.join(alignment_dir, f"rnacentral_hits.sto")
148
+ nt_out_path = os.path.join(alignment_dir, f"nt_hits.sto")
149
+
150
+ uniref90_msa = None
151
+ bfd_uniclust30_msa = None
152
+ bfd_uniref30_msa = None
153
+ reduced_bfd_msa = None
154
+ mgnify_msa = None
155
+ uniprot_msa = None
156
+ rfam_msa = None
157
+ rnacentral_msa = None
158
+ nt_msa = None
159
+
160
+ if os.path.exists(uniref90_out_path):
161
+ uniref90_msa = parse_stockholm(load_txt(uniref90_out_path))
162
+
163
+ if os.path.exists(bfd_uniclust30_out_path):
164
+ bfd_uniclust30_msa = parse_a3m(load_txt(bfd_uniclust30_out_path))
165
+ if os.path.exists(bfd_uniref30_out_path):
166
+ bfd_uniref30_msa = parse_a3m(load_txt(bfd_uniref30_out_path))
167
+ if os.path.exists(reduced_bfd_out_path):
168
+ reduced_bfd_msa = parse_stockholm(load_txt(reduced_bfd_out_path))
169
+ if os.path.exists(mgnify_out_path):
170
+ mgnify_msa = parse_stockholm(load_txt(mgnify_out_path))
171
+
172
+ if os.path.exists(uniprot_out_path):
173
+ uniprot_msa = parse_stockholm(load_txt(uniprot_out_path))
174
+
175
+ if os.path.exists(rfam_out_path):
176
+ # rfam_msa = parse_stockholm(load_txt(rfam_out_path))
177
+ rfam_msa = parse_stockholm_file(rfam_out_path)
178
+
179
+ if os.path.exists(rnacentral_out_path):
180
+ # rnacentral_msa = parse_stockholm(load_txt(rnacentral_out_path))
181
+ rnacentral_msa = parse_stockholm_file(rnacentral_out_path)
182
+
183
+ if os.path.exists(nt_out_path):
184
+ # nt_msa = parse_stockholm(load_txt(nt_out_path))
185
+ nt_msa = parse_stockholm_file(nt_out_path)
186
+
187
+ protein_msas = [uniref90_msa, bfd_uniclust30_msa, bfd_uniref30_msa, reduced_bfd_msa, mgnify_msa]
188
+ uniprot_msas = [uniprot_msa]
189
+ rna_msas = [rfam_msa, rnacentral_msa, nt_msa]
190
+ protein_msas = [i for i in protein_msas if i is not None]
191
+ uniprot_msas = [i for i in uniprot_msas if i is not None]
192
+ rna_msas = [i for i in rna_msas if i is not None]
193
+ output = dict()
194
+ if len(uniprot_msas) > 0:
195
+ uniprot_msa_features = make_msa_features(uniprot_msas)
196
+ output["msa_all_seq"] = uniprot_msa_features.pop("msa")
197
+ output["deletion_matrix_all_seq"] = uniprot_msa_features.pop("deletion_matrix")
198
+ output["msa_species_identifiers_all_seq"] = uniprot_msa_features.pop("msa_species_identifiers")
199
+ if len(protein_msas) > 0:
200
+ msa_features = make_msa_features(protein_msas)
201
+ output["msa"] = msa_features.pop("msa")
202
+ output["deletion_matrix"] = msa_features.pop("deletion_matrix")
203
+ output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
204
+
205
+ # TODO: DEBUG parse rna sto and
206
+ if len(rna_msas) > 0:
207
+ assert len(protein_msas) == 0
208
+ msa_features = make_msa_features(rna_msas, is_rna=True)
209
+ output["msa"] = msa_features.pop("msa")
210
+ output["deletion_matrix"] = msa_features.pop("deletion_matrix")
211
+ output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
212
+
213
+ return output
214
+
215
+
216
+ def parse_protein_alignment_dir(alignment_dir):
217
+ # MSA Order: uniref90 bfd_uniclust30/bfd_uniref30 mgnify
218
+ uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.sto")
219
+ reduced_bfd_out_path = os.path.join(alignment_dir, "reduced_bfd_hits.sto")
220
+ mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.sto")
221
+ bfd_uniref30_out_path = os.path.join(alignment_dir, f"bfd_uniref_hits.a3m")
222
+ bfd_uniclust30_out_path = os.path.join(alignment_dir, f"bfd_uniclust30_hits.a3m")
223
+
224
+ uniref90_msa = None
225
+ bfd_uniclust30_msa = None
226
+ bfd_uniref30_msa = None
227
+ reduced_bfd_msa = None
228
+ mgnify_msa = None
229
+
230
+ if os.path.exists(uniref90_out_path):
231
+ uniref90_msa = parse_stockholm(load_txt(uniref90_out_path))
232
+
233
+ if os.path.exists(bfd_uniclust30_out_path):
234
+ bfd_uniclust30_msa = parse_a3m(load_txt(bfd_uniclust30_out_path))
235
+ if os.path.exists(bfd_uniref30_out_path):
236
+ bfd_uniref30_msa = parse_a3m(load_txt(bfd_uniref30_out_path))
237
+ if os.path.exists(reduced_bfd_out_path):
238
+ reduced_bfd_msa = parse_stockholm(load_txt(reduced_bfd_out_path))
239
+ if os.path.exists(mgnify_out_path):
240
+ mgnify_msa = parse_stockholm(load_txt(mgnify_out_path))
241
+
242
+ protein_msas = [uniref90_msa, bfd_uniclust30_msa, bfd_uniref30_msa, reduced_bfd_msa, mgnify_msa]
243
+ protein_msas = [i for i in protein_msas if i is not None]
244
+
245
+ output = dict()
246
+ if len(protein_msas) > 0:
247
+ msa_features = make_msa_features(protein_msas)
248
+ output["msa"] = msa_features.pop("msa")
249
+ output["deletion_matrix"] = msa_features.pop("deletion_matrix")
250
+ output["msa_species_identifiers"] = msa_features.pop("msa_species_identifiers")
251
+
252
+ return output
253
+
254
+
255
+ def parse_uniprot_alignment_dir(
256
+ alignment_dir,
257
+ ):
258
+ uniprot_out_path = os.path.join(alignment_dir, "uniprot_hits.sto")
259
+ uniprot_msa = None
260
+ if os.path.exists(uniprot_out_path):
261
+ uniprot_msa = parse_stockholm(load_txt(uniprot_out_path))
262
+ uniprot_msas = [uniprot_msa]
263
+ uniprot_msas = [i for i in uniprot_msas if i is not None]
264
+ output = dict()
265
+ if len(uniprot_msas) > 0:
266
+ uniprot_msa_features = make_msa_features(uniprot_msas)
267
+ output["msa_all_seq"] = uniprot_msa_features.pop("msa")
268
+ output["deletion_matrix_all_seq"] = uniprot_msa_features.pop("deletion_matrix")
269
+ output["msa_species_identifiers_all_seq"] = uniprot_msa_features.pop("msa_species_identifiers")
270
+ return output
271
+
272
+
273
+ def parse_rna_from_input_fasta_path(input_fasta_path):
274
+ with open(input_fasta_path, "r") as f:
275
+ query_sequence, dec = parse_fasta(f.read())
276
+ deletion_matrix = [[0] * len(query_sequence[0])]
277
+
278
+ query_msa = Msa(
279
+ sequences=query_sequence,
280
+ deletion_matrix=deletion_matrix,
281
+ descriptions=dec
282
+ )
283
+ return query_msa
284
+
285
+
286
+ def parse_rna_single_alignment(input_fasta_path):
287
+ query_msa = parse_rna_from_input_fasta_path(input_fasta_path)
288
+ rna_msas = [query_msa]
289
+ msa_features = make_msa_features(rna_msas, is_rna=True)
290
+ output = dict()
291
+ output["msa"] = msa_features.pop("msa")
292
+ output["deletion_matrix"] = msa_features.pop("deletion_matrix")
293
+ return output
294
+
295
+
296
+ def parse_rna_alignment_dir(
297
+ alignment_dir,
298
+ input_fasta_path,
299
+ ):
300
+ rfam_out_path = os.path.join(alignment_dir, f"rfam_hits_realigned.sto")
301
+ rnacentral_out_path = os.path.join(alignment_dir, f"rnacentral_hits_realigned.sto")
302
+ nt_out_path = os.path.join(alignment_dir, f"nt_hits_realigned.sto")
303
+ rfam_msa = None
304
+ rnacentral_msa = None
305
+ nt_msa = None
306
+
307
+ if os.path.exists(rfam_out_path):
308
+ # rfam_msa = parse_stockholm(load_txt(rfam_out_path))
309
+ rfam_msa = parse_stockholm_file(rfam_out_path)
310
+
311
+ if os.path.exists(rnacentral_out_path):
312
+ # rnacentral_msa = parse_stockholm(load_txt(rnacentral_out_path))
313
+ rnacentral_msa = parse_stockholm_file(rnacentral_out_path)
314
+
315
+ if os.path.exists(nt_out_path):
316
+ # nt_msa = parse_stockholm(load_txt(nt_out_path))
317
+ nt_msa = parse_stockholm_file(nt_out_path)
318
+
319
+ query_msa = parse_rna_from_input_fasta_path(input_fasta_path)
320
+ rna_msas = [query_msa, rfam_msa, rnacentral_msa, nt_msa]
321
+
322
+ rna_msas = [i for i in rna_msas if i is not None and len(i) > 0]
323
+ # rna_msas_gt0 = [i for i in rna_msas if len(i) > 0]
324
+ output = dict()
325
+ msa_features = make_msa_features(rna_msas, is_rna=True)
326
+ output["msa"] = msa_features.pop("msa")
327
+ output["deletion_matrix"] = msa_features.pop("deletion_matrix")
328
+ return output
PhysDock/data/tools/parsers.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for parsing various file formats."""
17
+ import collections
18
+ import dataclasses
19
+ import itertools
20
+ import re
21
+ import string
22
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set
23
+
24
+ DeletionMatrix = Sequence[Sequence[int]]
25
+
26
+
27
+ @dataclasses.dataclass(frozen=True)
28
+ class Msa:
29
+ """Class representing a parsed MSA file"""
30
+ sequences: Sequence[str]
31
+ deletion_matrix: DeletionMatrix
32
+ descriptions: Optional[Sequence[str]]
33
+
34
+ def __post_init__(self):
35
+ if (not (
36
+ len(self.sequences) ==
37
+ len(self.deletion_matrix) ==
38
+ len(self.descriptions)
39
+ )):
40
+ raise ValueError(
41
+ "All fields for an MSA must have the same length"
42
+ )
43
+
44
+ def __len__(self):
45
+ return len(self.sequences)
46
+
47
+ def truncate(self, max_seqs: int):
48
+ return Msa(
49
+ sequences=self.sequences[:max_seqs],
50
+ deletion_matrix=self.deletion_matrix[:max_seqs],
51
+ descriptions=self.descriptions[:max_seqs],
52
+ )
53
+
54
+
55
+ @dataclasses.dataclass(frozen=True)
56
+ class TemplateHit:
57
+ """Class representing a template hit."""
58
+
59
+ index: int
60
+ name: str
61
+ aligned_cols: int
62
+ sum_probs: Optional[float]
63
+ query: str
64
+ hit_sequence: str
65
+ indices_query: List[int]
66
+ indices_hit: List[int]
67
+
68
+
69
+ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
70
+ """Parses FASTA string and returns list of strings with amino-acid sequences.
71
+
72
+ Arguments:
73
+ fasta_string: The string contents of a FASTA file.
74
+
75
+ Returns:
76
+ A tuple of two lists:
77
+ * A list of sequences.
78
+ * A list of sequence descriptions taken from the comment lines. In the
79
+ same order as the sequences.
80
+ """
81
+ sequences = []
82
+ descriptions = []
83
+ index = -1
84
+ for line in fasta_string.splitlines():
85
+ line = line.strip()
86
+ if line.startswith(">"):
87
+ index += 1
88
+ descriptions.append(line[1:]) # Remove the '>' at the beginning.
89
+ sequences.append("")
90
+ continue
91
+ elif line.startswith("#"):
92
+ continue
93
+ elif not line:
94
+ continue # Skip blank lines.
95
+ sequences[index] += line
96
+
97
+ return sequences, descriptions
98
+
99
+
100
+
101
+
102
+ def parse_stockholm(stockholm_string: str) -> Msa:
103
+ """Parses sequences and deletion matrix from stockholm format alignment.
104
+
105
+ Args:
106
+ stockholm_string: The string contents of a stockholm file. The first
107
+ sequence in the file should be the query sequence.
108
+
109
+ Returns:
110
+ A tuple of:
111
+ * A list of sequences that have been aligned to the query. These
112
+ might contain duplicates.
113
+ * The deletion matrix for the alignment as a list of lists. The element
114
+ at `deletion_matrix[i][j]` is the number of residues deleted from
115
+ the aligned sequence i at residue position j.
116
+ * The names of the targets matched, including the jackhmmer subsequence
117
+ suffix.
118
+ """
119
+ name_to_sequence = collections.OrderedDict()
120
+ for line in stockholm_string.splitlines():
121
+ line = line.strip()
122
+ if not line or line.startswith(("#", "//")):
123
+ continue
124
+ name, sequence = line.split()
125
+ if name not in name_to_sequence:
126
+ name_to_sequence[name] = ""
127
+ name_to_sequence[name] += sequence
128
+
129
+ msa = []
130
+ deletion_matrix = []
131
+
132
+ query = ""
133
+ keep_columns = []
134
+ for seq_index, sequence in enumerate(name_to_sequence.values()):
135
+ if seq_index == 0:
136
+ # Gather the columns with gaps from the query
137
+ query = sequence
138
+ keep_columns = [i for i, res in enumerate(query) if res != "-"]
139
+
140
+ # Remove the columns with gaps in the query from all sequences.
141
+ aligned_sequence = "".join([sequence[c] for c in keep_columns])
142
+
143
+ msa.append(aligned_sequence)
144
+
145
+ # Count the number of deletions w.r.t. query.
146
+ deletion_vec = []
147
+ deletion_count = 0
148
+ for seq_res, query_res in zip(sequence, query):
149
+ if seq_res != "-" or query_res != "-":
150
+ if query_res == "-":
151
+ deletion_count += 1
152
+ else:
153
+ deletion_vec.append(deletion_count)
154
+ deletion_count = 0
155
+ deletion_matrix.append(deletion_vec)
156
+
157
+ return Msa(
158
+ sequences=msa,
159
+ deletion_matrix=deletion_matrix,
160
+ descriptions=list(name_to_sequence.keys())
161
+ )
162
+
163
+
164
+ def parse_stockholm_file(stockholm_file: str) -> Msa:
165
+ """Parses sequences and deletion matrix from stockholm format alignment.
166
+
167
+ Args:
168
+ stockholm_string: The string contents of a stockholm file. The first
169
+ sequence in the file should be the query sequence.
170
+
171
+ Returns:
172
+ A tuple of:
173
+ * A list of sequences that have been aligned to the query. These
174
+ might contain duplicates.
175
+ * The deletion matrix for the alignment as a list of lists. The element
176
+ at `deletion_matrix[i][j]` is the number of residues deleted from
177
+ the aligned sequence i at residue position j.
178
+ * The names of the targets matched, including the jackhmmer subsequence
179
+ suffix.
180
+ """
181
+ name_to_sequence = collections.OrderedDict()
182
+ with open(stockholm_file, "r") as f:
183
+ for line in f:
184
+ line = line.strip()
185
+
186
+ if not line or line.startswith(("#", "//")):
187
+ continue
188
+ name, sequence = line.split()
189
+ sequence = "".join([c for c in sequence if not c.islower() and c not in ["*","."]])
190
+ if name not in name_to_sequence:
191
+ name_to_sequence[name] = ""
192
+ name_to_sequence[name] += sequence
193
+
194
+ msa = []
195
+ deletion_matrix = []
196
+
197
+ query = ""
198
+ keep_columns = []
199
+ for seq_index, sequence in enumerate(name_to_sequence.values()):
200
+ if seq_index == 0:
201
+ # Gather the columns with gaps from the query
202
+ query = sequence
203
+ keep_columns = [i for i, res in enumerate(query) if res != "-"]
204
+
205
+ # Remove the columns with gaps in the query from all sequences.
206
+ aligned_sequence = "".join([sequence[c] for c in keep_columns])
207
+
208
+ msa.append(aligned_sequence)
209
+
210
+ # Count the number of deletions w.r.t. query.
211
+ deletion_vec = []
212
+ deletion_count = 0
213
+ for seq_res, query_res in zip(sequence, query):
214
+ if seq_res != "-" or query_res != "-":
215
+ if query_res == "-":
216
+ deletion_count += 1
217
+ else:
218
+ deletion_vec.append(deletion_count)
219
+ deletion_count = 0
220
+ deletion_matrix.append(deletion_vec)
221
+ return Msa(
222
+ sequences=msa,
223
+ deletion_matrix=deletion_matrix,
224
+ descriptions=list(name_to_sequence.keys())
225
+ )
226
+
227
+
228
+ def parse_a3m(a3m_string: str) -> Msa:
229
+ """Parses sequences and deletion matrix from a3m format alignment.
230
+
231
+ Args:
232
+ a3m_string: The string contents of a a3m file. The first sequence in the
233
+ file should be the query sequence.
234
+
235
+ Returns:
236
+ A tuple of:
237
+ * A list of sequences that have been aligned to the query. These
238
+ might contain duplicates.
239
+ * The deletion matrix for the alignment as a list of lists. The element
240
+ at `deletion_matrix[i][j]` is the number of residues deleted from
241
+ the aligned sequence i at residue position j.
242
+ """
243
+ sequences, descriptions = parse_fasta(a3m_string)
244
+ deletion_matrix = []
245
+ for msa_sequence in sequences:
246
+ deletion_vec = []
247
+ deletion_count = 0
248
+ for j in msa_sequence:
249
+ if j.islower():
250
+ deletion_count += 1
251
+ else:
252
+ deletion_vec.append(deletion_count)
253
+ deletion_count = 0
254
+ deletion_matrix.append(deletion_vec)
255
+
256
+ # Make the MSA matrix out of aligned (deletion-free) sequences.
257
+ deletion_table = str.maketrans("", "", string.ascii_lowercase)
258
+ aligned_sequences = [s.translate(deletion_table) for s in sequences]
259
+ return Msa(
260
+ sequences=aligned_sequences,
261
+ deletion_matrix=deletion_matrix,
262
+ descriptions=descriptions
263
+ )
264
+
265
+
266
+ def _convert_sto_seq_to_a3m(
267
+ query_non_gaps: Sequence[bool], sto_seq: str
268
+ ) -> Iterable[str]:
269
+ for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
270
+ if is_query_res_non_gap:
271
+ yield sequence_res
272
+ elif sequence_res != "-":
273
+ yield sequence_res.lower()
274
+
275
+
276
+ def convert_stockholm_to_a3m(
277
+ stockholm_format: str,
278
+ max_sequences: Optional[int] = None,
279
+ remove_first_row_gaps: bool = True,
280
+ ) -> str:
281
+ """Converts MSA in Stockholm format to the A3M format."""
282
+ descriptions = {}
283
+ sequences = {}
284
+ reached_max_sequences = False
285
+
286
+ for line in stockholm_format.splitlines():
287
+ reached_max_sequences = (
288
+ max_sequences and len(sequences) >= max_sequences
289
+ )
290
+ if line.strip() and not line.startswith(("#", "//")):
291
+ # Ignore blank lines, markup and end symbols - remainder are alignment
292
+ # sequence parts.
293
+ seqname, aligned_seq = line.split(maxsplit=1)
294
+ if seqname not in sequences:
295
+ if reached_max_sequences:
296
+ continue
297
+ sequences[seqname] = ""
298
+ sequences[seqname] += aligned_seq
299
+
300
+ for line in stockholm_format.splitlines():
301
+ if line[:4] == "#=GS":
302
+ # Description row - example format is:
303
+ # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
304
+ columns = line.split(maxsplit=3)
305
+ seqname, feature = columns[1:3]
306
+ value = columns[3] if len(columns) == 4 else ""
307
+ if feature != "DE":
308
+ continue
309
+ if reached_max_sequences and seqname not in sequences:
310
+ continue
311
+ descriptions[seqname] = value
312
+ if len(descriptions) == len(sequences):
313
+ break
314
+
315
+ # Convert sto format to a3m line by line
316
+ a3m_sequences = {}
317
+ if (remove_first_row_gaps):
318
+ # query_sequence is assumed to be the first sequence
319
+ query_sequence = next(iter(sequences.values()))
320
+ query_non_gaps = [res != "-" for res in query_sequence]
321
+ for seqname, sto_sequence in sequences.items():
322
+ # Dots are optional in a3m format and are commonly removed.
323
+ out_sequence = sto_sequence.replace('.', '')
324
+ if (remove_first_row_gaps):
325
+ out_sequence = ''.join(
326
+ _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)
327
+ )
328
+ a3m_sequences[seqname] = out_sequence
329
+
330
+ fasta_chunks = (
331
+ f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
332
+ for k in a3m_sequences
333
+ )
334
+ return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
335
+
336
+
337
+ def _keep_line(line: str, seqnames: Set[str]) -> bool:
338
+ """Function to decide which lines to keep."""
339
+ if not line.strip():
340
+ return True
341
+ if line.strip() == '//': # End tag
342
+ return True
343
+ if line.startswith('# STOCKHOLM'): # Start tag
344
+ return True
345
+ if line.startswith('#=GC RF'): # Reference Annotation Line
346
+ return True
347
+ if line[:4] == '#=GS': # Description lines - keep if sequence in list.
348
+ _, seqname, _ = line.split(maxsplit=2)
349
+ return seqname in seqnames
350
+ elif line.startswith('#'): # Other markup - filter out
351
+ return False
352
+ else: # Alignment data - keep if sequence in list.
353
+ seqname = line.partition(' ')[0]
354
+ return seqname in seqnames
355
+
356
+
357
+ def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
358
+ """Reads + truncates a Stockholm file while preventing excessive RAM usage."""
359
+ seqnames = set()
360
+ filtered_lines = []
361
+
362
+ with open(stockholm_msa_path) as f:
363
+ for line in f:
364
+ if line.strip() and not line.startswith(('#', '//')):
365
+ # Ignore blank lines, markup and end symbols - remainder are alignment
366
+ # sequence parts.
367
+ seqname = line.partition(' ')[0]
368
+ seqnames.add(seqname)
369
+ if len(seqnames) >= max_sequences:
370
+ break
371
+
372
+ f.seek(0)
373
+ for line in f:
374
+ if _keep_line(line, seqnames):
375
+ filtered_lines.append(line)
376
+
377
+ return ''.join(filtered_lines)
378
+
379
+
380
+ def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
381
+ """Removes empty columns (dashes-only) from a Stockholm MSA."""
382
+ processed_lines = {}
383
+ unprocessed_lines = {}
384
+ for i, line in enumerate(stockholm_msa.splitlines()):
385
+ if line.startswith('#=GC RF'):
386
+ reference_annotation_i = i
387
+ reference_annotation_line = line
388
+ # Reached the end of this chunk of the alignment. Process chunk.
389
+ _, _, first_alignment = line.rpartition(' ')
390
+ mask = []
391
+ for j in range(len(first_alignment)):
392
+ for _, unprocessed_line in unprocessed_lines.items():
393
+ prefix, _, alignment = unprocessed_line.rpartition(' ')
394
+ if alignment[j] != '-':
395
+ mask.append(True)
396
+ break
397
+ else: # Every row contained a hyphen - empty column.
398
+ mask.append(False)
399
+ # Add reference annotation for processing with mask.
400
+ unprocessed_lines[reference_annotation_i] = reference_annotation_line
401
+
402
+ if not any(mask): # All columns were empty. Output empty lines for chunk.
403
+ for line_index in unprocessed_lines:
404
+ processed_lines[line_index] = ''
405
+ else:
406
+ for line_index, unprocessed_line in unprocessed_lines.items():
407
+ prefix, _, alignment = unprocessed_line.rpartition(' ')
408
+ masked_alignment = ''.join(itertools.compress(alignment, mask))
409
+ processed_lines[line_index] = f'{prefix} {masked_alignment}'
410
+
411
+ # Clear raw_alignments.
412
+ unprocessed_lines = {}
413
+ elif line.strip() and not line.startswith(('#', '//')):
414
+ unprocessed_lines[i] = line
415
+ else:
416
+ processed_lines[i] = line
417
+ return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))
418
+
419
+
420
+ def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
421
+ """Remove duplicate sequences (ignoring insertions wrt query)."""
422
+ sequence_dict = collections.defaultdict(str)
423
+
424
+ # First we must extract all sequences from the MSA.
425
+ for line in stockholm_msa.splitlines():
426
+ # Only consider the alignments - ignore reference annotation, empty lines,
427
+ # descriptions or markup.
428
+ if line.strip() and not line.startswith(('#', '//')):
429
+ line = line.strip()
430
+ seqname, alignment = line.split()
431
+ sequence_dict[seqname] += alignment
432
+
433
+ seen_sequences = set()
434
+ seqnames = set()
435
+ # First alignment is the query.
436
+ query_align = next(iter(sequence_dict.values()))
437
+ mask = [c != '-' for c in query_align] # Mask is False for insertions.
438
+ for seqname, alignment in sequence_dict.items():
439
+ # Apply mask to remove all insertions from the string.
440
+ masked_alignment = ''.join(itertools.compress(alignment, mask))
441
+ if masked_alignment in seen_sequences:
442
+ continue
443
+ else:
444
+ seen_sequences.add(masked_alignment)
445
+ seqnames.add(seqname)
446
+
447
+ filtered_lines = []
448
+ for line in stockholm_msa.splitlines():
449
+ if _keep_line(line, seqnames):
450
+ filtered_lines.append(line)
451
+
452
+ return '\n'.join(filtered_lines) + '\n'
453
+
454
+
455
+ def _get_hhr_line_regex_groups(
456
+ regex_pattern: str, line: str
457
+ ) -> Sequence[Optional[str]]:
458
+ match = re.match(regex_pattern, line)
459
+ if match is None:
460
+ raise RuntimeError(f"Could not parse query line {line}")
461
+ return match.groups()
462
+
463
+
464
+ def _update_hhr_residue_indices_list(
465
+ sequence: str, start_index: int, indices_list: List[int]
466
+ ):
467
+ """Computes the relative indices for each residue with respect to the original sequence."""
468
+ counter = start_index
469
+ for symbol in sequence:
470
+ if symbol == "-":
471
+ indices_list.append(-1)
472
+ else:
473
+ indices_list.append(counter)
474
+ counter += 1
475
+
476
+
477
+ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
478
+ """Parses the detailed HMM HMM comparison section for a single Hit.
479
+
480
+ This works on .hhr files generated from both HHBlits and HHSearch.
481
+
482
+ Args:
483
+ detailed_lines: A list of lines from a single comparison section between 2
484
+ sequences (which each have their own HMM's)
485
+
486
+ Returns:
487
+ A dictionary with the information from that detailed comparison section
488
+
489
+ Raises:
490
+ RuntimeError: If a certain line cannot be processed
491
+ """
492
+ # Parse first 2 lines.
493
+ number_of_hit = int(detailed_lines[0].split()[-1])
494
+ name_hit = detailed_lines[1][1:]
495
+
496
+ # Parse the summary line.
497
+ pattern = (
498
+ "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
499
+ " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
500
+ "]*Template_Neff=(.*)"
501
+ )
502
+ match = re.match(pattern, detailed_lines[2])
503
+ if match is None:
504
+ raise RuntimeError(
505
+ "Could not parse section: %s. Expected this: \n%s to contain summary."
506
+ % (detailed_lines, detailed_lines[2])
507
+ )
508
+ (_, _, _, aligned_cols, _, _, sum_probs, _) = [
509
+ float(x) for x in match.groups()
510
+ ]
511
+
512
+ # The next section reads the detailed comparisons. These are in a 'human
513
+ # readable' format which has a fixed length. The strategy employed is to
514
+ # assume that each block starts with the query sequence line, and to parse
515
+ # that with a regexp in order to deduce the fixed length used for that block.
516
+ query = ""
517
+ hit_sequence = ""
518
+ indices_query = []
519
+ indices_hit = []
520
+ length_block = None
521
+
522
+ for line in detailed_lines[3:]:
523
+ # Parse the query sequence line
524
+ if (
525
+ line.startswith("Q ")
526
+ and not line.startswith("Q ss_dssp")
527
+ and not line.startswith("Q ss_pred")
528
+ and not line.startswith("Q Consensus")
529
+ ):
530
+ # Thus the first 17 characters must be 'Q <query_name> ', and we can parse
531
+ # everything after that.
532
+ # start sequence end total_sequence_length
533
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
534
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
535
+
536
+ # Get the length of the parsed block using the start and finish indices,
537
+ # and ensure it is the same as the actual block length.
538
+ start = int(groups[0]) - 1 # Make index zero based.
539
+ delta_query = groups[1]
540
+ end = int(groups[2])
541
+ num_insertions = len([x for x in delta_query if x == "-"])
542
+ length_block = end - start + num_insertions
543
+ assert length_block == len(delta_query)
544
+
545
+ # Update the query sequence and indices list.
546
+ query += delta_query
547
+ _update_hhr_residue_indices_list(delta_query, start, indices_query)
548
+
549
+ elif line.startswith("T "):
550
+ # Parse the hit sequence.
551
+ if (
552
+ not line.startswith("T ss_dssp")
553
+ and not line.startswith("T ss_pred")
554
+ and not line.startswith("T Consensus")
555
+ ):
556
+ # Thus the first 17 characters must be 'T <hit_name> ', and we can
557
+ # parse everything after that.
558
+ # start sequence end total_sequence_length
559
+ patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
560
+ groups = _get_hhr_line_regex_groups(patt, line[17:])
561
+ start = int(groups[0]) - 1 # Make index zero based.
562
+ delta_hit_sequence = groups[1]
563
+ assert length_block == len(delta_hit_sequence)
564
+
565
+ # Update the hit sequence and indices list.
566
+ hit_sequence += delta_hit_sequence
567
+ _update_hhr_residue_indices_list(
568
+ delta_hit_sequence, start, indices_hit
569
+ )
570
+
571
+ return TemplateHit(
572
+ index=number_of_hit,
573
+ name=name_hit,
574
+ aligned_cols=int(aligned_cols),
575
+ sum_probs=sum_probs,
576
+ query=query,
577
+ hit_sequence=hit_sequence,
578
+ indices_query=indices_query,
579
+ indices_hit=indices_hit,
580
+ )
581
+
582
+
583
+ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
584
+ """Parses the content of an entire HHR file."""
585
+ lines = hhr_string.splitlines()
586
+
587
+ # Each .hhr file starts with a results table, then has a sequence of hit
588
+ # "paragraphs", each paragraph starting with a line 'No <hit number>'. We
589
+ # iterate through each paragraph to parse each hit.
590
+
591
+ block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
592
+
593
+ hits = []
594
+ if block_starts:
595
+ block_starts.append(len(lines)) # Add the end of the final block.
596
+ for i in range(len(block_starts) - 1):
597
+ hits.append(
598
+ _parse_hhr_hit(lines[block_starts[i]: block_starts[i + 1]])
599
+ )
600
+ return hits
601
+
602
+
603
+ def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
604
+ """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
605
+ e_values = {"query": 0}
606
+ lines = [line for line in tblout.splitlines() if line[0] != "#"]
607
+ # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
608
+ # space-delimited. Relevant fields are (1) target name: and
609
+ # (5) E-value (full sequence) (numbering from 1).
610
+ for line in lines:
611
+ fields = line.split()
612
+ e_value = fields[4]
613
+ target_name = fields[0]
614
+ e_values[target_name] = float(e_value)
615
+ return e_values
616
+
617
+
618
+ def _get_indices(sequence: str, start: int) -> List[int]:
619
+ """Returns indices for non-gap/insert residues starting at the given index."""
620
+ indices = []
621
+ counter = start
622
+ for symbol in sequence:
623
+ # Skip gaps but add a placeholder so that the alignment is preserved.
624
+ if symbol == '-':
625
+ indices.append(-1)
626
+ # Skip deleted residues, but increase the counter.
627
+ elif symbol.islower():
628
+ counter += 1
629
+ # Normal aligned residue. Increase the counter and append to indices.
630
+ else:
631
+ indices.append(counter)
632
+ counter += 1
633
+ return indices
634
+
635
+
636
+ @dataclasses.dataclass(frozen=True)
637
+ class HitMetadata:
638
+ pdb_id: str
639
+ chain: str
640
+ start: int
641
+ end: int
642
+ length: int
643
+ text: str
644
+
645
+
646
+ def _parse_hmmsearch_description(description: str) -> HitMetadata:
647
+ """Parses the hmmsearch A3M sequence description line."""
648
+ # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
649
+ # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
650
+ match = re.match(
651
+ r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
652
+ description.strip())
653
+
654
+ if not match:
655
+ raise ValueError(f'Could not parse description: "{description}".')
656
+
657
+ return HitMetadata(
658
+ pdb_id=match[1],
659
+ chain=match[2],
660
+ start=int(match[3]),
661
+ end=int(match[4]),
662
+ length=int(match[5]),
663
+ text=match[6]
664
+ )
665
+
666
+
667
+ def parse_hmmsearch_a3m(
668
+ query_sequence: str,
669
+ a3m_string: str,
670
+ skip_first: bool = True
671
+ ) -> Sequence[TemplateHit]:
672
+ """Parses an a3m string produced by hmmsearch.
673
+
674
+ Args:
675
+ query_sequence: The query sequence.
676
+ a3m_string: The a3m string produced by hmmsearch.
677
+ skip_first: Whether to skip the first sequence in the a3m string.
678
+
679
+ Returns:
680
+ A sequence of `TemplateHit` results.
681
+ """
682
+ # Zip the descriptions and MSAs together, skip the first query sequence.
683
+ parsed_a3m = list(zip(*parse_fasta(a3m_string)))
684
+ if skip_first:
685
+ parsed_a3m = parsed_a3m[1:]
686
+
687
+ indices_query = _get_indices(query_sequence, start=0)
688
+
689
+ hits = []
690
+ for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
691
+ if 'mol:protein' not in hit_description:
692
+ continue # Skip non-protein chains.
693
+ metadata = _parse_hmmsearch_description(hit_description)
694
+ # Aligned columns are only the match states.
695
+ aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
696
+ indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)
697
+
698
+ hit = TemplateHit(
699
+ index=i,
700
+ name=f'{metadata.pdb_id}_{metadata.chain}',
701
+ aligned_cols=aligned_cols,
702
+ sum_probs=None,
703
+ query=query_sequence,
704
+ hit_sequence=hit_sequence.upper(),
705
+ indices_query=indices_query,
706
+ indices_hit=indices_hit,
707
+ )
708
+ hits.append(hit)
709
+
710
+ return hits
711
+
712
+
713
+ def parse_hmmsearch_sto(
714
+ output_string: str,
715
+ input_sequence: str
716
+ ) -> Sequence[TemplateHit]:
717
+ """Gets parsed template hits from the raw string output by the tool."""
718
+ a3m_string = convert_stockholm_to_a3m(
719
+ output_string,
720
+ remove_first_row_gaps=False
721
+ )
722
+ template_hits = parse_hmmsearch_a3m(
723
+ query_sequence=input_sequence,
724
+ a3m_string=a3m_string,
725
+ skip_first=False
726
+ )
727
+ return template_hits
PhysDock/data/tools/rdkit.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import rdkit
3
+ from rdkit import Chem
4
+ from rdkit.Chem import AllChem, rdmolops
5
+ from rdkit.Chem.rdchem import ChiralType, BondType
6
+ import numpy as np
7
+ import copy
8
+
9
+ from PhysDock.data.constants.periodic_table import PeriodicTable
10
+ from PhysDock.utils.io_utils import load_txt
11
+
12
+
13
+ def get_ref_mol(string):
14
+ if Chem.MolFromSmiles(string) is not None:
15
+ mol = Chem.MolFromSmiles(string)
16
+ elif os.path.isfile(string) and string.split(".")[-1] == "smi":
17
+ mol = Chem.MolFromSmiles(load_txt(string).strip())
18
+ else:
19
+ mol = None
20
+ if mol is not None:
21
+ AllChem.EmbedMolecule(mol,maxAttempts=100000)
22
+ # mol2 = Chem.MolFromPDBBlock(Chem.MolToPDBBlock(mol))
23
+ # for atom in mol2.GetAtoms():
24
+ # # if atom.GetChiralTag() != ChiralType.CHI_UNSPECIFIED:
25
+ # print(f"Atom {atom.GetIdx()} has chiral tag: {atom.GetChiralTag()}")
26
+ # mol = Chem.RemoveAllHs(mol2)
27
+ mol = Chem.RemoveAllHs(mol)
28
+ return mol
29
+
30
+
31
+ Hybridization = {
32
+ Chem.rdchem.HybridizationType.S: 0,
33
+ Chem.rdchem.HybridizationType.SP: 1,
34
+ Chem.rdchem.HybridizationType.SP2: 2,
35
+ Chem.rdchem.HybridizationType.SP3: 3,
36
+ Chem.rdchem.HybridizationType.SP3D: 4,
37
+ Chem.rdchem.HybridizationType.SP3D2: 5,
38
+ }
39
+
40
+ Chirality = {ChiralType.CHI_TETRAHEDRAL_CW: 0,
41
+ ChiralType.CHI_TETRAHEDRAL_CCW: 1,
42
+ ChiralType.CHI_UNSPECIFIED: 2,
43
+ ChiralType.CHI_OTHER: 2}
44
+ # Add None
45
+ Bonds = {BondType.SINGLE: 0, BondType.DOUBLE: 1, BondType.TRIPLE: 2, BondType.AROMATIC: 3}
46
+
47
+ dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]')
48
+
49
+
50
+ # Feats From SMI
51
+ # Feats From MOL
52
+ # Feats From SDF
53
+
54
+
55
+ def get_features_from_ref_mol(
56
+ ref_mol,
57
+ remove_hs=True
58
+ ):
59
+ if remove_hs:
60
+ ref_mol = Chem.RemoveAllHs(ref_mol)
61
+ # print(ref_mol)
62
+ # if ref_mol.GetNumConformers()==0:
63
+ # AllChem.EmbedMolecule(ref_mol,useExpTorsionAnglePrefs=True, useBasicKnowledge=True,maxAttempts=100000)
64
+ ref_conf = ref_mol.GetConformer()
65
+ x_gt = []
66
+ for atom_id, atom in enumerate(ref_mol.GetAtoms()):
67
+ atom_pos = ref_conf.GetAtomPosition(atom_id)
68
+ x_gt.append(np.array([atom_pos.x, atom_pos.y, atom_pos.z]))
69
+ x_gt = np.stack(x_gt, axis=0).astype(np.float32)
70
+ x_exists = np.ones_like(x_gt[:, 0]).astype(np.int64)
71
+ a_mask = np.ones_like(x_gt[:, 0]).astype(np.int64)
72
+
73
+ # Ref Mol
74
+ AllChem.EmbedMolecule(ref_mol,maxAttempts=100000)
75
+ AllChem.MMFFOptimizeMolecule(ref_mol)
76
+ num_atoms = ref_mol.GetNumAtoms()
77
+ conf = ref_mol.GetConformer()
78
+ ring = ref_mol.GetRingInfo()
79
+
80
+ # Filtering Conditions
81
+ # if ref_mol.GetNumAtoms() < 4:
82
+ # return None
83
+ # if ref_mol.GetNumBonds() < 4:
84
+ # return None
85
+ #
86
+ # k = 0
87
+ # for conf in [conf]:
88
+ # # skip mols with atoms with more than 4 neighbors for now
89
+ # n_neighbors = [len(a.GetNeighbors()) for a in ref_mol.GetAtoms()]
90
+ # if np.max(n_neighbors) > 4:
91
+ # continue
92
+ # try:
93
+ # conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(ref_mol))
94
+ # except Exception as e:
95
+ # continue
96
+ # k += 1
97
+ # if k == 0:
98
+ # return None
99
+
100
+ ref_pos = []
101
+ ref_charge = []
102
+ ref_element = []
103
+ ref_is_aromatic = []
104
+ ref_degree = []
105
+ ref_hybridization = []
106
+ ref_implicit_valence = []
107
+ ref_chirality = []
108
+ ref_in_ring_of_3 = []
109
+ ref_in_ring_of_4 = []
110
+ ref_in_ring_of_5 = []
111
+ ref_in_ring_of_6 = []
112
+ ref_in_ring_of_7 = []
113
+ ref_in_ring_of_8 = []
114
+ for atom_id, atom in enumerate(ref_mol.GetAtoms()):
115
+ atom_pos = conf.GetAtomPosition(atom_id)
116
+ ref_pos.append(np.array([atom_pos.x, atom_pos.y, atom_pos.z]))
117
+ ref_charge.append(atom.GetFormalCharge())
118
+ ref_element.append(atom.GetAtomicNum() - 1)
119
+ ref_is_aromatic.append(int(atom.GetIsAromatic()))
120
+ ref_degree.append(min(atom.GetDegree(), 8))
121
+ ref_hybridization.append(Hybridization.get(atom.GetHybridization(), 6))
122
+ ref_implicit_valence.append(min(atom.GetImplicitValence(), 8))
123
+ ref_chirality.append(Chirality.get(atom.GetChiralTag(), 2))
124
+ ref_in_ring_of_3.append(int(ring.IsAtomInRingOfSize(atom_id, 3)))
125
+ ref_in_ring_of_4.append(int(ring.IsAtomInRingOfSize(atom_id, 4)))
126
+ ref_in_ring_of_5.append(int(ring.IsAtomInRingOfSize(atom_id, 5)))
127
+ ref_in_ring_of_6.append(int(ring.IsAtomInRingOfSize(atom_id, 6)))
128
+ ref_in_ring_of_7.append(int(ring.IsAtomInRingOfSize(atom_id, 7)))
129
+ ref_in_ring_of_8.append(int(ring.IsAtomInRingOfSize(atom_id, 8)))
130
+
131
+ ref_pos = np.stack(ref_pos, axis=0).astype(np.float32)
132
+ ref_charge = np.array(ref_charge).astype(np.float32)
133
+ ref_element = np.array(ref_element).astype(np.int8)
134
+ ref_is_aromatic = np.array(ref_is_aromatic).astype(np.int8)
135
+ ref_degree = np.array(ref_degree).astype(np.int8)
136
+ ref_hybridization = np.array(ref_hybridization).astype(np.int8)
137
+ ref_implicit_valence = np.array(ref_implicit_valence).astype(np.int8)
138
+ ref_chirality = np.array(ref_chirality).astype(np.int8)
139
+ ref_in_ring_of_3 = np.array(ref_in_ring_of_3).astype(np.int8)
140
+ ref_in_ring_of_4 = np.array(ref_in_ring_of_4).astype(np.int8)
141
+ ref_in_ring_of_5 = np.array(ref_in_ring_of_5).astype(np.int8)
142
+ ref_in_ring_of_6 = np.array(ref_in_ring_of_6).astype(np.int8)
143
+ ref_in_ring_of_7 = np.array(ref_in_ring_of_7).astype(np.int8)
144
+ ref_in_ring_of_8 = np.array(ref_in_ring_of_8).astype(np.int8)
145
+
146
+ d_token = np.zeros([num_atoms, num_atoms], dtype=np.int8)
147
+ token_bonds = np.zeros([num_atoms, num_atoms], dtype=np.int8)
148
+ bond_type = np.zeros([num_atoms, num_atoms], dtype=np.int8)
149
+ bond_as_double = np.zeros([num_atoms, num_atoms], dtype=np.int8)
150
+ bond_in_ring = np.zeros([num_atoms, num_atoms], dtype=np.int8)
151
+ bond_is_aromatic = np.zeros([num_atoms, num_atoms], dtype=np.int8)
152
+ bond_is_conjugated = np.zeros([num_atoms, num_atoms], dtype=np.int8)
153
+ for i in range(num_atoms - 1):
154
+ for j in range(i + 1, num_atoms):
155
+ dist = len(rdmolops.GetShortestPath(ref_mol, i, j)) - 1
156
+ dist = min(30, dist)
157
+ d_token[i, j] = dist
158
+ d_token[j, i] = dist
159
+ for bond_id, bond in enumerate(ref_mol.GetBonds()):
160
+ i = bond.GetBeginAtomIdx()
161
+ j = bond.GetEndAtomIdx()
162
+ token_bonds[i, j] = 1
163
+ token_bonds[j, i] = 1
164
+ bond_type[i, j] = Bonds.get(bond.GetBondType(), 4)
165
+ bond_type[j, i] = Bonds.get(bond.GetBondType(), 4)
166
+ bond_as_double[i, j] = bond.GetBondTypeAsDouble()
167
+ bond_as_double[j, i] = bond.GetBondTypeAsDouble()
168
+ bond_in_ring[i, j] = bond.IsInRing()
169
+ bond_in_ring[j, i] = bond.IsInRing()
170
+ bond_is_conjugated[i, j] = bond.GetIsConjugated()
171
+ bond_is_conjugated[j, i] = bond.GetIsConjugated()
172
+ bond_is_aromatic[i, j] = bond.GetIsAromatic()
173
+ bond_is_aromatic[j, i] = bond.GetIsAromatic()
174
+
175
+ ref_atom_name_chars = [PeriodicTable[e] for e in ref_element.tolist()]
176
+ ref_mask_in_polymer = [1] * len(ref_pos)
177
+
178
+ # ccds, restype, residue_index, atom_id_to_conformer_atom_id, a_mask, x_gt, x_exists
179
+ num_atoms = len(x_gt)
180
+ label_feature = {
181
+ "x_gt": x_gt,
182
+ "x_exists": x_exists,
183
+ "a_mask": a_mask,
184
+ "restype": np.array([20]).astype(np.int64),
185
+ "residue_index": np.arange(1).astype(np.int64),
186
+ "atom_id_to_conformer_atom_id": np.arange(num_atoms).astype(np.int64),
187
+ "conformer_id_to_chunk_sizes": np.array([num_atoms]).astype(np.int64)
188
+ }
189
+ conf_feature = {
190
+ "ref_pos": ref_pos,
191
+ "ref_charge": ref_charge,
192
+ "ref_element": ref_element,
193
+ "ref_is_aromatic": ref_is_aromatic,
194
+ "ref_degree": ref_degree,
195
+ "ref_hybridization": ref_hybridization,
196
+ "ref_implicit_valence": ref_implicit_valence,
197
+ "ref_chirality": ref_chirality,
198
+ "ref_in_ring_of_3": ref_in_ring_of_3,
199
+ "ref_in_ring_of_4": ref_in_ring_of_4,
200
+ "ref_in_ring_of_5": ref_in_ring_of_5,
201
+ "ref_in_ring_of_6": ref_in_ring_of_6,
202
+ "ref_in_ring_of_7": ref_in_ring_of_7,
203
+ "ref_in_ring_of_8": ref_in_ring_of_8,
204
+ "d_token": d_token,
205
+ "token_bonds": token_bonds,
206
+ "bond_type": bond_type,
207
+ "bond_as_double": bond_as_double,
208
+ "bond_in_ring": bond_in_ring,
209
+ "bond_is_conjugated": bond_is_conjugated,
210
+ "bond_is_aromatic": bond_is_aromatic,
211
+ "ref_atom_name_chars": ref_atom_name_chars,
212
+ "ref_mask_in_polymer": ref_mask_in_polymer,
213
+ }
214
+ return label_feature, conf_feature, ref_mol
215
+
216
+
217
+ def get_features_from_smi(smi, remove_hs=True):
218
+ ref_mol = get_ref_mol(smi)
219
+ label_feature, conf_feature, ref_mol = get_features_from_ref_mol(ref_mol, remove_hs=remove_hs)
220
+ return label_feature, conf_feature, ref_mol
PhysDock/data/tools/residue_constants.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ amino_acid_1to3 = {
4
+ "A": "ALA",
5
+ "R": "ARG",
6
+ "N": "ASN",
7
+ "D": "ASP",
8
+ "C": "CYS",
9
+ "Q": "GLN",
10
+ "E": "GLU",
11
+ "G": "GLY",
12
+ "H": "HIS",
13
+ "I": "ILE",
14
+ "L": "LEU",
15
+ "K": "LYS",
16
+ "M": "MET",
17
+ "F": "PHE",
18
+ "P": "PRO",
19
+ "S": "SER",
20
+ "T": "THR",
21
+ "W": "TRP",
22
+ "Y": "TYR",
23
+ "V": "VAL",
24
+ "X": "UNK",
25
+ }
26
+
27
+ amino_acid_3to1 = {v: k for k, v in amino_acid_1to3.items()}
28
+
29
+ # Ligand Atom is representaed as "UNK" in token
30
+ # standard_residue is also ccd
31
+ standard_protein = ["ALA", "ARG", "ASN", "ASP", "CYS", "GLN", "GLU", "GLY", "HIS", "ILE",
32
+ "LEU", "LYS", "MET", "PHE", "PRO", "SER", "THR", "TRP", "TYR", "VAL", "UNK", ]
33
+ standard_rna = ["A ", "G ", "C ", "U ", "N ", ]
34
+ standard_dna = ["DA ", "DG ", "DC ", "DT ", "DN ", ]
35
+ standard_nucleics = standard_rna + standard_dna
36
+ standard_ccds_without_gap = standard_protein + standard_nucleics
37
+ GAP = ["GAP"] # used in msa one-hot
38
+ standard_ccds = standard_protein + standard_nucleics + GAP
39
+
40
+ standard_ccd_to_order = {ccd: id for id, ccd in enumerate(standard_ccds)}
41
+
42
+ standard_purines = ["A ", "G ", "DA ", "DG "]
43
+ standard_pyrimidines = ["C ", "U ", "DC ", "DT "]
44
+
45
+ is_standard = lambda x: x in standard_ccds
46
+ is_unk = lambda x: x in ["UNK", "N ", "DN ", "GAP", "UNL"]
47
+ is_protein = lambda x: x in standard_protein and not is_unk(x)
48
+ is_rna = lambda x: x in standard_rna and not is_unk(x)
49
+ is_dna = lambda x: x in standard_dna and not is_unk(x)
50
+ is_nucleics = lambda x: x in standard_nucleics and not is_unk(x)
51
+ is_purines = lambda x: x in standard_purines
52
+ is_pyrimidines = lambda x: x in standard_pyrimidines
53
+
54
+ standard_ccd_to_atoms_num = {s: n for s, n in zip(standard_ccds, [
55
+ 5, 11, 8, 8, 6, 9, 9, 4, 10, 8,
56
+ 8, 9, 8, 11, 7, 6, 7, 14, 12, 7, None,
57
+ 22, 23, 20, 20, None,
58
+ 21, 22, 19, 20, None,
59
+ None,
60
+ ])}
61
+
62
+ standard_ccd_to_token_centre_atom_name = {
63
+ **{residue: "CA" for residue in standard_protein},
64
+ **{residue: "C1'" for residue in standard_nucleics},
65
+ }
66
+
67
+ standard_ccd_to_frame_atom_name_0 = {
68
+ **{residue: "N" for residue in standard_protein},
69
+ **{residue: "C1'" for residue in standard_nucleics},
70
+ }
71
+
72
+ standard_ccd_to_frame_atom_name_1 = {
73
+ **{residue: "CA" for residue in standard_protein},
74
+ **{residue: "C3'" for residue in standard_nucleics},
75
+ }
76
+
77
+ standard_ccd_to_frame_atom_name_2 = {
78
+ **{residue: "C" for residue in standard_protein},
79
+ **{residue: "C4'" for residue in standard_nucleics},
80
+ }
81
+
82
+ standard_ccd_to_token_pseudo_beta_atom_name = {
83
+ **{residue: "CB" for residue in standard_protein},
84
+ **{residue: "C4" for residue in standard_purines},
85
+ **{residue: "C2" for residue in standard_pyrimidines},
86
+ }
87
+ standard_ccd_to_token_pseudo_beta_atom_name.update({"GLY": "CA"})
88
+
89
+ HHBLITS_ID_TO_AA = {
90
+ 0: "ALA",
91
+ 1: "CYS", # Also U.
92
+ 2: "ASP", # Also B.
93
+ 3: "GLU", # Also Z.
94
+ 4: "PHE",
95
+ 5: "GLY",
96
+ 6: "HIS",
97
+ 7: "ILE",
98
+ 8: "LYS",
99
+ 9: "LEU",
100
+ 10: "MET",
101
+ 11: "ASN",
102
+ 12: "PRO",
103
+ 13: "GLN",
104
+ 14: "ARG",
105
+ 15: "SER",
106
+ 16: "THR",
107
+ 17: "VAL",
108
+ 18: "TRP",
109
+ 19: "TYR",
110
+ 20: "UNK", # Includes J and O.
111
+ 21: "GAP",
112
+ }
113
+
114
+ # Usage: Convert hhblits msa to af3 aatype
115
+ # msa = hhblits_id_to_standard_residue_id_np[hhblits_msa.astype(np.int64)]
116
+ hhblits_id_to_standard_residue_id_np = np.array(
117
+ [standard_ccds.index(ccd) for id, ccd in HHBLITS_ID_TO_AA.items()]
118
+ )
119
+
120
+ of_restypes = [
121
+ "A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
122
+ "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", "X", "-"
123
+ ]
124
+
125
+ af3_restypes = [
126
+ amino_acid_3to1[ccd] if ccd in amino_acid_3to1 else "-" if ccd == "GAP" else "None"
127
+ for ccd in standard_ccds
128
+ ]
129
+
130
+ af3_if_to_residue_id = np.array(
131
+ [af3_restypes.index(restype) if restype in of_restypes else -1 for restype in af3_restypes])
132
+
133
+ ########################################################
134
+ # periodic table that used to encode elements #
135
+ ########################################################
136
+ periodic_table = [
137
+ "h", "he",
138
+ "li", "be", "b", "c", "n", "o", "f", "ne",
139
+ "na", "mg", "al", "si", "p", "s", "cl", "ar",
140
+ "k", "ca", "sc", "ti", "v", "cr", "mn", "fe", "co", "ni", "cu", "zn", "ga", "ge", "as", "se", "br", "kr",
141
+ "rb", "sr", "y", "zr", "nb", "mo", "tc", "ru", "rh", "pd", "ag", "cd", "in", "sn", "sb", "te", "i", "xe",
142
+ "cs", "ba",
143
+ "la", "ce", "pr", "nd", "pm", "sm", "eu", "gd", "tb", "dy", "ho", "er", "tm", "yb", "lu",
144
+ "hf", "ta", "w", "re", "os", "ir", "pt", "au", "hg", "tl", "pb", "bi", "po", "at", "rn",
145
+ "fr", "ra",
146
+ "ac", "th", "pa", "u", "np", "pu", "am", "cm", "bk", "cf", "es", "fm", "md", "no", "lr",
147
+ "rf", "db", "sg", "bh", "hs", "mt", "ds", "rg", "cn", "nh", "fl", "mc", "lv", "ts", "og"
148
+ ]
149
+
150
+ get_element_id = {ele: ele_id for ele_id, ele in enumerate(periodic_table)}
151
+
152
+ ##########################################################
153
+
154
+ standard_ccd_to_reference_features_table = {
155
+ # letters_3: [ref_pos,ref_charge, ref_mask, ref_elements, ref_atom_name_chars]
156
+ "ALA": [
157
+ [-0.966, 0.493, 1.500, 0., 1, "N", "N"],
158
+ [0.257, 0.418, 0.692, 0., 1, "C", "CA"],
159
+ [-0.094, 0.017, -0.716, 0., 1, "C", "C"],
160
+ [-1.056, -0.682, -0.923, 0., 1, "O", "O"],
161
+ [1.204, -0.620, 1.296, 0., 1, "C", "CB"],
162
+ [0.661, 0.439, -1.742, 0., 0, "O", "OXT"],
163
+ ],
164
+ "ARG": [
165
+ [-0.469, 1.110, -0.993, 0., 1, "N", "N"],
166
+ [0.004, 2.294, -1.708, 0., 1, "C", "CA"],
167
+ [-0.907, 2.521, -2.901, 0., 1, "C", "C"],
168
+ [-1.827, 1.789, -3.242, 0., 1, "O", "O"],
169
+ [1.475, 2.150, -2.127, 0., 1, "C", "CB"],
170
+ [1.745, 1.017, -3.130, 0., 1, "C", "CG"],
171
+ [3.210, 0.954, -3.557, 0., 1, "C", "CD"],
172
+ [4.071, 0.726, -2.421, 0., 1, "N", "NE"],
173
+ [5.469, 0.624, -2.528, 0., 1, "C", "CZ"],
174
+ [6.259, 0.404, -1.405, 0., 1, "N", "NH1"],
175
+ [6.078, 0.744, -3.773, 0., 1, "N", "NH2"],
176
+ [-0.588, 3.659, -3.574, 0., 0, "O", "OXT"],
177
+ ],
178
+ "ASN": [
179
+ [-0.293, 1.686, 0.094, 0., 1, "N", "N"],
180
+ [-0.448, 0.292, -0.340, 0., 1, "C", "CA"],
181
+ [-1.846, -0.179, -0.031, 0., 1, "C", "C"],
182
+ [-2.510, 0.402, 0.794, 0., 1, "O", "O"],
183
+ [0.562, -0.588, 0.401, 0., 1, "C", "CB"],
184
+ [1.960, -0.197, -0.002, 0., 1, "C", "CG"],
185
+ [2.132, 0.697, -0.804, 0., 1, "O", "OD1"],
186
+ [3.019, -0.841, 0.527, 0., 1, "N", "ND2"],
187
+ [-2.353, -1.243, -0.673, 0., 0, "O", "OXT"],
188
+ ],
189
+ "ASP": [
190
+ [-0.317, 1.688, 0.066, 0., 1, "N", "N"],
191
+ [-0.470, 0.286, -0.344, 0., 1, "C", "CA"],
192
+ [-1.868, -0.180, -0.029, 0., 1, "C", "C"],
193
+ [-2.534, 0.415, 0.786, 0., 1, "O", "O"],
194
+ [0.539, -0.580, 0.413, 0., 1, "C", "CB"],
195
+ [1.938, -0.195, 0.004, 0., 1, "C", "CG"],
196
+ [2.109, 0.681, -0.810, 0., 1, "O", "OD1"],
197
+ [2.992, -0.826, 0.543, 0., 1, "O", "OD2"],
198
+ [-2.374, -1.256, -0.652, 0., 0, "O", "OXT"],
199
+ ],
200
+ "CYS": [
201
+ [1.585, 0.483, -0.081, 0., 1, "N", "N"],
202
+ [0.141, 0.450, 0.186, 0., 1, "C", "CA"],
203
+ [-0.095, 0.006, 1.606, 0., 1, "C", "C"],
204
+ [0.685, -0.742, 2.143, 0., 1, "O", "O"],
205
+ [-0.533, -0.530, -0.774, 0., 1, "C", "CB"],
206
+ [-0.247, 0.004, -2.484, 0., 1, "S", "SG"],
207
+ [-1.174, 0.443, 2.275, 0., 0, "O", "OXT"],
208
+ ],
209
+ "GLN": [
210
+ [1.858, -0.148, 1.125, 0., 1, "N", "N"],
211
+ [0.517, 0.451, 1.112, 0., 1, "C", "CA"],
212
+ [-0.236, 0.022, 2.344, 0., 1, "C", "C"],
213
+ [-0.005, -1.049, 2.851, 0., 1, "O", "O"],
214
+ [-0.236, -0.013, -0.135, 0., 1, "C", "CB"],
215
+ [0.529, 0.421, -1.385, 0., 1, "C", "CG"],
216
+ [-0.213, -0.036, -2.614, 0., 1, "C", "CD"],
217
+ [-1.252, -0.650, -2.500, 0., 1, "O", "OE1"],
218
+ [0.277, 0.236, -3.839, 0., 1, "N", "NE2"],
219
+ [-1.165, 0.831, 2.878, 0., 0, "O", "OXT"],
220
+ ],
221
+ "GLU": [
222
+ [1.199, 1.867, -0.117, 0., 1, "N", "N"],
223
+ [1.138, 0.515, 0.453, 0., 1, "C", "CA"],
224
+ [2.364, -0.260, 0.041, 0., 1, "C", "C"],
225
+ [3.010, 0.096, -0.916, 0., 1, "O", "O"],
226
+ [-0.113, -0.200, -0.062, 0., 1, "C", "CB"],
227
+ [-1.360, 0.517, 0.461, 0., 1, "C", "CG"],
228
+ [-2.593, -0.187, -0.046, 0., 1, "C", "CD"],
229
+ [-2.485, -1.161, -0.753, 0., 1, "O", "OE1"],
230
+ [-3.811, 0.269, 0.287, 0., 1, "O", "OE2"],
231
+ [2.737, -1.345, 0.737, 0., 0, "O", "OXT"],
232
+ ],
233
+ "GLY": [
234
+ [1.931, 0.090, -0.034, 0., 1, "N", "N"],
235
+ [0.761, -0.799, -0.008, 0., 1, "C", "CA"],
236
+ [-0.498, 0.029, -0.005, 0., 1, "C", "C"],
237
+ [-0.429, 1.235, -0.023, 0., 1, "O", "O"],
238
+ [-1.697, -0.574, 0.018, 0., 0, "O", "OXT"],
239
+ ],
240
+ "HIS": [
241
+ [-0.040, -1.210, 0.053, 0., 1, "N", "N"],
242
+ [1.172, -1.709, 0.652, 0., 1, "C", "CA"],
243
+ [1.083, -3.207, 0.905, 0., 1, "C", "C"],
244
+ [0.040, -3.770, 1.222, 0., 1, "O", "O"],
245
+ [1.484, -0.975, 1.962, 0., 1, "C", "CB"],
246
+ [2.940, -1.060, 2.353, 0., 1, "C", "CG"],
247
+ [3.380, -2.075, 3.129, 0., 1, "N", "ND1"],
248
+ [3.960, -0.251, 2.046, 0., 1, "C", "CD2"],
249
+ [4.693, -1.908, 3.317, 0., 1, "C", "CE1"],
250
+ [5.058, -0.801, 2.662, 0., 1, "N", "NE2"],
251
+ [2.247, -3.882, 0.744, 0., 0, "O", "OXT"],
252
+ ],
253
+ "ILE": [
254
+ [-1.944, 0.335, -0.343, 0., 1, "N", "N"],
255
+ [-0.487, 0.519, -0.369, 0., 1, "C", "CA"],
256
+ [0.066, -0.032, -1.657, 0., 1, "C", "C"],
257
+ [-0.484, -0.958, -2.203, 0., 1, "O", "O"],
258
+ [0.140, -0.219, 0.814, 0., 1, "C", "CB"],
259
+ [-0.421, 0.341, 2.122, 0., 1, "C", "CG1"],
260
+ [1.658, -0.027, 0.788, 0., 1, "C", "CG2"],
261
+ [0.206, -0.397, 3.305, 0., 1, "C", "CD1"],
262
+ [1.171, 0.504, -2.197, 0., 0, "O", "OXT"],
263
+ ],
264
+ "LEU": [
265
+ [-1.661, 0.627, -0.406, 0., 1, "N", "N"],
266
+ [-0.205, 0.441, -0.467, 0., 1, "C", "CA"],
267
+ [0.180, -0.055, -1.836, 0., 1, "C", "C"],
268
+ [-0.591, -0.731, -2.474, 0., 1, "O", "O"],
269
+ [0.221, -0.583, 0.585, 0., 1, "C", "CB"],
270
+ [-0.170, -0.079, 1.976, 0., 1, "C", "CG"],
271
+ [0.256, -1.104, 3.029, 0., 1, "C", "CD1"],
272
+ [0.526, 1.254, 2.250, 0., 1, "C", "CD2"],
273
+ [1.382, 0.254, -2.348, 0., 0, "O", "OXT"],
274
+ ],
275
+ "LYS": [
276
+ [1.422, 1.796, 0.198, 0., 1, "N", "N"],
277
+ [1.394, 0.355, 0.484, 0., 1, "C", "CA"],
278
+ [2.657, -0.284, -0.032, 0., 1, "C", "C"],
279
+ [3.316, 0.275, -0.876, 0., 1, "O", "O"],
280
+ [0.184, -0.278, -0.206, 0., 1, "C", "CB"],
281
+ [-1.102, 0.282, 0.407, 0., 1, "C", "CG"],
282
+ [-2.313, -0.351, -0.283, 0., 1, "C", "CD"],
283
+ [-3.598, 0.208, 0.329, 0., 1, "C", "CE"],
284
+ [-4.761, -0.400, -0.332, 0., 1, "N", "NZ"],
285
+ [3.050, -1.476, 0.446, 0., 0, "O", "OXT"],
286
+ ],
287
+ "MET": [
288
+ [-1.816, 0.142, -1.166, 0., 1, "N", "N"],
289
+ [-0.392, 0.499, -1.214, 0., 1, "C", "CA"],
290
+ [0.206, 0.002, -2.504, 0., 1, "C", "C"],
291
+ [-0.236, -0.989, -3.033, 0., 1, "O", "O"],
292
+ [0.334, -0.145, -0.032, 0., 1, "C", "CB"],
293
+ [-0.273, 0.359, 1.277, 0., 1, "C", "CG"],
294
+ [0.589, -0.405, 2.678, 0., 1, "S", "SD"],
295
+ [-0.314, 0.353, 4.056, 0., 1, "C", "CE"],
296
+ [1.232, 0.661, -3.066, 0., 0, "O", "OXT"],
297
+ ],
298
+ "PHE": [
299
+ [1.317, 0.962, 1.014, 0., 1, "N", "N"],
300
+ [-0.020, 0.426, 1.300, 0., 1, "C", "CA"],
301
+ [-0.109, 0.047, 2.756, 0., 1, "C", "C"],
302
+ [0.879, -0.317, 3.346, 0., 1, "O", "O"],
303
+ [-0.270, -0.809, 0.434, 0., 1, "C", "CB"],
304
+ [-0.181, -0.430, -1.020, 0., 1, "C", "CG"],
305
+ [1.031, -0.498, -1.680, 0., 1, "C", "CD1"],
306
+ [-1.314, -0.018, -1.698, 0., 1, "C", "CD2"],
307
+ [1.112, -0.150, -3.015, 0., 1, "C", "CE1"],
308
+ [-1.231, 0.333, -3.032, 0., 1, "C", "CE2"],
309
+ [-0.018, 0.265, -3.691, 0., 1, "C", "CZ"],
310
+ [-1.286, 0.113, 3.396, 0., 0, "O", "OXT"],
311
+ ],
312
+ "PRO": [
313
+ [-0.816, 1.108, 0.254, 0., 1, "N", "N"],
314
+ [0.001, -0.107, 0.509, 0., 1, "C", "CA"],
315
+ [1.408, 0.091, 0.005, 0., 1, "C", "C"],
316
+ [1.650, 0.980, -0.777, 0., 1, "O", "O"],
317
+ [-0.703, -1.227, -0.286, 0., 1, "C", "CB"],
318
+ [-2.163, -0.753, -0.439, 0., 1, "C", "CG"],
319
+ [-2.218, 0.614, 0.276, 0., 1, "C", "CD"],
320
+ [2.391, -0.721, 0.424, 0., 0, "O", "OXT"],
321
+ ],
322
+ "SER": [
323
+ [1.525, 0.493, -0.608, 0., 1, "N", "N"],
324
+ [0.100, 0.469, -0.252, 0., 1, "C", "CA"],
325
+ [-0.053, 0.004, 1.173, 0., 1, "C", "C"],
326
+ [0.751, -0.760, 1.649, 0., 1, "O", "O"],
327
+ [-0.642, -0.489, -1.184, 0., 1, "C", "CB"],
328
+ [-0.496, -0.049, -2.535, 0., 1, "O", "OG"],
329
+ [-1.084, 0.440, 1.913, 0., 0, "O", "OXT"],
330
+ ],
331
+ "THR": [
332
+ [1.543, -0.702, 0.430, 0., 1, "N", "N"],
333
+ [0.122, -0.706, 0.056, 0., 1, "C", "CA"],
334
+ [-0.038, -0.090, -1.309, 0., 1, "C", "C"],
335
+ [0.732, 0.761, -1.683, 0., 1, "O", "O"],
336
+ [-0.675, 0.104, 1.079, 0., 1, "C", "CB"],
337
+ [-0.193, 1.448, 1.103, 0., 1, "O", "OG1"],
338
+ [-0.511, -0.521, 2.466, 0., 1, "C", "CG2"],
339
+ [-1.039, -0.488, -2.110, 0., 0, "O", "OXT"],
340
+ ],
341
+ "TRP": [
342
+ [1.278, 1.121, 2.059, 0., 1, "N", "N"],
343
+ [-0.008, 0.417, 1.970, 0., 1, "C", "CA"],
344
+ [-0.490, 0.076, 3.357, 0., 1, "C", "C"],
345
+ [0.308, -0.130, 4.240, 0., 1, "O", "O"],
346
+ [0.168, -0.868, 1.161, 0., 1, "C", "CB"],
347
+ [0.650, -0.526, -0.225, 0., 1, "C", "CG"],
348
+ [1.928, -0.418, -0.622, 0., 1, "C", "CD1"],
349
+ [-0.186, -0.256, -1.396, 0., 1, "C", "CD2"],
350
+ [1.978, -0.095, -1.951, 0., 1, "N", "NE1"],
351
+ [0.701, 0.014, -2.454, 0., 1, "C", "CE2"],
352
+ [-1.564, -0.210, -1.615, 0., 1, "C", "CE3"],
353
+ [0.190, 0.314, -3.712, 0., 1, "C", "CZ2"],
354
+ [-2.044, 0.086, -2.859, 0., 1, "C", "CZ3"],
355
+ [-1.173, 0.348, -3.907, 0., 1, "C", "CH2"],
356
+ [-1.806, 0.001, 3.610, 0., 0, "O", "OXT"],
357
+ ],
358
+ "TYR": [
359
+ [1.320, 0.952, 1.428, 0., 1, "N", "N"],
360
+ [-0.018, 0.429, 1.734, 0., 1, "C", "CA"],
361
+ [-0.103, 0.094, 3.201, 0., 1, "C", "C"],
362
+ [0.886, -0.254, 3.799, 0., 1, "O", "O"],
363
+ [-0.274, -0.831, 0.907, 0., 1, "C", "CB"],
364
+ [-0.189, -0.496, -0.559, 0., 1, "C", "CG"],
365
+ [1.022, -0.589, -1.219, 0., 1, "C", "CD1"],
366
+ [-1.324, -0.102, -1.244, 0., 1, "C", "CD2"],
367
+ [1.103, -0.282, -2.563, 0., 1, "C", "CE1"],
368
+ [-1.247, 0.210, -2.587, 0., 1, "C", "CE2"],
369
+ [-0.032, 0.118, -3.252, 0., 1, "C", "CZ"],
370
+ [0.044, 0.420, -4.574, 0., 1, "O", "OH"],
371
+ [-1.279, 0.184, 3.842, 0., 0, "O", "OXT"],
372
+ ],
373
+ "VAL": [
374
+ [1.564, -0.642, 0.454, 0., 1, "N", "N"],
375
+ [0.145, -0.698, 0.079, 0., 1, "C", "CA"],
376
+ [-0.037, -0.093, -1.288, 0., 1, "C", "C"],
377
+ [0.703, 0.784, -1.664, 0., 1, "O", "O"],
378
+ [-0.682, 0.086, 1.098, 0., 1, "C", "CB"],
379
+ [-0.497, -0.528, 2.487, 0., 1, "C", "CG1"],
380
+ [-0.218, 1.543, 1.119, 0., 1, "C", "CG2"],
381
+ [-1.022, -0.529, -2.089, 0., 0, "O", "OXT"],
382
+ ],
383
+ "A ": [
384
+ [2.135, -1.141, -5.313, 0., 0, "O", "OP3"],
385
+ [1.024, -0.137, -4.723, 0., 1, "P", "P"],
386
+ [1.633, 1.190, -4.488, 0., 1, "O", "OP1"],
387
+ [-0.183, 0.005, -5.778, 0., 1, "O", "OP2"],
388
+ [0.456, -0.720, -3.334, 0., 1, "O", "O5'"],
389
+ [-0.520, 0.209, -2.863, 0., 1, "C", "C5'"],
390
+ [-1.101, -0.287, -1.538, 0., 1, "C", "C4'"],
391
+ [-0.064, -0.383, -0.538, 0., 1, "O", "O4'"],
392
+ [-2.105, 0.739, -0.969, 0., 1, "C", "C3'"],
393
+ [-3.445, 0.360, -1.287, 0., 1, "O", "O3'"],
394
+ [-1.874, 0.684, 0.558, 0., 1, "C", "C2'"],
395
+ [-3.065, 0.271, 1.231, 0., 1, "O", "O2'"],
396
+ [-0.755, -0.367, 0.729, 0., 1, "C", "C1'"],
397
+ [0.158, 0.029, 1.803, 0., 1, "N", "N9"],
398
+ [1.265, 0.813, 1.672, 0., 1, "C", "C8"],
399
+ [1.843, 0.963, 2.828, 0., 1, "N", "N7"],
400
+ [1.143, 0.292, 3.773, 0., 1, "C", "C5"],
401
+ [1.290, 0.091, 5.156, 0., 1, "C", "C6"],
402
+ [2.344, 0.664, 5.846, 0., 1, "N", "N6"],
403
+ [0.391, -0.656, 5.787, 0., 1, "N", "N1"],
404
+ [-0.617, -1.206, 5.136, 0., 1, "C", "C2"],
405
+ [-0.792, -1.051, 3.841, 0., 1, "N", "N3"],
406
+ [0.056, -0.320, 3.126, 0., 1, "C", "C4"],
407
+ ],
408
+ "G ": [
409
+ [-1.945, -1.360, 5.599, 0., 0, "O", "OP3"],
410
+ [-0.911, -0.277, 5.008, 0., 1, "P", "P"],
411
+ [-1.598, 1.022, 4.844, 0., 1, "O", "OP1"],
412
+ [0.325, -0.105, 6.025, 0., 1, "O", "OP2"],
413
+ [-0.365, -0.780, 3.580, 0., 1, "O", "O5'"],
414
+ [0.542, 0.217, 3.109, 0., 1, "C", "C5'"],
415
+ [1.100, -0.200, 1.748, 0., 1, "C", "C4'"],
416
+ [0.033, -0.318, 0.782, 0., 1, "O", "O4'"],
417
+ [2.025, 0.898, 1.182, 0., 1, "C", "C3'"],
418
+ [3.395, 0.582, 1.439, 0., 1, "O", "O3'"],
419
+ [1.741, 0.884, -0.338, 0., 1, "C", "C2'"],
420
+ [2.927, 0.560, -1.066, 0., 1, "O", "O2'"],
421
+ [0.675, -0.220, -0.507, 0., 1, "C", "C1'"],
422
+ [-0.297, 0.162, -1.534, 0., 1, "N", "N9"],
423
+ [-1.440, 0.880, -1.334, 0., 1, "C", "C8"],
424
+ [-2.066, 1.037, -2.464, 0., 1, "N", "N7"],
425
+ [-1.364, 0.431, -3.453, 0., 1, "C", "C5"],
426
+ [-1.556, 0.279, -4.846, 0., 1, "C", "C6"],
427
+ [-2.534, 0.755, -5.397, 0., 1, "O", "O6"],
428
+ [-0.626, -0.401, -5.551, 0., 1, "N", "N1"],
429
+ [0.459, -0.934, -4.923, 0., 1, "C", "C2"],
430
+ [1.384, -1.626, -5.664, 0., 1, "N", "N2"],
431
+ [0.649, -0.800, -3.630, 0., 1, "N", "N3"],
432
+ [-0.226, -0.134, -2.868, 0., 1, "C", "C4"],
433
+ ],
434
+ "C ": [
435
+ [2.147, -1.021, -4.678, 0., 0, "O", "OP3"],
436
+ [1.049, -0.039, -4.028, 0., 1, "P", "P"],
437
+ [1.692, 1.237, -3.646, 0., 1, "O", "OP1"],
438
+ [-0.116, 0.246, -5.102, 0., 1, "O", "OP2"],
439
+ [0.415, -0.733, -2.721, 0., 1, "O", "O5'"],
440
+ [-0.546, 0.181, -2.193, 0., 1, "C", "C5'"],
441
+ [-1.189, -0.419, -0.942, 0., 1, "C", "C4'"],
442
+ [-0.190, -0.648, 0.076, 0., 1, "O", "O4'"],
443
+ [-2.178, 0.583, -0.307, 0., 1, "C", "C3'"],
444
+ [-3.518, 0.283, -0.703, 0., 1, "O", "O3'"],
445
+ [-2.001, 0.373, 1.215, 0., 1, "C", "C2'"],
446
+ [-3.228, -0.059, 1.806, 0., 1, "O", "O2'"],
447
+ [-0.924, -0.729, 1.317, 0., 1, "C", "C1'"],
448
+ [-0.036, -0.470, 2.453, 0., 1, "N", "N1"],
449
+ [0.652, 0.683, 2.514, 0., 1, "C", "C2"],
450
+ [0.529, 1.504, 1.620, 0., 1, "O", "O2"],
451
+ [1.467, 0.945, 3.535, 0., 1, "N", "N3"],
452
+ [1.620, 0.070, 4.520, 0., 1, "C", "C4"],
453
+ [2.464, 0.350, 5.569, 0., 1, "N", "N4"],
454
+ [0.916, -1.151, 4.483, 0., 1, "C", "C5"],
455
+ [0.087, -1.399, 3.442, 0., 1, "C", "C6"],
456
+ ],
457
+ "U ": [
458
+ [-2.122, 1.033, -4.690, 0., 0, "O", "OP3"],
459
+ [-1.030, 0.047, -4.037, 0., 1, "P", "P"],
460
+ [-1.679, -1.228, -3.660, 0., 1, "O", "OP1"],
461
+ [0.138, -0.241, -5.107, 0., 1, "O", "OP2"],
462
+ [-0.399, 0.736, -2.726, 0., 1, "O", "O5'"],
463
+ [0.557, -0.182, -2.196, 0., 1, "C", "C5'"],
464
+ [1.197, 0.415, -0.942, 0., 1, "C", "C4'"],
465
+ [0.194, 0.645, 0.074, 0., 1, "O", "O4'"],
466
+ [2.181, -0.588, -0.301, 0., 1, "C", "C3'"],
467
+ [3.524, -0.288, -0.686, 0., 1, "O", "O3'"],
468
+ [1.995, -0.383, 1.218, 0., 1, "C", "C2'"],
469
+ [3.219, 0.046, 1.819, 0., 1, "O", "O2'"],
470
+ [0.922, 0.723, 1.319, 0., 1, "C", "C1'"],
471
+ [0.028, 0.464, 2.451, 0., 1, "N", "N1"],
472
+ [-0.690, -0.671, 2.486, 0., 1, "C", "C2"],
473
+ [-0.587, -1.474, 1.580, 0., 1, "O", "O2"],
474
+ [-1.515, -0.936, 3.517, 0., 1, "N", "N3"],
475
+ [-1.641, -0.055, 4.530, 0., 1, "C", "C4"],
476
+ [-2.391, -0.292, 5.460, 0., 1, "O", "O4"],
477
+ [-0.894, 1.146, 4.502, 0., 1, "C", "C5"],
478
+ [-0.070, 1.384, 3.459, 0., 1, "C", "C6"],
479
+ ],
480
+ "DA ": [
481
+ [1.845, -1.282, -5.339, 0., 0, "O", "OP3"],
482
+ [0.934, -0.156, -4.636, 0., 1, "P", "P"],
483
+ [1.781, 0.996, -4.255, 0., 1, "O", "OP1"],
484
+ [-0.204, 0.331, -5.665, 0., 1, "O", "OP2"],
485
+ [0.241, -0.771, -3.320, 0., 1, "O", "O5'"],
486
+ [-0.549, 0.270, -2.744, 0., 1, "C", "C5'"],
487
+ [-1.239, -0.251, -1.482, 0., 1, "C", "C4'"],
488
+ [-0.267, -0.564, -0.458, 0., 1, "O", "O4'"],
489
+ [-2.105, 0.859, -0.835, 0., 1, "C", "C3'"],
490
+ [-3.409, 0.895, -1.418, 0., 1, "O", "O3'"],
491
+ [-2.173, 0.398, 0.640, 0., 1, "C", "C2'"],
492
+ [-0.965, -0.545, 0.797, 0., 1, "C", "C1'"],
493
+ [-0.078, -0.047, 1.852, 0., 1, "N", "N9"],
494
+ [0.962, 0.817, 1.689, 0., 1, "C", "C8"],
495
+ [1.535, 1.044, 2.835, 0., 1, "N", "N7"],
496
+ [0.897, 0.346, 3.805, 0., 1, "C", "C5"],
497
+ [1.069, 0.196, 5.191, 0., 1, "C", "C6"],
498
+ [2.079, 0.869, 5.856, 0., 1, "N", "N6"],
499
+ [0.236, -0.603, 5.850, 0., 1, "N", "N1"],
500
+ [-0.729, -1.249, 5.224, 0., 1, "C", "C2"],
501
+ [-0.925, -1.144, 3.927, 0., 1, "N", "N3"],
502
+ [-0.142, -0.368, 3.184, 0., 1, "C", "C4"],
503
+ ],
504
+ "DG ": [
505
+ [-1.603, -1.547, 5.624, 0., 0, "O", "OP3"],
506
+ [-0.818, -0.321, 4.935, 0., 1, "P", "P"],
507
+ [-1.774, 0.766, 4.630, 0., 1, "O", "OP1"],
508
+ [0.312, 0.224, 5.941, 0., 1, "O", "OP2"],
509
+ [-0.126, -0.826, 3.572, 0., 1, "O", "O5'"],
510
+ [0.550, 0.300, 3.011, 0., 1, "C", "C5'"],
511
+ [1.233, -0.113, 1.706, 0., 1, "C", "C4'"],
512
+ [0.253, -0.471, 0.705, 0., 1, "O", "O4'"],
513
+ [1.976, 1.091, 1.073, 0., 1, "C", "C3'"],
514
+ [3.294, 1.218, 1.612, 0., 1, "O", "O3'"],
515
+ [2.026, 0.692, -0.421, 0., 1, "C", "C2'"],
516
+ [0.897, -0.345, -0.573, 0., 1, "C", "C1'"],
517
+ [-0.068, 0.111, -1.575, 0., 1, "N", "N9"],
518
+ [-1.172, 0.877, -1.341, 0., 1, "C", "C8"],
519
+ [-1.804, 1.094, -2.458, 0., 1, "N", "N7"],
520
+ [-1.145, 0.482, -3.472, 0., 1, "C", "C5"],
521
+ [-1.361, 0.377, -4.866, 0., 1, "C", "C6"],
522
+ [-2.321, 0.914, -5.391, 0., 1, "O", "O6"],
523
+ [-0.473, -0.327, -5.601, 0., 1, "N", "N1"],
524
+ [0.593, -0.928, -5.003, 0., 1, "C", "C2"],
525
+ [1.474, -1.643, -5.774, 0., 1, "N", "N2"],
526
+ [0.804, -0.839, -3.709, 0., 1, "N", "N3"],
527
+ [-0.027, -0.152, -2.917, 0., 1, "C", "C4"],
528
+ ],
529
+ "DC ": [
530
+ [1.941, -1.055, -4.672, 0., 0, "O", "OP3"],
531
+ [0.987, -0.017, -3.894, 0., 1, "P", "P"],
532
+ [1.802, 1.099, -3.365, 0., 1, "O", "OP1"],
533
+ [-0.119, 0.560, -4.910, 0., 1, "O", "OP2"],
534
+ [0.255, -0.772, -2.674, 0., 1, "O", "O5'"],
535
+ [-0.571, 0.196, -2.027, 0., 1, "C", "C5'"],
536
+ [-1.300, -0.459, -0.852, 0., 1, "C", "C4'"],
537
+ [-0.363, -0.863, 0.171, 0., 1, "O", "O4'"],
538
+ [-2.206, 0.569, -0.129, 0., 1, "C", "C3'"],
539
+ [-3.488, 0.649, -0.756, 0., 1, "O", "O3'"],
540
+ [-2.322, -0.040, 1.288, 0., 1, "C", "C2'"],
541
+ [-1.106, -0.981, 1.395, 0., 1, "C", "C1'"],
542
+ [-0.267, -0.584, 2.528, 0., 1, "N", "N1"],
543
+ [0.270, 0.648, 2.563, 0., 1, "C", "C2"],
544
+ [0.052, 1.424, 1.647, 0., 1, "O", "O2"],
545
+ [1.037, 1.035, 3.581, 0., 1, "N", "N3"],
546
+ [1.291, 0.212, 4.589, 0., 1, "C", "C4"],
547
+ [2.085, 0.622, 5.635, 0., 1, "N", "N4"],
548
+ [0.746, -1.088, 4.580, 0., 1, "C", "C5"],
549
+ [-0.035, -1.465, 3.541, 0., 1, "C", "C6"],
550
+ ],
551
+ "DT ": [
552
+ [-3.912, -2.311, 1.636, 0., 0, "O", "OP3"],
553
+ [-3.968, -1.665, 3.118, 0., 1, "P", "P"],
554
+ [-4.406, -2.599, 4.208, 0., 1, "O", "OP1"],
555
+ [-4.901, -0.360, 2.920, 0., 1, "O", "OP2"],
556
+ [-2.493, -1.028, 3.315, 0., 1, "O", "O5'"],
557
+ [-2.005, -0.136, 2.327, 0., 1, "C", "C5'"],
558
+ [-0.611, 0.328, 2.728, 0., 1, "C", "C4'"],
559
+ [0.247, -0.829, 2.764, 0., 1, "O", "O4'"],
560
+ [0.008, 1.286, 1.720, 0., 1, "C", "C3'"],
561
+ [0.965, 2.121, 2.368, 0., 1, "O", "O3'"],
562
+ [0.710, 0.360, 0.754, 0., 1, "C", "C2'"],
563
+ [1.157, -0.778, 1.657, 0., 1, "C", "C1'"],
564
+ [1.164, -2.047, 0.989, 0., 1, "N", "N1"],
565
+ [2.333, -2.544, 0.374, 0., 1, "C", "C2"],
566
+ [3.410, -1.945, 0.363, 0., 1, "O", "O2"],
567
+ [2.194, -3.793, -0.240, 0., 1, "N", "N3"],
568
+ [1.047, -4.570, -0.300, 0., 1, "C", "C4"],
569
+ [0.995, -5.663, -0.857, 0., 1, "O", "O4"],
570
+ [-0.143, -3.980, 0.369, 0., 1, "C", "C5"],
571
+ [-1.420, -4.757, 0.347, 0., 1, "C", "C7"],
572
+ [-0.013, -2.784, 0.958, 0., 1, "C", "C6"],
573
+ ],
574
+ }
575
+
576
+ standard_ccd_to_ref_atom_name_chars = {
577
+ ccd: [atom_ref_feats[-1] for atom_ref_feats in standard_ccd_to_reference_features_table[ccd]]
578
+ for ccd in standard_ccds if not is_unk(ccd)
579
+ }
580
+
581
+ eye_64 = np.eye(64)
582
+ eye_128 = np.eye(128)
583
+ eye_9 = np.eye(9)
584
+ eye_7 = np.eye(7)
585
+ eye_3 = np.eye(3)
586
+ eye_32 = np.eye(32)
587
+ eye_5 = np.eye(5)
588
+
589
+
590
+ def _get_ref_feat_from_ccd_data(ccd, ref_feat_table):
591
+ ref_feat = np.stack([
592
+ np.concatenate(
593
+ [np.array(atom_ref_feats[:5]), eye_128[get_element_id[atom_ref_feats[5].lower()]],
594
+ *[eye_64[ord(c) - 32] for c in f"{atom_ref_feats[-1]:<4}"]], axis=-1)
595
+ for atom_ref_feats in ref_feat_table[ccd]
596
+ ], axis=0)
597
+
598
+ return ref_feat
599
+
600
+
601
+ standard_ccd_to_ref_feat = {
602
+ ccd: _get_ref_feat_from_ccd_data(ccd, standard_ccd_to_reference_features_table) for ccd in standard_ccds if
603
+ not is_unk(ccd)
604
+ }
PhysDock/data/tools/templates.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Functions for getting templates and calculating template features."""
17
+ import abc
18
+ import dataclasses
19
+ import datetime
20
+ import functools
21
+ import glob
22
+ import json
23
+ import logging
24
+ import os
25
+ import re
26
+ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
27
+ import numpy as np
28
+
29
+ from . import parsers, mmcif_parsing
30
+ from . import kalign
31
+ from .utils import to_date
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class residue_constants:
36
+ restypes = ["A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
37
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
38
+ restypes_with_x = restypes + ["X"]
39
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
40
+ atom_type_num = 37
41
+ atom_types = ["N", "CA", "C", "CB", "O", "CG", "CG1", "CG2", "OG", "OG1", "SG", "CD", "CD1", "CD2", "ND1", "ND2",
42
+ "OD1",
43
+ "OD2", "SD", "CE", "CE1", "CE2", "CE3", "NE", "NE1", "NE2", "OE1", "OE2", "CH2", "NH1", "NH2", "OH",
44
+ "CZ",
45
+ "CZ2", "CZ3", "NZ", "OXT", ]
46
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
47
+
48
+ HHBLITS_AA_TO_ID = {
49
+ "A": 0, "B": 2, "C": 1, "D": 2, "E": 3, "F": 4, "G": 5, "H": 6, "I": 7, "J": 20, "K": 8, "L": 9, "M": 10,
50
+ "N": 11, "O": 20, "P": 12, "Q": 13, "R": 14, "S": 15, "T": 16, "U": 1, "V": 17, "W": 18, "X": 20, "Y": 19,
51
+ "Z": 3, "-": 21,
52
+ }
53
+
54
+ @staticmethod
55
+ def sequence_to_onehot(
56
+ sequence: str, mapping: Mapping[str, int], map_unknown_to_x: bool = False
57
+ ) -> np.ndarray:
58
+ """Maps the given sequence into a one-hot encoded matrix.
59
+
60
+ Args:
61
+ sequence: An amino acid sequence.
62
+ mapping: A dictionary mapping amino acids to integers.
63
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
64
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
65
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
66
+ the mapping will throw an error.
67
+
68
+ Returns:
69
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
70
+ the sequence.
71
+
72
+ Raises:
73
+ ValueError: If the mapping doesn't contain values from 0 to
74
+ num_unique_aas - 1 without any gaps.
75
+ """
76
+ num_entries = max(mapping.values()) + 1
77
+
78
+ if sorted(set(mapping.values())) != list(range(num_entries)):
79
+ raise ValueError(
80
+ "The mapping must have values from 0 to num_unique_aas-1 "
81
+ "without any gaps. Got: %s" % sorted(mapping.values())
82
+ )
83
+
84
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)
85
+
86
+ for aa_index, aa_type in enumerate(sequence):
87
+ if map_unknown_to_x:
88
+ if aa_type.isalpha() and aa_type.isupper():
89
+ aa_id = mapping.get(aa_type, mapping["X"])
90
+ else:
91
+ raise ValueError(f"Invalid character in the sequence: {aa_type}")
92
+ else:
93
+ aa_id = mapping[aa_type]
94
+ one_hot_arr[aa_index, aa_id] = 1
95
+
96
+ return one_hot_arr
97
+
98
+
99
+ class Error(Exception):
100
+ """Base class for exceptions."""
101
+
102
+
103
+ class NoChainsError(Error):
104
+ """An error indicating that template mmCIF didn't have any chains."""
105
+
106
+
107
+ class SequenceNotInTemplateError(Error):
108
+ """An error indicating that template mmCIF didn't contain the sequence."""
109
+
110
+
111
+ class NoAtomDataInTemplateError(Error):
112
+ """An error indicating that template mmCIF didn't contain atom positions."""
113
+
114
+
115
+ class TemplateAtomMaskAllZerosError(Error):
116
+ """An error indicating that template mmCIF had all atom positions masked."""
117
+
118
+
119
+ class QueryToTemplateAlignError(Error):
120
+ """An error indicating that the query can't be aligned to the template."""
121
+
122
+
123
+ class CaDistanceError(Error):
124
+ """An error indicating that a CA atom distance exceeds a threshold."""
125
+
126
+
127
+ # Prefilter exceptions.
128
+ class PrefilterError(Exception):
129
+ """A base class for template prefilter exceptions."""
130
+
131
+
132
+ class DateError(PrefilterError):
133
+ """An error indicating that the hit date was after the max allowed date."""
134
+
135
+
136
+ class AlignRatioError(PrefilterError):
137
+ """An error indicating that the hit align ratio to the query was too small."""
138
+
139
+
140
+ class DuplicateError(PrefilterError):
141
+ """An error indicating that the hit was an exact subsequence of the query."""
142
+
143
+
144
+ class LengthError(PrefilterError):
145
+ """An error indicating that the hit was too short."""
146
+
147
+
148
+ TEMPLATE_FEATURES = {
149
+ "template_aatype": np.int64,
150
+ "template_all_atom_masks": np.float32,
151
+ "template_all_atom_positions": np.float32,
152
+ "template_domain_names": object,
153
+ "template_sequence": object,
154
+ "template_sum_probs": np.float32,
155
+ }
156
+
157
+
158
+ def empty_template_feats(n_res):
159
+ return {
160
+ "template_aatype": np.zeros(
161
+ (0, n_res, len(residue_constants.restypes_with_x_and_gap)),
162
+ np.float32
163
+ ),
164
+ "template_all_atom_masks": np.zeros(
165
+ (0, n_res, residue_constants.atom_type_num), np.float32
166
+ ),
167
+ "template_all_atom_positions": np.zeros(
168
+ (0, n_res, residue_constants.atom_type_num, 3), np.float32
169
+ ),
170
+ "template_domain_names": np.array([''.encode()], dtype=np.object),
171
+ "template_sequence": np.array([''.encode()], dtype=np.object),
172
+ "template_sum_probs": np.zeros((0, 1), dtype=np.float32),
173
+ }
174
+
175
+
176
+ def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
177
+ """Returns PDB id and chain id for an HHSearch Hit."""
178
+ # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
179
+ id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
180
+ if not id_match:
181
+ raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
182
+ pdb_id, chain_id = id_match.group(0).split("_")
183
+ return pdb_id.lower(), chain_id
184
+
185
+
186
+ def _is_after_cutoff(
187
+ pdb_id: str,
188
+ release_dates: Mapping[str, datetime.datetime],
189
+ release_date_cutoff: Optional[datetime.datetime],
190
+ ) -> bool:
191
+ """Checks if the template date is after the release date cutoff.
192
+
193
+ Args:
194
+ pdb_id: 4 letter pdb code.
195
+ release_dates: Dictionary mapping PDB ids to their structure release dates.
196
+ release_date_cutoff: Max release date that is valid for this query.
197
+
198
+ Returns:
199
+ True if the template release date is after the cutoff, False otherwise.
200
+ """
201
+ pdb_id_upper = pdb_id.upper()
202
+ if release_date_cutoff is None:
203
+ raise ValueError("The release_date_cutoff must not be None.")
204
+ if pdb_id_upper in release_dates:
205
+ return release_dates[pdb_id_upper] > release_date_cutoff
206
+ else:
207
+ # Since this is just a quick prefilter to reduce the number of mmCIF files
208
+ # we need to parse, we don't have to worry about returning True here.
209
+ logging.info(
210
+ "Template structure not in release dates dict: %s", pdb_id
211
+ )
212
+ return False
213
+
214
+
215
+ def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
216
+ """Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
217
+ obsolete_new = {}
218
+ obsolete_keys = obsolete_mapping.keys()
219
+
220
+ def _new_target(k):
221
+ v = obsolete_mapping[k]
222
+ if v in obsolete_keys:
223
+ return _new_target(v)
224
+ return v
225
+
226
+ for k in obsolete_keys:
227
+ obsolete_new[k] = _new_target(k)
228
+
229
+ return obsolete_new
230
+
231
+
232
+ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
233
+ """Parses the data file from PDB that lists which PDB ids are obsolete."""
234
+ with open(obsolete_file_path) as f:
235
+ result = {}
236
+ for line in f:
237
+ line = line.strip()
238
+ # We skip obsolete entries that don't contain a mapping to a new entry.
239
+ if line.startswith("OBSLTE") and len(line) > 30:
240
+ # Format: Date From To
241
+ # 'OBSLTE 31-JUL-94 116L 216L'
242
+ from_id = line[20:24].lower()
243
+ to_id = line[29:33].lower()
244
+ result[from_id] = to_id
245
+ return _replace_obsolete_references(result)
246
+
247
+
248
+ def generate_release_dates_cache(mmcif_dir: str, out_path: str):
249
+ dates = {}
250
+ for f in os.listdir(mmcif_dir):
251
+ if f.endswith(".cif"):
252
+ path = os.path.join(mmcif_dir, f)
253
+ with open(path, "r") as fp:
254
+ mmcif_string = fp.read()
255
+
256
+ file_id = os.path.splitext(f)[0]
257
+ mmcif = mmcif_parsing.parse(
258
+ file_id=file_id, mmcif_string=mmcif_string
259
+ )
260
+ if mmcif.mmcif_object is None:
261
+ logging.info(f"Failed to parse {f}. Skipping...")
262
+ continue
263
+
264
+ mmcif = mmcif.mmcif_object
265
+ release_date = mmcif.header["release_date"]
266
+
267
+ dates[file_id] = release_date
268
+
269
+ with open(out_path, "r") as fp:
270
+ fp.write(json.dumps(dates))
271
+
272
+
273
+ def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
274
+ """Parses release dates file, returns a mapping from PDBs to release dates."""
275
+ with open(path, "r") as fp:
276
+ data = json.load(fp)
277
+
278
+ return {
279
+ pdb.upper(): to_date(v)
280
+ for pdb, d in data.items()
281
+ for k, v in d.items()
282
+ if k == "release_date"
283
+ }
284
+
285
+
286
+ def _assess_hhsearch_hit(
287
+ hit: parsers.TemplateHit,
288
+ hit_pdb_code: str,
289
+ query_sequence: str,
290
+ release_dates: Mapping[str, datetime.datetime],
291
+ release_date_cutoff: datetime.datetime,
292
+ max_subsequence_ratio: float = 0.95,
293
+ min_align_ratio: float = 0.1,
294
+ ) -> bool:
295
+ """Determines if template is valid (without parsing the template mmcif file).
296
+
297
+ Args:
298
+ hit: HhrHit for the template.
299
+ hit_pdb_code: The 4 letter pdb code of the template hit. This might be
300
+ different from the value in the actual hit since the original pdb might
301
+ have become obsolete.
302
+ query_sequence: Amino acid sequence of the query.
303
+ release_dates: Dictionary mapping pdb codes to their structure release
304
+ dates.
305
+ release_date_cutoff: Max release date that is valid for this query.
306
+ max_subsequence_ratio: Exclude any exact matches with this much overlap.
307
+ min_align_ratio: Minimum overlap between the template and query.
308
+
309
+ Returns:
310
+ True if the hit passed the prefilter. Raises an exception otherwise.
311
+
312
+ Raises:
313
+ DateError: If the hit date was after the max allowed date.
314
+ AlignRatioError: If the hit align ratio to the query was too small.
315
+ DuplicateError: If the hit was an exact subsequence of the query.
316
+ LengthError: If the hit was too short.
317
+ """
318
+ aligned_cols = hit.aligned_cols
319
+ align_ratio = aligned_cols / len(query_sequence)
320
+
321
+ template_sequence = hit.hit_sequence.replace("-", "")
322
+ length_ratio = float(len(template_sequence)) / len(query_sequence)
323
+
324
+ if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
325
+ date = release_dates[hit_pdb_code.upper()]
326
+ raise DateError(
327
+ f"Date ({date}) > max template date "
328
+ f"({release_date_cutoff})."
329
+ )
330
+
331
+ if align_ratio <= min_align_ratio:
332
+ raise AlignRatioError(
333
+ "Proportion of residues aligned to query too small. "
334
+ f"Align ratio: {align_ratio}."
335
+ )
336
+
337
+ # Check whether the template is a large subsequence or duplicate of original
338
+ # query. This can happen due to duplicate entries in the PDB database.
339
+ duplicate = (
340
+ template_sequence in query_sequence
341
+ and length_ratio > max_subsequence_ratio
342
+ )
343
+
344
+ if duplicate:
345
+ raise DuplicateError(
346
+ "Template is an exact subsequence of query with large "
347
+ f"coverage. Length ratio: {length_ratio}."
348
+ )
349
+
350
+ if len(template_sequence) < 10:
351
+ raise LengthError(
352
+ f"Template too short. Length: {len(template_sequence)}."
353
+ )
354
+
355
+ return True
356
+
357
+
358
+ def _find_template_in_pdb(
359
+ template_chain_id: str,
360
+ template_sequence: str,
361
+ mmcif_object: mmcif_parsing.MmcifObject,
362
+ ) -> Tuple[str, str, int]:
363
+ """Tries to find the template chain in the given pdb file.
364
+
365
+ This method tries the three following things in order:
366
+ 1. Tries if there is an exact match in both the chain ID and the sequence.
367
+ If yes, the chain sequence is returned. Otherwise:
368
+ 2. Tries if there is an exact match only in the sequence.
369
+ If yes, the chain sequence is returned. Otherwise:
370
+ 3. Tries if there is a fuzzy match (X = wildcard) in the sequence.
371
+ If yes, the chain sequence is returned.
372
+ If none of these succeed, a SequenceNotInTemplateError is thrown.
373
+
374
+ Args:
375
+ template_chain_id: The template chain ID.
376
+ template_sequence: The template chain sequence.
377
+ mmcif_object: The PDB object to search for the template in.
378
+
379
+ Returns:
380
+ A tuple with:
381
+ * The chain sequence that was found to match the template in the PDB object.
382
+ * The ID of the chain that is being returned.
383
+ * The offset where the template sequence starts in the chain sequence.
384
+
385
+ Raises:
386
+ SequenceNotInTemplateError: If no match is found after the steps described
387
+ above.
388
+ """
389
+ # Try if there is an exact match in both the chain ID and the (sub)sequence.
390
+ pdb_id = mmcif_object.file_id
391
+ chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
392
+ if chain_sequence and (template_sequence in chain_sequence):
393
+ logging.info(
394
+ "Found an exact template match %s_%s.", pdb_id, template_chain_id
395
+ )
396
+ mapping_offset = chain_sequence.find(template_sequence)
397
+ return chain_sequence, template_chain_id, mapping_offset
398
+
399
+ # Try if there is an exact match in the (sub)sequence only.
400
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
401
+ if chain_sequence and (template_sequence in chain_sequence):
402
+ logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
403
+ mapping_offset = chain_sequence.find(template_sequence)
404
+ return chain_sequence, chain_id, mapping_offset
405
+
406
+ # Return a chain sequence that fuzzy matches (X = wildcard) the template.
407
+ # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
408
+ regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
409
+ regex = re.compile("".join(regex))
410
+ for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
411
+ match = re.search(regex, chain_sequence)
412
+ if match:
413
+ logging.info(
414
+ "Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
415
+ )
416
+ mapping_offset = match.start()
417
+ return chain_sequence, chain_id, mapping_offset
418
+
419
+ # No hits, raise an error.
420
+ raise SequenceNotInTemplateError(
421
+ "Could not find the template sequence in %s_%s. Template sequence: %s, "
422
+ "chain_to_seqres: %s"
423
+ % (
424
+ pdb_id,
425
+ template_chain_id,
426
+ template_sequence,
427
+ mmcif_object.chain_to_seqres,
428
+ )
429
+ )
430
+
431
+
432
+ def _realign_pdb_template_to_query(
433
+ old_template_sequence: str,
434
+ template_chain_id: str,
435
+ mmcif_object: mmcif_parsing.MmcifObject,
436
+ old_mapping: Mapping[int, int],
437
+ kalign_binary_path: str,
438
+ ) -> Tuple[str, Mapping[int, int]]:
439
+ """Aligns template from the mmcif_object to the query.
440
+
441
+ In case PDB70 contains a different version of the template sequence, we need
442
+ to perform a realignment to the actual sequence that is in the mmCIF file.
443
+ This method performs such realignment, but returns the new sequence and
444
+ mapping only if the sequence in the mmCIF file is 90% identical to the old
445
+ sequence.
446
+
447
+ Note that the old_template_sequence comes from the hit, and contains only that
448
+ part of the chain that matches with the query while the new_template_sequence
449
+ is the full chain.
450
+
451
+ Args:
452
+ old_template_sequence: The template sequence that was returned by the PDB
453
+ template search (typically done using HHSearch).
454
+ template_chain_id: The template chain id was returned by the PDB template
455
+ search (typically done using HHSearch). This is used to find the right
456
+ chain in the mmcif_object chain_to_seqres mapping.
457
+ mmcif_object: A mmcif_object which holds the actual template data.
458
+ old_mapping: A mapping from the query sequence to the template sequence.
459
+ This mapping will be used to compute the new mapping from the query
460
+ sequence to the actual mmcif_object template sequence by aligning the
461
+ old_template_sequence and the actual template sequence.
462
+ kalign_binary_path: The path to a kalign executable.
463
+
464
+ Returns:
465
+ A tuple (new_template_sequence, new_query_to_template_mapping) where:
466
+ * new_template_sequence is the actual template sequence that was found in
467
+ the mmcif_object.
468
+ * new_query_to_template_mapping is the new mapping from the query to the
469
+ actual template found in the mmcif_object.
470
+
471
+ Raises:
472
+ QueryToTemplateAlignError:
473
+ * If there was an error thrown by the alignment tool.
474
+ * Or if the actual template sequence differs by more than 10% from the
475
+ old_template_sequence.
476
+ """
477
+ aligner = kalign.Kalign(binary_path=kalign_binary_path)
478
+ new_template_sequence = mmcif_object.chain_to_seqres.get(
479
+ template_chain_id, ""
480
+ )
481
+
482
+ # Sometimes the template chain id is unknown. But if there is only a single
483
+ # sequence within the mmcif_object, it is safe to assume it is that one.
484
+ if not new_template_sequence:
485
+ if len(mmcif_object.chain_to_seqres) == 1:
486
+ logging.info(
487
+ "Could not find %s in %s, but there is only 1 sequence, so "
488
+ "using that one.",
489
+ template_chain_id,
490
+ mmcif_object.file_id,
491
+ )
492
+ new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
493
+ 0
494
+ ]
495
+ else:
496
+ raise QueryToTemplateAlignError(
497
+ f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
498
+ "If there are no mmCIF parsing errors, it is possible it was not a "
499
+ "protein chain."
500
+ )
501
+
502
+ try:
503
+ parsed_a3m = parsers.parse_a3m(
504
+ aligner.align([old_template_sequence, new_template_sequence])
505
+ )
506
+ old_aligned_template, new_aligned_template = parsed_a3m.sequences
507
+ except Exception as e:
508
+ raise QueryToTemplateAlignError(
509
+ "Could not align old template %s to template %s (%s_%s). Error: %s"
510
+ % (
511
+ old_template_sequence,
512
+ new_template_sequence,
513
+ mmcif_object.file_id,
514
+ template_chain_id,
515
+ str(e),
516
+ )
517
+ )
518
+
519
+ logging.info(
520
+ "Old aligned template: %s\nNew aligned template: %s",
521
+ old_aligned_template,
522
+ new_aligned_template,
523
+ )
524
+
525
+ old_to_new_template_mapping = {}
526
+ old_template_index = -1
527
+ new_template_index = -1
528
+ num_same = 0
529
+ for old_template_aa, new_template_aa in zip(
530
+ old_aligned_template, new_aligned_template
531
+ ):
532
+ if old_template_aa != "-":
533
+ old_template_index += 1
534
+ if new_template_aa != "-":
535
+ new_template_index += 1
536
+ if old_template_aa != "-" and new_template_aa != "-":
537
+ old_to_new_template_mapping[old_template_index] = new_template_index
538
+ if old_template_aa == new_template_aa:
539
+ num_same += 1
540
+
541
+ # Require at least 90 % sequence identity wrt to the shorter of the sequences.
542
+ if (
543
+ float(num_same)
544
+ / min(len(old_template_sequence), len(new_template_sequence))
545
+ < 0.9
546
+ ):
547
+ raise QueryToTemplateAlignError(
548
+ "Insufficient similarity of the sequence in the database: %s to the "
549
+ "actual sequence in the mmCIF file %s_%s: %s. We require at least "
550
+ "90 %% similarity wrt to the shorter of the sequences. This is not a "
551
+ "problem unless you think this is a template that should be included."
552
+ % (
553
+ old_template_sequence,
554
+ mmcif_object.file_id,
555
+ template_chain_id,
556
+ new_template_sequence,
557
+ )
558
+ )
559
+
560
+ new_query_to_template_mapping = {}
561
+ for query_index, old_template_index in old_mapping.items():
562
+ new_query_to_template_mapping[
563
+ query_index
564
+ ] = old_to_new_template_mapping.get(old_template_index, -1)
565
+
566
+ new_template_sequence = new_template_sequence.replace("-", "")
567
+
568
+ return new_template_sequence, new_query_to_template_mapping
569
+
570
+
571
+ def _check_residue_distances(
572
+ all_positions: np.ndarray,
573
+ all_positions_mask: np.ndarray,
574
+ max_ca_ca_distance: float,
575
+ ):
576
+ """Checks if the distance between unmasked neighbor residues is ok."""
577
+ ca_position = residue_constants.atom_order["CA"]
578
+ prev_is_unmasked = False
579
+ prev_calpha = None
580
+ for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
581
+ this_is_unmasked = bool(mask[ca_position])
582
+ if this_is_unmasked:
583
+ this_calpha = coords[ca_position]
584
+ if prev_is_unmasked:
585
+ distance = np.linalg.norm(this_calpha - prev_calpha)
586
+ if distance > max_ca_ca_distance:
587
+ raise CaDistanceError(
588
+ "The distance between residues %d and %d is %f > limit %f."
589
+ % (i, i + 1, distance, max_ca_ca_distance)
590
+ )
591
+ prev_calpha = this_calpha
592
+ prev_is_unmasked = this_is_unmasked
593
+
594
+
595
+ def _get_atom_positions(
596
+ mmcif_object: mmcif_parsing.MmcifObject,
597
+ auth_chain_id: str,
598
+ max_ca_ca_distance: float,
599
+ _zero_center_positions: bool = False,
600
+ ) -> Tuple[np.ndarray, np.ndarray]:
601
+ """Gets atom positions and mask from a list of Biopython Residues."""
602
+ coords_with_mask = mmcif_parsing.get_atom_coords(
603
+ mmcif_object=mmcif_object,
604
+ chain_id=auth_chain_id,
605
+ _zero_center_positions=_zero_center_positions,
606
+ )
607
+ all_atom_positions, all_atom_mask = coords_with_mask
608
+ _check_residue_distances(
609
+ all_atom_positions, all_atom_mask, max_ca_ca_distance
610
+ )
611
+ return all_atom_positions, all_atom_mask
612
+
613
+
614
+ def _extract_template_features(
615
+ mmcif_object: mmcif_parsing.MmcifObject,
616
+ pdb_id: str,
617
+ mapping: Mapping[int, int],
618
+ template_sequence: str,
619
+ query_sequence: str,
620
+ template_chain_id: str,
621
+ kalign_binary_path: str,
622
+ _zero_center_positions: bool = True,
623
+ ) -> Tuple[Dict[str, Any], Optional[str]]:
624
+ """Parses atom positions in the target structure and aligns with the query.
625
+
626
+ Atoms for each residue in the template structure are indexed to coincide
627
+ with their corresponding residue in the query sequence, according to the
628
+ alignment mapping provided.
629
+
630
+ Args:
631
+ mmcif_object: mmcif_parsing.MmcifObject representing the template.
632
+ pdb_id: PDB code for the template.
633
+ mapping: Dictionary mapping indices in the query sequence to indices in
634
+ the template sequence.
635
+ template_sequence: String describing the amino acid sequence for the
636
+ template protein.
637
+ query_sequence: String describing the amino acid sequence for the query
638
+ protein.
639
+ template_chain_id: String ID describing which chain in the structure proto
640
+ should be used.
641
+ kalign_binary_path: The path to a kalign executable used for template
642
+ realignment.
643
+
644
+ Returns:
645
+ A tuple with:
646
+ * A dictionary containing the extra features derived from the template
647
+ protein structure.
648
+ * A warning message if the hit was realigned to the actual mmCIF sequence.
649
+ Otherwise None.
650
+
651
+ Raises:
652
+ NoChainsError: If the mmcif object doesn't contain any chains.
653
+ SequenceNotInTemplateError: If the given chain id / sequence can't
654
+ be found in the mmcif object.
655
+ QueryToTemplateAlignError: If the actual template in the mmCIF file
656
+ can't be aligned to the query.
657
+ NoAtomDataInTemplateError: If the mmcif object doesn't contain
658
+ atom positions.
659
+ TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any
660
+ unmasked residues.
661
+ """
662
+ if mmcif_object is None or not mmcif_object.chain_to_seqres:
663
+ raise NoChainsError(
664
+ "No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
665
+ )
666
+
667
+ warning = None
668
+ try:
669
+ seqres, chain_id, mapping_offset = _find_template_in_pdb(
670
+ template_chain_id=template_chain_id,
671
+ template_sequence=template_sequence,
672
+ mmcif_object=mmcif_object,
673
+ )
674
+ except SequenceNotInTemplateError:
675
+ # If PDB70 contains a different version of the template, we use the sequence
676
+ # from the mmcif_object.
677
+ chain_id = template_chain_id
678
+ warning = (
679
+ f"The exact sequence {template_sequence} was not found in "
680
+ f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
681
+ )
682
+ logging.warning(warning)
683
+ # This throws an exception if it fails to realign the hit.
684
+ seqres, mapping = _realign_pdb_template_to_query(
685
+ old_template_sequence=template_sequence,
686
+ template_chain_id=template_chain_id,
687
+ mmcif_object=mmcif_object,
688
+ old_mapping=mapping,
689
+ kalign_binary_path=kalign_binary_path,
690
+ )
691
+ logging.info(
692
+ "Sequence in %s_%s: %s successfully realigned to %s",
693
+ pdb_id,
694
+ chain_id,
695
+ template_sequence,
696
+ seqres,
697
+ )
698
+ # The template sequence changed.
699
+ template_sequence = seqres
700
+ # No mapping offset, the query is aligned to the actual sequence.
701
+ mapping_offset = 0
702
+
703
+ try:
704
+ # Essentially set to infinity - we don't want to reject templates unless
705
+ # they're really really bad.
706
+ all_atom_positions, all_atom_mask = _get_atom_positions(
707
+ mmcif_object,
708
+ chain_id,
709
+ max_ca_ca_distance=150.0,
710
+ _zero_center_positions=_zero_center_positions,
711
+ )
712
+ except (CaDistanceError, KeyError) as ex:
713
+ raise NoAtomDataInTemplateError(
714
+ "Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
715
+ ) from ex
716
+
717
+ all_atom_positions = np.split(
718
+ all_atom_positions, all_atom_positions.shape[0]
719
+ )
720
+ all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
721
+
722
+ output_templates_sequence = []
723
+ templates_all_atom_positions = []
724
+ templates_all_atom_masks = []
725
+
726
+ for _ in query_sequence:
727
+ # Residues in the query_sequence that are not in the template_sequence:
728
+ templates_all_atom_positions.append(
729
+ np.zeros((residue_constants.atom_type_num, 3))
730
+ )
731
+ templates_all_atom_masks.append(
732
+ np.zeros(residue_constants.atom_type_num)
733
+ )
734
+ output_templates_sequence.append("-")
735
+
736
+ for k, v in mapping.items():
737
+ template_index = v + mapping_offset
738
+ templates_all_atom_positions[k] = all_atom_positions[template_index][0]
739
+ templates_all_atom_masks[k] = all_atom_masks[template_index][0]
740
+ output_templates_sequence[k] = template_sequence[v]
741
+
742
+ # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
743
+ if np.sum(templates_all_atom_masks) < 5:
744
+ raise TemplateAtomMaskAllZerosError(
745
+ "Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
746
+ % (
747
+ pdb_id,
748
+ chain_id,
749
+ min(mapping.values()) + mapping_offset,
750
+ max(mapping.values()) + mapping_offset,
751
+ )
752
+ )
753
+
754
+ output_templates_sequence = "".join(output_templates_sequence)
755
+
756
+ templates_aatype = residue_constants.sequence_to_onehot(
757
+ output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
758
+ )
759
+
760
+ return (
761
+ {
762
+ "template_all_atom_positions": np.array(
763
+ templates_all_atom_positions
764
+ ),
765
+ "template_all_atom_masks": np.array(templates_all_atom_masks),
766
+ "template_sequence": output_templates_sequence.encode(),
767
+ "template_aatype": np.array(templates_aatype),
768
+ "template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
769
+ },
770
+ warning,
771
+ )
772
+
773
+
774
+ def _build_query_to_hit_index_mapping(
775
+ hit_query_sequence: str,
776
+ hit_sequence: str,
777
+ indices_hit: Sequence[int],
778
+ indices_query: Sequence[int],
779
+ original_query_sequence: str,
780
+ ) -> Mapping[int, int]:
781
+ """Gets mapping from indices in original query sequence to indices in the hit.
782
+
783
+ hit_query_sequence and hit_sequence are two aligned sequences containing gap
784
+ characters. hit_query_sequence contains only the part of the original query
785
+ sequence that matched the hit. When interpreting the indices from the .hhr, we
786
+ need to correct for this to recover a mapping from original query sequence to
787
+ the hit sequence.
788
+
789
+ Args:
790
+ hit_query_sequence: The portion of the query sequence that is in the .hhr
791
+ hit
792
+ hit_sequence: The portion of the hit sequence that is in the .hhr
793
+ indices_hit: The indices for each aminoacid relative to the hit sequence
794
+ indices_query: The indices for each aminoacid relative to the original query
795
+ sequence
796
+ original_query_sequence: String describing the original query sequence.
797
+
798
+ Returns:
799
+ Dictionary with indices in the original query sequence as keys and indices
800
+ in the hit sequence as values.
801
+ """
802
+ # If the hit is empty (no aligned residues), return empty mapping
803
+ if not hit_query_sequence:
804
+ return {}
805
+
806
+ # Remove gaps and find the offset of hit.query relative to original query.
807
+ hhsearch_query_sequence = hit_query_sequence.replace("-", "")
808
+ hit_sequence = hit_sequence.replace("-", "")
809
+ hhsearch_query_offset = original_query_sequence.find(
810
+ hhsearch_query_sequence
811
+ )
812
+
813
+ # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
814
+ min_idx = min(x for x in indices_hit if x > -1)
815
+ fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
816
+
817
+ min_idx = min(x for x in indices_query if x > -1)
818
+ fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
819
+
820
+ # Zip the corrected indices, ignore case where both seqs have gap characters.
821
+ mapping = {}
822
+ for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
823
+ if q_t != -1 and q_i != -1:
824
+ if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
825
+ original_query_sequence
826
+ ):
827
+ continue
828
+ mapping[q_i + hhsearch_query_offset] = q_t
829
+
830
+ return mapping
831
+
832
+
833
+ @dataclasses.dataclass(frozen=True)
834
+ class PrefilterResult:
835
+ valid: bool
836
+ error: Optional[str]
837
+ warning: Optional[str]
838
+
839
+
840
+ @dataclasses.dataclass(frozen=True)
841
+ class SingleHitResult:
842
+ features: Optional[Mapping[str, Any]]
843
+ error: Optional[str]
844
+ warning: Optional[str]
845
+
846
+
847
+ def _prefilter_hit(
848
+ query_sequence: str,
849
+ hit: parsers.TemplateHit,
850
+ max_template_date: datetime.datetime,
851
+ release_dates: Mapping[str, datetime.datetime],
852
+ obsolete_pdbs: Mapping[str, str],
853
+ strict_error_check: bool = False,
854
+ ):
855
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
856
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
857
+
858
+ if hit_pdb_code not in release_dates:
859
+ if hit_pdb_code in obsolete_pdbs:
860
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
861
+
862
+ # Pass hit_pdb_code since it might have changed due to the pdb being
863
+ # obsolete.
864
+ try:
865
+ _assess_hhsearch_hit(
866
+ hit=hit,
867
+ hit_pdb_code=hit_pdb_code,
868
+ query_sequence=query_sequence,
869
+ release_dates=release_dates,
870
+ release_date_cutoff=max_template_date,
871
+ )
872
+ except PrefilterError as e:
873
+ hit_name = f"{hit_pdb_code}_{hit_chain_id}"
874
+ msg = f"hit {hit_name} did not pass prefilter: {str(e)}"
875
+ logging.info(msg)
876
+ if strict_error_check and isinstance(e, (DateError, DuplicateError)):
877
+ # In strict mode we treat some prefilter cases as errors.
878
+ return PrefilterResult(valid=False, error=msg, warning=None)
879
+
880
+ return PrefilterResult(valid=False, error=None, warning=None)
881
+
882
+ return PrefilterResult(valid=True, error=None, warning=None)
883
+
884
+
885
+ @functools.lru_cache(16, typed=False)
886
+ def _read_file(path):
887
+ with open(path, 'r') as f:
888
+ file_data = f.read()
889
+
890
+ return file_data
891
+
892
+
893
+ def _process_single_hit(
894
+ query_sequence: str,
895
+ hit: parsers.TemplateHit,
896
+ mmcif_dir: str,
897
+ max_template_date: datetime.datetime,
898
+ release_dates: Mapping[str, datetime.datetime],
899
+ obsolete_pdbs: Mapping[str, str],
900
+ kalign_binary_path: str,
901
+ strict_error_check: bool = False,
902
+ _zero_center_positions: bool = True,
903
+ ) -> SingleHitResult:
904
+ """Tries to extract template features from a single HHSearch hit."""
905
+ # Fail hard if we can't get the PDB ID and chain name from the hit.
906
+ hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
907
+
908
+ if hit_pdb_code not in release_dates:
909
+ if hit_pdb_code in obsolete_pdbs:
910
+ hit_pdb_code = obsolete_pdbs[hit_pdb_code]
911
+
912
+ mapping = _build_query_to_hit_index_mapping(
913
+ hit.query,
914
+ hit.hit_sequence,
915
+ hit.indices_hit,
916
+ hit.indices_query,
917
+ query_sequence,
918
+ )
919
+
920
+ # The mapping is from the query to the actual hit sequence, so we need to
921
+ # remove gaps (which regardless have a missing confidence score).
922
+ template_sequence = hit.hit_sequence.replace("-", "")
923
+
924
+ cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
925
+ logging.info(
926
+ "Reading PDB entry from %s. Query: %s, template: %s",
927
+ cif_path,
928
+ query_sequence,
929
+ template_sequence,
930
+ )
931
+
932
+ # Fail if we can't find the mmCIF file.
933
+ cif_string = _read_file(cif_path)
934
+
935
+ parsing_result = mmcif_parsing.parse(
936
+ file_id=hit_pdb_code, mmcif_string=cif_string
937
+ )
938
+
939
+ if parsing_result.mmcif_object is not None:
940
+ hit_release_date = datetime.datetime.strptime(
941
+ parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
942
+ )
943
+ if hit_release_date > max_template_date:
944
+ error = "Template %s date (%s) > max template date (%s)." % (
945
+ hit_pdb_code,
946
+ hit_release_date,
947
+ max_template_date,
948
+ )
949
+ if strict_error_check:
950
+ return SingleHitResult(features=None, error=error, warning=None)
951
+ else:
952
+ logging.info(error)
953
+ return SingleHitResult(features=None, error=None, warning=None)
954
+
955
+ try:
956
+ features, realign_warning = _extract_template_features(
957
+ mmcif_object=parsing_result.mmcif_object,
958
+ pdb_id=hit_pdb_code,
959
+ mapping=mapping,
960
+ template_sequence=template_sequence,
961
+ query_sequence=query_sequence,
962
+ template_chain_id=hit_chain_id,
963
+ kalign_binary_path=kalign_binary_path,
964
+ _zero_center_positions=_zero_center_positions,
965
+ )
966
+
967
+ if hit.sum_probs is None:
968
+ features["template_sum_probs"] = [0]
969
+ else:
970
+ features["template_sum_probs"] = [hit.sum_probs]
971
+
972
+ # It is possible there were some errors when parsing the other chains in the
973
+ # mmCIF file, but the template features for the chain we want were still
974
+ # computed. In such case the mmCIF parsing errors are not relevant.
975
+ return SingleHitResult(
976
+ features=features, error=None, warning=realign_warning
977
+ )
978
+ except (
979
+ NoChainsError,
980
+ NoAtomDataInTemplateError,
981
+ TemplateAtomMaskAllZerosError,
982
+ ) as e:
983
+ # These 3 errors indicate missing mmCIF experimental data rather than a
984
+ # problem with the template search, so turn them into warnings.
985
+ warning = (
986
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
987
+ "%s, mmCIF parsing errors: %s"
988
+ % (
989
+ hit_pdb_code,
990
+ hit_chain_id,
991
+ hit.sum_probs if hit.sum_probs else 0.,
992
+ hit.index,
993
+ str(e),
994
+ parsing_result.errors,
995
+ )
996
+ )
997
+ if strict_error_check:
998
+ return SingleHitResult(features=None, error=warning, warning=None)
999
+ else:
1000
+ return SingleHitResult(features=None, error=None, warning=warning)
1001
+ except Error as e:
1002
+ error = (
1003
+ "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
1004
+ "%s, mmCIF parsing errors: %s"
1005
+ % (
1006
+ hit_pdb_code,
1007
+ hit_chain_id,
1008
+ hit.sum_probs if hit.sum_probs else 0.,
1009
+ hit.index,
1010
+ str(e),
1011
+ parsing_result.errors,
1012
+ )
1013
+ )
1014
+ return SingleHitResult(features=None, error=error, warning=None)
1015
+
1016
+
1017
+ def get_custom_template_features(
1018
+ mmcif_path: str,
1019
+ query_sequence: str,
1020
+ pdb_id: str,
1021
+ chain_id: str,
1022
+ kalign_binary_path: str):
1023
+ with open(mmcif_path, "r") as mmcif_path:
1024
+ cif_string = mmcif_path.read()
1025
+
1026
+ mmcif_parse_result = mmcif_parsing.parse(
1027
+ file_id=pdb_id, mmcif_string=cif_string
1028
+ )
1029
+ template_sequence = mmcif_parse_result.mmcif_object.chain_to_seqres[chain_id]
1030
+
1031
+ mapping = {x: x for x, _ in enumerate(query_sequence)}
1032
+
1033
+ features, warnings = _extract_template_features(
1034
+ mmcif_object=mmcif_parse_result.mmcif_object,
1035
+ pdb_id=pdb_id,
1036
+ mapping=mapping,
1037
+ template_sequence=template_sequence,
1038
+ query_sequence=query_sequence,
1039
+ template_chain_id=chain_id,
1040
+ kalign_binary_path=kalign_binary_path,
1041
+ _zero_center_positions=True
1042
+ )
1043
+ features["template_sum_probs"] = [1.0]
1044
+
1045
+ # TODO: clean up this logic
1046
+ template_features = {}
1047
+ for template_feature_name in TEMPLATE_FEATURES:
1048
+ template_features[template_feature_name] = []
1049
+
1050
+ for k in template_features:
1051
+ template_features[k].append(features[k])
1052
+
1053
+ for name in template_features:
1054
+ template_features[name] = np.stack(
1055
+ template_features[name], axis=0
1056
+ ).astype(TEMPLATE_FEATURES[name])
1057
+
1058
+ return TemplateSearchResult(
1059
+ features=template_features, errors=None, warnings=warnings
1060
+ )
1061
+
1062
+
1063
+ @dataclasses.dataclass(frozen=True)
1064
+ class TemplateSearchResult:
1065
+ features: Mapping[str, Any]
1066
+ errors: Sequence[str]
1067
+ warnings: Sequence[str]
1068
+
1069
+
1070
+ class TemplateHitFeaturizer(abc.ABC):
1071
+ """An abstract base class for turning template hits to features."""
1072
+
1073
+ def __init__(
1074
+ self,
1075
+ mmcif_dir: str,
1076
+ max_template_date: str,
1077
+ max_hits: int,
1078
+ kalign_binary_path: str,
1079
+ release_dates_path: Optional[str] = None,
1080
+ obsolete_pdbs_path: Optional[str] = None,
1081
+ strict_error_check: bool = False,
1082
+ _shuffle_top_k_prefiltered: Optional[int] = None,
1083
+ _zero_center_positions: bool = True,
1084
+ ):
1085
+ """Initializes the Template Search.
1086
+
1087
+ Args:
1088
+ mmcif_dir: Path to a directory with mmCIF structures. Once a template ID
1089
+ is found by HHSearch, this directory is used to retrieve the template
1090
+ data.
1091
+ max_template_date: The maximum date permitted for template structures. No
1092
+ template with date higher than this date will be returned. In ISO8601
1093
+ date format, YYYY-MM-DD.
1094
+ max_hits: The maximum number of templates that will be returned.
1095
+ kalign_binary_path: The path to a kalign executable used for template
1096
+ realignment.
1097
+ release_dates_path: An optional path to a file with a mapping from PDB IDs
1098
+ to their release dates. Thanks to this we don't have to redundantly
1099
+ parse mmCIF files to get that information.
1100
+ obsolete_pdbs_path: An optional path to a file containing a mapping from
1101
+ obsolete PDB IDs to the PDB IDs of their replacements.
1102
+ strict_error_check: If True, then the following will be treated as errors:
1103
+ * If any template date is after the max_template_date.
1104
+ * If any template has identical PDB ID to the query.
1105
+ * If any template is a duplicate of the query.
1106
+ * Any feature computation errors.
1107
+ """
1108
+ self._mmcif_dir = mmcif_dir
1109
+ if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
1110
+ logging.error("Could not find CIFs in %s", self._mmcif_dir)
1111
+ raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
1112
+
1113
+ try:
1114
+ self._max_template_date = datetime.datetime.strptime(
1115
+ max_template_date, "%Y-%m-%d"
1116
+ )
1117
+ except ValueError:
1118
+ raise ValueError(
1119
+ "max_template_date must be set and have format YYYY-MM-DD."
1120
+ )
1121
+ self._max_hits = max_hits
1122
+ self._kalign_binary_path = kalign_binary_path
1123
+ self._strict_error_check = strict_error_check
1124
+
1125
+ if release_dates_path:
1126
+ logging.info(
1127
+ "Using precomputed release dates %s.", release_dates_path
1128
+ )
1129
+ self._release_dates = _parse_release_dates(release_dates_path)
1130
+ else:
1131
+ self._release_dates = {}
1132
+
1133
+ if obsolete_pdbs_path:
1134
+ logging.info(
1135
+ "Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
1136
+ )
1137
+ self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
1138
+ else:
1139
+ self._obsolete_pdbs = {}
1140
+
1141
+ self._shuffle_top_k_prefiltered = _shuffle_top_k_prefiltered
1142
+ self._zero_center_positions = _zero_center_positions
1143
+
1144
+ @abc.abstractmethod
1145
+ def get_templates(
1146
+ self,
1147
+ query_sequence: str,
1148
+ hits: Sequence[parsers.TemplateHit]
1149
+ ) -> TemplateSearchResult:
1150
+ """ Computes the templates for a given query sequence """
1151
+
1152
+
1153
+ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
1154
+ def get_templates(
1155
+ self,
1156
+ query_sequence: str,
1157
+ hits: Sequence[parsers.TemplateHit],
1158
+ ) -> TemplateSearchResult:
1159
+ """Computes the templates for given query sequence (more details above)."""
1160
+ logging.info("Searching for template for: %s", query_sequence)
1161
+
1162
+ template_features = {}
1163
+ for template_feature_name in TEMPLATE_FEATURES:
1164
+ template_features[template_feature_name] = []
1165
+
1166
+ already_seen = set()
1167
+ errors = []
1168
+ warnings = []
1169
+
1170
+ filtered = []
1171
+ for hit in hits:
1172
+ prefilter_result = _prefilter_hit(
1173
+ query_sequence=query_sequence,
1174
+ hit=hit,
1175
+ max_template_date=self._max_template_date,
1176
+ release_dates=self._release_dates,
1177
+ obsolete_pdbs=self._obsolete_pdbs,
1178
+ strict_error_check=self._strict_error_check,
1179
+ )
1180
+
1181
+ if prefilter_result.error:
1182
+ errors.append(prefilter_result.error)
1183
+
1184
+ if prefilter_result.warning:
1185
+ warnings.append(prefilter_result.warning)
1186
+
1187
+ if prefilter_result.valid:
1188
+ filtered.append(hit)
1189
+
1190
+ filtered = list(
1191
+ sorted(filtered, key=lambda x: x.sum_probs, reverse=True)
1192
+ )
1193
+
1194
+ idx = list(range(len(filtered)))
1195
+ if (self._shuffle_top_k_prefiltered):
1196
+ stk = self._shuffle_top_k_prefiltered
1197
+ idx[:stk] = np.random.permutation(idx[:stk])
1198
+
1199
+ for i in idx:
1200
+ # We got all the templates we wanted, stop processing hits.
1201
+ if len(already_seen) >= self._max_hits:
1202
+ break
1203
+ try:
1204
+ hit = filtered[i]
1205
+
1206
+ result = _process_single_hit(
1207
+ query_sequence=query_sequence,
1208
+ hit=hit,
1209
+ mmcif_dir=self._mmcif_dir,
1210
+ max_template_date=self._max_template_date,
1211
+ release_dates=self._release_dates,
1212
+ obsolete_pdbs=self._obsolete_pdbs,
1213
+ strict_error_check=self._strict_error_check,
1214
+ kalign_binary_path=self._kalign_binary_path,
1215
+ _zero_center_positions=self._zero_center_positions,
1216
+ )
1217
+
1218
+ if result.error:
1219
+ errors.append(result.error)
1220
+
1221
+ # There could be an error even if there are some results, e.g. thrown by
1222
+ # other unparsable chains in the same mmCIF file.
1223
+ if result.warning:
1224
+ warnings.append(result.warning)
1225
+
1226
+ if result.features is None:
1227
+ logging.info(
1228
+ "Skipped invalid hit %s, error: %s, warning: %s",
1229
+ hit.name,
1230
+ result.error,
1231
+ result.warning,
1232
+ )
1233
+ else:
1234
+ already_seen_key = result.features["template_sequence"]
1235
+ if (already_seen_key in already_seen):
1236
+ continue
1237
+ already_seen.add(already_seen_key)
1238
+ for k in template_features:
1239
+ template_features[k].append(result.features[k])
1240
+ except Exception as e:
1241
+ print(e)
1242
+ continue
1243
+
1244
+ if already_seen:
1245
+ for name in template_features:
1246
+ template_features[name] = np.stack(
1247
+ template_features[name], axis=0
1248
+ ).astype(TEMPLATE_FEATURES[name])
1249
+ else:
1250
+ num_res = len(query_sequence)
1251
+ # Construct a default template with all zeros.
1252
+ template_features = empty_template_feats(num_res)
1253
+
1254
+ return TemplateSearchResult(
1255
+ features=template_features, errors=errors, warnings=warnings
1256
+ )
1257
+
1258
+
1259
+ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
1260
+ def get_templates(
1261
+ self,
1262
+ query_sequence: str,
1263
+ hits: Sequence[parsers.TemplateHit]
1264
+ ) -> TemplateSearchResult:
1265
+ logging.info("Searching for template for: %s", query_sequence)
1266
+
1267
+ template_features = {}
1268
+ for template_feature_name in TEMPLATE_FEATURES:
1269
+ template_features[template_feature_name] = []
1270
+
1271
+ already_seen = set()
1272
+ errors = []
1273
+ warnings = []
1274
+
1275
+ # DISCREPANCY: This filtering scheme that saves time
1276
+ filtered = []
1277
+ for hit in hits:
1278
+ prefilter_result = _prefilter_hit(
1279
+ query_sequence=query_sequence,
1280
+ hit=hit,
1281
+ max_template_date=self._max_template_date,
1282
+ release_dates=self._release_dates,
1283
+ obsolete_pdbs=self._obsolete_pdbs,
1284
+ strict_error_check=self._strict_error_check,
1285
+ )
1286
+
1287
+ if prefilter_result.error:
1288
+ errors.append(prefilter_result.error)
1289
+
1290
+ if prefilter_result.warning:
1291
+ warnings.append(prefilter_result.warning)
1292
+
1293
+ if prefilter_result.valid:
1294
+ filtered.append(hit)
1295
+
1296
+ filtered = list(
1297
+ sorted(
1298
+ filtered, key=lambda x: x.sum_probs if x.sum_probs else 0., reverse=True
1299
+ )
1300
+ )
1301
+ idx = list(range(len(filtered)))
1302
+ if (self._shuffle_top_k_prefiltered):
1303
+ stk = self._shuffle_top_k_prefiltered
1304
+ idx[:stk] = np.random.permutation(idx[:stk])
1305
+
1306
+ for i in idx:
1307
+ if (len(already_seen) >= self._max_hits):
1308
+ break
1309
+
1310
+ hit = filtered[i]
1311
+
1312
+ result = _process_single_hit(
1313
+ query_sequence=query_sequence,
1314
+ hit=hit,
1315
+ mmcif_dir=self._mmcif_dir,
1316
+ max_template_date=self._max_template_date,
1317
+ release_dates=self._release_dates,
1318
+ obsolete_pdbs=self._obsolete_pdbs,
1319
+ strict_error_check=self._strict_error_check,
1320
+ kalign_binary_path=self._kalign_binary_path
1321
+ )
1322
+
1323
+ if result.error:
1324
+ errors.append(result.error)
1325
+
1326
+ if result.warning:
1327
+ warnings.append(result.warning)
1328
+
1329
+ if result.features is None:
1330
+ logging.debug(
1331
+ "Skipped invalid hit %s, error: %s, warning: %s",
1332
+ hit.name, result.error, result.warning,
1333
+ )
1334
+ else:
1335
+ already_seen_key = result.features["template_sequence"]
1336
+ if (already_seen_key in already_seen):
1337
+ continue
1338
+ # Increment the hit counter, since we got features out of this hit.
1339
+ already_seen.add(already_seen_key)
1340
+ for k in template_features:
1341
+ template_features[k].append(result.features[k])
1342
+
1343
+ if already_seen:
1344
+ for name in template_features:
1345
+ template_features[name] = np.stack(
1346
+ template_features[name], axis=0
1347
+ ).astype(TEMPLATE_FEATURES[name])
1348
+ else:
1349
+ num_res = len(query_sequence)
1350
+ # Construct a default template with all zeros.
1351
+ template_features = empty_template_feats(num_res)
1352
+
1353
+ return TemplateSearchResult(
1354
+ features=template_features,
1355
+ errors=errors,
1356
+ warnings=warnings,
1357
+ )
PhysDock/data/tools/utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Common utilities for data pipeline tools."""
17
+ import contextlib
18
+ import datetime
19
+ import logging
20
+ import shutil
21
+ import tempfile
22
+ import time
23
+ from typing import Optional
24
+
25
+
26
+ @contextlib.contextmanager
27
+ def tmpdir_manager(base_dir: Optional[str] = None):
28
+ """Context manager that deletes a temporary directory on exit."""
29
+ tmpdir = tempfile.mkdtemp(dir=base_dir)
30
+ try:
31
+ yield tmpdir
32
+ finally:
33
+ shutil.rmtree(tmpdir, ignore_errors=True)
34
+
35
+
36
+ @contextlib.contextmanager
37
+ def timing(msg: str):
38
+ logging.info("Started %s", msg)
39
+ tic = time.perf_counter()
40
+ yield
41
+ toc = time.perf_counter()
42
+ logging.info("Finished %s in %.3f seconds", msg, toc - tic)
43
+
44
+
45
+ def to_date(s: str):
46
+ return datetime.datetime(
47
+ year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
48
+ )
PhysDock/models/__init__.py ADDED
File without changes
PhysDock/models/layers/__init__.py ADDED
File without changes